mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2025-01-30 14:58:38 +00:00
move code from mitmproxy to netlib
This commit is contained in:
parent
fb48217224
commit
827fe824d9
@ -4,6 +4,7 @@ import collections
|
||||
import string
|
||||
import sys
|
||||
import urlparse
|
||||
import time
|
||||
|
||||
from netlib import odict, utils, tcp, http
|
||||
from .. import status_codes
|
||||
@ -17,10 +18,7 @@ class TCPHandler(object):
|
||||
class HTTP1Protocol(object):
|
||||
|
||||
def __init__(self, tcp_handler=None, rfile=None, wfile=None):
|
||||
if tcp_handler:
|
||||
self.tcp_handler = tcp_handler
|
||||
else:
|
||||
self.tcp_handler = TCPHandler(rfile, wfile)
|
||||
self.tcp_handler = tcp_handler or TCPHandler(rfile, wfile)
|
||||
|
||||
|
||||
def read_request(self, include_body=True, body_size_limit=None, allow_empty=False):
|
||||
@ -39,6 +37,10 @@ class HTTP1Protocol(object):
|
||||
Raises:
|
||||
HttpError: If the input is invalid.
|
||||
"""
|
||||
timestamp_start = time.time()
|
||||
if hasattr(self.tcp_handler.rfile, "reset_timestamps"):
|
||||
self.tcp_handler.rfile.reset_timestamps()
|
||||
|
||||
httpversion, host, port, scheme, method, path, headers, body = (
|
||||
None, None, None, None, None, None, None, None)
|
||||
|
||||
@ -106,6 +108,12 @@ class HTTP1Protocol(object):
|
||||
True
|
||||
)
|
||||
|
||||
if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"):
|
||||
# more accurate timestamp_start
|
||||
timestamp_start = self.tcp_handler.rfile.first_byte_timestamp
|
||||
|
||||
timestamp_end = time.time()
|
||||
|
||||
return http.Request(
|
||||
form_in,
|
||||
method,
|
||||
@ -115,7 +123,9 @@ class HTTP1Protocol(object):
|
||||
path,
|
||||
httpversion,
|
||||
headers,
|
||||
body
|
||||
body,
|
||||
timestamp_start,
|
||||
timestamp_end,
|
||||
)
|
||||
|
||||
|
||||
@ -124,12 +134,15 @@ class HTTP1Protocol(object):
|
||||
Returns an http.Response
|
||||
|
||||
By default, both response header and body are read.
|
||||
If include_body=False is specified, content may be one of the
|
||||
If include_body=False is specified, body may be one of the
|
||||
following:
|
||||
- None, if the response is technically allowed to have a response body
|
||||
- "", if the response must not have a response body (e.g. it's a
|
||||
response to a HEAD request)
|
||||
"""
|
||||
timestamp_start = time.time()
|
||||
if hasattr(self.tcp_handler.rfile, "reset_timestamps"):
|
||||
self.tcp_handler.rfile.reset_timestamps()
|
||||
|
||||
line = self.tcp_handler.rfile.readline()
|
||||
# Possible leftover from previous message
|
||||
@ -149,7 +162,7 @@ class HTTP1Protocol(object):
|
||||
raise HttpError(502, "Invalid headers.")
|
||||
|
||||
if include_body:
|
||||
content = self.read_http_body(
|
||||
body = self.read_http_body(
|
||||
headers,
|
||||
body_size_limit,
|
||||
request_method,
|
||||
@ -157,10 +170,29 @@ class HTTP1Protocol(object):
|
||||
False
|
||||
)
|
||||
else:
|
||||
# if include_body==False then a None content means the body should be
|
||||
# if include_body==False then a None body means the body should be
|
||||
# read separately
|
||||
content = None
|
||||
return http.Response(httpversion, code, msg, headers, content)
|
||||
body = None
|
||||
|
||||
|
||||
if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"):
|
||||
# more accurate timestamp_start
|
||||
timestamp_start = self.tcp_handler.rfile.first_byte_timestamp
|
||||
|
||||
if include_body:
|
||||
timestamp_end = time.time()
|
||||
else:
|
||||
timestamp_end = None
|
||||
|
||||
return http.Response(
|
||||
httpversion,
|
||||
code,
|
||||
msg,
|
||||
headers,
|
||||
body,
|
||||
timestamp_start=timestamp_start,
|
||||
timestamp_end=timestamp_end,
|
||||
)
|
||||
|
||||
|
||||
def read_headers(self):
|
||||
|
@ -1,11 +1,18 @@
|
||||
from __future__ import (absolute_import, print_function, division)
|
||||
import itertools
|
||||
import time
|
||||
|
||||
from hpack.hpack import Encoder, Decoder
|
||||
from netlib import http, utils, odict
|
||||
from . import frame
|
||||
|
||||
|
||||
class TCPHandler(object):
|
||||
def __init__(self, rfile, wfile=None):
|
||||
self.rfile = rfile
|
||||
self.wfile = wfile
|
||||
|
||||
|
||||
class HTTP2Protocol(object):
|
||||
|
||||
ERROR_CODES = utils.BiDi(
|
||||
@ -31,16 +38,26 @@ class HTTP2Protocol(object):
|
||||
|
||||
ALPN_PROTO_H2 = 'h2'
|
||||
|
||||
def __init__(self, tcp_handler, is_server=False, dump_frames=False):
|
||||
self.tcp_handler = tcp_handler
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tcp_handler=None,
|
||||
rfile=None,
|
||||
wfile=None,
|
||||
is_server=False,
|
||||
dump_frames=False,
|
||||
encoder=None,
|
||||
decoder=None,
|
||||
):
|
||||
self.tcp_handler = tcp_handler or TCPHandler(rfile, wfile)
|
||||
self.is_server = is_server
|
||||
self.dump_frames = dump_frames
|
||||
self.encoder = encoder or Encoder()
|
||||
self.decoder = decoder or Decoder()
|
||||
|
||||
self.http2_settings = frame.HTTP2_DEFAULT_SETTINGS.copy()
|
||||
self.current_stream_id = None
|
||||
self.encoder = Encoder()
|
||||
self.decoder = Decoder()
|
||||
self.connection_preface_performed = False
|
||||
self.dump_frames = dump_frames
|
||||
|
||||
def check_alpn(self):
|
||||
alp = self.tcp_handler.get_alpn_proto_negotiated()
|
||||
@ -186,29 +203,68 @@ class HTTP2Protocol(object):
|
||||
self._create_headers(headers, stream_id, end_stream=(body is None)),
|
||||
self._create_body(body, stream_id)))
|
||||
|
||||
def read_response(self, *args):
|
||||
stream_id, headers, body = self._receive_transmission()
|
||||
def read_response(self, request_method_='', body_size_limit_=None, include_body=True):
|
||||
timestamp_start = time.time()
|
||||
if hasattr(self.tcp_handler.rfile, "reset_timestamps"):
|
||||
self.tcp_handler.rfile.reset_timestamps()
|
||||
|
||||
status = headers[':status'][0]
|
||||
response = http.Response("HTTP/2", status, "", headers, body)
|
||||
stream_id, headers, body = self._receive_transmission(include_body)
|
||||
|
||||
if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"):
|
||||
# more accurate timestamp_start
|
||||
timestamp_start = self.tcp_handler.rfile.first_byte_timestamp
|
||||
|
||||
if include_body:
|
||||
timestamp_end = time.time()
|
||||
else:
|
||||
timestamp_end = None
|
||||
|
||||
response = http.Response(
|
||||
(2, 0),
|
||||
headers[':status'][0],
|
||||
"",
|
||||
headers,
|
||||
body,
|
||||
timestamp_start=timestamp_start,
|
||||
timestamp_end=timestamp_end,
|
||||
)
|
||||
response.stream_id = stream_id
|
||||
|
||||
return response
|
||||
|
||||
def read_request(self):
|
||||
stream_id, headers, body = self._receive_transmission()
|
||||
def read_request(self, include_body=True, body_size_limit_=None, allow_empty_=False):
|
||||
timestamp_start = time.time()
|
||||
if hasattr(self.tcp_handler.rfile, "reset_timestamps"):
|
||||
self.tcp_handler.rfile.reset_timestamps()
|
||||
|
||||
stream_id, headers, body = self._receive_transmission(include_body)
|
||||
|
||||
if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"):
|
||||
# more accurate timestamp_start
|
||||
timestamp_start = self.tcp_handler.rfile.first_byte_timestamp
|
||||
|
||||
timestamp_end = time.time()
|
||||
|
||||
form_in = ""
|
||||
method = headers.get(':method', [''])[0]
|
||||
scheme = headers.get(':scheme', [''])[0]
|
||||
host = headers.get(':host', [''])[0]
|
||||
port = '' # TODO: parse port number?
|
||||
path = headers.get(':path', [''])[0]
|
||||
|
||||
request = http.Request(form_in, method, scheme, host, port, path, "HTTP/2", headers, body)
|
||||
request = http.Request(
|
||||
"",
|
||||
headers.get_first(':method', ['']),
|
||||
headers.get_first(':scheme', ['']),
|
||||
headers.get_first(':host', ['']),
|
||||
port,
|
||||
headers.get_first(':path', ['']),
|
||||
(2, 0),
|
||||
headers,
|
||||
body,
|
||||
timestamp_start,
|
||||
timestamp_end,
|
||||
)
|
||||
request.stream_id = stream_id
|
||||
|
||||
return request
|
||||
|
||||
def _receive_transmission(self):
|
||||
def _receive_transmission(self, include_body=True):
|
||||
body_expected = True
|
||||
|
||||
stream_id = 0
|
||||
|
@ -20,7 +20,11 @@ class Request(object):
|
||||
httpversion,
|
||||
headers,
|
||||
body,
|
||||
timestamp_start=None,
|
||||
timestamp_end=None,
|
||||
):
|
||||
assert isinstance(headers, odict.ODictCaseless) or not headers
|
||||
|
||||
self.form_in = form_in
|
||||
self.method = method
|
||||
self.scheme = scheme
|
||||
@ -30,17 +34,30 @@ class Request(object):
|
||||
self.httpversion = httpversion
|
||||
self.headers = headers
|
||||
self.body = body
|
||||
self.timestamp_start = timestamp_start
|
||||
self.timestamp_end = timestamp_end
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.__dict__ == other.__dict__
|
||||
try:
|
||||
self_d = [self.__dict__[k] for k in self.__dict__ if k not in ('timestamp_start', 'timestamp_end')]
|
||||
other_d = [other.__dict__[k] for k in other.__dict__ if k not in ('timestamp_start', 'timestamp_end')]
|
||||
return self_d == other_d
|
||||
except:
|
||||
return False
|
||||
|
||||
def __repr__(self):
|
||||
return "Request(%s - %s, %s)" % (self.method, self.host, self.path)
|
||||
|
||||
@property
|
||||
def content(self):
|
||||
# TODO: remove deprecated getter
|
||||
return self.body
|
||||
|
||||
@content.setter
|
||||
def content(self, content):
|
||||
# TODO: remove deprecated setter
|
||||
self.body = content
|
||||
|
||||
|
||||
class EmptyRequest(Request):
|
||||
def __init__(self):
|
||||
@ -67,24 +84,52 @@ class Response(object):
|
||||
headers,
|
||||
body,
|
||||
sslinfo=None,
|
||||
timestamp_start=None,
|
||||
timestamp_end=None,
|
||||
):
|
||||
assert isinstance(headers, odict.ODictCaseless) or not headers
|
||||
|
||||
self.httpversion = httpversion
|
||||
self.status_code = status_code
|
||||
self.msg = msg
|
||||
self.headers = headers
|
||||
self.body = body
|
||||
self.sslinfo = sslinfo
|
||||
self.timestamp_start = timestamp_start
|
||||
self.timestamp_end = timestamp_end
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.__dict__ == other.__dict__
|
||||
try:
|
||||
self_d = [self.__dict__[k] for k in self.__dict__ if k not in ('timestamp_start', 'timestamp_end')]
|
||||
other_d = [other.__dict__[k] for k in other.__dict__ if k not in ('timestamp_start', 'timestamp_end')]
|
||||
return self_d == other_d
|
||||
except:
|
||||
return False
|
||||
|
||||
def __repr__(self):
|
||||
return "Response(%s - %s)" % (self.status_code, self.msg)
|
||||
|
||||
@property
|
||||
def content(self):
|
||||
# TODO: remove deprecated getter
|
||||
return self.body
|
||||
|
||||
@content.setter
|
||||
def content(self, content):
|
||||
# TODO: remove deprecated setter
|
||||
self.body = content
|
||||
|
||||
@property
|
||||
def code(self):
|
||||
# TODO: remove deprecated getter
|
||||
return self.status_code
|
||||
|
||||
@code.setter
|
||||
def code(self, code):
|
||||
# TODO: remove deprecated setter
|
||||
self.status_code = code
|
||||
|
||||
|
||||
|
||||
def is_valid_port(port):
|
||||
if not 0 <= port <= 65535:
|
||||
|
@ -297,10 +297,10 @@ class TestReadResponseNoContentLength(tservers.ServerTestBase):
|
||||
|
||||
|
||||
def test_read_response():
|
||||
def tst(data, method, limit, include_body=True):
|
||||
def tst(data, method, body_size_limit, include_body=True):
|
||||
data = textwrap.dedent(data)
|
||||
return mock_protocol(data).read_response(
|
||||
method, limit, include_body=include_body
|
||||
method, body_size_limit, include_body=include_body
|
||||
)
|
||||
|
||||
tutils.raises("server disconnect", tst, "", "GET", None)
|
||||
|
@ -253,7 +253,7 @@ class TestReadResponse(tservers.ServerTestBase):
|
||||
|
||||
resp = protocol.read_response()
|
||||
|
||||
assert resp.httpversion == "HTTP/2"
|
||||
assert resp.httpversion == (2, 0)
|
||||
assert resp.status_code == "200"
|
||||
assert resp.msg == ""
|
||||
assert resp.headers.lst == [[':status', '200'], ['etag', 'foobar']]
|
||||
@ -279,7 +279,7 @@ class TestReadEmptyResponse(tservers.ServerTestBase):
|
||||
resp = protocol.read_response()
|
||||
|
||||
assert resp.stream_id
|
||||
assert resp.httpversion == "HTTP/2"
|
||||
assert resp.httpversion == (2, 0)
|
||||
assert resp.status_code == "200"
|
||||
assert resp.msg == ""
|
||||
assert resp.headers.lst == [[':status', '200'], ['etag', 'foobar']]
|
||||
|
Loading…
Reference in New Issue
Block a user