fix bugs, fix tests

This commit is contained in:
Maximilian Hils 2015-08-29 20:53:25 +02:00
parent 63844df343
commit a7058e2a3c
5 changed files with 69 additions and 85 deletions

View File

@ -199,11 +199,12 @@ class StatusBar(urwid.WidgetWrap):
r.append("[%s]" % (":".join(opts))) r.append("[%s]" % (":".join(opts)))
if self.master.server.config.mode in ["reverse", "upstream"]: if self.master.server.config.mode in ["reverse", "upstream"]:
dst = self.master.server.config.mode.dst dst = self.master.server.config.upstream_server
scheme = "https" if dst[0] else "http" r.append("[dest:%s]" % netlib.utils.unparse_url(
if dst[1] != dst[0]: dst.scheme,
scheme += "2https" if dst[1] else "http" dst.address.host,
r.append("[dest:%s]" % utils.unparse_url(scheme, *dst[2:])) dst.address.port
))
if self.master.scripts: if self.master.scripts:
r.append("[") r.append("[")
r.append(("heading_key", "s")) r.append(("heading_key", "s"))

View File

@ -40,6 +40,7 @@ class _HttpLayer(Layer):
def send_response(self, response): def send_response(self, response):
raise NotImplementedError() raise NotImplementedError()
class _StreamingHttpLayer(_HttpLayer): class _StreamingHttpLayer(_HttpLayer):
supports_streaming = True supports_streaming = True
@ -58,7 +59,6 @@ class _StreamingHttpLayer(_HttpLayer):
class Http1Layer(_StreamingHttpLayer): class Http1Layer(_StreamingHttpLayer):
def __init__(self, ctx, mode): def __init__(self, ctx, mode):
super(Http1Layer, self).__init__(ctx) super(Http1Layer, self).__init__(ctx)
self.mode = mode self.mode = mode
@ -105,12 +105,12 @@ class Http1Layer(_StreamingHttpLayer):
def send_response_headers(self, response): def send_response_headers(self, response):
h = self.client_protocol._assemble_response_first_line(response) h = self.client_protocol._assemble_response_first_line(response)
self.client_conn.wfile.write(h+"\r\n") self.client_conn.wfile.write(h + "\r\n")
h = self.client_protocol._assemble_response_headers( h = self.client_protocol._assemble_response_headers(
response, response,
preserve_transfer_encoding=True preserve_transfer_encoding=True
) )
self.client_conn.send(h+"\r\n") self.client_conn.send(h + "\r\n")
def send_response_body(self, response, chunks): def send_response_body(self, response, chunks):
if self.client_protocol.has_chunked_encoding(response.headers): if self.client_protocol.has_chunked_encoding(response.headers):
@ -142,8 +142,10 @@ class Http2Layer(_HttpLayer):
def __init__(self, ctx, mode): def __init__(self, ctx, mode):
super(Http2Layer, self).__init__(ctx) super(Http2Layer, self).__init__(ctx)
self.mode = mode self.mode = mode
self.client_protocol = HTTP2Protocol(self.client_conn, is_server=True, unhandled_frame_cb=self.handle_unexpected_frame) self.client_protocol = HTTP2Protocol(self.client_conn, is_server=True,
self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False, unhandled_frame_cb=self.handle_unexpected_frame) unhandled_frame_cb=self.handle_unexpected_frame)
self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False,
unhandled_frame_cb=self.handle_unexpected_frame)
def read_request(self): def read_request(self):
request = HTTPRequest.from_protocol( request = HTTPRequest.from_protocol(
@ -172,17 +174,20 @@ class Http2Layer(_HttpLayer):
def connect(self): def connect(self):
self.ctx.connect() self.ctx.connect()
self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False, unhandled_frame_cb=self.handle_unexpected_frame) self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False,
unhandled_frame_cb=self.handle_unexpected_frame)
self.server_protocol.perform_connection_preface() self.server_protocol.perform_connection_preface()
def reconnect(self): def reconnect(self):
self.ctx.reconnect() self.ctx.reconnect()
self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False, unhandled_frame_cb=self.handle_unexpected_frame) self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False,
unhandled_frame_cb=self.handle_unexpected_frame)
self.server_protocol.perform_connection_preface() self.server_protocol.perform_connection_preface()
def set_server(self, *args, **kwargs): def set_server(self, *args, **kwargs):
self.ctx.set_server(*args, **kwargs) self.ctx.set_server(*args, **kwargs)
self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False, unhandled_frame_cb=self.handle_unexpected_frame) self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False,
unhandled_frame_cb=self.handle_unexpected_frame)
self.server_protocol.perform_connection_preface() self.server_protocol.perform_connection_preface()
def __call__(self): def __call__(self):
@ -264,7 +269,10 @@ class UpstreamConnectLayer(Layer):
def __init__(self, ctx, connect_request): def __init__(self, ctx, connect_request):
super(UpstreamConnectLayer, self).__init__(ctx) super(UpstreamConnectLayer, self).__init__(ctx)
self.connect_request = connect_request self.connect_request = connect_request
self.server_conn = ConnectServerConnection((connect_request.host, connect_request.port), self.ctx) self.server_conn = ConnectServerConnection(
(connect_request.host, connect_request.port),
self.ctx
)
def __call__(self): def __call__(self):
layer = self.ctx.next_layer(self) layer = self.ctx.next_layer(self)
@ -280,6 +288,9 @@ class UpstreamConnectLayer(Layer):
def reconnect(self): def reconnect(self):
self.ctx.reconnect() self.ctx.reconnect()
self.send_request(self.connect_request) self.send_request(self.connect_request)
resp = self.read_response("CONNECT")
if resp.code != 200:
raise ProtocolException("Reconnect: Upstream server refuses CONNECT request")
def set_server(self, address, server_tls=None, sni=None, depth=1): def set_server(self, address, server_tls=None, sni=None, depth=1):
if depth == 1: if depth == 1:
@ -290,7 +301,7 @@ class UpstreamConnectLayer(Layer):
self.connect_request.port = address.port self.connect_request.port = address.port
self.server_conn.address = address self.server_conn.address = address
else: else:
self.ctx.set_server(address, server_tls, sni, depth-1) self.ctx.set_server(address, server_tls, sni, depth - 1)
class HttpLayer(Layer): class HttpLayer(Layer):
@ -413,10 +424,10 @@ class HttpLayer(Layer):
# First send the headers and then transfer the response incrementally # First send the headers and then transfer the response incrementally
self.send_response_headers(flow.response) self.send_response_headers(flow.response)
chunks = self.read_response_body( chunks = self.read_response_body(
flow.response.headers, flow.response.headers,
flow.request.method, flow.request.method,
flow.response.code, flow.response.code,
max_chunk_size=4096 max_chunk_size=4096
) )
if callable(flow.response.stream): if callable(flow.response.stream):
chunks = flow.response.stream(chunks) chunks = flow.response.stream(chunks)
@ -521,7 +532,8 @@ class HttpLayer(Layer):
# If there's not TlsLayer below which could catch the exception, # If there's not TlsLayer below which could catch the exception,
# TLS will not be established. # TLS will not be established.
if tls and not self.server_conn.tls_established: if tls and not self.server_conn.tls_established:
raise ProtocolException("Cannot upgrade to SSL, no TLS layer on the protocol stack.") raise ProtocolException(
"Cannot upgrade to SSL, no TLS layer on the protocol stack.")
else: else:
if not self.server_conn: if not self.server_conn:
self.connect() self.connect()
@ -542,7 +554,8 @@ class HttpLayer(Layer):
def validate_request(self, request): def validate_request(self, request):
if request.form_in == "absolute" and request.scheme != "http": if request.form_in == "absolute" and request.scheme != "http":
self.send_response(make_error_response(400, "Invalid request scheme: %s" % request.scheme)) self.send_response(
make_error_response(400, "Invalid request scheme: %s" % request.scheme))
raise HttpException("Invalid request scheme: %s" % request.scheme) raise HttpException("Invalid request scheme: %s" % request.scheme)
expected_request_forms = { expected_request_forms = {
@ -570,7 +583,11 @@ class HttpLayer(Layer):
self.send_response(make_error_response( self.send_response(make_error_response(
407, 407,
"Proxy Authentication Required", "Proxy Authentication Required",
odict.ODictCaseless([[k,v] for k, v in self.config.authenticator.auth_challenge_headers().items()]) odict.ODictCaseless(
[
[k, v] for k, v in
self.config.authenticator.auth_challenge_headers().items()
])
)) ))
raise InvalidCredentials("Proxy Authentication Required") raise InvalidCredentials("Proxy Authentication Required")
@ -614,6 +631,9 @@ class RequestReplayThread(threading.Thread):
if r.scheme == "https": if r.scheme == "https":
connect_request = make_connect_request((r.host, r.port)) connect_request = make_connect_request((r.host, r.port))
server.send(protocol.assemble(connect_request)) server.send(protocol.assemble(connect_request))
resp = protocol.read_response("CONNECT")
if resp.code != 200:
raise HttpError(502, "Upstream server refuses CONNECT request")
server.establish_ssl( server.establish_ssl(
self.config.clientcerts, self.config.clientcerts,
sni=self.flow.server_conn.sni sni=self.flow.server_conn.sni

View File

@ -1,9 +1,8 @@
import argparse
from libmproxy import cmdline from libmproxy import cmdline
from libmproxy.proxy import ProxyConfig, process_proxy_options from libmproxy.proxy import ProxyConfig, process_proxy_options
from libmproxy.proxy.connection import ServerConnection from libmproxy.proxy.connection import ServerConnection
from libmproxy.proxy.primitives import ProxyError from libmproxy.proxy.primitives import ProxyError
from libmproxy.proxy.server import DummyServer, ProxyServer, ConnectionHandler from libmproxy.proxy.server import DummyServer, ProxyServer, ConnectionHandler2
import tutils import tutils
from libpathod import test from libpathod import test
from netlib import http, tcp from netlib import http, tcp
@ -175,8 +174,10 @@ class TestDummyServer:
class TestConnectionHandler: class TestConnectionHandler:
def test_fatal_error(self): def test_fatal_error(self):
config = mock.Mock() config = mock.Mock()
config.mode.get_upstream_server.side_effect = RuntimeError root_layer = mock.Mock()
c = ConnectionHandler( root_layer.side_effect = RuntimeError
config.mode.return_value = root_layer
c = ConnectionHandler2(
config, config,
mock.MagicMock(), mock.MagicMock(),
("127.0.0.1", ("127.0.0.1",

View File

@ -68,7 +68,7 @@ class CommonMixin:
# SSL with the upstream proxy. # SSL with the upstream proxy.
rt = self.master.replay_request(l, block=True) rt = self.master.replay_request(l, block=True)
assert not rt assert not rt
if isinstance(self, tservers.HTTPUpstreamProxTest) and not self.ssl: if isinstance(self, tservers.HTTPUpstreamProxTest):
assert l.response.code == 502 assert l.response.code == 502
else: else:
assert l.error assert l.error
@ -506,7 +506,7 @@ class TestTransparentSSL(tservers.TransparentProxTest, CommonMixin, TcpMixin):
p = pathoc.Pathoc(("localhost", self.proxy.port), fp=None) p = pathoc.Pathoc(("localhost", self.proxy.port), fp=None)
p.connect() p.connect()
r = p.request("get:/") r = p.request("get:/")
assert r.status_code == 400 assert r.status_code == 502
class TestProxy(tservers.HTTPProxTest): class TestProxy(tservers.HTTPProxTest):
@ -724,9 +724,9 @@ class TestStreamRequest(tservers.HTTPProxTest):
assert resp.headers["Transfer-Encoding"][0] == 'chunked' assert resp.headers["Transfer-Encoding"][0] == 'chunked'
assert resp.status_code == 200 assert resp.status_code == 200
chunks = list( chunks = list(protocol.read_http_body_chunked(
content for _, content, _ in protocol.read_http_body_chunked( resp.headers, None, "GET", 200, False
resp.headers, None, "GET", 200, False)) ))
assert chunks == ["this", "isatest", ""] assert chunks == ["this", "isatest", ""]
connection.close() connection.close()
@ -959,6 +959,9 @@ class TestProxyChainingSSLReconnect(tservers.HTTPUpstreamProxTest):
p = self.pathoc() p = self.pathoc()
req = p.request("get:'/p/418:b\"content\"'") req = p.request("get:'/p/418:b\"content\"'")
assert req.content == "content"
assert req.status_code == 418
assert self.proxy.tmaster.state.flow_count() == 2 # CONNECT and request assert self.proxy.tmaster.state.flow_count() == 2 # CONNECT and request
# CONNECT, failing request, # CONNECT, failing request,
assert self.chain[0].tmaster.state.flow_count() == 4 assert self.chain[0].tmaster.state.flow_count() == 4
@ -967,8 +970,7 @@ class TestProxyChainingSSLReconnect(tservers.HTTPUpstreamProxTest):
assert self.chain[1].tmaster.state.flow_count() == 2 assert self.chain[1].tmaster.state.flow_count() == 2
# (doesn't store (repeated) CONNECTs from chain[0] # (doesn't store (repeated) CONNECTs from chain[0]
# as it is a regular proxy) # as it is a regular proxy)
assert req.content == "content"
assert req.status_code == 418
assert not self.chain[1].tmaster.state.flows[0].response # killed assert not self.chain[1].tmaster.state.flows[0].response # killed
assert self.chain[1].tmaster.state.flows[1].response assert self.chain[1].tmaster.state.flows[1].response

View File

@ -181,22 +181,24 @@ class TResolver:
def original_addr(self, sock): def original_addr(self, sock):
return ("127.0.0.1", self.port) return ("127.0.0.1", self.port)
class TransparentProxTest(ProxTestBase): class TransparentProxTest(ProxTestBase):
ssl = None ssl = None
resolver = TResolver resolver = TResolver
@classmethod @classmethod
@mock.patch("libmproxy.platform.resolver") def setupAll(cls):
def setupAll(cls, _):
super(TransparentProxTest, cls).setupAll() super(TransparentProxTest, cls).setupAll()
if cls.ssl:
ports = [cls.server.port, cls.server2.port] cls._resolver = mock.patch(
else: "libmproxy.platform.resolver",
ports = [] new=lambda: cls.resolver(cls.server.port)
cls.config.mode = TransparentProxyMode( )
cls.resolver(cls.server.port), cls._resolver.start()
ports)
@classmethod
def teardownAll(cls):
cls._resolver.stop()
super(TransparentProxTest, cls).teardownAll()
@classmethod @classmethod
def get_proxy_config(cls): def get_proxy_config(cls):
@ -270,48 +272,6 @@ class SocksModeTest(HTTPProxTest):
d["mode"] = "socks5" d["mode"] = "socks5"
return d return d
class SpoofModeTest(ProxTestBase):
ssl = None
@classmethod
def get_proxy_config(cls):
d = ProxTestBase.get_proxy_config()
d["upstream_server"] = None
d["mode"] = "spoof"
return d
def pathoc(self, sni=None):
"""
Returns a connected Pathoc instance.
"""
p = libpathod.pathoc.Pathoc(
("localhost", self.proxy.port), ssl=self.ssl, sni=sni, fp=None
)
p.connect()
return p
class SSLSpoofModeTest(ProxTestBase):
ssl = True
@classmethod
def get_proxy_config(cls):
d = ProxTestBase.get_proxy_config()
d["upstream_server"] = None
d["mode"] = "sslspoof"
d["spoofed_ssl_port"] = 443
return d
def pathoc(self, sni=None):
"""
Returns a connected Pathoc instance.
"""
p = libpathod.pathoc.Pathoc(
("localhost", self.proxy.port), ssl=self.ssl, sni=sni, fp=None
)
p.connect()
return p
class ChainProxTest(ProxTestBase): class ChainProxTest(ProxTestBase):
""" """