Reorganise netlib imports according to Google Style Guide

This commit is contained in:
Aldo Cortesi 2016-06-01 11:12:10 +12:00
parent be64445364
commit 44fdcb4b82
21 changed files with 172 additions and 170 deletions

View File

@ -12,7 +12,7 @@ from pyasn1.codec.der.decoder import decode
from pyasn1.error import PyAsn1Error from pyasn1.error import PyAsn1Error
import OpenSSL import OpenSSL
from . import basetypes from netlib import basetypes
# Default expiry must not be too long: https://github.com/mitmproxy/mitmproxy/issues/815 # Default expiry must not be too long: https://github.com/mitmproxy/mitmproxy/issues/815

View File

@ -1,9 +1,9 @@
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
from .request import Request from netlib.http.request import Request
from .response import Response from netlib.http.response import Response
from .headers import Headers, parse_content_type from netlib.http.headers import Headers, parse_content_type
from .message import decoded from netlib.http.message import decoded
from . import http1, http2, status_codes, multipart from netlib.http import http1, http2, status_codes, multipart
__all__ = [ __all__ = [
"Request", "Request",

View File

@ -1,5 +1,5 @@
from __future__ import (absolute_import, print_function, division) from __future__ import (absolute_import, print_function, division)
from argparse import Action, ArgumentTypeError import argparse
import binascii import binascii
@ -124,7 +124,7 @@ class PassManSingleUser(PassMan):
return self.username == username and self.password == password_token return self.username == username and self.password == password_token
class AuthAction(Action): class AuthAction(argparse.Action):
""" """
Helper class to allow seamless integration int argparse. Example usage: Helper class to allow seamless integration int argparse. Example usage:
@ -148,7 +148,7 @@ class SingleuserAuthAction(AuthAction):
def getPasswordManager(self, s): def getPasswordManager(self, s):
if len(s.split(':')) != 2: if len(s.split(':')) != 2:
raise ArgumentTypeError( raise argparse.ArgumentTypeError(
"Invalid single-user specification. Please use the format username:password" "Invalid single-user specification. Please use the format username:password"
) )
username, password = s.split(':') username, password = s.split(':')

View File

@ -1,8 +1,8 @@
import collections import collections
import re import re
from email.utils import parsedate_tz, formatdate, mktime_tz
from netlib.multidict import ImmutableMultiDict import email.utils
from netlib import multidict
""" """
A flexible module for cookie parsing and manipulation. A flexible module for cookie parsing and manipulation.
@ -167,7 +167,7 @@ def parse_set_cookie_headers(headers):
return ret return ret
class CookieAttrs(ImmutableMultiDict): class CookieAttrs(multidict.ImmutableMultiDict):
@staticmethod @staticmethod
def _kconv(key): def _kconv(key):
return key.lower() return key.lower()
@ -243,10 +243,10 @@ def refresh_set_cookie_header(c, delta):
raise ValueError("Invalid Cookie") raise ValueError("Invalid Cookie")
if "expires" in attrs: if "expires" in attrs:
e = parsedate_tz(attrs["expires"]) e = email.utils.parsedate_tz(attrs["expires"])
if e: if e:
f = mktime_tz(e) + delta f = email.utils.mktime_tz(e) + delta
attrs = attrs.with_set_all("expires", [formatdate(f)]) attrs = attrs.with_set_all("expires", [email.utils.formatdate(f)])
else: else:
# This can happen when the expires tag is invalid. # This can happen when the expires tag is invalid.
# reddit.com sends a an expires tag like this: "Thu, 31 Dec # reddit.com sends a an expires tag like this: "Thu, 31 Dec

View File

@ -3,8 +3,8 @@ from __future__ import absolute_import, print_function, division
import re import re
import six import six
from ..multidict import MultiDict from netlib import multidict
from ..utils import always_bytes from netlib import utils
# See also: http://lucumr.pocoo.org/2013/7/2/the-updated-guide-to-unicode/ # See also: http://lucumr.pocoo.org/2013/7/2/the-updated-guide-to-unicode/
@ -20,10 +20,10 @@ else:
return x.decode("utf-8", "surrogateescape") return x.decode("utf-8", "surrogateescape")
def _always_bytes(x): def _always_bytes(x):
return always_bytes(x, "utf-8", "surrogateescape") return utils.always_bytes(x, "utf-8", "surrogateescape")
class Headers(MultiDict): class Headers(multidict.MultiDict):
""" """
Header class which allows both convenient access to individual headers as well as Header class which allows both convenient access to individual headers as well as
direct access to the underlying raw data. Provides a full dictionary interface. direct access to the underlying raw data. Provides a full dictionary interface.

View File

@ -1,12 +1,12 @@
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
from ... import utils from netlib import utils
from ...exceptions import HttpException from netlib import exceptions
def assemble_request(request): def assemble_request(request):
if request.content is None: if request.content is None:
raise HttpException("Cannot assemble flow with missing content") raise exceptions.HttpException("Cannot assemble flow with missing content")
head = assemble_request_head(request) head = assemble_request_head(request)
body = b"".join(assemble_body(request.data.headers, [request.data.content])) body = b"".join(assemble_body(request.data.headers, [request.data.content]))
return head + body return head + body
@ -20,7 +20,7 @@ def assemble_request_head(request):
def assemble_response(response): def assemble_response(response):
if response.content is None: if response.content is None:
raise HttpException("Cannot assemble flow with missing content") raise exceptions.HttpException("Cannot assemble flow with missing content")
head = assemble_response_head(response) head = assemble_response_head(response)
body = b"".join(assemble_body(response.data.headers, [response.data.content])) body = b"".join(assemble_body(response.data.headers, [response.data.content]))
return head + body return head + body

View File

@ -3,10 +3,12 @@ import time
import sys import sys
import re import re
from ... import utils from netlib.http import request
from ...exceptions import HttpReadDisconnect, HttpSyntaxException, HttpException, TcpDisconnect from netlib.http import response
from .. import Request, Response, Headers from netlib.http import headers
from .. import url from netlib.http import url
from netlib import utils
from netlib import exceptions
def get_header_tokens(headers, key): def get_header_tokens(headers, key):
@ -40,9 +42,9 @@ def read_request_head(rfile):
The HTTP request object (without body) The HTTP request object (without body)
Raises: Raises:
HttpReadDisconnect: No bytes can be read from rfile. exceptions.HttpReadDisconnect: No bytes can be read from rfile.
HttpSyntaxException: The input is malformed HTTP. exceptions.HttpSyntaxException: The input is malformed HTTP.
HttpException: Any other error occured. exceptions.HttpException: Any other error occured.
""" """
timestamp_start = time.time() timestamp_start = time.time()
if hasattr(rfile, "reset_timestamps"): if hasattr(rfile, "reset_timestamps"):
@ -55,7 +57,7 @@ def read_request_head(rfile):
# more accurate timestamp_start # more accurate timestamp_start
timestamp_start = rfile.first_byte_timestamp timestamp_start = rfile.first_byte_timestamp
return Request( return request.Request(
form, method, scheme, host, port, path, http_version, headers, None, timestamp_start form, method, scheme, host, port, path, http_version, headers, None, timestamp_start
) )
@ -79,9 +81,9 @@ def read_response_head(rfile):
The HTTP request object (without body) The HTTP request object (without body)
Raises: Raises:
HttpReadDisconnect: No bytes can be read from rfile. exceptions.HttpReadDisconnect: No bytes can be read from rfile.
HttpSyntaxException: The input is malformed HTTP. exceptions.HttpSyntaxException: The input is malformed HTTP.
HttpException: Any other error occured. exceptions.HttpException: Any other error occured.
""" """
timestamp_start = time.time() timestamp_start = time.time()
@ -95,7 +97,7 @@ def read_response_head(rfile):
# more accurate timestamp_start # more accurate timestamp_start
timestamp_start = rfile.first_byte_timestamp timestamp_start = rfile.first_byte_timestamp
return Response(http_version, status_code, message, headers, None, timestamp_start) return response.Response(http_version, status_code, message, headers, None, timestamp_start)
def read_body(rfile, expected_size, limit=None, max_chunk_size=4096): def read_body(rfile, expected_size, limit=None, max_chunk_size=4096):
@ -112,7 +114,7 @@ def read_body(rfile, expected_size, limit=None, max_chunk_size=4096):
A generator that yields byte chunks of the content. A generator that yields byte chunks of the content.
Raises: Raises:
HttpException, if an error occurs exceptions.HttpException, if an error occurs
Caveats: Caveats:
max_chunk_size is not considered if the transfer encoding is chunked. max_chunk_size is not considered if the transfer encoding is chunked.
@ -127,7 +129,7 @@ def read_body(rfile, expected_size, limit=None, max_chunk_size=4096):
yield x yield x
elif expected_size >= 0: elif expected_size >= 0:
if limit is not None and expected_size > limit: if limit is not None and expected_size > limit:
raise HttpException( raise exceptions.HttpException(
"HTTP Body too large. " "HTTP Body too large. "
"Limit is {}, content length was advertised as {}".format(limit, expected_size) "Limit is {}, content length was advertised as {}".format(limit, expected_size)
) )
@ -136,7 +138,7 @@ def read_body(rfile, expected_size, limit=None, max_chunk_size=4096):
chunk_size = min(bytes_left, max_chunk_size) chunk_size = min(bytes_left, max_chunk_size)
content = rfile.read(chunk_size) content = rfile.read(chunk_size)
if len(content) < chunk_size: if len(content) < chunk_size:
raise HttpException("Unexpected EOF") raise exceptions.HttpException("Unexpected EOF")
yield content yield content
bytes_left -= chunk_size bytes_left -= chunk_size
else: else:
@ -150,7 +152,7 @@ def read_body(rfile, expected_size, limit=None, max_chunk_size=4096):
bytes_left -= chunk_size bytes_left -= chunk_size
not_done = rfile.read(1) not_done = rfile.read(1)
if not_done: if not_done:
raise HttpException("HTTP body too large. Limit is {}.".format(limit)) raise exceptions.HttpException("HTTP body too large. Limit is {}.".format(limit))
def connection_close(http_version, headers): def connection_close(http_version, headers):
@ -180,7 +182,7 @@ def expected_http_body_size(request, response=None):
- -1, if all data should be read until end of stream. - -1, if all data should be read until end of stream.
Raises: Raises:
HttpSyntaxException, if the content length header is invalid exceptions.HttpSyntaxException, if the content length header is invalid
""" """
# Determine response size according to # Determine response size according to
# http://tools.ietf.org/html/rfc7230#section-3.3 # http://tools.ietf.org/html/rfc7230#section-3.3
@ -215,7 +217,7 @@ def expected_http_body_size(request, response=None):
raise ValueError() raise ValueError()
return size return size
except ValueError: except ValueError:
raise HttpSyntaxException("Unparseable Content Length") raise exceptions.HttpSyntaxException("Unparseable Content Length")
if is_request: if is_request:
return 0 return 0
return -1 return -1
@ -227,19 +229,19 @@ def _get_first_line(rfile):
if line == b"\r\n" or line == b"\n": if line == b"\r\n" or line == b"\n":
# Possible leftover from previous message # Possible leftover from previous message
line = rfile.readline() line = rfile.readline()
except TcpDisconnect: except exceptions.TcpDisconnect:
raise HttpReadDisconnect("Remote disconnected") raise exceptions.HttpReadDisconnect("Remote disconnected")
if not line: if not line:
raise HttpReadDisconnect("Remote disconnected") raise exceptions.HttpReadDisconnect("Remote disconnected")
return line.strip() return line.strip()
def _read_request_line(rfile): def _read_request_line(rfile):
try: try:
line = _get_first_line(rfile) line = _get_first_line(rfile)
except HttpReadDisconnect: except exceptions.HttpReadDisconnect:
# We want to provide a better error message. # We want to provide a better error message.
raise HttpReadDisconnect("Client disconnected") raise exceptions.HttpReadDisconnect("Client disconnected")
try: try:
method, path, http_version = line.split(b" ") method, path, http_version = line.split(b" ")
@ -257,7 +259,7 @@ def _read_request_line(rfile):
_check_http_version(http_version) _check_http_version(http_version)
except ValueError: except ValueError:
raise HttpSyntaxException("Bad HTTP request line: {}".format(line)) raise exceptions.HttpSyntaxException("Bad HTTP request line: {}".format(line))
return form, method, scheme, host, port, path, http_version return form, method, scheme, host, port, path, http_version
@ -276,7 +278,7 @@ def _parse_authority_form(hostport):
if not utils.is_valid_host(host) or not utils.is_valid_port(port): if not utils.is_valid_host(host) or not utils.is_valid_port(port):
raise ValueError() raise ValueError()
except ValueError: except ValueError:
raise HttpSyntaxException("Invalid host specification: {}".format(hostport)) raise exceptions.HttpSyntaxException("Invalid host specification: {}".format(hostport))
return host, port return host, port
@ -284,9 +286,9 @@ def _parse_authority_form(hostport):
def _read_response_line(rfile): def _read_response_line(rfile):
try: try:
line = _get_first_line(rfile) line = _get_first_line(rfile)
except HttpReadDisconnect: except exceptions.HttpReadDisconnect:
# We want to provide a better error message. # We want to provide a better error message.
raise HttpReadDisconnect("Server disconnected") raise exceptions.HttpReadDisconnect("Server disconnected")
try: try:
@ -299,14 +301,14 @@ def _read_response_line(rfile):
_check_http_version(http_version) _check_http_version(http_version)
except ValueError: except ValueError:
raise HttpSyntaxException("Bad HTTP response line: {}".format(line)) raise exceptions.HttpSyntaxException("Bad HTTP response line: {}".format(line))
return http_version, status_code, message return http_version, status_code, message
def _check_http_version(http_version): def _check_http_version(http_version):
if not re.match(br"^HTTP/\d\.\d$", http_version): if not re.match(br"^HTTP/\d\.\d$", http_version):
raise HttpSyntaxException("Unknown HTTP version: {}".format(http_version)) raise exceptions.HttpSyntaxException("Unknown HTTP version: {}".format(http_version))
def _read_headers(rfile): def _read_headers(rfile):
@ -318,7 +320,7 @@ def _read_headers(rfile):
A headers object A headers object
Raises: Raises:
HttpSyntaxException exceptions.HttpSyntaxException
""" """
ret = [] ret = []
while True: while True:
@ -327,7 +329,7 @@ def _read_headers(rfile):
break break
if line[0] in b" \t": if line[0] in b" \t":
if not ret: if not ret:
raise HttpSyntaxException("Invalid headers") raise exceptions.HttpSyntaxException("Invalid headers")
# continued header # continued header
ret[-1] = (ret[-1][0], ret[-1][1] + b'\r\n ' + line.strip()) ret[-1] = (ret[-1][0], ret[-1][1] + b'\r\n ' + line.strip())
else: else:
@ -338,8 +340,8 @@ def _read_headers(rfile):
raise ValueError() raise ValueError()
ret.append((name, value)) ret.append((name, value))
except ValueError: except ValueError:
raise HttpSyntaxException("Invalid headers") raise exceptions.HttpSyntaxException("Invalid headers")
return Headers(ret) return headers.Headers(ret)
def _read_chunked(rfile, limit=sys.maxsize): def _read_chunked(rfile, limit=sys.maxsize):
@ -354,22 +356,22 @@ def _read_chunked(rfile, limit=sys.maxsize):
while True: while True:
line = rfile.readline(128) line = rfile.readline(128)
if line == b"": if line == b"":
raise HttpException("Connection closed prematurely") raise exceptions.HttpException("Connection closed prematurely")
if line != b"\r\n" and line != b"\n": if line != b"\r\n" and line != b"\n":
try: try:
length = int(line, 16) length = int(line, 16)
except ValueError: except ValueError:
raise HttpSyntaxException("Invalid chunked encoding length: {}".format(line)) raise exceptions.HttpSyntaxException("Invalid chunked encoding length: {}".format(line))
total += length total += length
if total > limit: if total > limit:
raise HttpException( raise exceptions.HttpException(
"HTTP Body too large. Limit is {}, " "HTTP Body too large. Limit is {}, "
"chunked content longer than {}".format(limit, total) "chunked content longer than {}".format(limit, total)
) )
chunk = rfile.read(length) chunk = rfile.read(length)
suffix = rfile.readline(5) suffix = rfile.readline(5)
if suffix != b"\r\n": if suffix != b"\r\n":
raise HttpSyntaxException("Malformed chunked body") raise exceptions.HttpSyntaxException("Malformed chunked body")
if length == 0: if length == 0:
return return
yield chunk yield chunk

View File

@ -5,9 +5,12 @@ import time
import hyperframe.frame import hyperframe.frame
from hpack.hpack import Encoder, Decoder from hpack.hpack import Encoder, Decoder
from ... import utils from netlib import utils
from .. import Headers, Response, Request, url from netlib.http import url
from . import framereader import netlib.http.headers
import netlib.http.response
import netlib.http.request
from netlib.http.http2 import framereader
class TCPHandler(object): class TCPHandler(object):
@ -128,7 +131,7 @@ class HTTP2Protocol(object):
port = 80 if scheme == 'http' else 443 port = 80 if scheme == 'http' else 443
port = int(port) port = int(port)
request = Request( request = netlib.http.request.Request(
first_line_format, first_line_format,
method.encode('ascii'), method.encode('ascii'),
scheme.encode('ascii'), scheme.encode('ascii'),
@ -176,7 +179,7 @@ class HTTP2Protocol(object):
else: else:
timestamp_end = None timestamp_end = None
response = Response( response = netlib.http.response.Response(
b"HTTP/2.0", b"HTTP/2.0",
int(headers.get(':status', 502)), int(headers.get(':status', 502)),
b'', b'',
@ -190,15 +193,15 @@ class HTTP2Protocol(object):
return response return response
def assemble(self, message): def assemble(self, message):
if isinstance(message, Request): if isinstance(message, netlib.http.request.Request):
return self.assemble_request(message) return self.assemble_request(message)
elif isinstance(message, Response): elif isinstance(message, netlib.http.response.Response):
return self.assemble_response(message) return self.assemble_response(message)
else: else:
raise ValueError("HTTP message not supported.") raise ValueError("HTTP message not supported.")
def assemble_request(self, request): def assemble_request(self, request):
assert isinstance(request, Request) assert isinstance(request, netlib.http.request.Request)
authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host
if self.tcp_handler.address.port != 443: if self.tcp_handler.address.port != 443:
@ -222,7 +225,7 @@ class HTTP2Protocol(object):
self._create_body(request.body, stream_id))) self._create_body(request.body, stream_id)))
def assemble_response(self, response): def assemble_response(self, response):
assert isinstance(response, Response) assert isinstance(response, netlib.http.response.Response)
headers = response.headers.copy() headers = response.headers.copy()
@ -422,7 +425,7 @@ class HTTP2Protocol(object):
else: else:
self._handle_unexpected_frame(frm) self._handle_unexpected_frame(frm)
headers = Headers( headers = netlib.http.headers.Headers(
(k.encode('ascii'), v.encode('ascii')) for k, v in self.decoder.decode(header_blocks) (k.encode('ascii'), v.encode('ascii')) for k, v in self.decoder.decode(header_blocks)
) )

View File

@ -4,8 +4,8 @@ import warnings
import six import six
from .. import encoding, utils, basetypes from netlib import encoding, utils, basetypes
from . import headers from netlib.http import headers
if six.PY2: # pragma: no cover if six.PY2: # pragma: no cover
def _native(x): def _native(x):

View File

@ -1,6 +1,6 @@
import re import re
from . import headers from netlib.http import headers
def decode(hdrs, content): def decode(hdrs, content):

View File

@ -5,14 +5,14 @@ import re
import six import six
from six.moves import urllib from six.moves import urllib
from netlib import encoding
from netlib import multidict
from netlib import utils from netlib import utils
import netlib.http.url
from netlib.http import multipart from netlib.http import multipart
from . import cookies from netlib.http import cookies
from .. import encoding from netlib.http import headers as nheaders
from ..multidict import MultiDictView from netlib.http import message
from .headers import Headers import netlib.http.url
from .message import Message, _native, _always_bytes, MessageData
# This regex extracts & splits the host header into host and port. # This regex extracts & splits the host header into host and port.
# Handles the edge case of IPv6 addresses containing colons. # Handles the edge case of IPv6 addresses containing colons.
@ -20,11 +20,11 @@ from .message import Message, _native, _always_bytes, MessageData
host_header_re = re.compile(r"^(?P<host>[^:]+|\[.+\])(?::(?P<port>\d+))?$") host_header_re = re.compile(r"^(?P<host>[^:]+|\[.+\])(?::(?P<port>\d+))?$")
class RequestData(MessageData): class RequestData(message.MessageData):
def __init__(self, first_line_format, method, scheme, host, port, path, http_version, headers=(), content=None, def __init__(self, first_line_format, method, scheme, host, port, path, http_version, headers=(), content=None,
timestamp_start=None, timestamp_end=None): timestamp_start=None, timestamp_end=None):
if not isinstance(headers, Headers): if not isinstance(headers, nheaders.Headers):
headers = Headers(headers) headers = nheaders.Headers(headers)
self.first_line_format = first_line_format self.first_line_format = first_line_format
self.method = method self.method = method
@ -39,7 +39,7 @@ class RequestData(MessageData):
self.timestamp_end = timestamp_end self.timestamp_end = timestamp_end
class Request(Message): class Request(message.Message):
""" """
An HTTP request. An HTTP request.
""" """
@ -91,22 +91,22 @@ class Request(Message):
""" """
HTTP request method, e.g. "GET". HTTP request method, e.g. "GET".
""" """
return _native(self.data.method).upper() return message._native(self.data.method).upper()
@method.setter @method.setter
def method(self, method): def method(self, method):
self.data.method = _always_bytes(method) self.data.method = message._always_bytes(method)
@property @property
def scheme(self): def scheme(self):
""" """
HTTP request scheme, which should be "http" or "https". HTTP request scheme, which should be "http" or "https".
""" """
return _native(self.data.scheme) return message._native(self.data.scheme)
@scheme.setter @scheme.setter
def scheme(self, scheme): def scheme(self, scheme):
self.data.scheme = _always_bytes(scheme) self.data.scheme = message._always_bytes(scheme)
@property @property
def host(self): def host(self):
@ -168,11 +168,11 @@ class Request(Message):
if self.data.path is None: if self.data.path is None:
return None return None
else: else:
return _native(self.data.path) return message._native(self.data.path)
@path.setter @path.setter
def path(self, path): def path(self, path):
self.data.path = _always_bytes(path) self.data.path = message._always_bytes(path)
@property @property
def url(self): def url(self):
@ -225,11 +225,11 @@ class Request(Message):
@property @property
def query(self): def query(self):
# type: () -> MultiDictView # type: () -> multidict.MultiDictView
""" """
The request query string as an :py:class:`MultiDictView` object. The request query string as an :py:class:`MultiDictView` object.
""" """
return MultiDictView( return multidict.MultiDictView(
self._get_query, self._get_query,
self._set_query self._set_query
) )
@ -250,13 +250,13 @@ class Request(Message):
@property @property
def cookies(self): def cookies(self):
# type: () -> MultiDictView # type: () -> multidict.MultiDictView
""" """
The request cookies. The request cookies.
An empty :py:class:`MultiDictView` object if the cookie monster ate them all. An empty :py:class:`multidict.MultiDictView` object if the cookie monster ate them all.
""" """
return MultiDictView( return multidict.MultiDictView(
self._get_cookies, self._get_cookies,
self._set_cookies self._set_cookies
) )
@ -329,11 +329,11 @@ class Request(Message):
@property @property
def urlencoded_form(self): def urlencoded_form(self):
""" """
The URL-encoded form data as an :py:class:`MultiDictView` object. The URL-encoded form data as an :py:class:`multidict.MultiDictView` object.
An empty MultiDictView if the content-type indicates non-form data An empty multidict.MultiDictView if the content-type indicates non-form data
or the content could not be parsed. or the content could not be parsed.
""" """
return MultiDictView( return multidict.MultiDictView(
self._get_urlencoded_form, self._get_urlencoded_form,
self._set_urlencoded_form self._set_urlencoded_form
) )
@ -362,7 +362,7 @@ class Request(Message):
The multipart form data as an :py:class:`MultipartFormDict` object. The multipart form data as an :py:class:`MultipartFormDict` object.
None if the content-type indicates non-form data. None if the content-type indicates non-form data.
""" """
return MultiDictView( return multidict.MultiDictView(
self._get_multipart_form, self._get_multipart_form,
self._set_multipart_form self._set_multipart_form
) )

View File

@ -3,18 +3,18 @@ from __future__ import absolute_import, print_function, division
from email.utils import parsedate_tz, formatdate, mktime_tz from email.utils import parsedate_tz, formatdate, mktime_tz
import time import time
from . import cookies from netlib.http import cookies
from .headers import Headers from netlib.http import headers as nheaders
from .message import Message, _native, _always_bytes, MessageData from netlib.http import message
from ..multidict import MultiDictView from netlib import multidict
from .. import human from netlib import human
class ResponseData(MessageData): class ResponseData(message.MessageData):
def __init__(self, http_version, status_code, reason=None, headers=(), content=None, def __init__(self, http_version, status_code, reason=None, headers=(), content=None,
timestamp_start=None, timestamp_end=None): timestamp_start=None, timestamp_end=None):
if not isinstance(headers, Headers): if not isinstance(headers, nheaders.Headers):
headers = Headers(headers) headers = nheaders.Headers(headers)
self.http_version = http_version self.http_version = http_version
self.status_code = status_code self.status_code = status_code
@ -25,7 +25,7 @@ class ResponseData(MessageData):
self.timestamp_end = timestamp_end self.timestamp_end = timestamp_end
class Response(Message): class Response(message.Message):
""" """
An HTTP response. An HTTP response.
""" """
@ -63,17 +63,17 @@ class Response(Message):
HTTP Reason Phrase, e.g. "Not Found". HTTP Reason Phrase, e.g. "Not Found".
This is always :py:obj:`None` for HTTP2 requests, because HTTP2 responses do not contain a reason phrase. This is always :py:obj:`None` for HTTP2 requests, because HTTP2 responses do not contain a reason phrase.
""" """
return _native(self.data.reason) return message._native(self.data.reason)
@reason.setter @reason.setter
def reason(self, reason): def reason(self, reason):
self.data.reason = _always_bytes(reason) self.data.reason = message._always_bytes(reason)
@property @property
def cookies(self): def cookies(self):
# type: () -> MultiDictView # type: () -> multidict.MultiDictView
""" """
The response cookies. A possibly empty :py:class:`MultiDictView`, where the keys are The response cookies. A possibly empty :py:class:`multidict.MultiDictView`, where the keys are
cookie name strings, and values are (value, attr) tuples. Value is a string, and attr is cookie name strings, and values are (value, attr) tuples. Value is a string, and attr is
an ODictCaseless containing cookie attributes. Within attrs, unary attributes (e.g. HTTPOnly) an ODictCaseless containing cookie attributes. Within attrs, unary attributes (e.g. HTTPOnly)
are indicated by a Null value. are indicated by a Null value.
@ -81,7 +81,7 @@ class Response(Message):
Caveats: Caveats:
Updating the attr Updating the attr
""" """
return MultiDictView( return multidict.MultiDictView(
self._get_cookies, self._get_cookies,
self._set_cookies self._set_cookies
) )

View File

@ -1,7 +1,7 @@
import six import six
from six.moves import urllib from six.moves import urllib
from .. import utils from netlib import utils
# PY2 workaround # PY2 workaround

View File

@ -9,7 +9,7 @@ except ImportError: # pragma: no cover
from collections import MutableMapping # Workaround for Python < 3.3 from collections import MutableMapping # Workaround for Python < 3.3
import six import six
from . import basetypes from netlib import basetypes
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)

View File

@ -3,7 +3,7 @@ import copy
import six import six
from . import basetypes, utils from netlib import basetypes, utils
class ODict(basetypes.Serializable): class ODict(basetypes.Serializable):

View File

@ -2,7 +2,8 @@ from __future__ import (absolute_import, print_function, division)
import struct import struct
import array import array
import ipaddress import ipaddress
from . import tcp, utils
from netlib import tcp, utils
class SocksError(Exception): class SocksError(Exception):

View File

@ -16,13 +16,10 @@ import six
import OpenSSL import OpenSSL
from OpenSSL import SSL from OpenSSL import SSL
from . import certutils, version_check, basetypes from netlib import certutils, version_check, basetypes, exceptions
# This is a rather hackish way to make sure that # This is a rather hackish way to make sure that
# the latest version of pyOpenSSL is actually installed. # the latest version of pyOpenSSL is actually installed.
from netlib.exceptions import InvalidCertificateException, TcpReadIncomplete, TlsException, \
TcpTimeout, TcpDisconnect, TcpException
version_check.check_pyopenssl_version() version_check.check_pyopenssl_version()
if six.PY2: if six.PY2:
@ -162,17 +159,17 @@ class Writer(_FileLike):
def flush(self): def flush(self):
""" """
May raise TcpDisconnect May raise exceptions.TcpDisconnect
""" """
if hasattr(self.o, "flush"): if hasattr(self.o, "flush"):
try: try:
self.o.flush() self.o.flush()
except (socket.error, IOError) as v: except (socket.error, IOError) as v:
raise TcpDisconnect(str(v)) raise exceptions.TcpDisconnect(str(v))
def write(self, v): def write(self, v):
""" """
May raise TcpDisconnect May raise exceptions.TcpDisconnect
""" """
if v: if v:
self.first_byte_timestamp = self.first_byte_timestamp or time.time() self.first_byte_timestamp = self.first_byte_timestamp or time.time()
@ -185,7 +182,7 @@ class Writer(_FileLike):
self.add_log(v[:r]) self.add_log(v[:r])
return r return r
except (SSL.Error, socket.error) as e: except (SSL.Error, socket.error) as e:
raise TcpDisconnect(str(e)) raise exceptions.TcpDisconnect(str(e))
class Reader(_FileLike): class Reader(_FileLike):
@ -216,17 +213,17 @@ class Reader(_FileLike):
time.sleep(0.1) time.sleep(0.1)
continue continue
else: else:
raise TcpTimeout() raise exceptions.TcpTimeout()
except socket.timeout: except socket.timeout:
raise TcpTimeout() raise exceptions.TcpTimeout()
except socket.error as e: except socket.error as e:
raise TcpDisconnect(str(e)) raise exceptions.TcpDisconnect(str(e))
except SSL.SysCallError as e: except SSL.SysCallError as e:
if e.args == (-1, 'Unexpected EOF'): if e.args == (-1, 'Unexpected EOF'):
break break
raise TlsException(str(e)) raise exceptions.TlsException(str(e))
except SSL.Error as e: except SSL.Error as e:
raise TlsException(str(e)) raise exceptions.TlsException(str(e))
self.first_byte_timestamp = self.first_byte_timestamp or time.time() self.first_byte_timestamp = self.first_byte_timestamp or time.time()
if not data: if not data:
break break
@ -260,9 +257,9 @@ class Reader(_FileLike):
result = self.read(length) result = self.read(length)
if length != -1 and len(result) != length: if length != -1 and len(result) != length:
if not result: if not result:
raise TcpDisconnect() raise exceptions.TcpDisconnect()
else: else:
raise TcpReadIncomplete( raise exceptions.TcpReadIncomplete(
"Expected %s bytes, got %s" % (length, len(result)) "Expected %s bytes, got %s" % (length, len(result))
) )
return result return result
@ -275,7 +272,7 @@ class Reader(_FileLike):
Up to the next N bytes if peeking is successful. Up to the next N bytes if peeking is successful.
Raises: Raises:
TcpException if there was an error with the socket exceptions.TcpException if there was an error with the socket
TlsException if there was an error with pyOpenSSL. TlsException if there was an error with pyOpenSSL.
NotImplementedError if the underlying file object is not a [pyOpenSSL] socket NotImplementedError if the underlying file object is not a [pyOpenSSL] socket
""" """
@ -283,7 +280,7 @@ class Reader(_FileLike):
try: try:
return self.o._sock.recv(length, socket.MSG_PEEK) return self.o._sock.recv(length, socket.MSG_PEEK)
except socket.error as e: except socket.error as e:
raise TcpException(repr(e)) raise exceptions.TcpException(repr(e))
elif isinstance(self.o, SSL.Connection): elif isinstance(self.o, SSL.Connection):
try: try:
if tuple(int(x) for x in OpenSSL.__version__.split(".")[:2]) > (0, 15): if tuple(int(x) for x in OpenSSL.__version__.split(".")[:2]) > (0, 15):
@ -297,7 +294,7 @@ class Reader(_FileLike):
self.o._raise_ssl_error(self.o._ssl, result) self.o._raise_ssl_error(self.o._ssl, result)
return SSL._ffi.buffer(buf, result)[:] return SSL._ffi.buffer(buf, result)[:]
except SSL.Error as e: except SSL.Error as e:
six.reraise(TlsException, TlsException(str(e)), sys.exc_info()[2]) six.reraise(exceptions.TlsException, exceptions.TlsException(str(e)), sys.exc_info()[2])
else: else:
raise NotImplementedError("Can only peek into (pyOpenSSL) sockets") raise NotImplementedError("Can only peek into (pyOpenSSL) sockets")
@ -490,7 +487,7 @@ class _Connection(object):
try: try:
self.wfile.flush() self.wfile.flush()
self.wfile.close() self.wfile.close()
except TcpDisconnect: except exceptions.TcpDisconnect:
pass pass
self.rfile.close() self.rfile.close()
@ -554,7 +551,7 @@ class _Connection(object):
# TODO: maybe change this to with newer pyOpenSSL APIs # TODO: maybe change this to with newer pyOpenSSL APIs
context.set_tmp_ecdh(OpenSSL.crypto.get_elliptic_curve('prime256v1')) context.set_tmp_ecdh(OpenSSL.crypto.get_elliptic_curve('prime256v1'))
except SSL.Error as v: except SSL.Error as v:
raise TlsException("SSL cipher specification error: %s" % str(v)) raise exceptions.TlsException("SSL cipher specification error: %s" % str(v))
# SSLKEYLOGFILE # SSLKEYLOGFILE
if log_ssl_key: if log_ssl_key:
@ -575,7 +572,7 @@ class _Connection(object):
elif alpn_select_callback is not None and alpn_select is None: elif alpn_select_callback is not None and alpn_select is None:
context.set_alpn_select_callback(alpn_select_callback) context.set_alpn_select_callback(alpn_select_callback)
elif alpn_select_callback is not None and alpn_select is not None: elif alpn_select_callback is not None and alpn_select is not None:
raise TlsException("ALPN error: only define alpn_select (string) OR alpn_select_callback (method).") raise exceptions.TlsException("ALPN error: only define alpn_select (string) OR alpn_select_callback (method).")
return context return context
@ -632,7 +629,7 @@ class TCPClient(_Connection):
context.use_privatekey_file(cert) context.use_privatekey_file(cert)
context.use_certificate_file(cert) context.use_certificate_file(cert)
except SSL.Error as v: except SSL.Error as v:
raise TlsException("SSL client certificate error: %s" % str(v)) raise exceptions.TlsException("SSL client certificate error: %s" % str(v))
return context return context
def convert_to_ssl(self, sni=None, alpn_protos=None, **sslctx_kwargs): def convert_to_ssl(self, sni=None, alpn_protos=None, **sslctx_kwargs):
@ -646,7 +643,7 @@ class TCPClient(_Connection):
""" """
verification_mode = sslctx_kwargs.get('verify_options', None) verification_mode = sslctx_kwargs.get('verify_options', None)
if verification_mode == SSL.VERIFY_PEER and not sni: if verification_mode == SSL.VERIFY_PEER and not sni:
raise TlsException("Cannot validate certificate hostname without SNI") raise exceptions.TlsException("Cannot validate certificate hostname without SNI")
context = self.create_ssl_context( context = self.create_ssl_context(
alpn_protos=alpn_protos, alpn_protos=alpn_protos,
@ -661,14 +658,14 @@ class TCPClient(_Connection):
self.connection.do_handshake() self.connection.do_handshake()
except SSL.Error as v: except SSL.Error as v:
if self.ssl_verification_error: if self.ssl_verification_error:
raise InvalidCertificateException("SSL handshake error: %s" % repr(v)) raise exceptions.InvalidCertificateException("SSL handshake error: %s" % repr(v))
else: else:
raise TlsException("SSL handshake error: %s" % repr(v)) raise exceptions.TlsException("SSL handshake error: %s" % repr(v))
else: else:
# Fix for pre v1.0 OpenSSL, which doesn't throw an exception on # Fix for pre v1.0 OpenSSL, which doesn't throw an exception on
# certificate validation failure # certificate validation failure
if verification_mode == SSL.VERIFY_PEER and self.ssl_verification_error is not None: if verification_mode == SSL.VERIFY_PEER and self.ssl_verification_error is not None:
raise InvalidCertificateException("SSL handshake error: certificate verify failed") raise exceptions.InvalidCertificateException("SSL handshake error: certificate verify failed")
self.cert = certutils.SSLCert(self.connection.get_peer_certificate()) self.cert = certutils.SSLCert(self.connection.get_peer_certificate())
@ -691,7 +688,7 @@ class TCPClient(_Connection):
except (ValueError, ssl_match_hostname.CertificateError) as e: except (ValueError, ssl_match_hostname.CertificateError) as e:
self.ssl_verification_error = dict(depth=0, errno="Invalid Hostname") self.ssl_verification_error = dict(depth=0, errno="Invalid Hostname")
if verification_mode == SSL.VERIFY_PEER: if verification_mode == SSL.VERIFY_PEER:
raise InvalidCertificateException("Presented certificate for {} is not valid: {}".format(sni, str(e))) raise exceptions.InvalidCertificateException("Presented certificate for {} is not valid: {}".format(sni, str(e)))
self.ssl_established = True self.ssl_established = True
self.rfile.set_descriptor(self.connection) self.rfile.set_descriptor(self.connection)
@ -705,7 +702,7 @@ class TCPClient(_Connection):
connection.connect(self.address()) connection.connect(self.address())
self.source_address = Address(connection.getsockname()) self.source_address = Address(connection.getsockname())
except (socket.error, IOError) as err: except (socket.error, IOError) as err:
raise TcpException( raise exceptions.TcpException(
'Error connecting to "%s": %s' % 'Error connecting to "%s": %s' %
(self.address.host, err)) (self.address.host, err))
self.connection = connection self.connection = connection
@ -818,7 +815,7 @@ class BaseHandler(_Connection):
try: try:
self.connection.do_handshake() self.connection.do_handshake()
except SSL.Error as v: except SSL.Error as v:
raise TlsException("SSL handshake error: %s" % repr(v)) raise exceptions.TlsException("SSL handshake error: %s" % repr(v))
self.ssl_established = True self.ssl_established = True
self.rfile.set_descriptor(self.connection) self.rfile.set_descriptor(self.connection)
self.wfile.set_descriptor(self.connection) self.wfile.set_descriptor(self.connection)

View File

@ -7,8 +7,7 @@ from contextlib import contextmanager
import six import six
import sys import sys
from . import utils, tcp from netlib import utils, tcp, http
from .http import Request, Response, Headers
def treader(bytes): def treader(bytes):
@ -107,11 +106,11 @@ def treq(**kwargs):
port=22, port=22,
path=b"/path", path=b"/path",
http_version=b"HTTP/1.1", http_version=b"HTTP/1.1",
headers=Headers(((b"header", b"qvalue"), (b"content-length", b"7"))), headers=http.Headers(((b"header", b"qvalue"), (b"content-length", b"7"))),
content=b"content" content=b"content"
) )
default.update(kwargs) default.update(kwargs)
return Request(**default) return http.Request(**default)
def tresp(**kwargs): def tresp(**kwargs):
@ -123,10 +122,10 @@ def tresp(**kwargs):
http_version=b"HTTP/1.1", http_version=b"HTTP/1.1",
status_code=200, status_code=200,
reason=b"OK", reason=b"OK",
headers=Headers(((b"header-response", b"svalue"), (b"content-length", b"7"))), headers=http.Headers(((b"header-response", b"svalue"), (b"content-length", b"7"))),
content=b"message", content=b"message",
timestamp_start=time.time(), timestamp_start=time.time(),
timestamp_end=time.time(), timestamp_end=time.time(),
) )
default.update(kwargs) default.update(kwargs)
return Response(**default) return http.Response(**default)

View File

@ -6,10 +6,10 @@ import warnings
import six import six
from .protocol import Masker
from netlib import tcp from netlib import tcp
from netlib import utils from netlib import utils
from netlib import human from netlib import human
from netlib.websockets import protocol
MAX_16_BIT_INT = (1 << 16) MAX_16_BIT_INT = (1 << 16)
@ -267,7 +267,7 @@ class Frame(object):
""" """
b = bytes(self.header) b = bytes(self.header)
if self.header.masking_key: if self.header.masking_key:
b += Masker(self.header.masking_key)(self.payload) b += protocol.Masker(self.header.masking_key)(self.payload)
else: else:
b += self.payload b += self.payload
return b return b
@ -296,7 +296,7 @@ class Frame(object):
payload = fp.safe_read(header.payload_length) payload = fp.safe_read(header.payload_length)
if header.mask == 1 and header.masking_key: if header.mask == 1 and header.masking_key:
payload = Masker(header.masking_key)(payload) payload = protocol.Masker(header.masking_key)(payload)
return cls( return cls(
payload, payload,

View File

@ -19,7 +19,8 @@ import hashlib
import os import os
import six import six
from ..http import Headers
from netlib import http
websockets_magic = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11' websockets_magic = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
VERSION = "13" VERSION = "13"
@ -72,11 +73,11 @@ class WebsocketsProtocol(object):
specified, it is generated, and can be found in sec-websocket-key in specified, it is generated, and can be found in sec-websocket-key in
the returned header set. the returned header set.
Returns an instance of Headers Returns an instance of http.Headers
""" """
if not key: if not key:
key = base64.b64encode(os.urandom(16)).decode('ascii') key = base64.b64encode(os.urandom(16)).decode('ascii')
return Headers( return http.Headers(
sec_websocket_key=key, sec_websocket_key=key,
sec_websocket_version=version, sec_websocket_version=version,
connection="Upgrade", connection="Upgrade",
@ -88,7 +89,7 @@ class WebsocketsProtocol(object):
""" """
The server response is a valid HTTP 101 response. The server response is a valid HTTP 101 response.
""" """
return Headers( return http.Headers(
sec_websocket_accept=self.create_server_nonce(key), sec_websocket_accept=self.create_server_nonce(key),
connection="Upgrade", connection="Upgrade",
upgrade="websocket" upgrade="websocket"

View File

@ -6,8 +6,7 @@ import six
from io import BytesIO from io import BytesIO
from six.moves import urllib from six.moves import urllib
from netlib.utils import always_bytes, native from netlib import http, tcp, utils
from . import http, tcp
class ClientConn(object): class ClientConn(object):
@ -55,38 +54,38 @@ class WSGIAdaptor(object):
self.app, self.domain, self.port, self.sversion = app, domain, port, sversion self.app, self.domain, self.port, self.sversion = app, domain, port, sversion
def make_environ(self, flow, errsoc, **extra): def make_environ(self, flow, errsoc, **extra):
path = native(flow.request.path, "latin-1") path = utils.native(flow.request.path, "latin-1")
if '?' in path: if '?' in path:
path_info, query = native(path, "latin-1").split('?', 1) path_info, query = utils.native(path, "latin-1").split('?', 1)
else: else:
path_info = path path_info = path
query = '' query = ''
environ = { environ = {
'wsgi.version': (1, 0), 'wsgi.version': (1, 0),
'wsgi.url_scheme': native(flow.request.scheme, "latin-1"), 'wsgi.url_scheme': utils.native(flow.request.scheme, "latin-1"),
'wsgi.input': BytesIO(flow.request.content or b""), 'wsgi.input': BytesIO(flow.request.content or b""),
'wsgi.errors': errsoc, 'wsgi.errors': errsoc,
'wsgi.multithread': True, 'wsgi.multithread': True,
'wsgi.multiprocess': False, 'wsgi.multiprocess': False,
'wsgi.run_once': False, 'wsgi.run_once': False,
'SERVER_SOFTWARE': self.sversion, 'SERVER_SOFTWARE': self.sversion,
'REQUEST_METHOD': native(flow.request.method, "latin-1"), 'REQUEST_METHOD': utils.native(flow.request.method, "latin-1"),
'SCRIPT_NAME': '', 'SCRIPT_NAME': '',
'PATH_INFO': urllib.parse.unquote(path_info), 'PATH_INFO': urllib.parse.unquote(path_info),
'QUERY_STRING': query, 'QUERY_STRING': query,
'CONTENT_TYPE': native(flow.request.headers.get('Content-Type', ''), "latin-1"), 'CONTENT_TYPE': utils.native(flow.request.headers.get('Content-Type', ''), "latin-1"),
'CONTENT_LENGTH': native(flow.request.headers.get('Content-Length', ''), "latin-1"), 'CONTENT_LENGTH': utils.native(flow.request.headers.get('Content-Length', ''), "latin-1"),
'SERVER_NAME': self.domain, 'SERVER_NAME': self.domain,
'SERVER_PORT': str(self.port), 'SERVER_PORT': str(self.port),
'SERVER_PROTOCOL': native(flow.request.http_version, "latin-1"), 'SERVER_PROTOCOL': utils.native(flow.request.http_version, "latin-1"),
} }
environ.update(extra) environ.update(extra)
if flow.client_conn.address: if flow.client_conn.address:
environ["REMOTE_ADDR"] = native(flow.client_conn.address.host, "latin-1") environ["REMOTE_ADDR"] = utils.native(flow.client_conn.address.host, "latin-1")
environ["REMOTE_PORT"] = flow.client_conn.address.port environ["REMOTE_PORT"] = flow.client_conn.address.port
for key, value in flow.request.headers.items(): for key, value in flow.request.headers.items():
key = 'HTTP_' + native(key, "latin-1").upper().replace('-', '_') key = 'HTTP_' + utils.native(key, "latin-1").upper().replace('-', '_')
if key not in ('HTTP_CONTENT_TYPE', 'HTTP_CONTENT_LENGTH'): if key not in ('HTTP_CONTENT_TYPE', 'HTTP_CONTENT_LENGTH'):
environ[key] = value environ[key] = value
return environ return environ
@ -140,7 +139,7 @@ class WSGIAdaptor(object):
elif state["status"]: elif state["status"]:
raise AssertionError('Response already started') raise AssertionError('Response already started')
state["status"] = status state["status"] = status
state["headers"] = http.Headers([[always_bytes(k), always_bytes(v)] for k, v in headers]) state["headers"] = http.Headers([[utils.always_bytes(k), utils.always_bytes(v)] for k, v in headers])
if exc_info: if exc_info:
self.error_page(soc, state["headers_sent"], traceback.format_tb(exc_info[2])) self.error_page(soc, state["headers_sent"], traceback.format_tb(exc_info[2]))
state["headers_sent"] = True state["headers_sent"] = True