Merge remote-tracking branch 'origin/master' into pr-2120

Conflicts:
	test/mitmproxy/addons/test_replace.py
This commit is contained in:
Maximilian Hils 2017-03-12 23:33:49 +01:00
commit 05e11547f5
33 changed files with 395 additions and 667 deletions

View File

@ -10,9 +10,6 @@ git:
matrix:
fast_finish: true
allow_failures:
- language: node_js
node_js: "node"
include:
- python: 3.5
env: TOXENV=lint

View File

@ -4,7 +4,7 @@ from mitmproxy.addons import check_alpn
from mitmproxy.addons import check_ca
from mitmproxy.addons import clientplayback
from mitmproxy.addons import core_option_validation
from mitmproxy.addons import disable_h2c_upgrade
from mitmproxy.addons import disable_h2c
from mitmproxy.addons import onboarding
from mitmproxy.addons import proxyauth
from mitmproxy.addons import replace
@ -26,7 +26,7 @@ def default_addons():
check_alpn.CheckALPN(),
check_ca.CheckCA(),
clientplayback.ClientPlayback(),
disable_h2c_upgrade.DisableH2CleartextUpgrade(),
disable_h2c.DisableH2C(),
onboarding.Onboarding(),
proxyauth.ProxyAuth(),
replace.Replace(),

View File

@ -0,0 +1,41 @@
import mitmproxy
class DisableH2C:
"""
We currently only support HTTP/2 over a TLS connection.
Some clients try to upgrade a connection from HTTP/1.1 to h2c. We need to
remove those headers to avoid protocol errors if one endpoints suddenly
starts sending HTTP/2 frames.
Some clients might use HTTP/2 Prior Knowledge to directly initiate a session
by sending the connection preface. We just kill those flows.
"""
def configure(self, options, updated):
pass
def process_flow(self, f):
if f.request.headers.get('upgrade', '') == 'h2c':
mitmproxy.ctx.log.warn("HTTP/2 cleartext connections (h2c upgrade requests) are currently not supported.")
del f.request.headers['upgrade']
if 'connection' in f.request.headers:
del f.request.headers['connection']
if 'http2-settings' in f.request.headers:
del f.request.headers['http2-settings']
is_connection_preface = (
f.request.method == 'PRI' and
f.request.path == '*' and
f.request.http_version == 'HTTP/2.0'
)
if is_connection_preface:
f.kill()
mitmproxy.ctx.log.warn("Initiating HTTP/2 connections with prior knowledge are currently not supported.")
# Handlers
def request(self, f):
self.process_flow(f)

View File

@ -1,21 +0,0 @@
class DisableH2CleartextUpgrade:
"""
We currently only support HTTP/2 over a TLS connection. Some clients try
to upgrade a connection from HTTP/1.1 to h2c, so we need to remove those
headers to avoid protocol errors if one endpoints suddenly starts sending
HTTP/2 frames.
"""
def process_flow(self, f):
if f.request.headers.get('upgrade', '') == 'h2c':
del f.request.headers['upgrade']
if 'connection' in f.request.headers:
del f.request.headers['connection']
if 'http2-settings' in f.request.headers:
del f.request.headers['http2-settings']
# Handlers
def request(self, f):
self.process_flow(f)

View File

@ -1,7 +1,7 @@
import binascii
import weakref
from typing import Optional
from typing import Set # noqa
from typing import MutableMapping # noqa
from typing import Tuple
import passlib.apache
@ -46,7 +46,7 @@ class ProxyAuth:
self.htpasswd = None
self.singleuser = None
self.mode = None
self.authenticated = weakref.WeakSet() # type: Set[connections.ClientConnection]
self.authenticated = weakref.WeakKeyDictionary() # type: MutableMapping[connections.ClientConnection, Tuple[str, str]]
"""Contains all connections that are permanently authenticated after an HTTP CONNECT"""
def enabled(self) -> bool:
@ -153,11 +153,12 @@ class ProxyAuth:
def http_connect(self, f: http.HTTPFlow) -> None:
if self.enabled():
if self.authenticate(f):
self.authenticated.add(f.client_conn)
self.authenticated[f.client_conn] = f.metadata["proxyauth"]
def requestheaders(self, f: http.HTTPFlow) -> None:
if self.enabled():
# Is this connection authenticated by a previous HTTP CONNECT?
if f.client_conn in self.authenticated:
f.metadata["proxyauth"] = self.authenticated[f.client_conn]
return
self.authenticate(f)

View File

@ -78,7 +78,7 @@ class Flow(stateobject.StateObject):
self._backup = None # type: typing.Optional[Flow]
self.reply = None # type: typing.Optional[controller.Reply]
self.marked = False # type: bool
self.metadata = dict() # type: typing.Dict[str, str]
self.metadata = dict() # type: typing.Dict[str, typing.Any]
_stateobject_attributes = dict(
id=str,

View File

@ -140,8 +140,8 @@ class WebSocketLayer(base.Layer):
def __call__(self):
self.flow = WebSocketFlow(self.client_conn, self.server_conn, self.handshake_flow, self)
self.flow.metadata['websocket_handshake'] = self.handshake_flow
self.handshake_flow.metadata['websocket_flow'] = self.flow
self.flow.metadata['websocket_handshake'] = self.handshake_flow.id
self.handshake_flow.metadata['websocket_flow'] = self.flow.id
self.channel.ask("websocket_start", self.flow)
client = self.client_conn.connection

View File

@ -39,6 +39,14 @@ class StateObject(serializable.Serializable):
state[attr] = val.get_state()
elif _is_list(cls):
state[attr] = [x.get_state() for x in val]
elif isinstance(val, dict):
s = {}
for k, v in val.items():
if hasattr(v, "get_state"):
s[k] = v.get_state()
else:
s[k] = v
state[attr] = s
else:
state[attr] = val
return state

View File

@ -70,6 +70,7 @@ def twebsocketflow(client_conn=True, server_conn=True, messages=True, err=None,
handshake_flow.response = resp
f = websocket.WebSocketFlow(client_conn, server_conn, handshake_flow)
handshake_flow.metadata['websocket_flow'] = f
if messages is True:
messages = [

View File

@ -1,9 +1,5 @@
from io import BytesIO
import tempfile
import os
import time
import shutil
from contextlib import contextmanager
from io import BytesIO
from mitmproxy.utils import data
from mitmproxy.net import tcp
@ -13,18 +9,6 @@ from mitmproxy.net import http
test_data = data.Data(__name__).push("../../test/")
@contextmanager
def tmpdir(*args, **kwargs):
orig_workdir = os.getcwd()
temp_workdir = tempfile.mkdtemp(*args, **kwargs)
os.chdir(temp_workdir)
yield temp_workdir
os.chdir(orig_workdir)
shutil.rmtree(temp_workdir)
def treader(bytes):
"""
Construct a tcp.Read object from bytes.

View File

@ -2,7 +2,6 @@ import time
from typing import List, Optional
from mitmproxy import flow
from mitmproxy.http import HTTPFlow
from mitmproxy.net import websockets
from mitmproxy.types import serializable
from mitmproxy.utils import strutils
@ -44,6 +43,22 @@ class WebSocketFlow(flow.Flow):
self.close_code = '(status code missing)'
self.close_message = '(message missing)'
self.close_reason = 'unknown status code'
if handshake_flow:
self.client_key = websockets.get_client_key(handshake_flow.request.headers)
self.client_protocol = websockets.get_protocol(handshake_flow.request.headers)
self.client_extensions = websockets.get_extensions(handshake_flow.request.headers)
self.server_accept = websockets.get_server_accept(handshake_flow.response.headers)
self.server_protocol = websockets.get_protocol(handshake_flow.response.headers)
self.server_extensions = websockets.get_extensions(handshake_flow.response.headers)
else:
self.client_key = ''
self.client_protocol = ''
self.client_extensions = ''
self.server_accept = ''
self.server_protocol = ''
self.server_extensions = ''
self.handshake_flow = handshake_flow
_stateobject_attributes = flow.Flow._stateobject_attributes.copy()
@ -53,7 +68,15 @@ class WebSocketFlow(flow.Flow):
close_code=str,
close_message=str,
close_reason=str,
handshake_flow=HTTPFlow,
client_key=str,
client_protocol=str,
client_extensions=str,
server_accept=str,
server_protocol=str,
server_extensions=str,
# Do not include handshake_flow, to prevent recursive serialization!
# Since mitmproxy-console currently only displays HTTPFlows,
# dumping the handshake_flow will include the WebSocketFlow too.
)
@classmethod
@ -65,30 +88,6 @@ class WebSocketFlow(flow.Flow):
def __repr__(self):
return "<WebSocketFlow ({} messages)>".format(len(self.messages))
@property
def client_key(self):
return websockets.get_client_key(self.handshake_flow.request.headers)
@property
def client_protocol(self):
return websockets.get_protocol(self.handshake_flow.request.headers)
@property
def client_extensions(self):
return websockets.get_extensions(self.handshake_flow.request.headers)
@property
def server_accept(self):
return websockets.get_server_accept(self.handshake_flow.response.headers)
@property
def server_protocol(self):
return websockets.get_protocol(self.handshake_flow.response.headers)
@property
def server_extensions(self):
return websockets.get_extensions(self.handshake_flow.response.headers)
def message_info(self, message: WebSocketMessage) -> str:
return "{client} {direction} WebSocket {type} message {direction} {server}{endpoint}".format(
type=message.type,

View File

@ -1,9 +1,7 @@
import os
import pytest
from unittest import mock
from mitmproxy.test import tflow
from mitmproxy.test import tutils
from mitmproxy import io
from mitmproxy import exceptions
@ -49,14 +47,13 @@ class TestClientPlayback:
cp.tick()
assert cp.current_thread is None
def test_configure(self):
def test_configure(self, tmpdir):
cp = clientplayback.ClientPlayback()
with taddons.context() as tctx:
with tutils.tmpdir() as td:
path = os.path.join(td, "flows")
tdump(path, [tflow.tflow()])
tctx.configure(cp, client_replay=[path])
tctx.configure(cp, client_replay=[])
tctx.configure(cp)
with pytest.raises(exceptions.OptionsError):
tctx.configure(cp, client_replay=["nonexistent"])
path = str(tmpdir.join("flows"))
tdump(path, [tflow.tflow()])
tctx.configure(cp, client_replay=[path])
tctx.configure(cp, client_replay=[])
tctx.configure(cp)
with pytest.raises(exceptions.OptionsError):
tctx.configure(cp, client_replay=["nonexistent"])

View File

@ -0,0 +1,39 @@
import io
from mitmproxy import http
from mitmproxy.addons import disable_h2c
from mitmproxy.net.http import http1
from mitmproxy.exceptions import Kill
from mitmproxy.test import tflow
from mitmproxy.test import taddons
class TestDisableH2CleartextUpgrade:
def test_upgrade(self):
with taddons.context() as tctx:
a = disable_h2c.DisableH2C()
tctx.configure(a)
f = tflow.tflow()
f.request.headers['upgrade'] = 'h2c'
f.request.headers['connection'] = 'foo'
f.request.headers['http2-settings'] = 'bar'
a.request(f)
assert 'upgrade' not in f.request.headers
assert 'connection' not in f.request.headers
assert 'http2-settings' not in f.request.headers
def test_prior_knowledge(self):
with taddons.context() as tctx:
a = disable_h2c.DisableH2C()
tctx.configure(a)
b = io.BytesIO(b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n")
f = tflow.tflow()
f.request = http.HTTPRequest.wrap(http1.read_request(b))
f.reply.handle()
f.intercept()
a.request(f)
assert not f.killable
assert f.reply.value == Kill

View File

@ -1,17 +0,0 @@
from mitmproxy.addons import disable_h2c_upgrade
from mitmproxy.test import tflow
class TestTermLog:
def test_simple(self):
a = disable_h2c_upgrade.DisableH2CleartextUpgrade()
f = tflow.tflow()
f.request.headers['upgrade'] = 'h2c'
f.request.headers['connection'] = 'foo'
f.request.headers['http2-settings'] = 'bar'
a.request(f)
assert 'upgrade' not in f.request.headers
assert 'connection' not in f.request.headers
assert 'http2-settings' not in f.request.headers

View File

@ -173,3 +173,4 @@ def test_handlers():
f2 = tflow.tflow(client_conn=f.client_conn)
up.requestheaders(f2)
assert not f2.response
assert f2.metadata["proxyauth"] == ('test', 'test')

View File

@ -68,13 +68,12 @@ class TestParseCommand:
with pytest.raises(ValueError):
script.parse_command(" ")
def test_no_script_file(self):
def test_no_script_file(self, tmpdir):
with pytest.raises(Exception, match="not found"):
script.parse_command("notfound")
with tutils.tmpdir() as dir:
with pytest.raises(Exception, match="Not a file"):
script.parse_command(dir)
with pytest.raises(Exception, match="Not a file"):
script.parse_command(str(tmpdir))
def test_parse_args(self):
with utils.chdir(tutils.test_data.dirname):
@ -128,21 +127,19 @@ class TestScript:
recf = sc.ns.call_log[0]
assert recf[1] == "request"
def test_reload(self):
def test_reload(self, tmpdir):
with taddons.context() as tctx:
with tutils.tmpdir():
with open("foo.py", "w"):
pass
sc = script.Script("foo.py")
tctx.configure(sc)
for _ in range(100):
with open("foo.py", "a") as f:
f.write(".")
sc.tick()
time.sleep(0.1)
if tctx.master.event_log:
return
raise AssertionError("Change event not detected.")
f = tmpdir.join("foo.py")
f.ensure(file=True)
sc = script.Script(str(f))
tctx.configure(sc)
for _ in range(100):
f.write(".")
sc.tick()
time.sleep(0.1)
if tctx.master.event_log:
return
raise AssertionError("Change event not detected.")
def test_exception(self):
with taddons.context() as tctx:

View File

@ -1,10 +1,8 @@
import os
import urllib
import pytest
from mitmproxy.test import tutils
from mitmproxy.test import tflow
from mitmproxy.test import taddons
from mitmproxy.test import tflow
import mitmproxy.test.tutils
from mitmproxy.addons import serverplayback
@ -19,15 +17,14 @@ def tdump(path, flows):
w.add(i)
def test_config():
def test_config(tmpdir):
s = serverplayback.ServerPlayback()
with tutils.tmpdir() as p:
with taddons.context() as tctx:
fpath = os.path.join(p, "flows")
tdump(fpath, [tflow.tflow(resp=True)])
tctx.configure(s, server_replay=[fpath])
with pytest.raises(exceptions.OptionsError):
tctx.configure(s, server_replay=[p])
with taddons.context() as tctx:
fpath = str(tmpdir.join("flows"))
tdump(fpath, [tflow.tflow(resp=True)])
tctx.configure(s, server_replay=[fpath])
with pytest.raises(exceptions.OptionsError):
tctx.configure(s, server_replay=[str(tmpdir)])
def test_tick():

View File

@ -1,9 +1,7 @@
import os.path
import pytest
from mitmproxy.test import tflow
from mitmproxy.test import tutils
from mitmproxy.test import taddons
from mitmproxy.test import tflow
from mitmproxy import io
from mitmproxy import exceptions
@ -11,19 +9,17 @@ from mitmproxy import options
from mitmproxy.addons import streamfile
def test_configure():
def test_configure(tmpdir):
sa = streamfile.StreamFile()
with taddons.context(options=options.Options()) as tctx:
with tutils.tmpdir() as tdir:
p = os.path.join(tdir, "foo")
with pytest.raises(exceptions.OptionsError):
tctx.configure(sa, streamfile=tdir)
with pytest.raises(Exception, match="Invalid filter"):
tctx.configure(sa, streamfile=p, filtstr="~~")
tctx.configure(sa, filtstr="foo")
assert sa.filt
tctx.configure(sa, filtstr=None)
assert not sa.filt
with pytest.raises(exceptions.OptionsError):
tctx.configure(sa, streamfile=str(tmpdir))
with pytest.raises(Exception, match="Invalid filter"):
tctx.configure(sa, streamfile=str(tmpdir.join("foo")), filtstr="~~")
tctx.configure(sa, filtstr="foo")
assert sa.filt
tctx.configure(sa, filtstr=None)
assert not sa.filt
def rd(p):
@ -31,36 +27,34 @@ def rd(p):
return list(x.stream())
def test_tcp():
def test_tcp(tmpdir):
sa = streamfile.StreamFile()
with taddons.context() as tctx:
with tutils.tmpdir() as tdir:
p = os.path.join(tdir, "foo")
tctx.configure(sa, streamfile=p)
p = str(tmpdir.join("foo"))
tctx.configure(sa, streamfile=p)
tt = tflow.ttcpflow()
sa.tcp_start(tt)
sa.tcp_end(tt)
tctx.configure(sa, streamfile=None)
assert rd(p)
tt = tflow.ttcpflow()
sa.tcp_start(tt)
sa.tcp_end(tt)
tctx.configure(sa, streamfile=None)
assert rd(p)
def test_simple():
def test_simple(tmpdir):
sa = streamfile.StreamFile()
with taddons.context() as tctx:
with tutils.tmpdir() as tdir:
p = os.path.join(tdir, "foo")
p = str(tmpdir.join("foo"))
tctx.configure(sa, streamfile=p)
tctx.configure(sa, streamfile=p)
f = tflow.tflow(resp=True)
sa.request(f)
sa.response(f)
tctx.configure(sa, streamfile=None)
assert rd(p)[0].response
f = tflow.tflow(resp=True)
sa.request(f)
sa.response(f)
tctx.configure(sa, streamfile=None)
assert rd(p)[0].response
tctx.configure(sa, streamfile="+" + p)
f = tflow.tflow()
sa.request(f)
tctx.configure(sa, streamfile=None)
assert not rd(p)[1].response
tctx.configure(sa, streamfile="+" + p)
f = tflow.tflow()
sa.request(f)
tctx.configure(sa, streamfile=None)
assert not rd(p)[1].response

View File

@ -11,8 +11,8 @@ from OpenSSL import SSL
from mitmproxy import certs
from mitmproxy.net import tcp
from mitmproxy.test import tutils
from mitmproxy import exceptions
from mitmproxy.test import tutils
from . import tservers
from ...conftest import requires_alpn
@ -783,25 +783,24 @@ class TestSSLKeyLogger(tservers.ServerTestBase):
cipher_list="AES256-SHA"
)
def test_log(self):
def test_log(self, tmpdir):
testval = b"echo!\n"
_logfun = tcp.log_ssl_key
with tutils.tmpdir() as d:
logfile = os.path.join(d, "foo", "bar", "logfile")
tcp.log_ssl_key = tcp.SSLKeyLogger(logfile)
logfile = str(tmpdir.join("foo", "bar", "logfile"))
tcp.log_ssl_key = tcp.SSLKeyLogger(logfile)
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
c.convert_to_ssl()
c.wfile.write(testval)
c.wfile.flush()
assert c.rfile.readline() == testval
c.finish()
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
c.convert_to_ssl()
c.wfile.write(testval)
c.wfile.flush()
assert c.rfile.readline() == testval
c.finish()
tcp.log_ssl_key.close()
with open(logfile, "rb") as f:
assert f.read().count(b"CLIENT_RANDOM") == 2
tcp.log_ssl_key.close()
with open(logfile, "rb") as f:
assert f.read().count(b"CLIENT_RANDOM") == 2
tcp.log_ssl_key = _logfun

View File

@ -34,118 +34,106 @@ from mitmproxy.test import tutils
class TestCertStore:
def test_create_explicit(self):
with tutils.tmpdir() as d:
ca = certs.CertStore.from_store(d, "test")
assert ca.get_cert(b"foo", [])
def test_create_explicit(self, tmpdir):
ca = certs.CertStore.from_store(str(tmpdir), "test")
assert ca.get_cert(b"foo", [])
ca2 = certs.CertStore.from_store(d, "test")
assert ca2.get_cert(b"foo", [])
ca2 = certs.CertStore.from_store(str(tmpdir), "test")
assert ca2.get_cert(b"foo", [])
assert ca.default_ca.get_serial_number() == ca2.default_ca.get_serial_number()
assert ca.default_ca.get_serial_number() == ca2.default_ca.get_serial_number()
def test_create_no_common_name(self):
with tutils.tmpdir() as d:
ca = certs.CertStore.from_store(d, "test")
assert ca.get_cert(None, [])[0].cn is None
def test_create_no_common_name(self, tmpdir):
ca = certs.CertStore.from_store(str(tmpdir), "test")
assert ca.get_cert(None, [])[0].cn is None
def test_create_tmp(self):
with tutils.tmpdir() as d:
ca = certs.CertStore.from_store(d, "test")
assert ca.get_cert(b"foo.com", [])
assert ca.get_cert(b"foo.com", [])
assert ca.get_cert(b"*.foo.com", [])
def test_create_tmp(self, tmpdir):
ca = certs.CertStore.from_store(str(tmpdir), "test")
assert ca.get_cert(b"foo.com", [])
assert ca.get_cert(b"foo.com", [])
assert ca.get_cert(b"*.foo.com", [])
r = ca.get_cert(b"*.foo.com", [])
assert r[1] == ca.default_privatekey
r = ca.get_cert(b"*.foo.com", [])
assert r[1] == ca.default_privatekey
def test_sans(self):
with tutils.tmpdir() as d:
ca = certs.CertStore.from_store(d, "test")
c1 = ca.get_cert(b"foo.com", [b"*.bar.com"])
ca.get_cert(b"foo.bar.com", [])
# assert c1 == c2
c3 = ca.get_cert(b"bar.com", [])
assert not c1 == c3
def test_sans(self, tmpdir):
ca = certs.CertStore.from_store(str(tmpdir), "test")
c1 = ca.get_cert(b"foo.com", [b"*.bar.com"])
ca.get_cert(b"foo.bar.com", [])
# assert c1 == c2
c3 = ca.get_cert(b"bar.com", [])
assert not c1 == c3
def test_sans_change(self):
with tutils.tmpdir() as d:
ca = certs.CertStore.from_store(d, "test")
ca.get_cert(b"foo.com", [b"*.bar.com"])
cert, key, chain_file = ca.get_cert(b"foo.bar.com", [b"*.baz.com"])
assert b"*.baz.com" in cert.altnames
def test_sans_change(self, tmpdir):
ca = certs.CertStore.from_store(str(tmpdir), "test")
ca.get_cert(b"foo.com", [b"*.bar.com"])
cert, key, chain_file = ca.get_cert(b"foo.bar.com", [b"*.baz.com"])
assert b"*.baz.com" in cert.altnames
def test_expire(self):
with tutils.tmpdir() as d:
ca = certs.CertStore.from_store(d, "test")
ca.STORE_CAP = 3
ca.get_cert(b"one.com", [])
ca.get_cert(b"two.com", [])
ca.get_cert(b"three.com", [])
def test_expire(self, tmpdir):
ca = certs.CertStore.from_store(str(tmpdir), "test")
ca.STORE_CAP = 3
ca.get_cert(b"one.com", [])
ca.get_cert(b"two.com", [])
ca.get_cert(b"three.com", [])
assert (b"one.com", ()) in ca.certs
assert (b"two.com", ()) in ca.certs
assert (b"three.com", ()) in ca.certs
assert (b"one.com", ()) in ca.certs
assert (b"two.com", ()) in ca.certs
assert (b"three.com", ()) in ca.certs
ca.get_cert(b"one.com", [])
ca.get_cert(b"one.com", [])
assert (b"one.com", ()) in ca.certs
assert (b"two.com", ()) in ca.certs
assert (b"three.com", ()) in ca.certs
assert (b"one.com", ()) in ca.certs
assert (b"two.com", ()) in ca.certs
assert (b"three.com", ()) in ca.certs
ca.get_cert(b"four.com", [])
ca.get_cert(b"four.com", [])
assert (b"one.com", ()) not in ca.certs
assert (b"two.com", ()) in ca.certs
assert (b"three.com", ()) in ca.certs
assert (b"four.com", ()) in ca.certs
assert (b"one.com", ()) not in ca.certs
assert (b"two.com", ()) in ca.certs
assert (b"three.com", ()) in ca.certs
assert (b"four.com", ()) in ca.certs
def test_overrides(self):
with tutils.tmpdir() as d:
ca1 = certs.CertStore.from_store(os.path.join(d, "ca1"), "test")
ca2 = certs.CertStore.from_store(os.path.join(d, "ca2"), "test")
assert not ca1.default_ca.get_serial_number(
) == ca2.default_ca.get_serial_number()
def test_overrides(self, tmpdir):
ca1 = certs.CertStore.from_store(str(tmpdir.join("ca1")), "test")
ca2 = certs.CertStore.from_store(str(tmpdir.join("ca2")), "test")
assert not ca1.default_ca.get_serial_number() == ca2.default_ca.get_serial_number()
dc = ca2.get_cert(b"foo.com", [b"sans.example.com"])
dcp = os.path.join(d, "dc")
f = open(dcp, "wb")
f.write(dc[0].to_pem())
f.close()
ca1.add_cert_file(b"foo.com", dcp)
dc = ca2.get_cert(b"foo.com", [b"sans.example.com"])
dcp = tmpdir.join("dc")
dcp.write(dc[0].to_pem())
ca1.add_cert_file(b"foo.com", str(dcp))
ret = ca1.get_cert(b"foo.com", [])
assert ret[0].serial == dc[0].serial
ret = ca1.get_cert(b"foo.com", [])
assert ret[0].serial == dc[0].serial
def test_create_dhparams(self):
with tutils.tmpdir() as d:
filename = os.path.join(d, "dhparam.pem")
certs.CertStore.load_dhparam(filename)
assert os.path.exists(filename)
def test_create_dhparams(self, tmpdir):
filename = str(tmpdir.join("dhparam.pem"))
certs.CertStore.load_dhparam(filename)
assert os.path.exists(filename)
class TestDummyCert:
def test_with_ca(self):
with tutils.tmpdir() as d:
ca = certs.CertStore.from_store(d, "test")
r = certs.dummy_cert(
ca.default_privatekey,
ca.default_ca,
b"foo.com",
[b"one.com", b"two.com", b"*.three.com", b"127.0.0.1"]
)
assert r.cn == b"foo.com"
assert r.altnames == [b'one.com', b'two.com', b'*.three.com']
def test_with_ca(self, tmpdir):
ca = certs.CertStore.from_store(str(tmpdir), "test")
r = certs.dummy_cert(
ca.default_privatekey,
ca.default_ca,
b"foo.com",
[b"one.com", b"two.com", b"*.three.com", b"127.0.0.1"]
)
assert r.cn == b"foo.com"
assert r.altnames == [b'one.com', b'two.com', b'*.three.com']
r = certs.dummy_cert(
ca.default_privatekey,
ca.default_ca,
None,
[]
)
assert r.cn is None
assert r.altnames == []
r = certs.dummy_cert(
ca.default_privatekey,
ca.default_ca,
None,
[]
)
assert r.cn is None
assert r.altnames == []
class TestSSLCert:

View File

@ -1,5 +1,4 @@
import json
import os
import shlex
import pytest
@ -142,30 +141,26 @@ class TestHARDump:
with pytest.raises(ScriptError):
tscript("complex/har_dump.py")
def test_simple(self):
with tutils.tmpdir() as tdir:
path = os.path.join(tdir, "somefile")
def test_simple(self, tmpdir):
path = str(tmpdir.join("somefile"))
m, sc = tscript("complex/har_dump.py", shlex.quote(path))
m.addons.invoke(m, "response", self.flow())
m.addons.remove(sc)
with open(path, "r") as inp:
har = json.load(inp)
m, sc = tscript("complex/har_dump.py", shlex.quote(path))
m.addons.invoke(m, "response", self.flow())
m.addons.remove(sc)
with open(path, "r") as inp:
har = json.load(inp)
assert len(har["log"]["entries"]) == 1
def test_base64(self):
with tutils.tmpdir() as tdir:
path = os.path.join(tdir, "somefile")
def test_base64(self, tmpdir):
path = str(tmpdir.join("somefile"))
m, sc = tscript("complex/har_dump.py", shlex.quote(path))
m.addons.invoke(m, "response", self.flow(resp_content=b"foo" + b"\xFF" * 10))
m.addons.remove(sc)
with open(path, "r") as inp:
har = json.load(inp)
m, sc = tscript("complex/har_dump.py", shlex.quote(path))
m.addons.invoke(m, "response", self.flow(resp_content=b"foo" + b"\xFF" * 10))
m.addons.remove(sc)
with open(path, "r") as inp:
har = json.load(inp)
assert har["log"]["entries"][0]["response"]["content"]["encoding"] == "base64"
def test_format_cookies(self):
@ -187,7 +182,7 @@ class TestHARDump:
f = format_cookies([("n", "v", CA([("expires", "Mon, 24-Aug-2037 00:00:00 GMT")]))])[0]
assert f['expires']
def test_binary(self):
def test_binary(self, tmpdir):
f = self.flow()
f.request.method = "POST"
@ -196,14 +191,12 @@ class TestHARDump:
f.response.headers["random-junk"] = bytes(range(256))
f.response.content = bytes(range(256))
with tutils.tmpdir() as tdir:
path = os.path.join(tdir, "somefile")
path = str(tmpdir.join("somefile"))
m, sc = tscript("complex/har_dump.py", shlex.quote(path))
m.addons.invoke(m, "response", f)
m.addons.remove(sc)
with open(path, "r") as inp:
har = json.load(inp)
m, sc = tscript("complex/har_dump.py", shlex.quote(path))
m.addons.invoke(m, "response", f)
m.addons.remove(sc)
with open(path, "r") as inp:
har = json.load(inp)
assert len(har["log"]["entries"]) == 1

View File

@ -1,5 +1,4 @@
import copy
import os
import pytest
import typing
import argparse
@ -7,7 +6,6 @@ import argparse
from mitmproxy import options
from mitmproxy import optmanager
from mitmproxy import exceptions
from mitmproxy.test import tutils
class TO(optmanager.OptManager):
@ -238,25 +236,24 @@ def test_serialize_defaults():
assert o.serialize(None, defaults=True)
def test_saving():
def test_saving(tmpdir):
o = TD2()
o.three = "set"
with tutils.tmpdir() as tdir:
dst = os.path.join(tdir, "conf")
o.save(dst, defaults=True)
dst = str(tmpdir.join("conf"))
o.save(dst, defaults=True)
o2 = TD2()
o2.load_paths(dst)
o2.three = "foo"
o2.save(dst, defaults=True)
o2 = TD2()
o2.load_paths(dst)
o2.three = "foo"
o2.save(dst, defaults=True)
o.load_paths(dst)
assert o.three == "foo"
with open(dst, 'a') as f:
f.write("foobar: '123'")
with pytest.raises(exceptions.OptionsError, matches=''):
o.load_paths(dst)
assert o.three == "foo"
with open(dst, 'a') as f:
f.write("foobar: '123'")
with pytest.raises(exceptions.OptionsError, matches=''):
o.load_paths(dst)
def test_merge():

View File

@ -26,10 +26,12 @@ class Container(StateObject):
def __init__(self):
self.child = None
self.children = None
self.dictionary = None
_stateobject_attributes = dict(
child=Child,
children=List[Child],
dictionary=dict,
)
@classmethod
@ -62,12 +64,30 @@ def test_container_list():
a.children = [Child(42), Child(44)]
assert a.get_state() == {
"child": None,
"children": [{"x": 42}, {"x": 44}]
"children": [{"x": 42}, {"x": 44}],
"dictionary": None,
}
copy = a.copy()
assert len(copy.children) == 2
assert copy.children is not a.children
assert copy.children[0] is not a.children[0]
assert Container.from_state(a.get_state())
def test_container_dict():
a = Container()
a.dictionary = dict()
a.dictionary['foo'] = 'bar'
a.dictionary['bar'] = Child(44)
assert a.get_state() == {
"child": None,
"children": None,
"dictionary": {'bar': {'x': 44}, 'foo': 'bar'},
}
copy = a.copy()
assert len(copy.dictionary) == 2
assert copy.dictionary is not a.dictionary
assert copy.dictionary['bar'] is not a.dictionary['bar']
def test_too_much_state():

View File

@ -1,5 +1,7 @@
import io
import pytest
from mitmproxy.contrib import tnetstring
from mitmproxy import flowfilter
from mitmproxy.test import tflow
@ -14,8 +16,6 @@ class TestWebSocketFlow:
b = f2.get_state()
del a["id"]
del b["id"]
del a["handshake_flow"]["id"]
del b["handshake_flow"]["id"]
assert a == b
assert not f == f2
assert f is not f2
@ -60,3 +60,14 @@ class TestWebSocketFlow:
assert 'WebSocketFlow' in repr(f)
assert 'binary message: ' in repr(f.messages[0])
assert 'text message: ' in repr(f.messages[1])
def test_serialize(self):
b = io.BytesIO()
d = tflow.twebsocketflow().get_state()
tnetstring.dump(d, b)
assert b.getvalue()
b = io.BytesIO()
d = tflow.twebsocketflow().handshake_flow.get_state()
tnetstring.dump(d, b)
assert b.getvalue()

View File

@ -1,4 +1,3 @@
import os
import pytest
from unittest import mock
@ -9,7 +8,6 @@ from mitmproxy import controller
from mitmproxy import options
from mitmproxy.tools import dump
from mitmproxy.test import tutils
from .. import tservers
@ -19,18 +17,17 @@ class TestDumpMaster(tservers.MasterTest):
m = dump.DumpMaster(o, proxy.DummyServer(), with_termlog=False, with_dumper=False)
return m
def test_read(self):
with tutils.tmpdir() as t:
p = os.path.join(t, "read")
self.flowfile(p)
self.dummy_cycle(
self.mkmaster(None, rfile=p),
1, b"",
)
with pytest.raises(exceptions.OptionsError):
self.mkmaster(None, rfile="/nonexistent")
with pytest.raises(exceptions.OptionsError):
self.mkmaster(None, rfile="test_dump.py")
def test_read(self, tmpdir):
p = str(tmpdir.join("read"))
self.flowfile(p)
self.dummy_cycle(
self.mkmaster(None, rfile=p),
1, b"",
)
with pytest.raises(exceptions.OptionsError):
self.mkmaster(None, rfile="/nonexistent")
with pytest.raises(exceptions.OptionsError):
self.mkmaster(None, rfile="test_dump.py")
def test_has_error(self):
m = self.mkmaster(None)

View File

@ -1,11 +1,8 @@
import os
import pytest
from pathod import language
from pathod.language import base, exceptions
from mitmproxy.test import tutils
def parse_request(s):
return language.parse_pathoc(s).next()
@ -137,24 +134,22 @@ class TestTokValueFile:
v = base.TokValue.parseString("<path")[0]
assert v.path == "path"
def test_access_control(self):
def test_access_control(self, tmpdir):
v = base.TokValue.parseString("<path")[0]
with tutils.tmpdir() as t:
p = os.path.join(t, "path")
with open(p, "wb") as f:
f.write(b"x" * 10000)
f = tmpdir.join("path")
f.write(b"x" * 10000)
assert v.get_generator(language.Settings(staticdir=t))
assert v.get_generator(language.Settings(staticdir=str(tmpdir)))
v = base.TokValue.parseString("<path2")[0]
with pytest.raises(exceptions.FileAccessDenied):
v.get_generator(language.Settings(staticdir=t))
with pytest.raises(Exception, match="access disabled"):
v.get_generator(language.Settings())
v = base.TokValue.parseString("<path2")[0]
with pytest.raises(exceptions.FileAccessDenied):
v.get_generator(language.Settings(staticdir=str(tmpdir)))
with pytest.raises(Exception, match="access disabled"):
v.get_generator(language.Settings())
v = base.TokValue.parseString("</outside")[0]
with pytest.raises(Exception, match="outside"):
v.get_generator(language.Settings(staticdir=t))
v = base.TokValue.parseString("</outside")[0]
with pytest.raises(Exception, match="outside"):
v.get_generator(language.Settings(staticdir=str(tmpdir)))
def test_spec(self):
v = base.TokValue.parseString("<'one two'")[0]

View File

@ -1,7 +1,4 @@
import os
from pathod.language import generators
from mitmproxy.test import tutils
def test_randomgenerator():
@ -15,23 +12,20 @@ def test_randomgenerator():
assert len(g[1000:1001]) == 0
def test_filegenerator():
with tutils.tmpdir() as t:
path = os.path.join(t, "foo")
f = open(path, "wb")
f.write(b"x" * 10000)
f.close()
g = generators.FileGenerator(path)
assert len(g) == 10000
assert g[0] == b"x"
assert g[-1] == b"x"
assert g[0:5] == b"xxxxx"
assert len(g[1:10]) == 9
assert len(g[10000:10001]) == 0
assert repr(g)
# remove all references to FileGenerator instance to close the file
# handle.
del g
def test_filegenerator(tmpdir):
f = tmpdir.join("foo")
f.write(b"x" * 10000)
g = generators.FileGenerator(str(f))
assert len(g) == 10000
assert g[0] == b"x"
assert g[-1] == b"x"
assert g[0:5] == b"xxxxx"
assert len(g[1:10]) == 9
assert len(g[10000:10001]) == 0
assert repr(g)
# remove all references to FileGenerator instance to close the file
# handle.
del g
def test_transform_generator():

View File

@ -1,67 +0,0 @@
jest.unmock('../../ducks/flows')
jest.unmock('../../ducks/flowView')
jest.unmock('../../ducks/utils/view')
jest.unmock('../../ducks/utils/list')
jest.unmock('./tutils')
import { createStore } from './tutils'
import flows, * as flowActions from '../../ducks/flows'
import flowView, * as flowViewActions from '../../ducks/flowView'
function testStore() {
let store = createStore({
flows,
flowView
})
for (let i of [1, 2, 3, 4]) {
store.dispatch(
flowActions.addFlow({ id: i })
)
}
return store
}
describe('select relative', () => {
function testSelect(start, relative, result) {
const store = testStore()
store.dispatch(flowActions.select(start))
expect(store.getState().flows.selected).toEqual(start ? [start] : [])
store.dispatch(flowViewActions.selectRelative(relative))
expect(store.getState().flows.selected).toEqual([result])
}
describe('previous', () => {
it('should select the previous flow', () => {
testSelect(3, -1, 2)
})
it('should not changed when first flow is selected', () => {
testSelect(1, -1, 1)
})
it('should select first flow if no flow is selected', () => {
testSelect(undefined, -1, 1)
})
})
describe('next', () => {
it('should select the next flow', () => {
testSelect(2, 1, 3)
})
it('should not changed when last flow is selected', () => {
testSelect(4, 1, 4)
})
it('should select last flow if no flow is selected', () => {
testSelect(undefined, 1, 4)
})
})
})

View File

@ -1,13 +1,14 @@
jest.unmock('../../ducks/flows');
import reduceFlows, * as flowActions from '../../ducks/flows'
import * as storeActions from '../../ducks/utils/store'
describe('select flow', () => {
let state = reduceFlows(undefined, {})
for (let i of [1, 2, 3, 4]) {
state = reduceFlows(state, flowActions.addFlow({ id: i }))
state = reduceFlows(state, storeActions.add({ id: i }))
}
it('should be possible to select a single flow', () => {

View File

@ -8,7 +8,8 @@ import reducer, {
setContentViewDescription,
setShowFullContent,
setContent,
updateEdit
updateEdit,
stopEdit
} from '../../../ducks/ui/flow'
import { select, updateFlow } from '../../../ducks/flows'
@ -65,12 +66,12 @@ describe('flow reducer', () => {
it('should not change the state when a flow is updated which is not selected', () => {
let modifiedFlow = {id: 1}
let updatedFlow = {id: 0}
expect(reducer({modifiedFlow}, updateFlow(updatedFlow)).modifiedFlow).toEqual(modifiedFlow)
expect(reducer({modifiedFlow}, stopEdit(updatedFlow, modifiedFlow)).modifiedFlow).toEqual(modifiedFlow)
})
it('should stop editing when the selected flow is updated', () => {
it('should stop editing when the selected flow is updated', () => {
let modifiedFlow = {id: 1}
let updatedFlow = {id: 1}
expect(reducer({modifiedFlow}, updateFlow(updatedFlow)).modifiedFlow).toBeFalsy()
expect(reducer({modifiedFlow}, stopEdit(updatedFlow, modifiedFlow)).modifiedFlow).toBeFalsy()
})
})

View File

@ -1,64 +0,0 @@
jest.unmock('lodash')
jest.unmock('../../../ducks/utils/list')
import reduce, * as list from '../../../ducks/utils/list'
import _ from 'lodash'
describe('list reduce', () => {
it('should add item', () => {
const state = createState([
{ id: 1 },
{ id: 2 }
])
const result = createState([
{ id: 1 },
{ id: 2 },
{ id: 3 }
])
expect(reduce(state, list.add({ id: 3 }))).toEqual(result)
})
it('should update item', () => {
const state = createState([
{ id: 1, val: 1 },
{ id: 2, val: 2 }
])
const result = createState([
{ id: 1, val: 1 },
{ id: 2, val: 3 }
])
expect(reduce(state, list.update({ id: 2, val: 3 }))).toEqual(result)
})
it('should remove item', () => {
const state = createState([
{ id: 1 },
{ id: 2 }
])
const result = createState([
{ id: 1 }
])
result.byId[2] = result.indexOf[2] = null
expect(reduce(state, list.remove(2))).toEqual(result)
})
it('should replace all items', () => {
const state = createState([
{ id: 1 },
{ id: 2 }
])
const result = createState([
{ id: 1 }
])
expect(reduce(state, list.receive([{ id: 1 }]))).toEqual(result)
})
})
function createState(items) {
return {
data: items,
byId: _.fromPairs(items.map((item, index) => [item.id, item])),
indexOf: _.fromPairs(items.map((item, index) => [item.id, index]))
}
}

View File

@ -1,156 +0,0 @@
jest.unmock('../../../ducks/utils/view')
jest.unmock('lodash')
import reduce, * as view from '../../../ducks/utils/view'
import _ from 'lodash'
describe('view reduce', () => {
it('should filter items', () => {
const state = createState([
{ id: 1 },
{ id: 2 }
])
const result = createState([
{ id: 1 }
])
expect(reduce(state, view.updateFilter(state.data, item => item.id === 1))).toEqual(result)
})
it('should sort items', () => {
const state = createState([
{ id: 1 },
{ id: 2 }
])
const result = createState([
{ id: 2 },
{ id: 1 }
])
expect(reduce(state, view.updateSort((a, b) => b.id - a.id))).toEqual(result)
})
it('should add item', () => {
const state = createState([
{ id: 1 },
{ id: 2 }
])
const result = createState([
{ id: 1 },
{ id: 2 },
{ id: 3 }
])
expect(reduce(state, view.add({ id: 3 }))).toEqual(result)
})
it('should add item in place', () => {
const state = createState([
{ id: 1 }
])
const result = createState([
{ id: 3 },
{ id: 1 }
])
expect(reduce(state, view.add({ id: 3 }, undefined, (a, b) => b.id - a.id))).toEqual(result)
})
it('should filter added item', () => {
const state = createState([
{ id: 1 }
])
const result = createState([
{ id: 1 }
])
expect(reduce(state, view.add({ id: 3 }, i => i.id === 1))).toEqual(result)
})
it('should update item', () => {
const state = createState([
{ id: 1, val: 1 },
{ id: 2, val: 2 },
{ id: 3, val: 3 }
])
const result = createState([
{ id: 1, val: 1 },
{ id: 2, val: 3 },
{ id: 3, val: 3 }
])
expect(reduce(state, view.update({ id: 2, val: 3 }))).toEqual(result)
})
it('should sort updated item', () => {
const state = createState([
{ id: 1, val: 1 },
{ id: 2, val: 2 }
])
const result = createState([
{ id: 2, val: 3 },
{ id: 1, val: 1 }
])
expect(reduce(state, view.update({ id: 2, val: 3 }, undefined, (a, b) => b.id - a.id))).toEqual(result)
})
it('should filter updated item', () => {
const state = createState([
{ id: 1, val: 1 },
{ id: 2, val: 2 }
])
const result = createState([
{ id: 1, val: 1 }
])
result.indexOf[2] = null
expect(reduce(state, view.update({ id: 2, val: 3 }, i => i.id === i.val))).toEqual(result)
})
it('should remove item', () => {
const state = createState([
{ id: 1 },
{ id: 2 }
])
const result = createState([
{ id: 1 }
])
result.indexOf[2] = null
expect(reduce(state, view.remove(2))).toEqual(result)
})
it('should replace items', () => {
const state = createState([
{ id: 1 },
{ id: 2 }
])
const result = createState([
{ id: 1 }
])
expect(reduce(state, view.receive([{ id: 1 }]))).toEqual(result)
})
it('should sort received items', () => {
const state = createState([
{ id: 1 },
{ id: 2 }
])
const result = createState([
{ id: 2 },
{ id: 1 }
])
expect(reduce(state, view.receive([{ id: 1 }, { id: 2 }], undefined, (a, b) => b.id - a.id))).toEqual(result)
})
it('should filter received', () => {
const state = createState([
{ id: 1 },
{ id: 2 }
])
const result = createState([
{ id: 1 }
])
expect(reduce(state, view.receive([{ id: 1 }, { id: 2 }], i => i.id === 1))).toEqual(result)
})
})
function createState(items) {
return {
data: items,
indexOf: _.fromPairs(items.map((item, index) => [item.id, index]))
}
}

View File

@ -60,7 +60,7 @@ export default function reducer(state = defaultState, action) {
// There is no explicit "stop edit" event.
// We stop editing when we receive an update for
// the currently edited flow from the server
if (action.data.id === state.modifiedFlow.id) {
if (action.flow.id === state.modifiedFlow.id) {
return {
...state,
modifiedFlow: false,
@ -145,9 +145,10 @@ export function setShowFullContent() {
}
export function setContent(content){
return { type: SET_CONTENT, content}
return { type: SET_CONTENT, content }
}
export function stopEdit(flow, modifiedFlow) {
return flowsActions.update(flow, getDiff(flow, modifiedFlow))
let diff = getDiff(flow, modifiedFlow)
return {type: flowsActions.UPDATE, flow, diff }
}