fix inline script redirects

This commit is contained in:
Maximilian Hils 2015-08-27 17:35:53 +02:00
parent 515c024448
commit 83decd6771
5 changed files with 45 additions and 34 deletions

View File

@ -8,6 +8,7 @@ import Cookie
import cookielib import cookielib
import os import os
import re import re
from libmproxy.protocol2.http import RequestReplayThread
from netlib import odict, wsgi, tcp from netlib import odict, wsgi, tcp
from netlib.http.semantics import CONTENT_MISSING from netlib.http.semantics import CONTENT_MISSING
@ -934,7 +935,7 @@ class FlowMaster(controller.Master):
f.response = None f.response = None
f.error = None f.error = None
self.process_new_request(f) self.process_new_request(f)
rt = http.RequestReplayThread( rt = RequestReplayThread(
self.server.config, self.server.config,
f, f,
self.masterq if run_scripthooks else False, self.masterq if run_scripthooks else False,

View File

@ -212,10 +212,11 @@ class UpstreamConnectLayer(Layer):
self.ctx.reconnect() self.ctx.reconnect()
self.send_to_server(self.connect_request) self.send_to_server(self.connect_request)
def set_server(self, address, server_tls, sni, depth=1): def set_server(self, address, server_tls=None, sni=None, depth=1):
if depth == 1: if depth == 1:
if self.ctx.server_conn: if self.ctx.server_conn:
self.ctx.reconnect() self.ctx.reconnect()
address = Address.wrap(address)
self.connect_request.host = address.host self.connect_request.host = address.host
self.connect_request.port = address.port self.connect_request.port = address.port
self.server_conn.address = address self.server_conn.address = address
@ -227,11 +228,16 @@ class HttpLayer(Layer):
def __init__(self, ctx, mode): def __init__(self, ctx, mode):
super(HttpLayer, self).__init__(ctx) super(HttpLayer, self).__init__(ctx)
self.mode = mode self.mode = mode
self.__original_server_conn = None
"Contains the original destination in transparent mode, which needs to be restored"
"if an inline script modified the target server for a single http request"
def __call__(self): def __call__(self):
if self.mode == "transparent":
self.__original_server_conn = self.server_conn
while True: while True:
try: try:
flow = HTTPFlow(self.client_conn, self.server_conn, live=True) flow = HTTPFlow(self.client_conn, self.server_conn, live=self)
try: try:
request = self.read_from_client() request = self.read_from_client()
@ -288,7 +294,7 @@ class HttpLayer(Layer):
flow.live = False flow.live = False
def handle_regular_mode_connect(self, request): def handle_regular_mode_connect(self, request):
self.set_server((request.host, request.port), False, None) self.set_server((request.host, request.port))
self.send_to_client(make_connect_response(request.httpversion)) self.send_to_client(make_connect_response(request.httpversion))
layer = self.ctx.next_layer(self) layer = self.ctx.next_layer(self)
layer() layer()
@ -433,11 +439,10 @@ class HttpLayer(Layer):
if flow.request.form_in == "authority": if flow.request.form_in == "authority":
flow.request.scheme = "http" # pseudo value flow.request.scheme = "http" # pseudo value
else: else:
flow.request.host = self.ctx.server_conn.address.host flow.request.host = self.__original_server_conn.address.host
flow.request.port = self.ctx.server_conn.address.port flow.request.port = self.__original_server_conn.address.port
flow.request.scheme = "https" if self.server_conn.tls_established else "http" flow.request.scheme = "https" if self.__original_server_conn.tls_established else "http"
# TODO: Expose .set_server functionality to inline scripts
request_reply = self.channel.ask("request", flow) request_reply = self.channel.ask("request", flow)
if request_reply is None or request_reply == KILL: if request_reply is None or request_reply == KILL:
raise Kill() raise Kill()

View File

@ -112,7 +112,7 @@ class ServerConnectionMixin(object):
self.server_conn.address = address self.server_conn.address = address
self.connect() self.connect()
def set_server(self, address, server_tls, sni, depth=1): def set_server(self, address, server_tls=None, sni=None, depth=1):
if depth == 1: if depth == 1:
if self.server_conn: if self.server_conn:
self._disconnect() self._disconnect()

View File

@ -110,9 +110,9 @@ class TlsLayer(Layer):
if self._server_tls and not self.server_conn.tls_established: if self._server_tls and not self.server_conn.tls_established:
self._establish_tls_with_server() self._establish_tls_with_server()
def set_server(self, address, server_tls, sni, depth=1): def set_server(self, address, server_tls=None, sni=None, depth=1):
self.ctx.set_server(address, server_tls, sni, depth) self.ctx.set_server(address, server_tls, sni, depth)
if server_tls is not None: if depth == 1 and server_tls is not None:
self._sni_from_server_change = sni self._sni_from_server_change = sni
self._server_tls = server_tls self._server_tls = server_tls

View File

@ -1,6 +1,7 @@
import socket import socket
import time import time
from OpenSSL import SSL from OpenSSL import SSL
from netlib.tcp import Address
import netlib.tutils import netlib.tutils
from netlib import tcp, http, socks from netlib import tcp, http, socks
@ -655,63 +656,67 @@ class MasterRedirectRequest(tservers.TestMaster):
redirect_port = None # Set by TestRedirectRequest redirect_port = None # Set by TestRedirectRequest
def handle_request(self, f): def handle_request(self, f):
request = f.request if f.request.path == "/p/201":
if request.path == "/p/201":
addr = f.live.c.server_conn.address # This part should have no impact, but it should not cause any exceptions.
assert f.live.change_server( addr = f.live.server_conn.address
("127.0.0.1", self.redirect_port), ssl=False) addr2 = Address(("127.0.0.1", self.redirect_port))
assert not f.live.change_server( f.live.set_server(addr2)
("127.0.0.1", self.redirect_port), ssl=False) f.live.connect()
tutils.raises( f.live.set_server(addr)
"SSL handshake error", f.live.connect()
f.live.change_server,
("127.0.0.1", # This is the actual redirection.
self.redirect_port), f.request.port = self.redirect_port
ssl=True) super(MasterRedirectRequest, self).handle_request(f)
assert f.live.change_server(addr, ssl=False)
request.url = "http://127.0.0.1:%s/p/201" % self.redirect_port
tservers.TestMaster.handle_request(self, f)
def handle_response(self, f): def handle_response(self, f):
f.response.content = str(f.client_conn.address.port) f.response.content = str(f.client_conn.address.port)
f.response.headers[ f.response.headers[
"server-conn-id"] = [str(f.server_conn.source_address.port)] "server-conn-id"] = [str(f.server_conn.source_address.port)]
tservers.TestMaster.handle_response(self, f) super(MasterRedirectRequest, self).handle_response(f)
class TestRedirectRequest(tservers.HTTPProxTest): class TestRedirectRequest(tservers.HTTPProxTest):
masterclass = MasterRedirectRequest masterclass = MasterRedirectRequest
ssl = True
def test_redirect(self): def test_redirect(self):
"""
Imagine a single HTTPS connection with three requests:
1. First request should pass through unmodified
2. Second request will be redirected to a different host by an inline script
3. Third request should pass through unmodified
This test verifies that the original destination is restored for the third request.
"""
self.master.redirect_port = self.server2.port self.master.redirect_port = self.server2.port
p = self.pathoc() p = self.pathoc()
self.server.clear_log() self.server.clear_log()
self.server2.clear_log() self.server2.clear_log()
r1 = p.request("get:'%s/p/200'" % self.server.urlbase) r1 = p.request("get:'/p/200'")
assert r1.status_code == 200 assert r1.status_code == 200
assert self.server.last_log() assert self.server.last_log()
assert not self.server2.last_log() assert not self.server2.last_log()
self.server.clear_log() self.server.clear_log()
self.server2.clear_log() self.server2.clear_log()
r2 = p.request("get:'%s/p/201'" % self.server.urlbase) r2 = p.request("get:'/p/201'")
assert r2.status_code == 201 assert r2.status_code == 201
assert not self.server.last_log() assert not self.server.last_log()
assert self.server2.last_log() assert self.server2.last_log()
self.server.clear_log() self.server.clear_log()
self.server2.clear_log() self.server2.clear_log()
r3 = p.request("get:'%s/p/202'" % self.server.urlbase) r3 = p.request("get:'/p/202'")
assert r3.status_code == 202 assert r3.status_code == 202
assert self.server.last_log() assert self.server.last_log()
assert not self.server2.last_log() assert not self.server2.last_log()
assert r1.content == r2.content == r3.content assert r1.content == r2.content == r3.content
assert r1.headers.get_first(
"server-conn-id") == r3.headers.get_first("server-conn-id")
# Make sure that we actually use the same connection in this test case
class MasterStreamRequest(tservers.TestMaster): class MasterStreamRequest(tservers.TestMaster):