mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-26 18:18:25 +00:00
cleanup code with autopep8
run the following command: $ autopep8 -i -r -a -a .
This commit is contained in:
parent
f7b75ba8c2
commit
e3d390e036
@ -1,12 +1,15 @@
|
||||
from __future__ import (absolute_import, print_function, division)
|
||||
import os, ssl, time, datetime
|
||||
import os
|
||||
import ssl
|
||||
import time
|
||||
import datetime
|
||||
import itertools
|
||||
from pyasn1.type import univ, constraint, char, namedtype, tag
|
||||
from pyasn1.codec.der.decoder import decode
|
||||
from pyasn1.error import PyAsn1Error
|
||||
import OpenSSL
|
||||
|
||||
DEFAULT_EXP = 157680000 # = 24 * 60 * 60 * 365 * 5
|
||||
DEFAULT_EXP = 157680000 # = 24 * 60 * 60 * 365 * 5
|
||||
# Generated with "openssl dhparam". It's too slow to generate this on startup.
|
||||
DEFAULT_DHPARAM = """-----BEGIN DH PARAMETERS-----
|
||||
MIGHAoGBAOdPzMbYgoYfO3YBYauCLRlE8X1XypTiAjoeCFD0qWRx8YUsZ6Sj20W5
|
||||
@ -14,31 +17,32 @@ zsfQxlZfKovo3f2MftjkDkbI/C/tDgxoe0ZPbjy5CjdOhkzxn0oTbKTs16Rw8DyK
|
||||
1LjTR65sQJkJEdgsX8TSi/cicCftJZl9CaZEaObF2bdgSgGK+PezAgEC
|
||||
-----END DH PARAMETERS-----"""
|
||||
|
||||
|
||||
def create_ca(o, cn, exp):
|
||||
key = OpenSSL.crypto.PKey()
|
||||
key.generate_key(OpenSSL.crypto.TYPE_RSA, 1024)
|
||||
cert = OpenSSL.crypto.X509()
|
||||
cert.set_serial_number(int(time.time()*10000))
|
||||
cert.set_serial_number(int(time.time() * 10000))
|
||||
cert.set_version(2)
|
||||
cert.get_subject().CN = cn
|
||||
cert.get_subject().O = o
|
||||
cert.gmtime_adj_notBefore(-3600*48)
|
||||
cert.gmtime_adj_notBefore(-3600 * 48)
|
||||
cert.gmtime_adj_notAfter(exp)
|
||||
cert.set_issuer(cert.get_subject())
|
||||
cert.set_pubkey(key)
|
||||
cert.add_extensions([
|
||||
OpenSSL.crypto.X509Extension("basicConstraints", True,
|
||||
"CA:TRUE"),
|
||||
OpenSSL.crypto.X509Extension("nsCertType", False,
|
||||
"sslCA"),
|
||||
OpenSSL.crypto.X509Extension("extendedKeyUsage", False,
|
||||
"serverAuth,clientAuth,emailProtection,timeStamping,msCodeInd,msCodeCom,msCTLSign,msSGC,msEFS,nsSGC"
|
||||
),
|
||||
OpenSSL.crypto.X509Extension("keyUsage", True,
|
||||
"keyCertSign, cRLSign"),
|
||||
OpenSSL.crypto.X509Extension("subjectKeyIdentifier", False, "hash",
|
||||
subject=cert),
|
||||
])
|
||||
OpenSSL.crypto.X509Extension("basicConstraints", True,
|
||||
"CA:TRUE"),
|
||||
OpenSSL.crypto.X509Extension("nsCertType", False,
|
||||
"sslCA"),
|
||||
OpenSSL.crypto.X509Extension("extendedKeyUsage", False,
|
||||
"serverAuth,clientAuth,emailProtection,timeStamping,msCodeInd,msCodeCom,msCTLSign,msSGC,msEFS,nsSGC"
|
||||
),
|
||||
OpenSSL.crypto.X509Extension("keyUsage", True,
|
||||
"keyCertSign, cRLSign"),
|
||||
OpenSSL.crypto.X509Extension("subjectKeyIdentifier", False, "hash",
|
||||
subject=cert),
|
||||
])
|
||||
cert.sign(key, "sha1")
|
||||
return key, cert
|
||||
|
||||
@ -56,15 +60,15 @@ def dummy_cert(privkey, cacert, commonname, sans):
|
||||
"""
|
||||
ss = []
|
||||
for i in sans:
|
||||
ss.append("DNS: %s"%i)
|
||||
ss.append("DNS: %s" % i)
|
||||
ss = ", ".join(ss)
|
||||
|
||||
cert = OpenSSL.crypto.X509()
|
||||
cert.gmtime_adj_notBefore(-3600*48)
|
||||
cert.gmtime_adj_notBefore(-3600 * 48)
|
||||
cert.gmtime_adj_notAfter(DEFAULT_EXP)
|
||||
cert.set_issuer(cacert.get_subject())
|
||||
cert.get_subject().CN = commonname
|
||||
cert.set_serial_number(int(time.time()*10000))
|
||||
cert.set_serial_number(int(time.time() * 10000))
|
||||
if ss:
|
||||
cert.set_version(2)
|
||||
cert.add_extensions([OpenSSL.crypto.X509Extension("subjectAltName", False, ss)])
|
||||
@ -114,6 +118,7 @@ def dummy_cert(privkey, cacert, commonname, sans):
|
||||
|
||||
|
||||
class CertStoreEntry(object):
|
||||
|
||||
def __init__(self, cert, privatekey, chain_file):
|
||||
self.cert = cert
|
||||
self.privatekey = privatekey
|
||||
@ -121,9 +126,11 @@ class CertStoreEntry(object):
|
||||
|
||||
|
||||
class CertStore(object):
|
||||
|
||||
"""
|
||||
Implements an in-memory certificate store.
|
||||
"""
|
||||
|
||||
def __init__(self, default_privatekey, default_ca, default_chain_file, dhparams=None):
|
||||
self.default_privatekey = default_privatekey
|
||||
self.default_ca = default_ca
|
||||
@ -144,11 +151,11 @@ class CertStore(object):
|
||||
if bio != OpenSSL.SSL._ffi.NULL:
|
||||
bio = OpenSSL.SSL._ffi.gc(bio, OpenSSL.SSL._lib.BIO_free)
|
||||
dh = OpenSSL.SSL._lib.PEM_read_bio_DHparams(
|
||||
bio, OpenSSL.SSL._ffi.NULL, OpenSSL.SSL._ffi.NULL, OpenSSL.SSL._ffi.NULL
|
||||
)
|
||||
bio, OpenSSL.SSL._ffi.NULL, OpenSSL.SSL._ffi.NULL, OpenSSL.SSL._ffi.NULL
|
||||
)
|
||||
dh = OpenSSL.SSL._ffi.gc(dh, OpenSSL.SSL._lib.DH_free)
|
||||
return dh
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_store(cls, path, basename):
|
||||
ca_path = os.path.join(path, basename + "-ca.pem")
|
||||
@ -277,8 +284,8 @@ class _GeneralName(univ.Choice):
|
||||
# other types.
|
||||
componentType = namedtype.NamedTypes(
|
||||
namedtype.NamedType('dNSName', char.IA5String().subtype(
|
||||
implicitTag=tag.Tag(tag.tagClassContext, tag.tagFormatSimple, 2)
|
||||
)
|
||||
implicitTag=tag.Tag(tag.tagClassContext, tag.tagFormatSimple, 2)
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
@ -289,6 +296,7 @@ class _GeneralNames(univ.SequenceOf):
|
||||
|
||||
|
||||
class SSLCert(object):
|
||||
|
||||
def __init__(self, cert):
|
||||
"""
|
||||
Returns a (common name, [subject alternative names]) tuple.
|
||||
|
@ -5,8 +5,11 @@ import struct
|
||||
import io
|
||||
|
||||
from .. import utils, odict, tcp
|
||||
from functools import reduce
|
||||
|
||||
|
||||
class Frame(object):
|
||||
|
||||
"""
|
||||
Baseclass Frame
|
||||
contains header
|
||||
@ -53,6 +56,7 @@ class Frame(object):
|
||||
def __eq__(self, other):
|
||||
return self.to_bytes() == other.to_bytes()
|
||||
|
||||
|
||||
class DataFrame(Frame):
|
||||
TYPE = 0x0
|
||||
VALID_FLAGS = [Frame.FLAG_END_STREAM, Frame.FLAG_PADDED]
|
||||
@ -89,11 +93,13 @@ class DataFrame(Frame):
|
||||
|
||||
return b
|
||||
|
||||
|
||||
class HeadersFrame(Frame):
|
||||
TYPE = 0x1
|
||||
VALID_FLAGS = [Frame.FLAG_END_STREAM, Frame.FLAG_END_HEADERS, Frame.FLAG_PADDED, Frame.FLAG_PRIORITY]
|
||||
|
||||
def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, header_block_fragment=b'', pad_length=0, exclusive=False, stream_dependency=0x0, weight=0):
|
||||
def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, header_block_fragment=b'',
|
||||
pad_length=0, exclusive=False, stream_dependency=0x0, weight=0):
|
||||
super(HeadersFrame, self).__init__(length, flags, stream_id)
|
||||
self.header_block_fragment = header_block_fragment
|
||||
self.pad_length = pad_length
|
||||
@ -137,6 +143,7 @@ class HeadersFrame(Frame):
|
||||
|
||||
return b
|
||||
|
||||
|
||||
class PriorityFrame(Frame):
|
||||
TYPE = 0x2
|
||||
VALID_FLAGS = []
|
||||
@ -166,6 +173,7 @@ class PriorityFrame(Frame):
|
||||
|
||||
return struct.pack('!LB', (int(self.exclusive) << 31) | self.stream_dependency, self.weight)
|
||||
|
||||
|
||||
class RstStreamFrame(Frame):
|
||||
TYPE = 0x3
|
||||
VALID_FLAGS = []
|
||||
@ -186,18 +194,19 @@ class RstStreamFrame(Frame):
|
||||
|
||||
return struct.pack('!L', self.error_code)
|
||||
|
||||
|
||||
class SettingsFrame(Frame):
|
||||
TYPE = 0x4
|
||||
VALID_FLAGS = [Frame.FLAG_ACK]
|
||||
|
||||
SETTINGS = utils.BiDi(
|
||||
SETTINGS_HEADER_TABLE_SIZE = 0x1,
|
||||
SETTINGS_ENABLE_PUSH = 0x2,
|
||||
SETTINGS_MAX_CONCURRENT_STREAMS = 0x3,
|
||||
SETTINGS_INITIAL_WINDOW_SIZE = 0x4,
|
||||
SETTINGS_MAX_FRAME_SIZE = 0x5,
|
||||
SETTINGS_MAX_HEADER_LIST_SIZE = 0x6,
|
||||
)
|
||||
SETTINGS_HEADER_TABLE_SIZE=0x1,
|
||||
SETTINGS_ENABLE_PUSH=0x2,
|
||||
SETTINGS_MAX_CONCURRENT_STREAMS=0x3,
|
||||
SETTINGS_INITIAL_WINDOW_SIZE=0x4,
|
||||
SETTINGS_MAX_FRAME_SIZE=0x5,
|
||||
SETTINGS_MAX_HEADER_LIST_SIZE=0x6,
|
||||
)
|
||||
|
||||
def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, settings={}):
|
||||
super(SettingsFrame, self).__init__(length, flags, stream_id)
|
||||
@ -208,7 +217,7 @@ class SettingsFrame(Frame):
|
||||
f = self(length=length, flags=flags, stream_id=stream_id)
|
||||
|
||||
for i in xrange(0, len(payload), 6):
|
||||
identifier, value = struct.unpack("!HL", payload[i:i+6])
|
||||
identifier, value = struct.unpack("!HL", payload[i:i + 6])
|
||||
f.settings[identifier] = value
|
||||
|
||||
return f
|
||||
@ -223,6 +232,7 @@ class SettingsFrame(Frame):
|
||||
|
||||
return b
|
||||
|
||||
|
||||
class PushPromiseFrame(Frame):
|
||||
TYPE = 0x5
|
||||
VALID_FLAGS = [Frame.FLAG_END_HEADERS, Frame.FLAG_PADDED]
|
||||
@ -267,6 +277,7 @@ class PushPromiseFrame(Frame):
|
||||
|
||||
return b
|
||||
|
||||
|
||||
class PingFrame(Frame):
|
||||
TYPE = 0x6
|
||||
VALID_FLAGS = [Frame.FLAG_ACK]
|
||||
@ -289,6 +300,7 @@ class PingFrame(Frame):
|
||||
b += b'\0' * (8 - len(b))
|
||||
return b
|
||||
|
||||
|
||||
class GoAwayFrame(Frame):
|
||||
TYPE = 0x7
|
||||
VALID_FLAGS = []
|
||||
@ -317,6 +329,7 @@ class GoAwayFrame(Frame):
|
||||
b += bytes(self.data)
|
||||
return b
|
||||
|
||||
|
||||
class WindowUpdateFrame(Frame):
|
||||
TYPE = 0x8
|
||||
VALID_FLAGS = []
|
||||
@ -335,11 +348,12 @@ class WindowUpdateFrame(Frame):
|
||||
return f
|
||||
|
||||
def payload_bytes(self):
|
||||
if self.window_size_increment <= 0 or self.window_size_increment >= 2**31:
|
||||
if self.window_size_increment <= 0 or self.window_size_increment >= 2 ** 31:
|
||||
raise ValueError('Window Szie Increment MUST be greater than 0 and less than 2^31.')
|
||||
|
||||
return struct.pack('!L', self.window_size_increment & 0x7FFFFFFF)
|
||||
|
||||
|
||||
class ContinuationFrame(Frame):
|
||||
TYPE = 0x9
|
||||
VALID_FLAGS = [Frame.FLAG_END_HEADERS]
|
||||
|
@ -8,18 +8,18 @@ import io
|
||||
CLIENT_CONNECTION_PREFACE = '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'
|
||||
|
||||
ERROR_CODES = utils.BiDi(
|
||||
NO_ERROR = 0x0,
|
||||
PROTOCOL_ERROR = 0x1,
|
||||
INTERNAL_ERROR = 0x2,
|
||||
FLOW_CONTROL_ERROR = 0x3,
|
||||
SETTINGS_TIMEOUT = 0x4,
|
||||
STREAM_CLOSED = 0x5,
|
||||
FRAME_SIZE_ERROR = 0x6,
|
||||
REFUSED_STREAM = 0x7,
|
||||
CANCEL = 0x8,
|
||||
COMPRESSION_ERROR = 0x9,
|
||||
CONNECT_ERROR = 0xa,
|
||||
ENHANCE_YOUR_CALM = 0xb,
|
||||
INADEQUATE_SECURITY = 0xc,
|
||||
HTTP_1_1_REQUIRED = 0xd
|
||||
)
|
||||
NO_ERROR=0x0,
|
||||
PROTOCOL_ERROR=0x1,
|
||||
INTERNAL_ERROR=0x2,
|
||||
FLOW_CONTROL_ERROR=0x3,
|
||||
SETTINGS_TIMEOUT=0x4,
|
||||
STREAM_CLOSED=0x5,
|
||||
FRAME_SIZE_ERROR=0x6,
|
||||
REFUSED_STREAM=0x7,
|
||||
CANCEL=0x8,
|
||||
COMPRESSION_ERROR=0x9,
|
||||
CONNECT_ERROR=0xa,
|
||||
ENHANCE_YOUR_CALM=0xb,
|
||||
INADEQUATE_SECURITY=0xc,
|
||||
HTTP_1_1_REQUIRED=0xd
|
||||
)
|
||||
|
@ -8,6 +8,7 @@ from . import odict, utils, tcp, http_status
|
||||
|
||||
|
||||
class HttpError(Exception):
|
||||
|
||||
def __init__(self, code, message):
|
||||
super(HttpError, self).__init__(message)
|
||||
self.code = code
|
||||
@ -95,7 +96,7 @@ def read_headers(fp):
|
||||
"""
|
||||
ret = []
|
||||
name = ''
|
||||
while 1:
|
||||
while True:
|
||||
line = fp.readline()
|
||||
if not line or line == '\r\n' or line == '\n':
|
||||
break
|
||||
@ -337,7 +338,7 @@ def read_http_body_chunked(
|
||||
otherwise
|
||||
"""
|
||||
if max_chunk_size is None:
|
||||
max_chunk_size = limit or sys.maxint
|
||||
max_chunk_size = limit or sys.maxsize
|
||||
|
||||
expected_size = expected_http_body_size(
|
||||
headers, is_request, request_method, response_code
|
||||
@ -399,10 +400,10 @@ def expected_http_body_size(headers, is_request, request_method, response_code):
|
||||
request_method = request_method.upper()
|
||||
|
||||
if (not is_request and (
|
||||
request_method == "HEAD" or
|
||||
(request_method == "CONNECT" and response_code == 200) or
|
||||
response_code in [204, 304] or
|
||||
100 <= response_code <= 199)):
|
||||
request_method == "HEAD" or
|
||||
(request_method == "CONNECT" and response_code == 200) or
|
||||
response_code in [204, 304] or
|
||||
100 <= response_code <= 199)):
|
||||
return 0
|
||||
if has_chunked_encoding(headers):
|
||||
return None
|
||||
|
@ -4,9 +4,11 @@ from . import http
|
||||
|
||||
|
||||
class NullProxyAuth(object):
|
||||
|
||||
"""
|
||||
No proxy auth at all (returns empty challange headers)
|
||||
"""
|
||||
|
||||
def __init__(self, password_manager):
|
||||
self.password_manager = password_manager
|
||||
|
||||
@ -48,7 +50,7 @@ class BasicProxyAuth(NullProxyAuth):
|
||||
if not parts:
|
||||
return False
|
||||
scheme, username, password = parts
|
||||
if scheme.lower()!='basic':
|
||||
if scheme.lower() != 'basic':
|
||||
return False
|
||||
if not self.password_manager.test(username, password):
|
||||
return False
|
||||
@ -56,18 +58,21 @@ class BasicProxyAuth(NullProxyAuth):
|
||||
return True
|
||||
|
||||
def auth_challenge_headers(self):
|
||||
return {self.CHALLENGE_HEADER:'Basic realm="%s"'%self.realm}
|
||||
return {self.CHALLENGE_HEADER: 'Basic realm="%s"' % self.realm}
|
||||
|
||||
|
||||
class PassMan(object):
|
||||
|
||||
def test(self, username, password_token):
|
||||
return False
|
||||
|
||||
|
||||
class PassManNonAnon(PassMan):
|
||||
|
||||
"""
|
||||
Ensure the user specifies a username, accept any password.
|
||||
"""
|
||||
|
||||
def test(self, username, password_token):
|
||||
if username:
|
||||
return True
|
||||
@ -75,9 +80,11 @@ class PassManNonAnon(PassMan):
|
||||
|
||||
|
||||
class PassManHtpasswd(PassMan):
|
||||
|
||||
"""
|
||||
Read usernames and passwords from an htpasswd file
|
||||
"""
|
||||
|
||||
def __init__(self, path):
|
||||
"""
|
||||
Raises ValueError if htpasswd file is invalid.
|
||||
@ -90,14 +97,16 @@ class PassManHtpasswd(PassMan):
|
||||
|
||||
|
||||
class PassManSingleUser(PassMan):
|
||||
|
||||
def __init__(self, username, password):
|
||||
self.username, self.password = username, password
|
||||
|
||||
def test(self, username, password_token):
|
||||
return self.username==username and self.password==password_token
|
||||
return self.username == username and self.password == password_token
|
||||
|
||||
|
||||
class AuthAction(Action):
|
||||
|
||||
"""
|
||||
Helper class to allow seamless integration int argparse. Example usage:
|
||||
parser.add_argument(
|
||||
@ -106,16 +115,18 @@ class AuthAction(Action):
|
||||
help="Allow access to any user long as a credentials are specified."
|
||||
)
|
||||
"""
|
||||
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
passman = self.getPasswordManager(values)
|
||||
authenticator = BasicProxyAuth(passman, "mitmproxy")
|
||||
setattr(namespace, self.dest, authenticator)
|
||||
|
||||
def getPasswordManager(self, s): # pragma: nocover
|
||||
def getPasswordManager(self, s): # pragma: nocover
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class SingleuserAuthAction(AuthAction):
|
||||
|
||||
def getPasswordManager(self, s):
|
||||
if len(s.split(':')) != 2:
|
||||
raise ArgumentTypeError(
|
||||
@ -126,11 +137,12 @@ class SingleuserAuthAction(AuthAction):
|
||||
|
||||
|
||||
class NonanonymousAuthAction(AuthAction):
|
||||
|
||||
def getPasswordManager(self, s):
|
||||
return PassManNonAnon()
|
||||
|
||||
|
||||
class HtpasswdAuthAction(AuthAction):
|
||||
|
||||
def getPasswordManager(self, s):
|
||||
return PassManHtpasswd(s)
|
||||
|
||||
|
@ -96,7 +96,7 @@ def _read_pairs(s, off=0, specials=()):
|
||||
specials: a lower-cased list of keys that may contain commas
|
||||
"""
|
||||
vals = []
|
||||
while 1:
|
||||
while True:
|
||||
lhs, off = _read_token(s, off)
|
||||
lhs = lhs.lstrip()
|
||||
if lhs:
|
||||
@ -135,15 +135,15 @@ def _format_pairs(lst, specials=(), sep="; "):
|
||||
else:
|
||||
if k.lower() not in specials and _has_special(v):
|
||||
v = ESCAPE.sub(r"\\\1", v)
|
||||
v = '"%s"'%v
|
||||
vals.append("%s=%s"%(k, v))
|
||||
v = '"%s"' % v
|
||||
vals.append("%s=%s" % (k, v))
|
||||
return sep.join(vals)
|
||||
|
||||
|
||||
def _format_set_cookie_pairs(lst):
|
||||
return _format_pairs(
|
||||
lst,
|
||||
specials = ("expires", "path")
|
||||
specials=("expires", "path")
|
||||
)
|
||||
|
||||
|
||||
@ -154,7 +154,7 @@ def _parse_set_cookie_pairs(s):
|
||||
"""
|
||||
pairs, off = _read_pairs(
|
||||
s,
|
||||
specials = ("expires", "path")
|
||||
specials=("expires", "path")
|
||||
)
|
||||
return pairs
|
||||
|
||||
|
@ -1,51 +1,51 @@
|
||||
from __future__ import (absolute_import, print_function, division)
|
||||
|
||||
CONTINUE = 100
|
||||
SWITCHING = 101
|
||||
OK = 200
|
||||
CREATED = 201
|
||||
ACCEPTED = 202
|
||||
NON_AUTHORITATIVE_INFORMATION = 203
|
||||
NO_CONTENT = 204
|
||||
RESET_CONTENT = 205
|
||||
PARTIAL_CONTENT = 206
|
||||
MULTI_STATUS = 207
|
||||
CONTINUE = 100
|
||||
SWITCHING = 101
|
||||
OK = 200
|
||||
CREATED = 201
|
||||
ACCEPTED = 202
|
||||
NON_AUTHORITATIVE_INFORMATION = 203
|
||||
NO_CONTENT = 204
|
||||
RESET_CONTENT = 205
|
||||
PARTIAL_CONTENT = 206
|
||||
MULTI_STATUS = 207
|
||||
|
||||
MULTIPLE_CHOICE = 300
|
||||
MOVED_PERMANENTLY = 301
|
||||
FOUND = 302
|
||||
SEE_OTHER = 303
|
||||
NOT_MODIFIED = 304
|
||||
USE_PROXY = 305
|
||||
TEMPORARY_REDIRECT = 307
|
||||
MULTIPLE_CHOICE = 300
|
||||
MOVED_PERMANENTLY = 301
|
||||
FOUND = 302
|
||||
SEE_OTHER = 303
|
||||
NOT_MODIFIED = 304
|
||||
USE_PROXY = 305
|
||||
TEMPORARY_REDIRECT = 307
|
||||
|
||||
BAD_REQUEST = 400
|
||||
UNAUTHORIZED = 401
|
||||
PAYMENT_REQUIRED = 402
|
||||
FORBIDDEN = 403
|
||||
NOT_FOUND = 404
|
||||
NOT_ALLOWED = 405
|
||||
NOT_ACCEPTABLE = 406
|
||||
PROXY_AUTH_REQUIRED = 407
|
||||
REQUEST_TIMEOUT = 408
|
||||
CONFLICT = 409
|
||||
GONE = 410
|
||||
LENGTH_REQUIRED = 411
|
||||
PRECONDITION_FAILED = 412
|
||||
REQUEST_ENTITY_TOO_LARGE = 413
|
||||
REQUEST_URI_TOO_LONG = 414
|
||||
UNSUPPORTED_MEDIA_TYPE = 415
|
||||
BAD_REQUEST = 400
|
||||
UNAUTHORIZED = 401
|
||||
PAYMENT_REQUIRED = 402
|
||||
FORBIDDEN = 403
|
||||
NOT_FOUND = 404
|
||||
NOT_ALLOWED = 405
|
||||
NOT_ACCEPTABLE = 406
|
||||
PROXY_AUTH_REQUIRED = 407
|
||||
REQUEST_TIMEOUT = 408
|
||||
CONFLICT = 409
|
||||
GONE = 410
|
||||
LENGTH_REQUIRED = 411
|
||||
PRECONDITION_FAILED = 412
|
||||
REQUEST_ENTITY_TOO_LARGE = 413
|
||||
REQUEST_URI_TOO_LONG = 414
|
||||
UNSUPPORTED_MEDIA_TYPE = 415
|
||||
REQUESTED_RANGE_NOT_SATISFIABLE = 416
|
||||
EXPECTATION_FAILED = 417
|
||||
EXPECTATION_FAILED = 417
|
||||
|
||||
INTERNAL_SERVER_ERROR = 500
|
||||
NOT_IMPLEMENTED = 501
|
||||
BAD_GATEWAY = 502
|
||||
SERVICE_UNAVAILABLE = 503
|
||||
GATEWAY_TIMEOUT = 504
|
||||
HTTP_VERSION_NOT_SUPPORTED = 505
|
||||
INSUFFICIENT_STORAGE_SPACE = 507
|
||||
NOT_EXTENDED = 510
|
||||
INTERNAL_SERVER_ERROR = 500
|
||||
NOT_IMPLEMENTED = 501
|
||||
BAD_GATEWAY = 502
|
||||
SERVICE_UNAVAILABLE = 503
|
||||
GATEWAY_TIMEOUT = 504
|
||||
HTTP_VERSION_NOT_SUPPORTED = 505
|
||||
INSUFFICIENT_STORAGE_SPACE = 507
|
||||
NOT_EXTENDED = 510
|
||||
|
||||
RESPONSES = {
|
||||
# 100
|
||||
|
@ -1,5 +1,6 @@
|
||||
from __future__ import (absolute_import, print_function, division)
|
||||
import re, copy
|
||||
import re
|
||||
import copy
|
||||
|
||||
|
||||
def safe_subn(pattern, repl, target, *args, **kwargs):
|
||||
@ -12,10 +13,12 @@ def safe_subn(pattern, repl, target, *args, **kwargs):
|
||||
|
||||
|
||||
class ODict(object):
|
||||
|
||||
"""
|
||||
A dictionary-like object for managing ordered (key, value) data. Think
|
||||
about it as a convenient interface to a list of (key, value) tuples.
|
||||
"""
|
||||
|
||||
def __init__(self, lst=None):
|
||||
self.lst = lst or []
|
||||
|
||||
@ -157,7 +160,7 @@ class ODict(object):
|
||||
"key: value"
|
||||
"""
|
||||
for k, v in self.lst:
|
||||
s = "%s: %s"%(k, v)
|
||||
s = "%s: %s" % (k, v)
|
||||
if re.search(expr, s):
|
||||
return True
|
||||
return False
|
||||
@ -192,11 +195,12 @@ class ODict(object):
|
||||
return klass([list(i) for i in state])
|
||||
|
||||
|
||||
|
||||
class ODictCaseless(ODict):
|
||||
|
||||
"""
|
||||
A variant of ODict with "caseless" keys. This version _preserves_ key
|
||||
case, but does not consider case when setting or getting items.
|
||||
"""
|
||||
|
||||
def _kconv(self, s):
|
||||
return s.lower()
|
||||
|
@ -6,49 +6,50 @@ from . import tcp, utils
|
||||
|
||||
|
||||
class SocksError(Exception):
|
||||
|
||||
def __init__(self, code, message):
|
||||
super(SocksError, self).__init__(message)
|
||||
self.code = code
|
||||
|
||||
|
||||
VERSION = utils.BiDi(
|
||||
SOCKS4 = 0x04,
|
||||
SOCKS5 = 0x05
|
||||
SOCKS4=0x04,
|
||||
SOCKS5=0x05
|
||||
)
|
||||
|
||||
|
||||
CMD = utils.BiDi(
|
||||
CONNECT = 0x01,
|
||||
BIND = 0x02,
|
||||
UDP_ASSOCIATE = 0x03
|
||||
CONNECT=0x01,
|
||||
BIND=0x02,
|
||||
UDP_ASSOCIATE=0x03
|
||||
)
|
||||
|
||||
|
||||
ATYP = utils.BiDi(
|
||||
IPV4_ADDRESS = 0x01,
|
||||
DOMAINNAME = 0x03,
|
||||
IPV6_ADDRESS = 0x04
|
||||
IPV4_ADDRESS=0x01,
|
||||
DOMAINNAME=0x03,
|
||||
IPV6_ADDRESS=0x04
|
||||
)
|
||||
|
||||
|
||||
REP = utils.BiDi(
|
||||
SUCCEEDED = 0x00,
|
||||
GENERAL_SOCKS_SERVER_FAILURE = 0x01,
|
||||
CONNECTION_NOT_ALLOWED_BY_RULESET = 0x02,
|
||||
NETWORK_UNREACHABLE = 0x03,
|
||||
HOST_UNREACHABLE = 0x04,
|
||||
CONNECTION_REFUSED = 0x05,
|
||||
TTL_EXPIRED = 0x06,
|
||||
COMMAND_NOT_SUPPORTED = 0x07,
|
||||
ADDRESS_TYPE_NOT_SUPPORTED = 0x08,
|
||||
SUCCEEDED=0x00,
|
||||
GENERAL_SOCKS_SERVER_FAILURE=0x01,
|
||||
CONNECTION_NOT_ALLOWED_BY_RULESET=0x02,
|
||||
NETWORK_UNREACHABLE=0x03,
|
||||
HOST_UNREACHABLE=0x04,
|
||||
CONNECTION_REFUSED=0x05,
|
||||
TTL_EXPIRED=0x06,
|
||||
COMMAND_NOT_SUPPORTED=0x07,
|
||||
ADDRESS_TYPE_NOT_SUPPORTED=0x08,
|
||||
)
|
||||
|
||||
|
||||
METHOD = utils.BiDi(
|
||||
NO_AUTHENTICATION_REQUIRED = 0x00,
|
||||
GSSAPI = 0x01,
|
||||
USERNAME_PASSWORD = 0x02,
|
||||
NO_ACCEPTABLE_METHODS = 0xFF
|
||||
NO_AUTHENTICATION_REQUIRED=0x00,
|
||||
GSSAPI=0x01,
|
||||
USERNAME_PASSWORD=0x02,
|
||||
NO_ACCEPTABLE_METHODS=0xFF
|
||||
)
|
||||
|
||||
|
||||
|
@ -22,14 +22,28 @@ OP_NO_SSLv2 = SSL.OP_NO_SSLv2
|
||||
OP_NO_SSLv3 = SSL.OP_NO_SSLv3
|
||||
|
||||
|
||||
class NetLibError(Exception): pass
|
||||
class NetLibDisconnect(NetLibError): pass
|
||||
class NetLibIncomplete(NetLibError): pass
|
||||
class NetLibTimeout(NetLibError): pass
|
||||
class NetLibSSLError(NetLibError): pass
|
||||
class NetLibError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class NetLibDisconnect(NetLibError):
|
||||
pass
|
||||
|
||||
|
||||
class NetLibIncomplete(NetLibError):
|
||||
pass
|
||||
|
||||
|
||||
class NetLibTimeout(NetLibError):
|
||||
pass
|
||||
|
||||
|
||||
class NetLibSSLError(NetLibError):
|
||||
pass
|
||||
|
||||
|
||||
class SSLKeyLogger(object):
|
||||
|
||||
def __init__(self, filename):
|
||||
self.filename = filename
|
||||
self.f = None
|
||||
@ -67,6 +81,7 @@ log_ssl_key = SSLKeyLogger.create_logfun(os.getenv("MITMPROXY_SSLKEYLOGFILE") or
|
||||
|
||||
class _FileLike(object):
|
||||
BLOCKSIZE = 1024 * 32
|
||||
|
||||
def __init__(self, o):
|
||||
self.o = o
|
||||
self._log = None
|
||||
@ -112,6 +127,7 @@ class _FileLike(object):
|
||||
|
||||
|
||||
class Writer(_FileLike):
|
||||
|
||||
def flush(self):
|
||||
"""
|
||||
May raise NetLibDisconnect
|
||||
@ -119,7 +135,7 @@ class Writer(_FileLike):
|
||||
if hasattr(self.o, "flush"):
|
||||
try:
|
||||
self.o.flush()
|
||||
except (socket.error, IOError), v:
|
||||
except (socket.error, IOError) as v:
|
||||
raise NetLibDisconnect(str(v))
|
||||
|
||||
def write(self, v):
|
||||
@ -135,11 +151,12 @@ class Writer(_FileLike):
|
||||
r = self.o.write(v)
|
||||
self.add_log(v[:r])
|
||||
return r
|
||||
except (SSL.Error, socket.error) as e:
|
||||
except (SSL.Error, socket.error) as e:
|
||||
raise NetLibDisconnect(str(e))
|
||||
|
||||
|
||||
class Reader(_FileLike):
|
||||
|
||||
def read(self, length):
|
||||
"""
|
||||
If length is -1, we read until connection closes.
|
||||
@ -180,7 +197,7 @@ class Reader(_FileLike):
|
||||
self.add_log(result)
|
||||
return result
|
||||
|
||||
def readline(self, size = None):
|
||||
def readline(self, size=None):
|
||||
result = ''
|
||||
bytes_read = 0
|
||||
while True:
|
||||
@ -204,16 +221,18 @@ class Reader(_FileLike):
|
||||
result = self.read(length)
|
||||
if length != -1 and len(result) != length:
|
||||
raise NetLibIncomplete(
|
||||
"Expected %s bytes, got %s"%(length, len(result))
|
||||
"Expected %s bytes, got %s" % (length, len(result))
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class Address(object):
|
||||
|
||||
"""
|
||||
This class wraps an IPv4/IPv6 tuple to provide named attributes and
|
||||
ipv6 information.
|
||||
"""
|
||||
|
||||
def __init__(self, address, use_ipv6=False):
|
||||
self.address = tuple(address)
|
||||
self.use_ipv6 = use_ipv6
|
||||
@ -304,6 +323,7 @@ def close_socket(sock):
|
||||
|
||||
|
||||
class _Connection(object):
|
||||
|
||||
def get_current_cipher(self):
|
||||
if not self.ssl_established:
|
||||
return None
|
||||
@ -319,7 +339,7 @@ class _Connection(object):
|
||||
# (We call _FileLike.set_descriptor(conn))
|
||||
# Closing the socket is not our task, therefore we don't call close
|
||||
# then.
|
||||
if type(self.connection) != SSL.Connection:
|
||||
if not isinstance(self.connection, SSL.Connection):
|
||||
if not getattr(self.wfile, "closed", False):
|
||||
try:
|
||||
self.wfile.flush()
|
||||
@ -337,6 +357,7 @@ class _Connection(object):
|
||||
"""
|
||||
Creates an SSL Context.
|
||||
"""
|
||||
|
||||
def _create_ssl_context(self,
|
||||
method=SSLv23_METHOD,
|
||||
options=(OP_NO_SSLv2 | OP_NO_SSLv3),
|
||||
@ -362,8 +383,8 @@ class _Connection(object):
|
||||
if cipher_list:
|
||||
try:
|
||||
context.set_cipher_list(cipher_list)
|
||||
except SSL.Error, v:
|
||||
raise NetLibError("SSL cipher specification error: %s"%str(v))
|
||||
except SSL.Error as v:
|
||||
raise NetLibError("SSL cipher specification error: %s" % str(v))
|
||||
|
||||
# SSLKEYLOGFILE
|
||||
if log_ssl_key:
|
||||
@ -380,7 +401,7 @@ class TCPClient(_Connection):
|
||||
# Make sure to close the real socket, not the SSL proxy.
|
||||
# OpenSSL is really good at screwing up, i.e. when trying to recv from a failed connection,
|
||||
# it tries to renegotiate...
|
||||
if type(self.connection) == SSL.Connection:
|
||||
if isinstance(self.connection, SSL.Connection):
|
||||
close_socket(self.connection._socket)
|
||||
else:
|
||||
close_socket(self.connection)
|
||||
@ -400,8 +421,8 @@ class TCPClient(_Connection):
|
||||
try:
|
||||
context.use_privatekey_file(cert)
|
||||
context.use_certificate_file(cert)
|
||||
except SSL.Error, v:
|
||||
raise NetLibError("SSL client certificate error: %s"%str(v))
|
||||
except SSL.Error as v:
|
||||
raise NetLibError("SSL client certificate error: %s" % str(v))
|
||||
return context
|
||||
|
||||
def convert_to_ssl(self, sni=None, **sslctx_kwargs):
|
||||
@ -418,8 +439,8 @@ class TCPClient(_Connection):
|
||||
self.connection.set_connect_state()
|
||||
try:
|
||||
self.connection.do_handshake()
|
||||
except SSL.Error, v:
|
||||
raise NetLibError("SSL handshake error: %s"%repr(v))
|
||||
except SSL.Error as v:
|
||||
raise NetLibError("SSL handshake error: %s" % repr(v))
|
||||
self.ssl_established = True
|
||||
self.cert = certutils.SSLCert(self.connection.get_peer_certificate())
|
||||
self.rfile.set_descriptor(self.connection)
|
||||
@ -435,7 +456,7 @@ class TCPClient(_Connection):
|
||||
self.source_address = Address(connection.getsockname())
|
||||
self.rfile = Reader(connection.makefile('rb', self.rbufsize))
|
||||
self.wfile = Writer(connection.makefile('wb', self.wbufsize))
|
||||
except (socket.error, IOError), err:
|
||||
except (socket.error, IOError) as err:
|
||||
raise NetLibError('Error connecting to "%s": %s' % (self.address.host, err))
|
||||
self.connection = connection
|
||||
|
||||
@ -447,6 +468,7 @@ class TCPClient(_Connection):
|
||||
|
||||
|
||||
class BaseHandler(_Connection):
|
||||
|
||||
"""
|
||||
The instantiator is expected to call the handle() and finish() methods.
|
||||
|
||||
@ -531,8 +553,8 @@ class BaseHandler(_Connection):
|
||||
self.connection.set_accept_state()
|
||||
try:
|
||||
self.connection.do_handshake()
|
||||
except SSL.Error, v:
|
||||
raise NetLibError("SSL handshake error: %s"%repr(v))
|
||||
except SSL.Error as v:
|
||||
raise NetLibError("SSL handshake error: %s" % repr(v))
|
||||
self.ssl_established = True
|
||||
self.rfile.set_descriptor(self.connection)
|
||||
self.wfile.set_descriptor(self.connection)
|
||||
|
@ -1,9 +1,13 @@
|
||||
from __future__ import (absolute_import, print_function, division)
|
||||
import threading, Queue, cStringIO
|
||||
import threading
|
||||
import Queue
|
||||
import cStringIO
|
||||
import OpenSSL
|
||||
from . import tcp, certutils
|
||||
|
||||
|
||||
class ServerThread(threading.Thread):
|
||||
|
||||
def __init__(self, server):
|
||||
self.server = server
|
||||
threading.Thread.__init__(self)
|
||||
@ -19,6 +23,7 @@ class ServerTestBase(object):
|
||||
ssl = None
|
||||
handler = None
|
||||
addr = ("localhost", 0)
|
||||
|
||||
@classmethod
|
||||
def setupAll(cls):
|
||||
cls.q = Queue.Queue()
|
||||
@ -41,10 +46,11 @@ class ServerTestBase(object):
|
||||
|
||||
|
||||
class TServer(tcp.TCPServer):
|
||||
|
||||
def __init__(self, ssl, q, handler_klass, addr):
|
||||
"""
|
||||
ssl: A dictionary of SSL parameters:
|
||||
|
||||
|
||||
cert, key, request_client_cert, cipher_list,
|
||||
dhparams, v3_only
|
||||
"""
|
||||
@ -70,13 +76,13 @@ class TServer(tcp.TCPServer):
|
||||
options = None
|
||||
h.convert_to_ssl(
|
||||
cert, key,
|
||||
method = method,
|
||||
options = options,
|
||||
handle_sni = getattr(h, "handle_sni", None),
|
||||
request_client_cert = self.ssl["request_client_cert"],
|
||||
cipher_list = self.ssl.get("cipher_list", None),
|
||||
dhparams = self.ssl.get("dhparams", None),
|
||||
chain_file = self.ssl.get("chain_file", None)
|
||||
method=method,
|
||||
options=options,
|
||||
handle_sni=getattr(h, "handle_sni", None),
|
||||
request_client_cert=self.ssl["request_client_cert"],
|
||||
cipher_list=self.ssl.get("cipher_list", None),
|
||||
dhparams=self.ssl.get("dhparams", None),
|
||||
chain_file=self.ssl.get("chain_file", None)
|
||||
)
|
||||
h.handle()
|
||||
h.finish()
|
||||
|
@ -68,6 +68,7 @@ def getbit(byte, offset):
|
||||
|
||||
|
||||
class BiDi:
|
||||
|
||||
"""
|
||||
A wee utility class for keeping bi-directional mappings, like field
|
||||
constants in protocols. Names are attributes on the object, dict-like
|
||||
@ -77,6 +78,7 @@ class BiDi:
|
||||
assert CONST.a == 1
|
||||
assert CONST.get_name(1) == "a"
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.names = kwargs
|
||||
self.values = {}
|
||||
@ -96,15 +98,15 @@ class BiDi:
|
||||
|
||||
def pretty_size(size):
|
||||
suffixes = [
|
||||
("B", 2**10),
|
||||
("kB", 2**20),
|
||||
("MB", 2**30),
|
||||
("B", 2 ** 10),
|
||||
("kB", 2 ** 20),
|
||||
("MB", 2 ** 30),
|
||||
]
|
||||
for suf, lim in suffixes:
|
||||
if size >= lim:
|
||||
continue
|
||||
else:
|
||||
x = round(size/float(lim/2**10), 2)
|
||||
x = round(size / float(lim / 2 ** 10), 2)
|
||||
if x == int(x):
|
||||
x = int(x)
|
||||
return str(x) + suf
|
||||
|
@ -26,16 +26,17 @@ MAX_64_BIT_INT = (1 << 64)
|
||||
|
||||
|
||||
OPCODE = utils.BiDi(
|
||||
CONTINUE = 0x00,
|
||||
TEXT = 0x01,
|
||||
BINARY = 0x02,
|
||||
CLOSE = 0x08,
|
||||
PING = 0x09,
|
||||
PONG = 0x0a
|
||||
CONTINUE=0x00,
|
||||
TEXT=0x01,
|
||||
BINARY=0x02,
|
||||
CLOSE=0x08,
|
||||
PING=0x09,
|
||||
PONG=0x0a
|
||||
)
|
||||
|
||||
|
||||
class Masker:
|
||||
|
||||
"""
|
||||
Data sent from the server must be masked to prevent malicious clients
|
||||
from sending data over the wire in predictable patterns
|
||||
@ -43,6 +44,7 @@ class Masker:
|
||||
Servers do not have to mask data they send to the client.
|
||||
https://tools.ietf.org/html/rfc6455#section-5.3
|
||||
"""
|
||||
|
||||
def __init__(self, key):
|
||||
self.key = key
|
||||
self.masks = [utils.bytes_to_int(byte) for byte in key]
|
||||
@ -128,17 +130,18 @@ DEFAULT = object()
|
||||
|
||||
|
||||
class FrameHeader:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
opcode = OPCODE.TEXT,
|
||||
payload_length = 0,
|
||||
fin = False,
|
||||
rsv1 = False,
|
||||
rsv2 = False,
|
||||
rsv3 = False,
|
||||
masking_key = DEFAULT,
|
||||
mask = DEFAULT,
|
||||
length_code = DEFAULT
|
||||
opcode=OPCODE.TEXT,
|
||||
payload_length=0,
|
||||
fin=False,
|
||||
rsv1=False,
|
||||
rsv2=False,
|
||||
rsv3=False,
|
||||
masking_key=DEFAULT,
|
||||
mask=DEFAULT,
|
||||
length_code=DEFAULT
|
||||
):
|
||||
if not 0 <= opcode < 2 ** 4:
|
||||
raise ValueError("opcode must be 0-16")
|
||||
@ -182,9 +185,9 @@ class FrameHeader:
|
||||
if flags:
|
||||
vals.extend([":", "|".join(flags)])
|
||||
if self.masking_key:
|
||||
vals.append(":key=%s"%repr(self.masking_key))
|
||||
vals.append(":key=%s" % repr(self.masking_key))
|
||||
if self.payload_length:
|
||||
vals.append(" %s"%utils.pretty_size(self.payload_length))
|
||||
vals.append(" %s" % utils.pretty_size(self.payload_length))
|
||||
return "".join(vals)
|
||||
|
||||
def to_bytes(self):
|
||||
@ -246,15 +249,15 @@ class FrameHeader:
|
||||
masking_key = None
|
||||
|
||||
return klass(
|
||||
fin = fin,
|
||||
rsv1 = rsv1,
|
||||
rsv2 = rsv2,
|
||||
rsv3 = rsv3,
|
||||
opcode = opcode,
|
||||
mask = mask_bit,
|
||||
length_code = length_code,
|
||||
payload_length = payload_length,
|
||||
masking_key = masking_key,
|
||||
fin=fin,
|
||||
rsv1=rsv1,
|
||||
rsv2=rsv2,
|
||||
rsv3=rsv3,
|
||||
opcode=opcode,
|
||||
mask=mask_bit,
|
||||
length_code=length_code,
|
||||
payload_length=payload_length,
|
||||
masking_key=masking_key,
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
@ -262,6 +265,7 @@ class FrameHeader:
|
||||
|
||||
|
||||
class Frame(object):
|
||||
|
||||
"""
|
||||
Represents one websockets frame.
|
||||
Constructor takes human readable forms of the frame components
|
||||
@ -287,13 +291,14 @@ class Frame(object):
|
||||
| Payload Data continued ... |
|
||||
+---------------------------------------------------------------+
|
||||
"""
|
||||
def __init__(self, payload = "", **kwargs):
|
||||
|
||||
def __init__(self, payload="", **kwargs):
|
||||
self.payload = payload
|
||||
kwargs["payload_length"] = kwargs.get("payload_length", len(payload))
|
||||
self.header = FrameHeader(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def default(cls, message, from_client = False):
|
||||
def default(cls, message, from_client=False):
|
||||
"""
|
||||
Construct a basic websocket frame from some default values.
|
||||
Creates a non-fragmented text frame.
|
||||
@ -307,10 +312,10 @@ class Frame(object):
|
||||
|
||||
return cls(
|
||||
message,
|
||||
fin = 1, # final frame
|
||||
opcode = OPCODE.TEXT, # text
|
||||
mask = mask_bit,
|
||||
masking_key = masking_key,
|
||||
fin=1, # final frame
|
||||
opcode=OPCODE.TEXT, # text
|
||||
mask=mask_bit,
|
||||
masking_key=masking_key,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -356,15 +361,15 @@ class Frame(object):
|
||||
|
||||
return cls(
|
||||
payload,
|
||||
fin = header.fin,
|
||||
opcode = header.opcode,
|
||||
mask = header.mask,
|
||||
payload_length = header.payload_length,
|
||||
masking_key = header.masking_key,
|
||||
rsv1 = header.rsv1,
|
||||
rsv2 = header.rsv2,
|
||||
rsv3 = header.rsv3,
|
||||
length_code = header.length_code
|
||||
fin=header.fin,
|
||||
opcode=header.opcode,
|
||||
mask=header.mask,
|
||||
payload_length=header.payload_length,
|
||||
masking_key=header.masking_key,
|
||||
rsv1=header.rsv1,
|
||||
rsv2=header.rsv2,
|
||||
rsv3=header.rsv3,
|
||||
length_code=header.length_code
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
|
@ -7,17 +7,20 @@ from . import odict, tcp
|
||||
|
||||
|
||||
class ClientConn(object):
|
||||
|
||||
def __init__(self, address):
|
||||
self.address = tcp.Address.wrap(address)
|
||||
|
||||
|
||||
class Flow(object):
|
||||
|
||||
def __init__(self, address, request):
|
||||
self.client_conn = ClientConn(address)
|
||||
self.request = request
|
||||
|
||||
|
||||
class Request(object):
|
||||
|
||||
def __init__(self, scheme, method, path, headers, content):
|
||||
self.scheme, self.method, self.path = scheme, method, path
|
||||
self.headers, self.content = headers, content
|
||||
@ -42,6 +45,7 @@ def date_time_string():
|
||||
|
||||
|
||||
class WSGIAdaptor(object):
|
||||
|
||||
def __init__(self, app, domain, port, sversion):
|
||||
self.app, self.domain, self.port, self.sversion = app, domain, port, sversion
|
||||
|
||||
@ -52,24 +56,24 @@ class WSGIAdaptor(object):
|
||||
path_info = flow.request.path
|
||||
query = ''
|
||||
environ = {
|
||||
'wsgi.version': (1, 0),
|
||||
'wsgi.url_scheme': flow.request.scheme,
|
||||
'wsgi.input': cStringIO.StringIO(flow.request.content),
|
||||
'wsgi.errors': errsoc,
|
||||
'wsgi.multithread': True,
|
||||
'wsgi.multiprocess': False,
|
||||
'wsgi.run_once': False,
|
||||
'SERVER_SOFTWARE': self.sversion,
|
||||
'REQUEST_METHOD': flow.request.method,
|
||||
'SCRIPT_NAME': '',
|
||||
'PATH_INFO': urllib.unquote(path_info),
|
||||
'QUERY_STRING': query,
|
||||
'CONTENT_TYPE': flow.request.headers.get('Content-Type', [''])[0],
|
||||
'CONTENT_LENGTH': flow.request.headers.get('Content-Length', [''])[0],
|
||||
'SERVER_NAME': self.domain,
|
||||
'SERVER_PORT': str(self.port),
|
||||
'wsgi.version': (1, 0),
|
||||
'wsgi.url_scheme': flow.request.scheme,
|
||||
'wsgi.input': cStringIO.StringIO(flow.request.content),
|
||||
'wsgi.errors': errsoc,
|
||||
'wsgi.multithread': True,
|
||||
'wsgi.multiprocess': False,
|
||||
'wsgi.run_once': False,
|
||||
'SERVER_SOFTWARE': self.sversion,
|
||||
'REQUEST_METHOD': flow.request.method,
|
||||
'SCRIPT_NAME': '',
|
||||
'PATH_INFO': urllib.unquote(path_info),
|
||||
'QUERY_STRING': query,
|
||||
'CONTENT_TYPE': flow.request.headers.get('Content-Type', [''])[0],
|
||||
'CONTENT_LENGTH': flow.request.headers.get('Content-Length', [''])[0],
|
||||
'SERVER_NAME': self.domain,
|
||||
'SERVER_PORT': str(self.port),
|
||||
# FIXME: We need to pick up the protocol read from the request.
|
||||
'SERVER_PROTOCOL': "HTTP/1.1",
|
||||
'SERVER_PROTOCOL': "HTTP/1.1",
|
||||
}
|
||||
environ.update(extra)
|
||||
if flow.client_conn.address:
|
||||
@ -91,25 +95,25 @@ class WSGIAdaptor(object):
|
||||
<h1>Internal Server Error</h1>
|
||||
<pre>%s"</pre>
|
||||
</html>
|
||||
"""%s
|
||||
""" % s
|
||||
if not headers_sent:
|
||||
soc.write("HTTP/1.1 500 Internal Server Error\r\n")
|
||||
soc.write("Content-Type: text/html\r\n")
|
||||
soc.write("Content-Length: %s\r\n"%len(c))
|
||||
soc.write("Content-Length: %s\r\n" % len(c))
|
||||
soc.write("\r\n")
|
||||
soc.write(c)
|
||||
|
||||
def serve(self, request, soc, **env):
|
||||
state = dict(
|
||||
response_started = False,
|
||||
headers_sent = False,
|
||||
status = None,
|
||||
headers = None
|
||||
response_started=False,
|
||||
headers_sent=False,
|
||||
status=None,
|
||||
headers=None
|
||||
)
|
||||
|
||||
def write(data):
|
||||
if not state["headers_sent"]:
|
||||
soc.write("HTTP/1.1 %s\r\n"%state["status"])
|
||||
soc.write("HTTP/1.1 %s\r\n" % state["status"])
|
||||
h = state["headers"]
|
||||
if 'server' not in h:
|
||||
h["Server"] = [self.sversion]
|
||||
|
7
setup.cfg
Normal file
7
setup.cfg
Normal file
@ -0,0 +1,7 @@
|
||||
[flake8]
|
||||
max-line-length = 160
|
||||
max-complexity = 15
|
||||
|
||||
[pep8]
|
||||
max-line-length = 160
|
||||
max-complexity = 15
|
@ -4,20 +4,22 @@ import tutils
|
||||
from nose.tools import assert_equal
|
||||
|
||||
|
||||
|
||||
# TODO test stream association if valid or not
|
||||
|
||||
def test_invalid_flags():
|
||||
tutils.raises(ValueError, DataFrame, ContinuationFrame.FLAG_END_HEADERS, 0x1234567, 'foobar')
|
||||
|
||||
|
||||
def test_frame_equality():
|
||||
a = DataFrame(6, Frame.FLAG_END_STREAM, 0x1234567, 'foobar')
|
||||
b = DataFrame(6, Frame.FLAG_END_STREAM, 0x1234567, 'foobar')
|
||||
assert_equal(a, b)
|
||||
|
||||
|
||||
def test_too_large_frames():
|
||||
DataFrame(6, Frame.FLAG_END_STREAM, 0x1234567)
|
||||
|
||||
|
||||
def test_data_frame_to_bytes():
|
||||
f = DataFrame(6, Frame.FLAG_END_STREAM, 0x1234567, 'foobar')
|
||||
assert_equal(f.to_bytes().encode('hex'), '000006000101234567666f6f626172')
|
||||
@ -28,6 +30,7 @@ def test_data_frame_to_bytes():
|
||||
f = DataFrame(6, Frame.FLAG_NO_FLAGS, 0x0, 'foobar')
|
||||
tutils.raises(ValueError, f.to_bytes)
|
||||
|
||||
|
||||
def test_data_frame_from_bytes():
|
||||
f = Frame.from_bytes('000006000101234567666f6f626172'.decode('hex'))
|
||||
assert isinstance(f, DataFrame)
|
||||
@ -45,6 +48,7 @@ def test_data_frame_from_bytes():
|
||||
assert_equal(f.stream_id, 0x1234567)
|
||||
assert_equal(f.payload, 'foobar')
|
||||
|
||||
|
||||
def test_headers_frame_to_bytes():
|
||||
f = HeadersFrame(6, Frame.FLAG_NO_FLAGS, 0x1234567, 'foobar')
|
||||
assert_equal(f.to_bytes().encode('hex'), '000006010001234567666f6f626172')
|
||||
@ -55,15 +59,18 @@ def test_headers_frame_to_bytes():
|
||||
f = HeadersFrame(10, HeadersFrame.FLAG_PRIORITY, 0x1234567, 'foobar', exclusive=True, stream_dependency=0x7654321, weight=42)
|
||||
assert_equal(f.to_bytes().encode('hex'), '00000b012001234567876543212a666f6f626172')
|
||||
|
||||
f = HeadersFrame(14, HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY, 0x1234567, 'foobar', pad_length=3, exclusive=True, stream_dependency=0x7654321, weight=42)
|
||||
f = HeadersFrame(14, HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY, 0x1234567,
|
||||
'foobar', pad_length=3, exclusive=True, stream_dependency=0x7654321, weight=42)
|
||||
assert_equal(f.to_bytes().encode('hex'), '00000f01280123456703876543212a666f6f626172000000')
|
||||
|
||||
f = HeadersFrame(14, HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY, 0x1234567, 'foobar', pad_length=3, exclusive=False, stream_dependency=0x7654321, weight=42)
|
||||
f = HeadersFrame(14, HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY, 0x1234567, 'foobar',
|
||||
pad_length=3, exclusive=False, stream_dependency=0x7654321, weight=42)
|
||||
assert_equal(f.to_bytes().encode('hex'), '00000f01280123456703076543212a666f6f626172000000')
|
||||
|
||||
f = HeadersFrame(6, Frame.FLAG_NO_FLAGS, 0x0, 'foobar')
|
||||
tutils.raises(ValueError, f.to_bytes)
|
||||
|
||||
|
||||
def test_headers_frame_from_bytes():
|
||||
f = Frame.from_bytes('000006010001234567666f6f626172'.decode('hex'))
|
||||
assert isinstance(f, HeadersFrame)
|
||||
@ -114,6 +121,7 @@ def test_headers_frame_from_bytes():
|
||||
assert_equal(f.stream_dependency, 0x7654321)
|
||||
assert_equal(f.weight, 42)
|
||||
|
||||
|
||||
def test_priority_frame_to_bytes():
|
||||
f = PriorityFrame(5, Frame.FLAG_NO_FLAGS, 0x1234567, exclusive=True, stream_dependency=0x7654321, weight=42)
|
||||
assert_equal(f.to_bytes().encode('hex'), '000005020001234567876543212a')
|
||||
@ -127,6 +135,7 @@ def test_priority_frame_to_bytes():
|
||||
f = PriorityFrame(5, Frame.FLAG_NO_FLAGS, 0x1234567, stream_dependency=0x0)
|
||||
tutils.raises(ValueError, f.to_bytes)
|
||||
|
||||
|
||||
def test_priority_frame_from_bytes():
|
||||
f = Frame.from_bytes('000005020001234567876543212a'.decode('hex'))
|
||||
assert isinstance(f, PriorityFrame)
|
||||
@ -148,6 +157,7 @@ def test_priority_frame_from_bytes():
|
||||
assert_equal(f.stream_dependency, 0x7654321)
|
||||
assert_equal(f.weight, 21)
|
||||
|
||||
|
||||
def test_rst_stream_frame_to_bytes():
|
||||
f = RstStreamFrame(4, Frame.FLAG_NO_FLAGS, 0x1234567, error_code=0x7654321)
|
||||
assert_equal(f.to_bytes().encode('hex'), '00000403000123456707654321')
|
||||
@ -155,6 +165,7 @@ def test_rst_stream_frame_to_bytes():
|
||||
f = RstStreamFrame(4, Frame.FLAG_NO_FLAGS, 0x0)
|
||||
tutils.raises(ValueError, f.to_bytes)
|
||||
|
||||
|
||||
def test_rst_stream_frame_from_bytes():
|
||||
f = Frame.from_bytes('00000403000123456707654321'.decode('hex'))
|
||||
assert isinstance(f, RstStreamFrame)
|
||||
@ -164,6 +175,7 @@ def test_rst_stream_frame_from_bytes():
|
||||
assert_equal(f.stream_id, 0x1234567)
|
||||
assert_equal(f.error_code, 0x07654321)
|
||||
|
||||
|
||||
def test_settings_frame_to_bytes():
|
||||
f = SettingsFrame(0, Frame.FLAG_NO_FLAGS, 0x0)
|
||||
assert_equal(f.to_bytes().encode('hex'), '000000040000000000')
|
||||
@ -174,12 +186,14 @@ def test_settings_frame_to_bytes():
|
||||
f = SettingsFrame(6, SettingsFrame.FLAG_ACK, 0x0, settings={SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1})
|
||||
assert_equal(f.to_bytes().encode('hex'), '000006040100000000000200000001')
|
||||
|
||||
f = SettingsFrame(12, Frame.FLAG_NO_FLAGS, 0x0, settings={SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: 0x12345678})
|
||||
f = SettingsFrame(12, Frame.FLAG_NO_FLAGS, 0x0, settings={
|
||||
SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: 0x12345678})
|
||||
assert_equal(f.to_bytes().encode('hex'), '00000c040000000000000200000001000312345678')
|
||||
|
||||
f = SettingsFrame(0, Frame.FLAG_NO_FLAGS, 0x1234567)
|
||||
tutils.raises(ValueError, f.to_bytes)
|
||||
|
||||
|
||||
def test_settings_frame_from_bytes():
|
||||
f = Frame.from_bytes('000000040000000000'.decode('hex'))
|
||||
assert isinstance(f, SettingsFrame)
|
||||
@ -214,6 +228,7 @@ def test_settings_frame_from_bytes():
|
||||
assert_equal(f.settings[SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH], 1)
|
||||
assert_equal(f.settings[SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS], 0x12345678)
|
||||
|
||||
|
||||
def test_push_promise_frame_to_bytes():
|
||||
f = PushPromiseFrame(10, Frame.FLAG_NO_FLAGS, 0x1234567, 0x7654321, 'foobar')
|
||||
assert_equal(f.to_bytes().encode('hex'), '00000a05000123456707654321666f6f626172')
|
||||
@ -227,6 +242,7 @@ def test_push_promise_frame_to_bytes():
|
||||
f = PushPromiseFrame(4, Frame.FLAG_NO_FLAGS, 0x1234567, 0x0)
|
||||
tutils.raises(ValueError, f.to_bytes)
|
||||
|
||||
|
||||
def test_push_promise_frame_from_bytes():
|
||||
f = Frame.from_bytes('00000a05000123456707654321666f6f626172'.decode('hex'))
|
||||
assert isinstance(f, PushPromiseFrame)
|
||||
@ -244,6 +260,7 @@ def test_push_promise_frame_from_bytes():
|
||||
assert_equal(f.stream_id, 0x1234567)
|
||||
assert_equal(f.header_block_fragment, 'foobar')
|
||||
|
||||
|
||||
def test_ping_frame_to_bytes():
|
||||
f = PingFrame(8, PingFrame.FLAG_ACK, 0x0, payload=b'foobar')
|
||||
assert_equal(f.to_bytes().encode('hex'), '000008060100000000666f6f6261720000')
|
||||
@ -254,6 +271,7 @@ def test_ping_frame_to_bytes():
|
||||
f = PingFrame(8, Frame.FLAG_NO_FLAGS, 0x1234567)
|
||||
tutils.raises(ValueError, f.to_bytes)
|
||||
|
||||
|
||||
def test_ping_frame_from_bytes():
|
||||
f = Frame.from_bytes('000008060100000000666f6f6261720000'.decode('hex'))
|
||||
assert isinstance(f, PingFrame)
|
||||
@ -271,6 +289,7 @@ def test_ping_frame_from_bytes():
|
||||
assert_equal(f.stream_id, 0x0)
|
||||
assert_equal(f.payload, b'foobarde')
|
||||
|
||||
|
||||
def test_goaway_frame_to_bytes():
|
||||
f = GoAwayFrame(8, Frame.FLAG_NO_FLAGS, 0x0, last_stream=0x1234567, error_code=0x87654321, data=b'')
|
||||
assert_equal(f.to_bytes().encode('hex'), '0000080700000000000123456787654321')
|
||||
@ -281,6 +300,7 @@ def test_goaway_frame_to_bytes():
|
||||
f = GoAwayFrame(8, Frame.FLAG_NO_FLAGS, 0x1234567, last_stream=0x1234567, error_code=0x87654321)
|
||||
tutils.raises(ValueError, f.to_bytes)
|
||||
|
||||
|
||||
def test_goaway_frame_from_bytes():
|
||||
f = Frame.from_bytes('0000080700000000000123456787654321'.decode('hex'))
|
||||
assert isinstance(f, GoAwayFrame)
|
||||
@ -302,6 +322,7 @@ def test_goaway_frame_from_bytes():
|
||||
assert_equal(f.error_code, 0x87654321)
|
||||
assert_equal(f.data, b'foobar')
|
||||
|
||||
|
||||
def test_window_update_frame_to_bytes():
|
||||
f = WindowUpdateFrame(4, Frame.FLAG_NO_FLAGS, 0x0, window_size_increment=0x1234567)
|
||||
assert_equal(f.to_bytes().encode('hex'), '00000408000000000001234567')
|
||||
@ -315,6 +336,7 @@ def test_window_update_frame_to_bytes():
|
||||
f = WindowUpdateFrame(4, Frame.FLAG_NO_FLAGS, 0x0, window_size_increment=0)
|
||||
tutils.raises(ValueError, f.to_bytes)
|
||||
|
||||
|
||||
def test_window_update_frame_from_bytes():
|
||||
f = Frame.from_bytes('00000408000000000001234567'.decode('hex'))
|
||||
assert isinstance(f, WindowUpdateFrame)
|
||||
@ -324,6 +346,7 @@ def test_window_update_frame_from_bytes():
|
||||
assert_equal(f.stream_id, 0x0)
|
||||
assert_equal(f.window_size_increment, 0x1234567)
|
||||
|
||||
|
||||
def test_continuation_frame_to_bytes():
|
||||
f = ContinuationFrame(6, ContinuationFrame.FLAG_END_HEADERS, 0x1234567, 'foobar')
|
||||
assert_equal(f.to_bytes().encode('hex'), '000006090401234567666f6f626172')
|
||||
@ -331,6 +354,7 @@ def test_continuation_frame_to_bytes():
|
||||
f = ContinuationFrame(6, ContinuationFrame.FLAG_END_HEADERS, 0x0, 'foobar')
|
||||
tutils.raises(ValueError, f.to_bytes)
|
||||
|
||||
|
||||
def test_continuation_frame_from_bytes():
|
||||
f = Frame.from_bytes('000006090401234567666f6f626172'.decode('hex'))
|
||||
assert isinstance(f, ContinuationFrame)
|
||||
|
@ -34,6 +34,7 @@ import tutils
|
||||
|
||||
|
||||
class TestCertStore:
|
||||
|
||||
def test_create_explicit(self):
|
||||
with tutils.tmpdir() as d:
|
||||
ca = certutils.CertStore.from_store(d, "test")
|
||||
@ -102,6 +103,7 @@ class TestCertStore:
|
||||
|
||||
|
||||
class TestDummyCert:
|
||||
|
||||
def test_with_ca(self):
|
||||
with tutils.tmpdir() as d:
|
||||
ca = certutils.CertStore.from_store(d, "test")
|
||||
@ -115,6 +117,7 @@ class TestDummyCert:
|
||||
|
||||
|
||||
class TestSSLCert:
|
||||
|
||||
def test_simple(self):
|
||||
with open(tutils.test_data.path("data/text_cert"), "rb") as f:
|
||||
d = f.read()
|
||||
@ -152,5 +155,3 @@ class TestSSLCert:
|
||||
d = f.read()
|
||||
s = certutils.SSLCert.from_der(d)
|
||||
assert s.cn
|
||||
|
||||
|
||||
|
@ -230,6 +230,7 @@ def test_parse_init_http():
|
||||
|
||||
|
||||
class TestReadHeaders:
|
||||
|
||||
def _read(self, data, verbatim=False):
|
||||
if not verbatim:
|
||||
data = textwrap.dedent(data)
|
||||
@ -277,6 +278,7 @@ class TestReadHeaders:
|
||||
|
||||
|
||||
class NoContentLengthHTTPHandler(tcp.BaseHandler):
|
||||
|
||||
def handle(self):
|
||||
self.wfile.write("HTTP/1.1 200 OK\r\n\r\nbar\r\n\r\n")
|
||||
self.wfile.flush()
|
||||
@ -297,7 +299,7 @@ def test_read_response():
|
||||
data = textwrap.dedent(data)
|
||||
r = cStringIO.StringIO(data)
|
||||
return http.read_response(
|
||||
r, method, limit, include_body = include_body
|
||||
r, method, limit, include_body=include_body
|
||||
)
|
||||
|
||||
tutils.raises("server disconnect", tst, "", "GET", None)
|
||||
|
@ -1,9 +1,12 @@
|
||||
import binascii, cStringIO
|
||||
import binascii
|
||||
import cStringIO
|
||||
from netlib import odict, http_auth, http
|
||||
import mock
|
||||
import tutils
|
||||
|
||||
|
||||
class TestPassManNonAnon:
|
||||
|
||||
def test_simple(self):
|
||||
p = http_auth.PassManNonAnon()
|
||||
assert not p.test("", "")
|
||||
@ -11,6 +14,7 @@ class TestPassManNonAnon:
|
||||
|
||||
|
||||
class TestPassManHtpasswd:
|
||||
|
||||
def test_file_errors(self):
|
||||
tutils.raises("malformed htpasswd file", http_auth.PassManHtpasswd, tutils.test_data.path("data/server.crt"))
|
||||
|
||||
@ -27,6 +31,7 @@ class TestPassManHtpasswd:
|
||||
|
||||
|
||||
class TestPassManSingleUser:
|
||||
|
||||
def test_simple(self):
|
||||
pm = http_auth.PassManSingleUser("test", "test")
|
||||
assert pm.test("test", "test")
|
||||
@ -35,6 +40,7 @@ class TestPassManSingleUser:
|
||||
|
||||
|
||||
class TestNullProxyAuth:
|
||||
|
||||
def test_simple(self):
|
||||
na = http_auth.NullProxyAuth(http_auth.PassManNonAnon())
|
||||
assert not na.auth_challenge_headers()
|
||||
@ -43,6 +49,7 @@ class TestNullProxyAuth:
|
||||
|
||||
|
||||
class TestBasicProxyAuth:
|
||||
|
||||
def test_simple(self):
|
||||
ba = http_auth.BasicProxyAuth(http_auth.PassManNonAnon(), "test")
|
||||
h = odict.ODictCaseless()
|
||||
@ -60,7 +67,6 @@ class TestBasicProxyAuth:
|
||||
ba.clean(hdrs)
|
||||
assert not ba.AUTH_HEADER in hdrs
|
||||
|
||||
|
||||
hdrs[ba.AUTH_HEADER] = [""]
|
||||
assert not ba.authenticate(hdrs)
|
||||
|
||||
@ -77,25 +83,27 @@ class TestBasicProxyAuth:
|
||||
assert not ba.authenticate(hdrs)
|
||||
|
||||
|
||||
class Bunch: pass
|
||||
class Bunch:
|
||||
pass
|
||||
|
||||
|
||||
class TestAuthAction:
|
||||
|
||||
def test_nonanonymous(self):
|
||||
m = Bunch()
|
||||
aa = http_auth.NonanonymousAuthAction(None, "authenticator")
|
||||
aa(None, m, None, None)
|
||||
assert m.authenticator
|
||||
assert m.authenticator
|
||||
|
||||
def test_singleuser(self):
|
||||
m = Bunch()
|
||||
aa = http_auth.SingleuserAuthAction(None, "authenticator")
|
||||
aa(None, m, "foo:bar", None)
|
||||
assert m.authenticator
|
||||
assert m.authenticator
|
||||
tutils.raises("invalid", aa, None, m, "foo", None)
|
||||
|
||||
def test_httppasswd(self):
|
||||
m = Bunch()
|
||||
aa = http_auth.HtpasswdAuthAction(None, "authenticator")
|
||||
aa(None, m, tutils.test_data.path("data/htpasswd"), None)
|
||||
assert m.authenticator
|
||||
assert m.authenticator
|
||||
|
@ -4,4 +4,3 @@ from netlib import http_uastrings
|
||||
def test_get_shortcut():
|
||||
assert http_uastrings.get_by_shortcut("c")[0] == "chrome"
|
||||
assert not http_uastrings.get_by_shortcut("_")
|
||||
|
||||
|
@ -3,6 +3,7 @@ import tutils
|
||||
|
||||
|
||||
class TestODict:
|
||||
|
||||
def setUp(self):
|
||||
self.od = odict.ODict()
|
||||
|
||||
@ -106,13 +107,13 @@ class TestODict:
|
||||
def test_get(self):
|
||||
self.od.add("one", "two")
|
||||
assert self.od.get("one") == ["two"]
|
||||
assert self.od.get("two") == None
|
||||
assert self.od.get("two") is None
|
||||
|
||||
def test_get_first(self):
|
||||
self.od.add("one", "two")
|
||||
self.od.add("one", "three")
|
||||
assert self.od.get_first("one") == "two"
|
||||
assert self.od.get_first("two") == None
|
||||
assert self.od.get_first("two") is None
|
||||
|
||||
def test_extend(self):
|
||||
a = odict.ODict([["a", "b"], ["c", "d"]])
|
||||
@ -121,7 +122,9 @@ class TestODict:
|
||||
assert len(a) == 4
|
||||
assert a["a"] == ["b", "b"]
|
||||
|
||||
|
||||
class TestODictCaseless:
|
||||
|
||||
def setUp(self):
|
||||
self.od = odict.ODictCaseless()
|
||||
|
||||
|
190
test/test_tcp.py
190
test/test_tcp.py
@ -1,4 +1,8 @@
|
||||
import cStringIO, Queue, time, socket, random
|
||||
import cStringIO
|
||||
import Queue
|
||||
import time
|
||||
import socket
|
||||
import random
|
||||
import os
|
||||
from netlib import tcp, certutils, test, certffi
|
||||
import threading
|
||||
@ -6,8 +10,10 @@ import mock
|
||||
import tutils
|
||||
from OpenSSL import SSL
|
||||
|
||||
|
||||
class EchoHandler(tcp.BaseHandler):
|
||||
sni = None
|
||||
|
||||
def handle_sni(self, connection):
|
||||
self.sni = connection.get_servername()
|
||||
|
||||
@ -19,19 +25,22 @@ class EchoHandler(tcp.BaseHandler):
|
||||
|
||||
class ClientCipherListHandler(tcp.BaseHandler):
|
||||
sni = None
|
||||
|
||||
def handle(self):
|
||||
self.wfile.write("%s"%self.connection.get_cipher_list())
|
||||
self.wfile.write("%s" % self.connection.get_cipher_list())
|
||||
self.wfile.flush()
|
||||
|
||||
|
||||
class HangHandler(tcp.BaseHandler):
|
||||
|
||||
def handle(self):
|
||||
while 1:
|
||||
while True:
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
class TestServer(test.ServerTestBase):
|
||||
handler = EchoHandler
|
||||
|
||||
def test_echo(self):
|
||||
testval = "echo!\n"
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
@ -51,7 +60,9 @@ class TestServer(test.ServerTestBase):
|
||||
|
||||
|
||||
class TestServerBind(test.ServerTestBase):
|
||||
|
||||
class handler(tcp.BaseHandler):
|
||||
|
||||
def handle(self):
|
||||
self.wfile.write(str(self.connection.getpeername()))
|
||||
self.wfile.flush()
|
||||
@ -65,7 +76,7 @@ class TestServerBind(test.ServerTestBase):
|
||||
c.connect()
|
||||
assert c.rfile.readline() == str(("127.0.0.1", random_port))
|
||||
return
|
||||
except tcp.NetLibError: # port probably already in use
|
||||
except tcp.NetLibError: # port probably already in use
|
||||
pass
|
||||
|
||||
|
||||
@ -84,6 +95,7 @@ class TestServerIPv6(test.ServerTestBase):
|
||||
|
||||
class TestEcho(test.ServerTestBase):
|
||||
handler = EchoHandler
|
||||
|
||||
def test_echo(self):
|
||||
testval = "echo!\n"
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
@ -94,16 +106,19 @@ class TestEcho(test.ServerTestBase):
|
||||
|
||||
|
||||
class HardDisconnectHandler(tcp.BaseHandler):
|
||||
|
||||
def handle(self):
|
||||
self.connection.close()
|
||||
|
||||
|
||||
class TestFinishFail(test.ServerTestBase):
|
||||
|
||||
"""
|
||||
This tests a difficult-to-trigger exception in the .finish() method of
|
||||
the handler.
|
||||
"""
|
||||
handler = EchoHandler
|
||||
|
||||
def test_disconnect_in_finish(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
@ -115,13 +130,14 @@ class TestFinishFail(test.ServerTestBase):
|
||||
class TestServerSSL(test.ServerTestBase):
|
||||
handler = EchoHandler
|
||||
ssl = dict(
|
||||
cert = tutils.test_data.path("data/server.crt"),
|
||||
key = tutils.test_data.path("data/server.key"),
|
||||
request_client_cert = False,
|
||||
v3_only = False,
|
||||
cipher_list = "AES256-SHA",
|
||||
chain_file=tutils.test_data.path("data/server.crt")
|
||||
)
|
||||
cert=tutils.test_data.path("data/server.crt"),
|
||||
key=tutils.test_data.path("data/server.key"),
|
||||
request_client_cert=False,
|
||||
v3_only=False,
|
||||
cipher_list="AES256-SHA",
|
||||
chain_file=tutils.test_data.path("data/server.crt")
|
||||
)
|
||||
|
||||
def test_echo(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
@ -144,11 +160,12 @@ class TestServerSSL(test.ServerTestBase):
|
||||
class TestSSLv3Only(test.ServerTestBase):
|
||||
handler = EchoHandler
|
||||
ssl = dict(
|
||||
cert = tutils.test_data.path("data/server.crt"),
|
||||
key = tutils.test_data.path("data/server.key"),
|
||||
request_client_cert = False,
|
||||
v3_only = True
|
||||
cert=tutils.test_data.path("data/server.crt"),
|
||||
key=tutils.test_data.path("data/server.key"),
|
||||
request_client_cert=False,
|
||||
v3_only=True
|
||||
)
|
||||
|
||||
def test_failure(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
@ -156,20 +173,23 @@ class TestSSLv3Only(test.ServerTestBase):
|
||||
|
||||
|
||||
class TestSSLClientCert(test.ServerTestBase):
|
||||
|
||||
class handler(tcp.BaseHandler):
|
||||
sni = None
|
||||
|
||||
def handle_sni(self, connection):
|
||||
self.sni = connection.get_servername()
|
||||
|
||||
def handle(self):
|
||||
self.wfile.write("%s\n"%self.clientcert.serial)
|
||||
self.wfile.write("%s\n" % self.clientcert.serial)
|
||||
self.wfile.flush()
|
||||
ssl = dict(
|
||||
cert = tutils.test_data.path("data/server.crt"),
|
||||
key = tutils.test_data.path("data/server.key"),
|
||||
request_client_cert = True,
|
||||
v3_only = False
|
||||
cert=tutils.test_data.path("data/server.crt"),
|
||||
key=tutils.test_data.path("data/server.key"),
|
||||
request_client_cert=True,
|
||||
v3_only=False
|
||||
)
|
||||
|
||||
def test_clientcert(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
@ -187,8 +207,10 @@ class TestSSLClientCert(test.ServerTestBase):
|
||||
|
||||
|
||||
class TestSNI(test.ServerTestBase):
|
||||
|
||||
class handler(tcp.BaseHandler):
|
||||
sni = None
|
||||
|
||||
def handle_sni(self, connection):
|
||||
self.sni = connection.get_servername()
|
||||
|
||||
@ -197,11 +219,12 @@ class TestSNI(test.ServerTestBase):
|
||||
self.wfile.flush()
|
||||
|
||||
ssl = dict(
|
||||
cert = tutils.test_data.path("data/server.crt"),
|
||||
key = tutils.test_data.path("data/server.key"),
|
||||
request_client_cert = False,
|
||||
v3_only = False
|
||||
cert=tutils.test_data.path("data/server.crt"),
|
||||
key=tutils.test_data.path("data/server.key"),
|
||||
request_client_cert=False,
|
||||
v3_only=False
|
||||
)
|
||||
|
||||
def test_echo(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
@ -213,12 +236,13 @@ class TestSNI(test.ServerTestBase):
|
||||
class TestServerCipherList(test.ServerTestBase):
|
||||
handler = ClientCipherListHandler
|
||||
ssl = dict(
|
||||
cert = tutils.test_data.path("data/server.crt"),
|
||||
key = tutils.test_data.path("data/server.key"),
|
||||
request_client_cert = False,
|
||||
v3_only = False,
|
||||
cipher_list = 'RC4-SHA'
|
||||
cert=tutils.test_data.path("data/server.crt"),
|
||||
key=tutils.test_data.path("data/server.key"),
|
||||
request_client_cert=False,
|
||||
v3_only=False,
|
||||
cipher_list='RC4-SHA'
|
||||
)
|
||||
|
||||
def test_echo(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
@ -227,18 +251,21 @@ class TestServerCipherList(test.ServerTestBase):
|
||||
|
||||
|
||||
class TestServerCurrentCipher(test.ServerTestBase):
|
||||
|
||||
class handler(tcp.BaseHandler):
|
||||
sni = None
|
||||
|
||||
def handle(self):
|
||||
self.wfile.write("%s"%str(self.get_current_cipher()))
|
||||
self.wfile.write("%s" % str(self.get_current_cipher()))
|
||||
self.wfile.flush()
|
||||
ssl = dict(
|
||||
cert = tutils.test_data.path("data/server.crt"),
|
||||
key = tutils.test_data.path("data/server.key"),
|
||||
request_client_cert = False,
|
||||
v3_only = False,
|
||||
cipher_list = 'RC4-SHA'
|
||||
cert=tutils.test_data.path("data/server.crt"),
|
||||
key=tutils.test_data.path("data/server.key"),
|
||||
request_client_cert=False,
|
||||
v3_only=False,
|
||||
cipher_list='RC4-SHA'
|
||||
)
|
||||
|
||||
def test_echo(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
@ -249,12 +276,13 @@ class TestServerCurrentCipher(test.ServerTestBase):
|
||||
class TestServerCipherListError(test.ServerTestBase):
|
||||
handler = ClientCipherListHandler
|
||||
ssl = dict(
|
||||
cert = tutils.test_data.path("data/server.crt"),
|
||||
key = tutils.test_data.path("data/server.key"),
|
||||
request_client_cert = False,
|
||||
v3_only = False,
|
||||
cipher_list = 'bogus'
|
||||
cert=tutils.test_data.path("data/server.crt"),
|
||||
key=tutils.test_data.path("data/server.key"),
|
||||
request_client_cert=False,
|
||||
v3_only=False,
|
||||
cipher_list='bogus'
|
||||
)
|
||||
|
||||
def test_echo(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
@ -264,12 +292,13 @@ class TestServerCipherListError(test.ServerTestBase):
|
||||
class TestClientCipherListError(test.ServerTestBase):
|
||||
handler = ClientCipherListHandler
|
||||
ssl = dict(
|
||||
cert = tutils.test_data.path("data/server.crt"),
|
||||
key = tutils.test_data.path("data/server.key"),
|
||||
request_client_cert = False,
|
||||
v3_only = False,
|
||||
cipher_list = 'RC4-SHA'
|
||||
cert=tutils.test_data.path("data/server.crt"),
|
||||
key=tutils.test_data.path("data/server.key"),
|
||||
request_client_cert=False,
|
||||
v3_only=False,
|
||||
cipher_list='RC4-SHA'
|
||||
)
|
||||
|
||||
def test_echo(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
@ -277,15 +306,18 @@ class TestClientCipherListError(test.ServerTestBase):
|
||||
|
||||
|
||||
class TestSSLDisconnect(test.ServerTestBase):
|
||||
|
||||
class handler(tcp.BaseHandler):
|
||||
|
||||
def handle(self):
|
||||
self.finish()
|
||||
ssl = dict(
|
||||
cert = tutils.test_data.path("data/server.crt"),
|
||||
key = tutils.test_data.path("data/server.key"),
|
||||
request_client_cert = False,
|
||||
v3_only = False
|
||||
cert=tutils.test_data.path("data/server.crt"),
|
||||
key=tutils.test_data.path("data/server.key"),
|
||||
request_client_cert=False,
|
||||
v3_only=False
|
||||
)
|
||||
|
||||
def test_echo(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
@ -300,11 +332,12 @@ class TestSSLDisconnect(test.ServerTestBase):
|
||||
class TestSSLHardDisconnect(test.ServerTestBase):
|
||||
handler = HardDisconnectHandler
|
||||
ssl = dict(
|
||||
cert = tutils.test_data.path("data/server.crt"),
|
||||
key = tutils.test_data.path("data/server.key"),
|
||||
request_client_cert = False,
|
||||
v3_only = False
|
||||
cert=tutils.test_data.path("data/server.crt"),
|
||||
key=tutils.test_data.path("data/server.key"),
|
||||
request_client_cert=False,
|
||||
v3_only=False
|
||||
)
|
||||
|
||||
def test_echo(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
@ -316,6 +349,7 @@ class TestSSLHardDisconnect(test.ServerTestBase):
|
||||
|
||||
|
||||
class TestDisconnect(test.ServerTestBase):
|
||||
|
||||
def test_echo(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
@ -326,7 +360,9 @@ class TestDisconnect(test.ServerTestBase):
|
||||
|
||||
|
||||
class TestServerTimeOut(test.ServerTestBase):
|
||||
|
||||
class handler(tcp.BaseHandler):
|
||||
|
||||
def handle(self):
|
||||
self.timeout = False
|
||||
self.settimeout(0.01)
|
||||
@ -344,6 +380,7 @@ class TestServerTimeOut(test.ServerTestBase):
|
||||
|
||||
class TestTimeOut(test.ServerTestBase):
|
||||
handler = HangHandler
|
||||
|
||||
def test_timeout(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
@ -355,11 +392,12 @@ class TestTimeOut(test.ServerTestBase):
|
||||
class TestSSLTimeOut(test.ServerTestBase):
|
||||
handler = HangHandler
|
||||
ssl = dict(
|
||||
cert = tutils.test_data.path("data/server.crt"),
|
||||
key = tutils.test_data.path("data/server.key"),
|
||||
request_client_cert = False,
|
||||
v3_only = False
|
||||
cert=tutils.test_data.path("data/server.crt"),
|
||||
key=tutils.test_data.path("data/server.key"),
|
||||
request_client_cert=False,
|
||||
v3_only=False
|
||||
)
|
||||
|
||||
def test_timeout_client(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
@ -371,15 +409,16 @@ class TestSSLTimeOut(test.ServerTestBase):
|
||||
class TestDHParams(test.ServerTestBase):
|
||||
handler = HangHandler
|
||||
ssl = dict(
|
||||
cert = tutils.test_data.path("data/server.crt"),
|
||||
key = tutils.test_data.path("data/server.key"),
|
||||
request_client_cert = False,
|
||||
v3_only = False,
|
||||
dhparams = certutils.CertStore.load_dhparam(
|
||||
cert=tutils.test_data.path("data/server.crt"),
|
||||
key=tutils.test_data.path("data/server.key"),
|
||||
request_client_cert=False,
|
||||
v3_only=False,
|
||||
dhparams=certutils.CertStore.load_dhparam(
|
||||
tutils.test_data.path("data/dhparam.pem"),
|
||||
),
|
||||
cipher_list = "DHE-RSA-AES256-SHA"
|
||||
cipher_list="DHE-RSA-AES256-SHA"
|
||||
)
|
||||
|
||||
def test_dhparams(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
@ -395,7 +434,9 @@ class TestDHParams(test.ServerTestBase):
|
||||
|
||||
|
||||
class TestPrivkeyGen(test.ServerTestBase):
|
||||
|
||||
class handler(tcp.BaseHandler):
|
||||
|
||||
def handle(self):
|
||||
with tutils.tmpdir() as d:
|
||||
ca1 = certutils.CertStore.from_store(d, "test2")
|
||||
@ -411,7 +452,9 @@ class TestPrivkeyGen(test.ServerTestBase):
|
||||
|
||||
|
||||
class TestPrivkeyGenNoFlags(test.ServerTestBase):
|
||||
|
||||
class handler(tcp.BaseHandler):
|
||||
|
||||
def handle(self):
|
||||
with tutils.tmpdir() as d:
|
||||
ca1 = certutils.CertStore.from_store(d, "test2")
|
||||
@ -426,14 +469,15 @@ class TestPrivkeyGenNoFlags(test.ServerTestBase):
|
||||
tutils.raises("sslv3 alert handshake failure", c.convert_to_ssl)
|
||||
|
||||
|
||||
|
||||
class TestTCPClient:
|
||||
|
||||
def test_conerr(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", 0))
|
||||
tutils.raises(tcp.NetLibError, c.connect)
|
||||
|
||||
|
||||
class TestFileLike:
|
||||
|
||||
def test_blocksize(self):
|
||||
s = cStringIO.StringIO("1234567890abcdefghijklmnopqrstuvwxyz")
|
||||
s = tcp.Reader(s)
|
||||
@ -460,7 +504,7 @@ class TestFileLike:
|
||||
assert s.readline(3) == "foo"
|
||||
|
||||
def test_limitless(self):
|
||||
s = cStringIO.StringIO("f"*(50*1024))
|
||||
s = cStringIO.StringIO("f" * (50 * 1024))
|
||||
s = tcp.Reader(s)
|
||||
ret = s.read(-1)
|
||||
assert len(ret) == 50 * 1024
|
||||
@ -551,7 +595,9 @@ class TestFileLike:
|
||||
s = tcp.Reader(o)
|
||||
tutils.raises(tcp.NetLibDisconnect, s.readline, 10)
|
||||
|
||||
|
||||
class TestAddress:
|
||||
|
||||
def test_simple(self):
|
||||
a = tcp.Address("localhost", True)
|
||||
assert a.use_ipv6
|
||||
@ -566,12 +612,12 @@ class TestAddress:
|
||||
class TestSSLKeyLogger(test.ServerTestBase):
|
||||
handler = EchoHandler
|
||||
ssl = dict(
|
||||
cert = tutils.test_data.path("data/server.crt"),
|
||||
key = tutils.test_data.path("data/server.key"),
|
||||
request_client_cert = False,
|
||||
v3_only = False,
|
||||
cipher_list = "AES256-SHA"
|
||||
)
|
||||
cert=tutils.test_data.path("data/server.crt"),
|
||||
key=tutils.test_data.path("data/server.key"),
|
||||
request_client_cert=False,
|
||||
v3_only=False,
|
||||
cipher_list="AES256-SHA"
|
||||
)
|
||||
|
||||
def test_log(self):
|
||||
testval = "echo!\n"
|
||||
@ -597,4 +643,4 @@ class TestSSLKeyLogger(test.ServerTestBase):
|
||||
|
||||
def test_create_logfun(self):
|
||||
assert isinstance(tcp.SSLKeyLogger.create_logfun("test"), tcp.SSLKeyLogger)
|
||||
assert not tcp.SSLKeyLogger.create_logfun(False)
|
||||
assert not tcp.SSLKeyLogger.create_logfun(False)
|
||||
|
@ -12,7 +12,7 @@ def test_bidi():
|
||||
|
||||
|
||||
def test_hexdump():
|
||||
assert utils.hexdump("one\0"*10)
|
||||
assert utils.hexdump("one\0" * 10)
|
||||
|
||||
|
||||
def test_cleanBin():
|
||||
@ -25,5 +25,5 @@ def test_cleanBin():
|
||||
def test_pretty_size():
|
||||
assert utils.pretty_size(100) == "100B"
|
||||
assert utils.pretty_size(1024) == "1kB"
|
||||
assert utils.pretty_size(1024 + (1024/2.0)) == "1.5kB"
|
||||
assert utils.pretty_size(1024*1024) == "1MB"
|
||||
assert utils.pretty_size(1024 + (1024 / 2.0)) == "1.5kB"
|
||||
assert utils.pretty_size(1024 * 1024) == "1MB"
|
||||
|
@ -7,6 +7,7 @@ import tutils
|
||||
|
||||
|
||||
class WebSocketsEchoHandler(tcp.BaseHandler):
|
||||
|
||||
def __init__(self, connection, address, server):
|
||||
super(WebSocketsEchoHandler, self).__init__(
|
||||
connection, address, server
|
||||
@ -25,7 +26,7 @@ class WebSocketsEchoHandler(tcp.BaseHandler):
|
||||
self.on_message(frame.payload)
|
||||
|
||||
def send_message(self, message):
|
||||
frame = websockets.Frame.default(message, from_client = False)
|
||||
frame = websockets.Frame.default(message, from_client=False)
|
||||
frame.to_file(self.wfile)
|
||||
|
||||
def handshake(self):
|
||||
@ -44,6 +45,7 @@ class WebSocketsEchoHandler(tcp.BaseHandler):
|
||||
|
||||
|
||||
class WebSocketsClient(tcp.TCPClient):
|
||||
|
||||
def __init__(self, address, source_address=None):
|
||||
super(WebSocketsClient, self).__init__(address, source_address)
|
||||
self.client_nonce = None
|
||||
@ -68,14 +70,14 @@ class WebSocketsClient(tcp.TCPClient):
|
||||
return websockets.Frame.from_file(self.rfile).payload
|
||||
|
||||
def send_message(self, message):
|
||||
frame = websockets.Frame.default(message, from_client = True)
|
||||
frame = websockets.Frame.default(message, from_client=True)
|
||||
frame.to_file(self.wfile)
|
||||
|
||||
|
||||
class TestWebSockets(test.ServerTestBase):
|
||||
handler = WebSocketsEchoHandler
|
||||
|
||||
def random_bytes(self, n = 100):
|
||||
def random_bytes(self, n=100):
|
||||
return os.urandom(n)
|
||||
|
||||
def echo(self, msg):
|
||||
@ -105,8 +107,8 @@ class TestWebSockets(test.ServerTestBase):
|
||||
default builder should always generate valid frames
|
||||
"""
|
||||
msg = self.random_bytes()
|
||||
client_frame = websockets.Frame.default(msg, from_client = True)
|
||||
server_frame = websockets.Frame.default(msg, from_client = False)
|
||||
client_frame = websockets.Frame.default(msg, from_client=True)
|
||||
server_frame = websockets.Frame.default(msg, from_client=False)
|
||||
|
||||
def test_serialization_bijection(self):
|
||||
"""
|
||||
@ -140,6 +142,7 @@ class TestWebSockets(test.ServerTestBase):
|
||||
|
||||
|
||||
class BadHandshakeHandler(WebSocketsEchoHandler):
|
||||
|
||||
def handshake(self):
|
||||
client_hs = http.read_request(self.rfile)
|
||||
websockets.check_client_handshake(client_hs.headers)
|
||||
@ -152,6 +155,7 @@ class BadHandshakeHandler(WebSocketsEchoHandler):
|
||||
|
||||
|
||||
class TestBadHandshake(test.ServerTestBase):
|
||||
|
||||
"""
|
||||
Ensure that the client disconnects if the server handshake is malformed
|
||||
"""
|
||||
@ -165,6 +169,7 @@ class TestBadHandshake(test.ServerTestBase):
|
||||
|
||||
|
||||
class TestFrameHeader:
|
||||
|
||||
def test_roundtrip(self):
|
||||
def round(*args, **kwargs):
|
||||
f = websockets.FrameHeader(*args, **kwargs)
|
||||
@ -216,6 +221,7 @@ class TestFrameHeader:
|
||||
|
||||
|
||||
class TestFrame:
|
||||
|
||||
def test_roundtrip(self):
|
||||
def round(*args, **kwargs):
|
||||
f = websockets.Frame(*args, **kwargs)
|
||||
@ -240,7 +246,7 @@ def test_masker():
|
||||
["fourf"],
|
||||
["fourfive"],
|
||||
["a", "aasdfasdfa", "asdf"],
|
||||
["a"*50, "aasdfasdfa", "asdf"],
|
||||
["a" * 50, "aasdfasdfa", "asdf"],
|
||||
]
|
||||
for i in tests:
|
||||
m = websockets.Masker("abcd")
|
||||
|
@ -1,4 +1,5 @@
|
||||
import cStringIO, sys
|
||||
import cStringIO
|
||||
import sys
|
||||
from netlib import wsgi, odict
|
||||
|
||||
|
||||
@ -10,6 +11,7 @@ def tflow():
|
||||
|
||||
|
||||
class TestApp:
|
||||
|
||||
def __init__(self):
|
||||
self.called = False
|
||||
|
||||
@ -22,6 +24,7 @@ class TestApp:
|
||||
|
||||
|
||||
class TestWSGI:
|
||||
|
||||
def test_make_environ(self):
|
||||
w = wsgi.WSGIAdaptor(None, "foo", 80, "version")
|
||||
tf = tflow()
|
||||
|
@ -44,14 +44,14 @@ def raises(exc, obj, *args, **kwargs):
|
||||
:kwargs Arguments to be passed to the callable.
|
||||
"""
|
||||
try:
|
||||
ret = apply(obj, args, kwargs)
|
||||
except Exception, v:
|
||||
ret = obj(*args, **kwargs)
|
||||
except Exception as v:
|
||||
if isinstance(exc, basestring):
|
||||
if exc.lower() in str(v).lower():
|
||||
return
|
||||
else:
|
||||
raise AssertionError(
|
||||
"Expected %s, but caught %s"%(
|
||||
"Expected %s, but caught %s" % (
|
||||
repr(str(exc)), v
|
||||
)
|
||||
)
|
||||
@ -60,7 +60,7 @@ def raises(exc, obj, *args, **kwargs):
|
||||
return
|
||||
else:
|
||||
raise AssertionError(
|
||||
"Expected %s, but caught %s %s"%(
|
||||
"Expected %s, but caught %s %s" % (
|
||||
exc.__name__, v.__class__.__name__, str(v)
|
||||
)
|
||||
)
|
||||
|
@ -22,6 +22,6 @@ else:
|
||||
cert = get_remote_cert(sys.argv[1], port, sni)
|
||||
print "CN:", cert.cn
|
||||
if cert.altnames:
|
||||
print "SANs:",
|
||||
print "SANs:",
|
||||
for i in cert.altnames:
|
||||
print "\t", i
|
||||
|
Loading…
Reference in New Issue
Block a user