Merge pull request #2147 from mhils/connection-ids

Add client/server connection ids
This commit is contained in:
Aldo Cortesi 2017-03-15 09:20:16 +13:00 committed by GitHub
commit 6e7ba84017
8 changed files with 93 additions and 8 deletions

View File

@ -1,6 +1,7 @@
import time import time
import os import os
import uuid
from mitmproxy import stateobject from mitmproxy import stateobject
from mitmproxy import certs from mitmproxy import certs
@ -41,6 +42,7 @@ class ClientConnection(tcp.BaseHandler, stateobject.StateObject):
self.clientcert = None self.clientcert = None
self.ssl_established = None self.ssl_established = None
self.id = str(uuid.uuid4())
self.mitmcert = None self.mitmcert = None
self.timestamp_start = time.time() self.timestamp_start = time.time()
self.timestamp_end = None self.timestamp_end = None
@ -73,6 +75,14 @@ class ClientConnection(tcp.BaseHandler, stateobject.StateObject):
port=self.address[1], port=self.address[1],
) )
def __eq__(self, other):
if isinstance(other, ClientConnection):
return self.id == other.id
return False
def __hash__(self):
return hash(self.id)
@property @property
def tls_established(self): def tls_established(self):
return self.ssl_established return self.ssl_established
@ -82,6 +92,7 @@ class ClientConnection(tcp.BaseHandler, stateobject.StateObject):
self.ssl_established = value self.ssl_established = value
_stateobject_attributes = dict( _stateobject_attributes = dict(
id=str,
address=tuple, address=tuple,
ssl_established=bool, ssl_established=bool,
clientcert=certs.SSLCert, clientcert=certs.SSLCert,
@ -110,6 +121,7 @@ class ClientConnection(tcp.BaseHandler, stateobject.StateObject):
@classmethod @classmethod
def make_dummy(cls, address): def make_dummy(cls, address):
return cls.from_state(dict( return cls.from_state(dict(
id=str(uuid.uuid4()),
address=address, address=address,
clientcert=None, clientcert=None,
mitmcert=None, mitmcert=None,
@ -165,6 +177,7 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
def __init__(self, address, source_address=None, spoof_source_address=None): def __init__(self, address, source_address=None, spoof_source_address=None):
tcp.TCPClient.__init__(self, address, source_address, spoof_source_address) tcp.TCPClient.__init__(self, address, source_address, spoof_source_address)
self.id = str(uuid.uuid4())
self.alpn_proto_negotiated = None self.alpn_proto_negotiated = None
self.tls_version = None self.tls_version = None
self.via = None self.via = None
@ -196,6 +209,14 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
port=self.address[1], port=self.address[1],
) )
def __eq__(self, other):
if isinstance(other, ServerConnection):
return self.id == other.id
return False
def __hash__(self):
return hash(self.id)
@property @property
def tls_established(self): def tls_established(self):
return self.ssl_established return self.ssl_established
@ -205,6 +226,7 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
self.ssl_established = value self.ssl_established = value
_stateobject_attributes = dict( _stateobject_attributes = dict(
id=str,
address=tuple, address=tuple,
ip_address=tuple, ip_address=tuple,
source_address=tuple, source_address=tuple,
@ -228,6 +250,7 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
@classmethod @classmethod
def make_dummy(cls, address): def make_dummy(cls, address):
return cls.from_state(dict( return cls.from_state(dict(
id=str(uuid.uuid4()),
address=address, address=address,
ip_address=address, ip_address=address,
cert=None, cert=None,

View File

@ -112,7 +112,6 @@ class Flow(stateobject.StateObject):
def copy(self): def copy(self):
f = super().copy() f = super().copy()
f.id = str(uuid.uuid4())
f.live = False f.live = False
return f return f

View File

@ -1,7 +1,7 @@
""" """
This module handles the import of mitmproxy flows generated by old versions. This module handles the import of mitmproxy flows generated by old versions.
""" """
import uuid
from typing import Any, Dict from typing import Any, Dict
from mitmproxy import version from mitmproxy import version
@ -113,6 +113,25 @@ def convert_300_4(data):
return data return data
client_connections = {}
server_connections = {}
def convert_4_5(data):
data["version"] = 5
client_conn_key = (
data["client_conn"]["timestamp_start"],
*data["client_conn"]["address"]
)
server_conn_key = (
data["server_conn"]["timestamp_start"],
*data["server_conn"]["source_address"]
)
data["client_conn"]["id"] = client_connections.setdefault(client_conn_key, str(uuid.uuid4()))
data["server_conn"]["id"] = server_connections.setdefault(server_conn_key, str(uuid.uuid4()))
return data
def _convert_dict_keys(o: Any) -> Any: def _convert_dict_keys(o: Any) -> Any:
if isinstance(o, dict): if isinstance(o, dict):
return {strutils.always_str(k): _convert_dict_keys(v) for k, v in o.items()} return {strutils.always_str(k): _convert_dict_keys(v) for k, v in o.items()}
@ -164,6 +183,7 @@ converters = {
(1, 0): convert_100_200, (1, 0): convert_100_200,
(2, 0): convert_200_300, (2, 0): convert_200_300,
(3, 0): convert_300_4, (3, 0): convert_300_4,
4: convert_4_5,
} }

View File

@ -1,4 +1,5 @@
import io import io
import uuid
from mitmproxy.net import websockets from mitmproxy.net import websockets
from mitmproxy.test import tutils from mitmproxy.test import tutils
@ -146,6 +147,7 @@ def tclient_conn():
@return: mitmproxy.proxy.connection.ClientConnection @return: mitmproxy.proxy.connection.ClientConnection
""" """
c = connections.ClientConnection.from_state(dict( c = connections.ClientConnection.from_state(dict(
id=str(uuid.uuid4()),
address=("address", 22), address=("address", 22),
clientcert=None, clientcert=None,
mitmcert=None, mitmcert=None,
@ -169,6 +171,7 @@ def tserver_conn():
@return: mitmproxy.proxy.connection.ServerConnection @return: mitmproxy.proxy.connection.ServerConnection
""" """
c = connections.ServerConnection.from_state(dict( c = connections.ServerConnection.from_state(dict(
id=str(uuid.uuid4()),
address=("address", 22), address=("address", 22),
source_address=("address", 22), source_address=("address", 22),
ip_address=None, ip_address=None,

View File

@ -1,4 +1,5 @@
import abc import abc
import uuid
class Serializable(metaclass=abc.ABCMeta): class Serializable(metaclass=abc.ABCMeta):
@ -29,4 +30,7 @@ class Serializable(metaclass=abc.ABCMeta):
raise NotImplementedError() raise NotImplementedError()
def copy(self): def copy(self):
return self.from_state(self.get_state()) state = self.get_state()
if isinstance(state, dict) and "id" in state:
state["id"] = str(uuid.uuid4())
return self.from_state(state)

View File

@ -5,7 +5,7 @@ MITMPROXY = "mitmproxy " + VERSION
# Serialization format version. This is displayed nowhere, it just needs to be incremented by one # Serialization format version. This is displayed nowhere, it just needs to be incremented by one
# for each change the the file format. # for each change the the file format.
FLOW_FORMAT_VERSION = 4 FLOW_FORMAT_VERSION = 5
if __name__ == "__main__": if __name__ == "__main__":
print(VERSION) print(VERSION)

View File

@ -66,8 +66,18 @@ class TestClientConnection:
assert c.timestamp_start == 42 assert c.timestamp_start == 42
c3 = c.copy() c3 = c.copy()
assert c3.get_state() != c.get_state()
c.id = c3.id = "foo"
assert c3.get_state() == c.get_state() assert c3.get_state() == c.get_state()
def test_eq(self):
c = tflow.tclient_conn()
c2 = c.copy()
assert c == c
assert c != c2
assert c != 42
assert hash(c) != hash(c2)
class TestServerConnection: class TestServerConnection:
@ -147,6 +157,21 @@ class TestServerConnection:
with pytest.raises(ValueError, matches='sni must be str, not '): with pytest.raises(ValueError, matches='sni must be str, not '):
c.establish_ssl(None, b'foobar') c.establish_ssl(None, b'foobar')
def test_state(self):
c = tflow.tserver_conn()
c2 = c.copy()
assert c2.get_state() != c.get_state()
c.id = c2.id = "foo"
assert c2.get_state() == c.get_state()
def test_eq(self):
c = tflow.tserver_conn()
c2 = c.copy()
assert c == c
assert c != c2
assert c != 42
assert hash(c) != hash(c2)
class TestClientConnectionTLS: class TestClientConnectionTLS:

View File

@ -1,3 +1,5 @@
import copy
from mitmproxy.types import serializable from mitmproxy.types import serializable
@ -6,17 +8,17 @@ class SerializableDummy(serializable.Serializable):
self.i = i self.i = i
def get_state(self): def get_state(self):
return self.i return copy.copy(self.i)
def set_state(self, i): def set_state(self, i):
self.i = i self.i = i
def from_state(self, state): @classmethod
return type(self)(state) def from_state(cls, state):
return cls(state)
class TestSerializable: class TestSerializable:
def test_copy(self): def test_copy(self):
a = SerializableDummy(42) a = SerializableDummy(42)
assert a.i == 42 assert a.i == 42
@ -26,3 +28,12 @@ class TestSerializable:
a.set_state(1) a.set_state(1)
assert a.i == 1 assert a.i == 1
assert b.i == 42 assert b.i == 42
def test_copy_id(self):
a = SerializableDummy({
"id": "foo",
"foo": 42
})
b = a.copy()
assert a.get_state()["id"] != b.get_state()["id"]
assert a.get_state()["foo"] == b.get_state()["foo"]