merge smurfix/fix2, add serverconnect hook docs, adjust tests

This commit is contained in:
Maximilian Hils 2013-12-16 22:10:06 +01:00
commit e643759ef6
9 changed files with 60 additions and 21 deletions

View File

@ -36,6 +36,11 @@ Called when a client initiates a connection to the proxy. Note that
a connection can correspond to multiple HTTP requests. a connection can correspond to multiple HTTP requests.
### serverconnect(ScriptContext, ServerConnection)
Called when the proxy initiates a connection to the target server. Note that
a connection can correspond to multiple HTTP requests.
### request(ScriptContext, Flow) ### request(ScriptContext, Flow)
Called when a client request has been received. The __Flow__ object is Called when a client request has been received. The __Flow__ object is

View File

@ -14,6 +14,13 @@ def clientconnect(ctx, client_connect):
""" """
ctx.log("clientconnect") ctx.log("clientconnect")
def serverconnect(ctx, server_connection):
"""
Called when the proxy initiates a connection to the target server. Note that a
connection can correspond to multiple HTTP requests
"""
ctx.log("serverconnect")
def request(ctx, flow): def request(ctx, flow):
""" """
Called when a client request has been received. Called when a client request has been received.

View File

@ -133,18 +133,25 @@ class ProxyHandler(tcp.BaseHandler):
self.server_conn = None self.server_conn = None
tcp.BaseHandler.__init__(self, connection, client_address, server) tcp.BaseHandler.__init__(self, connection, client_address, server)
def get_server_connection(self, cc, scheme, host, port, sni): def get_server_connection(self, cc, scheme, host, port, sni, request=None):
""" """
When SNI is in play, this means we have an SSL-encrypted When SNI is in play, this means we have an SSL-encrypted
connection, which means that the entire handler is dedicated to a connection, which means that the entire handler is dedicated to a
single server connection - no multiplexing. If this assumption ever single server connection - no multiplexing. If this assumption ever
breaks, we'll have to do something different with the SNI host breaks, we'll have to do something different with the SNI host
variable on the handler object. variable on the handler object.
`conn_info` holds the initial connection's parameters, as the
hook might change them. Also, the hook might require an initial
request to figure out connection settings; in this case it can
set require_request, which will cause the connection to be
re-opened after the client's request arrives.
""" """
sc = self.server_conn sc = self.server_conn
if not sni: if not sni:
sni = host sni = host
if sc and (scheme, host, port, sni) != (sc.scheme, sc.host, sc.port, sc.sni): conn_info = (scheme, host, port, sni)
if sc and (conn_info != sc.conn_info or (request and sc.require_request)):
sc.terminate() sc.terminate()
self.server_conn = None self.server_conn = None
self.log( self.log(
@ -159,6 +166,13 @@ class ProxyHandler(tcp.BaseHandler):
if not self.server_conn: if not self.server_conn:
try: try:
self.server_conn = ServerConnection(self.config, scheme, host, port, sni) self.server_conn = ServerConnection(self.config, scheme, host, port, sni)
# Additional attributes, used if the server_connect hook
# needs to change parameters
self.server_conn.request = request
self.server_conn.require_request = False
self.server_conn.conn_info = conn_info
self.channel.ask(self.server_conn) self.channel.ask(self.server_conn)
self.server_conn.connect() self.server_conn.connect()
except tcp.NetLibError, v: except tcp.NetLibError, v:
@ -223,7 +237,7 @@ class ProxyHandler(tcp.BaseHandler):
# the case, we want to reconnect without sending an error # the case, we want to reconnect without sending an error
# to the client. # to the client.
while 1: while 1:
sc = self.get_server_connection(cc, scheme, host, port, self.sni) sc = self.get_server_connection(cc, scheme, host, port, self.sni, request=request)
sc.send(request) sc.send(request)
if sc.requestcount == 1: # add timestamps only for first request (others are not directly affected) if sc.requestcount == 1: # add timestamps only for first request (others are not directly affected)
request.tcp_setup_timestamp = sc.tcp_setup_timestamp request.tcp_setup_timestamp = sc.tcp_setup_timestamp

View File

@ -78,7 +78,7 @@ def concurrent(fn):
r = getattr(flow, fn.func_name) r = getattr(flow, fn.func_name)
_handle_concurrent_reply(fn, r, [ctx, flow]) _handle_concurrent_reply(fn, r, [ctx, flow])
return _concurrent return _concurrent
elif fn.func_name in ["clientconnect", "clientdisconnect", "serverconnect"]: elif fn.func_name in ["clientconnect", "serverconnect", "clientdisconnect"]:
def _concurrent(ctx, conn): def _concurrent(ctx, conn):
_handle_concurrent_reply(fn, conn, [ctx, conn]) _handle_concurrent_reply(fn, conn, [ctx, conn])
return _concurrent return _concurrent

View File

@ -3,6 +3,10 @@ def clientconnect(ctx, cc):
ctx.log("XCLIENTCONNECT") ctx.log("XCLIENTCONNECT")
log.append("clientconnect") log.append("clientconnect")
def serverconnect(ctx, cc):
ctx.log("XSERVERCONNECT")
log.append("serverconnect")
def request(ctx, r): def request(ctx, r):
ctx.log("XREQUEST") ctx.log("XREQUEST")
log.append("request") log.append("request")

View File

@ -1,6 +1,17 @@
import time import time
from libmproxy.script import concurrent from libmproxy.script import concurrent
@concurrent
def clientconnect(context, cc):
context.log("clientconnect")
@concurrent
def serverconnect(context, sc):
context.log("serverconnect")
@concurrent @concurrent
def request(context, flow): def request(context, flow):
time.sleep(0.1) time.sleep(0.1)
@ -16,16 +27,6 @@ def error(context, err):
context.log("error") context.log("error")
@concurrent
def clientconnect(context, cc):
context.log("clientconnect")
@concurrent @concurrent
def clientdisconnect(context, dc): def clientdisconnect(context, dc):
context.log("clientdisconnect") context.log("clientdisconnect")
@concurrent
def serverconnect(context, sc):
context.log("serverconnect")

View File

@ -30,6 +30,9 @@ class TestDumpMaster:
resp = tutils.tresp(req) resp = tutils.tresp(req)
resp.content = content resp.content = content
m.handle_clientconnect(cc) m.handle_clientconnect(cc)
sc = proxy.ServerConnection(m.o, req.scheme, req.host, req.port, None)
sc.reply = mock.MagicMock()
m.handle_serverconnection(sc)
m.handle_request(req) m.handle_request(req)
f = m.handle_response(resp) f = m.handle_response(resp)
cd = flow.ClientDisconnect(cc) cd = flow.ClientDisconnect(cc)
@ -153,6 +156,7 @@ class TestDumpMaster:
scripts=[[tutils.test_data.path("scripts/all.py")]], verbosity=0, eventlog=True scripts=[[tutils.test_data.path("scripts/all.py")]], verbosity=0, eventlog=True
) )
assert "XCLIENTCONNECT" in ret assert "XCLIENTCONNECT" in ret
assert "XSERVERCONNECT" in ret
assert "XREQUEST" in ret assert "XREQUEST" in ret
assert "XRESPONSE" in ret assert "XRESPONSE" in ret
assert "XCLIENTDISCONNECT" in ret assert "XCLIENTDISCONNECT" in ret

View File

@ -1,7 +1,7 @@
import Queue, time, os.path import Queue, time, os.path
from cStringIO import StringIO from cStringIO import StringIO
import email.utils import email.utils
from libmproxy import filt, flow, controller, utils, tnetstring from libmproxy import filt, flow, controller, utils, tnetstring, proxy
import tutils import tutils
@ -575,6 +575,10 @@ class TestFlowMaster:
req = tutils.treq() req = tutils.treq()
fm.handle_clientconnect(req.client_conn) fm.handle_clientconnect(req.client_conn)
assert fm.scripts[0].ns["log"][-1] == "clientconnect" assert fm.scripts[0].ns["log"][-1] == "clientconnect"
sc = proxy.ServerConnection(None, req.scheme, req.host, req.port, None)
sc.reply = controller.DummyReply()
fm.handle_serverconnection(sc)
assert fm.scripts[0].ns["log"][-1] == "serverconnect"
f = fm.handle_request(req) f = fm.handle_request(req)
assert fm.scripts[0].ns["log"][-1] == "request" assert fm.scripts[0].ns["log"][-1] == "request"
resp = tutils.tresp(req) resp = tutils.tresp(req)

View File

@ -103,11 +103,11 @@ class TestScript:
f.error = tutils.terr(f.request) f.error = tutils.terr(f.request)
f.reply = f.request.reply f.reply = f.request.reply
print s.run("response", f) s.run("clientconnect", f)
print s.run("error", f) s.run("serverconnect", f)
print s.run("clientconnect", f) s.run("response", f)
print s.run("clientdisconnect", f) s.run("error", f)
print s.run("serverconnect", f) s.run("clientdisconnect", f)
time.sleep(0.1) time.sleep(0.1)
assert ctx.count == 5 assert ctx.count == 5