From 375680a3be47b7dd7b94ebd376978d9e4d90abcd Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Tue, 14 Mar 2017 01:36:36 +0100 Subject: [PATCH] add connection ids --- mitmproxy/connections.py | 33 ++++++++++++++++++++++++++++++ mitmproxy/io_compat.py | 22 +++++++++++++++++++- mitmproxy/test/tflow.py | 3 +++ mitmproxy/version.py | 2 +- test/mitmproxy/test_connections.py | 24 ++++++++++++++++++++++ 5 files changed, 82 insertions(+), 2 deletions(-) diff --git a/mitmproxy/connections.py b/mitmproxy/connections.py index 9359b67db..5a3cb69e8 100644 --- a/mitmproxy/connections.py +++ b/mitmproxy/connections.py @@ -1,6 +1,7 @@ import time import os +import uuid from mitmproxy import stateobject from mitmproxy import certs @@ -41,6 +42,7 @@ class ClientConnection(tcp.BaseHandler, stateobject.StateObject): self.clientcert = None self.ssl_established = None + self.id = str(uuid.uuid4()) self.mitmcert = None self.timestamp_start = time.time() self.timestamp_end = None @@ -73,6 +75,19 @@ class ClientConnection(tcp.BaseHandler, stateobject.StateObject): port=self.address[1], ) + def __eq__(self, other): + if isinstance(other, ClientConnection): + return self.id == other.id + return False + + def __hash__(self): + return hash(self.id) + + def copy(self): + f = super().copy() + f.id = str(uuid.uuid4()) + return f + @property def tls_established(self): return self.ssl_established @@ -82,6 +97,7 @@ class ClientConnection(tcp.BaseHandler, stateobject.StateObject): self.ssl_established = value _stateobject_attributes = dict( + id=str, address=tuple, ssl_established=bool, clientcert=certs.SSLCert, @@ -110,6 +126,7 @@ class ClientConnection(tcp.BaseHandler, stateobject.StateObject): @classmethod def make_dummy(cls, address): return cls.from_state(dict( + id=str(uuid.uuid4()), address=address, clientcert=None, mitmcert=None, @@ -165,6 +182,7 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject): def __init__(self, address, source_address=None, spoof_source_address=None): tcp.TCPClient.__init__(self, address, source_address, spoof_source_address) + self.id = str(uuid.uuid4()) self.alpn_proto_negotiated = None self.tls_version = None self.via = None @@ -196,6 +214,19 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject): port=self.address[1], ) + def __eq__(self, other): + if isinstance(other, ServerConnection): + return self.id == other.id + return False + + def __hash__(self): + return hash(self.id) + + def copy(self): + f = super().copy() + f.id = str(uuid.uuid4()) + return f + @property def tls_established(self): return self.ssl_established @@ -205,6 +236,7 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject): self.ssl_established = value _stateobject_attributes = dict( + id=str, address=tuple, ip_address=tuple, source_address=tuple, @@ -228,6 +260,7 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject): @classmethod def make_dummy(cls, address): return cls.from_state(dict( + id=str(uuid.uuid4()), address=address, ip_address=address, cert=None, diff --git a/mitmproxy/io_compat.py b/mitmproxy/io_compat.py index 8fa7b77c9..7d839ffd4 100644 --- a/mitmproxy/io_compat.py +++ b/mitmproxy/io_compat.py @@ -1,7 +1,7 @@ """ This module handles the import of mitmproxy flows generated by old versions. """ - +import uuid from typing import Any, Dict from mitmproxy import version @@ -113,6 +113,25 @@ def convert_300_4(data): return data +client_connections = {} +server_connections = {} + + +def convert_4_5(data): + data["version"] = 5 + client_conn_key = ( + data["client_conn"]["timestamp_start"], + *data["client_conn"]["address"] + ) + server_conn_key = ( + data["server_conn"]["timestamp_start"], + *data["server_conn"]["source_address"] + ) + data["client_conn"]["id"] = client_connections.setdefault(client_conn_key, str(uuid.uuid4())) + data["server_conn"]["id"] = server_connections.setdefault(server_conn_key, str(uuid.uuid4())) + return data + + def _convert_dict_keys(o: Any) -> Any: if isinstance(o, dict): return {strutils.always_str(k): _convert_dict_keys(v) for k, v in o.items()} @@ -164,6 +183,7 @@ converters = { (1, 0): convert_100_200, (2, 0): convert_200_300, (3, 0): convert_300_4, + 4: convert_4_5, } diff --git a/mitmproxy/test/tflow.py b/mitmproxy/test/tflow.py index 7fbe17271..270021cbf 100644 --- a/mitmproxy/test/tflow.py +++ b/mitmproxy/test/tflow.py @@ -1,4 +1,5 @@ import io +import uuid from mitmproxy.net import websockets from mitmproxy.test import tutils @@ -146,6 +147,7 @@ def tclient_conn(): @return: mitmproxy.proxy.connection.ClientConnection """ c = connections.ClientConnection.from_state(dict( + id=str(uuid.uuid4()), address=("address", 22), clientcert=None, mitmcert=None, @@ -169,6 +171,7 @@ def tserver_conn(): @return: mitmproxy.proxy.connection.ServerConnection """ c = connections.ServerConnection.from_state(dict( + id=str(uuid.uuid4()), address=("address", 22), source_address=("address", 22), ip_address=None, diff --git a/mitmproxy/version.py b/mitmproxy/version.py index 2882d1fbb..006ec868c 100644 --- a/mitmproxy/version.py +++ b/mitmproxy/version.py @@ -5,7 +5,7 @@ MITMPROXY = "mitmproxy " + VERSION # Serialization format version. This is displayed nowhere, it just needs to be incremented by one # for each change the the file format. -FLOW_FORMAT_VERSION = 4 +FLOW_FORMAT_VERSION = 5 if __name__ == "__main__": print(VERSION) diff --git a/test/mitmproxy/test_connections.py b/test/mitmproxy/test_connections.py index 0083f57cc..57fdd8c79 100644 --- a/test/mitmproxy/test_connections.py +++ b/test/mitmproxy/test_connections.py @@ -66,8 +66,17 @@ class TestClientConnection: assert c.timestamp_start == 42 c3 = c.copy() + assert c3.get_state() != c.get_state() + c.id = c3.id = "foo" assert c3.get_state() == c.get_state() + def test_eq(self): + c = tflow.tclient_conn() + c2 = c.copy() + assert c == c + assert c != c2 + assert c != 42 + assert hash(c) != hash(c2) class TestServerConnection: @@ -147,6 +156,21 @@ class TestServerConnection: with pytest.raises(ValueError, matches='sni must be str, not '): c.establish_ssl(None, b'foobar') + def test_state(self): + c = tflow.tserver_conn() + c2 = c.copy() + assert c2.get_state() != c.get_state() + c.id = c2.id = "foo" + assert c2.get_state() == c.get_state() + + def test_eq(self): + c = tflow.tserver_conn() + c2 = c.copy() + assert c == c + assert c != c2 + assert c != 42 + assert hash(c) != hash(c2) + class TestClientConnectionTLS: