Merge remote-tracking branch 'origin/master' into flow_editing_v2

This commit is contained in:
Maximilian Hils 2016-07-25 15:16:16 -07:00
commit 79ebcb046e
73 changed files with 931 additions and 739 deletions

View File

@ -20,10 +20,10 @@ matrix:
include:
- python: 3.5
env: TOXENV=lint
# - os: osx
# osx_image: xcode7.3
# language: generic
# env: TOXENV=py35
- os: osx
osx_image: xcode7.3
language: generic
env: TOXENV=py35
- python: 3.5
env: TOXENV=py35
- python: 3.5

View File

@ -1,18 +1,20 @@
# This scripts demonstrates how to use mitmproxy's filter pattern in inline scripts.
# This scripts demonstrates how to use mitmproxy's filter pattern in scripts.
# Usage: mitmdump -s "filt.py FILTER"
import sys
from mitmproxy import filt
state = {}
class Filter:
def __init__(self, spec):
self.filter = filt.parse(spec)
def response(self, flow):
if flow.match(self.filter):
print("Flow matches filter:")
print(flow)
def start():
if len(sys.argv) != 2:
raise ValueError("Usage: -s 'filt.py FILTER'")
state["filter"] = filt.parse(sys.argv[1])
def response(flow):
if flow.match(state["filter"]):
print("Flow matches filter:")
print(flow)
return Filter(sys.argv[1])

View File

@ -3,20 +3,21 @@ import sys
from mitmproxy.flow import FlowWriter
state = {}
class Writer:
def __init__(self, path):
if path == "-":
f = sys.stdout
else:
f = open(path, "wb")
self.w = FlowWriter(f)
def response(self, flow):
if random.choice([True, False]):
self.w.add(flow)
def start():
if len(sys.argv) != 2:
raise ValueError('Usage: -s "flowriter.py filename"')
if sys.argv[1] == "-":
f = sys.stdout
else:
f = open(sys.argv[1], "wb")
state["flow_writer"] = FlowWriter(f)
def response(flow):
if random.choice([True, False]):
state["flow_writer"].add(flow)
return Writer(sys.argv[1])

View File

@ -3,26 +3,27 @@
import sys
from bs4 import BeautifulSoup
iframe_url = None
class Injector:
def __init__(self, iframe_url):
self.iframe_url = iframe_url
def response(self, flow):
if flow.request.host in self.iframe_url:
return
html = BeautifulSoup(flow.response.content, "lxml")
if html.body:
iframe = html.new_tag(
"iframe",
src=self.iframe_url,
frameborder=0,
height=0,
width=0)
html.body.insert(0, iframe)
flow.response.content = str(html).encode("utf8")
def start():
if len(sys.argv) != 2:
raise ValueError('Usage: -s "iframe_injector.py url"')
global iframe_url
iframe_url = sys.argv[1]
def response(flow):
if flow.request.host in iframe_url:
return
html = BeautifulSoup(flow.response.content, "lxml")
if html.body:
iframe = html.new_tag(
"iframe",
src=iframe_url,
frameborder=0,
height=0,
width=0)
html.body.insert(0, iframe)
flow.response.content = str(html).encode("utf8")
return Injector(sys.argv[1])

19
examples/remote_debug.py Normal file
View File

@ -0,0 +1,19 @@
"""
This script enables remote debugging of the mitmproxy *UI* with PyCharm.
For general debugging purposes, it is easier to just debug mitmdump within PyCharm.
Usage:
- pip install pydevd on the mitmproxy machine
- Open the Run/Debug Configuration dialog box in PyCharm, and select the Python Remote Debug configuration type.
- Debugging works in the way that mitmproxy connects to the debug server on startup.
Specify host and port that mitmproxy can use to reach your PyCharm instance on startup.
- Adjust this inline script accordingly.
- Start debug server in PyCharm
- Set breakpoints
- Start mitmproxy -s remote_debug.py
"""
def start():
import pydevd
pydevd.settrace("localhost", port=5678, stdoutToServer=True, stderrToServer=True)

View File

@ -11,7 +11,7 @@ def start():
mitmproxy.ctx.log("start")
def configure(options):
def configure(options, updated):
"""
Called once on script startup before any other events, and whenever options changes.
"""

View File

@ -13,16 +13,23 @@ class Addons(object):
self.master = master
master.options.changed.connect(self.options_update)
def options_update(self, options):
def options_update(self, options, updated):
for i in self.chain:
with self.master.handlecontext():
i.configure(options)
i.configure(options, updated)
def add(self, *addons):
def add(self, options, *addons):
if not addons:
raise ValueError("No adons specified.")
self.chain.extend(addons)
for i in addons:
self.invoke_with_context(i, "start")
self.invoke_with_context(i, "configure", self.master.options)
self.invoke_with_context(
i,
"configure",
self.master.options,
self.master.options.keys()
)
def remove(self, addon):
self.chain = [i for i in self.chain if i is not addon]

View File

@ -5,7 +5,7 @@ class AntiCache:
def __init__(self):
self.enabled = False
def configure(self, options):
def configure(self, options, updated):
self.enabled = options.anticache
def request(self, flow):

View File

@ -5,7 +5,7 @@ class AntiComp:
def __init__(self):
self.enabled = False
def configure(self, options):
def configure(self, options, updated):
self.enabled = options.anticomp
def request(self, flow):

View File

@ -5,6 +5,8 @@ import traceback
import click
import typing # noqa
from mitmproxy import contentviews
from mitmproxy import ctx
from mitmproxy import exceptions
@ -19,12 +21,25 @@ def indent(n, text):
return "\n".join(pad + i for i in l)
class Dumper():
class Dumper(object):
def __init__(self):
self.filter = None
self.flow_detail = None
self.outfp = None
self.showhost = None
self.filter = None # type: filt.TFilter
self.flow_detail = None # type: int
self.outfp = None # type: typing.io.TextIO
self.showhost = None # type: bool
def configure(self, options, updated):
if options.filtstr:
self.filter = filt.parse(options.filtstr)
if not self.filter:
raise exceptions.OptionsError(
"Invalid filter expression: %s" % options.filtstr
)
else:
self.filter = None
self.flow_detail = options.flow_detail
self.outfp = options.tfile
self.showhost = options.showhost
def echo(self, text, ident=None, **style):
if ident:
@ -59,7 +74,7 @@ class Dumper():
self.echo("")
try:
type, lines = contentviews.get_content_view(
_, lines = contentviews.get_content_view(
contentviews.get("Auto"),
content,
headers=getattr(message, "headers", None)
@ -67,7 +82,7 @@ class Dumper():
except exceptions.ContentViewException:
s = "Content viewer failed: \n" + traceback.format_exc()
ctx.log.debug(s)
type, lines = contentviews.get_content_view(
_, lines = contentviews.get_content_view(
contentviews.get("Raw"),
content,
headers=getattr(message, "headers", None)
@ -114,9 +129,8 @@ class Dumper():
if flow.client_conn:
client = click.style(
strutils.escape_control_characters(
flow.client_conn.address.host
),
bold=True
repr(flow.client_conn.address)
)
)
elif flow.request.is_replay:
client = click.style("[replay]", fg="yellow", bold=True)
@ -139,17 +153,23 @@ class Dumper():
url = flow.request.url
url = click.style(strutils.escape_control_characters(url), bold=True)
httpversion = ""
http_version = ""
if flow.request.http_version not in ("HTTP/1.1", "HTTP/1.0"):
# We hide "normal" HTTP 1.
httpversion = " " + flow.request.http_version
http_version = " " + flow.request.http_version
line = "{stickycookie}{client} {method} {url}{httpversion}".format(
stickycookie=stickycookie,
if self.flow_detail >= 2:
linebreak = "\n "
else:
linebreak = ""
line = "{client}: {linebreak}{stickycookie}{method} {url}{http_version}".format(
client=client,
stickycookie=stickycookie,
linebreak=linebreak,
method=method,
url=url,
httpversion=httpversion
http_version=http_version
)
self.echo(line)
@ -185,9 +205,14 @@ class Dumper():
size = human.pretty_size(len(flow.response.raw_content))
size = click.style(size, bold=True)
arrows = click.style(" <<", bold=True)
arrows = click.style(" <<", bold=True)
if self.flow_detail == 1:
# This aligns the HTTP response code with the HTTP request method:
# 127.0.0.1:59519: GET http://example.com/
# << 304 Not Modified 0b
arrows = " " * (len(repr(flow.client_conn.address)) - 2) + arrows
line = "{replay} {arrows} {code} {reason} {size}".format(
line = "{replay}{arrows} {code} {reason} {size}".format(
replay=replay,
arrows=arrows,
code=code,
@ -211,25 +236,12 @@ class Dumper():
def match(self, f):
if self.flow_detail == 0:
return False
if not self.filt:
if not self.filter:
return True
elif f.match(self.filt):
elif f.match(self.filter):
return True
return False
def configure(self, options):
if options.filtstr:
self.filt = filt.parse(options.filtstr)
if not self.filt:
raise exceptions.OptionsError(
"Invalid filter expression: %s" % options.filtstr
)
else:
self.filt = None
self.flow_detail = options.flow_detail
self.outfp = options.tfile
self.showhost = options.showhost
def response(self, f):
if self.match(f):
self.echo_flow(f)
@ -239,8 +251,7 @@ class Dumper():
self.echo_flow(f)
def tcp_message(self, f):
# FIXME: Filter should be applied here
if self.options.flow_detail == 0:
if not self.match(f):
return
message = f.messages[-1]
direction = "->" if message.from_client else "<-"

View File

@ -19,7 +19,7 @@ class FileStreamer:
self.stream = io.FilteredFlowWriter(f, filt)
self.active_flows = set()
def configure(self, options):
def configure(self, options, updated):
# We're already streaming - stop the previous stream and restart
if self.stream:
self.done()

View File

@ -8,7 +8,7 @@ class Replace:
def __init__(self):
self.lst = []
def configure(self, options):
def configure(self, options, updated):
"""
.replacements is a list of tuples (fpat, rex, s):

View File

@ -16,6 +16,19 @@ import watchdog.events
from watchdog.observers import polling
class NS:
def __init__(self, ns):
self.__dict__["ns"] = ns
def __getattr__(self, key):
if key not in self.ns:
raise AttributeError("No such element: %s", key)
return self.ns[key]
def __setattr__(self, key, value):
self.__dict__["ns"][key] = value
def parse_command(command):
"""
Returns a (path, args) tuple.
@ -74,18 +87,27 @@ def load_script(path, args):
ns = {'__file__': os.path.abspath(path)}
with scriptenv(path, args):
exec(code, ns, ns)
return ns
return NS(ns)
class ReloadHandler(watchdog.events.FileSystemEventHandler):
def __init__(self, callback):
self.callback = callback
def filter(self, event):
if event.is_directory:
return False
if os.path.basename(event.src_path).startswith("."):
return False
return True
def on_modified(self, event):
self.callback()
if self.filter(event):
self.callback()
def on_created(self, event):
self.callback()
if self.filter(event):
self.callback()
class Script:
@ -118,29 +140,35 @@ class Script:
# It's possible for ns to be un-initialised if we failed during
# configure
if self.ns is not None and not self.dead:
func = self.ns.get(name)
func = getattr(self.ns, name, None)
if func:
with scriptenv(self.path, self.args):
func(*args, **kwargs)
return func(*args, **kwargs)
def reload(self):
self.should_reload.set()
def load_script(self):
self.ns = load_script(self.path, self.args)
ret = self.run("start")
if ret:
self.ns = ret
self.run("start")
def tick(self):
if self.should_reload.is_set():
self.should_reload.clear()
ctx.log.info("Reloading script: %s" % self.name)
self.ns = load_script(self.path, self.args)
self.start()
self.configure(self.last_options)
self.configure(self.last_options, self.last_options.keys())
else:
self.run("tick")
def start(self):
self.ns = load_script(self.path, self.args)
self.run("start")
self.load_script()
def configure(self, options):
def configure(self, options, updated):
self.last_options = options
if not self.observer:
self.observer = polling.PollingObserver()
@ -150,7 +178,7 @@ class Script:
os.path.dirname(self.path) or "."
)
self.observer.start()
self.run("configure", options)
self.run("configure", options, updated)
def done(self):
self.run("done")
@ -161,26 +189,27 @@ class ScriptLoader():
"""
An addon that manages loading scripts from options.
"""
def configure(self, options):
for s in options.scripts:
if options.scripts.count(s) > 1:
raise exceptions.OptionsError("Duplicate script: %s" % s)
def configure(self, options, updated):
if "scripts" in updated:
for s in options.scripts:
if options.scripts.count(s) > 1:
raise exceptions.OptionsError("Duplicate script: %s" % s)
for a in ctx.master.addons.chain[:]:
if isinstance(a, Script) and a.name not in options.scripts:
ctx.log.info("Un-loading script: %s" % a.name)
ctx.master.addons.remove(a)
for a in ctx.master.addons.chain[:]:
if isinstance(a, Script) and a.name not in options.scripts:
ctx.log.info("Un-loading script: %s" % a.name)
ctx.master.addons.remove(a)
current = {}
for a in ctx.master.addons.chain[:]:
if isinstance(a, Script):
current[a.name] = a
ctx.master.addons.chain.remove(a)
current = {}
for a in ctx.master.addons.chain[:]:
if isinstance(a, Script):
current[a.name] = a
ctx.master.addons.chain.remove(a)
for s in options.scripts:
if s in current:
ctx.master.addons.chain.append(current[s])
else:
ctx.log.info("Loading script: %s" % s)
sc = Script(s)
ctx.master.addons.add(sc)
for s in options.scripts:
if s in current:
ctx.master.addons.chain.append(current[s])
else:
ctx.log.info("Loading script: %s" % s)
sc = Script(s)
ctx.master.addons.add(options, sc)

View File

@ -6,7 +6,7 @@ class SetHeaders:
def __init__(self):
self.lst = []
def configure(self, options):
def configure(self, options, updated):
"""
options.setheaders is a tuple of (fpatt, header, value)

View File

@ -10,7 +10,7 @@ class StickyAuth:
self.flt = None
self.hosts = {}
def configure(self, options):
def configure(self, options, updated):
if options.stickyauth:
flt = filt.parse(options.stickyauth)
if not flt:

View File

@ -32,7 +32,7 @@ class StickyCookie:
self.jar = collections.defaultdict(dict)
self.flt = None
def configure(self, options):
def configure(self, options, updated):
if options.stickycookie:
flt = filt.parse(options.stickycookie)
if not flt:

View File

@ -134,7 +134,11 @@ def save_data(path, data):
if not path:
return
try:
with open(path, "wb") as f:
if isinstance(data, bytes):
mode = "wb"
else:
mode = "w"
with open(path, mode) as f:
f.write(data)
except IOError as v:
signals.status_message.send(message=v.strerror)
@ -193,10 +197,9 @@ def ask_scope_and_callback(flow, cb, *args):
def copy_to_clipboard_or_prompt(data):
# pyperclip calls encode('utf-8') on data to be copied without checking.
# if data are already encoded that way UnicodeDecodeError is thrown.
toclip = ""
try:
toclip = data.decode('utf-8')
except (UnicodeDecodeError):
if isinstance(data, bytes):
toclip = data.decode("utf8", "replace")
else:
toclip = data
try:
@ -216,7 +219,7 @@ def copy_to_clipboard_or_prompt(data):
def format_flow_data(key, scope, flow):
data = ""
data = b""
if scope in ("q", "b"):
request = flow.request.copy()
request.decode(strict=False)
@ -230,7 +233,7 @@ def format_flow_data(key, scope, flow):
raise ValueError("Unknown key: {}".format(key))
if scope == "b" and flow.request.raw_content and flow.response:
# Add padding between request and response
data += "\r\n" * 2
data += b"\r\n" * 2
if scope in ("s", "b") and flow.response:
response = flow.response.copy()
response.decode(strict=False)
@ -293,7 +296,7 @@ def ask_save_body(scope, flow):
)
elif scope == "b" and request_has_content and response_has_content:
ask_save_path(
(flow.request.get_content(strict=False) + "\n" +
(flow.request.get_content(strict=False) + b"\n" +
flow.response.get_content(strict=False)),
"Save request & response content to"
)
@ -407,7 +410,7 @@ def raw_format_flow(f, focus, extended):
return urwid.Pile(pile)
def format_flow(f, focus, extended=False, hostheader=False, marked=False):
def format_flow(f, focus, extended=False, hostheader=False):
d = dict(
intercepted = f.intercepted,
acked = f.reply.acked,
@ -420,7 +423,7 @@ def format_flow(f, focus, extended=False, hostheader=False, marked=False):
err_msg = f.error.msg if f.error else None,
marked = marked,
marked = f.marked,
)
if f.response:
if f.response.raw_content:

View File

@ -120,23 +120,17 @@ class ConnectionItem(urwid.WidgetWrap):
self.flow,
self.f,
hostheader = self.master.options.showhost,
marked=self.state.flow_marked(self.flow)
)
def selectable(self):
return True
def save_flows_prompt(self, k):
if k == "a":
if k == "l":
signals.status_prompt_path.send(
prompt = "Save all flows to",
prompt = "Save listed flows to",
callback = self.master.save_flows
)
elif k == "m":
signals.status_prompt_path.send(
prompt = "Save marked flows to",
callback = self.master.save_marked_flows
)
else:
signals.status_prompt_path.send(
prompt = "Save this flow to",
@ -188,17 +182,16 @@ class ConnectionItem(urwid.WidgetWrap):
self.flow.accept_intercept(self.master)
signals.flowlist_change.send(self)
elif key == "d":
self.flow.kill(self.master)
if not self.flow.reply.acked:
self.flow.kill(self.master)
self.state.delete_flow(self.flow)
signals.flowlist_change.send(self)
elif key == "D":
f = self.master.duplicate_flow(self.flow)
self.master.view_flow(f)
self.master.state.set_focus_flow(f)
signals.flowlist_change.send(self)
elif key == "m":
if self.state.flow_marked(self.flow):
self.state.set_flow_marked(self.flow, False)
else:
self.state.set_flow_marked(self.flow, True)
self.flow.marked = not self.flow.marked
signals.flowlist_change.send(self)
elif key == "M":
if self.state.mark_filter:
@ -233,7 +226,7 @@ class ConnectionItem(urwid.WidgetWrap):
)
elif key == "U":
for f in self.state.flows:
self.state.set_flow_marked(f, False)
f.marked = False
signals.flowlist_change.send(self)
elif key == "V":
if not self.flow.modified():
@ -247,14 +240,14 @@ class ConnectionItem(urwid.WidgetWrap):
self,
prompt = "Save",
keys = (
("all flows", "a"),
("listed flows", "l"),
("this flow", "t"),
("marked flows", "m"),
),
callback = self.save_flows_prompt,
)
elif key == "X":
self.flow.kill(self.master)
if not self.flow.reply.acked:
self.flow.kill(self.master)
elif key == "enter":
if self.flow.request:
self.master.view_flow(self.flow)
@ -356,7 +349,8 @@ class FlowListBox(urwid.ListBox):
return
scheme, host, port, path = parts
f = self.master.create_request(method, scheme, host, port, path)
self.master.view_flow(f)
self.master.state.set_focus_flow(f)
signals.flowlist_change.send(self)
def keypress(self, size, key):
key = common.shortcuts(key)

View File

@ -6,6 +6,7 @@ import sys
import traceback
import urwid
from typing import Optional, Union # noqa
from mitmproxy import contentviews
from mitmproxy import controller
@ -38,7 +39,7 @@ def _mkhelp():
("d", "delete flow"),
("e", "edit request/response"),
("f", "load full body data"),
("m", "change body display mode for this entity"),
("m", "change body display mode for this entity\n(default mode can be changed in the options)"),
(None,
common.highlight_key("automatic", "a") +
[("text", ": automatic detection")]
@ -75,7 +76,6 @@ def _mkhelp():
common.highlight_key("xml", "x") +
[("text", ": XML")]
),
("M", "change default body display mode"),
("E", "export flow to file"),
("r", "replay request"),
("V", "revert changes to request"),
@ -105,7 +105,8 @@ footer = [
class FlowViewHeader(urwid.WidgetWrap):
def __init__(self, master, f):
self.master, self.flow = master, f
self.master = master # type: "mitmproxy.console.master.ConsoleMaster"
self.flow = f # type: models.HTTPFlow
self._w = common.format_flow(
f,
False,
@ -135,14 +136,15 @@ class FlowView(tabs.Tabs):
def __init__(self, master, state, flow, tab_offset):
self.master, self.state, self.flow = master, state, flow
tabs.Tabs.__init__(self,
[
(self.tab_request, self.view_request),
(self.tab_response, self.view_response),
(self.tab_details, self.view_details),
],
tab_offset
)
super(FlowView, self).__init__(
[
(self.tab_request, self.view_request),
(self.tab_response, self.view_response),
(self.tab_details, self.view_details),
],
tab_offset
)
self.show()
self.last_displayed_body = None
signals.flow_change.connect(self.sig_flow_change)
@ -189,15 +191,21 @@ class FlowView(tabs.Tabs):
limit = sys.maxsize
else:
limit = contentviews.VIEW_CUTOFF
flow_modify_cache_invalidation = hash((
message.raw_content,
message.headers.fields,
getattr(message, "path", None),
))
return cache.get(
self._get_content_view,
# We move message into this partial function as it is not hashable.
lambda *args: self._get_content_view(message, *args),
viewmode,
message,
limit,
message # Cache invalidation
flow_modify_cache_invalidation
)
def _get_content_view(self, viewmode, message, max_lines, _):
def _get_content_view(self, message, viewmode, max_lines, _):
try:
content = message.content
@ -396,7 +404,7 @@ class FlowView(tabs.Tabs):
if not self.flow.response:
self.flow.response = models.HTTPResponse(
self.flow.request.http_version,
200, "OK", Headers(), ""
200, b"OK", Headers(), b""
)
self.flow.response.reply = controller.DummyReply()
message = self.flow.response
@ -524,30 +532,24 @@ class FlowView(tabs.Tabs):
)
signals.flow_change.send(self, flow = self.flow)
def delete_body(self, t):
if self.tab_offset == TAB_REQ:
self.flow.request.content = None
else:
self.flow.response.content = None
signals.flow_change.send(self, flow = self.flow)
def keypress(self, size, key):
conn = None # type: Optional[Union[models.HTTPRequest, models.HTTPResponse]]
if self.tab_offset == TAB_REQ:
conn = self.flow.request
elif self.tab_offset == TAB_RESP:
conn = self.flow.response
key = super(self.__class__, self).keypress(size, key)
# Special case: Space moves over to the next flow.
# We need to catch that before applying common.shortcuts()
if key == " ":
self.view_next_flow(self.flow)
return
key = common.shortcuts(key)
if self.tab_offset == TAB_REQ:
conn = self.flow.request
elif self.tab_offset == TAB_RESP:
conn = self.flow.response
else:
conn = None
if key in ("up", "down", "page up", "page down"):
# Why doesn't this just work??
# Pass scroll events to the wrapped widget
self._w.keypress(size, key)
elif key == "a":
self.flow.accept_intercept(self.master)
@ -563,10 +565,12 @@ class FlowView(tabs.Tabs):
else:
self.view_next_flow(self.flow)
f = self.flow
f.kill(self.master)
if not f.reply.acked:
f.kill(self.master)
self.state.delete_flow(f)
elif key == "D":
f = self.master.duplicate_flow(self.flow)
signals.pop_view_state.send(self)
self.master.view_flow(f)
signals.status_message.send(message="Duplicated.")
elif key == "p":
@ -577,12 +581,12 @@ class FlowView(tabs.Tabs):
signals.status_message.send(message=r)
signals.flow_change.send(self, flow = self.flow)
elif key == "V":
if not self.flow.modified():
if self.flow.modified():
self.state.revert(self.flow)
signals.flow_change.send(self, flow = self.flow)
signals.status_message.send(message="Reverted.")
else:
signals.status_message.send(message="Flow not modified.")
return
self.state.revert(self.flow)
signals.flow_change.send(self, flow = self.flow)
signals.status_message.send(message="Reverted.")
elif key == "W":
signals.status_prompt_path.send(
prompt = "Save this flow",
@ -595,133 +599,128 @@ class FlowView(tabs.Tabs):
callback = self.master.run_script_once,
args = (self.flow,)
)
if not conn and key in set(list("befgmxvzEC")):
elif key == "e":
if self.tab_offset == TAB_REQ:
signals.status_prompt_onekey.send(
prompt="Edit request",
keys=(
("cookies", "c"),
("query", "q"),
("path", "p"),
("url", "u"),
("header", "h"),
("form", "f"),
("raw body", "r"),
("method", "m"),
),
callback=self.edit
)
elif self.tab_offset == TAB_RESP:
signals.status_prompt_onekey.send(
prompt="Edit response",
keys=(
("cookies", "c"),
("code", "o"),
("message", "m"),
("header", "h"),
("raw body", "r"),
),
callback=self.edit
)
else:
signals.status_message.send(
message="Tab to the request or response",
expire=1
)
elif key in set("bfgmxvzEC") and not conn:
signals.status_message.send(
message = "Tab to the request or response",
expire = 1
)
elif conn:
if key == "b":
if self.tab_offset == TAB_REQ:
common.ask_save_body(
"q", self.master, self.state, self.flow
)
return
elif key == "b":
if self.tab_offset == TAB_REQ:
common.ask_save_body("q", self.flow)
else:
common.ask_save_body("s", self.flow)
elif key == "f":
signals.status_message.send(message="Loading all body data...")
self.state.add_flow_setting(
self.flow,
(self.tab_offset, "fullcontents"),
True
)
signals.flow_change.send(self, flow = self.flow)
signals.status_message.send(message="")
elif key == "m":
p = list(contentviews.view_prompts)
p.insert(0, ("Clear", "C"))
signals.status_prompt_onekey.send(
self,
prompt = "Display mode",
keys = p,
callback = self.change_this_display_mode
)
elif key == "E":
if self.tab_offset == TAB_REQ:
scope = "q"
else:
scope = "s"
signals.status_prompt_onekey.send(
self,
prompt = "Export to file",
keys = [(e[0], e[1]) for e in export.EXPORTERS],
callback = common.export_to_clip_or_file,
args = (scope, self.flow, common.ask_save_path)
)
elif key == "C":
if self.tab_offset == TAB_REQ:
scope = "q"
else:
scope = "s"
signals.status_prompt_onekey.send(
self,
prompt = "Export to clipboard",
keys = [(e[0], e[1]) for e in export.EXPORTERS],
callback = common.export_to_clip_or_file,
args = (scope, self.flow, common.copy_to_clipboard_or_prompt)
)
elif key == "x":
conn.content = None
signals.flow_change.send(self, flow=self.flow)
elif key == "v":
if conn.raw_content:
t = conn.headers.get("content-type")
if "EDITOR" in os.environ or "PAGER" in os.environ:
self.master.spawn_external_viewer(conn.get_content(strict=False), t)
else:
common.ask_save_body(
"s", self.master, self.state, self.flow
signals.status_message.send(
message = "Error! Set $EDITOR or $PAGER."
)
elif key == "e":
if self.tab_offset == TAB_REQ:
signals.status_prompt_onekey.send(
prompt = "Edit request",
keys = (
("cookies", "c"),
("query", "q"),
("path", "p"),
("url", "u"),
("header", "h"),
("form", "f"),
("raw body", "r"),
("method", "m"),
),
callback = self.edit
elif key == "z":
self.flow.backup()
e = conn.headers.get("content-encoding", "identity")
if e != "identity":
try:
conn.decode()
except ValueError:
signals.status_message.send(
message = "Could not decode - invalid data?"
)
else:
signals.status_prompt_onekey.send(
prompt = "Edit response",
keys = (
("cookies", "c"),
("code", "o"),
("message", "m"),
("header", "h"),
("raw body", "r"),
),
callback = self.edit
)
key = None
elif key == "f":
signals.status_message.send(message="Loading all body data...")
self.state.add_flow_setting(
self.flow,
(self.tab_offset, "fullcontents"),
True
)
signals.flow_change.send(self, flow = self.flow)
signals.status_message.send(message="")
elif key == "m":
p = list(contentviews.view_prompts)
p.insert(0, ("Clear", "C"))
else:
signals.status_prompt_onekey.send(
self,
prompt = "Display mode",
keys = p,
callback = self.change_this_display_mode
)
key = None
elif key == "E":
if self.tab_offset == TAB_REQ:
scope = "q"
else:
scope = "s"
signals.status_prompt_onekey.send(
self,
prompt = "Export to file",
keys = [(e[0], e[1]) for e in export.EXPORTERS],
callback = common.export_to_clip_or_file,
args = (scope, self.flow, common.ask_save_path)
)
elif key == "C":
if self.tab_offset == TAB_REQ:
scope = "q"
else:
scope = "s"
signals.status_prompt_onekey.send(
self,
prompt = "Export to clipboard",
keys = [(e[0], e[1]) for e in export.EXPORTERS],
callback = common.export_to_clip_or_file,
args = (scope, self.flow, common.copy_to_clipboard_or_prompt)
)
elif key == "x":
signals.status_prompt_onekey.send(
prompt = "Delete body",
prompt = "Select encoding: ",
keys = (
("completely", "c"),
("mark as missing", "m"),
("gzip", "z"),
("deflate", "d"),
),
callback = self.delete_body
callback = self.encode_callback,
args = (conn,)
)
key = None
elif key == "v":
if conn.raw_content:
t = conn.headers.get("content-type")
if "EDITOR" in os.environ or "PAGER" in os.environ:
self.master.spawn_external_viewer(conn.get_content(strict=False), t)
else:
signals.status_message.send(
message = "Error! Set $EDITOR or $PAGER."
)
elif key == "z":
self.flow.backup()
e = conn.headers.get("content-encoding", "identity")
if e != "identity":
if not conn.decode():
signals.status_message.send(
message = "Could not decode - invalid data?"
)
else:
signals.status_prompt_onekey.send(
prompt = "Select encoding: ",
keys = (
("gzip", "z"),
("deflate", "d"),
),
callback = self.encode_callback,
args = (conn,)
)
signals.flow_change.send(self, flow = self.flow)
return key
signals.flow_change.send(self, flow = self.flow)
else:
# Key is not handled here.
return key
def encode_callback(self, key, conn):
encoding_map = {

View File

@ -1,5 +1,7 @@
from __future__ import absolute_import, print_function, division
import platform
import urwid
from mitmproxy import filt
@ -9,7 +11,7 @@ from mitmproxy.console import signals
from netlib import version
footer = [
("heading", 'mitmproxy v%s ' % version.VERSION),
("heading", 'mitmproxy {} (Python {}) '.format(version.VERSION, platform.python_version())),
('heading_key', "q"), ":back ",
]

View File

@ -34,6 +34,7 @@ from mitmproxy.console import palettes
from mitmproxy.console import signals
from mitmproxy.console import statusbar
from mitmproxy.console import window
from mitmproxy.filt import FMarked
from netlib import tcp, strutils
EVENTLOG_SIZE = 500
@ -48,7 +49,7 @@ class ConsoleState(flow.State):
self.default_body_view = contentviews.get("Auto")
self.flowsettings = weakref.WeakKeyDictionary()
self.last_search = None
self.last_filter = None
self.last_filter = ""
self.mark_filter = False
def __setattr__(self, name, value):
@ -66,7 +67,6 @@ class ConsoleState(flow.State):
def add_flow(self, f):
super(ConsoleState, self).add_flow(f)
self.update_focus()
self.set_flow_marked(f, False)
return f
def update_flow(self, f):
@ -86,10 +86,10 @@ class ConsoleState(flow.State):
def set_focus(self, idx):
if self.view:
if idx >= len(self.view):
idx = len(self.view) - 1
elif idx < 0:
if idx is None or idx < 0:
idx = 0
elif idx >= len(self.view):
idx = len(self.view) - 1
self.focus = idx
else:
self.focus = None
@ -123,48 +123,71 @@ class ConsoleState(flow.State):
self.set_focus(self.focus)
return ret
def filter_marked(self, m):
def actual_func(x):
if x.id in m:
return True
return False
return actual_func
def get_nearest_matching_flow(self, flow, filt):
fidx = self.view.index(flow)
dist = 1
fprev = fnext = True
while fprev or fnext:
fprev, _ = self.get_from_pos(fidx - dist)
fnext, _ = self.get_from_pos(fidx + dist)
if fprev and fprev.match(filt):
return fprev
elif fnext and fnext.match(filt):
return fnext
dist += 1
return None
def enable_marked_filter(self):
marked_flows = [f for f in self.flows if f.marked]
if not marked_flows:
return
marked_filter = "~%s" % FMarked.code
# Save Focus
last_focus, _ = self.get_focus()
nearest_marked = self.get_nearest_matching_flow(last_focus, marked_filter)
self.last_filter = self.limit_txt
marked_flows = []
for f in self.flows:
if self.flow_marked(f):
marked_flows.append(f.id)
if len(marked_flows) > 0:
f = self.filter_marked(marked_flows)
self.view._close()
self.view = flow.FlowView(self.flows, f)
self.focus = 0
self.set_focus(self.focus)
self.mark_filter = True
self.set_limit(marked_filter)
# Restore Focus
if last_focus.marked:
self.set_focus_flow(last_focus)
else:
self.set_focus_flow(nearest_marked)
self.mark_filter = True
def disable_marked_filter(self):
if self.last_filter is None:
self.view = flow.FlowView(self.flows, None)
marked_filter = "~%s" % FMarked.code
# Save Focus
last_focus, _ = self.get_focus()
nearest_marked = self.get_nearest_matching_flow(last_focus, marked_filter)
self.set_limit(self.last_filter)
self.last_filter = ""
# Restore Focus
if last_focus.marked:
self.set_focus_flow(last_focus)
else:
self.set_limit(self.last_filter)
self.focus = 0
self.set_focus(self.focus)
self.last_filter = None
self.set_focus_flow(nearest_marked)
self.mark_filter = False
def clear(self):
marked_flows = []
for f in self.flows:
if self.flow_marked(f):
marked_flows.append(f)
marked_flows = [f for f in self.view if f.marked]
super(ConsoleState, self).clear()
for f in marked_flows:
self.add_flow(f)
self.set_flow_marked(f, True)
f.marked = True
if len(self.flows.views) == 0:
self.focus = None
@ -172,12 +195,6 @@ class ConsoleState(flow.State):
self.focus = 0
self.set_focus(self.focus)
def flow_marked(self, flow):
return self.get_flow_setting(flow, "marked", False)
def set_flow_marked(self, flow, marked):
self.add_flow_setting(flow, "marked", marked)
class Options(mitmproxy.options.Options):
def __init__(
@ -242,7 +259,7 @@ class ConsoleMaster(flow.FlowMaster):
signals.pop_view_state.connect(self.sig_pop_view_state)
signals.push_view_state.connect(self.sig_push_view_state)
signals.sig_add_log.connect(self.sig_add_log)
self.addons.add(*builtins.default_addons())
self.addons.add(options, *builtins.default_addons())
def __setattr__(self, name, value):
self.__dict__[name] = value
@ -254,10 +271,6 @@ class ConsoleMaster(flow.FlowMaster):
expire=1
)
def load_script(self, command, use_reloader=True):
# We default to using the reloader in the console ui.
return super(ConsoleMaster, self).load_script(command, use_reloader)
def sig_add_log(self, sender, e, level):
if self.options.verbosity < utils.log_tier(level):
return
@ -352,7 +365,7 @@ class ConsoleMaster(flow.FlowMaster):
try:
return flow.read_flows_from_paths(path)
except exceptions.FlowReadException as e:
signals.status_message.send(message=e.strerror)
signals.status_message.send(message=str(e))
def client_playback_path(self, path):
if not isinstance(path, list):
@ -619,13 +632,6 @@ class ConsoleMaster(flow.FlowMaster):
def save_flows(self, path):
return self._write_flows(path, self.state.view)
def save_marked_flows(self, path):
marked_flows = []
for f in self.state.view:
if self.state.flow_marked(f):
marked_flows.append(f)
return self._write_flows(path, marked_flows)
def load_flows_callback(self, path):
if not path:
return
@ -748,10 +754,3 @@ class ConsoleMaster(flow.FlowMaster):
direction=direction,
), "info")
self.add_log(strutils.bytes_to_escaped_str(message.content), "debug")
@controller.handler
def script_change(self, script):
if super(ConsoleMaster, self).script_change(script):
signals.status_message.send(message='"{}" reloaded.'.format(script.path))
else:
signals.status_message.send(message='Error reloading "{}".'.format(script.path))

View File

@ -140,7 +140,7 @@ class Options(urwid.WidgetWrap):
)
self.master.loop.widget.footer.update("")
signals.update_settings.connect(self.sig_update_settings)
master.options.changed.connect(self.sig_update_settings)
master.options.changed.connect(lambda sender, updated: self.sig_update_settings(sender))
def sig_update_settings(self, sender):
self.lb.walker._modified()

View File

@ -78,9 +78,9 @@ class Searchable(urwid.ListBox):
return
# Start search at focus + 1
if backwards:
rng = xrange(len(self.body) - 1, -1, -1)
rng = range(len(self.body) - 1, -1, -1)
else:
rng = xrange(1, len(self.body) + 1)
rng = range(1, len(self.body) + 1)
for i in rng:
off = (self.focus_position + i) % len(self.body)
w = self.body[off]

View File

@ -124,7 +124,7 @@ class StatusBar(urwid.WidgetWrap):
super(StatusBar, self).__init__(urwid.Pile([self.ib, self.master.ab]))
signals.update_settings.connect(self.sig_update_settings)
signals.flowlist_change.connect(self.sig_update_settings)
master.options.changed.connect(self.sig_update_settings)
master.options.changed.connect(lambda sender, updated: self.sig_update_settings(sender))
self.redraw()
def sig_update_settings(self, sender):
@ -171,10 +171,6 @@ class StatusBar(urwid.WidgetWrap):
r.append("[")
r.append(("heading_key", "l"))
r.append(":%s]" % self.master.state.limit_txt)
if self.master.state.mark_filter:
r.append("[")
r.append(("heading_key", "Marked Flows"))
r.append("]")
if self.master.options.stickycookie:
r.append("[")
r.append(("heading_key", "t"))

View File

@ -25,7 +25,7 @@ class Tab(urwid.WidgetWrap):
class Tabs(urwid.WidgetWrap):
def __init__(self, tabs, tab_offset=0):
urwid.WidgetWrap.__init__(self, "")
super(Tabs, self).__init__("")
self.tab_offset = tab_offset
self.tabs = tabs
self.show()

View File

@ -20,6 +20,8 @@ import logging
import subprocess
import sys
from typing import Mapping # noqa
import html2text
import lxml.etree
import lxml.html
@ -76,6 +78,7 @@ def pretty_json(s):
def format_dict(d):
# type: (Mapping[Union[str,bytes], Union[str,bytes]]) -> Generator[Tuple[Union[str,bytes], Union[str,bytes]]]
"""
Helper function that transforms the given dictionary into a list of
("key", key )
@ -85,7 +88,7 @@ def format_dict(d):
max_key_len = max(len(k) for k in d.keys())
max_key_len = min(max_key_len, KEY_MAX)
for key, value in d.items():
key += ":"
key += b":" if isinstance(key, bytes) else u":"
key = key.ljust(max_key_len + 2)
yield [
("header", key),
@ -106,12 +109,16 @@ class View(object):
prompt = ()
content_types = []
def __call__(self, data, **metadata):
def __call__(
self,
data, # type: bytes
**metadata
):
"""
Transform raw data into human-readable output.
Args:
data: the data to decode/format as bytes.
data: the data to decode/format.
metadata: optional keyword-only arguments for metadata. Implementations must not
rely on a given argument being present.
@ -278,6 +285,10 @@ class ViewURLEncoded(View):
content_types = ["application/x-www-form-urlencoded"]
def __call__(self, data, **metadata):
try:
data = data.decode("ascii", "strict")
except ValueError:
return None
d = url.decode(data)
return "URLEncoded form", format_dict(multidict.MultiDict(d))

View File

@ -37,8 +37,6 @@ Events = frozenset([
"configure",
"done",
"tick",
"script_change",
])

View File

@ -1,4 +1,4 @@
from typing import Callable # noqa
master = None # type: "mitmproxy.flow.FlowMaster"
log = None # type: Callable[[str], None]
log = None # type: "mitmproxy.controller.Log"

View File

@ -42,8 +42,8 @@ class DumpMaster(flow.FlowMaster):
def __init__(self, server, options):
flow.FlowMaster.__init__(self, options, server, flow.State())
self.has_errored = False
self.addons.add(*builtins.default_addons())
self.addons.add(dumper.Dumper())
self.addons.add(options, *builtins.default_addons())
self.addons.add(options, dumper.Dumper())
# This line is just for type hinting
self.options = self.options # type: Options
self.replay_ignore_params = options.replay_ignore_params

View File

@ -39,9 +39,12 @@ import functools
from mitmproxy.models.http import HTTPFlow
from mitmproxy.models.tcp import TCPFlow
from mitmproxy.models.flow import Flow
from netlib import strutils
import pyparsing as pp
from typing import Callable
def only(*types):
@ -80,6 +83,14 @@ class FErr(_Action):
return True if f.error else False
class FMarked(_Action):
code = "marked"
help = "Match marked flows"
def __call__(self, f):
return f.marked
class FHTTP(_Action):
code = "http"
help = "Match HTTP flows"
@ -398,6 +409,7 @@ filt_unary = [
FAsset,
FErr,
FHTTP,
FMarked,
FReq,
FResp,
FTCP,
@ -471,7 +483,11 @@ def _make():
bnf = _make()
TFilter = Callable[[Flow], bool]
def parse(s):
# type: (str) -> TFilter
try:
filt = bnf.parseString(s, parseAll=True)[0]
filt.pattern = s

View File

@ -60,6 +60,7 @@ def convert_017_018(data):
data = convert_unicode(data)
data["server_conn"]["ip_address"] = data["server_conn"].pop("peer_address")
data["marked"] = False
data["version"] = (0, 18)
return data

View File

@ -8,6 +8,8 @@ from mitmproxy import stateobject
from mitmproxy.models.connections import ClientConnection
from mitmproxy.models.connections import ServerConnection
import six
from netlib import version
from typing import Optional # noqa
@ -79,6 +81,7 @@ class Flow(stateobject.StateObject):
self.intercepted = False # type: bool
self._backup = None # type: Optional[Flow]
self.reply = None
self.marked = False # type: bool
_stateobject_attributes = dict(
id=str,
@ -86,7 +89,8 @@ class Flow(stateobject.StateObject):
client_conn=ClientConnection,
server_conn=ServerConnection,
type=str,
intercepted=bool
intercepted=bool,
marked=bool,
)
def get_state(self):
@ -173,3 +177,21 @@ class Flow(stateobject.StateObject):
self.intercepted = False
self.reply.ack()
master.handle_accept_intercept(self)
def match(self, f):
"""
Match this flow against a compiled filter expression. Returns True
if matched, False if not.
If f is a string, it will be compiled as a filter expression. If
the expression is invalid, ValueError is raised.
"""
if isinstance(f, six.string_types):
from .. import filt
f = filt.parse(f)
if not f:
raise ValueError("Invalid filter expression.")
if f:
return f(self)
return True

View File

@ -2,7 +2,6 @@ from __future__ import absolute_import, print_function, division
import cgi
import warnings
import six
from mitmproxy.models.flow import Flow
from netlib import version
@ -211,24 +210,6 @@ class HTTPFlow(Flow):
f.response = self.response.copy()
return f
def match(self, f):
"""
Match this flow against a compiled filter expression. Returns True
if matched, False if not.
If f is a string, it will be compiled as a filter expression. If
the expression is invalid, ValueError is raised.
"""
if isinstance(f, six.string_types):
from .. import filt
f = filt.parse(f)
if not f:
raise ValueError("Invalid filter expression.")
if f:
return f(self)
return True
def replace(self, pattern, repl, *args, **kwargs):
"""
Replaces a regular expression pattern with repl in both request and

View File

@ -7,8 +7,6 @@ from typing import List
import netlib.basetypes
from mitmproxy.models.flow import Flow
import six
class TCPMessage(netlib.basetypes.Serializable):
@ -55,22 +53,3 @@ class TCPFlow(Flow):
def __repr__(self):
return "<TCPFlow ({} messages)>".format(len(self.messages))
def match(self, f):
"""
Match this flow against a compiled filter expression. Returns True
if matched, False if not.
If f is a string, it will be compiled as a filter expression. If
the expression is invalid, ValueError is raised.
"""
if isinstance(f, six.string_types):
from .. import filt
f = filt.parse(f)
if not f:
raise ValueError("Invalid filter expression.")
if f:
return f(self)
return True

View File

@ -35,7 +35,7 @@ class OptManager(object):
self.__dict__["_initialized"] = True
@contextlib.contextmanager
def rollback(self):
def rollback(self, updated):
old = self._opts.copy()
try:
yield
@ -44,7 +44,7 @@ class OptManager(object):
self.errored.send(self, exc=e)
# Rollback
self.__dict__["_opts"] = old
self.changed.send(self)
self.changed.send(self, updated=updated)
def __eq__(self, other):
return self._opts == other._opts
@ -62,22 +62,22 @@ class OptManager(object):
if not self._initialized:
self._opts[attr] = value
return
if attr not in self._opts:
raise KeyError("No such option: %s" % attr)
with self.rollback():
self._opts[attr] = value
self.changed.send(self)
self.update(**{attr: value})
def keys(self):
return set(self._opts.keys())
def get(self, k, d=None):
return self._opts.get(k, d)
def update(self, **kwargs):
updated = set(kwargs.keys())
for k in kwargs:
if k not in self._opts:
raise KeyError("No such option: %s" % k)
with self.rollback():
with self.rollback(updated):
self._opts.update(kwargs)
self.changed.send(self)
self.changed.send(self, updated=updated)
def setter(self, attr):
"""

View File

@ -584,6 +584,8 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
except exceptions.ProtocolException as e: # pragma: no cover
self.log(repr(e), "info")
self.log(traceback.format_exc(), "debug")
except exceptions.Kill:
self.log("Connection killed", "info")
if not self.zombie:
self.zombie = time.time()

View File

@ -79,10 +79,10 @@ class ProxyConfig:
self.certstore = None
self.clientcerts = None
self.openssl_verification_mode_server = None
self.configure(options)
self.configure(options, set(options.keys()))
options.changed.connect(self.configure)
def configure(self, options):
def configure(self, options, updated):
conflict = all(
[
options.add_upstream_certs_to_client_chain,

View File

@ -234,7 +234,8 @@ class AcceptFlow(RequestHandler):
class FlowHandler(RequestHandler):
def delete(self, flow_id):
self.flow.kill(self.master)
if not self.flow.reply.acked:
self.flow.kill(self.master)
self.state.delete_flow(self.flow)
def put(self, flow_id):

View File

@ -136,7 +136,7 @@ class WebMaster(flow.FlowMaster):
def __init__(self, server, options):
super(WebMaster, self).__init__(options, server, WebState())
self.addons.add(*builtins.default_addons())
self.addons.add(options, *builtins.default_addons())
self.app = app.Application(
self, self.options.wdebug, self.options.wauthenticator
)

View File

@ -4,6 +4,7 @@ Utility functions for decoding response bodies.
from __future__ import absolute_import
import codecs
import collections
from io import BytesIO
import gzip
import zlib
@ -11,7 +12,15 @@ import zlib
from typing import Union # noqa
def decode(obj, encoding, errors='strict'):
# We have a shared single-element cache for encoding and decoding.
# This is quite useful in practice, e.g.
# flow.request.content = flow.request.content.replace(b"foo", b"bar")
# does not require an .encode() call if content does not contain b"foo"
CachedDecode = collections.namedtuple("CachedDecode", "encoded encoding errors decoded")
_cache = CachedDecode(None, None, None, None)
def decode(encoded, encoding, errors='strict'):
# type: (Union[str, bytes], str, str) -> Union[str, bytes]
"""
Decode the given input object
@ -22,20 +31,32 @@ def decode(obj, encoding, errors='strict'):
Raises:
ValueError, if decoding fails.
"""
global _cache
cached = (
isinstance(encoded, bytes) and
_cache.encoded == encoded and
_cache.encoding == encoding and
_cache.errors == errors
)
if cached:
return _cache.decoded
try:
try:
return custom_decode[encoding](obj)
decoded = custom_decode[encoding](encoded)
except KeyError:
return codecs.decode(obj, encoding, errors)
decoded = codecs.decode(encoded, encoding, errors)
if encoding in ("gzip", "deflate"):
_cache = CachedDecode(encoded, encoding, errors, decoded)
return decoded
except Exception as e:
raise ValueError("{} when decoding {} with {}".format(
type(e).__name__,
repr(obj)[:10],
repr(encoded)[:10],
repr(encoding),
))
def encode(obj, encoding, errors='strict'):
def encode(decoded, encoding, errors='strict'):
# type: (Union[str, bytes], str, str) -> Union[str, bytes]
"""
Encode the given input object
@ -46,15 +67,27 @@ def encode(obj, encoding, errors='strict'):
Raises:
ValueError, if encoding fails.
"""
global _cache
cached = (
isinstance(decoded, bytes) and
_cache.decoded == decoded and
_cache.encoding == encoding and
_cache.errors == errors
)
if cached:
return _cache.encoded
try:
try:
return custom_encode[encoding](obj)
encoded = custom_encode[encoding](decoded)
except KeyError:
return codecs.encode(obj, encoding, errors)
encoded = codecs.encode(decoded, encoding, errors)
if encoding in ("gzip", "deflate"):
_cache = CachedDecode(encoded, encoding, errors, decoded)
return encoded
except Exception as e:
raise ValueError("{} when encoding {} with {}".format(
type(e).__name__,
repr(obj)[:10],
repr(decoded)[:10],
repr(encoding),
))

View File

@ -32,9 +32,6 @@ class MessageData(basetypes.Serializable):
def __ne__(self, other):
return not self.__eq__(other)
def __hash__(self):
return hash(frozenset(self.__dict__.items()))
def set_state(self, state):
for k, v in state.items():
if k == "headers":
@ -52,23 +49,7 @@ class MessageData(basetypes.Serializable):
return cls(**state)
class CachedDecode(object):
__slots__ = ["encoded", "encoding", "strict", "decoded"]
def __init__(self, object, encoding, strict, decoded):
self.encoded = object
self.encoding = encoding
self.strict = strict
self.decoded = decoded
no_cached_decode = CachedDecode(None, None, None, None)
class Message(basetypes.Serializable):
def __init__(self):
self._content_cache = no_cached_decode # type: CachedDecode
self._text_cache = no_cached_decode # type: CachedDecode
def __eq__(self, other):
if isinstance(other, Message):
return self.data == other.data
@ -77,9 +58,6 @@ class Message(basetypes.Serializable):
def __ne__(self, other):
return not self.__eq__(other)
def __hash__(self):
return hash(self.data) ^ 1
def get_state(self):
return self.data.get_state()
@ -132,25 +110,15 @@ class Message(basetypes.Serializable):
if self.raw_content is None:
return None
ce = self.headers.get("content-encoding")
cached = (
self._content_cache.encoded == self.raw_content and
(self._content_cache.strict or not strict) and
self._content_cache.encoding == ce
)
if not cached:
is_strict = True
if ce:
try:
decoded = encoding.decode(self.raw_content, ce)
except ValueError:
if strict:
raise
is_strict = False
decoded = self.raw_content
else:
decoded = self.raw_content
self._content_cache = CachedDecode(self.raw_content, ce, is_strict, decoded)
return self._content_cache.decoded
if ce:
try:
return encoding.decode(self.raw_content, ce)
except ValueError:
if strict:
raise
return self.raw_content
else:
return self.raw_content
def set_content(self, value):
if value is None:
@ -163,22 +131,13 @@ class Message(basetypes.Serializable):
.format(type(value).__name__)
)
ce = self.headers.get("content-encoding")
cached = (
self._content_cache.decoded == value and
self._content_cache.encoding == ce and
self._content_cache.strict
)
if not cached:
try:
encoded = encoding.encode(value, ce or "identity")
except ValueError:
# So we have an invalid content-encoding?
# Let's remove it!
del self.headers["content-encoding"]
ce = None
encoded = value
self._content_cache = CachedDecode(encoded, ce, True, value)
self.raw_content = self._content_cache.encoded
try:
self.raw_content = encoding.encode(value, ce or "identity")
except ValueError:
# So we have an invalid content-encoding?
# Let's remove it!
del self.headers["content-encoding"]
self.raw_content = value
self.headers["content-length"] = str(len(self.raw_content))
content = property(get_content, set_content)
@ -250,22 +209,12 @@ class Message(basetypes.Serializable):
enc = self._guess_encoding()
content = self.get_content(strict)
cached = (
self._text_cache.encoded == content and
(self._text_cache.strict or not strict) and
self._text_cache.encoding == enc
)
if not cached:
is_strict = self._content_cache.strict
try:
decoded = encoding.decode(content, enc)
except ValueError:
if strict:
raise
is_strict = False
decoded = self.content.decode("utf8", "replace" if six.PY2 else "surrogateescape")
self._text_cache = CachedDecode(content, enc, is_strict, decoded)
return self._text_cache.decoded
try:
return encoding.decode(content, enc)
except ValueError:
if strict:
raise
return content.decode("utf8", "replace" if six.PY2 else "surrogateescape")
def set_text(self, text):
if text is None:
@ -273,23 +222,15 @@ class Message(basetypes.Serializable):
return
enc = self._guess_encoding()
cached = (
self._text_cache.decoded == text and
self._text_cache.encoding == enc and
self._text_cache.strict
)
if not cached:
try:
encoded = encoding.encode(text, enc)
except ValueError:
# Fall back to UTF-8 and update the content-type header.
ct = headers.parse_content_type(self.headers.get("content-type", "")) or ("text", "plain", {})
ct[2]["charset"] = "utf-8"
self.headers["content-type"] = headers.assemble_content_type(*ct)
enc = "utf8"
encoded = text.encode(enc, "replace" if six.PY2 else "surrogateescape")
self._text_cache = CachedDecode(encoded, enc, True, text)
self.content = self._text_cache.encoded
try:
self.content = encoding.encode(text, enc)
except ValueError:
# Fall back to UTF-8 and update the content-type header.
ct = headers.parse_content_type(self.headers.get("content-type", "")) or ("text", "plain", {})
ct[2]["charset"] = "utf-8"
self.headers["content-type"] = headers.assemble_content_type(*ct)
enc = "utf8"
self.content = text.encode(enc, "replace" if six.PY2 else "surrogateescape")
text = property(get_text, set_text)

View File

@ -253,14 +253,13 @@ class Request(message.Message):
)
def _get_query(self):
_, _, _, _, query, _ = urllib.parse.urlparse(self.url)
query = urllib.parse.urlparse(self.url).query
return tuple(netlib.http.url.decode(query))
def _set_query(self, value):
query = netlib.http.url.encode(value)
scheme, netloc, path, params, _, fragment = urllib.parse.urlparse(self.url)
_, _, _, self.path = netlib.http.url.parse(
urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment]))
def _set_query(self, query_data):
query = netlib.http.url.encode(query_data)
_, _, path, params, _, fragment = urllib.parse.urlparse(self.url)
self.path = urllib.parse.urlunparse(["", "", path, params, query, fragment])
@query.setter
def query(self, value):
@ -296,19 +295,18 @@ class Request(message.Message):
The URL's path components as a tuple of strings.
Components are unquoted.
"""
_, _, path, _, _, _ = urllib.parse.urlparse(self.url)
path = urllib.parse.urlparse(self.url).path
# This needs to be a tuple so that it's immutable.
# Otherwise, this would fail silently:
# request.path_components.append("foo")
return tuple(urllib.parse.unquote(i) for i in path.split("/") if i)
return tuple(netlib.http.url.unquote(i) for i in path.split("/") if i)
@path_components.setter
def path_components(self, components):
components = map(lambda x: urllib.parse.quote(x, safe=""), components)
components = map(lambda x: netlib.http.url.quote(x, safe=""), components)
path = "/" + "/".join(components)
scheme, netloc, _, params, query, fragment = urllib.parse.urlparse(self.url)
_, _, _, self.path = netlib.http.url.parse(
urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment]))
_, _, _, params, query, fragment = urllib.parse.urlparse(self.url)
self.path = urllib.parse.urlunparse(["", "", path, params, query, fragment])
def anticache(self):
"""
@ -365,13 +363,13 @@ class Request(message.Message):
pass
return ()
def _set_urlencoded_form(self, value):
def _set_urlencoded_form(self, form_data):
"""
Sets the body to the URL-encoded form data, and adds the appropriate content-type header.
This will overwrite the existing content if there is one.
"""
self.headers["content-type"] = "application/x-www-form-urlencoded"
self.content = netlib.http.url.encode(value).encode()
self.content = netlib.http.url.encode(form_data).encode()
@urlencoded_form.setter
def urlencoded_form(self, value):

View File

@ -82,18 +82,51 @@ def unparse(scheme, host, port, path=""):
def encode(s):
# type: Sequence[Tuple[str,str]] -> str
"""
Takes a list of (key, value) tuples and returns a urlencoded string.
"""
s = [tuple(i) for i in s]
return urllib.parse.urlencode(s, False)
if six.PY2:
return urllib.parse.urlencode(s, False)
else:
return urllib.parse.urlencode(s, False, errors="surrogateescape")
def decode(s):
"""
Takes a urlencoded string and returns a list of (key, value) tuples.
Takes a urlencoded string and returns a list of surrogate-escaped (key, value) tuples.
"""
return urllib.parse.parse_qsl(s, keep_blank_values=True)
if six.PY2:
return urllib.parse.parse_qsl(s, keep_blank_values=True)
else:
return urllib.parse.parse_qsl(s, keep_blank_values=True, errors='surrogateescape')
def quote(b, safe="/"):
"""
Returns:
An ascii-encodable str.
"""
# type: (str) -> str
if six.PY2:
return urllib.parse.quote(b, safe=safe)
else:
return urllib.parse.quote(b, safe=safe, errors="surrogateescape")
def unquote(s):
"""
Args:
s: A surrogate-escaped str
Returns:
A surrogate-escaped str
"""
# type: (str) -> str
if six.PY2:
return urllib.parse.unquote(s)
else:
return urllib.parse.unquote(s, errors="surrogateescape")
def hostport(scheme, host, port):

View File

@ -79,9 +79,6 @@ class _MultiDict(MutableMapping, basetypes.Serializable):
def __ne__(self, other):
return not self.__eq__(other)
def __hash__(self):
return hash(self.fields)
def get_all(self, key):
"""
Return the list of all values for a given key.
@ -241,6 +238,9 @@ class ImmutableMultiDict(MultiDict):
__delitem__ = set_all = insert = _immutable
def __hash__(self):
return hash(self.fields)
def with_delitem(self, key):
"""
Returns:

View File

@ -51,8 +51,7 @@ else:
def escape_control_characters(text, keep_spacing=True):
"""
Replace all unicode C1 control characters from the given text with their respective control pictures.
For example, a null byte is replaced with the unicode character "\u2400".
Replace all unicode C1 control characters from the given text with a single "."
Args:
keep_spacing: If True, tabs and newlines will not be replaced.
@ -99,6 +98,9 @@ def bytes_to_escaped_str(data, keep_spacing=False):
def escaped_str_to_bytes(data):
"""
Take an escaped string and return the unescaped bytes equivalent.
Raises:
ValueError, if the escape sequence is invalid.
"""
if not isinstance(data, six.string_types):
if six.PY2:

View File

@ -8,9 +8,10 @@ from mitmproxy import options
class TestAntiCache(mastertest.MasterTest):
def test_simple(self):
s = state.State()
m = master.FlowMaster(options.Options(anticache = True), None, s)
o = options.Options(anticache = True)
m = master.FlowMaster(o, None, s)
sa = anticache.AntiCache()
m.addons.add(sa)
m.addons.add(o, sa)
f = tutils.tflow(resp=True)
self.invoke(m, "request", f)

View File

@ -8,9 +8,10 @@ from mitmproxy import options
class TestAntiComp(mastertest.MasterTest):
def test_simple(self):
s = state.State()
m = master.FlowMaster(options.Options(anticomp = True), None, s)
o = options.Options(anticomp = True)
m = master.FlowMaster(o, None, s)
sa = anticomp.AntiComp()
m.addons.add(sa)
m.addons.add(o, sa)
f = tutils.tflow(resp=True)
self.invoke(m, "request", f)

View File

@ -15,26 +15,27 @@ class TestDumper(mastertest.MasterTest):
d = dumper.Dumper()
sio = StringIO()
d.configure(dump.Options(tfile = sio, flow_detail = 0))
updated = set(["tfile", "flow_detail"])
d.configure(dump.Options(tfile = sio, flow_detail = 0), updated)
d.response(tutils.tflow())
assert not sio.getvalue()
d.configure(dump.Options(tfile = sio, flow_detail = 4))
d.configure(dump.Options(tfile = sio, flow_detail = 4), updated)
d.response(tutils.tflow())
assert sio.getvalue()
sio = StringIO()
d.configure(dump.Options(tfile = sio, flow_detail = 4))
d.configure(dump.Options(tfile = sio, flow_detail = 4), updated)
d.response(tutils.tflow(resp=True))
assert "<<" in sio.getvalue()
sio = StringIO()
d.configure(dump.Options(tfile = sio, flow_detail = 4))
d.configure(dump.Options(tfile = sio, flow_detail = 4), updated)
d.response(tutils.tflow(err=True))
assert "<<" in sio.getvalue()
sio = StringIO()
d.configure(dump.Options(tfile = sio, flow_detail = 4))
d.configure(dump.Options(tfile = sio, flow_detail = 4), updated)
flow = tutils.tflow()
flow.request = netlib.tutils.treq()
flow.request.stickycookie = True
@ -47,7 +48,7 @@ class TestDumper(mastertest.MasterTest):
assert sio.getvalue()
sio = StringIO()
d.configure(dump.Options(tfile = sio, flow_detail = 4))
d.configure(dump.Options(tfile = sio, flow_detail = 4), updated)
flow = tutils.tflow(resp=netlib.tutils.tresp(content=b"{"))
flow.response.headers["content-type"] = "application/json"
flow.response.status_code = 400
@ -55,7 +56,7 @@ class TestDumper(mastertest.MasterTest):
assert sio.getvalue()
sio = StringIO()
d.configure(dump.Options(tfile = sio))
d.configure(dump.Options(tfile = sio), updated)
flow = tutils.tflow()
flow.request.content = None
flow.response = models.HTTPResponse.wrap(netlib.tutils.tresp())
@ -72,15 +73,13 @@ class TestContentView(mastertest.MasterTest):
s = state.State()
sio = StringIO()
m = mastertest.RecordingMaster(
dump.Options(
flow_detail=4,
verbosity=3,
tfile=sio,
),
None, s
o = dump.Options(
flow_detail=4,
verbosity=3,
tfile=sio,
)
m = mastertest.RecordingMaster(o, None, s)
d = dumper.Dumper()
m.addons.add(d)
m.addons.add(o, d)
self.invoke(m, "response", tutils.tflow())
assert "Content viewer failed" in m.event_log[0][1]

View File

@ -20,16 +20,13 @@ class TestStream(mastertest.MasterTest):
return list(r.stream())
s = state.State()
m = master.FlowMaster(
options.Options(
outfile = (p, "wb")
),
None,
s
o = options.Options(
outfile = (p, "wb")
)
m = master.FlowMaster(o, None, s)
sa = filestreamer.FileStreamer()
m.addons.add(sa)
m.addons.add(o, sa)
f = tutils.tflow(resp=True)
self.invoke(m, "request", f)
self.invoke(m, "response", f)
@ -39,7 +36,7 @@ class TestStream(mastertest.MasterTest):
m.options.outfile = (p, "ab")
m.addons.add(sa)
m.addons.add(o, sa)
f = tutils.tflow()
self.invoke(m, "request", f)
m.addons.remove(sa)

View File

@ -8,38 +8,38 @@ from mitmproxy import options
class TestReplace(mastertest.MasterTest):
def test_configure(self):
r = replace.Replace()
updated = set(["replacements"])
r.configure(options.Options(
replacements=[("one", "two", "three")]
))
), updated)
tutils.raises(
"invalid filter pattern",
r.configure,
options.Options(
replacements=[("~b", "two", "three")]
)
),
updated
)
tutils.raises(
"invalid regular expression",
r.configure,
options.Options(
replacements=[("foo", "+", "three")]
)
),
updated
)
def test_simple(self):
s = state.State()
m = master.FlowMaster(
options.Options(
replacements = [
("~q", "foo", "bar"),
("~s", "foo", "bar"),
]
),
None,
s
o = options.Options(
replacements = [
("~q", "foo", "bar"),
("~s", "foo", "bar"),
]
)
m = master.FlowMaster(o, None, s)
sa = replace.Replace()
m.addons.add(sa)
m.addons.add(o, sa)
f = tutils.tflow()
f.request.content = b"foo"

View File

@ -48,39 +48,41 @@ def test_load_script():
"data/addonscripts/recorder.py"
), []
)
assert ns["configure"]
assert ns.start
class TestScript(mastertest.MasterTest):
def test_simple(self):
s = state.State()
m = master.FlowMaster(options.Options(), None, s)
o = options.Options()
m = master.FlowMaster(o, None, s)
sc = script.Script(
tutils.test_data.path(
"data/addonscripts/recorder.py"
)
)
m.addons.add(sc)
assert sc.ns["call_log"] == [
m.addons.add(o, sc)
assert sc.ns.call_log == [
("solo", "start", (), {}),
("solo", "configure", (options.Options(),), {})
("solo", "configure", (o, o.keys()), {})
]
sc.ns["call_log"] = []
sc.ns.call_log = []
f = tutils.tflow(resp=True)
self.invoke(m, "request", f)
recf = sc.ns["call_log"][0]
recf = sc.ns.call_log[0]
assert recf[1] == "request"
def test_reload(self):
s = state.State()
m = mastertest.RecordingMaster(options.Options(), None, s)
o = options.Options()
m = mastertest.RecordingMaster(o, None, s)
with tutils.tmpdir():
with open("foo.py", "w"):
pass
sc = script.Script("foo.py")
m.addons.add(sc)
m.addons.add(o, sc)
for _ in range(100):
with open("foo.py", "a") as f:
@ -93,19 +95,22 @@ class TestScript(mastertest.MasterTest):
def test_exception(self):
s = state.State()
m = mastertest.RecordingMaster(options.Options(), None, s)
o = options.Options()
m = mastertest.RecordingMaster(o, None, s)
sc = script.Script(
tutils.test_data.path("data/addonscripts/error.py")
)
m.addons.add(sc)
m.addons.add(o, sc)
f = tutils.tflow(resp=True)
self.invoke(m, "request", f)
assert m.event_log[0][0] == "error"
def test_duplicate_flow(self):
s = state.State()
fm = master.FlowMaster(None, None, s)
o = options.Options()
fm = master.FlowMaster(o, None, s)
fm.addons.add(
o,
script.Script(
tutils.test_data.path("data/addonscripts/duplicate_flow.py")
)
@ -116,6 +121,20 @@ class TestScript(mastertest.MasterTest):
assert not fm.state.view[0].request.is_replay
assert fm.state.view[1].request.is_replay
def test_addon(self):
s = state.State()
o = options.Options()
m = master.FlowMaster(o, None, s)
sc = script.Script(
tutils.test_data.path(
"data/addonscripts/addon.py"
)
)
m.addons.add(o, sc)
assert sc.ns.event_log == [
'scriptstart', 'addonstart', 'addonconfigure'
]
class TestScriptLoader(mastertest.MasterTest):
def test_simple(self):
@ -123,7 +142,7 @@ class TestScriptLoader(mastertest.MasterTest):
o = options.Options(scripts=[])
m = master.FlowMaster(o, None, s)
sc = script.ScriptLoader()
m.addons.add(sc)
m.addons.add(o, sc)
assert len(m.addons) == 1
o.update(
scripts = [
@ -139,7 +158,7 @@ class TestScriptLoader(mastertest.MasterTest):
o = options.Options(scripts=["one", "one"])
m = master.FlowMaster(o, None, s)
sc = script.ScriptLoader()
tutils.raises(exceptions.OptionsError, m.addons.add, sc)
tutils.raises(exceptions.OptionsError, m.addons.add, o, sc)
def test_order(self):
rec = tutils.test_data.path("data/addonscripts/recorder.py")
@ -154,7 +173,7 @@ class TestScriptLoader(mastertest.MasterTest):
)
m = mastertest.RecordingMaster(o, None, s)
sc = script.ScriptLoader()
m.addons.add(sc)
m.addons.add(o, sc)
debug = [(i[0], i[1]) for i in m.event_log if i[0] == "debug"]
assert debug == [

View File

@ -8,19 +8,20 @@ from mitmproxy import options
class TestSetHeaders(mastertest.MasterTest):
def mkmaster(self, **opts):
s = state.State()
m = mastertest.RecordingMaster(options.Options(**opts), None, s)
o = options.Options(**opts)
m = mastertest.RecordingMaster(o, None, s)
sh = setheaders.SetHeaders()
m.addons.add(sh)
m.addons.add(o, sh)
return m, sh
def test_configure(self):
sh = setheaders.SetHeaders()
o = options.Options(
setheaders = [("~b", "one", "two")]
)
tutils.raises(
"invalid setheader filter pattern",
sh.configure,
options.Options(
setheaders = [("~b", "one", "two")]
)
sh.configure, o, o.keys()
)
def test_setheaders(self):

View File

@ -8,9 +8,10 @@ from mitmproxy import options
class TestStickyAuth(mastertest.MasterTest):
def test_simple(self):
s = state.State()
m = master.FlowMaster(options.Options(stickyauth = ".*"), None, s)
o = options.Options(stickyauth = ".*")
m = master.FlowMaster(o, None, s)
sa = stickyauth.StickyAuth()
m.addons.add(sa)
m.addons.add(o, sa)
f = tutils.tflow(resp=True)
f.request.headers["authorization"] = "foo"

View File

@ -14,22 +14,23 @@ def test_domain_match():
class TestStickyCookie(mastertest.MasterTest):
def mk(self):
s = state.State()
m = master.FlowMaster(options.Options(stickycookie = ".*"), None, s)
o = options.Options(stickycookie = ".*")
m = master.FlowMaster(o, None, s)
sc = stickycookie.StickyCookie()
m.addons.add(sc)
m.addons.add(o, sc)
return s, m, sc
def test_config(self):
sc = stickycookie.StickyCookie()
o = options.Options(stickycookie = "~b")
tutils.raises(
"invalid filter",
sc.configure,
options.Options(stickycookie = "~b")
sc.configure, o, o.keys()
)
def test_simple(self):
s, m, sc = self.mk()
m.addons.add(sc)
m.addons.add(m.options, sc)
f = tutils.tflow(resp=True)
f.response.headers["set-cookie"] = "foo=bar"

View File

@ -0,0 +1,22 @@
event_log = []
class Addon:
@property
def event_log(self):
return event_log
def start(self):
event_log.append("addonstart")
def configure(self, options, updated):
event_log.append("addonconfigure")
def configure(options, updated):
event_log.append("addonconfigure")
def start():
event_log.append("scriptstart")
return Addon()

View File

@ -2,24 +2,24 @@ from mitmproxy import controller
from mitmproxy import ctx
import sys
call_log = []
if len(sys.argv) > 1:
name = sys.argv[1]
else:
name = "solo"
class CallLogger:
call_log = []
# Keep a log of all possible event calls
evts = list(controller.Events) + ["configure"]
for i in evts:
def mkprox():
evt = i
def __init__(self, name = "solo"):
self.name = name
def prox(*args, **kwargs):
lg = (name, evt, args, kwargs)
if evt != "log":
ctx.log.info(str(lg))
call_log.append(lg)
ctx.log.debug("%s %s" % (name, evt))
return prox
globals()[i] = mkprox()
def __getattr__(self, attr):
if attr in controller.Events:
def prox(*args, **kwargs):
lg = (self.name, attr, args, kwargs)
if attr != "log":
ctx.log.info(str(lg))
self.call_log.append(lg)
ctx.log.debug("%s %s" % (self.name, attr))
return prox
raise AttributeError
def start():
return CallLogger(*sys.argv[1:])

Binary file not shown.

View File

@ -23,7 +23,7 @@ class TestConcurrent(mastertest.MasterTest):
"data/addonscripts/concurrent_decorator.py"
)
)
m.addons.add(sc)
m.addons.add(m.options, sc)
f1, f2 = tutils.tflow(), tutils.tflow()
self.invoke(m, "request", f1)
self.invoke(m, "request", f2)

View File

@ -13,8 +13,9 @@ class TAddon:
def test_simple():
m = controller.Master(options.Options())
o = options.Options()
m = controller.Master(o)
a = addons.Addons(m)
a.add(TAddon("one"))
a.add(o, TAddon("one"))
assert a.has_addon("one")
assert not a.has_addon("two")

View File

@ -59,10 +59,10 @@ class TestContentView:
assert f[0] == "Query"
def test_view_urlencoded(self):
d = url.encode([("one", "two"), ("three", "four")])
d = url.encode([("one", "two"), ("three", "four")]).encode()
v = cv.ViewURLEncoded()
assert v(d)
d = url.encode([("adsfa", "")])
d = url.encode([("adsfa", "")]).encode()
v = cv.ViewURLEncoded()
assert v(d)

View File

@ -27,10 +27,11 @@ class RaiseMaster(master.FlowMaster):
def tscript(cmd, args=""):
o = options.Options()
cmd = example_dir.path(cmd) + " " + args
m = RaiseMaster(options.Options(), None, state.State())
m = RaiseMaster(o, None, state.State())
sc = script.Script(cmd)
m.addons.add(sc)
m.addons.add(o, sc)
return m, sc

View File

@ -615,6 +615,7 @@ class TestSerialize:
def test_roundtrip(self):
sio = io.BytesIO()
f = tutils.tflow()
f.marked = True
f.request.content = bytes(bytearray(range(256)))
w = flow.FlowWriter(sio)
w.add(f)
@ -627,6 +628,7 @@ class TestSerialize:
f2 = l[0]
assert f2.get_state() == f.get_state()
assert f2.request == f.request
assert f2.marked
def test_load_flows(self):
r = self._treader()

View File

@ -15,6 +15,8 @@ class TO(optmanager.OptManager):
def test_options():
o = TO(two="three")
assert o.keys() == set(["one", "two"])
assert o.one is None
assert o.two == "three"
o.one = "one"
@ -29,7 +31,7 @@ def test_options():
rec = []
def sub(opts):
def sub(opts, updated):
rec.append(copy.copy(opts))
o.changed.connect(sub)
@ -68,7 +70,7 @@ def test_rollback():
rec = []
def sub(opts):
def sub(opts, updated):
rec.append(copy.copy(opts))
recerr = []
@ -76,7 +78,7 @@ def test_rollback():
def errsub(opts, **kwargs):
recerr.append(kwargs)
def err(opts):
def err(opts, updated):
if opts.one == "ten":
raise exceptions.OptionsError()

View File

@ -30,7 +30,7 @@ logging.getLogger("PIL.PngImagePlugin").setLevel(logging.WARNING)
requires_alpn = pytest.mark.skipif(
not netlib.tcp.HAS_ALPN,
reason="requires OpenSSL with ALPN support")
reason='requires OpenSSL with ALPN support')
class _Http2ServerBase(netlib_tservers.ServerTestBase):
@ -80,7 +80,7 @@ class _Http2ServerBase(netlib_tservers.ServerTestBase):
print(traceback.format_exc())
break
def handle_server_event(self, h2_conn, rfile, wfile):
def handle_server_event(self, event, h2_conn, rfile, wfile):
raise NotImplementedError()
@ -88,7 +88,6 @@ class _Http2TestBase(object):
@classmethod
def setup_class(cls):
cls.masteroptions = options.Options()
opts = cls.get_options()
cls.config = ProxyConfig(opts)
@ -145,12 +144,14 @@ class _Http2TestBase(object):
wfile,
h2_conn,
stream_id=1,
headers=[],
headers=None,
body=b'',
end_stream=None,
priority_exclusive=None,
priority_depends_on=None,
priority_weight=None):
if headers is None:
headers = []
if end_stream is None:
end_stream = (len(body) == 0)
@ -172,12 +173,12 @@ class _Http2TestBase(object):
class _Http2Test(_Http2TestBase, _Http2ServerBase):
@classmethod
def setup_class(self):
def setup_class(cls):
_Http2TestBase.setup_class()
_Http2ServerBase.setup_class()
@classmethod
def teardown_class(self):
def teardown_class(cls):
_Http2TestBase.teardown_class()
_Http2ServerBase.teardown_class()
@ -187,7 +188,7 @@ class TestSimple(_Http2Test):
request_body_buffer = b''
@classmethod
def handle_server_event(self, event, h2_conn, rfile, wfile):
def handle_server_event(cls, event, h2_conn, rfile, wfile):
if isinstance(event, h2.events.ConnectionTerminated):
return False
elif isinstance(event, h2.events.RequestReceived):
@ -214,7 +215,7 @@ class TestSimple(_Http2Test):
wfile.write(h2_conn.data_to_send())
wfile.flush()
elif isinstance(event, h2.events.DataReceived):
self.request_body_buffer += event.data
cls.request_body_buffer += event.data
return True
def test_simple(self):
@ -225,7 +226,7 @@ class TestSimple(_Http2Test):
client.wfile,
h2_conn,
headers=[
(':authority', "127.0.0.1:%s" % self.server.server.address.port),
(':authority', "127.0.0.1:{}".format(self.server.server.address.port)),
(':method', 'GET'),
(':scheme', 'https'),
(':path', '/'),
@ -269,7 +270,7 @@ class TestSimple(_Http2Test):
class TestRequestWithPriority(_Http2Test):
@classmethod
def handle_server_event(self, event, h2_conn, rfile, wfile):
def handle_server_event(cls, event, h2_conn, rfile, wfile):
if isinstance(event, h2.events.ConnectionTerminated):
return False
elif isinstance(event, h2.events.RequestReceived):
@ -301,14 +302,14 @@ class TestRequestWithPriority(_Http2Test):
client.wfile,
h2_conn,
headers=[
(':authority', "127.0.0.1:%s" % self.server.server.address.port),
(':authority', "127.0.0.1:{}".format(self.server.server.address.port)),
(':method', 'GET'),
(':scheme', 'https'),
(':path', '/'),
],
priority_exclusive = True,
priority_depends_on = 42424242,
priority_weight = 42,
priority_exclusive=True,
priority_depends_on=42424242,
priority_weight=42,
)
done = False
@ -343,7 +344,7 @@ class TestRequestWithPriority(_Http2Test):
client.wfile,
h2_conn,
headers=[
(':authority', "127.0.0.1:%s" % self.server.server.address.port),
(':authority', "127.0.0.1:{}".format(self.server.server.address.port)),
(':method', 'GET'),
(':scheme', 'https'),
(':path', '/'),
@ -381,11 +382,11 @@ class TestPriority(_Http2Test):
priority_data = None
@classmethod
def handle_server_event(self, event, h2_conn, rfile, wfile):
def handle_server_event(cls, event, h2_conn, rfile, wfile):
if isinstance(event, h2.events.ConnectionTerminated):
return False
elif isinstance(event, h2.events.PriorityUpdated):
self.priority_data = (event.exclusive, event.depends_on, event.weight)
cls.priority_data = (event.exclusive, event.depends_on, event.weight)
elif isinstance(event, h2.events.RequestReceived):
import warnings
with warnings.catch_warnings():
@ -415,7 +416,7 @@ class TestPriority(_Http2Test):
client.wfile,
h2_conn,
headers=[
(':authority', "127.0.0.1:%s" % self.server.server.address.port),
(':authority', "127.0.0.1:{}".format(self.server.server.address.port)),
(':method', 'GET'),
(':scheme', 'https'),
(':path', '/'),
@ -451,11 +452,11 @@ class TestPriorityWithExistingStream(_Http2Test):
priority_data = []
@classmethod
def handle_server_event(self, event, h2_conn, rfile, wfile):
def handle_server_event(cls, event, h2_conn, rfile, wfile):
if isinstance(event, h2.events.ConnectionTerminated):
return False
elif isinstance(event, h2.events.PriorityUpdated):
self.priority_data.append((event.exclusive, event.depends_on, event.weight))
cls.priority_data.append((event.exclusive, event.depends_on, event.weight))
elif isinstance(event, h2.events.RequestReceived):
assert not event.priority_updated
@ -486,7 +487,7 @@ class TestPriorityWithExistingStream(_Http2Test):
client.wfile,
h2_conn,
headers=[
(':authority', "127.0.0.1:%s" % self.server.server.address.port),
(':authority', "127.0.0.1:{}".format(self.server.server.address.port)),
(':method', 'GET'),
(':scheme', 'https'),
(':path', '/'),
@ -527,7 +528,7 @@ class TestPriorityWithExistingStream(_Http2Test):
class TestStreamResetFromServer(_Http2Test):
@classmethod
def handle_server_event(self, event, h2_conn, rfile, wfile):
def handle_server_event(cls, event, h2_conn, rfile, wfile):
if isinstance(event, h2.events.ConnectionTerminated):
return False
elif isinstance(event, h2.events.RequestReceived):
@ -543,7 +544,7 @@ class TestStreamResetFromServer(_Http2Test):
client.wfile,
h2_conn,
headers=[
(':authority', "127.0.0.1:%s" % self.server.server.address.port),
(':authority', "127.0.0.1:{}".format(self.server.server.address.port)),
(':method', 'GET'),
(':scheme', 'https'),
(':path', '/'),
@ -578,7 +579,7 @@ class TestStreamResetFromServer(_Http2Test):
class TestBodySizeLimit(_Http2Test):
@classmethod
def handle_server_event(self, event, h2_conn, rfile, wfile):
def handle_server_event(cls, event, h2_conn, rfile, wfile):
if isinstance(event, h2.events.ConnectionTerminated):
return False
return True
@ -592,7 +593,7 @@ class TestBodySizeLimit(_Http2Test):
client.wfile,
h2_conn,
headers=[
(':authority', "127.0.0.1:%s" % self.server.server.address.port),
(':authority', "127.0.0.1:{}".format(self.server.server.address.port)),
(':method', 'GET'),
(':scheme', 'https'),
(':path', '/'),
@ -627,7 +628,7 @@ class TestBodySizeLimit(_Http2Test):
class TestPushPromise(_Http2Test):
@classmethod
def handle_server_event(self, event, h2_conn, rfile, wfile):
def handle_server_event(cls, event, h2_conn, rfile, wfile):
if isinstance(event, h2.events.ConnectionTerminated):
return False
elif isinstance(event, h2.events.RequestReceived):
@ -637,14 +638,14 @@ class TestPushPromise(_Http2Test):
h2_conn.send_headers(1, [(':status', '200')])
h2_conn.push_stream(1, 2, [
(':authority', "127.0.0.1:%s" % self.port),
(':authority', "127.0.0.1:{}".format(cls.port)),
(':method', 'GET'),
(':scheme', 'https'),
(':path', '/pushed_stream_foo'),
('foo', 'bar')
])
h2_conn.push_stream(1, 4, [
(':authority', "127.0.0.1:%s" % self.port),
(':authority', "127.0.0.1:{}".format(cls.port)),
(':method', 'GET'),
(':scheme', 'https'),
(':path', '/pushed_stream_bar'),
@ -675,7 +676,7 @@ class TestPushPromise(_Http2Test):
client, h2_conn = self._setup_connection()
self._send_request(client.wfile, h2_conn, stream_id=1, headers=[
(':authority', "127.0.0.1:%s" % self.server.server.address.port),
(':authority', "127.0.0.1:{}".format(self.server.server.address.port)),
(':method', 'GET'),
(':scheme', 'https'),
(':path', '/'),
@ -728,7 +729,7 @@ class TestPushPromise(_Http2Test):
client, h2_conn = self._setup_connection()
self._send_request(client.wfile, h2_conn, stream_id=1, headers=[
(':authority', "127.0.0.1:%s" % self.server.server.address.port),
(':authority', "127.0.0.1:{}".format(self.server.server.address.port)),
(':method', 'GET'),
(':scheme', 'https'),
(':path', '/'),
@ -780,7 +781,7 @@ class TestPushPromise(_Http2Test):
class TestConnectionLost(_Http2Test):
@classmethod
def handle_server_event(self, event, h2_conn, rfile, wfile):
def handle_server_event(cls, event, h2_conn, rfile, wfile):
if isinstance(event, h2.events.RequestReceived):
h2_conn.send_headers(1, [(':status', '200')])
wfile.write(h2_conn.data_to_send())
@ -791,7 +792,7 @@ class TestConnectionLost(_Http2Test):
client, h2_conn = self._setup_connection()
self._send_request(client.wfile, h2_conn, stream_id=1, headers=[
(':authority', "127.0.0.1:%s" % self.server.server.address.port),
(':authority', "127.0.0.1:{}".format(self.server.server.address.port)),
(':method', 'GET'),
(':scheme', 'https'),
(':path', '/'),
@ -822,12 +823,12 @@ class TestConnectionLost(_Http2Test):
class TestMaxConcurrentStreams(_Http2Test):
@classmethod
def setup_class(self):
def setup_class(cls):
_Http2TestBase.setup_class()
_Http2ServerBase.setup_class(h2_server_settings={h2.settings.MAX_CONCURRENT_STREAMS: 2})
@classmethod
def handle_server_event(self, event, h2_conn, rfile, wfile):
def handle_server_event(cls, event, h2_conn, rfile, wfile):
if isinstance(event, h2.events.ConnectionTerminated):
return False
elif isinstance(event, h2.events.RequestReceived):
@ -848,7 +849,7 @@ class TestMaxConcurrentStreams(_Http2Test):
# this will exceed MAX_CONCURRENT_STREAMS on the server connection
# and cause mitmproxy to throttle stream creation to the server
self._send_request(client.wfile, h2_conn, stream_id=id, headers=[
(':authority', "127.0.0.1:%s" % self.server.server.address.port),
(':authority', "127.0.0.1:{}".format(self.server.server.address.port)),
(':method', 'GET'),
(':scheme', 'https'),
(':path', '/'),
@ -883,7 +884,7 @@ class TestMaxConcurrentStreams(_Http2Test):
class TestConnectionTerminated(_Http2Test):
@classmethod
def handle_server_event(self, event, h2_conn, rfile, wfile):
def handle_server_event(cls, event, h2_conn, rfile, wfile):
if isinstance(event, h2.events.RequestReceived):
h2_conn.close_connection(error_code=5, last_stream_id=42, additional_data=b'foobar')
wfile.write(h2_conn.data_to_send())
@ -894,7 +895,7 @@ class TestConnectionTerminated(_Http2Test):
client, h2_conn = self._setup_connection()
self._send_request(client.wfile, h2_conn, headers=[
(':authority', "127.0.0.1:%s" % self.server.server.address.port),
(':authority', "127.0.0.1:{}".format(self.server.server.address.port)),
(':method', 'GET'),
(':scheme', 'https'),
(':path', '/'),

View File

@ -291,7 +291,7 @@ class TestHTTP(tservers.HTTPProxyTest, CommonMixin, AppMixin):
s = script.Script(
tutils.test_data.path("data/addonscripts/stream_modify.py")
)
self.master.addons.add(s)
self.master.addons.add(self.master.options, s)
d = self.pathod('200:b"foo"')
assert d.content == b"bar"
self.master.addons.remove(s)
@ -523,7 +523,7 @@ class TestTransparent(tservers.TransparentProxyTest, CommonMixin, TcpMixin):
s = script.Script(
tutils.test_data.path("data/addonscripts/tcp_stream_modify.py")
)
self.master.addons.add(s)
self.master.addons.add(self.master.options, s)
self._tcpproxy_on()
d = self.pathod('200:b"foo"')
self._tcpproxy_off()

View File

@ -34,7 +34,7 @@ class TestMaster(flow.FlowMaster):
s = ProxyServer(config)
state = flow.State()
flow.FlowMaster.__init__(self, opts, s, state)
self.addons.add(*builtins.default_addons())
self.addons.add(opts, *builtins.default_addons())
self.apps.add(testapp, "testapp", 80)
self.apps.add(errapp, "errapp", 80)
self.clear_log()

View File

@ -1,7 +1,6 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import, print_function, division
import mock
import six
from netlib.tutils import tresp
@ -71,10 +70,6 @@ class TestMessage(object):
assert resp != 0
def test_hash(self):
resp = tresp()
assert hash(resp)
def test_serializable(self):
resp = tresp()
resp2 = http.Response.from_state(resp.get_state())
@ -117,14 +112,6 @@ class TestMessageContentEncoding(object):
assert r.content == b"message"
assert r.raw_content != b"message"
r.raw_content = b"foo"
with mock.patch("netlib.encoding.decode") as e:
assert r.content
assert e.call_count == 1
e.reset_mock()
assert r.content
assert e.call_count == 0
def test_modify(self):
r = tresp()
assert "content-encoding" not in r.headers
@ -135,13 +122,6 @@ class TestMessageContentEncoding(object):
r.decode()
assert r.raw_content == b"foo"
r.encode("identity")
with mock.patch("netlib.encoding.encode") as e:
r.content = b"foo"
assert e.call_count == 0
r.content = b"bar"
assert e.call_count == 1
with tutils.raises(TypeError):
r.content = u"foo"
@ -216,15 +196,6 @@ class TestMessageText(object):
r.headers["content-type"] = "text/html; charset=utf8"
assert r.text == u"ü"
r.encode("identity")
r.raw_content = b"foo"
with mock.patch("netlib.encoding.decode") as e:
assert r.text
assert e.call_count == 2
e.reset_mock()
assert r.text
assert e.call_count == 0
def test_guess_json(self):
r = tresp(content=b'"\xc3\xbc"')
r.headers["content-type"] = "application/json"
@ -249,14 +220,6 @@ class TestMessageText(object):
assert r.raw_content == b"\xc3\xbc"
assert r.headers["content-length"] == "2"
r.encode("identity")
with mock.patch("netlib.encoding.encode") as e:
e.return_value = b""
r.text = u"ü"
assert e.call_count == 0
r.text = u"ä"
assert e.call_count == 2
def test_unknown_ce(self):
r = tresp()
r.headers["content-type"] = "text/html; charset=wtf"

View File

@ -1,3 +1,4 @@
import six
from netlib import tutils
from netlib.http import url
@ -57,10 +58,49 @@ def test_unparse():
assert url.unparse("https", "foo.com", 443, "") == "https://foo.com"
def test_urlencode():
if six.PY2:
surrogates = bytes(bytearray(range(256)))
else:
surrogates = bytes(range(256)).decode("utf8", "surrogateescape")
surrogates_quoted = (
'%00%01%02%03%04%05%06%07%08%09%0A%0B%0C%0D%0E%0F'
'%10%11%12%13%14%15%16%17%18%19%1A%1B%1C%1D%1E%1F'
'%20%21%22%23%24%25%26%27%28%29%2A%2B%2C-./'
'0123456789%3A%3B%3C%3D%3E%3F'
'%40ABCDEFGHIJKLMNO'
'PQRSTUVWXYZ%5B%5C%5D%5E_'
'%60abcdefghijklmno'
'pqrstuvwxyz%7B%7C%7D%7E%7F'
'%80%81%82%83%84%85%86%87%88%89%8A%8B%8C%8D%8E%8F'
'%90%91%92%93%94%95%96%97%98%99%9A%9B%9C%9D%9E%9F'
'%A0%A1%A2%A3%A4%A5%A6%A7%A8%A9%AA%AB%AC%AD%AE%AF'
'%B0%B1%B2%B3%B4%B5%B6%B7%B8%B9%BA%BB%BC%BD%BE%BF'
'%C0%C1%C2%C3%C4%C5%C6%C7%C8%C9%CA%CB%CC%CD%CE%CF'
'%D0%D1%D2%D3%D4%D5%D6%D7%D8%D9%DA%DB%DC%DD%DE%DF'
'%E0%E1%E2%E3%E4%E5%E6%E7%E8%E9%EA%EB%EC%ED%EE%EF'
'%F0%F1%F2%F3%F4%F5%F6%F7%F8%F9%FA%FB%FC%FD%FE%FF'
)
def test_encode():
assert url.encode([('foo', 'bar')])
assert url.encode([('foo', surrogates)])
def test_urldecode():
def test_decode():
s = "one=two&three=four"
assert len(url.decode(s)) == 2
assert url.decode(surrogates)
def test_quote():
assert url.quote("foo") == "foo"
assert url.quote("foo bar") == "foo%20bar"
assert url.quote(surrogates) == surrogates_quoted
def test_unquote():
assert url.unquote("foo") == "foo"
assert url.unquote("foo%20bar") == "foo bar"
assert url.unquote(surrogates_quoted) == surrogates

View File

@ -1,3 +1,4 @@
import mock
from netlib import encoding, tutils
@ -37,3 +38,32 @@ def test_deflate():
)
with tutils.raises(ValueError):
encoding.decode(b"bogus", "deflate")
def test_cache():
decode_gzip = mock.MagicMock()
decode_gzip.return_value = b"decoded"
encode_gzip = mock.MagicMock()
encode_gzip.return_value = b"encoded"
with mock.patch.dict(encoding.custom_decode, gzip=decode_gzip):
with mock.patch.dict(encoding.custom_encode, gzip=encode_gzip):
assert encoding.decode(b"encoded", "gzip") == b"decoded"
assert decode_gzip.call_count == 1
# should be cached
assert encoding.decode(b"encoded", "gzip") == b"decoded"
assert decode_gzip.call_count == 1
# the other way around as well
assert encoding.encode(b"decoded", "gzip") == b"encoded"
assert encode_gzip.call_count == 0
# different encoding
decode_gzip.return_value = b"bar"
assert encoding.encode(b"decoded", "deflate") != b"decoded"
assert encode_gzip.call_count == 0
# This is not in the cache anymore
assert encoding.encode(b"decoded", "gzip") == b"encoded"
assert encode_gzip.call_count == 1

View File

@ -45,7 +45,7 @@ class TestMultiDict(object):
assert md["foo"] == "bar"
with tutils.raises(KeyError):
md["bar"]
assert md["bar"]
md_multi = TMultiDict(
[("foo", "a"), ("foo", "b")]
@ -101,6 +101,15 @@ class TestMultiDict(object):
assert TMultiDict() != self._multi()
assert TMultiDict() != 42
def test_hash(self):
"""
If a class defines mutable objects and implements an __eq__() method,
it should not implement __hash__(), since the implementation of hashable
collections requires that a key's hash value is immutable.
"""
with tutils.raises(TypeError):
assert hash(TMultiDict())
def test_get_all(self):
md = self._multi()
assert md.get_all("foo") == ["bar"]
@ -197,6 +206,9 @@ class TestImmutableMultiDict(object):
with tutils.raises(TypeError):
md.add("foo", "bar")
def test_hash(self):
assert hash(TImmutableMultiDict())
def test_with_delitem(self):
md = TImmutableMultiDict([("foo", "bar")])
assert md.with_delitem("foo").fields == ()

View File

@ -1,6 +1,7 @@
import React, { PropTypes } from 'react'
import classnames from 'classnames'
import columns from './FlowColumns'
import { pure } from '../../utils'
FlowRow.propTypes = {
onSelect: PropTypes.func.isRequired,
@ -9,7 +10,7 @@ FlowRow.propTypes = {
selected: PropTypes.bool,
}
export default function FlowRow({ flow, selected, highlighted, onSelect }) {
function FlowRow({ flow, selected, highlighted, onSelect }) {
const className = classnames({
'selected': selected,
'highlighted': highlighted,
@ -19,10 +20,12 @@ export default function FlowRow({ flow, selected, highlighted, onSelect }) {
})
return (
<tr className={className} onClick={() => onSelect(flow)}>
<tr className={className} onClick={() => onSelect(flow.id)}>
{columns.map(Column => (
<Column key={Column.name} flow={flow}/>
))}
</tr>
)
}
export default pure(FlowRow)

View File

@ -22,7 +22,7 @@ class MainView extends Component {
flows={flows}
selected={selectedFlow}
highlight={highlight}
onSelect={flow => this.props.selectFlow(flow.id)}
onSelect={this.props.selectFlow}
/>
{selectedFlow && [
<Splitter key="splitter"/>,

View File

@ -1,7 +1,9 @@
import _ from "lodash";
import _ from 'lodash'
import React from 'react'
import shallowEqual from 'shallowequal'
window._ = _;
window.React = require("react");
window.React = React;
export var Key = {
UP: 38,
@ -106,15 +108,27 @@ fetchApi.put = (url, json, options) => fetchApi(
}
)
export function getDiff(obj1, obj2) {
let result = {...obj2};
for(let key in obj1) {
if(_.isEqual(obj2[key], obj1[key]))
result[key] = undefined;
result[key] = undefined
else if(!(Array.isArray(obj2[key]) && Array.isArray(obj1[key])) &&
typeof obj2[key] == 'object' && typeof obj1[key] == 'object')
result[key] = getDiff(obj1[key], obj2[key]);
result[key] = getDiff(obj1[key], obj2[key])
}
return result
}
export const pure = renderFn => class extends React.Component {
static displayName = renderFn.name
shouldComponentUpdate(nextProps) {
console.log(!shallowEqual(this.props, nextProps))
return !shallowEqual(this.props, nextProps)
}
render() {
return renderFn(this.props)
}
return result;
}