Merge pull request #3056 from cortesi/readfile

readfile fixes
This commit is contained in:
Aldo Cortesi 2018-04-15 10:03:30 +12:00 committed by GitHub
commit 4e126c0fba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 102 additions and 57 deletions

View File

@ -1,9 +1,11 @@
import asyncio
import os.path import os.path
import sys import sys
import typing import typing
from mitmproxy import ctx from mitmproxy import ctx
from mitmproxy import exceptions from mitmproxy import exceptions
from mitmproxy import flowfilter
from mitmproxy import io from mitmproxy import io
@ -11,18 +13,38 @@ class ReadFile:
""" """
An addon that handles reading from file on startup. An addon that handles reading from file on startup.
""" """
def __init__(self):
self.filter = None
def load(self, loader): def load(self, loader):
loader.add_option( loader.add_option(
"rfile", typing.Optional[str], None, "rfile", typing.Optional[str], None,
"Read flows from file." "Read flows from file."
) )
loader.add_option(
"readfile_filter", typing.Optional[str], None,
"Read only matching flows."
)
def load_flows(self, fo: typing.IO[bytes]) -> int: def configure(self, updated):
if "readfile_filter" in updated:
filt = None
if ctx.options.readfile_filter:
filt = flowfilter.parse(ctx.options.readfile_filter)
if not filt:
raise exceptions.OptionsError(
"Invalid readfile filter: %s" % ctx.options.readfile_filter
)
self.filter = filt
async def load_flows(self, fo: typing.IO[bytes]) -> int:
cnt = 0 cnt = 0
freader = io.FlowReader(fo) freader = io.FlowReader(fo)
try: try:
for flow in freader.stream(): for flow in freader.stream():
ctx.master.load_flow(flow) if self.filter and not self.filter(flow):
continue
await ctx.master.load_flow(flow)
cnt += 1 cnt += 1
except (IOError, exceptions.FlowReadException) as e: except (IOError, exceptions.FlowReadException) as e:
if cnt: if cnt:
@ -33,29 +55,34 @@ class ReadFile:
else: else:
return cnt return cnt
def load_flows_from_path(self, path: str) -> int: async def load_flows_from_path(self, path: str) -> int:
path = os.path.expanduser(path) path = os.path.expanduser(path)
try: try:
with open(path, "rb") as f: with open(path, "rb") as f:
return self.load_flows(f) return await self.load_flows(f)
except IOError as e: except IOError as e:
ctx.log.error("Cannot load flows: {}".format(e)) ctx.log.error("Cannot load flows: {}".format(e))
raise exceptions.FlowReadException(str(e)) from e raise exceptions.FlowReadException(str(e)) from e
async def doread(self, rfile):
try:
await self.load_flows_from_path(ctx.options.rfile)
except exceptions.FlowReadException as e:
raise exceptions.OptionsError(e) from e
finally:
ctx.master.addons.trigger("processing_complete")
def running(self): def running(self):
if ctx.options.rfile: if ctx.options.rfile:
try: asyncio.get_event_loop().create_task(self.doread(ctx.options.rfile))
self.load_flows_from_path(ctx.options.rfile)
except exceptions.FlowReadException as e:
raise exceptions.OptionsError(e) from e
finally:
ctx.master.addons.trigger("processing_complete")
class ReadFileStdin(ReadFile): class ReadFileStdin(ReadFile):
"""Support the special case of "-" for reading from stdin""" """Support the special case of "-" for reading from stdin"""
def load_flows_from_path(self, path: str) -> int: async def load_flows_from_path(self, path: str) -> int:
if path == "-": if path == "-": # pragma: no cover
return self.load_flows(sys.stdin.buffer) # Need to think about how to test this. This function is scheduled
# onto the event loop, where a sys.stdin mock has no effect.
return await self.load_flows(sys.stdin.buffer)
else: else:
return super().load_flows_from_path(path) return await super().load_flows_from_path(path)

View File

@ -150,8 +150,9 @@ def mitmdump(args=None): # pragma: no cover
if args.filter_args: if args.filter_args:
v = " ".join(args.filter_args) v = " ".join(args.filter_args)
return dict( return dict(
view_filter=v,
save_stream_filter=v, save_stream_filter=v,
readfile_filter=v,
dumper_filter=v,
) )
return {} return {}

View File

@ -85,6 +85,7 @@ setup(
"pydivert>=2.0.3,<2.2", "pydivert>=2.0.3,<2.2",
], ],
'dev': [ 'dev': [
"asynctest>=0.12.0",
"flake8>=3.5, <3.6", "flake8>=3.5, <3.6",
"Flask>=0.10.1, <0.13", "Flask>=0.10.1, <0.13",
"mypy>=0.580,<0.581", "mypy>=0.580,<0.581",

View File

@ -1,7 +1,9 @@
import asyncio
import io import io
from unittest import mock from unittest import mock
import pytest import pytest
import asynctest
import mitmproxy.io import mitmproxy.io
from mitmproxy import exceptions from mitmproxy import exceptions
@ -38,69 +40,83 @@ def corrupt_data():
class TestReadFile: class TestReadFile:
@mock.patch('mitmproxy.master.Master.load_flow') def test_configure(self):
def test_configure(self, mck, tmpdir, data, corrupt_data): rf = readfile.ReadFile()
with taddons.context() as tctx:
tctx.configure(rf, readfile_filter="~q")
with pytest.raises(Exception, match="Invalid readfile filter"):
tctx.configure(rf, readfile_filter="~~")
@pytest.mark.asyncio
async def test_read(self, tmpdir, data, corrupt_data):
rf = readfile.ReadFile() rf = readfile.ReadFile()
with taddons.context(rf) as tctx: with taddons.context(rf) as tctx:
tf = tmpdir.join("tfile") tf = tmpdir.join("tfile")
tf.write(data.getvalue()) with asynctest.patch('mitmproxy.master.Master.load_flow') as mck:
tctx.configure(rf, rfile=str(tf)) tf.write(data.getvalue())
assert not mck.called tctx.configure(
rf.running() rf,
assert mck.called rfile = str(tf),
readfile_filter = ".*"
)
assert not mck.awaited
rf.running()
await asyncio.sleep(0)
assert mck.awaited
tf.write(corrupt_data.getvalue()) tf.write(corrupt_data.getvalue())
tctx.configure(rf, rfile=str(tf)) tctx.configure(rf, rfile=str(tf))
with pytest.raises(exceptions.OptionsError): rf.running()
rf.running() assert await tctx.master.await_log("corrupted")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_corrupt(self, corrupt_data): async def test_corrupt(self, corrupt_data):
rf = readfile.ReadFile() rf = readfile.ReadFile()
with taddons.context(rf) as tctx: with taddons.context(rf) as tctx:
with mock.patch('mitmproxy.master.Master.load_flow') as mck: with pytest.raises(exceptions.FlowReadException):
with pytest.raises(exceptions.FlowReadException): await rf.load_flows(io.BytesIO(b"qibble"))
rf.load_flows(io.BytesIO(b"qibble"))
assert not mck.called
tctx.master.clear() tctx.master.clear()
with pytest.raises(exceptions.FlowReadException): with pytest.raises(exceptions.FlowReadException):
rf.load_flows(corrupt_data) await rf.load_flows(corrupt_data)
assert await tctx.master.await_log("file corrupted") assert await tctx.master.await_log("file corrupted")
assert mck.called
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_nonexisting_file(self): async def test_nonexistent_file(self):
rf = readfile.ReadFile() rf = readfile.ReadFile()
with taddons.context(rf) as tctx: with taddons.context(rf) as tctx:
with pytest.raises(exceptions.FlowReadException): with pytest.raises(exceptions.FlowReadException):
rf.load_flows_from_path("nonexistent") await rf.load_flows_from_path("nonexistent")
assert await tctx.master.await_log("nonexistent") assert await tctx.master.await_log("nonexistent")
class TestReadFileStdin: class TestReadFileStdin:
@mock.patch('mitmproxy.master.Master.load_flow') @asynctest.patch('sys.stdin')
@mock.patch('sys.stdin') @pytest.mark.asyncio
def test_stdin(self, stdin, load_flow, data, corrupt_data): async def test_stdin(self, stdin, data, corrupt_data):
rf = readfile.ReadFileStdin()
with taddons.context(rf) as tctx:
stdin.buffer = data
tctx.configure(rf, rfile="-")
assert not load_flow.called
rf.running()
assert load_flow.called
stdin.buffer = corrupt_data
tctx.configure(rf, rfile="-")
with pytest.raises(exceptions.OptionsError):
rf.running()
@mock.patch('mitmproxy.master.Master.load_flow')
def test_normal(self, load_flow, tmpdir, data):
rf = readfile.ReadFileStdin() rf = readfile.ReadFileStdin()
with taddons.context(rf): with taddons.context(rf):
tfile = tmpdir.join("tfile") with asynctest.patch('mitmproxy.master.Master.load_flow') as mck:
tfile.write(data.getvalue()) stdin.buffer = data
rf.load_flows_from_path(str(tfile)) assert not mck.awaited
assert load_flow.called await rf.load_flows(stdin.buffer)
assert mck.awaited
stdin.buffer = corrupt_data
with pytest.raises(exceptions.FlowReadException):
await rf.load_flows(stdin.buffer)
@pytest.mark.asyncio
@mock.patch('mitmproxy.master.Master.load_flow')
async def test_normal(self, load_flow, tmpdir, data):
rf = readfile.ReadFileStdin()
with taddons.context(rf) as tctx:
tf = tmpdir.join("tfile")
with asynctest.patch('mitmproxy.master.Master.load_flow') as mck:
tf.write(data.getvalue())
tctx.configure(rf, rfile=str(tf))
assert not mck.awaited
rf.running()
await asyncio.sleep(0)
assert mck.awaited