add connection ids

This commit is contained in:
Maximilian Hils 2017-03-14 01:36:36 +01:00
parent e29cd7f5b7
commit 375680a3be
5 changed files with 82 additions and 2 deletions

View File

@ -1,6 +1,7 @@
import time import time
import os import os
import uuid
from mitmproxy import stateobject from mitmproxy import stateobject
from mitmproxy import certs from mitmproxy import certs
@ -41,6 +42,7 @@ class ClientConnection(tcp.BaseHandler, stateobject.StateObject):
self.clientcert = None self.clientcert = None
self.ssl_established = None self.ssl_established = None
self.id = str(uuid.uuid4())
self.mitmcert = None self.mitmcert = None
self.timestamp_start = time.time() self.timestamp_start = time.time()
self.timestamp_end = None self.timestamp_end = None
@ -73,6 +75,19 @@ class ClientConnection(tcp.BaseHandler, stateobject.StateObject):
port=self.address[1], port=self.address[1],
) )
def __eq__(self, other):
if isinstance(other, ClientConnection):
return self.id == other.id
return False
def __hash__(self):
return hash(self.id)
def copy(self):
f = super().copy()
f.id = str(uuid.uuid4())
return f
@property @property
def tls_established(self): def tls_established(self):
return self.ssl_established return self.ssl_established
@ -82,6 +97,7 @@ class ClientConnection(tcp.BaseHandler, stateobject.StateObject):
self.ssl_established = value self.ssl_established = value
_stateobject_attributes = dict( _stateobject_attributes = dict(
id=str,
address=tuple, address=tuple,
ssl_established=bool, ssl_established=bool,
clientcert=certs.SSLCert, clientcert=certs.SSLCert,
@ -110,6 +126,7 @@ class ClientConnection(tcp.BaseHandler, stateobject.StateObject):
@classmethod @classmethod
def make_dummy(cls, address): def make_dummy(cls, address):
return cls.from_state(dict( return cls.from_state(dict(
id=str(uuid.uuid4()),
address=address, address=address,
clientcert=None, clientcert=None,
mitmcert=None, mitmcert=None,
@ -165,6 +182,7 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
def __init__(self, address, source_address=None, spoof_source_address=None): def __init__(self, address, source_address=None, spoof_source_address=None):
tcp.TCPClient.__init__(self, address, source_address, spoof_source_address) tcp.TCPClient.__init__(self, address, source_address, spoof_source_address)
self.id = str(uuid.uuid4())
self.alpn_proto_negotiated = None self.alpn_proto_negotiated = None
self.tls_version = None self.tls_version = None
self.via = None self.via = None
@ -196,6 +214,19 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
port=self.address[1], port=self.address[1],
) )
def __eq__(self, other):
if isinstance(other, ServerConnection):
return self.id == other.id
return False
def __hash__(self):
return hash(self.id)
def copy(self):
f = super().copy()
f.id = str(uuid.uuid4())
return f
@property @property
def tls_established(self): def tls_established(self):
return self.ssl_established return self.ssl_established
@ -205,6 +236,7 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
self.ssl_established = value self.ssl_established = value
_stateobject_attributes = dict( _stateobject_attributes = dict(
id=str,
address=tuple, address=tuple,
ip_address=tuple, ip_address=tuple,
source_address=tuple, source_address=tuple,
@ -228,6 +260,7 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
@classmethod @classmethod
def make_dummy(cls, address): def make_dummy(cls, address):
return cls.from_state(dict( return cls.from_state(dict(
id=str(uuid.uuid4()),
address=address, address=address,
ip_address=address, ip_address=address,
cert=None, cert=None,

View File

@ -1,7 +1,7 @@
""" """
This module handles the import of mitmproxy flows generated by old versions. This module handles the import of mitmproxy flows generated by old versions.
""" """
import uuid
from typing import Any, Dict from typing import Any, Dict
from mitmproxy import version from mitmproxy import version
@ -113,6 +113,25 @@ def convert_300_4(data):
return data return data
client_connections = {}
server_connections = {}
def convert_4_5(data):
data["version"] = 5
client_conn_key = (
data["client_conn"]["timestamp_start"],
*data["client_conn"]["address"]
)
server_conn_key = (
data["server_conn"]["timestamp_start"],
*data["server_conn"]["source_address"]
)
data["client_conn"]["id"] = client_connections.setdefault(client_conn_key, str(uuid.uuid4()))
data["server_conn"]["id"] = server_connections.setdefault(server_conn_key, str(uuid.uuid4()))
return data
def _convert_dict_keys(o: Any) -> Any: def _convert_dict_keys(o: Any) -> Any:
if isinstance(o, dict): if isinstance(o, dict):
return {strutils.always_str(k): _convert_dict_keys(v) for k, v in o.items()} return {strutils.always_str(k): _convert_dict_keys(v) for k, v in o.items()}
@ -164,6 +183,7 @@ converters = {
(1, 0): convert_100_200, (1, 0): convert_100_200,
(2, 0): convert_200_300, (2, 0): convert_200_300,
(3, 0): convert_300_4, (3, 0): convert_300_4,
4: convert_4_5,
} }

View File

@ -1,4 +1,5 @@
import io import io
import uuid
from mitmproxy.net import websockets from mitmproxy.net import websockets
from mitmproxy.test import tutils from mitmproxy.test import tutils
@ -146,6 +147,7 @@ def tclient_conn():
@return: mitmproxy.proxy.connection.ClientConnection @return: mitmproxy.proxy.connection.ClientConnection
""" """
c = connections.ClientConnection.from_state(dict( c = connections.ClientConnection.from_state(dict(
id=str(uuid.uuid4()),
address=("address", 22), address=("address", 22),
clientcert=None, clientcert=None,
mitmcert=None, mitmcert=None,
@ -169,6 +171,7 @@ def tserver_conn():
@return: mitmproxy.proxy.connection.ServerConnection @return: mitmproxy.proxy.connection.ServerConnection
""" """
c = connections.ServerConnection.from_state(dict( c = connections.ServerConnection.from_state(dict(
id=str(uuid.uuid4()),
address=("address", 22), address=("address", 22),
source_address=("address", 22), source_address=("address", 22),
ip_address=None, ip_address=None,

View File

@ -5,7 +5,7 @@ MITMPROXY = "mitmproxy " + VERSION
# Serialization format version. This is displayed nowhere, it just needs to be incremented by one # Serialization format version. This is displayed nowhere, it just needs to be incremented by one
# for each change the the file format. # for each change the the file format.
FLOW_FORMAT_VERSION = 4 FLOW_FORMAT_VERSION = 5
if __name__ == "__main__": if __name__ == "__main__":
print(VERSION) print(VERSION)

View File

@ -66,8 +66,17 @@ class TestClientConnection:
assert c.timestamp_start == 42 assert c.timestamp_start == 42
c3 = c.copy() c3 = c.copy()
assert c3.get_state() != c.get_state()
c.id = c3.id = "foo"
assert c3.get_state() == c.get_state() assert c3.get_state() == c.get_state()
def test_eq(self):
c = tflow.tclient_conn()
c2 = c.copy()
assert c == c
assert c != c2
assert c != 42
assert hash(c) != hash(c2)
class TestServerConnection: class TestServerConnection:
@ -147,6 +156,21 @@ class TestServerConnection:
with pytest.raises(ValueError, matches='sni must be str, not '): with pytest.raises(ValueError, matches='sni must be str, not '):
c.establish_ssl(None, b'foobar') c.establish_ssl(None, b'foobar')
def test_state(self):
c = tflow.tserver_conn()
c2 = c.copy()
assert c2.get_state() != c.get_state()
c.id = c2.id = "foo"
assert c2.get_state() == c.get_state()
def test_eq(self):
c = tflow.tserver_conn()
c2 = c.copy()
assert c == c
assert c != c2
assert c != 42
assert hash(c) != hash(c2)
class TestClientConnectionTLS: class TestClientConnectionTLS: