asyncio: adjust readfile.py addon for async

This commit is contained in:
Aldo Cortesi 2018-04-14 10:24:03 +12:00
parent 660aa87a24
commit 214498f01c
2 changed files with 17 additions and 14 deletions

View File

@ -1,3 +1,4 @@
import asyncio
import os.path
import sys
import typing
@ -17,12 +18,12 @@ class ReadFile:
"Read flows from file."
)
def load_flows(self, fo: typing.IO[bytes]) -> int:
async def load_flows(self, fo: typing.IO[bytes]) -> int:
cnt = 0
freader = io.FlowReader(fo)
try:
for flow in freader.stream():
ctx.master.load_flow(flow)
await ctx.master.load_flow(flow)
cnt += 1
except (IOError, exceptions.FlowReadException) as e:
if cnt:
@ -33,29 +34,32 @@ class ReadFile:
else:
return cnt
def load_flows_from_path(self, path: str) -> int:
async def load_flows_from_path(self, path: str) -> int:
path = os.path.expanduser(path)
try:
with open(path, "rb") as f:
return self.load_flows(f)
return await self.load_flows(f)
except IOError as e:
ctx.log.error("Cannot load flows: {}".format(e))
raise exceptions.FlowReadException(str(e)) from e
async def doread(self, rfile):
try:
await self.load_flows_from_path(ctx.options.rfile)
except exceptions.FlowReadException as e:
raise exceptions.OptionsError(e) from e
finally:
ctx.master.addons.trigger("processing_complete")
def running(self):
if ctx.options.rfile:
try:
self.load_flows_from_path(ctx.options.rfile)
except exceptions.FlowReadException as e:
raise exceptions.OptionsError(e) from e
finally:
ctx.master.addons.trigger("processing_complete")
asyncio.get_event_loop().create_task(self.doread(ctx.options.rfile))
class ReadFileStdin(ReadFile):
"""Support the special case of "-" for reading from stdin"""
def load_flows_from_path(self, path: str) -> int:
async def load_flows_from_path(self, path: str) -> int:
if path == "-":
return self.load_flows(sys.stdin.buffer)
return await self.load_flows(sys.stdin.buffer)
else:
return super().load_flows_from_path(path)
return await super().load_flows_from_path(path)

View File

@ -150,7 +150,6 @@ def mitmdump(args=None): # pragma: no cover
if args.filter_args:
v = " ".join(args.filter_args)
return dict(
view_filter=v,
save_stream_filter=v,
)
return {}