mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-23 08:11:00 +00:00
fix header assembly, other improvements
This commit is contained in:
parent
efdb25ef68
commit
24fc8ff292
@ -1,12 +1,12 @@
|
||||
from libmproxy import flow
|
||||
from libmproxy.utils import timestamp
|
||||
from netlib import http, utils, tcp
|
||||
import libmproxy.utils
|
||||
from netlib import http, tcp
|
||||
import netlib.utils
|
||||
from netlib.odict import ODictCaseless
|
||||
|
||||
KILL = 0 # FIXME: Remove duplication with proxy module
|
||||
LEGACY = True
|
||||
|
||||
#FIXME: Combine with ProxyError?
|
||||
class ProtocolError(Exception):
|
||||
def __init__(self, code, msg, headers=None):
|
||||
self.code, self.msg, self.headers = code, msg, headers
|
||||
@ -30,10 +30,6 @@ def handle_messages(conntype, connection_handler):
|
||||
_handle("messages", conntype, connection_handler)
|
||||
|
||||
|
||||
def handle_error(conntype, connection_handler, e):
|
||||
_handle("error", conntype, connection_handler, e)
|
||||
|
||||
|
||||
class ConnectionTypeChange(Exception):
|
||||
pass
|
||||
|
||||
@ -56,7 +52,20 @@ class HTTPFlow(Flow):
|
||||
self.request, self.response = request, response
|
||||
|
||||
|
||||
class HTTPResponse(object):
|
||||
class HTTPMessage(object):
|
||||
def _assemble_headers(self):
|
||||
headers = self.headers.copy()
|
||||
libmproxy.utils.del_all(headers,
|
||||
["proxy-connection",
|
||||
"transfer-encoding"])
|
||||
if self.content:
|
||||
headers["Content-Length"] = [str(len(self.content))]
|
||||
elif 'Transfer-Encoding' in self.headers: # content-length for e.g. chuncked transfer-encoding with no content
|
||||
headers["Content-Length"] = ["0"]
|
||||
|
||||
return str(headers)
|
||||
|
||||
class HTTPResponse(HTTPMessage):
|
||||
def __init__(self, http_version, code, msg, headers, content, timestamp_start, timestamp_end):
|
||||
self.http_version = http_version
|
||||
self.code = code
|
||||
@ -75,7 +84,7 @@ class HTTPResponse(object):
|
||||
|
||||
def _assemble(self):
|
||||
response_line = 'HTTP/%s.%s %s %s'%(self.http_version[0], self.http_version[1], self.code, self.msg)
|
||||
return '%s\r\n%s\r\n%s' % (response_line, str(self.headers), self.content)
|
||||
return '%s\r\n%s\r\n%s' % (response_line, self._assemble_headers(), self.content)
|
||||
|
||||
@classmethod
|
||||
def from_stream(cls, rfile, request_method, include_content=True, body_size_limit=None):
|
||||
@ -85,15 +94,16 @@ class HTTPResponse(object):
|
||||
if not include_content:
|
||||
raise NotImplementedError
|
||||
|
||||
timestamp_start = timestamp()
|
||||
timestamp_start = libmproxy.utils.timestamp()
|
||||
http_version, code, msg, headers, content = http.read_response(
|
||||
rfile,
|
||||
request_method,
|
||||
body_size_limit)
|
||||
timestamp_end = timestamp()
|
||||
timestamp_end = libmproxy.utils.timestamp()
|
||||
return HTTPResponse(http_version, code, msg, headers, content, timestamp_start, timestamp_end)
|
||||
|
||||
class HTTPRequest(object):
|
||||
|
||||
class HTTPRequest(HTTPMessage):
|
||||
def __init__(self, form_in, method, scheme, host, port, path, http_version, headers, content,
|
||||
timestamp_start, timestamp_end, form_out=None, ip=None):
|
||||
self.form_in = form_in
|
||||
@ -109,7 +119,7 @@ class HTTPRequest(object):
|
||||
self.timestamp_end = timestamp_end
|
||||
|
||||
self.form_out = form_out or self.form_in
|
||||
self.ip = ip # resolved ip address
|
||||
self.ip = ip # resolved ip address
|
||||
assert isinstance(headers, ODictCaseless)
|
||||
|
||||
#FIXME: Remove, legacy
|
||||
@ -123,9 +133,14 @@ class HTTPRequest(object):
|
||||
elif self.form_out == "authority":
|
||||
request_line = '%s %s:%s HTTP/%s.%s' % (self.method, self.host, self.port,
|
||||
self.http_version[0], self.http_version[1])
|
||||
elif self.form_out == "absolute":
|
||||
request_line = '%s %s://%s:%s%s HTTP/%s.%s' % \
|
||||
(self.method, self.scheme, self.host, self.port, self.path,
|
||||
self.http_version[0], self.http_version[1])
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return '%s\r\n%s\r\n%s' % (request_line, str(self.headers), self.content)
|
||||
raise http.HttpError(400, "Invalid request form")
|
||||
|
||||
return '%s\r\n%s\r\n%s' % (request_line, self._assemble_headers(), self.content)
|
||||
|
||||
@classmethod
|
||||
def from_stream(cls, rfile, include_content=True, body_size_limit=None):
|
||||
@ -135,7 +150,7 @@ class HTTPRequest(object):
|
||||
http_version, host, port, scheme, method, path, headers, content, timestamp_start, timestamp_end \
|
||||
= None, None, None, None, None, None, None, None, None, None
|
||||
|
||||
timestamp_start = timestamp()
|
||||
timestamp_start = libmproxy.utils.timestamp()
|
||||
request_line = HTTPHandler.get_line(rfile)
|
||||
|
||||
request_line_parts = http.parse_init(request_line)
|
||||
@ -147,7 +162,7 @@ class HTTPRequest(object):
|
||||
form_in = "asterisk"
|
||||
elif path.startswith("/"):
|
||||
form_in = "origin"
|
||||
if not utils.isascii(path):
|
||||
if not netlib.utils.isascii(path):
|
||||
raise ProtocolError(400, "Bad HTTP request line: %s"%repr(request_line))
|
||||
elif method.upper() == 'CONNECT':
|
||||
form_in = "authority"
|
||||
@ -168,7 +183,7 @@ class HTTPRequest(object):
|
||||
|
||||
if include_content:
|
||||
content = http.read_http_body(rfile, headers, body_size_limit, True)
|
||||
timestamp_end = timestamp()
|
||||
timestamp_end = libmproxy.utils.timestamp()
|
||||
|
||||
return HTTPRequest(form_in, method, scheme, host, port, path, http_version, headers, content,
|
||||
timestamp_start, timestamp_end)
|
||||
@ -182,48 +197,60 @@ class HTTPHandler(ProtocolHandler):
|
||||
self.c.close = True
|
||||
|
||||
def handle_error(self, e):
|
||||
raise e # FIXME: Proper error handling
|
||||
raise e # FIXME: Proper error handling
|
||||
|
||||
def handle_request(self):
|
||||
try:
|
||||
flow = HTTPFlow(self.c.client_conn, self.c.server_conn, timestamp(), None, None, None)
|
||||
flow = HTTPFlow(self.c.client_conn, self.c.server_conn, libmproxy.utils.timestamp(), None, None, None)
|
||||
flow.request = self.read_request()
|
||||
request_reply = self.c.channel.ask("request" if LEGACY else "httprequest", flow.request)
|
||||
|
||||
request_reply = self.c.channel.ask("request" if LEGACY else "httprequest", flow.request if LEGACY else flow)
|
||||
if request_reply is None or request_reply == KILL:
|
||||
return False
|
||||
return False
|
||||
|
||||
if isinstance(request_reply, HTTPResponse):
|
||||
flow.response = request_reply
|
||||
else:
|
||||
flow.request = request_reply
|
||||
raw = flow.request._assemble()
|
||||
self.c.server_conn.wfile.write(raw)
|
||||
self.c.server_conn.wfile.flush()
|
||||
flow.response = self.read_response(flow)
|
||||
response_reply = self.c.channel.ask("response" if LEGACY else "httpresponse", flow.response)
|
||||
|
||||
response_reply = self.c.channel.ask("response" if LEGACY else "httpresponse",
|
||||
flow.response if LEGACY else flow)
|
||||
if response_reply is None or response_reply == KILL:
|
||||
return False
|
||||
else:
|
||||
raw = flow.response._assemble()
|
||||
self.c.client_conn.wfile.write(raw)
|
||||
self.c.client_conn.wfile.flush()
|
||||
|
||||
raw = flow.response._assemble()
|
||||
self.c.client_conn.wfile.write(raw)
|
||||
self.c.client_conn.wfile.flush()
|
||||
flow.timestamp_end = libmproxy.utils.timestamp()
|
||||
|
||||
if (http.connection_close(flow.request.http_version, flow.request.headers) or
|
||||
http.connection_close(flow.response.http_version, flow.response.headers)):
|
||||
return False
|
||||
|
||||
flow.timestamp_end = timestamp()
|
||||
if flow.request.form_in == "authority":
|
||||
self.ssl_upgrade()
|
||||
return flow
|
||||
except tcp.NetLibDisconnect, e:
|
||||
except ProtocolError, http.HttpError:
|
||||
raise NotImplementedError
|
||||
# FIXME: Implement error handling
|
||||
return False
|
||||
|
||||
def ssl_upgrade(self):
|
||||
self.c.mode = "transparent"
|
||||
self.c.determine_conntype()
|
||||
self.c.establish_ssl(server=True, client=True)
|
||||
raise ConnectionTypeChange
|
||||
|
||||
def read_request(self):
|
||||
request = HTTPRequest.from_stream(self.c.client_conn.rfile, body_size_limit=self.c.config.body_size_limit)
|
||||
|
||||
if self.c.mode == "regular":
|
||||
self.authenticate(request)
|
||||
if request.form_in == "authority" and self.c.client_conn.ssl_established:
|
||||
raise ProtocolError(502, "Must not CONNECT on SSL connection")
|
||||
raise ProtocolError(502, "Must not CONNECT on already encrypted connection")
|
||||
|
||||
# If we have a CONNECT request, we might need to intercept
|
||||
if request.form_in == "authority":
|
||||
@ -236,11 +263,7 @@ class HTTPHandler(ProtocolHandler):
|
||||
'\r\n'
|
||||
)
|
||||
self.c.client_conn.wfile.flush()
|
||||
|
||||
self.c.establish_ssl(server=True, client=True)
|
||||
self.c.mode = "transparent"
|
||||
self.c.determine_conntype()
|
||||
raise ConnectionTypeChange
|
||||
self.ssl_upgrade()
|
||||
|
||||
if self.c.mode == "regular":
|
||||
if request.form_in == "authority":
|
||||
@ -258,7 +281,8 @@ class HTTPHandler(ProtocolHandler):
|
||||
return request
|
||||
|
||||
def read_response(self, flow):
|
||||
return HTTPResponse.from_stream(self.c.server_conn.rfile, flow.request.method, body_size_limit=self.c.config.body_size_limit)
|
||||
return HTTPResponse.from_stream(self.c.server_conn.rfile, flow.request.method,
|
||||
body_size_limit=self.c.config.body_size_limit)
|
||||
|
||||
def authenticate(self, request):
|
||||
if self.c.config.authenticator:
|
||||
@ -278,7 +302,7 @@ class HTTPHandler(ProtocolHandler):
|
||||
Get a line, possibly preceded by a blank.
|
||||
"""
|
||||
line = fp.readline()
|
||||
if line == "\r\n" or line == "\n": # Possible leftover from previous message
|
||||
if line == "\r\n" or line == "\n": # Possible leftover from previous message
|
||||
line = fp.readline()
|
||||
if line == "":
|
||||
raise tcp.NetLibDisconnect
|
||||
|
@ -184,9 +184,8 @@ class ConnectionHandler:
|
||||
continue
|
||||
|
||||
self.del_server_connection()
|
||||
except (ProxyError, protocol.ProtocolError), e:
|
||||
except ProxyError, e:
|
||||
self.log(str(e))
|
||||
protocol.handle_error(self.conntype, self, e)
|
||||
# FIXME: We need to persist errors
|
||||
|
||||
self.log("disconnect")
|
||||
@ -223,8 +222,13 @@ class ConnectionHandler:
|
||||
self.channel.tell("serverconnect", self)
|
||||
|
||||
def establish_ssl(self, client, server):
|
||||
"""
|
||||
Establishes SSL on the existing connection(s) to the server or the client,
|
||||
as specified by the parameters. If the target server is on the pass-through list,
|
||||
the conntype attribute will be changed and no the SSL connection won't be wrapped.
|
||||
A protocol handler must raise a ConnTypeChanged exception if it detects that this is happening
|
||||
"""
|
||||
# TODO: Implement SSL pass-through handling and change conntype
|
||||
|
||||
if self.server_conn.host == "ycombinator.com":
|
||||
self.conntype = "tcp"
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user