mitmproxy/libmproxy/flow.py
2015-05-30 12:03:28 +12:00

1101 lines
30 KiB
Python

"""
This module provides more sophisticated flow tracking and provides filtering and interception facilities.
"""
from __future__ import absolute_import
from abc import abstractmethod, ABCMeta
import hashlib
import Cookie
import cookielib
import os
import re
from netlib import odict, wsgi, tcp
import netlib.http
from . import controller, protocol, tnetstring, filt, script, version
from .onboarding import app
from .protocol import http, handle
from .proxy.config import HostMatcher
from .proxy.connection import ClientConnection, ServerConnection
import urlparse
class AppRegistry:
def __init__(self):
self.apps = {}
def add(self, app, domain, port):
"""
Add a WSGI app to the registry, to be served for requests to the
specified domain, on the specified port.
"""
self.apps[(domain, port)] = wsgi.WSGIAdaptor(
app,
domain,
port,
version.NAMEVERSION
)
def get(self, request):
"""
Returns an WSGIAdaptor instance if request matches an app, or None.
"""
if (request.host, request.port) in self.apps:
return self.apps[(request.host, request.port)]
if "host" in request.headers:
host = request.headers["host"][0]
return self.apps.get((host, request.port), None)
class ReplaceHooks:
def __init__(self):
self.lst = []
def set(self, r):
self.clear()
for i in r:
self.add(*i)
def add(self, fpatt, rex, s):
"""
add a replacement hook.
fpatt: a string specifying a filter pattern.
rex: a regular expression.
s: the replacement string
returns true if hook was added, false if the pattern could not be
parsed.
"""
cpatt = filt.parse(fpatt)
if not cpatt:
return False
try:
re.compile(rex)
except re.error:
return False
self.lst.append((fpatt, rex, s, cpatt))
return True
def get_specs(self):
"""
Retrieve the hook specifcations. Returns a list of (fpatt, rex, s)
tuples.
"""
return [i[:3] for i in self.lst]
def count(self):
return len(self.lst)
def run(self, f):
for _, rex, s, cpatt in self.lst:
if cpatt(f):
if f.response:
f.response.replace(rex, s)
else:
f.request.replace(rex, s)
def clear(self):
self.lst = []
class SetHeaders:
def __init__(self):
self.lst = []
def set(self, r):
self.clear()
for i in r:
self.add(*i)
def add(self, fpatt, header, value):
"""
Add a set header hook.
fpatt: String specifying a filter pattern.
header: Header name.
value: Header value string
Returns True if hook was added, False if the pattern could not be
parsed.
"""
cpatt = filt.parse(fpatt)
if not cpatt:
return False
self.lst.append((fpatt, header, value, cpatt))
return True
def get_specs(self):
"""
Retrieve the hook specifcations. Returns a list of (fpatt, rex, s)
tuples.
"""
return [i[:3] for i in self.lst]
def count(self):
return len(self.lst)
def clear(self):
self.lst = []
def run(self, f):
for _, header, value, cpatt in self.lst:
if cpatt(f):
if f.response:
del f.response.headers[header]
else:
del f.request.headers[header]
for _, header, value, cpatt in self.lst:
if cpatt(f):
if f.response:
f.response.headers.add(header, value)
else:
f.request.headers.add(header, value)
class StreamLargeBodies(object):
def __init__(self, max_size):
self.max_size = max_size
def run(self, flow, is_request):
r = flow.request if is_request else flow.response
code = flow.response.code if flow.response else None
expected_size = netlib.http.expected_http_body_size(
r.headers, is_request, flow.request.method, code
)
if not (0 <= expected_size <= self.max_size):
# r.stream may already be a callable, which we want to preserve.
r.stream = r.stream or True
class ClientPlaybackState:
def __init__(self, flows, exit):
self.flows, self.exit = flows, exit
self.current = None
self.testing = False # Disables actual replay for testing.
def count(self):
return len(self.flows)
def done(self):
if len(self.flows) == 0 and not self.current:
return True
return False
def clear(self, flow):
"""
A request has returned in some way - if this is the one we're
servicing, go to the next flow.
"""
if flow is self.current:
self.current = None
def tick(self, master):
if self.flows and not self.current:
self.current = self.flows.pop(0).copy()
if not self.testing:
master.replay_request(self.current)
else:
self.current.reply = controller.DummyReply()
master.handle_request(self.current)
if self.current.response:
master.handle_response(self.current)
class ServerPlaybackState:
def __init__(
self,
headers,
flows,
exit,
nopop,
ignore_params,
ignore_content,
ignore_payload_params,
ignore_host):
"""
headers: Case-insensitive list of request headers that should be
included in request-response matching.
"""
self.headers = headers
self.exit = exit
self.nopop = nopop
self.ignore_params = ignore_params
self.ignore_content = ignore_content
self.ignore_payload_params = ignore_payload_params
self.ignore_host = ignore_host
self.fmap = {}
for i in flows:
if i.response:
l = self.fmap.setdefault(self._hash(i), [])
l.append(i)
def count(self):
return sum(len(i) for i in self.fmap.values())
def _hash(self, flow):
"""
Calculates a loose hash of the flow request.
"""
r = flow.request
_, _, path, _, query, _ = urlparse.urlparse(r.url)
queriesArray = urlparse.parse_qsl(query, keep_blank_values=True)
key = [
str(r.port),
str(r.scheme),
str(r.method),
str(path),
]
if not self.ignore_content:
form_contents = r.get_form()
if self.ignore_payload_params and form_contents:
key.extend(
p for p in form_contents
if p[0] not in self.ignore_payload_params
)
else:
key.append(str(r.content))
if not self.ignore_host:
key.append(r.host)
filtered = []
ignore_params = self.ignore_params or []
for p in queriesArray:
if p[0] not in ignore_params:
filtered.append(p)
for p in filtered:
key.append(p[0])
key.append(p[1])
if self.headers:
hdrs = []
for i in self.headers:
v = r.headers[i]
# Slightly subtle: we need to convert everything to strings
# to prevent a mismatch between unicode/non-unicode.
v = [str(x) for x in v]
hdrs.append((i, v))
key.append(hdrs)
return hashlib.sha256(repr(key)).digest()
def next_flow(self, request):
"""
Returns the next flow object, or None if no matching flow was
found.
"""
l = self.fmap.get(self._hash(request))
if not l:
return None
if self.nopop:
return l[0]
else:
return l.pop(0)
class StickyCookieState:
def __init__(self, flt):
"""
flt: Compiled filter.
"""
self.jar = {}
self.flt = flt
def ckey(self, m, f):
"""
Returns a (domain, port, path) tuple.
"""
return (
m["domain"] or f.request.host,
f.request.port,
m["path"] or "/"
)
def domain_match(self, a, b):
if cookielib.domain_match(a, b):
return True
elif cookielib.domain_match(a, b.strip(".")):
return True
return False
def handle_response(self, f):
for i in f.response.headers["set-cookie"]:
# FIXME: We now know that Cookie.py screws up some cookies with
# valid RFC 822/1123 datetime specifications for expiry. Sigh.
c = Cookie.SimpleCookie(str(i))
for m in c.values():
k = self.ckey(m, f)
if self.domain_match(f.request.host, k[0]):
self.jar[k] = m
def handle_request(self, f):
l = []
if f.match(self.flt):
for i in self.jar.keys():
match = [
self.domain_match(f.request.host, i[0]),
f.request.port == i[1],
f.request.path.startswith(i[2])
]
if all(match):
l.append(self.jar[i].output(header="").strip())
if l:
f.request.stickycookie = True
f.request.headers["cookie"] = l
class StickyAuthState:
def __init__(self, flt):
"""
flt: Compiled filter.
"""
self.flt = flt
self.hosts = {}
def handle_request(self, f):
host = f.request.host
if "authorization" in f.request.headers:
self.hosts[host] = f.request.headers["authorization"]
elif f.match(self.flt):
if host in self.hosts:
f.request.headers["authorization"] = self.hosts[host]
class FlowList(object):
__metaclass__ = ABCMeta
def __iter__(self):
return iter(self._list)
def __contains__(self, item):
return item in self._list
def __getitem__(self, item):
return self._list[item]
def __nonzero__(self):
return bool(self._list)
def __len__(self):
return len(self._list)
def index(self, f):
return self._list.index(f)
@abstractmethod
def _add(self, f):
return
@abstractmethod
def _update(self, f):
return
@abstractmethod
def _remove(self, f):
return
class FlowView(FlowList):
def __init__(self, store, filt=None):
self._list = []
if not filt:
filt = lambda flow: True
self._build(store, filt)
self.store = store
self.store.views.append(self)
def _close(self):
self.store.views.remove(self)
def _build(self, flows, filt=None):
if filt:
self.filt = filt
self._list = list(filter(self.filt, flows))
def _add(self, f):
if self.filt(f):
self._list.append(f)
def _update(self, f):
if f not in self._list:
self._add(f)
elif not self.filt(f):
self._remove(f)
def _remove(self, f):
if f in self._list:
self._list.remove(f)
def _recalculate(self, flows):
self._build(flows)
class FlowStore(FlowList):
"""
Responsible for handling flows in the state:
Keeps a list of all flows and provides views on them.
"""
def __init__(self):
self._list = []
self._set = set() # Used for O(1) lookups
self.views = []
self._recalculate_views()
def get(self, flow_id):
for f in self._list:
if f.id == flow_id:
return f
def __contains__(self, f):
return f in self._set
def _add(self, f):
"""
Adds a flow to the state.
The flow to add must not be present in the state.
"""
self._list.append(f)
self._set.add(f)
for view in self.views:
view._add(f)
def _update(self, f):
"""
Notifies the state that a flow has been updated.
The flow must be present in the state.
"""
if f in self:
for view in self.views:
view._update(f)
def _remove(self, f):
"""
Deletes a flow from the state.
The flow must be present in the state.
"""
self._list.remove(f)
self._set.remove(f)
for view in self.views:
view._remove(f)
# Expensive bulk operations
def _extend(self, flows):
"""
Adds a list of flows to the state.
The list of flows to add must not contain flows that are already in the state.
"""
self._list.extend(flows)
self._set.update(flows)
self._recalculate_views()
def _clear(self):
self._list = []
self._set = set()
self._recalculate_views()
def _recalculate_views(self):
"""
Expensive operation: Recalculate all the views after a bulk change.
"""
for view in self.views:
view._recalculate(self)
# Utility functions.
# There are some common cases where we need to argue about all flows
# irrespective of filters on the view etc (i.e. on shutdown).
def active_count(self):
c = 0
for i in self._list:
if not i.response and not i.error:
c += 1
return c
# TODO: Should accept_all operate on views or on all flows?
def accept_all(self, master):
for f in self._list:
f.accept_intercept(master)
def kill_all(self, master):
for f in self._list:
f.kill(master)
class State(object):
def __init__(self):
self.flows = FlowStore()
self.view = FlowView(self.flows, None)
# These are compiled filt expressions:
self.intercept = None
@property
def limit_txt(self):
return getattr(self.view.filt, "pattern", None)
def flow_count(self):
return len(self.flows)
# TODO: All functions regarding flows that don't cause side-effects should
# be moved into FlowStore.
def index(self, f):
return self.flows.index(f)
def active_flow_count(self):
return self.flows.active_count()
def add_flow(self, f):
"""
Add a request to the state.
"""
self.flows._add(f)
return f
def update_flow(self, f):
"""
Add a response to the state.
"""
self.flows._update(f)
return f
def delete_flow(self, f):
self.flows._remove(f)
def load_flows(self, flows):
self.flows._extend(flows)
def set_limit(self, txt):
if txt == self.limit_txt:
return
if txt:
f = filt.parse(txt)
if not f:
return "Invalid filter expression."
self.view._close()
self.view = FlowView(self.flows, f)
else:
self.view._close()
self.view = FlowView(self.flows, None)
def set_intercept(self, txt):
if txt:
f = filt.parse(txt)
if not f:
return "Invalid filter expression."
self.intercept = f
else:
self.intercept = None
@property
def intercept_txt(self):
return getattr(self.intercept, "pattern", None)
def clear(self):
self.flows._clear()
def accept_all(self, master):
self.flows.accept_all(master)
def backup(self, f):
f.backup()
self.update_flow(f)
def revert(self, f):
f.revert()
self.update_flow(f)
def killall(self, master):
self.flows.kill_all(master)
class FlowMaster(controller.Master):
def __init__(self, server, state):
controller.Master.__init__(self, server)
self.state = state
self.server_playback = None
self.client_playback = None
self.kill_nonreplay = False
self.scripts = []
self.pause_scripts = False
self.stickycookie_state = False
self.stickycookie_txt = None
self.stickyauth_state = False
self.stickyauth_txt = None
self.anticache = False
self.anticomp = False
self.stream_large_bodies = False
self.refresh_server_playback = False
self.replacehooks = ReplaceHooks()
self.setheaders = SetHeaders()
self.replay_ignore_params = False
self.replay_ignore_content = None
self.replay_ignore_host = False
self.stream = None
self.apps = AppRegistry()
def start_app(self, host, port):
self.apps.add(
app.mapp,
host,
port
)
def add_event(self, e, level="info"):
"""
level: debug, info, error
"""
pass
def unload_scripts(self):
for s in self.scripts[:]:
self.unload_script(s)
def unload_script(self, script):
script.unload()
self.scripts.remove(script)
def load_script(self, command):
"""
Loads a script. Returns an error description if something went
wrong.
"""
try:
s = script.Script(command, self)
except script.ScriptError as v:
return v.args[0]
self.scripts.append(s)
def run_single_script_hook(self, script, name, *args, **kwargs):
if script and not self.pause_scripts:
ret = script.run(name, *args, **kwargs)
if not ret[0] and ret[1]:
e = "Script error:\n" + ret[1][1]
self.add_event(e, "error")
def run_script_hook(self, name, *args, **kwargs):
for script in self.scripts:
self.run_single_script_hook(script, name, *args, **kwargs)
def get_ignore_filter(self):
return self.server.config.check_ignore.patterns
def set_ignore_filter(self, host_patterns):
self.server.config.check_ignore = HostMatcher(host_patterns)
def get_tcp_filter(self):
return self.server.config.check_tcp.patterns
def set_tcp_filter(self, host_patterns):
self.server.config.check_tcp = HostMatcher(host_patterns)
def set_stickycookie(self, txt):
if txt:
flt = filt.parse(txt)
if not flt:
return "Invalid filter expression."
self.stickycookie_state = StickyCookieState(flt)
self.stickycookie_txt = txt
else:
self.stickycookie_state = None
self.stickycookie_txt = None
def set_stream_large_bodies(self, max_size):
if max_size is not None:
self.stream_large_bodies = StreamLargeBodies(max_size)
else:
self.stream_large_bodies = False
def set_stickyauth(self, txt):
if txt:
flt = filt.parse(txt)
if not flt:
return "Invalid filter expression."
self.stickyauth_state = StickyAuthState(flt)
self.stickyauth_txt = txt
else:
self.stickyauth_state = None
self.stickyauth_txt = None
def start_client_playback(self, flows, exit):
"""
flows: List of flows.
"""
self.client_playback = ClientPlaybackState(flows, exit)
def stop_client_playback(self):
self.client_playback = None
def start_server_playback(
self,
flows,
kill,
headers,
exit,
nopop,
ignore_params,
ignore_content,
ignore_payload_params,
ignore_host):
"""
flows: List of flows.
kill: Boolean, should we kill requests not part of the replay?
ignore_params: list of parameters to ignore in server replay
ignore_content: true if request content should be ignored in server replay
ignore_payload_params: list of content params to ignore in server replay
ignore_host: true if request host should be ignored in server replay
"""
self.server_playback = ServerPlaybackState(
headers,
flows,
exit,
nopop,
ignore_params,
ignore_content,
ignore_payload_params,
ignore_host)
self.kill_nonreplay = kill
def stop_server_playback(self):
if self.server_playback.exit:
self.shutdown()
self.server_playback = None
def do_server_playback(self, flow):
"""
This method should be called by child classes in the handle_request
handler. Returns True if playback has taken place, None if not.
"""
if self.server_playback:
rflow = self.server_playback.next_flow(flow)
if not rflow:
return None
response = http.HTTPResponse.from_state(rflow.response.get_state())
response.is_replay = True
if self.refresh_server_playback:
response.refresh()
flow.reply(response)
if self.server_playback.count() == 0:
self.stop_server_playback()
return True
return None
def tick(self, q, timeout):
if self.client_playback:
e = [
self.client_playback.done(),
self.client_playback.exit,
self.state.active_flow_count() == 0
]
if all(e):
self.shutdown()
self.client_playback.tick(self)
if self.client_playback.done():
self.client_playback = None
return super(FlowMaster, self).tick(q, timeout)
def duplicate_flow(self, f):
return self.load_flow(f.copy())
def create_request(self, method, scheme, host, port, path):
"""
this method creates a new artificial and minimalist request also adds it to flowlist
"""
c = ClientConnection.from_state(dict(
address=dict(address=(host, port), use_ipv6=False),
clientcert=None
))
s = ServerConnection.from_state(dict(
address=dict(address=(host, port), use_ipv6=False),
state=[],
source_address=None,
# source_address=dict(address=(host, port), use_ipv6=False),
cert=None,
sni=host,
ssl_established=True
))
f = http.HTTPFlow(c, s)
headers = ODictCaseless()
req = http.HTTPRequest(
"absolute",
method,
scheme,
host,
port,
path,
(1,
1),
headers,
None,
None,
None,
None)
f.request = req
return self.load_flow(f)
def load_flow(self, f):
"""
Loads a flow, and returns a new flow object.
"""
if self.server and self.server.config.mode == "reverse":
f.request.host, f.request.port = self.server.config.mode.dst[2:]
f.request.scheme = "https" if self.server.config.mode.dst[
1] else "http"
f.reply = controller.DummyReply()
if f.request:
self.handle_request(f)
if f.response:
self.handle_responseheaders(f)
self.handle_response(f)
if f.error:
self.handle_error(f)
return f
def load_flows(self, fr):
"""
Load flows from a FlowReader object.
"""
cnt = 0
for i in fr.stream():
cnt += 1
self.load_flow(i)
return cnt
def load_flows_file(self, path):
path = os.path.expanduser(path)
try:
f = file(path, "rb")
freader = FlowReader(f)
except IOError as v:
raise FlowReadError(v.strerror)
return self.load_flows(freader)
def process_new_request(self, f):
if self.stickycookie_state:
self.stickycookie_state.handle_request(f)
if self.stickyauth_state:
self.stickyauth_state.handle_request(f)
if self.anticache:
f.request.anticache()
if self.anticomp:
f.request.anticomp()
if self.server_playback:
pb = self.do_server_playback(f)
if not pb:
if self.kill_nonreplay:
f.kill(self)
else:
f.reply()
def process_new_response(self, f):
if self.stickycookie_state:
self.stickycookie_state.handle_response(f)
def replay_request(self, f, block=False, run_scripthooks=True):
"""
Returns None if successful, or error message if not.
"""
if f.live and run_scripthooks:
return "Can't replay live request."
if f.intercepted:
return "Can't replay while intercepting..."
if f.request.content == http.CONTENT_MISSING:
return "Can't replay request with missing content..."
if f.request:
f.backup()
f.request.is_replay = True
if f.request.content:
f.request.headers[
"Content-Length"] = [str(len(f.request.content))]
f.response = None
f.error = None
self.process_new_request(f)
rt = http.RequestReplayThread(
self.server.config,
f,
self.masterq if run_scripthooks else False,
self.should_exit
)
rt.start() # pragma: no cover
if block:
rt.join()
def handle_log(self, l):
self.add_event(l.msg, l.level)
l.reply()
def handle_clientconnect(self, cc):
self.run_script_hook("clientconnect", cc)
cc.reply()
def handle_clientdisconnect(self, r):
self.run_script_hook("clientdisconnect", r)
r.reply()
def handle_serverconnect(self, sc):
self.run_script_hook("serverconnect", sc)
sc.reply()
def handle_error(self, f):
self.state.update_flow(f)
self.run_script_hook("error", f)
if self.client_playback:
self.client_playback.clear(f)
f.reply()
return f
def handle_request(self, f):
if f.live:
app = self.apps.get(f.request)
if app:
err = app.serve(
f,
f.client_conn.wfile,
**{"mitmproxy.master": self}
)
if err:
self.add_event("Error in wsgi app. %s" % err, "error")
f.reply(protocol.KILL)
return
if f not in self.state.flows: # don't add again on replay
self.state.add_flow(f)
self.replacehooks.run(f)
self.setheaders.run(f)
self.run_script_hook("request", f)
self.process_new_request(f)
return f
def handle_responseheaders(self, f):
self.run_script_hook("responseheaders", f)
try:
if self.stream_large_bodies:
self.stream_large_bodies.run(f, False)
except netlib.http.HttpError:
f.reply(protocol.KILL)
return
f.reply()
return f
def handle_response(self, f):
self.state.update_flow(f)
self.replacehooks.run(f)
self.setheaders.run(f)
self.run_script_hook("response", f)
if self.client_playback:
self.client_playback.clear(f)
self.process_new_response(f)
if self.stream:
self.stream.add(f)
return f
def handle_intercept(self, f):
self.state.update_flow(f)
def handle_accept_intercept(self, f):
self.state.update_flow(f)
def shutdown(self):
self.unload_scripts()
controller.Master.shutdown(self)
if self.stream:
for i in self.state.flows:
if not i.response:
self.stream.add(i)
self.stop_stream()
def start_stream(self, fp, filt):
self.stream = FilteredFlowWriter(fp, filt)
def stop_stream(self):
self.stream.fo.close()
self.stream = None
def read_flows_from_paths(paths):
"""
Given a list of filepaths, read all flows and return a list of them.
From a performance perspective, streaming would be advisable -
however, if there's an error with one of the files, we want it to be raised immediately.
If an error occurs, a FlowReadError will be raised.
"""
try:
flows = []
for path in paths:
path = os.path.expanduser(path)
with file(path, "rb") as f:
flows.extend(FlowReader(f).stream())
except IOError as e:
raise FlowReadError(e.strerror)
return flows
class FlowWriter:
def __init__(self, fo):
self.fo = fo
def add(self, flow):
d = flow.get_state()
tnetstring.dump(d, self.fo)
class FlowReadError(Exception):
@property
def strerror(self):
return self.args[0]
class FlowReader:
def __init__(self, fo):
self.fo = fo
def stream(self):
"""
Yields Flow objects from the dump.
"""
off = 0
try:
while True:
data = tnetstring.load(self.fo)
if tuple(data["version"][:2]) != version.IVERSION[:2]:
v = ".".join(str(i) for i in data["version"])
raise FlowReadError(
"Incompatible serialized data version: %s" % v
)
off = self.fo.tell()
yield handle.protocols[data["type"]]["flow"].from_state(data)
except ValueError as v:
# Error is due to EOF
if self.fo.tell() == off and self.fo.read() == '':
return
raise FlowReadError("Invalid data format.")
class FilteredFlowWriter:
def __init__(self, fo, filt):
self.fo = fo
self.filt = filt
def add(self, f):
if self.filt and not f.match(self.filt):
return
d = f.get_state()
tnetstring.dump(d, self.fo)