mitmproxy/libmproxy/flow.py
2015-11-13 16:55:27 -05:00

1112 lines
31 KiB
Python

"""
This module provides more sophisticated flow tracking and provides filtering and interception facilities.
"""
from __future__ import absolute_import
from abc import abstractmethod, ABCMeta
import hashlib
import Cookie
import cookielib
import os
import re
import urlparse
import inspect
from netlib import wsgi
from netlib.exceptions import HttpException
from netlib.http import CONTENT_MISSING, Headers, http1
import netlib.http
from . import controller, tnetstring, filt, script, version
from .onboarding import app
from .proxy.config import HostMatcher
from .protocol.http_replay import RequestReplayThread
from .protocol import Kill
from .models import ClientConnection, ServerConnection, HTTPResponse, HTTPFlow, HTTPRequest
from . import contentviews as cv
class AppRegistry:
def __init__(self):
self.apps = {}
def add(self, app, domain, port):
"""
Add a WSGI app to the registry, to be served for requests to the
specified domain, on the specified port.
"""
self.apps[(domain, port)] = wsgi.WSGIAdaptor(
app,
domain,
port,
version.NAMEVERSION
)
def get(self, request):
"""
Returns an WSGIAdaptor instance if request matches an app, or None.
"""
if (request.host, request.port) in self.apps:
return self.apps[(request.host, request.port)]
if "host" in request.headers:
host = request.headers["host"]
return self.apps.get((host, request.port), None)
class ReplaceHooks:
def __init__(self):
self.lst = []
def set(self, r):
self.clear()
for i in r:
self.add(*i)
def add(self, fpatt, rex, s):
"""
add a replacement hook.
fpatt: a string specifying a filter pattern.
rex: a regular expression.
s: the replacement string
returns true if hook was added, false if the pattern could not be
parsed.
"""
cpatt = filt.parse(fpatt)
if not cpatt:
return False
try:
re.compile(rex)
except re.error:
return False
self.lst.append((fpatt, rex, s, cpatt))
return True
def get_specs(self):
"""
Retrieve the hook specifcations. Returns a list of (fpatt, rex, s)
tuples.
"""
return [i[:3] for i in self.lst]
def count(self):
return len(self.lst)
def run(self, f):
for _, rex, s, cpatt in self.lst:
if cpatt(f):
if f.response:
f.response.replace(rex, s)
else:
f.request.replace(rex, s)
def clear(self):
self.lst = []
class SetHeaders:
def __init__(self):
self.lst = []
def set(self, r):
self.clear()
for i in r:
self.add(*i)
def add(self, fpatt, header, value):
"""
Add a set header hook.
fpatt: String specifying a filter pattern.
header: Header name.
value: Header value string
Returns True if hook was added, False if the pattern could not be
parsed.
"""
cpatt = filt.parse(fpatt)
if not cpatt:
return False
self.lst.append((fpatt, header, value, cpatt))
return True
def get_specs(self):
"""
Retrieve the hook specifcations. Returns a list of (fpatt, rex, s)
tuples.
"""
return [i[:3] for i in self.lst]
def count(self):
return len(self.lst)
def clear(self):
self.lst = []
def run(self, f):
for _, header, value, cpatt in self.lst:
if cpatt(f):
if f.response:
f.response.headers.pop(header, None)
else:
f.request.headers.pop(header, None)
for _, header, value, cpatt in self.lst:
if cpatt(f):
if f.response:
f.response.headers.fields.append((header, value))
else:
f.request.headers.fields.append((header, value))
class StreamLargeBodies(object):
def __init__(self, max_size):
self.max_size = max_size
def run(self, flow, is_request):
r = flow.request if is_request else flow.response
expected_size = http1.expected_http_body_size(
flow.request, flow.response if not is_request else None
)
if not (0 <= expected_size <= self.max_size):
# r.stream may already be a callable, which we want to preserve.
r.stream = r.stream or True
class ClientPlaybackState:
def __init__(self, flows, exit):
self.flows, self.exit = flows, exit
self.current = None
self.testing = False # Disables actual replay for testing.
def count(self):
return len(self.flows)
def done(self):
if len(self.flows) == 0 and not self.current:
return True
return False
def clear(self, flow):
"""
A request has returned in some way - if this is the one we're
servicing, go to the next flow.
"""
if flow is self.current:
self.current = None
def tick(self, master):
if self.flows and not self.current:
self.current = self.flows.pop(0).copy()
if not self.testing:
master.replay_request(self.current)
else:
self.current.reply = controller.DummyReply()
master.handle_request(self.current)
if self.current.response:
master.handle_response(self.current)
class ServerPlaybackState:
def __init__(
self,
headers,
flows,
exit,
nopop,
ignore_params,
ignore_content,
ignore_payload_params,
ignore_host):
"""
headers: Case-insensitive list of request headers that should be
included in request-response matching.
"""
self.headers = headers
self.exit = exit
self.nopop = nopop
self.ignore_params = ignore_params
self.ignore_content = ignore_content
self.ignore_payload_params = ignore_payload_params
self.ignore_host = ignore_host
self.fmap = {}
for i in flows:
if i.response:
l = self.fmap.setdefault(self._hash(i), [])
l.append(i)
def count(self):
return sum(len(i) for i in self.fmap.values())
def _hash(self, flow):
"""
Calculates a loose hash of the flow request.
"""
r = flow.request
_, _, path, _, query, _ = urlparse.urlparse(r.url)
queriesArray = urlparse.parse_qsl(query, keep_blank_values=True)
key = [
str(r.port),
str(r.scheme),
str(r.method),
str(path),
]
if not self.ignore_content:
form_contents = r.urlencoded_form or r.multipart_form
if self.ignore_payload_params and form_contents:
key.extend(
p for p in form_contents
if p[0] not in self.ignore_payload_params
)
else:
key.append(str(r.content))
if not self.ignore_host:
key.append(r.host)
filtered = []
ignore_params = self.ignore_params or []
for p in queriesArray:
if p[0] not in ignore_params:
filtered.append(p)
for p in filtered:
key.append(p[0])
key.append(p[1])
if self.headers:
headers = []
for i in self.headers:
v = r.headers.get(i)
headers.append((i, v))
key.append(headers)
return hashlib.sha256(repr(key)).digest()
def next_flow(self, request):
"""
Returns the next flow object, or None if no matching flow was
found.
"""
l = self.fmap.get(self._hash(request))
if not l:
return None
if self.nopop:
return l[0]
else:
return l.pop(0)
class StickyCookieState:
def __init__(self, flt):
"""
flt: Compiled filter.
"""
self.jar = {}
self.flt = flt
def ckey(self, m, f):
"""
Returns a (domain, port, path) tuple.
"""
return (
m["domain"] or f.request.host,
f.request.port,
m["path"] or "/"
)
def domain_match(self, a, b):
if cookielib.domain_match(a, b):
return True
elif cookielib.domain_match(a, b.strip(".")):
return True
return False
def handle_response(self, f):
for i in f.response.headers.get_all("set-cookie"):
# FIXME: We now know that Cookie.py screws up some cookies with
# valid RFC 822/1123 datetime specifications for expiry. Sigh.
c = Cookie.SimpleCookie(str(i))
for m in c.values():
k = self.ckey(m, f)
if self.domain_match(f.request.host, k[0]):
self.jar[k] = m
def handle_request(self, f):
l = []
if f.match(self.flt):
for i in self.jar.keys():
match = [
self.domain_match(f.request.host, i[0]),
f.request.port == i[1],
f.request.path.startswith(i[2])
]
if all(match):
l.append(self.jar[i].output(header="").strip())
if l:
f.request.stickycookie = True
f.request.headers.set_all("cookie",l)
class StickyAuthState:
def __init__(self, flt):
"""
flt: Compiled filter.
"""
self.flt = flt
self.hosts = {}
def handle_request(self, f):
host = f.request.host
if "authorization" in f.request.headers:
self.hosts[host] = f.request.headers["authorization"]
elif f.match(self.flt):
if host in self.hosts:
f.request.headers["authorization"] = self.hosts[host]
class FlowList(object):
__metaclass__ = ABCMeta
def __iter__(self):
return iter(self._list)
def __contains__(self, item):
return item in self._list
def __getitem__(self, item):
return self._list[item]
def __nonzero__(self):
return bool(self._list)
def __len__(self):
return len(self._list)
def index(self, f):
return self._list.index(f)
@abstractmethod
def _add(self, f):
return
@abstractmethod
def _update(self, f):
return
@abstractmethod
def _remove(self, f):
return
class FlowView(FlowList):
def __init__(self, store, filt=None):
self._list = []
if not filt:
filt = lambda flow: True
self._build(store, filt)
self.store = store
self.store.views.append(self)
def _close(self):
self.store.views.remove(self)
def _build(self, flows, filt=None):
if filt:
self.filt = filt
self._list = list(filter(self.filt, flows))
def _add(self, f):
if self.filt(f):
self._list.append(f)
def _update(self, f):
if f not in self._list:
self._add(f)
elif not self.filt(f):
self._remove(f)
def _remove(self, f):
if f in self._list:
self._list.remove(f)
def _recalculate(self, flows):
self._build(flows)
class FlowStore(FlowList):
"""
Responsible for handling flows in the state:
Keeps a list of all flows and provides views on them.
"""
def __init__(self):
self._list = []
self._set = set() # Used for O(1) lookups
self.views = []
self._recalculate_views()
def get(self, flow_id):
for f in self._list:
if f.id == flow_id:
return f
def __contains__(self, f):
return f in self._set
def _add(self, f):
"""
Adds a flow to the state.
The flow to add must not be present in the state.
"""
self._list.append(f)
self._set.add(f)
for view in self.views:
view._add(f)
def _update(self, f):
"""
Notifies the state that a flow has been updated.
The flow must be present in the state.
"""
if f in self:
for view in self.views:
view._update(f)
def _remove(self, f):
"""
Deletes a flow from the state.
The flow must be present in the state.
"""
self._list.remove(f)
self._set.remove(f)
for view in self.views:
view._remove(f)
# Expensive bulk operations
def _extend(self, flows):
"""
Adds a list of flows to the state.
The list of flows to add must not contain flows that are already in the state.
"""
self._list.extend(flows)
self._set.update(flows)
self._recalculate_views()
def _clear(self):
self._list = []
self._set = set()
self._recalculate_views()
def _recalculate_views(self):
"""
Expensive operation: Recalculate all the views after a bulk change.
"""
for view in self.views:
view._recalculate(self)
# Utility functions.
# There are some common cases where we need to argue about all flows
# irrespective of filters on the view etc (i.e. on shutdown).
def active_count(self):
c = 0
for i in self._list:
if not i.response and not i.error:
c += 1
return c
# TODO: Should accept_all operate on views or on all flows?
def accept_all(self, master):
for f in self._list:
f.accept_intercept(master)
def kill_all(self, master):
for f in self._list:
f.kill(master)
class State(object):
def __init__(self):
self.flows = FlowStore()
self.view = FlowView(self.flows, None)
# These are compiled filt expressions:
self.intercept = None
@property
def limit_txt(self):
return getattr(self.view.filt, "pattern", None)
def flow_count(self):
return len(self.flows)
# TODO: All functions regarding flows that don't cause side-effects should
# be moved into FlowStore.
def index(self, f):
return self.flows.index(f)
def active_flow_count(self):
return self.flows.active_count()
def add_flow(self, f):
"""
Add a request to the state.
"""
self.flows._add(f)
return f
def update_flow(self, f):
"""
Add a response to the state.
"""
self.flows._update(f)
return f
def delete_flow(self, f):
self.flows._remove(f)
def load_flows(self, flows):
self.flows._extend(flows)
def set_limit(self, txt):
if txt == self.limit_txt:
return
if txt:
f = filt.parse(txt)
if not f:
return "Invalid filter expression."
self.view._close()
self.view = FlowView(self.flows, f)
else:
self.view._close()
self.view = FlowView(self.flows, None)
def set_intercept(self, txt):
if txt:
f = filt.parse(txt)
if not f:
return "Invalid filter expression."
self.intercept = f
else:
self.intercept = None
@property
def intercept_txt(self):
return getattr(self.intercept, "pattern", None)
def clear(self):
self.flows._clear()
def accept_all(self, master):
self.flows.accept_all(master)
def backup(self, f):
f.backup()
self.update_flow(f)
def revert(self, f):
f.revert()
self.update_flow(f)
def killall(self, master):
self.flows.kill_all(master)
class FlowMaster(controller.Master):
def __init__(self, server, state):
controller.Master.__init__(self, server)
self.state = state
self.server_playback = None
self.client_playback = None
self.kill_nonreplay = False
self.scripts = []
self.pause_scripts = False
self.stickycookie_state = False
self.stickycookie_txt = None
self.stickyauth_state = False
self.stickyauth_txt = None
self.anticache = False
self.anticomp = False
self.stream_large_bodies = False
self.refresh_server_playback = False
self.replacehooks = ReplaceHooks()
self.setheaders = SetHeaders()
self.replay_ignore_params = False
self.replay_ignore_content = None
self.replay_ignore_host = False
self.stream = None
self.apps = AppRegistry()
def start_app(self, host, port):
self.apps.add(
app.mapp,
host,
port
)
def add_event(self, e, level="info"):
"""
level: debug, info, error
"""
pass
def unload_scripts(self):
for s in self.scripts[:]:
self.unload_script(s)
def unload_script(self, script_obj):
try:
script_obj.unload()
except script.ScriptError as e:
self.add_event("Script error:\n" + str(e), "error")
self.scripts.remove(script_obj)
def load_script(self, command):
"""
Loads a script. Returns an error description if something went
wrong.
"""
try:
s = script.Script(command, self)
except script.ScriptError as v:
return v.args[0]
self.scripts.append(s)
def _run_single_script_hook(self, script_obj, name, *args, **kwargs):
if script_obj and not self.pause_scripts:
try:
script_obj.run(name, *args, **kwargs)
except script.ScriptError as e:
self.add_event("Script error:\n" + str(e), "error")
def run_script_hook(self, name, *args, **kwargs):
for script_obj in self.scripts:
self._run_single_script_hook(script_obj, name, *args, **kwargs)
def get_ignore_filter(self):
return self.server.config.check_ignore.patterns
def set_ignore_filter(self, host_patterns):
self.server.config.check_ignore = HostMatcher(host_patterns)
def get_tcp_filter(self):
return self.server.config.check_tcp.patterns
def set_tcp_filter(self, host_patterns):
self.server.config.check_tcp = HostMatcher(host_patterns)
def set_stickycookie(self, txt):
if txt:
flt = filt.parse(txt)
if not flt:
return "Invalid filter expression."
self.stickycookie_state = StickyCookieState(flt)
self.stickycookie_txt = txt
else:
self.stickycookie_state = None
self.stickycookie_txt = None
def set_stream_large_bodies(self, max_size):
if max_size is not None:
self.stream_large_bodies = StreamLargeBodies(max_size)
else:
self.stream_large_bodies = False
def set_stickyauth(self, txt):
if txt:
flt = filt.parse(txt)
if not flt:
return "Invalid filter expression."
self.stickyauth_state = StickyAuthState(flt)
self.stickyauth_txt = txt
else:
self.stickyauth_state = None
self.stickyauth_txt = None
def start_client_playback(self, flows, exit):
"""
flows: List of flows.
"""
self.client_playback = ClientPlaybackState(flows, exit)
def stop_client_playback(self):
self.client_playback = None
def start_server_playback(
self,
flows,
kill,
headers,
exit,
nopop,
ignore_params,
ignore_content,
ignore_payload_params,
ignore_host):
"""
flows: List of flows.
kill: Boolean, should we kill requests not part of the replay?
ignore_params: list of parameters to ignore in server replay
ignore_content: true if request content should be ignored in server replay
ignore_payload_params: list of content params to ignore in server replay
ignore_host: true if request host should be ignored in server replay
"""
self.server_playback = ServerPlaybackState(
headers,
flows,
exit,
nopop,
ignore_params,
ignore_content,
ignore_payload_params,
ignore_host)
self.kill_nonreplay = kill
def stop_server_playback(self):
if self.server_playback.exit:
self.shutdown()
self.server_playback = None
def do_server_playback(self, flow):
"""
This method should be called by child classes in the handle_request
handler. Returns True if playback has taken place, None if not.
"""
if self.server_playback:
rflow = self.server_playback.next_flow(flow)
if not rflow:
return None
response = HTTPResponse.from_state(rflow.response.get_state())
response.is_replay = True
if self.refresh_server_playback:
response.refresh()
flow.reply(response)
if self.server_playback.count() == 0:
self.stop_server_playback()
return True
return None
def tick(self, q, timeout):
if self.client_playback:
e = [
self.client_playback.done(),
self.client_playback.exit,
self.state.active_flow_count() == 0
]
if all(e):
self.shutdown()
self.client_playback.tick(self)
if self.client_playback.done():
self.client_playback = None
return super(FlowMaster, self).tick(q, timeout)
def duplicate_flow(self, f):
return self.load_flow(f.copy())
def create_request(self, method, scheme, host, port, path):
"""
this method creates a new artificial and minimalist request also adds it to flowlist
"""
c = ClientConnection.from_state(dict(
address=dict(address=(host, port), use_ipv6=False),
clientcert=None
))
s = ServerConnection.from_state(dict(
address=dict(address=(host, port), use_ipv6=False),
state=[],
source_address=None,
# source_address=dict(address=(host, port), use_ipv6=False),
cert=None,
sni=host,
ssl_established=True
))
f = HTTPFlow(c, s)
headers = Headers()
req = HTTPRequest(
"absolute",
method,
scheme,
host,
port,
path,
b"HTTP/1.1",
headers,
None,
None,
None,
None)
f.request = req
return self.load_flow(f)
def load_flow(self, f):
"""
Loads a flow, and returns a new flow object.
"""
if self.server and self.server.config.mode == "reverse":
f.request.host = self.server.config.upstream_server.address.host
f.request.port = self.server.config.upstream_server.address.port
f.request.scheme = re.sub("^https?2", "", self.server.config.upstream_server.scheme)
f.reply = controller.DummyReply()
if f.request:
self.handle_request(f)
if f.response:
self.handle_responseheaders(f)
self.handle_response(f)
if f.error:
self.handle_error(f)
return f
def load_flows(self, fr):
"""
Load flows from a FlowReader object.
"""
cnt = 0
for i in fr.stream():
cnt += 1
self.load_flow(i)
return cnt
def load_flows_file(self, path):
path = os.path.expanduser(path)
try:
f = file(path, "rb")
freader = FlowReader(f)
except IOError as v:
raise FlowReadError(v.strerror)
return self.load_flows(freader)
def process_new_request(self, f):
if self.stickycookie_state:
self.stickycookie_state.handle_request(f)
if self.stickyauth_state:
self.stickyauth_state.handle_request(f)
if self.anticache:
f.request.anticache()
if self.anticomp:
f.request.anticomp()
if self.server_playback:
pb = self.do_server_playback(f)
if not pb:
if self.kill_nonreplay:
f.kill(self)
else:
f.reply()
def process_new_response(self, f):
if self.stickycookie_state:
self.stickycookie_state.handle_response(f)
def replay_request(self, f, block=False, run_scripthooks=True):
"""
Returns None if successful, or error message if not.
"""
if f.live and run_scripthooks:
return "Can't replay live request."
if f.intercepted:
return "Can't replay while intercepting..."
if f.request.content == CONTENT_MISSING:
return "Can't replay request with missing content..."
if f.request:
f.backup()
f.request.is_replay = True
if "Content-Length" in f.request.headers:
f.request.headers["Content-Length"] = str(len(f.request.content))
f.response = None
f.error = None
self.process_new_request(f)
rt = RequestReplayThread(
self.server.config,
f,
self.masterq if run_scripthooks else False,
self.should_exit
)
rt.start() # pragma: no cover
if block:
rt.join()
def handle_log(self, l):
self.add_event(l.msg, l.level)
l.reply()
def handle_clientconnect(self, root_layer):
self.run_script_hook("clientconnect", root_layer)
root_layer.reply()
def handle_clientdisconnect(self, root_layer):
self.run_script_hook("clientdisconnect", root_layer)
root_layer.reply()
def handle_serverconnect(self, server_conn):
self.run_script_hook("serverconnect", server_conn)
server_conn.reply()
def handle_serverdisconnect(self, server_conn):
self.run_script_hook("serverdisconnect", server_conn)
server_conn.reply()
def handle_next_layer(self, top_layer):
self.run_script_hook("next_layer", top_layer)
top_layer.reply()
def handle_error(self, f):
self.state.update_flow(f)
self.run_script_hook("error", f)
if self.client_playback:
self.client_playback.clear(f)
f.reply()
return f
def handle_request(self, f):
if f.live:
app = self.apps.get(f.request)
if app:
err = app.serve(
f,
f.client_conn.wfile,
**{"mitmproxy.master": self}
)
if err:
self.add_event("Error in wsgi app. %s" % err, "error")
f.reply(Kill)
return
if f not in self.state.flows: # don't add again on replay
self.state.add_flow(f)
self.replacehooks.run(f)
self.setheaders.run(f)
self.run_script_hook("request", f)
self.process_new_request(f)
return f
def handle_responseheaders(self, f):
self.run_script_hook("responseheaders", f)
try:
if self.stream_large_bodies:
self.stream_large_bodies.run(f, False)
except HttpException:
f.reply(Kill)
return
f.reply()
return f
def handle_response(self, f):
self.state.update_flow(f)
self.replacehooks.run(f)
self.setheaders.run(f)
self.run_script_hook("response", f)
if self.client_playback:
self.client_playback.clear(f)
self.process_new_response(f)
if self.stream:
self.stream.add(f)
return f
def handle_intercept(self, f):
self.state.update_flow(f)
def handle_accept_intercept(self, f):
self.state.update_flow(f)
def shutdown(self):
self.unload_scripts()
controller.Master.shutdown(self)
if self.stream:
for i in self.state.flows:
if not i.response:
self.stream.add(i)
self.stop_stream()
def start_stream(self, fp, filt):
self.stream = FilteredFlowWriter(fp, filt)
def stop_stream(self):
self.stream.fo.close()
self.stream = None
def read_flows_from_paths(paths):
"""
Given a list of filepaths, read all flows and return a list of them.
From a performance perspective, streaming would be advisable -
however, if there's an error with one of the files, we want it to be raised immediately.
If an error occurs, a FlowReadError will be raised.
"""
try:
flows = []
for path in paths:
path = os.path.expanduser(path)
with file(path, "rb") as f:
flows.extend(FlowReader(f).stream())
except IOError as e:
raise FlowReadError(e.strerror)
return flows
class FlowWriter:
def __init__(self, fo):
self.fo = fo
def add(self, flow):
d = flow.get_state()
tnetstring.dump(d, self.fo)
class FlowReadError(Exception):
@property
def strerror(self):
return self.args[0]
class FlowReader:
def __init__(self, fo):
self.fo = fo
def stream(self):
"""
Yields Flow objects from the dump.
"""
off = 0
try:
while True:
data = tnetstring.load(self.fo)
if tuple(data["version"][:2]) != version.IVERSION[:2]:
v = ".".join(str(i) for i in data["version"])
raise FlowReadError(
"Incompatible serialized data version: %s" % v
)
off = self.fo.tell()
yield HTTPFlow.from_state(data)
except ValueError as v:
# Error is due to EOF
if self.fo.tell() == off and self.fo.read() == '':
return
raise FlowReadError("Invalid data format.")
class FilteredFlowWriter:
def __init__(self, fo, filt):
self.fo = fo
self.filt = filt
def add(self, f):
if self.filt and not f.match(self.filt):
return
d = f.get_state()
tnetstring.dump(d, self.fo)