Refactor Flow primitives to remove HTTP1.0 assumption.

This is a big patch removing the assumption that there's one connection per
Request/Response pair. It touches pretty much every part of mitmproxy, so
expect glitches until everything is ironed out.
This commit is contained in:
Aldo Cortesi 2011-02-19 17:00:24 +13:00
parent cd4eea3934
commit 5da27a9905
10 changed files with 141 additions and 194 deletions

View File

@ -66,7 +66,7 @@ def format_flow(f, focus, extended=False, padding=2):
f.request.url(),
),
]
if f.response or f.error or f.is_replay():
if f.response or f.error or f.request.is_replay():
tsr = f.response or f.error
if extended and tsr:
ts = ("highlight", utils.format_timestamp(tsr.timestamp) + " ")
@ -77,7 +77,7 @@ def format_flow(f, focus, extended=False, padding=2):
txt.append(("text", ts))
txt.append(" "*(padding+2))
met = ""
if f.is_replay():
if f.request.is_replay():
txt.append(("method", "[replay] "))
elif f.modified():
txt.append(("method", "[edited] "))
@ -715,17 +715,13 @@ class ConsoleState(flow.State):
self.last_script = ""
self.last_saveload = ""
def add_browserconnect(self, f):
flow.State.add_browserconnect(self, f)
def add_request(self, req):
f = flow.State.add_request(self, req)
if self.focus is None:
self.set_focus(0)
else:
self.set_focus(self.focus + 1)
def add_request(self, req):
if self.focus is None:
self.set_focus(0)
return flow.State.add_request(self, req)
return f
def add_response(self, resp):
if self.store is not None:
@ -1305,7 +1301,7 @@ class ConsoleMaster(flow.FlowMaster):
def process_flow(self, f, r):
if f.match(self.state.beep):
urwid.curses_display.curses.beep()
if f.match(self.state.intercept) and not f.is_replay():
if f.match(self.state.intercept) and not f.request.is_replay():
f.intercept()
else:
r.ack()
@ -1313,8 +1309,8 @@ class ConsoleMaster(flow.FlowMaster):
self.refresh_connection(f)
# Handlers
def handle_clientconnection(self, r):
f = flow.FlowMaster.handle_clientconnection(self, r)
def handle_clientconnect(self, r):
f = flow.FlowMaster.handle_clientconnect(self, r)
if f:
self.sync_list_view()

View File

@ -38,14 +38,6 @@ class DumpMaster(flow.FlowMaster):
except IOError, v:
raise DumpError(v.strerror)
def handle_clientconnection(self, r):
flow.FlowMaster.handle_clientconnection(self, r)
r.ack()
def handle_error(self, r):
flow.FlowMaster.handle_error(self, r)
r.ack()
def _runscript(self, f, script):
try:
ret = f.run_script(script)
@ -80,12 +72,12 @@ class DumpMaster(flow.FlowMaster):
return
sz = utils.pretty_size(len(f.response.content))
if self.o.verbosity == 1:
print >> self.outfile, f.client_conn.address[0],
print >> self.outfile, f.request.client_conn.address[0],
print >> self.outfile, f.request.short()
print >> self.outfile, " <<",
print >> self.outfile, f.response.short(), sz
elif self.o.verbosity == 2:
print >> self.outfile, f.client_conn.address[0],
print >> self.outfile, f.request.client_conn.address[0],
print >> self.outfile, f.request.short()
print >> self.outfile, self.indent(4, f.request.headers)
print >> self.outfile
@ -93,7 +85,7 @@ class DumpMaster(flow.FlowMaster):
print >> self.outfile, self.indent(4, f.response.headers)
print >> self.outfile, "\n"
elif self.o.verbosity == 3:
print >> self.outfile, f.client_conn.address[0],
print >> self.outfile, f.request.client_conn.address[0],
print >> self.outfile, f.request.short()
print >> self.outfile, self.indent(4, f.request.headers)
if utils.isBin(f.request.content):

View File

@ -32,9 +32,9 @@ class ReplayThread(threading.Thread):
class Flow:
def __init__(self, client_conn):
self.client_conn = client_conn
self.request, self.response, self.error = None, None, None
def __init__(self, request):
self.request = request
self.response, self.error = None, None
self.intercepting = False
self._backup = None
@ -90,7 +90,6 @@ class Flow:
request = self.request.get_state() if self.request else None,
response = self.response.get_state() if self.response else None,
error = self.error.get_state() if self.error else None,
client_conn = self.client_conn.get_state()
)
if nobackup:
d["backup"] = None
@ -99,10 +98,8 @@ class Flow:
return d
def load_state(self, state):
self.client_conn = proxy.ClientConnection.from_state(state["client_conn"])
self._backup = state["backup"]
if state["request"]:
self.request = proxy.Request.from_state(self.client_conn, state["request"])
self.request = proxy.Request.from_state(state["request"])
if state["response"]:
self.response = proxy.Response.from_state(self.request, state["response"])
if state["error"]:
@ -141,9 +138,6 @@ class Flow:
return pattern(self.request)
return False
def is_replay(self):
return self.client_conn.is_replay()
def kill(self):
if self.request and not self.request.acked:
self.request.ack(None)
@ -165,35 +159,43 @@ class Flow:
class State:
def __init__(self):
self.client_connections = []
self.flow_map = {}
self.flow_list = []
# These are compiled filt expressions:
self.limit = None
self.intercept = None
def add_browserconnect(self, f):
def clientconnect(self, cc):
if not isinstance(cc, proxy.ClientConnect):
assert False
self.client_connections.append(cc)
def clientdisconnect(self, dc):
"""
Start a browser connection.
"""
self.flow_list.insert(0, f)
self.flow_map[f.client_conn] = f
self.client_connections.remove(dc.client_conn)
def add_request(self, req):
"""
Add a request to the state. Returns the matching flow.
"""
f = self.flow_map.get(req.client_conn)
if not f:
f = Flow(req.client_conn)
self.add_browserconnect(f)
f.request = req
if not isinstance(req, proxy.Request):
assert False
f = Flow(req)
self.flow_list.insert(0, f)
self.flow_map[req] = f
return f
def add_response(self, resp):
"""
Add a response to the state. Returns the matching flow.
"""
f = self.flow_map.get(resp.request.client_conn)
if not isinstance(resp, proxy.Response):
assert False
f = self.flow_map.get(resp.request)
if not f:
return False
f.response = resp
@ -204,7 +206,7 @@ class State:
Add an error response to the state. Returns the matching flow, or
None if there isn't one.
"""
f = self.flow_map.get(err.client_conn)
f = self.flow_map.get(err.flow.request)
if not f:
return None
f.error = err
@ -213,7 +215,7 @@ class State:
def load_flows(self, flows):
self.flow_list.extend(flows)
for i in flows:
self.flow_map[i.client_conn] = i
self.flow_map[i.request] = i
def set_limit(self, limit):
"""
@ -229,27 +231,17 @@ class State:
return tuple(self.flow_list[:])
def get_client_conn(self, itm):
if isinstance(itm, proxy.ClientConnection):
if isinstance(itm, proxy.ClientConnect):
return itm
elif hasattr(itm, "client_conn"):
return itm.client_conn
elif hasattr(itm, "request"):
return itm.request.client_conn
def lookup(self, itm):
"""
Checks for matching client_conn, using a Flow, Replay Connection,
ClientConnection, Request, Response or Error object. Returns None
if not found.
"""
client_conn = self.get_client_conn(itm)
return self.flow_map.get(client_conn)
def delete_flow(self, f):
if not f.intercepting:
c = self.get_client_conn(f)
if c in self.flow_map:
del self.flow_map[c]
if f.request in self.flow_map:
del self.flow_map[f.request]
self.flow_list.remove(f)
return True
return False
@ -280,7 +272,7 @@ class State:
if f.request:
f.backup()
conn = self.get_client_conn(f)
f.client_conn.set_replay()
f.request.set_replay()
if f.request.content:
f.request.headers["content-length"] = [str(len(f.request.content))]
f.response = None
@ -295,12 +287,13 @@ class FlowMaster(controller.Master):
controller.Master.__init__(self, server)
self.state = state
# Handlers
def handle_clientconnection(self, r):
f = Flow(r)
self.state.add_browserconnect(f)
def handle_clientconnect(self, r):
self.state.clientconnect(r)
r.ack()
def handle_clientdisconnect(self, r):
self.state.clientdisconnect(r)
r.ack()
return f
def handle_error(self, r):
f = self.state.add_error(r)

View File

@ -136,11 +136,21 @@ class Request(controller.Msg):
self.close = False
controller.Msg.__init__(self)
def set_replay(self):
self.client_conn = None
def is_replay(self):
if self.client_conn:
return False
else:
return True
def is_cached(self):
return False
def get_state(self):
return dict(
client_conn = self.client_conn.get_state(),
host = self.host,
port = self.port,
scheme = self.scheme,
@ -152,9 +162,9 @@ class Request(controller.Msg):
)
@classmethod
def from_state(klass, client_conn, state):
def from_state(klass, state):
return klass(
client_conn,
ClientConnect.from_state(state["client_conn"]),
state["host"],
state["port"],
state["scheme"],
@ -165,6 +175,9 @@ class Request(controller.Msg):
state["timestamp"]
)
def __hash__(self):
return id(self)
def __eq__(self, other):
return self.get_state() == other.get_state()
@ -296,7 +309,13 @@ class Response(controller.Msg):
return self.FMT%data
class ClientConnection(controller.Msg):
class ClientDisconnect(controller.Msg):
def __init__(self, client_conn):
controller.Msg.__init__(self)
self.client_conn = client_conn
class ClientConnect(controller.Msg):
def __init__(self, address):
"""
address is an (address, port) tuple, or None if this connection has
@ -313,22 +332,13 @@ class ClientConnection(controller.Msg):
def from_state(klass, state):
return klass(state)
def set_replay(self):
self.address = None
def is_replay(self):
if self.address:
return False
else:
return True
def copy(self):
return copy.copy(self)
class Error(controller.Msg):
def __init__(self, client_conn, msg, timestamp=None):
self.client_conn, self.msg = client_conn, msg
def __init__(self, flow, msg, timestamp=None):
self.flow, self.msg = flow, msg
self.timestamp = timestamp or time.time()
controller.Msg.__init__(self)
@ -453,11 +463,12 @@ class ProxyHandler(SocketServer.StreamRequestHandler):
SocketServer.StreamRequestHandler.__init__(self, request, client_address, server)
def handle(self):
cc = ClientConnection(self.client_address)
cc = ClientConnect(self.client_address)
cc.send(self.mqueue)
while not cc.close:
self.handle_request(cc)
cc = cc.copy()
cd = ClientDisconnect(cc)
cd.send(self.mqueue)
self.finish()
def handle_request(self, cc):

View File

@ -10,11 +10,9 @@ class uState(libpry.AutoTree):
connect -> request -> response
"""
bc = proxy.ClientConnection(("address", 22))
c = console.ConsoleState()
f = flow.Flow(bc)
c.add_browserconnect(f)
assert c.lookup(bc)
f = self._add_request(c)
assert f.request in c.flow_map
assert c.get_focus() == (f, 0)
def test_focus(self):
@ -24,18 +22,14 @@ class uState(libpry.AutoTree):
connect -> request -> response
"""
c = console.ConsoleState()
f = self._add_request(c)
bc = proxy.ClientConnection(("address", 22))
f = flow.Flow(bc)
c.add_browserconnect(f)
assert c.get_focus() == (f, 0)
assert c.get_from_pos(0) == (f, 0)
assert c.get_from_pos(1) == (None, None)
assert c.get_next(0) == (None, None)
bc2 = proxy.ClientConnection(("address", 22))
f2 = flow.Flow(bc2)
c.add_browserconnect(f2)
f2 = self._add_request(c)
assert c.get_focus() == (f, 1)
assert c.get_next(0) == (f, 1)
assert c.get_prev(1) == (f2, 0)
@ -52,25 +46,14 @@ class uState(libpry.AutoTree):
assert c.get_focus() == (None, None)
def _add_request(self, state):
f = utils.tflow()
state.add_browserconnect(f)
q = utils.treq(f.client_conn)
state.add_request(q)
return f
r = utils.treq()
return state.add_request(r)
def _add_response(self, state):
f = self._add_request(state)
r = utils.tresp(f.request)
state.add_response(r)
def test_add_request(self):
c = console.ConsoleState()
f = utils.tflow()
c.add_browserconnect(f)
q = utils.treq(f.client_conn)
c.focus = None
assert c.add_request(q)
def test_add_response(self):
c = console.ConsoleState()
f = self._add_request(c)
@ -118,11 +101,12 @@ class uformat_flow(libpry.AutoTree):
assert ('method', '[edited] ') in console.format_flow(f, True)
assert ('method', '[edited] ') in console.format_flow(f, True, True)
f.client_conn = proxy.ClientConnection(None)
f.request.set_replay()
assert ('method', '[replay] ') in console.format_flow(f, True)
assert ('method', '[replay] ') in console.format_flow(f, True, True)
class uPathCompleter(libpry.AutoTree):
def test_lookup_construction(self):
c = console._PathCompleter()

View File

@ -10,7 +10,7 @@ class uDumpMaster(libpry.AutoTree):
req = utils.treq()
cc = req.client_conn
resp = utils.tresp(req)
m.handle_clientconnection(cc)
m.handle_clientconnect(cc)
m.handle_request(req)
m.handle_response(resp)

View File

@ -72,7 +72,7 @@ class uParsing(libpry.AutoTree):
class uMatching(libpry.AutoTree):
def req(self):
conn = proxy.ClientConnection(("one", 2222))
conn = proxy.ClientConnect(("one", 2222))
headers = utils.Headers()
headers["header"] = ["qvalue"]
return proxy.Request(

View File

@ -46,38 +46,6 @@ class uFlow(libpry.AutoTree):
state = f.get_state()
assert f == flow.Flow.from_state(state)
def test_simple(self):
f = utils.tflow()
assert console.format_flow(f, True)
assert console.format_flow(f, False)
f.request = utils.treq()
assert console.format_flow(f, True)
assert console.format_flow(f, False)
f.response = utils.tresp()
f.response.headers["content-type"] = ["text/html"]
assert console.format_flow(f, True)
assert console.format_flow(f, False)
f.response.code = 404
assert console.format_flow(f, True)
assert console.format_flow(f, False)
assert console.format_flow(f, True)
assert console.format_flow(f, False)
f.client_conn.set_replay()
assert console.format_flow(f, True)
assert console.format_flow(f, False)
f.response = None
assert console.format_flow(f, True)
assert console.format_flow(f, False)
f.error = proxy.Error(200, "test")
assert console.format_flow(f, True)
assert console.format_flow(f, False)
def test_kill(self):
f = utils.tflow()
f.request = utils.treq()
@ -115,10 +83,10 @@ class uFlow(libpry.AutoTree):
class uState(libpry.AutoTree):
def test_backup(self):
bc = proxy.ClientConnection(("address", 22))
bc = proxy.ClientConnect(("address", 22))
c = flow.State()
f = flow.Flow(bc)
c.add_browserconnect(f)
req = utils.treq()
f = c.add_request(req)
f.backup()
c.revert(f)
@ -129,92 +97,98 @@ class uState(libpry.AutoTree):
connect -> request -> response
"""
bc = proxy.ClientConnection(("address", 22))
bc = proxy.ClientConnect(("address", 22))
c = flow.State()
f = flow.Flow(bc)
c.add_browserconnect(f)
assert c.lookup(bc)
c.clientconnect(bc)
assert len(c.client_connections) == 1
req = utils.treq(bc)
assert c.add_request(req)
f = c.add_request(req)
assert f
assert len(c.flow_list) == 1
assert c.lookup(req)
assert c.flow_map.get(req)
newreq = utils.treq()
assert c.add_request(newreq)
assert c.lookup(newreq)
assert c.flow_map.get(newreq)
resp = utils.tresp(req)
assert c.add_response(resp)
assert len(c.flow_list) == 2
assert c.lookup(resp)
assert c.flow_map.get(resp.request)
newresp = utils.tresp()
assert not c.add_response(newresp)
assert not c.lookup(newresp)
assert not c.flow_map.get(newresp.request)
dc = proxy.ClientDisconnect(bc)
c.clientdisconnect(dc)
assert not c.client_connections
def test_err(self):
bc = proxy.ClientConnection(("address", 22))
bc = proxy.ClientConnect(("address", 22))
c = flow.State()
f = flow.Flow(bc)
c.add_browserconnect(f)
e = proxy.Error(bc, "message")
req = utils.treq()
f = c.add_request(req)
e = proxy.Error(f, "message")
assert c.add_error(e)
e = proxy.Error(proxy.ClientConnection(("address", 22)), "message")
e = proxy.Error(utils.tflow(), "message")
assert not c.add_error(e)
def test_view(self):
c = flow.State()
f = utils.tflow()
c.add_browserconnect(f)
assert len(c.view) == 1
c.set_limit(filt.parse("~q"))
req = utils.treq()
c.clientconnect(req.client_conn)
assert len(c.view) == 0
f = c.add_request(req)
assert len(c.view) == 1
c.set_limit(filt.parse("~s"))
assert len(c.view) == 0
resp = utils.tresp(req)
c.add_response(resp)
assert len(c.view) == 1
c.set_limit(None)
assert len(c.view) == 1
f = utils.tflow()
req = utils.treq(f.client_conn)
c.add_browserconnect(f)
req = utils.treq()
c.clientconnect(req.client_conn)
c.add_request(req)
assert len(c.view) == 2
c.set_limit(filt.parse("~q"))
assert len(c.view) == 1
c.set_limit(filt.parse("~s"))
assert len(c.view) == 0
assert len(c.view) == 1
def _add_request(self, state):
f = utils.tflow()
state.add_browserconnect(f)
q = utils.treq(f.client_conn)
state.add_request(q)
req = utils.treq()
f = state.add_request(req)
return f
def _add_response(self, state):
f = self._add_request(state)
r = utils.tresp(f.request)
state.add_response(r)
req = utils.treq()
f = state.add_request(req)
resp = utils.tresp(req)
state.add_response(resp)
def _add_error(self, state):
f = utils.tflow()
f.error = proxy.Error(None, "msg")
state.add_browserconnect(f)
q = utils.treq(f.client_conn)
state.add_request(q)
req = utils.treq()
f = state.add_request(req)
f.error = proxy.Error(f, "msg")
def test_kill_flow(self):
c = flow.State()
f = utils.tflow()
c.add_browserconnect(f)
req = utils.treq()
f = c.add_request(req)
c.kill_flow(f)
assert not c.flow_list
def test_clear(self):
c = flow.State()
f = utils.tflow()
c.add_browserconnect(f)
f = self._add_request(c)
f.intercepting = True
c.clear()
@ -265,15 +239,12 @@ class uFlowMaster(libpry.AutoTree):
def test_one(self):
s = flow.State()
f = flow.FlowMaster(None, s)
req = utils.treq()
f.handle_clientconnection(req.client_conn)
assert len(s.flow_list) == 1
f.handle_request(req)
assert len(s.flow_list) == 1
f.handle_request(req)
resp = utils.tresp()
resp.request = req
resp = utils.tresp(req)
f.handle_response(resp)
assert len(s.flow_list) == 1

View File

@ -213,7 +213,7 @@ class uRequest(libpry.AutoTree):
def test_simple(self):
h = utils.Headers()
h["test"] = ["test"]
c = proxy.ClientConnection(("addr", 2222))
c = proxy.ClientConnect(("addr", 2222))
r = proxy.Request(c, "host", 22, "https", "GET", "/", h, "content")
u = r.url()
assert r.set_url(u)
@ -225,17 +225,17 @@ class uRequest(libpry.AutoTree):
def test_getset_state(self):
h = utils.Headers()
h["test"] = ["test"]
c = proxy.ClientConnection(("addr", 2222))
c = proxy.ClientConnect(("addr", 2222))
r = proxy.Request(c, "host", 22, "https", "GET", "/", h, "content")
state = r.get_state()
assert proxy.Request.from_state(c, state) == r
assert proxy.Request.from_state(state) == r
class uResponse(libpry.AutoTree):
def test_simple(self):
h = utils.Headers()
h["test"] = ["test"]
c = proxy.ClientConnection(("addr", 2222))
c = proxy.ClientConnect(("addr", 2222))
req = proxy.Request(c, "host", 22, "https", "GET", "/", h, "content")
resp = proxy.Response(req, 200, "msg", h.copy(), "content")
assert resp.short()
@ -244,7 +244,7 @@ class uResponse(libpry.AutoTree):
def test_getset_state(self):
h = utils.Headers()
h["test"] = ["test"]
c = proxy.ClientConnection(("addr", 2222))
c = proxy.ClientConnect(("addr", 2222))
r = proxy.Request(c, "host", 22, "https", "GET", "/", h, "content")
req = proxy.Request(c, "host", 22, "https", "GET", "/", h, "content")
resp = proxy.Response(req, 200, "msg", h.copy(), "content")

View File

@ -2,7 +2,7 @@ from libmproxy import proxy, utils, filt, flow
def treq(conn=None):
if not conn:
conn = proxy.ClientConnection(("address", 22))
conn = proxy.ClientConnect(("address", 22))
headers = utils.Headers()
headers["header"] = ["qvalue"]
return proxy.Request(conn, "host", 80, "http", "GET", "/path", headers, "content")
@ -17,6 +17,6 @@ def tresp(req=None):
def tflow():
bc = proxy.ClientConnection(("address", 22))
return flow.Flow(bc)
r = treq()
return flow.Flow(r)