From 7fdb09b766a84534dadfb56e46033697292cab60 Mon Sep 17 00:00:00 2001 From: Idan Gazit Date: Thu, 17 Nov 2011 13:44:13 +0200 Subject: [PATCH] Converted auth to use callable objects instead of tuples. My attempt to address #275 on kennethreitz/requests. --- docs/user/advanced.rst | 33 ++++++++--------- docs/user/quickstart.rst | 6 ++-- requests/auth.py | 76 ++++++++++++++-------------------------- requests/models.py | 10 ++---- test_requests.py | 7 ++-- 5 files changed, 55 insertions(+), 77 deletions(-) diff --git a/docs/user/advanced.rst b/docs/user/advanced.rst index 0c26f2d3..05267816 100644 --- a/docs/user/advanced.rst +++ b/docs/user/advanced.rst @@ -211,32 +211,33 @@ Custom Authentication Requests allows you to use specify your own authentication mechanism. -When you pass our authentication tuple to a request method, the first -string is the type of authentication. 'basic' is inferred if none is -provided. +Any callable which is passed as the ``auth`` argument to a request method will +have the opportunity to modify the request before it is dispatched. -You can pass in a callable object instead of a string for the first item -in the tuple, and it will be used in place of the built in authentication -callbacks. +Authentication implementations are subclasses of ``requests.auth.AuthBase``, +and are easy to define. Requests provides two common authentication scheme +implementations in ``requests.auth``: ``HTTPBasicAuth`` and ``HTTPDigestAuth``. Let's pretend that we have a web service that will only respond if the ``X-Pizza`` header is set to a password value. Unlikely, but just go with it. -We simply need to define a callback function that will be used to update the -Request object, right before it is dispatched. - :: - def pizza_auth(r, username): - """Attaches HTTP Pizza Authentication to the given Request object. - """ - r.headers['X-Pizza'] = username - - return r + from requests.auth import AuthBase + class PizzaAuth(AuthBase): + """Attaches HTTP Pizza Authentication to the given Request object.""" + def __init__(self, username): + # setup any auth-related data here + self.username = username + + def __call__(self, r): + # modify and return the request + r.headers['X-Pizza'] = self.username + return r Then, we can make a request using our Pizza Auth:: - >>> requests.get('http://pizzabin.org/admin', auth=(pizza_auth, 'kenneth')) + >>> requests.get('http://pizzabin.org/admin', auth=PizzaAuth('kenneth')) diff --git a/docs/user/quickstart.rst b/docs/user/quickstart.rst index 542a5bc6..cf8d1a71 100644 --- a/docs/user/quickstart.rst +++ b/docs/user/quickstart.rst @@ -235,7 +235,8 @@ authentication, but the most common is HTTP Basic Auth. Making requests with Basic Auth is extremely simple:: - >>> requests.get('https://api.github.com/user', auth=('user', 'pass')) + >>> from requests.auth import HTTPBasicAuth + >>> requests.get('https://api.github.com/user', auth=HTTPBasicAuth('user', 'pass')) OAuth Authentication @@ -249,8 +250,9 @@ Digest Authentication Another popular form of web service protection is Digest Authentication:: + >>> from requests.auth import HTTPDigestAuth >>> url = 'http://httpbin.org/digest-auth/auth/user/pass' - >>> requests.get(url, auth=('digest', 'user', 'pass')) + >>> requests.get(url, auth=HTTPDigestAuth('user', 'pass')) diff --git a/requests/auth.py b/requests/auth.py index aabeb866..fad6eb79 100644 --- a/requests/auth.py +++ b/requests/auth.py @@ -16,26 +16,32 @@ from urlparse import urlparse from .utils import randombytes, parse_dict_header -def http_basic(r, username, password): - """Attaches HTTP Basic Authentication to the given Request object. - Arguments should be considered non-positional. +class AuthBase(object): + """Base class that all auth implementations derive from""" - """ - username = str(username) - password = str(password) - - auth_s = b64encode('%s:%s' % (username, password)) - r.headers['Authorization'] = ('Basic %s' % auth_s) - - return r + def __call__(self, r): + raise NotImplementedError('Auth hooks must be callable.') -def http_digest(r, username, password): - """Attaches HTTP Digest Authentication to the given Request object. - Arguments should be considered non-positional. - """ +class HTTPBasicAuth(AuthBase): + """Attaches HTTP Basic Authentication to the given Request object.""" + def __init__(self, username, password): + self.username = str(username) + self.password = str(password) - def handle_401(r): + def __call__(self, r): + auth_s = b64encode('%s:%s' % (self.username, self.password)) + r.headers['Authorization'] = ('Basic %s' % auth_s) + return r + + +class HTTPDigestAuth(AuthBase): + """Attaches HTTP Digest Authentication to the given Request object.""" + def __init__(self, username, password): + self.username = username + self.password = password + + def handle_401(self, r): """Takes the given response and tries digest-auth, if needed.""" s_auth = r.headers.get('www-authenticate', '') @@ -70,7 +76,7 @@ def http_digest(r, username, password): p_parsed = urlparse(r.request.url) path = p_parsed.path + p_parsed.query - A1 = "%s:%s:%s" % (username, realm, password) + A1 = "%s:%s:%s" % (self.username, realm, self.password) A2 = "%s:%s" % (r.request.method, path) if qop == 'auth': @@ -95,7 +101,7 @@ def http_digest(r, username, password): # XXX should the partial digests be encoded too? base = 'username="%s", realm="%s", nonce="%s", uri="%s", ' \ - 'response="%s"' % (username, realm, nonce, path, respdig) + 'response="%s"' % (self.username, realm, nonce, path, respdig) if opaque: base += ', opaque="%s"' % opaque if entdig: @@ -104,7 +110,6 @@ def http_digest(r, username, password): if qop: base += ', qop=auth, nc=%s, cnonce="%s"' % (ncvalue, cnonce) - r.request.headers['Authorization'] = 'Digest %s' % (base) r.request.send(anyway=True) _r = r.request.response @@ -114,33 +119,6 @@ def http_digest(r, username, password): return r - r.hooks['response'] = handle_401 - return r - - -def dispatch(t): - """Given an auth tuple, return an expanded version.""" - - if not t: - return t - else: - t = list(t) - - # Make sure they're passing in something. - assert len(t) >= 2 - - # If only two items are passed in, assume HTTPBasic. - if (len(t) == 2): - t.insert(0, 'basic') - - # Allow built-in string referenced auths. - if isinstance(t[0], basestring): - if t[0] in ('basic', 'forced_basic'): - t[0] = http_basic - elif t[0] in ('digest',): - t[0] = http_digest - - # Return a custom callable. - return (t[0], tuple(t[1:])) - - + def __call__(self, r): + r.hooks['response'] = self.handle_401 + return r diff --git a/requests/models.py b/requests/models.py index 97237e77..9c4d1333 100644 --- a/requests/models.py +++ b/requests/models.py @@ -14,7 +14,6 @@ from Cookie import SimpleCookie from urlparse import urlparse, urlunparse, urljoin, urlsplit from datetime import datetime -from .auth import dispatch as auth_dispatch from .hooks import dispatch_hook from .structures import CaseInsensitiveDict from .status_codes import codes @@ -99,8 +98,7 @@ class Request(object): self.response = Response() #: Authentication tuple to attach to :class:`Request `. - self._auth = auth - self.auth = auth_dispatch(auth) + self.auth = auth #: CookieJar to attach to :class:`Request `. self.cookies = dict(cookies or []) @@ -235,7 +233,7 @@ class Request(object): files=self.files, method=method, params=self.session.params, - auth=self._auth, + auth=self.auth, cookies=cookies, redirect=True, config=self.config, @@ -392,10 +390,8 @@ class Request(object): if self.auth: - auth_func, auth_args = self.auth - # Allow auth to make its changes. - r = auth_func(self, *auth_args) + r = self.auth(self) # Update self to reflect the auth changes. self.__dict__.update(r.__dict__) diff --git a/test_requests.py b/test_requests.py index 61953a37..1cd73cfb 100755 --- a/test_requests.py +++ b/test_requests.py @@ -10,6 +10,7 @@ import unittest import requests import envoy from requests import HTTPError +from requests.auth import HTTPBasicAuth, HTTPDigestAuth try: import omnijson as json @@ -144,7 +145,7 @@ class RequestsTestSuite(unittest.TestCase): for service in SERVICES: - auth = ('user', 'pass') + auth = HTTPBasicAuth('user', 'pass') url = service('basic-auth', 'user', 'pass') r = requests.get(url, auth=auth) @@ -163,7 +164,7 @@ class RequestsTestSuite(unittest.TestCase): for service in SERVICES: - auth = ('digest', 'user', 'pass') + auth = HTTPDigestAuth('user', 'pass') url = service('digest-auth', 'auth', 'user', 'pass') r = requests.get(url, auth=auth) @@ -270,7 +271,7 @@ class RequestsTestSuite(unittest.TestCase): def test_httpauth_recursion(self): - http_auth = ('user', 'BADpass') + http_auth = HTTPBasicAuth('user', 'BADpass') for service in SERVICES: r = requests.get(service('basic-auth', 'user', 'pass'), auth=http_auth)