Refactor request processing at mitmproxy's core.

Gradually cleaning up towards a state machine model.
This commit is contained in:
Aldo Cortesi 2012-06-10 16:02:48 +12:00
parent 236447c65f
commit 1f659948cd

View File

@ -104,7 +104,7 @@ def read_chunked(fp, limit):
return content return content
def read_http_body(rfile, connection, headers, all, limit): def read_http_body(rfile, client_conn, headers, all, limit):
if 'transfer-encoding' in headers: if 'transfer-encoding' in headers:
if not ",".join(headers["transfer-encoding"]).lower() == "chunked": if not ",".join(headers["transfer-encoding"]).lower() == "chunked":
raise IOError('Invalid transfer-encoding') raise IOError('Invalid transfer-encoding')
@ -121,7 +121,7 @@ def read_http_body(rfile, connection, headers, all, limit):
content = rfile.read(l) content = rfile.read(l)
elif all: elif all:
content = rfile.read(limit if limit else None) content = rfile.read(limit if limit else None)
connection.close = True client_conn.close = True
else: else:
content = "" content = ""
return content return content
@ -203,6 +203,18 @@ def should_connection_close(httpversion, headers):
return True return True
def read_http_body_request(rfile, wfile, client_conn, headers, httpversion, limit):
if "expect" in headers:
# FIXME: Should be forwarded upstream
expect = ",".join(headers['expect'])
if expect == "100-continue" and httpversion >= (1, 1):
wfile.write('HTTP/1.1 100 Continue\r\n')
wfile.write('Proxy-agent: %s\r\n'%version.NAMEVERSION)
wfile.write('\r\n')
del headers['expect']
return read_http_body(rfile, client_conn, headers, False, limit)
class FileLike: class FileLike:
def __init__(self, o): def __init__(self, o):
self.o = o self.o = o
@ -262,10 +274,10 @@ class RequestReplayThread(threading.Thread):
class ServerConnection: class ServerConnection:
def __init__(self, config, scheme, host, port): def __init__(self, config, scheme, host, port):
self.config, self.scheme, self.host, self.port = config, scheme, host, port self.config, self.scheme, self.host, self.port = config, scheme, host, port
self.close = False
self.cert = None self.cert = None
self.sock, self.rfile, self.wfile = None, None, None self.sock, self.rfile, self.wfile = None, None, None
self.connect() self.connect()
self.requestcount = 0
def connect(self): def connect(self):
try: try:
@ -288,6 +300,7 @@ class ServerConnection:
self.rfile, self.wfile = server.makefile('rb'), server.makefile('wb') self.rfile, self.wfile = server.makefile('rb'), server.makefile('wb')
def send(self, request): def send(self, request):
self.requestcount += 1
try: try:
d = request._assemble() d = request._assemble()
if not d: if not d:
@ -336,29 +349,38 @@ class ServerConnection:
class ProxyHandler(SocketServer.StreamRequestHandler): class ProxyHandler(SocketServer.StreamRequestHandler):
def __init__(self, config, request, client_address, server, q): def __init__(self, config, request, client_address, server, q):
self.config = config
self.mqueue = q self.mqueue = q
self.config = config
self.server_conn = None self.server_conn = None
self.proxy_connect_state = None
SocketServer.StreamRequestHandler.__init__(self, request, client_address, server) SocketServer.StreamRequestHandler.__init__(self, request, client_address, server)
def handle(self): def handle(self):
cc = flow.ClientConnect(self.client_address) cc = flow.ClientConnect(self.client_address)
cc._send(self.mqueue) cc._send(self.mqueue)
while not cc.close: while self.handle_request(cc) and not cc.close:
self.handle_request(cc) pass
cc.close = True
cd = flow.ClientDisconnect(cc) cd = flow.ClientDisconnect(cc)
cd._send(self.mqueue) cd._send(self.mqueue)
self.finish() self.finish()
def server_connect(self, scheme, host, port):
sc = self.server_conn
if sc and (scheme, host, port) != (sc.scheme, sc.host, sc.port):
sc.terminate()
self.server_conn = None
if not self.server_conn:
self.server_conn = ServerConnection(self.config, scheme, host, port)
def handle_request(self, cc): def handle_request(self, cc):
server_conn, request, err = None, None, None
try: try:
request, err = None, None
try: try:
request = self.read_request(cc) request = self.read_request(cc)
except IOError, v: except IOError, v:
raise IOError, "Reading request: %s"%v raise IOError, "Reading request: %s"%v
if request is None: if request is None:
cc.close = True
return return
cc.requestcount += 1 cc.requestcount += 1
@ -368,7 +390,6 @@ class ProxyHandler(SocketServer.StreamRequestHandler):
else: else:
request = request._send(self.mqueue) request = request._send(self.mqueue)
if request is None: if request is None:
cc.close = True
return return
if isinstance(request, flow.Response): if isinstance(request, flow.Response):
@ -380,31 +401,30 @@ class ProxyHandler(SocketServer.StreamRequestHandler):
scheme, host, port = self.config.reverse_proxy scheme, host, port = self.config.reverse_proxy
else: else:
scheme, host, port = request.scheme, request.host, request.port scheme, host, port = request.scheme, request.host, request.port
server_conn = ServerConnection(self.config, scheme, host, port) self.server_connect(scheme, host, port)
server_conn.send(request) self.server_conn.send(request)
try: try:
response = server_conn.read_response(request) response = self.server_conn.read_response(request)
except IOError, v: except IOError, v:
raise IOError, "Reading response: %s"%v raise IOError, "Reading response: %s"%v
response = response._send(self.mqueue) response = response._send(self.mqueue)
if response is None: if response is None:
server_conn.terminate() self.server_conn.terminate()
if response is None: if response is None:
cc.close = True
return return
self.send_response(response) self.send_response(response)
if should_connection_close(request.httpversion, request.headers):
return
except IOError, v: except IOError, v:
cc.connection_error = v cc.connection_error = v
cc.close = True
except ProxyError, e: except ProxyError, e:
cc.close = True
cc.connection_error = "%s: %s"%(e.code, e.msg) cc.connection_error = "%s: %s"%(e.code, e.msg)
if request: if request:
err = flow.Error(request, e.msg) err = flow.Error(request, e.msg)
err._send(self.mqueue) err._send(self.mqueue)
self.send_error(e.code, e.msg) self.send_error(e.code, e.msg)
if server_conn: else:
server_conn.terminate() return True
def find_cert(self, host, port): def find_cert(self, host, port):
if self.config.certfile: if self.config.certfile:
@ -435,26 +455,6 @@ class ProxyHandler(SocketServer.StreamRequestHandler):
self.rfile = FileLike(self.connection) self.rfile = FileLike(self.connection)
self.wfile = FileLike(self.connection) self.wfile = FileLike(self.connection)
def read_contents(self, client_conn, headers, httpversion):
if "expect" in headers:
# FIXME: Should be forwarded upstream
expect = ",".join(headers['expect'])
if expect == "100-continue" and httpversion >= (1, 1):
self.wfile.write('HTTP/1.1 100 Continue\r\n')
self.wfile.write('Proxy-agent: %s\r\n'%version.NAMEVERSION)
self.wfile.write('\r\n')
del headers['expect']
if httpversion < (1, 1):
client_conn.close = True
if "connection" in headers:
for value in ",".join(headers['connection']).split(","):
value = value.strip()
if value == "close":
client_conn.close = True
if value == "keep-alive":
client_conn.close = False
return read_http_body(self.rfile, client_conn, headers, False, self.config.body_size_limit)
def read_request(self, client_conn): def read_request(self, client_conn):
line = self.rfile.readline() line = self.rfile.readline()
if line == "\r\n" or line == "\n": # Possible leftover from previous message if line == "\r\n" or line == "\n": # Possible leftover from previous message
@ -466,34 +466,45 @@ class ProxyHandler(SocketServer.StreamRequestHandler):
scheme, host, port = self.config.reverse_proxy scheme, host, port = self.config.reverse_proxy
method, path, httpversion = parse_init_http(line) method, path, httpversion = parse_init_http(line)
headers = read_headers(self.rfile) headers = read_headers(self.rfile)
content = self.read_contents(client_conn, headers, httpversion) content = read_http_body_request(
self.rfile, self.wfile, client_conn, headers, httpversion, self.config.body_size_limit
)
return flow.Request(client_conn, httpversion, host, port, "http", method, path, headers, content) return flow.Request(client_conn, httpversion, host, port, "http", method, path, headers, content)
elif line.startswith("CONNECT"):
host, port, httpversion = parse_init_connect(line)
# FIXME: Discard additional headers sent to the proxy. Should I expose
# these to users?
while 1:
d = self.rfile.readline()
if d == '\r\n' or d == '\n':
break
self.wfile.write(
'HTTP/1.1 200 Connection established\r\n' +
('Proxy-agent: %s\r\n'%version.NAMEVERSION) +
'\r\n'
)
self.wfile.flush()
certfile = self.find_cert(host, port)
self.convert_to_ssl(certfile)
method, path, httpversion = parse_init_http(self.rfile.readline(line))
headers = read_headers(self.rfile)
content = self.read_contents(client_conn, headers, httpversion)
return flow.Request(client_conn, httpversion, host, port, "https", method, path, headers, content)
else: else:
method, scheme, host, port, path, httpversion = parse_init_proxy(line) if line.startswith("CONNECT"):
headers = read_headers(self.rfile) host, port, httpversion = parse_init_connect(line)
content = self.read_contents(client_conn, headers, httpversion) # FIXME: Discard additional headers sent to the proxy. Should I expose
return flow.Request(client_conn, httpversion, host, port, scheme, method, path, headers, content) # these to users?
while 1:
d = self.rfile.readline()
if d == '\r\n' or d == '\n':
break
self.wfile.write(
'HTTP/1.1 200 Connection established\r\n' +
('Proxy-agent: %s\r\n'%version.NAMEVERSION) +
'\r\n'
)
self.wfile.flush()
certfile = self.find_cert(host, port)
self.convert_to_ssl(certfile)
self.proxy_connect_state = (host, port, httpversion)
line = self.rfile.readline(line)
if self.proxy_connect_state:
host, port, httpversion = self.proxy_connect_state
method, path, httpversion = parse_init_http(line)
headers = read_headers(self.rfile)
content = read_http_body_request(
self.rfile, self.wfile, client_conn, headers, httpversion, self.config.body_size_limit
)
return flow.Request(client_conn, httpversion, host, port, "https", method, path, headers, content)
else:
method, scheme, host, port, path, httpversion = parse_init_proxy(line)
headers = read_headers(self.rfile)
content = read_http_body_request(
self.rfile, self.wfile, client_conn, headers, httpversion, self.config.body_size_limit
)
return flow.Request(client_conn, httpversion, host, port, scheme, method, path, headers, content)
def send_response(self, response): def send_response(self, response):
d = response._assemble() d = response._assemble()