Add a way for users to specify header significance in server replay.

Also add the --rheader command-line option to mitmdump to let the user specify
an arbitrary number of significant headers. The default is to treat no headers
as significant.
This commit is contained in:
Aldo Cortesi 2011-02-23 10:54:51 +13:00
parent c80214ba55
commit 39207ffdd2
4 changed files with 61 additions and 20 deletions

View File

@ -12,6 +12,7 @@ class Options(object):
"replay", "replay",
"verbosity", "verbosity",
"wfile", "wfile",
"rheaders",
] ]
def __init__(self, **kwargs): def __init__(self, **kwargs):
for k, v in kwargs.items(): for k, v in kwargs.items():
@ -52,7 +53,7 @@ class DumpMaster(flow.FlowMaster):
flows = list(flow.FlowReader(f).stream()) flows = list(flow.FlowReader(f).stream())
except IOError, v: except IOError, v:
raise DumpError(v.strerror) raise DumpError(v.strerror)
self.start_playback(flows, options.kill) self.start_playback(flows, options.kill, options.rheaders)
def _runscript(self, f, script): def _runscript(self, f, script):
try: try:

View File

@ -32,7 +32,12 @@ class RequestReplayThread(threading.Thread):
class ServerPlaybackState: class ServerPlaybackState:
def __init__(self): def __init__(self, headers):
"""
headers: A case-insensitive list of request headers that should be
included in request-response matching.
"""
self.headers = headers
self.fmap = {} self.fmap = {}
def count(self): def count(self):
@ -62,6 +67,15 @@ class ServerPlaybackState:
str(r.path), str(r.path),
str(r.content), str(r.content),
] ]
if self.headers:
hdrs = []
for i in self.headers:
v = r.headers.get(i, [])
# Slightly subtle: we need to convert everything to strings
# to prevent a mismatch between unicode/non-unicode.
v = [str(x) for x in v]
hdrs.append((i, v))
key.append(repr(hdrs))
return hashlib.sha256(repr(key)).digest() return hashlib.sha256(repr(key)).digest()
def next_flow(self, request): def next_flow(self, request):
@ -342,12 +356,12 @@ class FlowMaster(controller.Master):
def set_request_script(self, s): def set_request_script(self, s):
self.scripts["request"] = s self.scripts["request"] = s
def start_playback(self, flows, kill): def start_playback(self, flows, kill, headers):
""" """
flows: A list of flows. flows: A list of flows.
kill: Boolean, should we kill requests not part of the replay? kill: Boolean, should we kill requests not part of the replay?
""" """
self.playback = ServerPlaybackState() self.playback = ServerPlaybackState(headers)
self.playback.load(flows) self.playback.load(flows)
self.kill_nonreplay = kill self.kill_nonreplay = kill

View File

@ -28,31 +28,40 @@ if __name__ == '__main__':
) )
proxy.certificate_option_group(parser) proxy.certificate_option_group(parser)
parser.add_option( parser.add_option(
"-p", "--port", action="store", "-p", action="store",
type = "int", dest="port", default=8080, type = "int", dest="port", default=8080,
help = "Port." help = "Port."
) )
parser.add_option("-q", "--quiet", parser.add_option("-q",
action="store_true", dest="quiet", action="store_true", dest="quiet",
help="Quiet.") help="Quiet.")
parser.add_option("-v", "--verbose",
action="count", dest="verbose", default=1,
help="Increase verbosity. Can be passed multiple times.")
parser.add_option("-w", "--writefile",
action="store", dest="wfile", default=None,
help="Write flows to file.")
parser.add_option("", "--reqscript", parser.add_option("", "--reqscript",
action="store", dest="request_script", default=None, action="store", dest="request_script", default=None,
help="Script to run when a request is recieved.") help="Script to run when a request is recieved.")
parser.add_option("", "--respscript", parser.add_option("", "--respscript",
action="store", dest="response_script", default=None, action="store", dest="response_script", default=None,
help="Script to run when a response is recieved.") help="Script to run when a response is recieved.")
parser.add_option("-r", "--replay", parser.add_option("-v",
action="store", dest="replay", default=None, action="count", dest="verbose", default=1,
help="Increase verbosity. Can be passed multiple times.")
parser.add_option("-w",
action="store", dest="wfile", default=None,
help="Write flows to file.")
group = OptionGroup(parser, "Server Replay")
group.add_option("-r", action="store", dest="replay", default=None, metavar="PATH",
help="Replay server responses from a saved file.") help="Replay server responses from a saved file.")
parser.add_option("-k", "--kill", group.add_option("-k", "--kill",
action="store_true", dest="kill", default=False, action="store_true", dest="kill", default=False,
help="Kill extra requests during replay.") help="Kill extra requests during replay.")
group.add_option("--rheader",
action="append", dest="rheaders", type="str",
help="Request headers to be considered during replay. "
"Can be passed multiple times.")
parser.add_option_group(group)
options, args = parser.parse_args() options, args = parser.parse_args()
@ -60,6 +69,7 @@ if __name__ == '__main__':
if options.quiet: if options.quiet:
options.verbose = 0 options.verbose = 0
config = proxy.process_certificate_option_group(parser, options) config = proxy.process_certificate_option_group(parser, options)
server = proxy.ProxyServer(config, options.port) server = proxy.ProxyServer(config, options.port)
dumpopts = dump.Options( dumpopts = dump.Options(
@ -68,7 +78,8 @@ if __name__ == '__main__':
request_script = options.request_script, request_script = options.request_script,
response_script = options.response_script, response_script = options.response_script,
replay = options.replay, replay = options.replay,
kill = options.kill kill = options.kill,
rheaders = options.rheaders
) )
if args: if args:
filt = " ".join(args) filt = " ".join(args)

View File

@ -6,7 +6,7 @@ import libpry
class uServerPlaybackState(libpry.AutoTree): class uServerPlaybackState(libpry.AutoTree):
def test_hash(self): def test_hash(self):
s = flow.ServerPlaybackState() s = flow.ServerPlaybackState(None)
r = utils.tflow() r = utils.tflow()
r2 = utils.tflow() r2 = utils.tflow()
@ -17,8 +17,23 @@ class uServerPlaybackState(libpry.AutoTree):
r.request.path = "voing" r.request.path = "voing"
assert s._hash(r) != s._hash(r2) assert s._hash(r) != s._hash(r2)
def test_headers(self):
s = flow.ServerPlaybackState(["foo"])
r = utils.tflow_full()
r.request.headers["foo"] = ["bar"]
r2 = utils.tflow_full()
assert not s._hash(r) == s._hash(r2)
r2.request.headers["foo"] = ["bar"]
assert s._hash(r) == s._hash(r2)
r2.request.headers["oink"] = ["bar"]
assert s._hash(r) == s._hash(r2)
r = utils.tflow_full()
r2 = utils.tflow_full()
assert s._hash(r) == s._hash(r2)
def test_load(self): def test_load(self):
s = flow.ServerPlaybackState() s = flow.ServerPlaybackState(None)
r = utils.tflow_full() r = utils.tflow_full()
r.request.headers["key"] = ["one"] r.request.headers["key"] = ["one"]
@ -319,10 +334,10 @@ class uFlowMaster(libpry.AutoTree):
fm = flow.FlowMaster(None, s) fm = flow.FlowMaster(None, s)
assert not fm.do_playback(utils.tflow()) assert not fm.do_playback(utils.tflow())
fm.start_playback(pb, False) fm.start_playback(pb, False, [])
assert fm.do_playback(utils.tflow()) assert fm.do_playback(utils.tflow())
fm.start_playback(pb, False) fm.start_playback(pb, False, [])
r = utils.tflow() r = utils.tflow()
r.request.content = "gibble" r.request.content = "gibble"
assert not fm.do_playback(r) assert not fm.do_playback(r)