move code from mitmproxy to netlib

This commit is contained in:
Thomas Kriechbaumer 2015-07-27 09:36:50 +02:00
parent fb48217224
commit 827fe824d9
5 changed files with 167 additions and 34 deletions

View File

@ -4,6 +4,7 @@ import collections
import string
import sys
import urlparse
import time
from netlib import odict, utils, tcp, http
from .. import status_codes
@ -17,10 +18,7 @@ class TCPHandler(object):
class HTTP1Protocol(object):
def __init__(self, tcp_handler=None, rfile=None, wfile=None):
if tcp_handler:
self.tcp_handler = tcp_handler
else:
self.tcp_handler = TCPHandler(rfile, wfile)
self.tcp_handler = tcp_handler or TCPHandler(rfile, wfile)
def read_request(self, include_body=True, body_size_limit=None, allow_empty=False):
@ -39,6 +37,10 @@ class HTTP1Protocol(object):
Raises:
HttpError: If the input is invalid.
"""
timestamp_start = time.time()
if hasattr(self.tcp_handler.rfile, "reset_timestamps"):
self.tcp_handler.rfile.reset_timestamps()
httpversion, host, port, scheme, method, path, headers, body = (
None, None, None, None, None, None, None, None)
@ -106,6 +108,12 @@ class HTTP1Protocol(object):
True
)
if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"):
# more accurate timestamp_start
timestamp_start = self.tcp_handler.rfile.first_byte_timestamp
timestamp_end = time.time()
return http.Request(
form_in,
method,
@ -115,7 +123,9 @@ class HTTP1Protocol(object):
path,
httpversion,
headers,
body
body,
timestamp_start,
timestamp_end,
)
@ -124,12 +134,15 @@ class HTTP1Protocol(object):
Returns an http.Response
By default, both response header and body are read.
If include_body=False is specified, content may be one of the
If include_body=False is specified, body may be one of the
following:
- None, if the response is technically allowed to have a response body
- "", if the response must not have a response body (e.g. it's a
response to a HEAD request)
"""
timestamp_start = time.time()
if hasattr(self.tcp_handler.rfile, "reset_timestamps"):
self.tcp_handler.rfile.reset_timestamps()
line = self.tcp_handler.rfile.readline()
# Possible leftover from previous message
@ -149,7 +162,7 @@ class HTTP1Protocol(object):
raise HttpError(502, "Invalid headers.")
if include_body:
content = self.read_http_body(
body = self.read_http_body(
headers,
body_size_limit,
request_method,
@ -157,10 +170,29 @@ class HTTP1Protocol(object):
False
)
else:
# if include_body==False then a None content means the body should be
# if include_body==False then a None body means the body should be
# read separately
content = None
return http.Response(httpversion, code, msg, headers, content)
body = None
if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"):
# more accurate timestamp_start
timestamp_start = self.tcp_handler.rfile.first_byte_timestamp
if include_body:
timestamp_end = time.time()
else:
timestamp_end = None
return http.Response(
httpversion,
code,
msg,
headers,
body,
timestamp_start=timestamp_start,
timestamp_end=timestamp_end,
)
def read_headers(self):

View File

@ -1,11 +1,18 @@
from __future__ import (absolute_import, print_function, division)
import itertools
import time
from hpack.hpack import Encoder, Decoder
from netlib import http, utils, odict
from . import frame
class TCPHandler(object):
def __init__(self, rfile, wfile=None):
self.rfile = rfile
self.wfile = wfile
class HTTP2Protocol(object):
ERROR_CODES = utils.BiDi(
@ -31,16 +38,26 @@ class HTTP2Protocol(object):
ALPN_PROTO_H2 = 'h2'
def __init__(self, tcp_handler, is_server=False, dump_frames=False):
self.tcp_handler = tcp_handler
def __init__(
self,
tcp_handler=None,
rfile=None,
wfile=None,
is_server=False,
dump_frames=False,
encoder=None,
decoder=None,
):
self.tcp_handler = tcp_handler or TCPHandler(rfile, wfile)
self.is_server = is_server
self.dump_frames = dump_frames
self.encoder = encoder or Encoder()
self.decoder = decoder or Decoder()
self.http2_settings = frame.HTTP2_DEFAULT_SETTINGS.copy()
self.current_stream_id = None
self.encoder = Encoder()
self.decoder = Decoder()
self.connection_preface_performed = False
self.dump_frames = dump_frames
def check_alpn(self):
alp = self.tcp_handler.get_alpn_proto_negotiated()
@ -186,29 +203,68 @@ class HTTP2Protocol(object):
self._create_headers(headers, stream_id, end_stream=(body is None)),
self._create_body(body, stream_id)))
def read_response(self, *args):
stream_id, headers, body = self._receive_transmission()
def read_response(self, request_method_='', body_size_limit_=None, include_body=True):
timestamp_start = time.time()
if hasattr(self.tcp_handler.rfile, "reset_timestamps"):
self.tcp_handler.rfile.reset_timestamps()
status = headers[':status'][0]
response = http.Response("HTTP/2", status, "", headers, body)
stream_id, headers, body = self._receive_transmission(include_body)
if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"):
# more accurate timestamp_start
timestamp_start = self.tcp_handler.rfile.first_byte_timestamp
if include_body:
timestamp_end = time.time()
else:
timestamp_end = None
response = http.Response(
(2, 0),
headers[':status'][0],
"",
headers,
body,
timestamp_start=timestamp_start,
timestamp_end=timestamp_end,
)
response.stream_id = stream_id
return response
def read_request(self):
stream_id, headers, body = self._receive_transmission()
def read_request(self, include_body=True, body_size_limit_=None, allow_empty_=False):
timestamp_start = time.time()
if hasattr(self.tcp_handler.rfile, "reset_timestamps"):
self.tcp_handler.rfile.reset_timestamps()
stream_id, headers, body = self._receive_transmission(include_body)
if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"):
# more accurate timestamp_start
timestamp_start = self.tcp_handler.rfile.first_byte_timestamp
timestamp_end = time.time()
form_in = ""
method = headers.get(':method', [''])[0]
scheme = headers.get(':scheme', [''])[0]
host = headers.get(':host', [''])[0]
port = '' # TODO: parse port number?
path = headers.get(':path', [''])[0]
request = http.Request(form_in, method, scheme, host, port, path, "HTTP/2", headers, body)
request = http.Request(
"",
headers.get_first(':method', ['']),
headers.get_first(':scheme', ['']),
headers.get_first(':host', ['']),
port,
headers.get_first(':path', ['']),
(2, 0),
headers,
body,
timestamp_start,
timestamp_end,
)
request.stream_id = stream_id
return request
def _receive_transmission(self):
def _receive_transmission(self, include_body=True):
body_expected = True
stream_id = 0

View File

@ -20,7 +20,11 @@ class Request(object):
httpversion,
headers,
body,
timestamp_start=None,
timestamp_end=None,
):
assert isinstance(headers, odict.ODictCaseless) or not headers
self.form_in = form_in
self.method = method
self.scheme = scheme
@ -30,17 +34,30 @@ class Request(object):
self.httpversion = httpversion
self.headers = headers
self.body = body
self.timestamp_start = timestamp_start
self.timestamp_end = timestamp_end
def __eq__(self, other):
return self.__dict__ == other.__dict__
try:
self_d = [self.__dict__[k] for k in self.__dict__ if k not in ('timestamp_start', 'timestamp_end')]
other_d = [other.__dict__[k] for k in other.__dict__ if k not in ('timestamp_start', 'timestamp_end')]
return self_d == other_d
except:
return False
def __repr__(self):
return "Request(%s - %s, %s)" % (self.method, self.host, self.path)
@property
def content(self):
# TODO: remove deprecated getter
return self.body
@content.setter
def content(self, content):
# TODO: remove deprecated setter
self.body = content
class EmptyRequest(Request):
def __init__(self):
@ -67,24 +84,52 @@ class Response(object):
headers,
body,
sslinfo=None,
timestamp_start=None,
timestamp_end=None,
):
assert isinstance(headers, odict.ODictCaseless) or not headers
self.httpversion = httpversion
self.status_code = status_code
self.msg = msg
self.headers = headers
self.body = body
self.sslinfo = sslinfo
self.timestamp_start = timestamp_start
self.timestamp_end = timestamp_end
def __eq__(self, other):
return self.__dict__ == other.__dict__
try:
self_d = [self.__dict__[k] for k in self.__dict__ if k not in ('timestamp_start', 'timestamp_end')]
other_d = [other.__dict__[k] for k in other.__dict__ if k not in ('timestamp_start', 'timestamp_end')]
return self_d == other_d
except:
return False
def __repr__(self):
return "Response(%s - %s)" % (self.status_code, self.msg)
@property
def content(self):
# TODO: remove deprecated getter
return self.body
@content.setter
def content(self, content):
# TODO: remove deprecated setter
self.body = content
@property
def code(self):
# TODO: remove deprecated getter
return self.status_code
@code.setter
def code(self, code):
# TODO: remove deprecated setter
self.status_code = code
def is_valid_port(port):
if not 0 <= port <= 65535:

View File

@ -297,10 +297,10 @@ class TestReadResponseNoContentLength(tservers.ServerTestBase):
def test_read_response():
def tst(data, method, limit, include_body=True):
def tst(data, method, body_size_limit, include_body=True):
data = textwrap.dedent(data)
return mock_protocol(data).read_response(
method, limit, include_body=include_body
method, body_size_limit, include_body=include_body
)
tutils.raises("server disconnect", tst, "", "GET", None)

View File

@ -253,7 +253,7 @@ class TestReadResponse(tservers.ServerTestBase):
resp = protocol.read_response()
assert resp.httpversion == "HTTP/2"
assert resp.httpversion == (2, 0)
assert resp.status_code == "200"
assert resp.msg == ""
assert resp.headers.lst == [[':status', '200'], ['etag', 'foobar']]
@ -279,7 +279,7 @@ class TestReadEmptyResponse(tservers.ServerTestBase):
resp = protocol.read_response()
assert resp.stream_id
assert resp.httpversion == "HTTP/2"
assert resp.httpversion == (2, 0)
assert resp.status_code == "200"
assert resp.msg == ""
assert resp.headers.lst == [[':status', '200'], ['etag', 'foobar']]