diff --git a/tests/__init__.py b/tests/__init__.py index 57d631c3..3e222031 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1 +1,22 @@ # coding: utf-8 + +"""Requests test package initialisation.""" + +import warnings + +try: + import urllib3 as urllib3_package +except ImportError: + urllib3_package = False + +from requests.packages import urllib3 as urllib3_bundle + +if urllib3_package is urllib3_bundle: + from urllib3.exceptions import SNIMissingWarning +else: + from requests.packages.urllib3.exceptions import SNIMissingWarning + +# urllib3 sets SNIMissingWarning to only go off once, +# while this test suite requires it to always fire +# so that it occurs during test_requests.test_https_warnings +warnings.simplefilter('always', SNIMissingWarning) diff --git a/tests/test_requests.py b/tests/test_requests.py index 167d5ca7..9031a9d6 100755 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -9,6 +9,7 @@ import os import pickle import collections import contextlib +import warnings import io import requests @@ -36,6 +37,19 @@ from .utils import override_environ # listening on that port) TARPIT = 'http://10.255.255.1' +try: + from ssl import SSLContext + del SSLContext + HAS_MODERN_SSL = True +except ImportError: + HAS_MODERN_SSL = False + +try: + requests.pyopenssl + HAS_PYOPENSSL = True +except AttributeError: + HAS_PYOPENSSL = False + class TestRequests: @@ -606,6 +620,27 @@ class TestRequests: def test_pyopenssl_redirect(self, httpbin_secure, httpbin_ca_bundle): requests.get(httpbin_secure('status', '301'), verify=httpbin_ca_bundle) + def test_https_warnings(self, httpbin_secure, httpbin_ca_bundle): + """warnings are emitted with requests.get""" + if HAS_MODERN_SSL or HAS_PYOPENSSL: + warnings_expected = ('SubjectAltNameWarning', ) + else: + warnings_expected = ('SNIMissingWarning', + 'InsecurePlatformWarning', + 'SubjectAltNameWarning', ) + + with pytest.warns(None) as warning_records: + warnings.simplefilter('always') + requests.get(httpbin_secure('status', '200'), + verify=httpbin_ca_bundle) + + warning_records = [item for item in warning_records + if item.category.__name__ != 'ResourceWarning'] + + warnings_category = tuple( + item.category.__name__ for item in warning_records) + assert warnings_category == warnings_expected + def test_urlencoded_get_query_multivalued_param(self, httpbin): r = requests.get(httpbin('get'), params=dict(test=['foo', 'baz']))