diff --git a/mitmproxy/protocol/base.py b/mitmproxy/protocol/base.py index b2caeb16a..da6e83563 100644 --- a/mitmproxy/protocol/base.py +++ b/mitmproxy/protocol/base.py @@ -163,8 +163,11 @@ class ServerConnectionMixin(object): self.server_conn.close() self.channel.tell("serverdisconnect", self.server_conn) - self.server_conn = models.ServerConnection(address, - (self.server_conn.source_address.host, 0), self.config.options.spoof_source_address) + self.server_conn = models.ServerConnection( + address, + (self.server_conn.source_address.host, 0), + self.config.options.spoof_source_address + ) def connect(self): """ diff --git a/netlib/exceptions.py b/netlib/exceptions.py index 376514096..dec79c22a 100644 --- a/netlib/exceptions.py +++ b/netlib/exceptions.py @@ -58,7 +58,3 @@ class InvalidCertificateException(TlsException): class Timeout(TcpException): pass - - -class ProtocolException(NetlibException): - pass diff --git a/netlib/tcp.py b/netlib/tcp.py index 2c55de85b..eea104252 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -484,14 +484,12 @@ class _Connection(object): if not isinstance(self.connection, SSL.Connection): if not getattr(self.wfile, "closed", False): try: - if self.wfile: - self.wfile.flush() - self.wfile.close() + self.wfile.flush() + self.wfile.close() except exceptions.TcpDisconnect: pass - if self.rfile: - self.rfile.close() + self.rfile.close() else: try: self.connection.shutdown() @@ -731,11 +729,7 @@ class TCPClient(_Connection): def connect(self): try: - # Allow the socket to be manipulated by using the server_conn stub. - if not self.connection: - connection = socket.socket(self.address.family, socket.SOCK_STREAM) - else: - connection = self.connection + connection = socket.socket(self.address.family, socket.SOCK_STREAM) if self.spoof_source_address: try: @@ -744,7 +738,8 @@ class TCPClient(_Connection): connection.setsockopt(socket.SOL_IP, 19, 1) except socket.error as e: raise exceptions.TcpException( - "Failed to spoof the source address: " + e.strerror) + "Failed to spoof the source address: " + e.strerror + ) if self.source_address: connection.bind(self.source_address()) connection.connect(self.address())