Merge pull request #734 from Kriechi/proxy-refactor-cb

move read methods to lower HTTP layer
This commit is contained in:
Maximilian Hils 2015-08-19 18:15:49 +02:00
commit 721bd1c136

View File

@ -28,6 +28,20 @@ class Http1Layer(Layer):
self.client_protocol = HTTP1Protocol(self.client_conn) self.client_protocol = HTTP1Protocol(self.client_conn)
self.server_protocol = HTTP1Protocol(self.server_conn) self.server_protocol = HTTP1Protocol(self.server_conn)
def read_from_client(self):
return HTTPRequest.from_protocol(
self.client_protocol,
body_size_limit=self.config.body_size_limit
)
def read_from_server(self, method):
return HTTPResponse.from_protocol(
self.server_protocol,
method,
body_size_limit=self.config.body_size_limit,
include_body=False,
)
def send_to_client(self, message): def send_to_client(self, message):
self.client_conn.send(self.client_protocol.assemble(message)) self.client_conn.send(self.client_protocol.assemble(message))
@ -57,6 +71,20 @@ class Http2Layer(Layer):
self.client_protocol = HTTP2Protocol(self.client_conn, is_server=True) self.client_protocol = HTTP2Protocol(self.client_conn, is_server=True)
self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False) self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False)
def read_from_client(self):
return HTTPRequest.from_protocol(
self.client_protocol,
body_size_limit=self.config.body_size_limit
)
def read_from_server(self, method):
return HTTPResponse.from_protocol(
self.server_protocol,
method,
body_size_limit=self.config.body_size_limit,
include_body=False,
)
def send_to_client(self, message): def send_to_client(self, message):
# TODO: implement flow control and WINDOW_UPDATE frames # TODO: implement flow control and WINDOW_UPDATE frames
self.client_conn.send(self.client_protocol.assemble(message)) self.client_conn.send(self.client_protocol.assemble(message))
@ -67,15 +95,18 @@ class Http2Layer(Layer):
def connect(self): def connect(self):
self.ctx.connect() self.ctx.connect()
self.server_protocol = HTTP2Protocol(self.server_conn) self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False)
self.server_protocol.perform_connection_preface()
def reconnect(self): def reconnect(self):
self.ctx.reconnect() self.ctx.reconnect()
self.server_protocol = HTTP2Protocol(self.server_conn) self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False)
self.server_protocol.perform_connection_preface()
def set_server(self, *args, **kwargs): def set_server(self, *args, **kwargs):
self.ctx.set_server(*args, **kwargs) self.ctx.set_server(*args, **kwargs)
self.server_protocol = HTTP2Protocol(self.server_conn) self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False)
self.server_protocol.perform_connection_preface()
def __call__(self): def __call__(self):
self.server_protocol.perform_connection_preface() self.server_protocol.perform_connection_preface()
@ -192,10 +223,7 @@ class HttpLayer(Layer):
flow = HTTPFlow(self.client_conn, self.server_conn, live=True) flow = HTTPFlow(self.client_conn, self.server_conn, live=True)
try: try:
request = HTTPRequest.from_protocol( request = self.read_from_client()
self.client_protocol,
body_size_limit=self.config.body_size_limit
)
except tcp.NetLibError: except tcp.NetLibError:
# don't throw an error for disconnects that happen # don't throw an error for disconnects that happen
# before/between requests. # before/between requests.
@ -320,13 +348,7 @@ class HttpLayer(Layer):
def get_response_from_server(self, flow): def get_response_from_server(self, flow):
def get_response(): def get_response():
self.send_to_server(flow.request) self.send_to_server(flow.request)
# Only get the headers at first... flow.response = self.read_from_server(flow.request.method)
flow.response = HTTPResponse.from_protocol(
self.server_protocol,
flow.request.method,
body_size_limit=self.config.body_size_limit,
include_body=False,
)
try: try:
get_response() get_response()
@ -408,9 +430,9 @@ class HttpLayer(Layer):
return return
def establish_server_connection(self, flow): def establish_server_connection(self, flow):
address = tcp.Address((flow.request.host, flow.request.port)) address = tcp.Address((flow.request.host, flow.request.port))
tls = (flow.request.scheme == "https") tls = (flow.request.scheme == "https")
if self.mode == "regular" or self.mode == "transparent": if self.mode == "regular" or self.mode == "transparent":
# If there's an existing connection that doesn't match our expectations, kill it. # If there's an existing connection that doesn't match our expectations, kill it.
if address != self.server_conn.address or tls != self.server_conn.ssl_established: if address != self.server_conn.address or tls != self.server_conn.ssl_established:
@ -424,7 +446,6 @@ class HttpLayer(Layer):
# TLS will not be established. # TLS will not be established.
if tls and not self.server_conn.tls_established: if tls and not self.server_conn.tls_established:
raise ProtocolException("Cannot upgrade to SSL, no TLS layer on the protocol stack.") raise ProtocolException("Cannot upgrade to SSL, no TLS layer on the protocol stack.")
else: else:
if not self.server_conn: if not self.server_conn:
self.connect() self.connect()