Converted auth to use callable objects instead of tuples.

My attempt to address #275 on kennethreitz/requests.
This commit is contained in:
Idan Gazit
2011-11-17 13:44:13 +02:00
parent 76dacaf231
commit 7fdb09b766
5 changed files with 55 additions and 77 deletions
+17 -16
View File
@@ -211,32 +211,33 @@ Custom Authentication
Requests allows you to use specify your own authentication mechanism.
When you pass our authentication tuple to a request method, the first
string is the type of authentication. 'basic' is inferred if none is
provided.
Any callable which is passed as the ``auth`` argument to a request method will
have the opportunity to modify the request before it is dispatched.
You can pass in a callable object instead of a string for the first item
in the tuple, and it will be used in place of the built in authentication
callbacks.
Authentication implementations are subclasses of ``requests.auth.AuthBase``,
and are easy to define. Requests provides two common authentication scheme
implementations in ``requests.auth``: ``HTTPBasicAuth`` and ``HTTPDigestAuth``.
Let's pretend that we have a web service that will only respond if the
``X-Pizza`` header is set to a password value. Unlikely, but just go with it.
We simply need to define a callback function that will be used to update the
Request object, right before it is dispatched.
::
def pizza_auth(r, username):
"""Attaches HTTP Pizza Authentication to the given Request object.
"""
r.headers['X-Pizza'] = username
return r
from requests.auth import AuthBase
class PizzaAuth(AuthBase):
"""Attaches HTTP Pizza Authentication to the given Request object."""
def __init__(self, username):
# setup any auth-related data here
self.username = username
def __call__(self, r):
# modify and return the request
r.headers['X-Pizza'] = self.username
return r
Then, we can make a request using our Pizza Auth::
>>> requests.get('http://pizzabin.org/admin', auth=(pizza_auth, 'kenneth'))
>>> requests.get('http://pizzabin.org/admin', auth=PizzaAuth('kenneth'))
<Response [200]>
+4 -2
View File
@@ -235,7 +235,8 @@ authentication, but the most common is HTTP Basic Auth.
Making requests with Basic Auth is extremely simple::
>>> requests.get('https://api.github.com/user', auth=('user', 'pass'))
>>> from requests.auth import HTTPBasicAuth
>>> requests.get('https://api.github.com/user', auth=HTTPBasicAuth('user', 'pass'))
<Response [200]>
OAuth Authentication
@@ -249,8 +250,9 @@ Digest Authentication
Another popular form of web service protection is Digest Authentication::
>>> from requests.auth import HTTPDigestAuth
>>> url = 'http://httpbin.org/digest-auth/auth/user/pass'
>>> requests.get(url, auth=('digest', 'user', 'pass'))
>>> requests.get(url, auth=HTTPDigestAuth('user', 'pass'))
<Response [200]>
+27 -49
View File
@@ -16,26 +16,32 @@ from urlparse import urlparse
from .utils import randombytes, parse_dict_header
def http_basic(r, username, password):
"""Attaches HTTP Basic Authentication to the given Request object.
Arguments should be considered non-positional.
class AuthBase(object):
"""Base class that all auth implementations derive from"""
"""
username = str(username)
password = str(password)
auth_s = b64encode('%s:%s' % (username, password))
r.headers['Authorization'] = ('Basic %s' % auth_s)
return r
def __call__(self, r):
raise NotImplementedError('Auth hooks must be callable.')
def http_digest(r, username, password):
"""Attaches HTTP Digest Authentication to the given Request object.
Arguments should be considered non-positional.
"""
class HTTPBasicAuth(AuthBase):
"""Attaches HTTP Basic Authentication to the given Request object."""
def __init__(self, username, password):
self.username = str(username)
self.password = str(password)
def handle_401(r):
def __call__(self, r):
auth_s = b64encode('%s:%s' % (self.username, self.password))
r.headers['Authorization'] = ('Basic %s' % auth_s)
return r
class HTTPDigestAuth(AuthBase):
"""Attaches HTTP Digest Authentication to the given Request object."""
def __init__(self, username, password):
self.username = username
self.password = password
def handle_401(self, r):
"""Takes the given response and tries digest-auth, if needed."""
s_auth = r.headers.get('www-authenticate', '')
@@ -70,7 +76,7 @@ def http_digest(r, username, password):
p_parsed = urlparse(r.request.url)
path = p_parsed.path + p_parsed.query
A1 = "%s:%s:%s" % (username, realm, password)
A1 = "%s:%s:%s" % (self.username, realm, self.password)
A2 = "%s:%s" % (r.request.method, path)
if qop == 'auth':
@@ -95,7 +101,7 @@ def http_digest(r, username, password):
# XXX should the partial digests be encoded too?
base = 'username="%s", realm="%s", nonce="%s", uri="%s", ' \
'response="%s"' % (username, realm, nonce, path, respdig)
'response="%s"' % (self.username, realm, nonce, path, respdig)
if opaque:
base += ', opaque="%s"' % opaque
if entdig:
@@ -104,7 +110,6 @@ def http_digest(r, username, password):
if qop:
base += ', qop=auth, nc=%s, cnonce="%s"' % (ncvalue, cnonce)
r.request.headers['Authorization'] = 'Digest %s' % (base)
r.request.send(anyway=True)
_r = r.request.response
@@ -114,33 +119,6 @@ def http_digest(r, username, password):
return r
r.hooks['response'] = handle_401
return r
def dispatch(t):
"""Given an auth tuple, return an expanded version."""
if not t:
return t
else:
t = list(t)
# Make sure they're passing in something.
assert len(t) >= 2
# If only two items are passed in, assume HTTPBasic.
if (len(t) == 2):
t.insert(0, 'basic')
# Allow built-in string referenced auths.
if isinstance(t[0], basestring):
if t[0] in ('basic', 'forced_basic'):
t[0] = http_basic
elif t[0] in ('digest',):
t[0] = http_digest
# Return a custom callable.
return (t[0], tuple(t[1:]))
def __call__(self, r):
r.hooks['response'] = self.handle_401
return r
+3 -7
View File
@@ -14,7 +14,6 @@ from Cookie import SimpleCookie
from urlparse import urlparse, urlunparse, urljoin, urlsplit
from datetime import datetime
from .auth import dispatch as auth_dispatch
from .hooks import dispatch_hook
from .structures import CaseInsensitiveDict
from .status_codes import codes
@@ -99,8 +98,7 @@ class Request(object):
self.response = Response()
#: Authentication tuple to attach to :class:`Request <Request>`.
self._auth = auth
self.auth = auth_dispatch(auth)
self.auth = auth
#: CookieJar to attach to :class:`Request <Request>`.
self.cookies = dict(cookies or [])
@@ -235,7 +233,7 @@ class Request(object):
files=self.files,
method=method,
params=self.session.params,
auth=self._auth,
auth=self.auth,
cookies=cookies,
redirect=True,
config=self.config,
@@ -392,10 +390,8 @@ class Request(object):
if self.auth:
auth_func, auth_args = self.auth
# Allow auth to make its changes.
r = auth_func(self, *auth_args)
r = self.auth(self)
# Update self to reflect the auth changes.
self.__dict__.update(r.__dict__)
+4 -3
View File
@@ -10,6 +10,7 @@ import unittest
import requests
import envoy
from requests import HTTPError
from requests.auth import HTTPBasicAuth, HTTPDigestAuth
try:
import omnijson as json
@@ -144,7 +145,7 @@ class RequestsTestSuite(unittest.TestCase):
for service in SERVICES:
auth = ('user', 'pass')
auth = HTTPBasicAuth('user', 'pass')
url = service('basic-auth', 'user', 'pass')
r = requests.get(url, auth=auth)
@@ -163,7 +164,7 @@ class RequestsTestSuite(unittest.TestCase):
for service in SERVICES:
auth = ('digest', 'user', 'pass')
auth = HTTPDigestAuth('user', 'pass')
url = service('digest-auth', 'auth', 'user', 'pass')
r = requests.get(url, auth=auth)
@@ -270,7 +271,7 @@ class RequestsTestSuite(unittest.TestCase):
def test_httpauth_recursion(self):
http_auth = ('user', 'BADpass')
http_auth = HTTPBasicAuth('user', 'BADpass')
for service in SERVICES:
r = requests.get(service('basic-auth', 'user', 'pass'), auth=http_auth)