Safer usage requests sessions and account for the verify_ssl requirements of each source. (#5728)

This commit is contained in:
Matt Davis
2023-06-11 02:58:39 -04:00
committed by GitHub
parent 5ebad396fd
commit 2d52034cf2
2 changed files with 25 additions and 15 deletions
+9 -10
View File
@@ -1,21 +1,18 @@
import re
from urllib.parse import urlparse
from pipenv.patched.pip._vendor import requests
from pipenv.patched.pip._vendor.requests.adapters import HTTPAdapter
from pipenv.patched.pip._vendor.urllib3 import util as urllib3_util
requests_session = None # type: ignore
def _get_requests_session(max_retries=1):
def _get_requests_session(max_retries=1, verify_ssl=True):
"""Load requests lazily."""
global requests_session
if requests_session is not None:
return requests_session
from pipenv.patched.pip._vendor import requests
requests_session = requests.Session()
adapter = requests.adapters.HTTPAdapter(max_retries=max_retries)
requests_session.mount("https://pypi.org/pypi", adapter)
adapter = HTTPAdapter(max_retries=max_retries)
requests_session.mount("https://", adapter)
if verify_ssl is False:
requests_session.verify = False
return requests_session
@@ -46,6 +43,7 @@ def create_mirror_source(url, name):
def download_file(url, filename, max_retries=1):
"""Downloads file from url to a path with filename"""
r = _get_requests_session(max_retries).get(url, stream=True)
r.close()
if not r.ok:
raise OSError("Unable to download file")
@@ -117,6 +115,7 @@ def proper_case(package_name):
r = _get_requests_session().get(
f"https://pypi.org/pypi/{package_name}/json", timeout=0.3, stream=True
)
r.close()
if not r.ok:
raise OSError(f"Unable to find package {package_name} in PyPI repository.")
+16 -5
View File
@@ -180,6 +180,7 @@ class Resolver:
self._pip_command = None
self._retry_attempts = 0
self._hash_cache = None
self._sessions = {}
def __repr__(self):
return (
@@ -594,7 +595,7 @@ class Resolver:
session=self.session,
)
# It would be nice if `shims.get_package_finder` took an
# `ignore_compatibility` parameter, but that's some vendorered code
# `ignore_compatibility` parameter, but that's some vendored code
# we'd rather avoid touching.
index_lookup = self.prepare_index_lookup()
ignore_compatibility_finder._ignore_compatibility = True
@@ -744,9 +745,19 @@ class Resolver:
cleaned_checksums.add(checksum)
return cleaned_checksums
def _get_hashes_from_pypi(self, ireq):
def _get_requests_session_for_source(self, source):
if self._sessions.get(source["name"]):
session = self._sessions[source["name"]]
else:
session = _get_requests_session(
self.project.s.PIPENV_MAX_RETRIES, source.get("verify_ssl", True)
)
self._sessions[source["name"]] = session
return session
def _get_hashes_from_pypi(self, ireq, source):
pkg_url = f"https://pypi.org/pypi/{ireq.name}/json"
session = _get_requests_session(self.project.s.PIPENV_MAX_RETRIES)
session = self._get_requests_session_for_source(source)
try:
collected_hashes = set()
# Grab the hashes from the new warehouse API.
@@ -776,7 +787,7 @@ class Resolver:
def _get_hashes_from_remote_index_urls(self, ireq, source):
pkg_url = f"{source['url']}/{ireq.name}/"
session = _get_requests_session(self.project.s.PIPENV_MAX_RETRIES)
session = self._get_requests_session_for_source(source)
try:
collected_hashes = set()
# Grab the hashes from the new warehouse API.
@@ -828,7 +839,7 @@ class Resolver:
source = sources[0] if len(sources) else None
if source:
if is_pypi_url(source["url"]):
hashes = self._get_hashes_from_pypi(ireq)
hashes = self._get_hashes_from_pypi(ireq, source)
if hashes:
return hashes
else: