fix minor bugs, add py.test compatibility

This commit is contained in:
Maximilian Hils 2015-09-21 02:26:47 +02:00
parent 6d27901b6f
commit 88375ad64a
6 changed files with 32 additions and 74 deletions

View File

@ -2,9 +2,10 @@
branch = True branch = True
[report] [report]
omit = *contrib*, *tnetstring*, *platform*, *console*, *main.py show_missing = True
include = *libmproxy* include = *libmproxy*
exclude_lines = exclude_lines =
pragma: nocover pragma: nocover
pragma: no cover pragma: no cover
raise NotImplementedError() raise NotImplementedError()
omit = *contrib*, *tnetstring*, *platform*, *console*, *main.py

View File

@ -8,6 +8,7 @@ from libmproxy import utils
from netlib import encoding from netlib import encoding
from netlib.http import status_codes, Headers, Request, Response, CONTENT_MISSING from netlib.http import status_codes, Headers, Request, Response, CONTENT_MISSING
from netlib.tcp import Address from netlib.tcp import Address
from netlib.utils import native
from .. import version, stateobject from .. import version, stateobject
from .flow import Flow from .flow import Flow
@ -497,6 +498,8 @@ class decoded(object):
def __init__(self, o): def __init__(self, o):
self.o = o self.o = o
ce = o.headers.get("content-encoding") ce = o.headers.get("content-encoding")
if ce:
ce = native(ce, "ascii", "ignore")
if ce in encoding.ENCODINGS: if ce in encoding.ENCODINGS:
self.ce = ce self.ce = ce
else: else:
@ -504,7 +507,8 @@ class decoded(object):
def __enter__(self): def __enter__(self):
if self.ce: if self.ce:
self.o.decode() if not self.o.decode():
self.ce = None
def __exit__(self, type, value, tb): def __exit__(self, type, value, tb):
if self.ce: if self.ce:

View File

@ -13,14 +13,10 @@ from netlib import http, tcp
from netlib.http import http1 from netlib.http import http1
class TestServerConnection: class TestServerConnection(object):
def setUp(self):
self.d = test.Daemon()
def tearDown(self):
self.d.shutdown()
def test_simple(self): def test_simple(self):
self.d = test.Daemon()
sc = ServerConnection((self.d.IFACE, self.d.port)) sc = ServerConnection((self.d.IFACE, self.d.port))
sc.connect() sc.connect()
f = tutils.tflow() f = tutils.tflow()
@ -35,14 +31,17 @@ class TestServerConnection:
assert self.d.last_log() assert self.d.last_log()
sc.finish() sc.finish()
self.d.shutdown()
def test_terminate_error(self): def test_terminate_error(self):
self.d = test.Daemon()
sc = ServerConnection((self.d.IFACE, self.d.port)) sc = ServerConnection((self.d.IFACE, self.d.port))
sc.connect() sc.connect()
sc.connection = mock.Mock() sc.connection = mock.Mock()
sc.connection.recv = mock.Mock(return_value=False) sc.connection.recv = mock.Mock(return_value=False)
sc.connection.flush = mock.Mock(side_effect=TcpDisconnect) sc.connection.flush = mock.Mock(side_effect=TcpDisconnect)
sc.finish() sc.finish()
self.d.shutdown()
def test_repr(self): def test_repr(self):
sc = tutils.tserver_conn() sc = tutils.tserver_conn()

View File

@ -158,7 +158,7 @@ class TcpMixin:
def _tcpproxy_off(self): def _tcpproxy_off(self):
assert hasattr(self, "_tcpproxy_backup") assert hasattr(self, "_tcpproxy_backup")
self.config.check_ignore = self._tcpproxy_backup self.config.check_tcp = self._tcpproxy_backup
del self._tcpproxy_backup del self._tcpproxy_backup
def test_tcp(self): def test_tcp(self):

View File

@ -89,7 +89,7 @@ class ProxTestBase(object):
masterclass = TestMaster masterclass = TestMaster
@classmethod @classmethod
def setupAll(cls): def setup_class(cls):
cls.server = libpathod.test.Daemon( cls.server = libpathod.test.Daemon(
ssl=cls.ssl, ssl=cls.ssl,
ssloptions=cls.ssloptions) ssloptions=cls.ssloptions)
@ -105,13 +105,15 @@ class ProxTestBase(object):
cls.proxy.start() cls.proxy.start()
@classmethod @classmethod
def teardownAll(cls): def teardown_class(cls):
shutil.rmtree(cls.cadir) # perf: we want to run tests in parallell
# should this ever cause an error, travis should catch it.
# shutil.rmtree(cls.cadir)
cls.proxy.shutdown() cls.proxy.shutdown()
cls.server.shutdown() cls.server.shutdown()
cls.server2.shutdown() cls.server2.shutdown()
def setUp(self): def setup(self):
self.master.clear_log() self.master.clear_log()
self.master.state.clear() self.master.state.clear()
self.server.clear_log() self.server.clear_log()
@ -185,8 +187,8 @@ class TransparentProxTest(ProxTestBase):
resolver = TResolver resolver = TResolver
@classmethod @classmethod
def setupAll(cls): def setup_class(cls):
super(TransparentProxTest, cls).setupAll() super(TransparentProxTest, cls).setup_class()
cls._resolver = mock.patch( cls._resolver = mock.patch(
"libmproxy.platform.resolver", "libmproxy.platform.resolver",
@ -195,9 +197,9 @@ class TransparentProxTest(ProxTestBase):
cls._resolver.start() cls._resolver.start()
@classmethod @classmethod
def teardownAll(cls): def teardown_class(cls):
cls._resolver.stop() cls._resolver.stop()
super(TransparentProxTest, cls).teardownAll() super(TransparentProxTest, cls).teardown_class()
@classmethod @classmethod
def get_proxy_config(cls): def get_proxy_config(cls):
@ -283,9 +285,9 @@ class ChainProxTest(ProxTestBase):
n = 2 n = 2
@classmethod @classmethod
def setupAll(cls): def setup_class(cls):
cls.chain = [] cls.chain = []
super(ChainProxTest, cls).setupAll() super(ChainProxTest, cls).setup_class()
for _ in range(cls.n): for _ in range(cls.n):
config = ProxyConfig(**cls.get_proxy_config()) config = ProxyConfig(**cls.get_proxy_config())
tmaster = cls.masterclass(config) tmaster = cls.masterclass(config)
@ -298,13 +300,13 @@ class ChainProxTest(ProxTestBase):
**cls.get_proxy_config()) **cls.get_proxy_config())
@classmethod @classmethod
def teardownAll(cls): def teardown_class(cls):
super(ChainProxTest, cls).teardownAll() super(ChainProxTest, cls).teardown_class()
for proxy in cls.chain: for proxy in cls.chain:
proxy.shutdown() proxy.shutdown()
def setUp(self): def setup(self):
super(ChainProxTest, self).setUp() super(ChainProxTest, self).setup()
for proxy in self.chain: for proxy in self.chain:
proxy.tmaster.clear_log() proxy.tmaster.clear_log()
proxy.tmaster.state.clear() proxy.tmaster.state.clear()

View File

@ -18,7 +18,7 @@ from libmproxy.console.flowview import FlowView
from libmproxy.console import ConsoleState from libmproxy.console import ConsoleState
def _SkipWindows(): def _SkipWindows(*args):
raise SkipTest("Skipped on Windows.") raise SkipTest("Skipped on Windows.")
@ -96,18 +96,6 @@ def terr(content="error"):
return err return err
def tflowview(request_contents=None):
m = Mock()
cs = ConsoleState()
if request_contents is None:
flow = tflow()
else:
flow = tflow(req=netlib.tutils.treq(body=request_contents))
fv = FlowView(m, cs, flow)
return fv
def get_body_line(last_displayed_body, line_nb): def get_body_line(last_displayed_body, line_nb):
return last_displayed_body.contents()[line_nb + 2] return last_displayed_body.contents()[line_nb + 2]
@ -134,43 +122,7 @@ class MockParser(argparse.ArgumentParser):
raise Exception(message) raise Exception(message)
def raises(exc, obj, *args, **kwargs): raises = netlib.tutils.raises
"""
Assert that a callable raises a specified exception.
:exc An exception class or a string. If a class, assert that an
exception of this type is raised. If a string, assert that the string
occurs in the string representation of the exception, based on a
case-insenstivie match.
:obj A callable object.
:args Arguments to be passsed to the callable.
:kwargs Arguments to be passed to the callable.
"""
try:
obj(*args, **kwargs)
except Exception as v:
if isinstance(exc, basestring):
if exc.lower() in str(v).lower():
return
else:
raise AssertionError(
"Expected %s, but caught %s" % (
repr(str(exc)), v
)
)
else:
if isinstance(v, exc):
return
else:
raise AssertionError(
"Expected %s, but caught %s %s" % (
exc.__name__, v.__class__.__name__, str(v)
)
)
raise AssertionError("No exception raised.")
@contextmanager @contextmanager