This commit is contained in:
Thomas Kriechbaumer 2017-12-17 17:44:36 +00:00
parent 1a7ce384da
commit 8d836d251e
8 changed files with 63 additions and 13 deletions

View File

@ -74,6 +74,15 @@ class Save:
self.stream.add(flow) self.stream.add(flow)
self.active_flows.discard(flow) self.active_flows.discard(flow)
def websocket_start(self, flow):
if self.stream:
self.active_flows.add(flow)
def websocket_end(self, flow):
if self.stream:
self.stream.add(flow)
self.active_flows.discard(flow)
def response(self, flow): def response(self, flow):
if self.stream: if self.stream:
self.stream.add(flow) self.stream.add(flow)

View File

View File

@ -331,6 +331,7 @@ class FDomain(_Rex):
self.re.search(f.request.pretty_host) self.re.search(f.request.pretty_host)
) )
class FUrl(_Rex): class FUrl(_Rex):
code = "u" code = "u"
help = "URL" help = "URL"

View File

@ -9,6 +9,7 @@ from mitmproxy import eventsequence
from mitmproxy import exceptions from mitmproxy import exceptions
from mitmproxy import command from mitmproxy import command
from mitmproxy import http from mitmproxy import http
from mitmproxy import websocket
from mitmproxy import log from mitmproxy import log
from mitmproxy.net import server_spec from mitmproxy.net import server_spec
from mitmproxy.proxy.protocol import http_replay from mitmproxy.proxy.protocol import http_replay
@ -41,6 +42,7 @@ class Master:
self.should_exit = threading.Event() self.should_exit = threading.Event()
self._server = None self._server = None
self.first_tick = True self.first_tick = True
self.waiting_flows = []
@property @property
def server(self): def server(self):
@ -117,15 +119,28 @@ class Master:
self.should_exit.set() self.should_exit.set()
self.addons.trigger("done") self.addons.trigger("done")
def _change_reverse_host(self, f):
if self.options.mode.startswith("reverse:"):
_, upstream_spec = server_spec.parse_with_mode(self.options.mode)
f.request.host, f.request.port = upstream_spec.address
f.request.scheme = upstream_spec.scheme
def load_flow(self, f): def load_flow(self, f):
""" """
Loads a flow Loads a flow and links websocket & handshake flows
""" """
if isinstance(f, http.HTTPFlow): if isinstance(f, http.HTTPFlow):
if self.options.mode.startswith("reverse:"): self._change_reverse_host(f)
_, upstream_spec = server_spec.parse_with_mode(self.options.mode) if 'websocket' in f.metadata:
f.request.host, f.request.port = upstream_spec.address self.waiting_flows.append(f)
f.request.scheme = upstream_spec.scheme
if isinstance(f, websocket.WebSocketFlow):
hf = [hf for hf in self.waiting_flows if hf.id == f.metadata['websocket_handshake']][0]
f.handshake_flow = hf
self.waiting_flows.remove(hf)
self._change_reverse_host(f.handshake_flow)
f.reply = controller.DummyReply() f.reply = controller.DummyReply()
for e, o in eventsequence.iterate(f): for e, o in eventsequence.iterate(f):
self.addons.handle_lifecycle(e, o) self.addons.handle_lifecycle(e, o)

View File

@ -321,6 +321,7 @@ class HttpLayer(base.Layer):
try: try:
if websockets.check_handshake(request.headers) and websockets.check_client_version(request.headers): if websockets.check_handshake(request.headers) and websockets.check_client_version(request.headers):
f.metadata['websocket'] = True
# We only support RFC6455 with WebSocket version 13 # We only support RFC6455 with WebSocket version 13
# allow inline scripts to manipulate the client handshake # allow inline scripts to manipulate the client handshake
self.channel.ask("websocket_handshake", f) self.channel.ask("websocket_handshake", f)

View File

@ -1,10 +1,11 @@
import socket import socket
from OpenSSL import SSL from OpenSSL import SSL
from mitmproxy.contrib import wsproto
from mitmproxy.contrib.wsproto import events from mitmproxy.contrib.wsproto import events
from mitmproxy.contrib.wsproto.connection import ConnectionType, WSConnection from mitmproxy.contrib.wsproto.connection import ConnectionType, WSConnection
from mitmproxy.contrib.wsproto.extensions import PerMessageDeflate from mitmproxy.contrib.wsproto.extensions import PerMessageDeflate
from mitmproxy.contrib.wsproto.frame_protocol import Opcode
from mitmproxy import exceptions from mitmproxy import exceptions
from mitmproxy import flow from mitmproxy import flow
@ -93,11 +94,14 @@ class WebSocketLayer(base.Layer):
if event.message_finished: if event.message_finished:
original_chunk_sizes = [len(f) for f in fb] original_chunk_sizes = [len(f) for f in fb]
message_type = Opcode.TEXT if isinstance(event, events.TextReceived) else Opcode.BINARY
if message_type == Opcode.TEXT: if isinstance(event, events.TextReceived):
message_type = wsproto.frame_protocol.Opcode.TEXT
payload = ''.join(fb) payload = ''.join(fb)
else: else:
message_type = wsproto.frame_protocol.Opcode.BINARY
payload = b''.join(fb) payload = b''.join(fb)
fb.clear() fb.clear()
websocket_message = WebSocketMessage(message_type, not is_server, payload) websocket_message = WebSocketMessage(message_type, not is_server, payload)

View File

@ -1,6 +1,8 @@
import time import time
from typing import List, Optional from typing import List, Optional
from mitmproxy.contrib import wsproto
from mitmproxy import flow from mitmproxy import flow
from mitmproxy.net import websockets from mitmproxy.net import websockets
from mitmproxy.coretypes import serializable from mitmproxy.coretypes import serializable
@ -11,7 +13,7 @@ class WebSocketMessage(serializable.Serializable):
def __init__( def __init__(
self, type: int, from_client: bool, content: bytes, timestamp: Optional[int]=None self, type: int, from_client: bool, content: bytes, timestamp: Optional[int]=None
) -> None: ) -> None:
self.type = type self.type = wsproto.frame_protocol.Opcode(type)
self.from_client = from_client self.from_client = from_client
self.content = content self.content = content
self.timestamp = timestamp or int(time.time()) # type: int self.timestamp = timestamp or int(time.time()) # type: int
@ -21,13 +23,14 @@ class WebSocketMessage(serializable.Serializable):
return cls(*state) return cls(*state)
def get_state(self): def get_state(self):
return self.type, self.from_client, self.content, self.timestamp return int(self.type), self.from_client, self.content, self.timestamp
def set_state(self, state): def set_state(self, state):
self.type, self.from_client, self.content, self.timestamp = state self.type, self.from_client, self.content, self.timestamp = state
self.type = wsproto.frame_protocol.Opcode(self.type) # replace enum with bare int
def __repr__(self): def __repr__(self):
if self.type == websockets.OPCODE.TEXT: if self.type == wsproto.frame_protocol.Opcode.TEXT:
return "text message: {}".format(repr(self.content)) return "text message: {}".format(repr(self.content))
else: else:
return "binary message: {}".format(strutils.bytes_to_escaped_str(self.content)) return "binary message: {}".format(strutils.bytes_to_escaped_str(self.content))
@ -42,7 +45,7 @@ class WebSocketFlow(flow.Flow):
super().__init__("websocket", client_conn, server_conn, live) super().__init__("websocket", client_conn, server_conn, live)
self.messages = [] # type: List[WebSocketMessage] self.messages = [] # type: List[WebSocketMessage]
self.close_sender = 'client' self.close_sender = 'client'
self.close_code = '(status code missing)' self.close_code = wsproto.frame_protocol.CloseReason.NORMAL_CLOSURE
self.close_message = '(message missing)' self.close_message = '(message missing)'
self.close_reason = 'unknown status code' self.close_reason = 'unknown status code'
self.stream = False self.stream = False
@ -69,7 +72,7 @@ class WebSocketFlow(flow.Flow):
_stateobject_attributes.update(dict( _stateobject_attributes.update(dict(
messages=List[WebSocketMessage], messages=List[WebSocketMessage],
close_sender=str, close_sender=str,
close_code=str, close_code=int,
close_message=str, close_message=str,
close_reason=str, close_reason=str,
client_key=str, client_key=str,
@ -83,6 +86,11 @@ class WebSocketFlow(flow.Flow):
# dumping the handshake_flow will include the WebSocketFlow too. # dumping the handshake_flow will include the WebSocketFlow too.
)) ))
def get_state(self):
d = super().get_state()
d['close_code'] = int(d['close_code']) # replace enum with bare int
return d
@classmethod @classmethod
def from_state(cls, state): def from_state(cls, state):
f = cls(None, None, None) f = cls(None, None, None)

View File

@ -44,6 +44,18 @@ def test_tcp(tmpdir):
assert rd(p) assert rd(p)
def test_websocket(tmpdir):
sa = save.Save()
with taddons.context() as tctx:
p = str(tmpdir.join("foo"))
tctx.configure(sa, save_stream_file=p)
f = tflow.twebsocketflow()
sa.websocket_start(f)
tctx.configure(sa, save_stream_file=None)
assert rd(p)
def test_save_command(tmpdir): def test_save_command(tmpdir):
sa = save.Save() sa = save.Save()
with taddons.context() as tctx: with taddons.context() as tctx: