More work on proxy auth

- Strip auth header if auth succeeds, so it's not passed upstream
- Actually use realm specification to BasicProxyAuth, and make it mandatory
- Cleanups and unit tests
This commit is contained in:
Aldo Cortesi 2012-12-31 10:56:44 +13:00
parent 3b84111493
commit 5347cb9c26
3 changed files with 56 additions and 14 deletions

View File

@ -9,9 +9,15 @@ class NullProxyAuth():
self.password_manager = password_manager self.password_manager = password_manager
self.username = "" self.username = ""
def clean(self, headers):
"""
Clean up authentication headers, so they're not passed upstream.
"""
pass
def authenticate(self, headers): def authenticate(self, headers):
""" """
Tests that the specified user is allowed to use the proxy (stub) Tests that the user is allowed to use the proxy
""" """
return True return True
@ -23,12 +29,17 @@ class NullProxyAuth():
class BasicProxyAuth(NullProxyAuth): class BasicProxyAuth(NullProxyAuth):
def __init__(self, password_manager, realm="mitmproxy"): CHALLENGE_HEADER = 'Proxy-Authenticate'
AUTH_HEADER = 'Proxy-Authorization'
def __init__(self, password_manager, realm):
NullProxyAuth.__init__(self, password_manager) NullProxyAuth.__init__(self, password_manager)
self.realm = "mitmproxy" self.realm = realm
def clean(self, headers):
del headers[self.AUTH_HEADER]
def authenticate(self, headers): def authenticate(self, headers):
auth_value = headers.get('Proxy-Authorization', []) auth_value = headers.get(self.AUTH_HEADER, [])
if not auth_value: if not auth_value:
return False return False
try: try:
@ -43,7 +54,7 @@ class BasicProxyAuth(NullProxyAuth):
return True return True
def auth_challenge_headers(self): def auth_challenge_headers(self):
return {'Proxy-Authenticate':'Basic realm="%s"'%self.realm} return {self.CHALLENGE_HEADER:'Basic realm="%s"'%self.realm}
def unparse_auth_value(self, scheme, username, password): def unparse_auth_value(self, scheme, username, password):
v = binascii.b2a_base64(username + ":" + password) v = binascii.b2a_base64(username + ":" + password)

View File

@ -356,7 +356,10 @@ class ProxyHandler(tcp.BaseHandler):
headers = http.read_headers(self.rfile) headers = http.read_headers(self.rfile)
if headers is None: if headers is None:
raise ProxyError(400, "Invalid headers") raise ProxyError(400, "Invalid headers")
if authenticate and self.config.authenticator and not self.config.authenticator.authenticate(headers): if authenticate and self.config.authenticator:
if self.config.authenticator.authenticate(headers):
self.config.authenticator.clean(headers)
else:
raise ProxyError( raise ProxyError(
407, 407,
"Proxy Authentication Required", "Proxy Authentication Required",
@ -552,7 +555,7 @@ def process_proxy_options(parser, options):
password_manager = authentication.HtpasswdPasswordManager(options.auth_htpasswd) password_manager = authentication.HtpasswdPasswordManager(options.auth_htpasswd)
# in the meanwhile, basic auth is the only true authentication scheme we support # in the meanwhile, basic auth is the only true authentication scheme we support
# so just use it # so just use it
authenticator = authentication.BasicProxyAuth(password_manager) authenticator = authentication.BasicProxyAuth(password_manager, "mitmproxy")
else: else:
authenticator = authentication.NullProxyAuth(None) authenticator = authentication.NullProxyAuth(None)

View File

@ -9,17 +9,18 @@ class TestNullProxyAuth:
na = authentication.NullProxyAuth(authentication.PermissivePasswordManager()) na = authentication.NullProxyAuth(authentication.PermissivePasswordManager())
assert not na.auth_challenge_headers() assert not na.auth_challenge_headers()
assert na.authenticate("foo") assert na.authenticate("foo")
na.clean({})
class TestBasicProxyAuth: class TestBasicProxyAuth:
def test_simple(self): def test_simple(self):
ba = authentication.BasicProxyAuth(authentication.PermissivePasswordManager()) ba = authentication.BasicProxyAuth(authentication.PermissivePasswordManager(), "test")
h = odict.ODictCaseless() h = odict.ODictCaseless()
assert ba.auth_challenge_headers() assert ba.auth_challenge_headers()
assert not ba.authenticate(h) assert not ba.authenticate(h)
def test_parse_auth_value(self): def test_parse_auth_value(self):
ba = authentication.BasicProxyAuth(authentication.PermissivePasswordManager()) ba = authentication.BasicProxyAuth(authentication.PermissivePasswordManager(), "test")
vals = ("basic", "foo", "bar") vals = ("basic", "foo", "bar")
assert ba.parse_auth_value(ba.unparse_auth_value(*vals)) == vals assert ba.parse_auth_value(ba.unparse_auth_value(*vals)) == vals
tutils.raises(ValueError, ba.parse_auth_value, "") tutils.raises(ValueError, ba.parse_auth_value, "")
@ -28,3 +29,30 @@ class TestBasicProxyAuth:
v = "basic " + binascii.b2a_base64("foo") v = "basic " + binascii.b2a_base64("foo")
tutils.raises(ValueError, ba.parse_auth_value, v) tutils.raises(ValueError, ba.parse_auth_value, v)
def test_authenticate_clean(self):
ba = authentication.BasicProxyAuth(authentication.PermissivePasswordManager(), "test")
hdrs = odict.ODictCaseless()
vals = ("basic", "foo", "bar")
hdrs[ba.AUTH_HEADER] = [ba.unparse_auth_value(*vals)]
assert ba.authenticate(hdrs)
ba.clean(hdrs)
assert not ba.AUTH_HEADER in hdrs
hdrs[ba.AUTH_HEADER] = [""]
assert not ba.authenticate(hdrs)
hdrs[ba.AUTH_HEADER] = ["foo"]
assert not ba.authenticate(hdrs)
vals = ("foo", "foo", "bar")
hdrs[ba.AUTH_HEADER] = [ba.unparse_auth_value(*vals)]
assert not ba.authenticate(hdrs)
ba = authentication.BasicProxyAuth(authentication.PasswordManager(), "test")
vals = ("basic", "foo", "bar")
hdrs[ba.AUTH_HEADER] = [ba.unparse_auth_value(*vals)]
assert not ba.authenticate(hdrs)