Add a simple server playback state object.

We use a loose hash to match incoming requests with recorded flows. At the
moment, this hash is over the host, port, scheme, method, path and content of
the request. Note that headers are not included here - if we do want to include
headers, we would have to do some work to normalize them to remove variations
between user agents, header order, etc. etc.
This commit is contained in:
Aldo Cortesi 2011-02-21 08:47:19 +13:00
parent aa16194518
commit deb79a9c5a
4 changed files with 87 additions and 7 deletions

View File

@ -160,7 +160,7 @@ class ConnectionItem(WWrap):
self.master.statusbar.message("Can't delete connection mid-intercept.")
self.master.sync_list_view()
elif key == "r":
r = self.state.replay(self.flow, self.master.masterq)
r = self.state.replay_request(self.flow, self.master.masterq)
if r:
self.master.statusbar.message(r)
self.master.sync_list_view()
@ -511,7 +511,7 @@ class ConnectionView(WWrap):
elif key == "p":
self.master.view_prev_flow(self.flow)
elif key == "r":
r = self.state.replay(self.flow, self.master.masterq)
r = self.state.replay_request(self.flow, self.master.masterq)
if r:
self.master.statusbar.message(r)
self.master.refresh_connection(self.flow)

View File

@ -2,7 +2,7 @@
This module provides more sophisticated flow tracking. These match requests
with their responses, and provide filtering and interception facilities.
"""
import subprocess, base64, sys, json
import subprocess, base64, sys, json, hashlib
import proxy, threading, netstring
import controller
@ -14,7 +14,7 @@ class RunException(Exception):
# begin nocover
class ReplayThread(threading.Thread):
class RequestReplayThread(threading.Thread):
def __init__(self, flow, masterq):
self.flow, self.masterq = flow, masterq
threading.Thread.__init__(self)
@ -31,6 +31,49 @@ class ReplayThread(threading.Thread):
# end nocover
class ServerPlaybackState:
def __init__(self):
self.fmap = {}
def __len__(self):
return sum([len(i) for i in self.fmap.values()])
def load(self, flows):
"""
Load a sequence of flows. We assume that the sequence is in
chronological order.
"""
for i in flows:
h = self._hash(i)
l = self.fmap.setdefault(self._hash(i), [])
l.append(i)
def _hash(self, flow):
"""
Calculates a loose hash of the flow request.
"""
r = flow.request
key = [
str(r.host),
str(r.port),
str(r.scheme),
str(r.method),
str(r.path),
str(r.content),
]
return hashlib.sha256(repr(key)).digest()
def next_flow(self, request):
"""
Returns the next flow object, or None if no matching flow was
found.
"""
l = self.fmap.get(self._hash(request))
if not l:
return None
return l.pop(0)
class Flow:
def __init__(self, request):
self.request = request
@ -262,7 +305,7 @@ class State:
def revert(self, f):
f.revert()
def replay(self, f, masterq):
def replay_request(self, f, masterq):
"""
Returns None if successful, or error message if not.
"""
@ -276,7 +319,7 @@ class State:
f.request.headers["content-length"] = [str(len(f.request.content))]
f.response = None
f.error = None
rt = ReplayThread(f, masterq)
rt = RequestReplayThread(f, masterq)
rt.start()
#end nocover

View File

@ -3,6 +3,43 @@ from libmproxy import console, proxy, filt, flow
import utils
import libpry
class uServerPlaybackState(libpry.AutoTree):
def test_hash(self):
s = flow.ServerPlaybackState()
r = utils.tflow()
r2 = utils.tflow()
assert s._hash(r)
assert s._hash(r) == s._hash(r2)
r.request.headers["foo"] = ["bar"]
assert s._hash(r) == s._hash(r2)
r.request.path = "voing"
assert s._hash(r) != s._hash(r2)
def test_load(self):
s = flow.ServerPlaybackState()
r = utils.tflow()
r.request.headers["key"] = ["one"]
r2 = utils.tflow()
r2.request.headers["key"] = ["two"]
s.load([r, r2])
assert len(s) == 2
assert len(s.fmap.keys()) == 1
n = s.next_flow(r)
assert n.request.headers["key"] == ["one"]
assert len(s) == 1
n = s.next_flow(r)
assert n.request.headers["key"] == ["two"]
assert len(s) == 0
assert not s.next_flow(r)
class uFlow(libpry.AutoTree):
def test_run_script(self):
f = utils.tflow()
@ -275,6 +312,7 @@ class uFlowMaster(libpry.AutoTree):
tests = [
uServerPlaybackState(),
uFlow(),
uState(),
uSerialize(),

View File

@ -3,7 +3,6 @@ from cStringIO import StringIO
import libpry
class uNetstring(libpry.AutoTree):
def setUp(self):
self.test_data = "Netstring module by Will McGugan"