fix bugs, fix tests

This commit is contained in:
Maximilian Hils 2015-08-29 20:53:25 +02:00
parent 63844df343
commit a7058e2a3c
5 changed files with 69 additions and 85 deletions

View File

@ -199,11 +199,12 @@ class StatusBar(urwid.WidgetWrap):
r.append("[%s]" % (":".join(opts)))
if self.master.server.config.mode in ["reverse", "upstream"]:
dst = self.master.server.config.mode.dst
scheme = "https" if dst[0] else "http"
if dst[1] != dst[0]:
scheme += "2https" if dst[1] else "http"
r.append("[dest:%s]" % utils.unparse_url(scheme, *dst[2:]))
dst = self.master.server.config.upstream_server
r.append("[dest:%s]" % netlib.utils.unparse_url(
dst.scheme,
dst.address.host,
dst.address.port
))
if self.master.scripts:
r.append("[")
r.append(("heading_key", "s"))

View File

@ -40,6 +40,7 @@ class _HttpLayer(Layer):
def send_response(self, response):
raise NotImplementedError()
class _StreamingHttpLayer(_HttpLayer):
supports_streaming = True
@ -58,7 +59,6 @@ class _StreamingHttpLayer(_HttpLayer):
class Http1Layer(_StreamingHttpLayer):
def __init__(self, ctx, mode):
super(Http1Layer, self).__init__(ctx)
self.mode = mode
@ -105,12 +105,12 @@ class Http1Layer(_StreamingHttpLayer):
def send_response_headers(self, response):
h = self.client_protocol._assemble_response_first_line(response)
self.client_conn.wfile.write(h+"\r\n")
self.client_conn.wfile.write(h + "\r\n")
h = self.client_protocol._assemble_response_headers(
response,
preserve_transfer_encoding=True
)
self.client_conn.send(h+"\r\n")
self.client_conn.send(h + "\r\n")
def send_response_body(self, response, chunks):
if self.client_protocol.has_chunked_encoding(response.headers):
@ -142,8 +142,10 @@ class Http2Layer(_HttpLayer):
def __init__(self, ctx, mode):
super(Http2Layer, self).__init__(ctx)
self.mode = mode
self.client_protocol = HTTP2Protocol(self.client_conn, is_server=True, unhandled_frame_cb=self.handle_unexpected_frame)
self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False, unhandled_frame_cb=self.handle_unexpected_frame)
self.client_protocol = HTTP2Protocol(self.client_conn, is_server=True,
unhandled_frame_cb=self.handle_unexpected_frame)
self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False,
unhandled_frame_cb=self.handle_unexpected_frame)
def read_request(self):
request = HTTPRequest.from_protocol(
@ -172,17 +174,20 @@ class Http2Layer(_HttpLayer):
def connect(self):
self.ctx.connect()
self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False, unhandled_frame_cb=self.handle_unexpected_frame)
self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False,
unhandled_frame_cb=self.handle_unexpected_frame)
self.server_protocol.perform_connection_preface()
def reconnect(self):
self.ctx.reconnect()
self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False, unhandled_frame_cb=self.handle_unexpected_frame)
self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False,
unhandled_frame_cb=self.handle_unexpected_frame)
self.server_protocol.perform_connection_preface()
def set_server(self, *args, **kwargs):
self.ctx.set_server(*args, **kwargs)
self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False, unhandled_frame_cb=self.handle_unexpected_frame)
self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False,
unhandled_frame_cb=self.handle_unexpected_frame)
self.server_protocol.perform_connection_preface()
def __call__(self):
@ -264,7 +269,10 @@ class UpstreamConnectLayer(Layer):
def __init__(self, ctx, connect_request):
super(UpstreamConnectLayer, self).__init__(ctx)
self.connect_request = connect_request
self.server_conn = ConnectServerConnection((connect_request.host, connect_request.port), self.ctx)
self.server_conn = ConnectServerConnection(
(connect_request.host, connect_request.port),
self.ctx
)
def __call__(self):
layer = self.ctx.next_layer(self)
@ -280,6 +288,9 @@ class UpstreamConnectLayer(Layer):
def reconnect(self):
self.ctx.reconnect()
self.send_request(self.connect_request)
resp = self.read_response("CONNECT")
if resp.code != 200:
raise ProtocolException("Reconnect: Upstream server refuses CONNECT request")
def set_server(self, address, server_tls=None, sni=None, depth=1):
if depth == 1:
@ -290,7 +301,7 @@ class UpstreamConnectLayer(Layer):
self.connect_request.port = address.port
self.server_conn.address = address
else:
self.ctx.set_server(address, server_tls, sni, depth-1)
self.ctx.set_server(address, server_tls, sni, depth - 1)
class HttpLayer(Layer):
@ -413,10 +424,10 @@ class HttpLayer(Layer):
# First send the headers and then transfer the response incrementally
self.send_response_headers(flow.response)
chunks = self.read_response_body(
flow.response.headers,
flow.request.method,
flow.response.code,
max_chunk_size=4096
flow.response.headers,
flow.request.method,
flow.response.code,
max_chunk_size=4096
)
if callable(flow.response.stream):
chunks = flow.response.stream(chunks)
@ -521,7 +532,8 @@ class HttpLayer(Layer):
# If there's not TlsLayer below which could catch the exception,
# TLS will not be established.
if tls and not self.server_conn.tls_established:
raise ProtocolException("Cannot upgrade to SSL, no TLS layer on the protocol stack.")
raise ProtocolException(
"Cannot upgrade to SSL, no TLS layer on the protocol stack.")
else:
if not self.server_conn:
self.connect()
@ -542,7 +554,8 @@ class HttpLayer(Layer):
def validate_request(self, request):
if request.form_in == "absolute" and request.scheme != "http":
self.send_response(make_error_response(400, "Invalid request scheme: %s" % request.scheme))
self.send_response(
make_error_response(400, "Invalid request scheme: %s" % request.scheme))
raise HttpException("Invalid request scheme: %s" % request.scheme)
expected_request_forms = {
@ -570,7 +583,11 @@ class HttpLayer(Layer):
self.send_response(make_error_response(
407,
"Proxy Authentication Required",
odict.ODictCaseless([[k,v] for k, v in self.config.authenticator.auth_challenge_headers().items()])
odict.ODictCaseless(
[
[k, v] for k, v in
self.config.authenticator.auth_challenge_headers().items()
])
))
raise InvalidCredentials("Proxy Authentication Required")
@ -614,6 +631,9 @@ class RequestReplayThread(threading.Thread):
if r.scheme == "https":
connect_request = make_connect_request((r.host, r.port))
server.send(protocol.assemble(connect_request))
resp = protocol.read_response("CONNECT")
if resp.code != 200:
raise HttpError(502, "Upstream server refuses CONNECT request")
server.establish_ssl(
self.config.clientcerts,
sni=self.flow.server_conn.sni

View File

@ -1,9 +1,8 @@
import argparse
from libmproxy import cmdline
from libmproxy.proxy import ProxyConfig, process_proxy_options
from libmproxy.proxy.connection import ServerConnection
from libmproxy.proxy.primitives import ProxyError
from libmproxy.proxy.server import DummyServer, ProxyServer, ConnectionHandler
from libmproxy.proxy.server import DummyServer, ProxyServer, ConnectionHandler2
import tutils
from libpathod import test
from netlib import http, tcp
@ -175,8 +174,10 @@ class TestDummyServer:
class TestConnectionHandler:
def test_fatal_error(self):
config = mock.Mock()
config.mode.get_upstream_server.side_effect = RuntimeError
c = ConnectionHandler(
root_layer = mock.Mock()
root_layer.side_effect = RuntimeError
config.mode.return_value = root_layer
c = ConnectionHandler2(
config,
mock.MagicMock(),
("127.0.0.1",

View File

@ -68,7 +68,7 @@ class CommonMixin:
# SSL with the upstream proxy.
rt = self.master.replay_request(l, block=True)
assert not rt
if isinstance(self, tservers.HTTPUpstreamProxTest) and not self.ssl:
if isinstance(self, tservers.HTTPUpstreamProxTest):
assert l.response.code == 502
else:
assert l.error
@ -506,7 +506,7 @@ class TestTransparentSSL(tservers.TransparentProxTest, CommonMixin, TcpMixin):
p = pathoc.Pathoc(("localhost", self.proxy.port), fp=None)
p.connect()
r = p.request("get:/")
assert r.status_code == 400
assert r.status_code == 502
class TestProxy(tservers.HTTPProxTest):
@ -724,9 +724,9 @@ class TestStreamRequest(tservers.HTTPProxTest):
assert resp.headers["Transfer-Encoding"][0] == 'chunked'
assert resp.status_code == 200
chunks = list(
content for _, content, _ in protocol.read_http_body_chunked(
resp.headers, None, "GET", 200, False))
chunks = list(protocol.read_http_body_chunked(
resp.headers, None, "GET", 200, False
))
assert chunks == ["this", "isatest", ""]
connection.close()
@ -959,6 +959,9 @@ class TestProxyChainingSSLReconnect(tservers.HTTPUpstreamProxTest):
p = self.pathoc()
req = p.request("get:'/p/418:b\"content\"'")
assert req.content == "content"
assert req.status_code == 418
assert self.proxy.tmaster.state.flow_count() == 2 # CONNECT and request
# CONNECT, failing request,
assert self.chain[0].tmaster.state.flow_count() == 4
@ -967,8 +970,7 @@ class TestProxyChainingSSLReconnect(tservers.HTTPUpstreamProxTest):
assert self.chain[1].tmaster.state.flow_count() == 2
# (doesn't store (repeated) CONNECTs from chain[0]
# as it is a regular proxy)
assert req.content == "content"
assert req.status_code == 418
assert not self.chain[1].tmaster.state.flows[0].response # killed
assert self.chain[1].tmaster.state.flows[1].response

View File

@ -181,22 +181,24 @@ class TResolver:
def original_addr(self, sock):
return ("127.0.0.1", self.port)
class TransparentProxTest(ProxTestBase):
ssl = None
resolver = TResolver
@classmethod
@mock.patch("libmproxy.platform.resolver")
def setupAll(cls, _):
def setupAll(cls):
super(TransparentProxTest, cls).setupAll()
if cls.ssl:
ports = [cls.server.port, cls.server2.port]
else:
ports = []
cls.config.mode = TransparentProxyMode(
cls.resolver(cls.server.port),
ports)
cls._resolver = mock.patch(
"libmproxy.platform.resolver",
new=lambda: cls.resolver(cls.server.port)
)
cls._resolver.start()
@classmethod
def teardownAll(cls):
cls._resolver.stop()
super(TransparentProxTest, cls).teardownAll()
@classmethod
def get_proxy_config(cls):
@ -270,48 +272,6 @@ class SocksModeTest(HTTPProxTest):
d["mode"] = "socks5"
return d
class SpoofModeTest(ProxTestBase):
ssl = None
@classmethod
def get_proxy_config(cls):
d = ProxTestBase.get_proxy_config()
d["upstream_server"] = None
d["mode"] = "spoof"
return d
def pathoc(self, sni=None):
"""
Returns a connected Pathoc instance.
"""
p = libpathod.pathoc.Pathoc(
("localhost", self.proxy.port), ssl=self.ssl, sni=sni, fp=None
)
p.connect()
return p
class SSLSpoofModeTest(ProxTestBase):
ssl = True
@classmethod
def get_proxy_config(cls):
d = ProxTestBase.get_proxy_config()
d["upstream_server"] = None
d["mode"] = "sslspoof"
d["spoofed_ssl_port"] = 443
return d
def pathoc(self, sni=None):
"""
Returns a connected Pathoc instance.
"""
p = libpathod.pathoc.Pathoc(
("localhost", self.proxy.port), ssl=self.ssl, sni=sni, fp=None
)
p.connect()
return p
class ChainProxTest(ProxTestBase):
"""