Consolidate reading flows from file, use in mitmweb.

This commit is contained in:
Aldo Cortesi 2015-01-02 13:26:22 +13:00
parent 4d01e22f26
commit 1b5f5021dc
6 changed files with 60 additions and 30 deletions

View File

@ -599,13 +599,20 @@ class ConsoleMaster(flow.FlowMaster):
self.view_flowlist() self.view_flowlist()
self.server.start_slave(controller.Slave, controller.Channel(self.masterq, self.should_exit)) self.server.start_slave(
controller.Slave,
controller.Channel(self.masterq, self.should_exit)
)
if self.options.rfile: if self.options.rfile:
ret = self.load_flows(self.options.rfile) ret = self.load_flows_path(self.options.rfile)
if ret and self.state.flow_count(): if ret and self.state.flow_count():
self.add_event("File truncated or corrupted. Loaded as many flows as possible.","error") self.add_event(
elif not self.state.flow_count(): "File truncated or corrupted. "
"Loaded as many flows as possible.",
"error"
)
elif ret and not self.state.flow_count():
self.shutdown() self.shutdown()
print >> sys.stderr, "Could not load file:", ret print >> sys.stderr, "Could not load file:", ret
sys.exit(1) sys.exit(1)
@ -700,23 +707,16 @@ class ConsoleMaster(flow.FlowMaster):
def load_flows_callback(self, path): def load_flows_callback(self, path):
if not path: if not path:
return return
ret = self.load_flows(path) ret = self.load_flows_path(path)
return ret or "Flows loaded from %s"%path return ret or "Flows loaded from %s"%path
def load_flows(self, path): def load_flows_path(self, path):
self.state.last_saveload = path self.state.last_saveload = path
path = os.path.expanduser(path)
try:
f = file(path, "rb")
fr = flow.FlowReader(f)
except IOError, v:
return v.strerror
reterr = None reterr = None
try: try:
flow.FlowMaster.load_flows(self, fr) flow.FlowMaster.load_flows_file(self, path)
except flow.FlowReadError, v: except flow.FlowReadError, v:
reterr = v.strerror reterr = str(v)
f.close()
if self.flow_list_walker: if self.flow_list_walker:
self.sync_list_view() self.sync_list_view()
return reterr return reterr

View File

@ -134,16 +134,11 @@ class DumpMaster(flow.FlowMaster):
raise DumpError(err) raise DumpError(err)
if options.rfile: if options.rfile:
path = os.path.expanduser(options.rfile)
try: try:
f = file(path, "rb") self.load_flows_file(options.rfile)
freader = flow.FlowReader(f)
except IOError, v:
raise DumpError(v.strerror)
try:
self.load_flows(freader)
except flow.FlowReadError, v: except flow.FlowReadError, v:
self.add_event("Flow file corrupted. Stopped loading.", "error") self.add_event("Flow file corrupted.", "error")
raise DumpError(v)
if self.o.app: if self.o.app:
self.start_app(self.o.app_host, self.o.app_port) self.start_app(self.o.app_host, self.o.app_port)

View File

@ -6,6 +6,7 @@ from abc import abstractmethod, ABCMeta
import hashlib import hashlib
import Cookie import Cookie
import cookielib import cookielib
import os
import re import re
from netlib import odict, wsgi from netlib import odict, wsgi
import netlib.http import netlib.http
@ -785,8 +786,20 @@ class FlowMaster(controller.Master):
""" """
Load flows from a FlowReader object. Load flows from a FlowReader object.
""" """
cnt = 0
for i in fr.stream(): for i in fr.stream():
cnt += 1
self.load_flow(i) self.load_flow(i)
return cnt
def load_flows_file(self, path):
path = os.path.expanduser(path)
try:
f = file(path, "rb")
freader = FlowReader(f)
except IOError, v:
raise FlowReadError(v.strerror)
return self.load_flows(freader)
def process_new_request(self, f): def process_new_request(self, f):
if self.stickycookie_state: if self.stickycookie_state:
@ -961,7 +974,9 @@ class FlowReader:
data = tnetstring.load(self.fo) data = tnetstring.load(self.fo)
if tuple(data["version"][:2]) != version.IVERSION[:2]: if tuple(data["version"][:2]) != version.IVERSION[:2]:
v = ".".join(str(i) for i in data["version"]) v = ".".join(str(i) for i in data["version"])
raise FlowReadError("Incompatible serialized data version: %s" % v) raise FlowReadError(
"Incompatible serialized data version: %s" % v
)
off = self.fo.tell() off = self.fo.tell()
yield handle.protocols[data["type"]]["flow"].from_state(data) yield handle.protocols[data["type"]]["flow"].from_state(data)
except ValueError, v: except ValueError, v:

View File

@ -2,6 +2,7 @@ from __future__ import absolute_import, print_function
import collections import collections
import tornado.ioloop import tornado.ioloop
import tornado.httpserver import tornado.httpserver
import os
from .. import controller, flow from .. import controller, flow
from . import app from . import app
@ -124,6 +125,14 @@ class WebMaster(flow.FlowMaster):
self.options = options self.options = options
super(WebMaster, self).__init__(server, WebState()) super(WebMaster, self).__init__(server, WebState())
self.app = app.Application(self, self.options.wdebug) self.app = app.Application(self, self.options.wdebug)
if options.rfile:
try:
print(self.load_flows_file(options.rfile))
except flow.FlowReadError, v:
self.add_event(
"Could not read flow file: %s"%v,
"error"
)
def tick(self): def tick(self):
flow.FlowMaster.tick(self, self.masterq, timeout=0) flow.FlowMaster.tick(self, self.masterq, timeout=0)

View File

@ -18,9 +18,12 @@ class RequestHandler(tornado.web.RequestHandler):
self.set_header("X-Frame-Options", "DENY") self.set_header("X-Frame-Options", "DENY")
self.add_header("X-XSS-Protection", "1; mode=block") self.add_header("X-XSS-Protection", "1; mode=block")
self.add_header("X-Content-Type-Options", "nosniff") self.add_header("X-Content-Type-Options", "nosniff")
self.add_header("Content-Security-Policy", "default-src 'self'; " self.add_header(
"connect-src 'self' ws://* ; " "Content-Security-Policy",
"style-src 'self' 'unsafe-inline'") "default-src 'self'; "
"connect-src 'self' ws://* ; "
"style-src 'self' 'unsafe-inline'"
)
@property @property
def state(self): def state(self):

View File

@ -99,15 +99,23 @@ class TestDumpMaster:
with tutils.tmpdir() as t: with tutils.tmpdir() as t:
p = os.path.join(t, "read") p = os.path.join(t, "read")
self._flowfile(p) self._flowfile(p)
assert "GET" in self._dummy_cycle(0, None, "", flow_detail=1, rfile=p) assert "GET" in self._dummy_cycle(
0,
None,
"",
flow_detail=1,
rfile=p
)
tutils.raises( tutils.raises(
dump.DumpError, self._dummy_cycle, dump.DumpError, self._dummy_cycle,
0, None, "", verbosity=1, rfile="/nonexistent" 0, None, "", verbosity=1, rfile="/nonexistent"
) )
tutils.raises(
dump.DumpError, self._dummy_cycle,
0, None, "", verbosity=1, rfile="test_dump.py"
)
# We now just ignore errors
self._dummy_cycle(0, None, "", verbosity=1, rfile=tutils.test_data.path("test_dump.py"))
def test_options(self): def test_options(self):
o = dump.Options(verbosity = 2) o = dump.Options(verbosity = 2)