asyncio: adjust readfile.py addon for async

This commit is contained in:
Aldo Cortesi 2018-04-14 10:24:03 +12:00
parent 660aa87a24
commit 214498f01c
2 changed files with 17 additions and 14 deletions

View File

@ -1,3 +1,4 @@
import asyncio
import os.path import os.path
import sys import sys
import typing import typing
@ -17,12 +18,12 @@ class ReadFile:
"Read flows from file." "Read flows from file."
) )
def load_flows(self, fo: typing.IO[bytes]) -> int: async def load_flows(self, fo: typing.IO[bytes]) -> int:
cnt = 0 cnt = 0
freader = io.FlowReader(fo) freader = io.FlowReader(fo)
try: try:
for flow in freader.stream(): for flow in freader.stream():
ctx.master.load_flow(flow) await ctx.master.load_flow(flow)
cnt += 1 cnt += 1
except (IOError, exceptions.FlowReadException) as e: except (IOError, exceptions.FlowReadException) as e:
if cnt: if cnt:
@ -33,29 +34,32 @@ class ReadFile:
else: else:
return cnt return cnt
def load_flows_from_path(self, path: str) -> int: async def load_flows_from_path(self, path: str) -> int:
path = os.path.expanduser(path) path = os.path.expanduser(path)
try: try:
with open(path, "rb") as f: with open(path, "rb") as f:
return self.load_flows(f) return await self.load_flows(f)
except IOError as e: except IOError as e:
ctx.log.error("Cannot load flows: {}".format(e)) ctx.log.error("Cannot load flows: {}".format(e))
raise exceptions.FlowReadException(str(e)) from e raise exceptions.FlowReadException(str(e)) from e
async def doread(self, rfile):
try:
await self.load_flows_from_path(ctx.options.rfile)
except exceptions.FlowReadException as e:
raise exceptions.OptionsError(e) from e
finally:
ctx.master.addons.trigger("processing_complete")
def running(self): def running(self):
if ctx.options.rfile: if ctx.options.rfile:
try: asyncio.get_event_loop().create_task(self.doread(ctx.options.rfile))
self.load_flows_from_path(ctx.options.rfile)
except exceptions.FlowReadException as e:
raise exceptions.OptionsError(e) from e
finally:
ctx.master.addons.trigger("processing_complete")
class ReadFileStdin(ReadFile): class ReadFileStdin(ReadFile):
"""Support the special case of "-" for reading from stdin""" """Support the special case of "-" for reading from stdin"""
def load_flows_from_path(self, path: str) -> int: async def load_flows_from_path(self, path: str) -> int:
if path == "-": if path == "-":
return self.load_flows(sys.stdin.buffer) return await self.load_flows(sys.stdin.buffer)
else: else:
return super().load_flows_from_path(path) return await super().load_flows_from_path(path)

View File

@ -150,7 +150,6 @@ def mitmdump(args=None): # pragma: no cover
if args.filter_args: if args.filter_args:
v = " ".join(args.filter_args) v = " ".join(args.filter_args)
return dict( return dict(
view_filter=v,
save_stream_filter=v, save_stream_filter=v,
) )
return {} return {}