diff --git a/requests/packages/__init__.py b/requests/packages/__init__.py index ec6a9e06..4dcf870f 100644 --- a/requests/packages/__init__.py +++ b/requests/packages/__init__.py @@ -27,9 +27,13 @@ import sys class VendorAlias(object): - def __init__(self): + def __init__(self, package_names): + self._package_names = package_names self._vendor_name = __name__ self._vendor_pkg = self._vendor_name + "." + self._vendor_pkgs = [ + self._vendor_pkg + name for name in self._package_names + ] def find_module(self, fullname, path=None): if fullname.startswith(self._vendor_pkg): @@ -44,6 +48,14 @@ class VendorAlias(object): ) ) + if not (name == self._vendor_name or + any(name.startswith(pkg) for pkg in self._vendor_pkgs)): + raise ImportError( + "Cannot import %s, must be one of %s." % ( + name, self._vendor_pkgs + ) + ) + # Check to see if we already have this item in sys.modules, if we do # then simply return that. if name in sys.modules: @@ -92,4 +104,4 @@ class VendorAlias(object): return module -sys.meta_path.append(VendorAlias()) +sys.meta_path.append(VendorAlias(["urllib3", "chardet"])) diff --git a/test_requests.py b/test_requests.py index 69a10616..07430a8e 100755 --- a/test_requests.py +++ b/test_requests.py @@ -1611,5 +1611,12 @@ def test_urllib3_retries(): with pytest.raises(RetryError): s.get(httpbin('status/500')) +def test_vendor_aliases(): + from requests.packages import urllib3 + from requests.packages import chardet + + with pytest.raises(ImportError): + from requests.packages import webbrowser + if __name__ == '__main__': unittest.main()