Merge branch 'master' of github.com:cortesi/mitmproxy

This commit is contained in:
Aldo Cortesi 2016-06-01 09:58:01 +12:00
commit a061e45877
45 changed files with 622 additions and 614 deletions

View File

@ -6,8 +6,9 @@ import base64
import configargparse
from netlib.tcp import Address, sslversion_choices
import netlib.utils
from . import filt, utils, version
import netlib.http.url
from netlib import human
from . import filt, version
from .proxy import config
APP_HOST = "mitm.it"
@ -105,7 +106,7 @@ def parse_setheader(s):
def parse_server_spec(url):
try:
p = netlib.utils.parse_url(url)
p = netlib.http.url.parse(url)
if p[0] not in ("http", "https"):
raise ValueError()
except ValueError:
@ -135,7 +136,9 @@ def get_common_options(options):
if options.stickyauth_filt:
stickyauth = options.stickyauth_filt
stream_large_bodies = utils.parse_size(options.stream_large_bodies)
stream_large_bodies = options.stream_large_bodies
if stream_large_bodies:
stream_large_bodies = human.parse_size(stream_large_bodies)
reps = []
for i in options.replace:

View File

@ -4,7 +4,8 @@ import urwid
import urwid.util
import os
import netlib.utils
import netlib
from netlib import human
from .. import utils
from .. import flow
@ -419,7 +420,7 @@ def format_flow(f, focus, extended=False, hostheader=False, marked=False):
)
if f.response:
if f.response.content:
contentdesc = netlib.utils.pretty_size(len(f.response.content))
contentdesc = human.pretty_size(len(f.response.content))
elif f.response.content is None:
contentdesc = "[content missing]"
else:
@ -427,7 +428,7 @@ def format_flow(f, focus, extended=False, hostheader=False, marked=False):
duration = 0
if f.response.timestamp_end and f.request.timestamp_start:
duration = f.response.timestamp_end - f.request.timestamp_start
roundtrip = utils.pretty_duration(duration)
roundtrip = human.pretty_duration(duration)
d.update(dict(
resp_code = f.response.status_code,

View File

@ -1,7 +1,7 @@
from __future__ import absolute_import
import urwid
import netlib.utils
import netlib.http.url
from . import common, signals
@ -343,7 +343,7 @@ class FlowListBox(urwid.ListBox):
)
def new_request(self, url, method):
parts = netlib.utils.parse_url(str(url))
parts = netlib.http.url.parse(str(url))
if not parts:
signals.status_message.send(message="Invalid Url")
return

View File

@ -3,6 +3,7 @@ import os.path
import urwid
import netlib.utils
from netlib import human
from . import pathedit, signals, common
@ -193,7 +194,7 @@ class StatusBar(urwid.WidgetWrap):
opts.append("following")
if self.master.stream_large_bodies:
opts.append(
"stream:%s" % netlib.utils.pretty_size(
"stream:%s" % human.pretty_size(
self.master.stream_large_bodies.max_size
)
)
@ -203,7 +204,7 @@ class StatusBar(urwid.WidgetWrap):
if self.master.server.config.mode in ["reverse", "upstream"]:
dst = self.master.server.config.upstream_server
r.append("[dest:%s]" % netlib.utils.unparse_url(
r.append("[dest:%s]" % netlib.utils.unparse(
dst.scheme,
dst.address.host,
dst.address.port

View File

@ -27,7 +27,9 @@ import html2text
import six
from netlib.odict import ODict
from netlib import encoding
from netlib.utils import clean_bin, hexdump, urldecode, multipartdecode, parse_content_type
import netlib.http.headers
from netlib.http import url, multipart
from netlib.utils import clean_bin, hexdump
from . import utils
from .exceptions import ContentViewException
from .contrib import jsbeautifier
@ -120,7 +122,7 @@ class ViewAuto(View):
headers = metadata.get("headers", {})
ctype = headers.get("content-type")
if data and ctype:
ct = parse_content_type(ctype) if ctype else None
ct = netlib.http.headers.parse_content_type(ctype) if ctype else None
ct = "%s/%s" % (ct[0], ct[1])
if ct in content_types_map:
return content_types_map[ct][0](data, **metadata)
@ -257,7 +259,7 @@ class ViewURLEncoded(View):
content_types = ["application/x-www-form-urlencoded"]
def __call__(self, data, **metadata):
d = urldecode(data)
d = url.decode(data)
return "URLEncoded form", format_dict(ODict(d))
@ -274,7 +276,7 @@ class ViewMultipart(View):
def __call__(self, data, **metadata):
headers = metadata.get("headers", {})
v = multipartdecode(headers, data)
v = multipart.decode(headers, data)
if v:
return "Multipart form", self._format(v)

View File

@ -4,8 +4,8 @@ import sys
import click
import itertools
from netlib import tcp
from netlib.utils import bytes_to_escaped_str, pretty_size
from netlib import tcp, human
from netlib.utils import bytes_to_escaped_str
from . import flow, filt, contentviews, controller
from .exceptions import ContentViewException, FlowReadException, ScriptException
@ -287,7 +287,7 @@ class DumpMaster(flow.FlowMaster):
if flow.response.content is None:
size = "(content missing)"
else:
size = pretty_size(len(flow.response.content))
size = human.pretty_size(len(flow.response.content))
size = click.style(size, bold=True)
arrows = click.style("<<", bold=True)

View File

@ -5,7 +5,7 @@ from textwrap import dedent
from six.moves.urllib.parse import quote, quote_plus
import netlib.http
from netlib.utils import parse_content_type
import netlib.http.headers
def curl_command(flow):
@ -88,7 +88,7 @@ def raw_request(flow):
def is_json(headers, content):
if headers:
ct = parse_content_type(headers.get("content-type", ""))
ct = netlib.http.headers.parse_content_type(headers.get("content-type", ""))
if ct and "%s/%s" % (ct[0], ct[1]) == "application/json":
try:
return json.loads(content)

View File

@ -1,11 +1,11 @@
import time
from typing import List
from netlib.utils import Serializable
import netlib.basetypes
from .flow import Flow
class TCPMessage(Serializable):
class TCPMessage(netlib.basetypes.Serializable):
def __init__(self, from_client, content, timestamp=None):
self.content = content

View File

@ -14,7 +14,8 @@ from hyperframe.frame import PriorityFrame
from netlib.tcp import ssl_read_select
from netlib.exceptions import HttpException
from netlib.http import Headers
from netlib.utils import http2_read_raw_frame, parse_url
from netlib.http.http2 import framereader
import netlib.http.url
from .base import Layer
from .http import _HttpTransmissionLayer, HttpLayer
@ -233,7 +234,7 @@ class Http2Layer(Layer):
with source_conn.h2.lock:
try:
raw_frame = b''.join(http2_read_raw_frame(source_conn.rfile))
raw_frame = b''.join(framereader.http2_read_raw_frame(source_conn.rfile))
except:
# read frame failed: connection closed
self._kill_all_streams()
@ -306,6 +307,9 @@ class Http2SingleStreamLayer(_HttpTransmissionLayer, threading.Thread):
method = self.request_headers.get(':method', 'GET')
scheme = self.request_headers.get(':scheme', 'https')
path = self.request_headers.get(':path', '/')
self.request_headers.clear(":method")
self.request_headers.clear(":scheme")
self.request_headers.clear(":path")
host = None
port = None
@ -316,7 +320,7 @@ class Http2SingleStreamLayer(_HttpTransmissionLayer, threading.Thread):
else: # pragma: no cover
first_line_format = "absolute"
# FIXME: verify if path or :host contains what we need
scheme, host, port, _ = parse_url(path)
scheme, host, port, _ = netlib.http.url.parse(path)
if authority:
host, _, port = authority.partition(':')
@ -362,10 +366,15 @@ class Http2SingleStreamLayer(_HttpTransmissionLayer, threading.Thread):
self.server_stream_id = self.server_conn.h2.get_next_available_stream_id()
self.server_to_client_stream_ids[self.server_stream_id] = self.client_stream_id
headers = message.headers.copy()
headers.insert(0, ":path", message.path)
headers.insert(0, ":method", message.method)
headers.insert(0, ":scheme", message.scheme)
self.server_conn.h2.safe_send_headers(
self.is_zombie,
self.server_stream_id,
message.headers
headers
)
self.server_conn.h2.safe_send_body(
self.is_zombie,
@ -379,12 +388,14 @@ class Http2SingleStreamLayer(_HttpTransmissionLayer, threading.Thread):
self.response_arrived.wait()
status_code = int(self.response_headers.get(':status', 502))
headers = self.response_headers.copy()
headers.clear(":status")
return HTTPResponse(
http_version=b"HTTP/2.0",
status_code=status_code,
reason='',
headers=self.response_headers,
headers=headers,
content=None,
timestamp_start=self.timestamp_start,
timestamp_end=self.timestamp_end,
@ -404,10 +415,12 @@ class Http2SingleStreamLayer(_HttpTransmissionLayer, threading.Thread):
raise Http2ProtocolException("Zombie Stream")
def send_response_headers(self, response):
headers = response.headers.copy()
headers.insert(0, ":status", str(response.status_code))
self.client_conn.h2.safe_send_headers(
self.is_zombie,
self.client_stream_id,
response.headers
headers
)
if self.zombie: # pragma: no cover
raise Http2ProtocolException("Zombie Stream")

View File

@ -6,11 +6,11 @@ import re
import six
from OpenSSL import SSL
from netlib import certutils, tcp
from netlib import certutils, tcp, human
from netlib.http import authentication
from netlib.tcp import Address, sslversion_choices
from .. import utils, platform
from .. import platform
CONF_BASENAME = "mitmproxy"
CA_DIR = "~/.mitmproxy"
@ -125,7 +125,9 @@ class ProxyConfig:
def process_proxy_options(parser, options):
body_size_limit = utils.parse_size(options.body_size_limit)
body_size_limit = options.body_size_limit
if body_size_limit:
body_size_limit = human.parse_size(body_size_limit)
c = 0
mode, upstream_server, upstream_auth = "regular", None, None

View File

@ -3,7 +3,7 @@ from __future__ import absolute_import
import six
from typing import List, Any
from netlib.utils import Serializable
import netlib.basetypes
def _is_list(cls):
@ -13,7 +13,7 @@ def _is_list(cls):
return issubclass(cls, List) or is_list_bugfix
class StateObject(Serializable):
class StateObject(netlib.basetypes.Serializable):
"""
An object with serializable state.

View File

@ -58,20 +58,6 @@ def pretty_json(s):
return json.dumps(p, sort_keys=True, indent=4)
def pretty_duration(secs):
formatters = [
(100, "{:.0f}s"),
(10, "{:2.1f}s"),
(1, "{:1.2f}s"),
]
for limit, formatter in formatters:
if secs >= limit:
return formatter.format(secs)
# less than 1 sec
return "{:.0f}ms".format(secs * 1000)
pkg_data = netlib.utils.Data(__name__)
@ -117,32 +103,3 @@ def clean_hanging_newline(t):
if t and t[-1] == "\n":
return t[:-1]
return t
def parse_size(s):
"""
Parses a size specification. Valid specifications are:
123: bytes
123k: kilobytes
123m: megabytes
123g: gigabytes
"""
if not s:
return None
mult = None
if s[-1].lower() == "k":
mult = 1024**1
elif s[-1].lower() == "m":
mult = 1024**2
elif s[-1].lower() == "g":
mult = 1024**3
if mult:
s = s[:-1]
else:
mult = 1
try:
return int(s) * mult
except ValueError:
raise ValueError("Invalid size specification: %s" % s)

34
netlib/basetypes.py Normal file
View File

@ -0,0 +1,34 @@
import six
import abc
@six.add_metaclass(abc.ABCMeta)
class Serializable(object):
"""
Abstract Base Class that defines an API to save an object's state and restore it later on.
"""
@classmethod
@abc.abstractmethod
def from_state(cls, state):
"""
Create a new object from the given state.
"""
raise NotImplementedError()
@abc.abstractmethod
def get_state(self):
"""
Retrieve object state.
"""
raise NotImplementedError()
@abc.abstractmethod
def set_state(self, state):
"""
Set object state to the given state.
"""
raise NotImplementedError()
def copy(self):
return self.from_state(self.get_state())

View File

@ -12,7 +12,7 @@ from pyasn1.codec.der.decoder import decode
from pyasn1.error import PyAsn1Error
import OpenSSL
from .utils import Serializable
from . import basetypes
# Default expiry must not be too long: https://github.com/mitmproxy/mitmproxy/issues/815
@ -364,7 +364,7 @@ class _GeneralNames(univ.SequenceOf):
constraint.ValueSizeConstraint(1, 1024)
class SSLCert(Serializable):
class SSLCert(basetypes.Serializable):
def __init__(self, cert):
"""

View File

@ -175,3 +175,30 @@ class Headers(MultiDict):
fields.append([name, value])
self.fields = fields
return replacements
def parse_content_type(c):
"""
A simple parser for content-type values. Returns a (type, subtype,
parameters) tuple, where type and subtype are strings, and parameters
is a dict. If the string could not be parsed, return None.
E.g. the following string:
text/html; charset=UTF-8
Returns:
("text", "html", {"charset": "UTF-8"})
"""
parts = c.split(";", 1)
ts = parts[0].split("/", 1)
if len(ts) != 2:
return None
d = {}
if len(parts) == 2:
for i in parts[1].split(";"):
clause = i.split("=", 1)
if len(clause) == 2:
d[clause[0].strip()] = clause[1].strip()
return ts[0].lower(), ts[1].lower(), d

View File

@ -6,6 +6,19 @@ import re
from ... import utils
from ...exceptions import HttpReadDisconnect, HttpSyntaxException, HttpException, TcpDisconnect
from .. import Request, Response, Headers
from .. import url
def get_header_tokens(headers, key):
"""
Retrieve all tokens for a header key. A number of different headers
follow a pattern where each header line can containe comma-separated
tokens, and headers can be set multiple times.
"""
if key not in headers:
return []
tokens = headers[key].split(",")
return [token.strip() for token in tokens]
def read_request(rfile, body_size_limit=None):
@ -147,7 +160,7 @@ def connection_close(http_version, headers):
"""
# At first, check if we have an explicit Connection header.
if "connection" in headers:
tokens = utils.get_header_tokens(headers, "connection")
tokens = get_header_tokens(headers, "connection")
if "close" in tokens:
return True
elif "keep-alive" in tokens:
@ -240,7 +253,7 @@ def _read_request_line(rfile):
scheme, path = None, None
else:
form = "absolute"
scheme, host, port, path = utils.parse_url(path)
scheme, host, port, path = url.parse(path)
_check_http_version(http_version)
except ValueError:

View File

@ -2,11 +2,12 @@ from __future__ import (absolute_import, print_function, division)
import itertools
import time
import hyperframe.frame
from hpack.hpack import Encoder, Decoder
from ... import utils
from .. import Headers, Response, Request
from hyperframe import frame
from .. import Headers, Response, Request, url
from . import framereader
class TCPHandler(object):
@ -38,12 +39,12 @@ class HTTP2Protocol(object):
CLIENT_CONNECTION_PREFACE = b'PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n'
HTTP2_DEFAULT_SETTINGS = {
frame.SettingsFrame.HEADER_TABLE_SIZE: 4096,
frame.SettingsFrame.ENABLE_PUSH: 1,
frame.SettingsFrame.MAX_CONCURRENT_STREAMS: None,
frame.SettingsFrame.INITIAL_WINDOW_SIZE: 2 ** 16 - 1,
frame.SettingsFrame.MAX_FRAME_SIZE: 2 ** 14,
frame.SettingsFrame.MAX_HEADER_LIST_SIZE: None,
hyperframe.frame.SettingsFrame.HEADER_TABLE_SIZE: 4096,
hyperframe.frame.SettingsFrame.ENABLE_PUSH: 1,
hyperframe.frame.SettingsFrame.MAX_CONCURRENT_STREAMS: None,
hyperframe.frame.SettingsFrame.INITIAL_WINDOW_SIZE: 2 ** 16 - 1,
hyperframe.frame.SettingsFrame.MAX_FRAME_SIZE: 2 ** 14,
hyperframe.frame.SettingsFrame.MAX_HEADER_LIST_SIZE: None,
}
def __init__(
@ -98,6 +99,11 @@ class HTTP2Protocol(object):
method = headers.get(':method', 'GET')
scheme = headers.get(':scheme', 'https')
path = headers.get(':path', '/')
headers.clear(":method")
headers.clear(":scheme")
headers.clear(":path")
host = None
port = None
@ -112,7 +118,7 @@ class HTTP2Protocol(object):
else:
first_line_format = "absolute"
# FIXME: verify if path or :host contains what we need
scheme, host, port, _ = utils.parse_url(path)
scheme, host, port, _ = url.parse(path)
scheme = scheme.decode('ascii')
host = host.decode('ascii')
@ -202,12 +208,9 @@ class HTTP2Protocol(object):
if ':authority' not in headers:
headers.insert(0, b':authority', authority.encode('ascii'))
if ':scheme' not in headers:
headers.insert(0, b':scheme', request.scheme.encode('ascii'))
if ':path' not in headers:
headers.insert(0, b':path', request.path.encode('ascii'))
if ':method' not in headers:
headers.insert(0, b':method', request.method.encode('ascii'))
headers.insert(0, b':scheme', request.scheme.encode('ascii'))
headers.insert(0, b':path', request.path.encode('ascii'))
headers.insert(0, b':method', request.method.encode('ascii'))
if hasattr(request, 'stream_id'):
stream_id = request.stream_id
@ -251,9 +254,9 @@ class HTTP2Protocol(object):
magic = self.tcp_handler.rfile.safe_read(magic_length)
assert magic == self.CLIENT_CONNECTION_PREFACE
frm = frame.SettingsFrame(settings={
frame.SettingsFrame.ENABLE_PUSH: 0,
frame.SettingsFrame.MAX_CONCURRENT_STREAMS: 1,
frm = hyperframe.frame.SettingsFrame(settings={
hyperframe.frame.SettingsFrame.ENABLE_PUSH: 0,
hyperframe.frame.SettingsFrame.MAX_CONCURRENT_STREAMS: 1,
})
self.send_frame(frm, hide=True)
self._receive_settings(hide=True)
@ -264,7 +267,7 @@ class HTTP2Protocol(object):
self.tcp_handler.wfile.write(self.CLIENT_CONNECTION_PREFACE)
self.send_frame(frame.SettingsFrame(), hide=True)
self.send_frame(hyperframe.frame.SettingsFrame(), hide=True)
self._receive_settings(hide=True) # server announces own settings
self._receive_settings(hide=True) # server acks my settings
@ -277,18 +280,18 @@ class HTTP2Protocol(object):
def read_frame(self, hide=False):
while True:
frm = utils.http2_read_frame(self.tcp_handler.rfile)
frm = framereader.http2_read_frame(self.tcp_handler.rfile)
if not hide and self.dump_frames: # pragma no cover
print(frm.human_readable("<<"))
if isinstance(frm, frame.PingFrame):
raw_bytes = frame.PingFrame(flags=['ACK'], payload=frm.payload).serialize()
if isinstance(frm, hyperframe.frame.PingFrame):
raw_bytes = hyperframe.frame.PingFrame(flags=['ACK'], payload=frm.payload).serialize()
self.tcp_handler.wfile.write(raw_bytes)
self.tcp_handler.wfile.flush()
continue
if isinstance(frm, frame.SettingsFrame) and 'ACK' not in frm.flags:
if isinstance(frm, hyperframe.frame.SettingsFrame) and 'ACK' not in frm.flags:
self._apply_settings(frm.settings, hide)
if isinstance(frm, frame.DataFrame) and frm.flow_controlled_length > 0:
if isinstance(frm, hyperframe.frame.DataFrame) and frm.flow_controlled_length > 0:
self._update_flow_control_window(frm.stream_id, frm.flow_controlled_length)
return frm
@ -300,7 +303,7 @@ class HTTP2Protocol(object):
return True
def _handle_unexpected_frame(self, frm):
if isinstance(frm, frame.SettingsFrame):
if isinstance(frm, hyperframe.frame.SettingsFrame):
return
if self.unhandled_frame_cb:
self.unhandled_frame_cb(frm)
@ -308,7 +311,7 @@ class HTTP2Protocol(object):
def _receive_settings(self, hide=False):
while True:
frm = self.read_frame(hide)
if isinstance(frm, frame.SettingsFrame):
if isinstance(frm, hyperframe.frame.SettingsFrame):
break
else:
self._handle_unexpected_frame(frm)
@ -332,26 +335,26 @@ class HTTP2Protocol(object):
old_value = '-'
self.http2_settings[setting] = value
frm = frame.SettingsFrame(flags=['ACK'])
frm = hyperframe.frame.SettingsFrame(flags=['ACK'])
self.send_frame(frm, hide)
def _update_flow_control_window(self, stream_id, increment):
frm = frame.WindowUpdateFrame(stream_id=0, window_increment=increment)
frm = hyperframe.frame.WindowUpdateFrame(stream_id=0, window_increment=increment)
self.send_frame(frm)
frm = frame.WindowUpdateFrame(stream_id=stream_id, window_increment=increment)
frm = hyperframe.frame.WindowUpdateFrame(stream_id=stream_id, window_increment=increment)
self.send_frame(frm)
def _create_headers(self, headers, stream_id, end_stream=True):
def frame_cls(chunks):
for i in chunks:
if i == 0:
yield frame.HeadersFrame, i
yield hyperframe.frame.HeadersFrame, i
else:
yield frame.ContinuationFrame, i
yield hyperframe.frame.ContinuationFrame, i
header_block_fragment = self.encoder.encode(headers.fields)
chunk_size = self.http2_settings[frame.SettingsFrame.MAX_FRAME_SIZE]
chunk_size = self.http2_settings[hyperframe.frame.SettingsFrame.MAX_FRAME_SIZE]
chunks = range(0, len(header_block_fragment), chunk_size)
frms = [frm_cls(
flags=[],
@ -372,9 +375,9 @@ class HTTP2Protocol(object):
if body is None or len(body) == 0:
return b''
chunk_size = self.http2_settings[frame.SettingsFrame.MAX_FRAME_SIZE]
chunk_size = self.http2_settings[hyperframe.frame.SettingsFrame.MAX_FRAME_SIZE]
chunks = range(0, len(body), chunk_size)
frms = [frame.DataFrame(
frms = [hyperframe.frame.DataFrame(
flags=[],
stream_id=stream_id,
data=body[i:i + chunk_size]) for i in chunks]
@ -398,7 +401,7 @@ class HTTP2Protocol(object):
while True:
frm = self.read_frame()
if (
(isinstance(frm, frame.HeadersFrame) or isinstance(frm, frame.ContinuationFrame)) and
(isinstance(frm, hyperframe.frame.HeadersFrame) or isinstance(frm, hyperframe.frame.ContinuationFrame)) and
(stream_id is None or frm.stream_id == stream_id)
):
stream_id = frm.stream_id
@ -412,7 +415,7 @@ class HTTP2Protocol(object):
while body_expected:
frm = self.read_frame()
if isinstance(frm, frame.DataFrame) and frm.stream_id == stream_id:
if isinstance(frm, hyperframe.frame.DataFrame) and frm.stream_id == stream_id:
body += frm.data
if 'END_STREAM' in frm.flags:
break

View File

@ -0,0 +1,21 @@
import codecs
import hyperframe
def http2_read_raw_frame(rfile):
header = rfile.safe_read(9)
length = int(codecs.encode(header[:3], 'hex_codec'), 16)
if length == 4740180:
raise ValueError("Length field looks more like HTTP/1.1: %s" % rfile.peek(20))
body = rfile.safe_read(length)
return [header, body]
def http2_read_frame(rfile):
header, body = http2_read_raw_frame(rfile)
frame, length = hyperframe.frame.Frame.parse_frame_header(header)
frame.parse_body(memoryview(body))
return frame

View File

@ -4,9 +4,8 @@ import warnings
import six
from .headers import Headers
from .. import encoding, utils
from ..utils import always_bytes
from .. import encoding, utils, basetypes
from . import headers
if six.PY2: # pragma: no cover
def _native(x):
@ -20,10 +19,10 @@ else:
return x.decode("utf-8", "surrogateescape")
def _always_bytes(x):
return always_bytes(x, "utf-8", "surrogateescape")
return utils.always_bytes(x, "utf-8", "surrogateescape")
class MessageData(utils.Serializable):
class MessageData(basetypes.Serializable):
def __eq__(self, other):
if isinstance(other, MessageData):
return self.__dict__ == other.__dict__
@ -38,7 +37,7 @@ class MessageData(utils.Serializable):
def set_state(self, state):
for k, v in state.items():
if k == "headers":
v = Headers.from_state(v)
v = headers.Headers.from_state(v)
setattr(self, k, v)
def get_state(self):
@ -48,11 +47,11 @@ class MessageData(utils.Serializable):
@classmethod
def from_state(cls, state):
state["headers"] = Headers.from_state(state["headers"])
state["headers"] = headers.Headers.from_state(state["headers"])
return cls(**state)
class Message(utils.Serializable):
class Message(basetypes.Serializable):
def __eq__(self, other):
if isinstance(other, Message):
return self.data == other.data
@ -72,7 +71,7 @@ class Message(utils.Serializable):
@classmethod
def from_state(cls, state):
state["headers"] = Headers.from_state(state["headers"])
state["headers"] = headers.Headers.from_state(state["headers"])
return cls(**state)
@property

32
netlib/http/multipart.py Normal file
View File

@ -0,0 +1,32 @@
import re
from . import headers
def decode(hdrs, content):
"""
Takes a multipart boundary encoded string and returns list of (key, value) tuples.
"""
v = hdrs.get("content-type")
if v:
v = headers.parse_content_type(v)
if not v:
return []
try:
boundary = v[2]["boundary"].encode("ascii")
except (KeyError, UnicodeError):
return []
rx = re.compile(br'\bname="([^"]+)"')
r = []
for i in content.split(b"--" + boundary):
parts = i.splitlines()
if len(parts) > 1 and parts[0][0:2] != b"--":
match = rx.search(parts[1])
if match:
key = match.group(1)
value = b"".join(parts[3 + parts[2:].index(b""):])
r.append((key, value))
return r
return []

View File

@ -6,7 +6,9 @@ import six
from six.moves import urllib
from netlib import utils
from netlib.http import cookies
import netlib.http.url
from netlib.http import multipart
from . import cookies
from .. import encoding
from ..multidict import MultiDictView
from .headers import Headers
@ -179,11 +181,11 @@ class Request(Message):
"""
if self.first_line_format == "authority":
return "%s:%d" % (self.host, self.port)
return utils.unparse_url(self.scheme, self.host, self.port, self.path)
return netlib.http.url.unparse(self.scheme, self.host, self.port, self.path)
@url.setter
def url(self, url):
self.scheme, self.host, self.port, self.path = utils.parse_url(url)
self.scheme, self.host, self.port, self.path = netlib.http.url.parse(url)
def _parse_host_header(self):
"""Extract the host and port from Host header"""
@ -219,7 +221,7 @@ class Request(Message):
"""
if self.first_line_format == "authority":
return "%s:%d" % (self.pretty_host, self.port)
return utils.unparse_url(self.scheme, self.pretty_host, self.port, self.path)
return netlib.http.url.unparse(self.scheme, self.pretty_host, self.port, self.path)
@property
def query(self):
@ -234,12 +236,12 @@ class Request(Message):
def _get_query(self):
_, _, _, _, query, _ = urllib.parse.urlparse(self.url)
return tuple(utils.urldecode(query))
return tuple(netlib.http.url.decode(query))
def _set_query(self, value):
query = utils.urlencode(value)
query = netlib.http.url.encode(value)
scheme, netloc, path, params, _, fragment = urllib.parse.urlparse(self.url)
_, _, _, self.path = utils.parse_url(
_, _, _, self.path = netlib.http.url.parse(
urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment]))
@query.setter
@ -287,7 +289,7 @@ class Request(Message):
components = map(lambda x: urllib.parse.quote(x, safe=""), components)
path = "/" + "/".join(components)
scheme, netloc, _, params, query, fragment = urllib.parse.urlparse(self.url)
_, _, _, self.path = utils.parse_url(
_, _, _, self.path = netlib.http.url.parse(
urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment]))
def anticache(self):
@ -339,7 +341,7 @@ class Request(Message):
def _get_urlencoded_form(self):
is_valid_content_type = "application/x-www-form-urlencoded" in self.headers.get("content-type", "").lower()
if is_valid_content_type:
return tuple(utils.urldecode(self.content))
return tuple(netlib.http.url.decode(self.content))
return ()
def _set_urlencoded_form(self, value):
@ -348,7 +350,7 @@ class Request(Message):
This will overwrite the existing content if there is one.
"""
self.headers["content-type"] = "application/x-www-form-urlencoded"
self.content = utils.urlencode(value)
self.content = netlib.http.url.encode(value)
@urlencoded_form.setter
def urlencoded_form(self, value):
@ -368,7 +370,7 @@ class Request(Message):
def _get_multipart_form(self):
is_valid_content_type = "multipart/form-data" in self.headers.get("content-type", "").lower()
if is_valid_content_type:
return utils.multipartdecode(self.headers, self.content)
return multipart.decode(self.headers, self.content)
return ()
def _set_multipart_form(self, value):

View File

@ -7,7 +7,7 @@ from . import cookies
from .headers import Headers
from .message import Message, _native, _always_bytes, MessageData
from ..multidict import MultiDictView
from .. import utils
from .. import human
class ResponseData(MessageData):
@ -36,7 +36,7 @@ class Response(Message):
if self.content:
details = "{}, {}".format(
self.headers.get("content-type", "unknown content type"),
utils.pretty_size(len(self.content))
human.pretty_size(len(self.content))
)
else:
details = "no content"

96
netlib/http/url.py Normal file
View File

@ -0,0 +1,96 @@
import six
from six.moves import urllib
from .. import utils
# PY2 workaround
def decode_parse_result(result, enc):
if hasattr(result, "decode"):
return result.decode(enc)
else:
return urllib.parse.ParseResult(*[x.decode(enc) for x in result])
# PY2 workaround
def encode_parse_result(result, enc):
if hasattr(result, "encode"):
return result.encode(enc)
else:
return urllib.parse.ParseResult(*[x.encode(enc) for x in result])
def parse(url):
"""
URL-parsing function that checks that
- port is an integer 0-65535
- host is a valid IDNA-encoded hostname with no null-bytes
- path is valid ASCII
Args:
A URL (as bytes or as unicode)
Returns:
A (scheme, host, port, path) tuple
Raises:
ValueError, if the URL is not properly formatted.
"""
parsed = urllib.parse.urlparse(url)
if not parsed.hostname:
raise ValueError("No hostname given")
if isinstance(url, six.binary_type):
host = parsed.hostname
# this should not raise a ValueError,
# but we try to be very forgiving here and accept just everything.
# decode_parse_result(parsed, "ascii")
else:
host = parsed.hostname.encode("idna")
parsed = encode_parse_result(parsed, "ascii")
port = parsed.port
if not port:
port = 443 if parsed.scheme == b"https" else 80
full_path = urllib.parse.urlunparse(
(b"", b"", parsed.path, parsed.params, parsed.query, parsed.fragment)
)
if not full_path.startswith(b"/"):
full_path = b"/" + full_path
if not utils.is_valid_host(host):
raise ValueError("Invalid Host")
if not utils.is_valid_port(port):
raise ValueError("Invalid Port")
return parsed.scheme, host, port, full_path
def unparse(scheme, host, port, path=""):
"""
Returns a URL string, constructed from the specified components.
Args:
All args must be str.
"""
if path == "*":
path = ""
return "%s://%s%s" % (scheme, utils.hostport(scheme, host, port), path)
def encode(s):
"""
Takes a list of (key, value) tuples and returns a urlencoded string.
"""
s = [tuple(i) for i in s]
return urllib.parse.urlencode(s, False)
def decode(s):
"""
Takes a urlencoded string and returns a list of (key, value) tuples.
"""
return urllib.parse.parse_qsl(s, keep_blank_values=True)

50
netlib/human.py Normal file
View File

@ -0,0 +1,50 @@
SIZE_TABLE = [
("b", 1024 ** 0),
("k", 1024 ** 1),
("m", 1024 ** 2),
("g", 1024 ** 3),
("t", 1024 ** 4),
]
SIZE_UNITS = dict(SIZE_TABLE)
def pretty_size(size):
for bottom, top in zip(SIZE_TABLE, SIZE_TABLE[1:]):
if bottom[1] <= size < top[1]:
suf = bottom[0]
lim = bottom[1]
x = round(size / lim, 2)
if x == int(x):
x = int(x)
return str(x) + suf
return "%s%s" % (size, SIZE_TABLE[0][0])
def parse_size(s):
try:
return int(s)
except ValueError:
pass
for i in SIZE_UNITS.keys():
if s.endswith(i):
try:
return int(s[:-1]) * SIZE_UNITS[i]
except ValueError:
break
raise ValueError("Invalid size specification.")
def pretty_duration(secs):
formatters = [
(100, "{:.0f}s"),
(10, "{:2.1f}s"),
(1, "{:1.2f}s"),
]
for limit, formatter in formatters:
if secs >= limit:
return formatter.format(secs)
# less than 1 sec
return "{:.0f}ms".format(secs * 1000)

View File

@ -9,12 +9,11 @@ except ImportError: # pragma: no cover
from collections import MutableMapping # Workaround for Python < 3.3
import six
from .utils import Serializable
from . import basetypes
@six.add_metaclass(ABCMeta)
class _MultiDict(MutableMapping, Serializable):
class _MultiDict(MutableMapping, basetypes.Serializable):
def __repr__(self):
fields = (
repr(field)
@ -171,6 +170,14 @@ class _MultiDict(MutableMapping, Serializable):
else:
return super(_MultiDict, self).items()
def clear(self, key):
"""
Removes all items with the specified key, and does not raise an
exception if the key does not exist.
"""
if key in self:
del self[key]
def to_dict(self):
"""
Get the MultiDict as a plain Python dict.

View File

@ -3,10 +3,10 @@ import copy
import six
from .utils import Serializable, safe_subn
from . import basetypes, utils
class ODict(Serializable):
class ODict(basetypes.Serializable):
"""
A dictionary-like object for managing ordered (key, value) data. Think
@ -139,9 +139,9 @@ class ODict(Serializable):
"""
new, count = [], 0
for k, v in self.lst:
k, c = safe_subn(pattern, repl, k, *args, **kwargs)
k, c = utils.safe_subn(pattern, repl, k, *args, **kwargs)
count += c
v, c = safe_subn(pattern, repl, v, *args, **kwargs)
v, c = utils.safe_subn(pattern, repl, v, *args, **kwargs)
count += c
new.append([k, v])
self.lst = new

View File

@ -16,7 +16,7 @@ import six
import OpenSSL
from OpenSSL import SSL
from . import certutils, version_check, utils
from . import certutils, version_check, basetypes
# This is a rather hackish way to make sure that
# the latest version of pyOpenSSL is actually installed.
@ -302,7 +302,7 @@ class Reader(_FileLike):
raise NotImplementedError("Can only peek into (pyOpenSSL) sockets")
class Address(utils.Serializable):
class Address(basetypes.Serializable):
"""
This class wraps an IPv4/IPv6 tuple to provide named attributes and

View File

@ -3,47 +3,11 @@ import os.path
import re
import codecs
import unicodedata
from abc import ABCMeta, abstractmethod
import importlib
import inspect
import six
from six.moves import urllib
import hyperframe
@six.add_metaclass(ABCMeta)
class Serializable(object):
"""
Abstract Base Class that defines an API to save an object's state and restore it later on.
"""
@classmethod
@abstractmethod
def from_state(cls, state):
"""
Create a new object from the given state.
"""
raise NotImplementedError()
@abstractmethod
def get_state(self):
"""
Retrieve object state.
"""
raise NotImplementedError()
@abstractmethod
def set_state(self, state):
"""
Set object state to the given state.
"""
raise NotImplementedError()
def copy(self):
return self.from_state(self.get_state())
def always_bytes(unicode_or_bytes, *encode_args):
if isinstance(unicode_or_bytes, six.text_type):
@ -69,14 +33,6 @@ def native(s, *encoding_opts):
return s
def isascii(bytes):
try:
bytes.decode("ascii")
except ValueError:
return False
return True
def clean_bin(s, keep_spacing=True):
"""
Cleans binary data to make it safe to display.
@ -161,22 +117,6 @@ class BiDi(object):
return self.values.get(n, default)
def pretty_size(size):
suffixes = [
("B", 2 ** 10),
("kB", 2 ** 20),
("MB", 2 ** 30),
]
for suf, lim in suffixes:
if size >= lim:
continue
else:
x = round(size / float(lim / 2 ** 10), 2)
if x == int(x):
x = int(x)
return str(x) + suf
class Data(object):
def __init__(self, name):
@ -222,83 +162,6 @@ def is_valid_port(port):
return 0 <= port <= 65535
# PY2 workaround
def decode_parse_result(result, enc):
if hasattr(result, "decode"):
return result.decode(enc)
else:
return urllib.parse.ParseResult(*[x.decode(enc) for x in result])
# PY2 workaround
def encode_parse_result(result, enc):
if hasattr(result, "encode"):
return result.encode(enc)
else:
return urllib.parse.ParseResult(*[x.encode(enc) for x in result])
def parse_url(url):
"""
URL-parsing function that checks that
- port is an integer 0-65535
- host is a valid IDNA-encoded hostname with no null-bytes
- path is valid ASCII
Args:
A URL (as bytes or as unicode)
Returns:
A (scheme, host, port, path) tuple
Raises:
ValueError, if the URL is not properly formatted.
"""
parsed = urllib.parse.urlparse(url)
if not parsed.hostname:
raise ValueError("No hostname given")
if isinstance(url, six.binary_type):
host = parsed.hostname
# this should not raise a ValueError,
# but we try to be very forgiving here and accept just everything.
# decode_parse_result(parsed, "ascii")
else:
host = parsed.hostname.encode("idna")
parsed = encode_parse_result(parsed, "ascii")
port = parsed.port
if not port:
port = 443 if parsed.scheme == b"https" else 80
full_path = urllib.parse.urlunparse(
(b"", b"", parsed.path, parsed.params, parsed.query, parsed.fragment)
)
if not full_path.startswith(b"/"):
full_path = b"/" + full_path
if not is_valid_host(host):
raise ValueError("Invalid Host")
if not is_valid_port(port):
raise ValueError("Invalid Port")
return parsed.scheme, host, port, full_path
def get_header_tokens(headers, key):
"""
Retrieve all tokens for a header key. A number of different headers
follow a pattern where each header line can containe comma-separated
tokens, and headers can be set multiple times.
"""
if key not in headers:
return []
tokens = headers[key].split(",")
return [token.strip() for token in tokens]
def hostport(scheme, host, port):
"""
Returns the host component, with a port specifcation if needed.
@ -312,107 +175,6 @@ def hostport(scheme, host, port):
return "%s:%d" % (host, port)
def unparse_url(scheme, host, port, path=""):
"""
Returns a URL string, constructed from the specified components.
Args:
All args must be str.
"""
if path == "*":
path = ""
return "%s://%s%s" % (scheme, hostport(scheme, host, port), path)
def urlencode(s):
"""
Takes a list of (key, value) tuples and returns a urlencoded string.
"""
s = [tuple(i) for i in s]
return urllib.parse.urlencode(s, False)
def urldecode(s):
"""
Takes a urlencoded string and returns a list of (key, value) tuples.
"""
return urllib.parse.parse_qsl(s, keep_blank_values=True)
def parse_content_type(c):
"""
A simple parser for content-type values. Returns a (type, subtype,
parameters) tuple, where type and subtype are strings, and parameters
is a dict. If the string could not be parsed, return None.
E.g. the following string:
text/html; charset=UTF-8
Returns:
("text", "html", {"charset": "UTF-8"})
"""
parts = c.split(";", 1)
ts = parts[0].split("/", 1)
if len(ts) != 2:
return None
d = {}
if len(parts) == 2:
for i in parts[1].split(";"):
clause = i.split("=", 1)
if len(clause) == 2:
d[clause[0].strip()] = clause[1].strip()
return ts[0].lower(), ts[1].lower(), d
def multipartdecode(headers, content):
"""
Takes a multipart boundary encoded string and returns list of (key, value) tuples.
"""
v = headers.get("content-type")
if v:
v = parse_content_type(v)
if not v:
return []
try:
boundary = v[2]["boundary"].encode("ascii")
except (KeyError, UnicodeError):
return []
rx = re.compile(br'\bname="([^"]+)"')
r = []
for i in content.split(b"--" + boundary):
parts = i.splitlines()
if len(parts) > 1 and parts[0][0:2] != b"--":
match = rx.search(parts[1])
if match:
key = match.group(1)
value = b"".join(parts[3 + parts[2:].index(b""):])
r.append((key, value))
return r
return []
def http2_read_raw_frame(rfile):
header = rfile.safe_read(9)
length = int(codecs.encode(header[:3], 'hex_codec'), 16)
if length == 4740180:
raise ValueError("Length field looks more like HTTP/1.1: %s" % rfile.peek(20))
body = rfile.safe_read(length)
return [header, body]
def http2_read_frame(rfile):
header, body = http2_read_raw_frame(rfile)
frame, length = hyperframe.frame.Frame.parse_frame_header(header)
frame.parse_body(memoryview(body))
return frame
def safe_subn(pattern, repl, target, *args, **kwargs):
"""
There are Unicode conversion problems with re.subn. We try to smooth

View File

@ -9,6 +9,7 @@ import six
from .protocol import Masker
from netlib import tcp
from netlib import utils
from netlib import human
MAX_16_BIT_INT = (1 << 16)
@ -98,7 +99,7 @@ class FrameHeader(object):
if self.masking_key:
vals.append(":key=%s" % repr(self.masking_key))
if self.payload_length:
vals.append(" %s" % utils.pretty_size(self.payload_length))
vals.append(" %s" % human.pretty_size(self.payload_length))
return "".join(vals)
def human_readable(self):

View File

@ -5,8 +5,8 @@ import pyparsing as pp
from six.moves import reduce
from netlib.utils import escaped_str_to_bytes, bytes_to_escaped_str
from netlib import human
from .. import utils
from . import generators, exceptions
@ -158,7 +158,7 @@ class TokValueGenerate(Token):
self.usize, self.unit, self.datatype = usize, unit, datatype
def bytes(self):
return self.usize * utils.SIZE_UNITS[self.unit]
return self.usize * human.SIZE_UNITS[self.unit]
def get_generator(self, settings_):
return generators.RandomGenerator(self.datatype, self.bytes())
@ -173,7 +173,7 @@ class TokValueGenerate(Token):
u = reduce(
operator.or_,
[pp.Literal(i) for i in utils.SIZE_UNITS.keys()]
[pp.Literal(i) for i in human.SIZE_UNITS.keys()]
).leaveWhitespace()
e = e + pp.Optional(u, default=None)

View File

@ -4,7 +4,7 @@ import os
import os.path
import re
from netlib import tcp
from netlib import tcp, human
from . import pathod, version, utils
@ -205,7 +205,7 @@ def args_pathod(argv, stdout_=sys.stdout, stderr_=sys.stderr):
sizelimit = None
if args.sizelimit:
try:
sizelimit = utils.parse_size(args.sizelimit)
sizelimit = human.parse_size(args.sizelimit)
except ValueError as v:
return parser.error(v)
args.sizelimit = sizelimit

View File

@ -5,15 +5,6 @@ import netlib.utils
from netlib.utils import bytes_to_escaped_str
SIZE_UNITS = dict(
b=1024 ** 0,
k=1024 ** 1,
m=1024 ** 2,
g=1024 ** 3,
t=1024 ** 4,
)
class MemBool(object):
"""
@ -28,20 +19,6 @@ class MemBool(object):
return bool(v)
def parse_size(s):
try:
return int(s)
except ValueError:
pass
for i in SIZE_UNITS.keys():
if s.endswith(i):
try:
return int(s[:-1]) * SIZE_UNITS[i]
except ValueError:
break
raise ValueError("Invalid size specification.")
def parse_anchor_spec(s):
"""
Return a tuple, or None on error.

View File

@ -1,8 +1,8 @@
from mitmproxy.exceptions import ContentViewException
from netlib.http import Headers
from netlib.odict import ODict
import netlib.utils
from netlib import encoding
from netlib.http import url
import mitmproxy.contentviews as cv
from . import tutils
@ -60,10 +60,10 @@ class TestContentView:
assert f[0] == "Query"
def test_view_urlencoded(self):
d = netlib.utils.urlencode([("one", "two"), ("three", "four")])
d = url.encode([("one", "two"), ("three", "four")])
v = cv.ViewURLEncoded()
assert v(d)
d = netlib.utils.urlencode([("adsfa", "")])
d = url.encode([("adsfa", "")])
v = cv.ViewURLEncoded()
assert v(d)

View File

@ -13,7 +13,7 @@ from mitmproxy.cmdline import APP_HOST, APP_PORT
import netlib
from ..netlib import tservers as netlib_tservers
from netlib.utils import http2_read_raw_frame
from netlib.http.http2 import framereader
from . import tservers
@ -48,7 +48,7 @@ class _Http2ServerBase(netlib_tservers.ServerTestBase):
done = False
while not done:
try:
raw = b''.join(http2_read_raw_frame(self.rfile))
raw = b''.join(framereader.http2_read_raw_frame(self.rfile))
events = h2_conn.receive_data(raw)
except:
break
@ -200,7 +200,7 @@ class TestSimple(_Http2TestBase, _Http2ServerBase):
done = False
while not done:
try:
events = h2_conn.receive_data(b''.join(http2_read_raw_frame(client.rfile)))
events = h2_conn.receive_data(b''.join(framereader.http2_read_raw_frame(client.rfile)))
except:
break
client.wfile.write(h2_conn.data_to_send())
@ -270,7 +270,7 @@ class TestWithBodies(_Http2TestBase, _Http2ServerBase):
done = False
while not done:
try:
events = h2_conn.receive_data(b''.join(http2_read_raw_frame(client.rfile)))
events = h2_conn.receive_data(b''.join(framereader.http2_read_raw_frame(client.rfile)))
except:
break
client.wfile.write(h2_conn.data_to_send())
@ -362,7 +362,7 @@ class TestPushPromise(_Http2TestBase, _Http2ServerBase):
responses = 0
while not done:
try:
raw = b''.join(http2_read_raw_frame(client.rfile))
raw = b''.join(framereader.http2_read_raw_frame(client.rfile))
events = h2_conn.receive_data(raw)
except:
break
@ -412,7 +412,7 @@ class TestPushPromise(_Http2TestBase, _Http2ServerBase):
responses = 0
while not done:
try:
events = h2_conn.receive_data(b''.join(http2_read_raw_frame(client.rfile)))
events = h2_conn.receive_data(b''.join(framereader.http2_read_raw_frame(client.rfile)))
except:
break
client.wfile.write(h2_conn.data_to_send())
@ -479,7 +479,7 @@ class TestConnectionLost(_Http2TestBase, _Http2ServerBase):
done = False
while not done:
try:
raw = b''.join(http2_read_raw_frame(client.rfile))
raw = b''.join(framereader.http2_read_raw_frame(client.rfile))
h2_conn.receive_data(raw)
except:
break

View File

@ -43,21 +43,6 @@ def test_pretty_json():
assert not utils.pretty_json("moo")
def test_pretty_duration():
assert utils.pretty_duration(0.00001) == "0ms"
assert utils.pretty_duration(0.0001) == "0ms"
assert utils.pretty_duration(0.001) == "1ms"
assert utils.pretty_duration(0.01) == "10ms"
assert utils.pretty_duration(0.1) == "100ms"
assert utils.pretty_duration(1) == "1.00s"
assert utils.pretty_duration(10) == "10.0s"
assert utils.pretty_duration(100) == "100s"
assert utils.pretty_duration(1000) == "1000s"
assert utils.pretty_duration(10000) == "10000s"
assert utils.pretty_duration(1.123) == "1.12s"
assert utils.pretty_duration(0.123) == "123ms"
def test_LRUCache():
cache = utils.LRUCache(2)
@ -89,13 +74,3 @@ def test_LRUCache():
assert len(cache.cacheList) == 2
assert len(cache.cache) == 2
def test_parse_size():
assert not utils.parse_size("")
assert utils.parse_size("1") == 1
assert utils.parse_size("1k") == 1024
assert utils.parse_size("1m") == 1024**2
assert utils.parse_size("1g") == 1024**3
tutils.raises(ValueError, utils.parse_size, "1f")
tutils.raises(ValueError, utils.parse_size, "ak")

View File

@ -7,11 +7,22 @@ from netlib.http.http1.read import (
read_request, read_response, read_request_head,
read_response_head, read_body, connection_close, expected_http_body_size, _get_first_line,
_read_request_line, _parse_authority_form, _read_response_line, _check_http_version,
_read_headers, _read_chunked
_read_headers, _read_chunked, get_header_tokens
)
from netlib.tutils import treq, tresp, raises
def test_get_header_tokens():
headers = Headers()
assert get_header_tokens(headers, "foo") == []
headers["foo"] = "bar"
assert get_header_tokens(headers, "foo") == ["bar"]
headers["foo"] = "bar, voing"
assert get_header_tokens(headers, "foo") == ["bar", "voing"]
headers.set_all("foo", ["bar, voing", "oink"])
assert get_header_tokens(headers, "foo") == ["bar", "voing", "oink"]
def test_read_request():
rfile = BytesIO(b"GET / HTTP/1.1\r\n\r\nskip")
r = read_request(rfile)

View File

@ -1,12 +1,12 @@
import mock
import codecs
from hyperframe import frame
from netlib import tcp, http, utils
import hyperframe
from netlib import tcp, http
from netlib.tutils import raises
from netlib.exceptions import TcpDisconnect
from netlib.http.http2.connections import HTTP2Protocol, TCPHandler
from netlib.http.http2 import framereader
from ... import tservers
@ -111,11 +111,11 @@ class TestPerformServerConnectionPreface(tservers.ServerTestBase):
self.wfile.flush()
# check empty settings frame
raw = utils.http2_read_raw_frame(self.rfile)
raw = framereader.http2_read_raw_frame(self.rfile)
assert raw == codecs.decode('00000c040000000000000200000000000300000001', 'hex_codec')
# check settings acknowledgement
raw = utils.http2_read_raw_frame(self.rfile)
raw = framereader.http2_read_raw_frame(self.rfile)
assert raw == codecs.decode('000000040100000000', 'hex_codec')
# send settings acknowledgement
@ -214,19 +214,19 @@ class TestApplySettings(tservers.ServerTestBase):
protocol = HTTP2Protocol(c)
protocol._apply_settings({
frame.SettingsFrame.ENABLE_PUSH: 'foo',
frame.SettingsFrame.MAX_CONCURRENT_STREAMS: 'bar',
frame.SettingsFrame.INITIAL_WINDOW_SIZE: 'deadbeef',
hyperframe.frame.SettingsFrame.ENABLE_PUSH: 'foo',
hyperframe.frame.SettingsFrame.MAX_CONCURRENT_STREAMS: 'bar',
hyperframe.frame.SettingsFrame.INITIAL_WINDOW_SIZE: 'deadbeef',
})
assert c.rfile.safe_read(2) == b"OK"
assert protocol.http2_settings[
frame.SettingsFrame.ENABLE_PUSH] == 'foo'
hyperframe.frame.SettingsFrame.ENABLE_PUSH] == 'foo'
assert protocol.http2_settings[
frame.SettingsFrame.MAX_CONCURRENT_STREAMS] == 'bar'
hyperframe.frame.SettingsFrame.MAX_CONCURRENT_STREAMS] == 'bar'
assert protocol.http2_settings[
frame.SettingsFrame.INITIAL_WINDOW_SIZE] == 'deadbeef'
hyperframe.frame.SettingsFrame.INITIAL_WINDOW_SIZE] == 'deadbeef'
class TestCreateHeaders(object):
@ -258,7 +258,7 @@ class TestCreateHeaders(object):
(b'server', b'version')])
protocol = HTTP2Protocol(self.c)
protocol.http2_settings[frame.SettingsFrame.MAX_FRAME_SIZE] = 8
protocol.http2_settings[hyperframe.frame.SettingsFrame.MAX_FRAME_SIZE] = 8
bytes = protocol._create_headers(headers, 1, end_stream=True)
assert len(bytes) == 3
assert bytes[0] == codecs.decode('000008010100000001828487408294e783', 'hex_codec')
@ -281,7 +281,7 @@ class TestCreateBody(object):
def test_create_body_multiple_frames(self):
protocol = HTTP2Protocol(self.c)
protocol.http2_settings[frame.SettingsFrame.MAX_FRAME_SIZE] = 5
protocol.http2_settings[hyperframe.frame.SettingsFrame.MAX_FRAME_SIZE] = 5
bytes = protocol._create_body(b'foobarmehm42', 1)
assert len(bytes) == 3
assert bytes[0] == codecs.decode('000005000000000001666f6f6261', 'hex_codec')
@ -312,7 +312,10 @@ class TestReadRequest(tservers.ServerTestBase):
req = protocol.read_request(NotImplemented)
assert req.stream_id
assert req.headers.fields == ((b':method', b'GET'), (b':path', b'/'), (b':scheme', b'https'))
assert req.headers.fields == ()
assert req.method == "GET"
assert req.path == "/"
assert req.scheme == "https"
assert req.content == b'foobar'

View File

@ -1,4 +1,5 @@
from netlib.http import Headers
from netlib.http.headers import parse_content_type
from netlib.tutils import raises
@ -72,3 +73,12 @@ class TestHeaders(object):
replacements = headers.replace(r"Host: ", "X-Host ")
assert replacements == 0
assert headers["Host"] == "example.com"
def test_parse_content_type():
p = parse_content_type
assert p("text/html") == ("text", "html", {})
assert p("text") is None
v = p("text/html; charset=UTF-8")
assert v == ('text', 'html', {'charset': 'UTF-8'})

View File

@ -0,0 +1,24 @@
from netlib.http import Headers
from netlib.http import multipart
def test_decode():
boundary = 'somefancyboundary'
headers = Headers(
content_type='multipart/form-data; boundary=' + boundary
)
content = (
"--{0}\n"
"Content-Disposition: form-data; name=\"field1\"\n\n"
"value1\n"
"--{0}\n"
"Content-Disposition: form-data; name=\"field2\"\n\n"
"value2\n"
"--{0}--".format(boundary).encode()
)
form = multipart.decode(headers, content)
assert len(form) == 2
assert form[0] == (b"field1", b"value1")
assert form[1] == (b"field2", b"value2")

View File

@ -24,7 +24,7 @@ class TestResponseCore(object):
"""
def test_repr(self):
response = tresp()
assert repr(response) == "Response(200 OK, unknown content type, 7B)"
assert repr(response) == "Response(200 OK, unknown content type, 7b)"
response.content = None
assert repr(response) == "Response(200 OK, no content)"

View File

@ -0,0 +1,66 @@
from netlib import tutils
from netlib.http import url
def test_parse():
with tutils.raises(ValueError):
url.parse("")
s, h, po, pa = url.parse(b"http://foo.com:8888/test")
assert s == b"http"
assert h == b"foo.com"
assert po == 8888
assert pa == b"/test"
s, h, po, pa = url.parse("http://foo/bar")
assert s == b"http"
assert h == b"foo"
assert po == 80
assert pa == b"/bar"
s, h, po, pa = url.parse(b"http://user:pass@foo/bar")
assert s == b"http"
assert h == b"foo"
assert po == 80
assert pa == b"/bar"
s, h, po, pa = url.parse(b"http://foo")
assert pa == b"/"
s, h, po, pa = url.parse(b"https://foo")
assert po == 443
with tutils.raises(ValueError):
url.parse(b"https://foo:bar")
# Invalid IDNA
with tutils.raises(ValueError):
url.parse("http://\xfafoo")
# Invalid PATH
with tutils.raises(ValueError):
url.parse("http:/\xc6/localhost:56121")
# Null byte in host
with tutils.raises(ValueError):
url.parse("http://foo\0")
# Port out of range
_, _, port, _ = url.parse("http://foo:999999")
assert port == 80
# Invalid IPv6 URL - see http://www.ietf.org/rfc/rfc2732.txt
with tutils.raises(ValueError):
url.parse('http://lo[calhost')
def test_unparse():
assert url.unparse("http", "foo.com", 99, "") == "http://foo.com:99"
assert url.unparse("http", "foo.com", 80, "/bar") == "http://foo.com/bar"
assert url.unparse("https", "foo.com", 80, "") == "https://foo.com:80"
assert url.unparse("https", "foo.com", 443, "") == "https://foo.com"
def test_urlencode():
assert url.encode([('foo', 'bar')])
def test_urldecode():
s = "one=two&three=four"
assert len(url.decode(s)) == 2

View File

@ -0,0 +1,28 @@
from netlib import basetypes
class SerializableDummy(basetypes.Serializable):
def __init__(self, i):
self.i = i
def get_state(self):
return self.i
def set_state(self, i):
self.i = i
def from_state(self, state):
return type(self)(state)
class TestSerializable:
def test_copy(self):
a = SerializableDummy(42)
assert a.i == 42
b = a.copy()
assert b.i == 42
a.set_state(1)
assert a.i == 1
assert b.i == 42

36
test/netlib/test_human.py Normal file
View File

@ -0,0 +1,36 @@
from netlib import human, tutils
def test_parse_size():
assert human.parse_size("0") == 0
assert human.parse_size("0b") == 0
assert human.parse_size("1") == 1
assert human.parse_size("1k") == 1024
assert human.parse_size("1m") == 1024**2
assert human.parse_size("1g") == 1024**3
tutils.raises(ValueError, human.parse_size, "1f")
tutils.raises(ValueError, human.parse_size, "ak")
def test_pretty_size():
assert human.pretty_size(0) == "0b"
assert human.pretty_size(100) == "100b"
assert human.pretty_size(1024) == "1k"
assert human.pretty_size(1024 + (1024 / 2.0)) == "1.5k"
assert human.pretty_size(1024 * 1024) == "1m"
assert human.pretty_size(10 * 1024 * 1024) == "10m"
def test_pretty_duration():
assert human.pretty_duration(0.00001) == "0ms"
assert human.pretty_duration(0.0001) == "0ms"
assert human.pretty_duration(0.001) == "1ms"
assert human.pretty_duration(0.01) == "10ms"
assert human.pretty_duration(0.1) == "100ms"
assert human.pretty_duration(1) == "1.00s"
assert human.pretty_duration(10) == "10.0s"
assert human.pretty_duration(100) == "100s"
assert human.pretty_duration(1000) == "1000s"
assert human.pretty_duration(10000) == "10000s"
assert human.pretty_duration(1.123) == "1.12s"
assert human.pretty_duration(0.123) == "123ms"

View File

@ -1,7 +1,6 @@
# coding=utf-8
from netlib import utils, tutils
from netlib.http import Headers
def test_bidi():
@ -31,146 +30,6 @@ def test_clean_bin():
assert utils.clean_bin(u"\u2605") == u"\u2605"
def test_pretty_size():
assert utils.pretty_size(100) == "100B"
assert utils.pretty_size(1024) == "1kB"
assert utils.pretty_size(1024 + (1024 / 2.0)) == "1.5kB"
assert utils.pretty_size(1024 * 1024) == "1MB"
def test_parse_url():
with tutils.raises(ValueError):
utils.parse_url("")
s, h, po, pa = utils.parse_url(b"http://foo.com:8888/test")
assert s == b"http"
assert h == b"foo.com"
assert po == 8888
assert pa == b"/test"
s, h, po, pa = utils.parse_url("http://foo/bar")
assert s == b"http"
assert h == b"foo"
assert po == 80
assert pa == b"/bar"
s, h, po, pa = utils.parse_url(b"http://user:pass@foo/bar")
assert s == b"http"
assert h == b"foo"
assert po == 80
assert pa == b"/bar"
s, h, po, pa = utils.parse_url(b"http://foo")
assert pa == b"/"
s, h, po, pa = utils.parse_url(b"https://foo")
assert po == 443
with tutils.raises(ValueError):
utils.parse_url(b"https://foo:bar")
# Invalid IDNA
with tutils.raises(ValueError):
utils.parse_url("http://\xfafoo")
# Invalid PATH
with tutils.raises(ValueError):
utils.parse_url("http:/\xc6/localhost:56121")
# Null byte in host
with tutils.raises(ValueError):
utils.parse_url("http://foo\0")
# Port out of range
_, _, port, _ = utils.parse_url("http://foo:999999")
assert port == 80
# Invalid IPv6 URL - see http://www.ietf.org/rfc/rfc2732.txt
with tutils.raises(ValueError):
utils.parse_url('http://lo[calhost')
def test_unparse_url():
assert utils.unparse_url("http", "foo.com", 99, "") == "http://foo.com:99"
assert utils.unparse_url("http", "foo.com", 80, "/bar") == "http://foo.com/bar"
assert utils.unparse_url("https", "foo.com", 80, "") == "https://foo.com:80"
assert utils.unparse_url("https", "foo.com", 443, "") == "https://foo.com"
def test_urlencode():
assert utils.urlencode([('foo', 'bar')])
def test_urldecode():
s = "one=two&three=four"
assert len(utils.urldecode(s)) == 2
def test_get_header_tokens():
headers = Headers()
assert utils.get_header_tokens(headers, "foo") == []
headers["foo"] = "bar"
assert utils.get_header_tokens(headers, "foo") == ["bar"]
headers["foo"] = "bar, voing"
assert utils.get_header_tokens(headers, "foo") == ["bar", "voing"]
headers.set_all("foo", ["bar, voing", "oink"])
assert utils.get_header_tokens(headers, "foo") == ["bar", "voing", "oink"]
def test_multipartdecode():
boundary = 'somefancyboundary'
headers = Headers(
content_type='multipart/form-data; boundary=' + boundary
)
content = (
"--{0}\n"
"Content-Disposition: form-data; name=\"field1\"\n\n"
"value1\n"
"--{0}\n"
"Content-Disposition: form-data; name=\"field2\"\n\n"
"value2\n"
"--{0}--".format(boundary).encode()
)
form = utils.multipartdecode(headers, content)
assert len(form) == 2
assert form[0] == (b"field1", b"value1")
assert form[1] == (b"field2", b"value2")
def test_parse_content_type():
p = utils.parse_content_type
assert p("text/html") == ("text", "html", {})
assert p("text") is None
v = p("text/html; charset=UTF-8")
assert v == ('text', 'html', {'charset': 'UTF-8'})
class SerializableDummy(utils.Serializable):
def __init__(self, i):
self.i = i
def get_state(self):
return self.i
def set_state(self, i):
self.i = i
def from_state(self, state):
return type(self)(state)
class TestSerializable:
def test_copy(self):
a = SerializableDummy(42)
assert a.i == 42
b = a.copy()
assert b.i == 42
a.set_state(1)
assert a.i == 1
assert b.i == 42
def test_safe_subn():
assert utils.safe_subn("foo", u"bar", "\xc2foo")

View File

@ -13,13 +13,6 @@ def test_membool():
assert m.v == 2
def test_parse_size():
assert utils.parse_size("100") == 100
assert utils.parse_size("100k") == 100 * 1024
tutils.raises("invalid size spec", utils.parse_size, "foo")
tutils.raises("invalid size spec", utils.parse_size, "100kk")
def test_parse_anchor_spec():
assert utils.parse_anchor_spec("foo=200") == ("foo", "200")
assert utils.parse_anchor_spec("foo") is None