mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2025-01-30 14:58:38 +00:00
Merge remote-tracking branch 'origin/master' into pr-2120
Conflicts: test/mitmproxy/addons/test_replace.py
This commit is contained in:
commit
05e11547f5
@ -10,9 +10,6 @@ git:
|
||||
|
||||
matrix:
|
||||
fast_finish: true
|
||||
allow_failures:
|
||||
- language: node_js
|
||||
node_js: "node"
|
||||
include:
|
||||
- python: 3.5
|
||||
env: TOXENV=lint
|
||||
|
@ -4,7 +4,7 @@ from mitmproxy.addons import check_alpn
|
||||
from mitmproxy.addons import check_ca
|
||||
from mitmproxy.addons import clientplayback
|
||||
from mitmproxy.addons import core_option_validation
|
||||
from mitmproxy.addons import disable_h2c_upgrade
|
||||
from mitmproxy.addons import disable_h2c
|
||||
from mitmproxy.addons import onboarding
|
||||
from mitmproxy.addons import proxyauth
|
||||
from mitmproxy.addons import replace
|
||||
@ -26,7 +26,7 @@ def default_addons():
|
||||
check_alpn.CheckALPN(),
|
||||
check_ca.CheckCA(),
|
||||
clientplayback.ClientPlayback(),
|
||||
disable_h2c_upgrade.DisableH2CleartextUpgrade(),
|
||||
disable_h2c.DisableH2C(),
|
||||
onboarding.Onboarding(),
|
||||
proxyauth.ProxyAuth(),
|
||||
replace.Replace(),
|
||||
|
41
mitmproxy/addons/disable_h2c.py
Normal file
41
mitmproxy/addons/disable_h2c.py
Normal file
@ -0,0 +1,41 @@
|
||||
import mitmproxy
|
||||
|
||||
|
||||
class DisableH2C:
|
||||
|
||||
"""
|
||||
We currently only support HTTP/2 over a TLS connection.
|
||||
|
||||
Some clients try to upgrade a connection from HTTP/1.1 to h2c. We need to
|
||||
remove those headers to avoid protocol errors if one endpoints suddenly
|
||||
starts sending HTTP/2 frames.
|
||||
|
||||
Some clients might use HTTP/2 Prior Knowledge to directly initiate a session
|
||||
by sending the connection preface. We just kill those flows.
|
||||
"""
|
||||
|
||||
def configure(self, options, updated):
|
||||
pass
|
||||
|
||||
def process_flow(self, f):
|
||||
if f.request.headers.get('upgrade', '') == 'h2c':
|
||||
mitmproxy.ctx.log.warn("HTTP/2 cleartext connections (h2c upgrade requests) are currently not supported.")
|
||||
del f.request.headers['upgrade']
|
||||
if 'connection' in f.request.headers:
|
||||
del f.request.headers['connection']
|
||||
if 'http2-settings' in f.request.headers:
|
||||
del f.request.headers['http2-settings']
|
||||
|
||||
is_connection_preface = (
|
||||
f.request.method == 'PRI' and
|
||||
f.request.path == '*' and
|
||||
f.request.http_version == 'HTTP/2.0'
|
||||
)
|
||||
if is_connection_preface:
|
||||
f.kill()
|
||||
mitmproxy.ctx.log.warn("Initiating HTTP/2 connections with prior knowledge are currently not supported.")
|
||||
|
||||
# Handlers
|
||||
|
||||
def request(self, f):
|
||||
self.process_flow(f)
|
@ -1,21 +0,0 @@
|
||||
class DisableH2CleartextUpgrade:
|
||||
|
||||
"""
|
||||
We currently only support HTTP/2 over a TLS connection. Some clients try
|
||||
to upgrade a connection from HTTP/1.1 to h2c, so we need to remove those
|
||||
headers to avoid protocol errors if one endpoints suddenly starts sending
|
||||
HTTP/2 frames.
|
||||
"""
|
||||
|
||||
def process_flow(self, f):
|
||||
if f.request.headers.get('upgrade', '') == 'h2c':
|
||||
del f.request.headers['upgrade']
|
||||
if 'connection' in f.request.headers:
|
||||
del f.request.headers['connection']
|
||||
if 'http2-settings' in f.request.headers:
|
||||
del f.request.headers['http2-settings']
|
||||
|
||||
# Handlers
|
||||
|
||||
def request(self, f):
|
||||
self.process_flow(f)
|
@ -1,7 +1,7 @@
|
||||
import binascii
|
||||
import weakref
|
||||
from typing import Optional
|
||||
from typing import Set # noqa
|
||||
from typing import MutableMapping # noqa
|
||||
from typing import Tuple
|
||||
|
||||
import passlib.apache
|
||||
@ -46,7 +46,7 @@ class ProxyAuth:
|
||||
self.htpasswd = None
|
||||
self.singleuser = None
|
||||
self.mode = None
|
||||
self.authenticated = weakref.WeakSet() # type: Set[connections.ClientConnection]
|
||||
self.authenticated = weakref.WeakKeyDictionary() # type: MutableMapping[connections.ClientConnection, Tuple[str, str]]
|
||||
"""Contains all connections that are permanently authenticated after an HTTP CONNECT"""
|
||||
|
||||
def enabled(self) -> bool:
|
||||
@ -153,11 +153,12 @@ class ProxyAuth:
|
||||
def http_connect(self, f: http.HTTPFlow) -> None:
|
||||
if self.enabled():
|
||||
if self.authenticate(f):
|
||||
self.authenticated.add(f.client_conn)
|
||||
self.authenticated[f.client_conn] = f.metadata["proxyauth"]
|
||||
|
||||
def requestheaders(self, f: http.HTTPFlow) -> None:
|
||||
if self.enabled():
|
||||
# Is this connection authenticated by a previous HTTP CONNECT?
|
||||
if f.client_conn in self.authenticated:
|
||||
f.metadata["proxyauth"] = self.authenticated[f.client_conn]
|
||||
return
|
||||
self.authenticate(f)
|
||||
|
@ -78,7 +78,7 @@ class Flow(stateobject.StateObject):
|
||||
self._backup = None # type: typing.Optional[Flow]
|
||||
self.reply = None # type: typing.Optional[controller.Reply]
|
||||
self.marked = False # type: bool
|
||||
self.metadata = dict() # type: typing.Dict[str, str]
|
||||
self.metadata = dict() # type: typing.Dict[str, typing.Any]
|
||||
|
||||
_stateobject_attributes = dict(
|
||||
id=str,
|
||||
|
@ -140,8 +140,8 @@ class WebSocketLayer(base.Layer):
|
||||
|
||||
def __call__(self):
|
||||
self.flow = WebSocketFlow(self.client_conn, self.server_conn, self.handshake_flow, self)
|
||||
self.flow.metadata['websocket_handshake'] = self.handshake_flow
|
||||
self.handshake_flow.metadata['websocket_flow'] = self.flow
|
||||
self.flow.metadata['websocket_handshake'] = self.handshake_flow.id
|
||||
self.handshake_flow.metadata['websocket_flow'] = self.flow.id
|
||||
self.channel.ask("websocket_start", self.flow)
|
||||
|
||||
client = self.client_conn.connection
|
||||
|
@ -39,6 +39,14 @@ class StateObject(serializable.Serializable):
|
||||
state[attr] = val.get_state()
|
||||
elif _is_list(cls):
|
||||
state[attr] = [x.get_state() for x in val]
|
||||
elif isinstance(val, dict):
|
||||
s = {}
|
||||
for k, v in val.items():
|
||||
if hasattr(v, "get_state"):
|
||||
s[k] = v.get_state()
|
||||
else:
|
||||
s[k] = v
|
||||
state[attr] = s
|
||||
else:
|
||||
state[attr] = val
|
||||
return state
|
||||
|
@ -70,6 +70,7 @@ def twebsocketflow(client_conn=True, server_conn=True, messages=True, err=None,
|
||||
handshake_flow.response = resp
|
||||
|
||||
f = websocket.WebSocketFlow(client_conn, server_conn, handshake_flow)
|
||||
handshake_flow.metadata['websocket_flow'] = f
|
||||
|
||||
if messages is True:
|
||||
messages = [
|
||||
|
@ -1,9 +1,5 @@
|
||||
from io import BytesIO
|
||||
import tempfile
|
||||
import os
|
||||
import time
|
||||
import shutil
|
||||
from contextlib import contextmanager
|
||||
from io import BytesIO
|
||||
|
||||
from mitmproxy.utils import data
|
||||
from mitmproxy.net import tcp
|
||||
@ -13,18 +9,6 @@ from mitmproxy.net import http
|
||||
test_data = data.Data(__name__).push("../../test/")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def tmpdir(*args, **kwargs):
|
||||
orig_workdir = os.getcwd()
|
||||
temp_workdir = tempfile.mkdtemp(*args, **kwargs)
|
||||
os.chdir(temp_workdir)
|
||||
|
||||
yield temp_workdir
|
||||
|
||||
os.chdir(orig_workdir)
|
||||
shutil.rmtree(temp_workdir)
|
||||
|
||||
|
||||
def treader(bytes):
|
||||
"""
|
||||
Construct a tcp.Read object from bytes.
|
||||
|
@ -2,7 +2,6 @@ import time
|
||||
from typing import List, Optional
|
||||
|
||||
from mitmproxy import flow
|
||||
from mitmproxy.http import HTTPFlow
|
||||
from mitmproxy.net import websockets
|
||||
from mitmproxy.types import serializable
|
||||
from mitmproxy.utils import strutils
|
||||
@ -44,6 +43,22 @@ class WebSocketFlow(flow.Flow):
|
||||
self.close_code = '(status code missing)'
|
||||
self.close_message = '(message missing)'
|
||||
self.close_reason = 'unknown status code'
|
||||
|
||||
if handshake_flow:
|
||||
self.client_key = websockets.get_client_key(handshake_flow.request.headers)
|
||||
self.client_protocol = websockets.get_protocol(handshake_flow.request.headers)
|
||||
self.client_extensions = websockets.get_extensions(handshake_flow.request.headers)
|
||||
self.server_accept = websockets.get_server_accept(handshake_flow.response.headers)
|
||||
self.server_protocol = websockets.get_protocol(handshake_flow.response.headers)
|
||||
self.server_extensions = websockets.get_extensions(handshake_flow.response.headers)
|
||||
else:
|
||||
self.client_key = ''
|
||||
self.client_protocol = ''
|
||||
self.client_extensions = ''
|
||||
self.server_accept = ''
|
||||
self.server_protocol = ''
|
||||
self.server_extensions = ''
|
||||
|
||||
self.handshake_flow = handshake_flow
|
||||
|
||||
_stateobject_attributes = flow.Flow._stateobject_attributes.copy()
|
||||
@ -53,7 +68,15 @@ class WebSocketFlow(flow.Flow):
|
||||
close_code=str,
|
||||
close_message=str,
|
||||
close_reason=str,
|
||||
handshake_flow=HTTPFlow,
|
||||
client_key=str,
|
||||
client_protocol=str,
|
||||
client_extensions=str,
|
||||
server_accept=str,
|
||||
server_protocol=str,
|
||||
server_extensions=str,
|
||||
# Do not include handshake_flow, to prevent recursive serialization!
|
||||
# Since mitmproxy-console currently only displays HTTPFlows,
|
||||
# dumping the handshake_flow will include the WebSocketFlow too.
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -65,30 +88,6 @@ class WebSocketFlow(flow.Flow):
|
||||
def __repr__(self):
|
||||
return "<WebSocketFlow ({} messages)>".format(len(self.messages))
|
||||
|
||||
@property
|
||||
def client_key(self):
|
||||
return websockets.get_client_key(self.handshake_flow.request.headers)
|
||||
|
||||
@property
|
||||
def client_protocol(self):
|
||||
return websockets.get_protocol(self.handshake_flow.request.headers)
|
||||
|
||||
@property
|
||||
def client_extensions(self):
|
||||
return websockets.get_extensions(self.handshake_flow.request.headers)
|
||||
|
||||
@property
|
||||
def server_accept(self):
|
||||
return websockets.get_server_accept(self.handshake_flow.response.headers)
|
||||
|
||||
@property
|
||||
def server_protocol(self):
|
||||
return websockets.get_protocol(self.handshake_flow.response.headers)
|
||||
|
||||
@property
|
||||
def server_extensions(self):
|
||||
return websockets.get_extensions(self.handshake_flow.response.headers)
|
||||
|
||||
def message_info(self, message: WebSocketMessage) -> str:
|
||||
return "{client} {direction} WebSocket {type} message {direction} {server}{endpoint}".format(
|
||||
type=message.type,
|
||||
|
@ -1,9 +1,7 @@
|
||||
import os
|
||||
import pytest
|
||||
from unittest import mock
|
||||
|
||||
from mitmproxy.test import tflow
|
||||
from mitmproxy.test import tutils
|
||||
from mitmproxy import io
|
||||
from mitmproxy import exceptions
|
||||
|
||||
@ -49,14 +47,13 @@ class TestClientPlayback:
|
||||
cp.tick()
|
||||
assert cp.current_thread is None
|
||||
|
||||
def test_configure(self):
|
||||
def test_configure(self, tmpdir):
|
||||
cp = clientplayback.ClientPlayback()
|
||||
with taddons.context() as tctx:
|
||||
with tutils.tmpdir() as td:
|
||||
path = os.path.join(td, "flows")
|
||||
tdump(path, [tflow.tflow()])
|
||||
tctx.configure(cp, client_replay=[path])
|
||||
tctx.configure(cp, client_replay=[])
|
||||
tctx.configure(cp)
|
||||
with pytest.raises(exceptions.OptionsError):
|
||||
tctx.configure(cp, client_replay=["nonexistent"])
|
||||
path = str(tmpdir.join("flows"))
|
||||
tdump(path, [tflow.tflow()])
|
||||
tctx.configure(cp, client_replay=[path])
|
||||
tctx.configure(cp, client_replay=[])
|
||||
tctx.configure(cp)
|
||||
with pytest.raises(exceptions.OptionsError):
|
||||
tctx.configure(cp, client_replay=["nonexistent"])
|
||||
|
39
test/mitmproxy/addons/test_disable_h2c.py
Normal file
39
test/mitmproxy/addons/test_disable_h2c.py
Normal file
@ -0,0 +1,39 @@
|
||||
import io
|
||||
from mitmproxy import http
|
||||
from mitmproxy.addons import disable_h2c
|
||||
from mitmproxy.net.http import http1
|
||||
from mitmproxy.exceptions import Kill
|
||||
from mitmproxy.test import tflow
|
||||
from mitmproxy.test import taddons
|
||||
|
||||
|
||||
class TestDisableH2CleartextUpgrade:
|
||||
def test_upgrade(self):
|
||||
with taddons.context() as tctx:
|
||||
a = disable_h2c.DisableH2C()
|
||||
tctx.configure(a)
|
||||
|
||||
f = tflow.tflow()
|
||||
f.request.headers['upgrade'] = 'h2c'
|
||||
f.request.headers['connection'] = 'foo'
|
||||
f.request.headers['http2-settings'] = 'bar'
|
||||
|
||||
a.request(f)
|
||||
assert 'upgrade' not in f.request.headers
|
||||
assert 'connection' not in f.request.headers
|
||||
assert 'http2-settings' not in f.request.headers
|
||||
|
||||
def test_prior_knowledge(self):
|
||||
with taddons.context() as tctx:
|
||||
a = disable_h2c.DisableH2C()
|
||||
tctx.configure(a)
|
||||
|
||||
b = io.BytesIO(b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n")
|
||||
f = tflow.tflow()
|
||||
f.request = http.HTTPRequest.wrap(http1.read_request(b))
|
||||
f.reply.handle()
|
||||
f.intercept()
|
||||
|
||||
a.request(f)
|
||||
assert not f.killable
|
||||
assert f.reply.value == Kill
|
@ -1,17 +0,0 @@
|
||||
from mitmproxy.addons import disable_h2c_upgrade
|
||||
from mitmproxy.test import tflow
|
||||
|
||||
|
||||
class TestTermLog:
|
||||
def test_simple(self):
|
||||
a = disable_h2c_upgrade.DisableH2CleartextUpgrade()
|
||||
|
||||
f = tflow.tflow()
|
||||
f.request.headers['upgrade'] = 'h2c'
|
||||
f.request.headers['connection'] = 'foo'
|
||||
f.request.headers['http2-settings'] = 'bar'
|
||||
|
||||
a.request(f)
|
||||
assert 'upgrade' not in f.request.headers
|
||||
assert 'connection' not in f.request.headers
|
||||
assert 'http2-settings' not in f.request.headers
|
@ -173,3 +173,4 @@ def test_handlers():
|
||||
f2 = tflow.tflow(client_conn=f.client_conn)
|
||||
up.requestheaders(f2)
|
||||
assert not f2.response
|
||||
assert f2.metadata["proxyauth"] == ('test', 'test')
|
||||
|
@ -68,13 +68,12 @@ class TestParseCommand:
|
||||
with pytest.raises(ValueError):
|
||||
script.parse_command(" ")
|
||||
|
||||
def test_no_script_file(self):
|
||||
def test_no_script_file(self, tmpdir):
|
||||
with pytest.raises(Exception, match="not found"):
|
||||
script.parse_command("notfound")
|
||||
|
||||
with tutils.tmpdir() as dir:
|
||||
with pytest.raises(Exception, match="Not a file"):
|
||||
script.parse_command(dir)
|
||||
with pytest.raises(Exception, match="Not a file"):
|
||||
script.parse_command(str(tmpdir))
|
||||
|
||||
def test_parse_args(self):
|
||||
with utils.chdir(tutils.test_data.dirname):
|
||||
@ -128,21 +127,19 @@ class TestScript:
|
||||
recf = sc.ns.call_log[0]
|
||||
assert recf[1] == "request"
|
||||
|
||||
def test_reload(self):
|
||||
def test_reload(self, tmpdir):
|
||||
with taddons.context() as tctx:
|
||||
with tutils.tmpdir():
|
||||
with open("foo.py", "w"):
|
||||
pass
|
||||
sc = script.Script("foo.py")
|
||||
tctx.configure(sc)
|
||||
for _ in range(100):
|
||||
with open("foo.py", "a") as f:
|
||||
f.write(".")
|
||||
sc.tick()
|
||||
time.sleep(0.1)
|
||||
if tctx.master.event_log:
|
||||
return
|
||||
raise AssertionError("Change event not detected.")
|
||||
f = tmpdir.join("foo.py")
|
||||
f.ensure(file=True)
|
||||
sc = script.Script(str(f))
|
||||
tctx.configure(sc)
|
||||
for _ in range(100):
|
||||
f.write(".")
|
||||
sc.tick()
|
||||
time.sleep(0.1)
|
||||
if tctx.master.event_log:
|
||||
return
|
||||
raise AssertionError("Change event not detected.")
|
||||
|
||||
def test_exception(self):
|
||||
with taddons.context() as tctx:
|
||||
|
@ -1,10 +1,8 @@
|
||||
import os
|
||||
import urllib
|
||||
import pytest
|
||||
|
||||
from mitmproxy.test import tutils
|
||||
from mitmproxy.test import tflow
|
||||
from mitmproxy.test import taddons
|
||||
from mitmproxy.test import tflow
|
||||
|
||||
import mitmproxy.test.tutils
|
||||
from mitmproxy.addons import serverplayback
|
||||
@ -19,15 +17,14 @@ def tdump(path, flows):
|
||||
w.add(i)
|
||||
|
||||
|
||||
def test_config():
|
||||
def test_config(tmpdir):
|
||||
s = serverplayback.ServerPlayback()
|
||||
with tutils.tmpdir() as p:
|
||||
with taddons.context() as tctx:
|
||||
fpath = os.path.join(p, "flows")
|
||||
tdump(fpath, [tflow.tflow(resp=True)])
|
||||
tctx.configure(s, server_replay=[fpath])
|
||||
with pytest.raises(exceptions.OptionsError):
|
||||
tctx.configure(s, server_replay=[p])
|
||||
with taddons.context() as tctx:
|
||||
fpath = str(tmpdir.join("flows"))
|
||||
tdump(fpath, [tflow.tflow(resp=True)])
|
||||
tctx.configure(s, server_replay=[fpath])
|
||||
with pytest.raises(exceptions.OptionsError):
|
||||
tctx.configure(s, server_replay=[str(tmpdir)])
|
||||
|
||||
|
||||
def test_tick():
|
||||
|
@ -1,9 +1,7 @@
|
||||
import os.path
|
||||
import pytest
|
||||
|
||||
from mitmproxy.test import tflow
|
||||
from mitmproxy.test import tutils
|
||||
from mitmproxy.test import taddons
|
||||
from mitmproxy.test import tflow
|
||||
|
||||
from mitmproxy import io
|
||||
from mitmproxy import exceptions
|
||||
@ -11,19 +9,17 @@ from mitmproxy import options
|
||||
from mitmproxy.addons import streamfile
|
||||
|
||||
|
||||
def test_configure():
|
||||
def test_configure(tmpdir):
|
||||
sa = streamfile.StreamFile()
|
||||
with taddons.context(options=options.Options()) as tctx:
|
||||
with tutils.tmpdir() as tdir:
|
||||
p = os.path.join(tdir, "foo")
|
||||
with pytest.raises(exceptions.OptionsError):
|
||||
tctx.configure(sa, streamfile=tdir)
|
||||
with pytest.raises(Exception, match="Invalid filter"):
|
||||
tctx.configure(sa, streamfile=p, filtstr="~~")
|
||||
tctx.configure(sa, filtstr="foo")
|
||||
assert sa.filt
|
||||
tctx.configure(sa, filtstr=None)
|
||||
assert not sa.filt
|
||||
with pytest.raises(exceptions.OptionsError):
|
||||
tctx.configure(sa, streamfile=str(tmpdir))
|
||||
with pytest.raises(Exception, match="Invalid filter"):
|
||||
tctx.configure(sa, streamfile=str(tmpdir.join("foo")), filtstr="~~")
|
||||
tctx.configure(sa, filtstr="foo")
|
||||
assert sa.filt
|
||||
tctx.configure(sa, filtstr=None)
|
||||
assert not sa.filt
|
||||
|
||||
|
||||
def rd(p):
|
||||
@ -31,36 +27,34 @@ def rd(p):
|
||||
return list(x.stream())
|
||||
|
||||
|
||||
def test_tcp():
|
||||
def test_tcp(tmpdir):
|
||||
sa = streamfile.StreamFile()
|
||||
with taddons.context() as tctx:
|
||||
with tutils.tmpdir() as tdir:
|
||||
p = os.path.join(tdir, "foo")
|
||||
tctx.configure(sa, streamfile=p)
|
||||
p = str(tmpdir.join("foo"))
|
||||
tctx.configure(sa, streamfile=p)
|
||||
|
||||
tt = tflow.ttcpflow()
|
||||
sa.tcp_start(tt)
|
||||
sa.tcp_end(tt)
|
||||
tctx.configure(sa, streamfile=None)
|
||||
assert rd(p)
|
||||
tt = tflow.ttcpflow()
|
||||
sa.tcp_start(tt)
|
||||
sa.tcp_end(tt)
|
||||
tctx.configure(sa, streamfile=None)
|
||||
assert rd(p)
|
||||
|
||||
|
||||
def test_simple():
|
||||
def test_simple(tmpdir):
|
||||
sa = streamfile.StreamFile()
|
||||
with taddons.context() as tctx:
|
||||
with tutils.tmpdir() as tdir:
|
||||
p = os.path.join(tdir, "foo")
|
||||
p = str(tmpdir.join("foo"))
|
||||
|
||||
tctx.configure(sa, streamfile=p)
|
||||
tctx.configure(sa, streamfile=p)
|
||||
|
||||
f = tflow.tflow(resp=True)
|
||||
sa.request(f)
|
||||
sa.response(f)
|
||||
tctx.configure(sa, streamfile=None)
|
||||
assert rd(p)[0].response
|
||||
f = tflow.tflow(resp=True)
|
||||
sa.request(f)
|
||||
sa.response(f)
|
||||
tctx.configure(sa, streamfile=None)
|
||||
assert rd(p)[0].response
|
||||
|
||||
tctx.configure(sa, streamfile="+" + p)
|
||||
f = tflow.tflow()
|
||||
sa.request(f)
|
||||
tctx.configure(sa, streamfile=None)
|
||||
assert not rd(p)[1].response
|
||||
tctx.configure(sa, streamfile="+" + p)
|
||||
f = tflow.tflow()
|
||||
sa.request(f)
|
||||
tctx.configure(sa, streamfile=None)
|
||||
assert not rd(p)[1].response
|
||||
|
@ -11,8 +11,8 @@ from OpenSSL import SSL
|
||||
|
||||
from mitmproxy import certs
|
||||
from mitmproxy.net import tcp
|
||||
from mitmproxy.test import tutils
|
||||
from mitmproxy import exceptions
|
||||
from mitmproxy.test import tutils
|
||||
|
||||
from . import tservers
|
||||
from ...conftest import requires_alpn
|
||||
@ -783,25 +783,24 @@ class TestSSLKeyLogger(tservers.ServerTestBase):
|
||||
cipher_list="AES256-SHA"
|
||||
)
|
||||
|
||||
def test_log(self):
|
||||
def test_log(self, tmpdir):
|
||||
testval = b"echo!\n"
|
||||
_logfun = tcp.log_ssl_key
|
||||
|
||||
with tutils.tmpdir() as d:
|
||||
logfile = os.path.join(d, "foo", "bar", "logfile")
|
||||
tcp.log_ssl_key = tcp.SSLKeyLogger(logfile)
|
||||
logfile = str(tmpdir.join("foo", "bar", "logfile"))
|
||||
tcp.log_ssl_key = tcp.SSLKeyLogger(logfile)
|
||||
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
with c.connect():
|
||||
c.convert_to_ssl()
|
||||
c.wfile.write(testval)
|
||||
c.wfile.flush()
|
||||
assert c.rfile.readline() == testval
|
||||
c.finish()
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
with c.connect():
|
||||
c.convert_to_ssl()
|
||||
c.wfile.write(testval)
|
||||
c.wfile.flush()
|
||||
assert c.rfile.readline() == testval
|
||||
c.finish()
|
||||
|
||||
tcp.log_ssl_key.close()
|
||||
with open(logfile, "rb") as f:
|
||||
assert f.read().count(b"CLIENT_RANDOM") == 2
|
||||
tcp.log_ssl_key.close()
|
||||
with open(logfile, "rb") as f:
|
||||
assert f.read().count(b"CLIENT_RANDOM") == 2
|
||||
|
||||
tcp.log_ssl_key = _logfun
|
||||
|
||||
|
@ -34,118 +34,106 @@ from mitmproxy.test import tutils
|
||||
|
||||
class TestCertStore:
|
||||
|
||||
def test_create_explicit(self):
|
||||
with tutils.tmpdir() as d:
|
||||
ca = certs.CertStore.from_store(d, "test")
|
||||
assert ca.get_cert(b"foo", [])
|
||||
def test_create_explicit(self, tmpdir):
|
||||
ca = certs.CertStore.from_store(str(tmpdir), "test")
|
||||
assert ca.get_cert(b"foo", [])
|
||||
|
||||
ca2 = certs.CertStore.from_store(d, "test")
|
||||
assert ca2.get_cert(b"foo", [])
|
||||
ca2 = certs.CertStore.from_store(str(tmpdir), "test")
|
||||
assert ca2.get_cert(b"foo", [])
|
||||
|
||||
assert ca.default_ca.get_serial_number() == ca2.default_ca.get_serial_number()
|
||||
assert ca.default_ca.get_serial_number() == ca2.default_ca.get_serial_number()
|
||||
|
||||
def test_create_no_common_name(self):
|
||||
with tutils.tmpdir() as d:
|
||||
ca = certs.CertStore.from_store(d, "test")
|
||||
assert ca.get_cert(None, [])[0].cn is None
|
||||
def test_create_no_common_name(self, tmpdir):
|
||||
ca = certs.CertStore.from_store(str(tmpdir), "test")
|
||||
assert ca.get_cert(None, [])[0].cn is None
|
||||
|
||||
def test_create_tmp(self):
|
||||
with tutils.tmpdir() as d:
|
||||
ca = certs.CertStore.from_store(d, "test")
|
||||
assert ca.get_cert(b"foo.com", [])
|
||||
assert ca.get_cert(b"foo.com", [])
|
||||
assert ca.get_cert(b"*.foo.com", [])
|
||||
def test_create_tmp(self, tmpdir):
|
||||
ca = certs.CertStore.from_store(str(tmpdir), "test")
|
||||
assert ca.get_cert(b"foo.com", [])
|
||||
assert ca.get_cert(b"foo.com", [])
|
||||
assert ca.get_cert(b"*.foo.com", [])
|
||||
|
||||
r = ca.get_cert(b"*.foo.com", [])
|
||||
assert r[1] == ca.default_privatekey
|
||||
r = ca.get_cert(b"*.foo.com", [])
|
||||
assert r[1] == ca.default_privatekey
|
||||
|
||||
def test_sans(self):
|
||||
with tutils.tmpdir() as d:
|
||||
ca = certs.CertStore.from_store(d, "test")
|
||||
c1 = ca.get_cert(b"foo.com", [b"*.bar.com"])
|
||||
ca.get_cert(b"foo.bar.com", [])
|
||||
# assert c1 == c2
|
||||
c3 = ca.get_cert(b"bar.com", [])
|
||||
assert not c1 == c3
|
||||
def test_sans(self, tmpdir):
|
||||
ca = certs.CertStore.from_store(str(tmpdir), "test")
|
||||
c1 = ca.get_cert(b"foo.com", [b"*.bar.com"])
|
||||
ca.get_cert(b"foo.bar.com", [])
|
||||
# assert c1 == c2
|
||||
c3 = ca.get_cert(b"bar.com", [])
|
||||
assert not c1 == c3
|
||||
|
||||
def test_sans_change(self):
|
||||
with tutils.tmpdir() as d:
|
||||
ca = certs.CertStore.from_store(d, "test")
|
||||
ca.get_cert(b"foo.com", [b"*.bar.com"])
|
||||
cert, key, chain_file = ca.get_cert(b"foo.bar.com", [b"*.baz.com"])
|
||||
assert b"*.baz.com" in cert.altnames
|
||||
def test_sans_change(self, tmpdir):
|
||||
ca = certs.CertStore.from_store(str(tmpdir), "test")
|
||||
ca.get_cert(b"foo.com", [b"*.bar.com"])
|
||||
cert, key, chain_file = ca.get_cert(b"foo.bar.com", [b"*.baz.com"])
|
||||
assert b"*.baz.com" in cert.altnames
|
||||
|
||||
def test_expire(self):
|
||||
with tutils.tmpdir() as d:
|
||||
ca = certs.CertStore.from_store(d, "test")
|
||||
ca.STORE_CAP = 3
|
||||
ca.get_cert(b"one.com", [])
|
||||
ca.get_cert(b"two.com", [])
|
||||
ca.get_cert(b"three.com", [])
|
||||
def test_expire(self, tmpdir):
|
||||
ca = certs.CertStore.from_store(str(tmpdir), "test")
|
||||
ca.STORE_CAP = 3
|
||||
ca.get_cert(b"one.com", [])
|
||||
ca.get_cert(b"two.com", [])
|
||||
ca.get_cert(b"three.com", [])
|
||||
|
||||
assert (b"one.com", ()) in ca.certs
|
||||
assert (b"two.com", ()) in ca.certs
|
||||
assert (b"three.com", ()) in ca.certs
|
||||
assert (b"one.com", ()) in ca.certs
|
||||
assert (b"two.com", ()) in ca.certs
|
||||
assert (b"three.com", ()) in ca.certs
|
||||
|
||||
ca.get_cert(b"one.com", [])
|
||||
ca.get_cert(b"one.com", [])
|
||||
|
||||
assert (b"one.com", ()) in ca.certs
|
||||
assert (b"two.com", ()) in ca.certs
|
||||
assert (b"three.com", ()) in ca.certs
|
||||
assert (b"one.com", ()) in ca.certs
|
||||
assert (b"two.com", ()) in ca.certs
|
||||
assert (b"three.com", ()) in ca.certs
|
||||
|
||||
ca.get_cert(b"four.com", [])
|
||||
ca.get_cert(b"four.com", [])
|
||||
|
||||
assert (b"one.com", ()) not in ca.certs
|
||||
assert (b"two.com", ()) in ca.certs
|
||||
assert (b"three.com", ()) in ca.certs
|
||||
assert (b"four.com", ()) in ca.certs
|
||||
assert (b"one.com", ()) not in ca.certs
|
||||
assert (b"two.com", ()) in ca.certs
|
||||
assert (b"three.com", ()) in ca.certs
|
||||
assert (b"four.com", ()) in ca.certs
|
||||
|
||||
def test_overrides(self):
|
||||
with tutils.tmpdir() as d:
|
||||
ca1 = certs.CertStore.from_store(os.path.join(d, "ca1"), "test")
|
||||
ca2 = certs.CertStore.from_store(os.path.join(d, "ca2"), "test")
|
||||
assert not ca1.default_ca.get_serial_number(
|
||||
) == ca2.default_ca.get_serial_number()
|
||||
def test_overrides(self, tmpdir):
|
||||
ca1 = certs.CertStore.from_store(str(tmpdir.join("ca1")), "test")
|
||||
ca2 = certs.CertStore.from_store(str(tmpdir.join("ca2")), "test")
|
||||
assert not ca1.default_ca.get_serial_number() == ca2.default_ca.get_serial_number()
|
||||
|
||||
dc = ca2.get_cert(b"foo.com", [b"sans.example.com"])
|
||||
dcp = os.path.join(d, "dc")
|
||||
f = open(dcp, "wb")
|
||||
f.write(dc[0].to_pem())
|
||||
f.close()
|
||||
ca1.add_cert_file(b"foo.com", dcp)
|
||||
dc = ca2.get_cert(b"foo.com", [b"sans.example.com"])
|
||||
dcp = tmpdir.join("dc")
|
||||
dcp.write(dc[0].to_pem())
|
||||
ca1.add_cert_file(b"foo.com", str(dcp))
|
||||
|
||||
ret = ca1.get_cert(b"foo.com", [])
|
||||
assert ret[0].serial == dc[0].serial
|
||||
ret = ca1.get_cert(b"foo.com", [])
|
||||
assert ret[0].serial == dc[0].serial
|
||||
|
||||
def test_create_dhparams(self):
|
||||
with tutils.tmpdir() as d:
|
||||
filename = os.path.join(d, "dhparam.pem")
|
||||
certs.CertStore.load_dhparam(filename)
|
||||
assert os.path.exists(filename)
|
||||
def test_create_dhparams(self, tmpdir):
|
||||
filename = str(tmpdir.join("dhparam.pem"))
|
||||
certs.CertStore.load_dhparam(filename)
|
||||
assert os.path.exists(filename)
|
||||
|
||||
|
||||
class TestDummyCert:
|
||||
|
||||
def test_with_ca(self):
|
||||
with tutils.tmpdir() as d:
|
||||
ca = certs.CertStore.from_store(d, "test")
|
||||
r = certs.dummy_cert(
|
||||
ca.default_privatekey,
|
||||
ca.default_ca,
|
||||
b"foo.com",
|
||||
[b"one.com", b"two.com", b"*.three.com", b"127.0.0.1"]
|
||||
)
|
||||
assert r.cn == b"foo.com"
|
||||
assert r.altnames == [b'one.com', b'two.com', b'*.three.com']
|
||||
def test_with_ca(self, tmpdir):
|
||||
ca = certs.CertStore.from_store(str(tmpdir), "test")
|
||||
r = certs.dummy_cert(
|
||||
ca.default_privatekey,
|
||||
ca.default_ca,
|
||||
b"foo.com",
|
||||
[b"one.com", b"two.com", b"*.three.com", b"127.0.0.1"]
|
||||
)
|
||||
assert r.cn == b"foo.com"
|
||||
assert r.altnames == [b'one.com', b'two.com', b'*.three.com']
|
||||
|
||||
r = certs.dummy_cert(
|
||||
ca.default_privatekey,
|
||||
ca.default_ca,
|
||||
None,
|
||||
[]
|
||||
)
|
||||
assert r.cn is None
|
||||
assert r.altnames == []
|
||||
r = certs.dummy_cert(
|
||||
ca.default_privatekey,
|
||||
ca.default_ca,
|
||||
None,
|
||||
[]
|
||||
)
|
||||
assert r.cn is None
|
||||
assert r.altnames == []
|
||||
|
||||
|
||||
class TestSSLCert:
|
||||
|
@ -1,5 +1,4 @@
|
||||
import json
|
||||
import os
|
||||
import shlex
|
||||
import pytest
|
||||
|
||||
@ -142,30 +141,26 @@ class TestHARDump:
|
||||
with pytest.raises(ScriptError):
|
||||
tscript("complex/har_dump.py")
|
||||
|
||||
def test_simple(self):
|
||||
with tutils.tmpdir() as tdir:
|
||||
path = os.path.join(tdir, "somefile")
|
||||
def test_simple(self, tmpdir):
|
||||
path = str(tmpdir.join("somefile"))
|
||||
|
||||
m, sc = tscript("complex/har_dump.py", shlex.quote(path))
|
||||
m.addons.invoke(m, "response", self.flow())
|
||||
m.addons.remove(sc)
|
||||
|
||||
with open(path, "r") as inp:
|
||||
har = json.load(inp)
|
||||
m, sc = tscript("complex/har_dump.py", shlex.quote(path))
|
||||
m.addons.invoke(m, "response", self.flow())
|
||||
m.addons.remove(sc)
|
||||
|
||||
with open(path, "r") as inp:
|
||||
har = json.load(inp)
|
||||
assert len(har["log"]["entries"]) == 1
|
||||
|
||||
def test_base64(self):
|
||||
with tutils.tmpdir() as tdir:
|
||||
path = os.path.join(tdir, "somefile")
|
||||
def test_base64(self, tmpdir):
|
||||
path = str(tmpdir.join("somefile"))
|
||||
|
||||
m, sc = tscript("complex/har_dump.py", shlex.quote(path))
|
||||
m.addons.invoke(m, "response", self.flow(resp_content=b"foo" + b"\xFF" * 10))
|
||||
m.addons.remove(sc)
|
||||
|
||||
with open(path, "r") as inp:
|
||||
har = json.load(inp)
|
||||
m, sc = tscript("complex/har_dump.py", shlex.quote(path))
|
||||
m.addons.invoke(m, "response", self.flow(resp_content=b"foo" + b"\xFF" * 10))
|
||||
m.addons.remove(sc)
|
||||
|
||||
with open(path, "r") as inp:
|
||||
har = json.load(inp)
|
||||
assert har["log"]["entries"][0]["response"]["content"]["encoding"] == "base64"
|
||||
|
||||
def test_format_cookies(self):
|
||||
@ -187,7 +182,7 @@ class TestHARDump:
|
||||
f = format_cookies([("n", "v", CA([("expires", "Mon, 24-Aug-2037 00:00:00 GMT")]))])[0]
|
||||
assert f['expires']
|
||||
|
||||
def test_binary(self):
|
||||
def test_binary(self, tmpdir):
|
||||
|
||||
f = self.flow()
|
||||
f.request.method = "POST"
|
||||
@ -196,14 +191,12 @@ class TestHARDump:
|
||||
f.response.headers["random-junk"] = bytes(range(256))
|
||||
f.response.content = bytes(range(256))
|
||||
|
||||
with tutils.tmpdir() as tdir:
|
||||
path = os.path.join(tdir, "somefile")
|
||||
path = str(tmpdir.join("somefile"))
|
||||
|
||||
m, sc = tscript("complex/har_dump.py", shlex.quote(path))
|
||||
m.addons.invoke(m, "response", f)
|
||||
m.addons.remove(sc)
|
||||
|
||||
with open(path, "r") as inp:
|
||||
har = json.load(inp)
|
||||
m, sc = tscript("complex/har_dump.py", shlex.quote(path))
|
||||
m.addons.invoke(m, "response", f)
|
||||
m.addons.remove(sc)
|
||||
|
||||
with open(path, "r") as inp:
|
||||
har = json.load(inp)
|
||||
assert len(har["log"]["entries"]) == 1
|
||||
|
@ -1,5 +1,4 @@
|
||||
import copy
|
||||
import os
|
||||
import pytest
|
||||
import typing
|
||||
import argparse
|
||||
@ -7,7 +6,6 @@ import argparse
|
||||
from mitmproxy import options
|
||||
from mitmproxy import optmanager
|
||||
from mitmproxy import exceptions
|
||||
from mitmproxy.test import tutils
|
||||
|
||||
|
||||
class TO(optmanager.OptManager):
|
||||
@ -238,25 +236,24 @@ def test_serialize_defaults():
|
||||
assert o.serialize(None, defaults=True)
|
||||
|
||||
|
||||
def test_saving():
|
||||
def test_saving(tmpdir):
|
||||
o = TD2()
|
||||
o.three = "set"
|
||||
with tutils.tmpdir() as tdir:
|
||||
dst = os.path.join(tdir, "conf")
|
||||
o.save(dst, defaults=True)
|
||||
dst = str(tmpdir.join("conf"))
|
||||
o.save(dst, defaults=True)
|
||||
|
||||
o2 = TD2()
|
||||
o2.load_paths(dst)
|
||||
o2.three = "foo"
|
||||
o2.save(dst, defaults=True)
|
||||
o2 = TD2()
|
||||
o2.load_paths(dst)
|
||||
o2.three = "foo"
|
||||
o2.save(dst, defaults=True)
|
||||
|
||||
o.load_paths(dst)
|
||||
assert o.three == "foo"
|
||||
|
||||
with open(dst, 'a') as f:
|
||||
f.write("foobar: '123'")
|
||||
with pytest.raises(exceptions.OptionsError, matches=''):
|
||||
o.load_paths(dst)
|
||||
assert o.three == "foo"
|
||||
|
||||
with open(dst, 'a') as f:
|
||||
f.write("foobar: '123'")
|
||||
with pytest.raises(exceptions.OptionsError, matches=''):
|
||||
o.load_paths(dst)
|
||||
|
||||
|
||||
def test_merge():
|
||||
|
@ -26,10 +26,12 @@ class Container(StateObject):
|
||||
def __init__(self):
|
||||
self.child = None
|
||||
self.children = None
|
||||
self.dictionary = None
|
||||
|
||||
_stateobject_attributes = dict(
|
||||
child=Child,
|
||||
children=List[Child],
|
||||
dictionary=dict,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -62,12 +64,30 @@ def test_container_list():
|
||||
a.children = [Child(42), Child(44)]
|
||||
assert a.get_state() == {
|
||||
"child": None,
|
||||
"children": [{"x": 42}, {"x": 44}]
|
||||
"children": [{"x": 42}, {"x": 44}],
|
||||
"dictionary": None,
|
||||
}
|
||||
copy = a.copy()
|
||||
assert len(copy.children) == 2
|
||||
assert copy.children is not a.children
|
||||
assert copy.children[0] is not a.children[0]
|
||||
assert Container.from_state(a.get_state())
|
||||
|
||||
|
||||
def test_container_dict():
|
||||
a = Container()
|
||||
a.dictionary = dict()
|
||||
a.dictionary['foo'] = 'bar'
|
||||
a.dictionary['bar'] = Child(44)
|
||||
assert a.get_state() == {
|
||||
"child": None,
|
||||
"children": None,
|
||||
"dictionary": {'bar': {'x': 44}, 'foo': 'bar'},
|
||||
}
|
||||
copy = a.copy()
|
||||
assert len(copy.dictionary) == 2
|
||||
assert copy.dictionary is not a.dictionary
|
||||
assert copy.dictionary['bar'] is not a.dictionary['bar']
|
||||
|
||||
|
||||
def test_too_much_state():
|
||||
|
@ -1,5 +1,7 @@
|
||||
import io
|
||||
import pytest
|
||||
|
||||
from mitmproxy.contrib import tnetstring
|
||||
from mitmproxy import flowfilter
|
||||
from mitmproxy.test import tflow
|
||||
|
||||
@ -14,8 +16,6 @@ class TestWebSocketFlow:
|
||||
b = f2.get_state()
|
||||
del a["id"]
|
||||
del b["id"]
|
||||
del a["handshake_flow"]["id"]
|
||||
del b["handshake_flow"]["id"]
|
||||
assert a == b
|
||||
assert not f == f2
|
||||
assert f is not f2
|
||||
@ -60,3 +60,14 @@ class TestWebSocketFlow:
|
||||
assert 'WebSocketFlow' in repr(f)
|
||||
assert 'binary message: ' in repr(f.messages[0])
|
||||
assert 'text message: ' in repr(f.messages[1])
|
||||
|
||||
def test_serialize(self):
|
||||
b = io.BytesIO()
|
||||
d = tflow.twebsocketflow().get_state()
|
||||
tnetstring.dump(d, b)
|
||||
assert b.getvalue()
|
||||
|
||||
b = io.BytesIO()
|
||||
d = tflow.twebsocketflow().handshake_flow.get_state()
|
||||
tnetstring.dump(d, b)
|
||||
assert b.getvalue()
|
||||
|
@ -1,4 +1,3 @@
|
||||
import os
|
||||
import pytest
|
||||
from unittest import mock
|
||||
|
||||
@ -9,7 +8,6 @@ from mitmproxy import controller
|
||||
from mitmproxy import options
|
||||
from mitmproxy.tools import dump
|
||||
|
||||
from mitmproxy.test import tutils
|
||||
from .. import tservers
|
||||
|
||||
|
||||
@ -19,18 +17,17 @@ class TestDumpMaster(tservers.MasterTest):
|
||||
m = dump.DumpMaster(o, proxy.DummyServer(), with_termlog=False, with_dumper=False)
|
||||
return m
|
||||
|
||||
def test_read(self):
|
||||
with tutils.tmpdir() as t:
|
||||
p = os.path.join(t, "read")
|
||||
self.flowfile(p)
|
||||
self.dummy_cycle(
|
||||
self.mkmaster(None, rfile=p),
|
||||
1, b"",
|
||||
)
|
||||
with pytest.raises(exceptions.OptionsError):
|
||||
self.mkmaster(None, rfile="/nonexistent")
|
||||
with pytest.raises(exceptions.OptionsError):
|
||||
self.mkmaster(None, rfile="test_dump.py")
|
||||
def test_read(self, tmpdir):
|
||||
p = str(tmpdir.join("read"))
|
||||
self.flowfile(p)
|
||||
self.dummy_cycle(
|
||||
self.mkmaster(None, rfile=p),
|
||||
1, b"",
|
||||
)
|
||||
with pytest.raises(exceptions.OptionsError):
|
||||
self.mkmaster(None, rfile="/nonexistent")
|
||||
with pytest.raises(exceptions.OptionsError):
|
||||
self.mkmaster(None, rfile="test_dump.py")
|
||||
|
||||
def test_has_error(self):
|
||||
m = self.mkmaster(None)
|
||||
|
@ -1,11 +1,8 @@
|
||||
import os
|
||||
import pytest
|
||||
|
||||
from pathod import language
|
||||
from pathod.language import base, exceptions
|
||||
|
||||
from mitmproxy.test import tutils
|
||||
|
||||
|
||||
def parse_request(s):
|
||||
return language.parse_pathoc(s).next()
|
||||
@ -137,24 +134,22 @@ class TestTokValueFile:
|
||||
v = base.TokValue.parseString("<path")[0]
|
||||
assert v.path == "path"
|
||||
|
||||
def test_access_control(self):
|
||||
def test_access_control(self, tmpdir):
|
||||
v = base.TokValue.parseString("<path")[0]
|
||||
with tutils.tmpdir() as t:
|
||||
p = os.path.join(t, "path")
|
||||
with open(p, "wb") as f:
|
||||
f.write(b"x" * 10000)
|
||||
f = tmpdir.join("path")
|
||||
f.write(b"x" * 10000)
|
||||
|
||||
assert v.get_generator(language.Settings(staticdir=t))
|
||||
assert v.get_generator(language.Settings(staticdir=str(tmpdir)))
|
||||
|
||||
v = base.TokValue.parseString("<path2")[0]
|
||||
with pytest.raises(exceptions.FileAccessDenied):
|
||||
v.get_generator(language.Settings(staticdir=t))
|
||||
with pytest.raises(Exception, match="access disabled"):
|
||||
v.get_generator(language.Settings())
|
||||
v = base.TokValue.parseString("<path2")[0]
|
||||
with pytest.raises(exceptions.FileAccessDenied):
|
||||
v.get_generator(language.Settings(staticdir=str(tmpdir)))
|
||||
with pytest.raises(Exception, match="access disabled"):
|
||||
v.get_generator(language.Settings())
|
||||
|
||||
v = base.TokValue.parseString("</outside")[0]
|
||||
with pytest.raises(Exception, match="outside"):
|
||||
v.get_generator(language.Settings(staticdir=t))
|
||||
v = base.TokValue.parseString("</outside")[0]
|
||||
with pytest.raises(Exception, match="outside"):
|
||||
v.get_generator(language.Settings(staticdir=str(tmpdir)))
|
||||
|
||||
def test_spec(self):
|
||||
v = base.TokValue.parseString("<'one two'")[0]
|
||||
|
@ -1,7 +1,4 @@
|
||||
import os
|
||||
|
||||
from pathod.language import generators
|
||||
from mitmproxy.test import tutils
|
||||
|
||||
|
||||
def test_randomgenerator():
|
||||
@ -15,23 +12,20 @@ def test_randomgenerator():
|
||||
assert len(g[1000:1001]) == 0
|
||||
|
||||
|
||||
def test_filegenerator():
|
||||
with tutils.tmpdir() as t:
|
||||
path = os.path.join(t, "foo")
|
||||
f = open(path, "wb")
|
||||
f.write(b"x" * 10000)
|
||||
f.close()
|
||||
g = generators.FileGenerator(path)
|
||||
assert len(g) == 10000
|
||||
assert g[0] == b"x"
|
||||
assert g[-1] == b"x"
|
||||
assert g[0:5] == b"xxxxx"
|
||||
assert len(g[1:10]) == 9
|
||||
assert len(g[10000:10001]) == 0
|
||||
assert repr(g)
|
||||
# remove all references to FileGenerator instance to close the file
|
||||
# handle.
|
||||
del g
|
||||
def test_filegenerator(tmpdir):
|
||||
f = tmpdir.join("foo")
|
||||
f.write(b"x" * 10000)
|
||||
g = generators.FileGenerator(str(f))
|
||||
assert len(g) == 10000
|
||||
assert g[0] == b"x"
|
||||
assert g[-1] == b"x"
|
||||
assert g[0:5] == b"xxxxx"
|
||||
assert len(g[1:10]) == 9
|
||||
assert len(g[10000:10001]) == 0
|
||||
assert repr(g)
|
||||
# remove all references to FileGenerator instance to close the file
|
||||
# handle.
|
||||
del g
|
||||
|
||||
|
||||
def test_transform_generator():
|
||||
|
@ -1,67 +0,0 @@
|
||||
jest.unmock('../../ducks/flows')
|
||||
jest.unmock('../../ducks/flowView')
|
||||
jest.unmock('../../ducks/utils/view')
|
||||
jest.unmock('../../ducks/utils/list')
|
||||
jest.unmock('./tutils')
|
||||
|
||||
import { createStore } from './tutils'
|
||||
|
||||
import flows, * as flowActions from '../../ducks/flows'
|
||||
import flowView, * as flowViewActions from '../../ducks/flowView'
|
||||
|
||||
|
||||
function testStore() {
|
||||
let store = createStore({
|
||||
flows,
|
||||
flowView
|
||||
})
|
||||
for (let i of [1, 2, 3, 4]) {
|
||||
store.dispatch(
|
||||
flowActions.addFlow({ id: i })
|
||||
)
|
||||
}
|
||||
return store
|
||||
}
|
||||
|
||||
describe('select relative', () => {
|
||||
|
||||
function testSelect(start, relative, result) {
|
||||
const store = testStore()
|
||||
store.dispatch(flowActions.select(start))
|
||||
expect(store.getState().flows.selected).toEqual(start ? [start] : [])
|
||||
store.dispatch(flowViewActions.selectRelative(relative))
|
||||
expect(store.getState().flows.selected).toEqual([result])
|
||||
}
|
||||
|
||||
describe('previous', () => {
|
||||
|
||||
it('should select the previous flow', () => {
|
||||
testSelect(3, -1, 2)
|
||||
})
|
||||
|
||||
it('should not changed when first flow is selected', () => {
|
||||
testSelect(1, -1, 1)
|
||||
})
|
||||
|
||||
it('should select first flow if no flow is selected', () => {
|
||||
testSelect(undefined, -1, 1)
|
||||
})
|
||||
|
||||
})
|
||||
|
||||
describe('next', () => {
|
||||
|
||||
it('should select the next flow', () => {
|
||||
testSelect(2, 1, 3)
|
||||
})
|
||||
|
||||
it('should not changed when last flow is selected', () => {
|
||||
testSelect(4, 1, 4)
|
||||
})
|
||||
|
||||
it('should select last flow if no flow is selected', () => {
|
||||
testSelect(undefined, 1, 4)
|
||||
})
|
||||
|
||||
})
|
||||
})
|
@ -1,13 +1,14 @@
|
||||
jest.unmock('../../ducks/flows');
|
||||
|
||||
import reduceFlows, * as flowActions from '../../ducks/flows'
|
||||
import * as storeActions from '../../ducks/utils/store'
|
||||
|
||||
|
||||
describe('select flow', () => {
|
||||
|
||||
let state = reduceFlows(undefined, {})
|
||||
for (let i of [1, 2, 3, 4]) {
|
||||
state = reduceFlows(state, flowActions.addFlow({ id: i }))
|
||||
state = reduceFlows(state, storeActions.add({ id: i }))
|
||||
}
|
||||
|
||||
it('should be possible to select a single flow', () => {
|
||||
|
@ -8,7 +8,8 @@ import reducer, {
|
||||
setContentViewDescription,
|
||||
setShowFullContent,
|
||||
setContent,
|
||||
updateEdit
|
||||
updateEdit,
|
||||
stopEdit
|
||||
} from '../../../ducks/ui/flow'
|
||||
|
||||
import { select, updateFlow } from '../../../ducks/flows'
|
||||
@ -65,12 +66,12 @@ describe('flow reducer', () => {
|
||||
it('should not change the state when a flow is updated which is not selected', () => {
|
||||
let modifiedFlow = {id: 1}
|
||||
let updatedFlow = {id: 0}
|
||||
expect(reducer({modifiedFlow}, updateFlow(updatedFlow)).modifiedFlow).toEqual(modifiedFlow)
|
||||
expect(reducer({modifiedFlow}, stopEdit(updatedFlow, modifiedFlow)).modifiedFlow).toEqual(modifiedFlow)
|
||||
})
|
||||
|
||||
it('should stop editing when the selected flow is updated', () => {
|
||||
it('should stop editing when the selected flow is updated', () => {
|
||||
let modifiedFlow = {id: 1}
|
||||
let updatedFlow = {id: 1}
|
||||
expect(reducer({modifiedFlow}, updateFlow(updatedFlow)).modifiedFlow).toBeFalsy()
|
||||
expect(reducer({modifiedFlow}, stopEdit(updatedFlow, modifiedFlow)).modifiedFlow).toBeFalsy()
|
||||
})
|
||||
})
|
||||
|
@ -1,64 +0,0 @@
|
||||
jest.unmock('lodash')
|
||||
jest.unmock('../../../ducks/utils/list')
|
||||
|
||||
import reduce, * as list from '../../../ducks/utils/list'
|
||||
import _ from 'lodash'
|
||||
|
||||
describe('list reduce', () => {
|
||||
|
||||
it('should add item', () => {
|
||||
const state = createState([
|
||||
{ id: 1 },
|
||||
{ id: 2 }
|
||||
])
|
||||
const result = createState([
|
||||
{ id: 1 },
|
||||
{ id: 2 },
|
||||
{ id: 3 }
|
||||
])
|
||||
expect(reduce(state, list.add({ id: 3 }))).toEqual(result)
|
||||
})
|
||||
|
||||
it('should update item', () => {
|
||||
const state = createState([
|
||||
{ id: 1, val: 1 },
|
||||
{ id: 2, val: 2 }
|
||||
])
|
||||
const result = createState([
|
||||
{ id: 1, val: 1 },
|
||||
{ id: 2, val: 3 }
|
||||
])
|
||||
expect(reduce(state, list.update({ id: 2, val: 3 }))).toEqual(result)
|
||||
})
|
||||
|
||||
it('should remove item', () => {
|
||||
const state = createState([
|
||||
{ id: 1 },
|
||||
{ id: 2 }
|
||||
])
|
||||
const result = createState([
|
||||
{ id: 1 }
|
||||
])
|
||||
result.byId[2] = result.indexOf[2] = null
|
||||
expect(reduce(state, list.remove(2))).toEqual(result)
|
||||
})
|
||||
|
||||
it('should replace all items', () => {
|
||||
const state = createState([
|
||||
{ id: 1 },
|
||||
{ id: 2 }
|
||||
])
|
||||
const result = createState([
|
||||
{ id: 1 }
|
||||
])
|
||||
expect(reduce(state, list.receive([{ id: 1 }]))).toEqual(result)
|
||||
})
|
||||
})
|
||||
|
||||
function createState(items) {
|
||||
return {
|
||||
data: items,
|
||||
byId: _.fromPairs(items.map((item, index) => [item.id, item])),
|
||||
indexOf: _.fromPairs(items.map((item, index) => [item.id, index]))
|
||||
}
|
||||
}
|
@ -1,156 +0,0 @@
|
||||
jest.unmock('../../../ducks/utils/view')
|
||||
jest.unmock('lodash')
|
||||
|
||||
import reduce, * as view from '../../../ducks/utils/view'
|
||||
import _ from 'lodash'
|
||||
|
||||
describe('view reduce', () => {
|
||||
|
||||
it('should filter items', () => {
|
||||
const state = createState([
|
||||
{ id: 1 },
|
||||
{ id: 2 }
|
||||
])
|
||||
const result = createState([
|
||||
{ id: 1 }
|
||||
])
|
||||
expect(reduce(state, view.updateFilter(state.data, item => item.id === 1))).toEqual(result)
|
||||
})
|
||||
|
||||
it('should sort items', () => {
|
||||
const state = createState([
|
||||
{ id: 1 },
|
||||
{ id: 2 }
|
||||
])
|
||||
const result = createState([
|
||||
{ id: 2 },
|
||||
{ id: 1 }
|
||||
])
|
||||
expect(reduce(state, view.updateSort((a, b) => b.id - a.id))).toEqual(result)
|
||||
})
|
||||
|
||||
it('should add item', () => {
|
||||
const state = createState([
|
||||
{ id: 1 },
|
||||
{ id: 2 }
|
||||
])
|
||||
const result = createState([
|
||||
{ id: 1 },
|
||||
{ id: 2 },
|
||||
{ id: 3 }
|
||||
])
|
||||
expect(reduce(state, view.add({ id: 3 }))).toEqual(result)
|
||||
})
|
||||
|
||||
it('should add item in place', () => {
|
||||
const state = createState([
|
||||
{ id: 1 }
|
||||
])
|
||||
const result = createState([
|
||||
{ id: 3 },
|
||||
{ id: 1 }
|
||||
])
|
||||
expect(reduce(state, view.add({ id: 3 }, undefined, (a, b) => b.id - a.id))).toEqual(result)
|
||||
})
|
||||
|
||||
it('should filter added item', () => {
|
||||
const state = createState([
|
||||
{ id: 1 }
|
||||
])
|
||||
const result = createState([
|
||||
{ id: 1 }
|
||||
])
|
||||
expect(reduce(state, view.add({ id: 3 }, i => i.id === 1))).toEqual(result)
|
||||
})
|
||||
|
||||
it('should update item', () => {
|
||||
const state = createState([
|
||||
{ id: 1, val: 1 },
|
||||
{ id: 2, val: 2 },
|
||||
{ id: 3, val: 3 }
|
||||
])
|
||||
const result = createState([
|
||||
{ id: 1, val: 1 },
|
||||
{ id: 2, val: 3 },
|
||||
{ id: 3, val: 3 }
|
||||
])
|
||||
expect(reduce(state, view.update({ id: 2, val: 3 }))).toEqual(result)
|
||||
})
|
||||
|
||||
it('should sort updated item', () => {
|
||||
const state = createState([
|
||||
{ id: 1, val: 1 },
|
||||
{ id: 2, val: 2 }
|
||||
])
|
||||
const result = createState([
|
||||
{ id: 2, val: 3 },
|
||||
{ id: 1, val: 1 }
|
||||
])
|
||||
expect(reduce(state, view.update({ id: 2, val: 3 }, undefined, (a, b) => b.id - a.id))).toEqual(result)
|
||||
})
|
||||
|
||||
it('should filter updated item', () => {
|
||||
const state = createState([
|
||||
{ id: 1, val: 1 },
|
||||
{ id: 2, val: 2 }
|
||||
])
|
||||
const result = createState([
|
||||
{ id: 1, val: 1 }
|
||||
])
|
||||
result.indexOf[2] = null
|
||||
expect(reduce(state, view.update({ id: 2, val: 3 }, i => i.id === i.val))).toEqual(result)
|
||||
})
|
||||
|
||||
it('should remove item', () => {
|
||||
const state = createState([
|
||||
{ id: 1 },
|
||||
{ id: 2 }
|
||||
])
|
||||
const result = createState([
|
||||
{ id: 1 }
|
||||
])
|
||||
result.indexOf[2] = null
|
||||
expect(reduce(state, view.remove(2))).toEqual(result)
|
||||
})
|
||||
|
||||
it('should replace items', () => {
|
||||
const state = createState([
|
||||
{ id: 1 },
|
||||
{ id: 2 }
|
||||
])
|
||||
const result = createState([
|
||||
{ id: 1 }
|
||||
])
|
||||
expect(reduce(state, view.receive([{ id: 1 }]))).toEqual(result)
|
||||
})
|
||||
|
||||
it('should sort received items', () => {
|
||||
const state = createState([
|
||||
{ id: 1 },
|
||||
{ id: 2 }
|
||||
])
|
||||
const result = createState([
|
||||
{ id: 2 },
|
||||
{ id: 1 }
|
||||
])
|
||||
expect(reduce(state, view.receive([{ id: 1 }, { id: 2 }], undefined, (a, b) => b.id - a.id))).toEqual(result)
|
||||
})
|
||||
|
||||
it('should filter received', () => {
|
||||
const state = createState([
|
||||
{ id: 1 },
|
||||
{ id: 2 }
|
||||
])
|
||||
const result = createState([
|
||||
{ id: 1 }
|
||||
])
|
||||
expect(reduce(state, view.receive([{ id: 1 }, { id: 2 }], i => i.id === 1))).toEqual(result)
|
||||
})
|
||||
})
|
||||
|
||||
function createState(items) {
|
||||
return {
|
||||
data: items,
|
||||
indexOf: _.fromPairs(items.map((item, index) => [item.id, index]))
|
||||
}
|
||||
}
|
@ -60,7 +60,7 @@ export default function reducer(state = defaultState, action) {
|
||||
// There is no explicit "stop edit" event.
|
||||
// We stop editing when we receive an update for
|
||||
// the currently edited flow from the server
|
||||
if (action.data.id === state.modifiedFlow.id) {
|
||||
if (action.flow.id === state.modifiedFlow.id) {
|
||||
return {
|
||||
...state,
|
||||
modifiedFlow: false,
|
||||
@ -145,9 +145,10 @@ export function setShowFullContent() {
|
||||
}
|
||||
|
||||
export function setContent(content){
|
||||
return { type: SET_CONTENT, content}
|
||||
return { type: SET_CONTENT, content }
|
||||
}
|
||||
|
||||
export function stopEdit(flow, modifiedFlow) {
|
||||
return flowsActions.update(flow, getDiff(flow, modifiedFlow))
|
||||
let diff = getDiff(flow, modifiedFlow)
|
||||
return {type: flowsActions.UPDATE, flow, diff }
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user