diff --git a/examples/README b/examples/README index adfcd0f25..b4dec8e56 100644 --- a/examples/README +++ b/examples/README @@ -1,3 +1,7 @@ +Some inline scripts may require additional dependencies, which can be installed using +`pip install mitmproxy[examples]`. + + # inline script examples add_header.py Simple script that just adds a header to every request. change_upstream_proxy.py Dynamically change the upstream proxy diff --git a/examples/add_header.py b/examples/add_header.py index 0c0593d1f..cf1b53cc8 100644 --- a/examples/add_header.py +++ b/examples/add_header.py @@ -1,2 +1,2 @@ def response(context, flow): - flow.response.headers["newheader"] = ["foo"] + flow.response.headers["newheader"] = "foo" diff --git a/examples/change_upstream_proxy.py b/examples/change_upstream_proxy.py index 8f58e1f21..9c454897a 100644 --- a/examples/change_upstream_proxy.py +++ b/examples/change_upstream_proxy.py @@ -4,7 +4,6 @@ # Usage: mitmdump -U http://default-upstream-proxy.local:8080/ -s change_upstream_proxy.py # # If you want to change the target server, you should modify flow.request.host and flow.request.port -# flow.live.set_server should only be used by inline scripts to change the upstream proxy. def proxy_address(flow): @@ -22,13 +21,4 @@ def request(context, flow): return address = proxy_address(flow) if flow.live: - if flow.request.scheme == "http": - # For a normal HTTP request, we just change the proxy server and we're done! - if address != flow.live.server_conn.address: - flow.live.set_server(address, depth=1) - else: - # If we have CONNECTed (and thereby established "destination state"), the story is - # a bit more complex. Now we don't want to change the top level address (which is - # the connect destination) but the address below that. (Notice the `.via` and depth=2). - if address != flow.live.server_conn.via.address: - flow.live.set_server(address, depth=2) + flow.live.change_upstream_proxy_server(address) \ No newline at end of file diff --git a/examples/flowbasic b/examples/flowbasic index 6663dc468..78b9eff70 100755 --- a/examples/flowbasic +++ b/examples/flowbasic @@ -8,9 +8,8 @@ Note that request and response messages are not automatically replied to, so we need to implement handlers to do this. """ -import os -from libmproxy import flow, proxy -from libmproxy.proxy.server import ProxyServer +from libmproxy import flow +from libmproxy.proxy import ProxyServer, ProxyConfig class MyMaster(flow.FlowMaster): @@ -34,7 +33,7 @@ class MyMaster(flow.FlowMaster): return f -config = proxy.ProxyConfig( +config = ProxyConfig( port=8080, # use ~/.mitmproxy/mitmproxy-ca.pem as default CA file. cadir="~/.mitmproxy/" diff --git a/examples/har_extractor.py b/examples/har_extractor.py index f06efec39..bc784dc47 100644 --- a/examples/har_extractor.py +++ b/examples/har_extractor.py @@ -147,8 +147,8 @@ def response(context, flow): response_body_size = len(flow.response.content) response_body_decoded_size = len(flow.response.get_decoded_content()) response_body_compression = response_body_decoded_size - response_body_size - response_mime_type = flow.response.headers.get_first('Content-Type', '') - response_redirect_url = flow.response.headers.get_first('Location', '') + response_mime_type = flow.response.headers.get('Content-Type', '') + response_redirect_url = flow.response.headers.get('Location', '') entry = HAR.entries( { @@ -201,12 +201,12 @@ def response(context, flow): # Lookup the referer in the page_ref of context.HARLog to point this entries # pageref attribute to the right pages object, then set it as a new # reference to build a reference tree. - elif context.HARLog.get_page_ref(flow.request.headers.get('Referer', (None, ))[0]) is not None: + elif context.HARLog.get_page_ref(flow.request.headers.get('Referer')) is not None: entry['pageref'] = context.HARLog.get_page_ref( - flow.request.headers['Referer'][0] + flow.request.headers['Referer'] ) context.HARLog.set_page_ref( - flow.request.headers['Referer'][0], entry['pageref'] + flow.request.headers['Referer'], entry['pageref'] ) context.HARLog.add(entry) diff --git a/examples/iframe_injector.py b/examples/iframe_injector.py index b2fa2d26f..29de9b63f 100644 --- a/examples/iframe_injector.py +++ b/examples/iframe_injector.py @@ -1,7 +1,7 @@ # Usage: mitmdump -s "iframe_injector.py url" # (this script works best with --anticache) from bs4 import BeautifulSoup -from libmproxy.protocol.http import decoded +from libmproxy.models import decoded def start(context, argv): diff --git a/examples/modify_form.py b/examples/modify_form.py index 37ba2faca..3e9d15c00 100644 --- a/examples/modify_form.py +++ b/examples/modify_form.py @@ -1,7 +1,5 @@ - def request(context, flow): - if "application/x-www-form-urlencoded" in flow.request.headers[ - "content-type"]: + if "application/x-www-form-urlencoded" in flow.request.headers.get("content-type", ""): form = flow.request.get_form_urlencoded() form["mitmproxy"] = ["rocks"] flow.request.set_form_urlencoded(form) diff --git a/examples/modify_response_body.py b/examples/modify_response_body.py index 68d3d4abd..a35e1525f 100644 --- a/examples/modify_response_body.py +++ b/examples/modify_response_body.py @@ -1,6 +1,6 @@ # Usage: mitmdump -s "modify_response_body.py mitmproxy bananas" # (this script works best with --anticache) -from libmproxy.protocol.http import decoded +from libmproxy.models import decoded def start(context, argv): diff --git a/examples/read_dumpfile b/examples/read_dumpfile index eb1c93bbf..b329c0e1b 100755 --- a/examples/read_dumpfile +++ b/examples/read_dumpfile @@ -4,7 +4,6 @@ # from libmproxy import flow -import json import pprint import sys diff --git a/examples/redirect_requests.py b/examples/redirect_requests.py index 48512f1bb..ca24c42a8 100644 --- a/examples/redirect_requests.py +++ b/examples/redirect_requests.py @@ -1,10 +1,8 @@ -from libmproxy.protocol.http import HTTPResponse -from netlib.odict import ODictCaseless - """ This example shows two ways to redirect flows to other destinations. """ - +from libmproxy.models import HTTPResponse +from netlib.http import Headers def request(context, flow): # pretty_host(hostheader=True) takes the Host: header of the request into account, @@ -15,7 +13,7 @@ def request(context, flow): if flow.request.pretty_host(hostheader=True).endswith("example.com"): resp = HTTPResponse( [1, 1], 200, "OK", - ODictCaseless([["Content-Type", "text/html"]]), + Headers(Content_Type="text/html"), "helloworld") flow.reply(resp) diff --git a/examples/stickycookies b/examples/stickycookies index 67b31da1d..7e84f71cd 100755 --- a/examples/stickycookies +++ b/examples/stickycookies @@ -23,16 +23,16 @@ class StickyMaster(controller.Master): def handle_request(self, flow): hid = (flow.request.host, flow.request.port) - if flow.request.headers["cookie"]: - self.stickyhosts[hid] = flow.request.headers["cookie"] + if "cookie" in flow.request.headers: + self.stickyhosts[hid] = flow.request.headers.get_all("cookie") elif hid in self.stickyhosts: - flow.request.headers["cookie"] = self.stickyhosts[hid] + flow.request.headers.set_all("cookie", self.stickyhosts[hid]) flow.reply() def handle_response(self, flow): hid = (flow.request.host, flow.request.port) - if flow.response.headers["set-cookie"]: - self.stickyhosts[hid] = flow.response.headers["set-cookie"] + if "set-cookie" in flow.response.headers: + self.stickyhosts[hid] = flow.response.headers.get_all("set-cookie") flow.reply() diff --git a/examples/stream_modify.py b/examples/stream_modify.py index e3f1f3cf7..aa395c03f 100644 --- a/examples/stream_modify.py +++ b/examples/stream_modify.py @@ -11,11 +11,9 @@ Be aware that content replacement isn't trivial: def modify(chunks): """ chunks is a generator that can be used to iterate over all chunks. - Each chunk is a (prefix, content, suffix) tuple. - For example, in the case of chunked transfer encoding: ("3\r\n","foo","\r\n") """ - for prefix, content, suffix in chunks: - yield prefix, content.replace("foo", "bar"), suffix + for chunk in chunks: + yield chunk.replace("foo", "bar") def responseheaders(context, flow): diff --git a/examples/stub.py b/examples/stub.py index bd3e7cd02..516b71a53 100644 --- a/examples/stub.py +++ b/examples/stub.py @@ -18,14 +18,6 @@ def clientconnect(context, root_layer): context.log("clientconnect") -def serverconnect(context, server_connection): - """ - Called when the proxy initiates a connection to the target server. Note that a - connection can correspond to multiple HTTP requests - """ - context.log("serverconnect") - - def request(context, flow): """ Called when a client request has been received. @@ -33,6 +25,14 @@ def request(context, flow): context.log("request") +def serverconnect(context, server_conn): + """ + Called when the proxy initiates a connection to the target server. Note that a + connection can correspond to multiple HTTP requests + """ + context.log("serverconnect") + + def responseheaders(context, flow): """ Called when the response headers for a server response have been received, @@ -58,7 +58,7 @@ def error(context, flow): context.log("error") -def serverdisconnect(context, server_connection): +def serverdisconnect(context, server_conn): """ Called when the proxy closes the connection to the target server. """ diff --git a/examples/tls_passthrough.py b/examples/tls_passthrough.py new file mode 100644 index 000000000..7b4dec622 --- /dev/null +++ b/examples/tls_passthrough.py @@ -0,0 +1,136 @@ +""" +This inline script allows conditional TLS Interception based +on a user-defined strategy. + +Example: + + > mitmdump -s tls_passthrough.py + + 1. curl --proxy http://localhost:8080 https://example.com --insecure + // works - we'll also see the contents in mitmproxy + + 2. curl --proxy http://localhost:8080 https://example.com --insecure + // still works - we'll also see the contents in mitmproxy + + 3. curl --proxy http://localhost:8080 https://example.com + // fails with a certificate error, which we will also see in mitmproxy + + 4. curl --proxy http://localhost:8080 https://example.com + // works again, but mitmproxy does not intercept and we do *not* see the contents + +Authors: Maximilian Hils, Matthew Tuusberg +""" +from __future__ import (absolute_import, print_function, division) +import collections +import random + +from enum import Enum + +from libmproxy.exceptions import TlsException +from libmproxy.protocol import TlsLayer, RawTCPLayer + + +class InterceptionResult(Enum): + success = True + failure = False + skipped = None + + +class _TlsStrategy(object): + """ + Abstract base class for interception strategies. + """ + def __init__(self): + # A server_address -> interception results mapping + self.history = collections.defaultdict(lambda: collections.deque(maxlen=200)) + + def should_intercept(self, server_address): + """ + Returns: + True, if we should attempt to intercept the connection. + False, if we want to employ pass-through instead. + """ + raise NotImplementedError() + + def record_success(self, server_address): + self.history[server_address].append(InterceptionResult.success) + + def record_failure(self, server_address): + self.history[server_address].append(InterceptionResult.failure) + + def record_skipped(self, server_address): + self.history[server_address].append(InterceptionResult.skipped) + + +class ConservativeStrategy(_TlsStrategy): + """ + Conservative Interception Strategy - only intercept if there haven't been any failed attempts + in the history. + """ + + def should_intercept(self, server_address): + if InterceptionResult.failure in self.history[server_address]: + return False + return True + + +class ProbabilisticStrategy(_TlsStrategy): + """ + Fixed probability that we intercept a given connection. + """ + def __init__(self, p): + self.p = p + super(ProbabilisticStrategy, self).__init__() + + def should_intercept(self, server_address): + return random.uniform(0, 1) < self.p + + +class TlsFeedback(TlsLayer): + """ + Monkey-patch _establish_tls_with_client to get feedback if TLS could be established + successfully on the client connection (which may fail due to cert pinning). + """ + + def _establish_tls_with_client(self): + server_address = self.server_conn.address + tls_strategy = self.script_context.tls_strategy + + try: + super(TlsFeedback, self)._establish_tls_with_client() + except TlsException as e: + tls_strategy.record_failure(server_address) + raise e + else: + tls_strategy.record_success(server_address) + + +# inline script hooks below. + + +def start(context, argv): + if len(argv) == 2: + context.tls_strategy = ProbabilisticStrategy(float(argv[1])) + else: + context.tls_strategy = ConservativeStrategy() + + +def next_layer(context, next_layer): + """ + This hook does the actual magic - if the next layer is planned to be a TLS layer, + we check if we want to enter pass-through mode instead. + """ + if isinstance(next_layer, TlsLayer) and next_layer._client_tls: + server_address = next_layer.server_conn.address + + if context.tls_strategy.should_intercept(server_address): + # We try to intercept. + # Monkey-Patch the layer to get feedback from the TLSLayer if interception worked. + next_layer.__class__ = TlsFeedback + next_layer.script_context = context + else: + # We don't intercept - reply with a pass-through layer and add a "skipped" entry. + context.log("TLS passthrough for %s" % repr(next_layer.server_conn.address), "info") + next_layer_replacement = RawTCPLayer(next_layer.ctx, logging=False) + next_layer.reply(next_layer_replacement) + context.tls_strategy.record_skipped(server_address) diff --git a/examples/upsidedownternet.py b/examples/upsidedownternet.py index a6de97e4c..f2e730475 100644 --- a/examples/upsidedownternet.py +++ b/examples/upsidedownternet.py @@ -1,10 +1,10 @@ import cStringIO from PIL import Image -from libmproxy.protocol.http import decoded +from libmproxy.models import decoded def response(context, flow): - if flow.response.headers.get_first("content-type", "").startswith("image"): + if flow.response.headers.get("content-type", "").startswith("image"): with decoded(flow.response): # automatically decode gzipped responses. try: s = cStringIO.StringIO(flow.response.content) @@ -12,6 +12,6 @@ def response(context, flow): s2 = cStringIO.StringIO() img.save(s2, "png") flow.response.content = s2.getvalue() - flow.response.headers["content-type"] = ["image/png"] + flow.response.headers["content-type"] = "image/png" except: # Unknown image types etc. pass diff --git a/libmproxy/cmdline.py b/libmproxy/cmdline.py index 7f6f69ef0..3779953f8 100644 --- a/libmproxy/cmdline.py +++ b/libmproxy/cmdline.py @@ -1,11 +1,11 @@ from __future__ import absolute_import import os import re + import configargparse + from netlib.tcp import Address, sslversion_choices - import netlib.utils - from . import filt, utils, version from .proxy import config @@ -358,6 +358,20 @@ def proxy_options(parser): action="store", type=int, dest="port", default=8080, help="Proxy service port." ) + http2 = group.add_mutually_exclusive_group() + http2.add_argument("--http2", action="store_true", dest="http2") + http2.add_argument("--no-http2", action="store_false", dest="http2", + help="Explicitly enable/disable experimental HTTP2 support. " + "Disabled by default. " + "Default value will change in a future version." + ) + rawtcp = group.add_mutually_exclusive_group() + rawtcp.add_argument("--raw-tcp", action="store_true", dest="rawtcp") + rawtcp.add_argument("--no-raw-tcp", action="store_false", dest="rawtcp", + help="Explicitly enable/disable experimental raw tcp support. " + "Disabled by default. " + "Default value will change in a future version." + ) def proxy_ssl_options(parser): diff --git a/libmproxy/console/common.py b/libmproxy/console/common.py index c25f7267f..ae3dd61ee 100644 --- a/libmproxy/console/common.py +++ b/libmproxy/console/common.py @@ -415,9 +415,9 @@ def format_flow(f, focus, extended=False, hostheader=False, padding=2, resp_clen = contentdesc, roundtrip = roundtrip, )) - t = f.response.headers["content-type"] + t = f.response.headers.get("content-type") if t: - d["resp_ctype"] = t[0].split(";")[0] + d["resp_ctype"] = t.split(";")[0] else: d["resp_ctype"] = "" return flowcache.get( diff --git a/libmproxy/console/flowview.py b/libmproxy/console/flowview.py index 958ab1766..e33d4c43d 100644 --- a/libmproxy/console/flowview.py +++ b/libmproxy/console/flowview.py @@ -4,7 +4,7 @@ import sys import urwid from netlib import odict -from netlib.http.semantics import CONTENT_MISSING +from netlib.http.semantics import CONTENT_MISSING, Headers from . import common, grideditor, signals, searchable, tabs from . import flowdetailview @@ -182,7 +182,7 @@ class FlowView(tabs.Tabs): description, text_objects = cache.get( contentview.get_content_view, viewmode, - tuple(tuple(i) for i in conn.headers.lst), + conn.headers, conn.content, limit, isinstance(conn, HTTPRequest), @@ -200,7 +200,7 @@ class FlowView(tabs.Tabs): def conn_text(self, conn): if conn: txt = common.format_keyvals( - [(h + ":", v) for (h, v) in conn.headers.lst], + [(h + ":", v) for (h, v) in conn.headers.fields], key = "header", val = "text" ) @@ -285,8 +285,8 @@ class FlowView(tabs.Tabs): response.msg = msg signals.flow_change.send(self, flow = self.flow) - def set_headers(self, lst, conn): - conn.headers = odict.ODictCaseless(lst) + def set_headers(self, fields, conn): + conn.headers = Headers(fields) signals.flow_change.send(self, flow = self.flow) def set_query(self, lst, conn): @@ -331,7 +331,7 @@ class FlowView(tabs.Tabs): if not self.flow.response: self.flow.response = HTTPResponse( self.flow.request.httpversion, - 200, "OK", odict.ODictCaseless(), "" + 200, "OK", Headers(), "" ) self.flow.response.reply = controller.DummyReply() message = self.flow.response @@ -382,7 +382,7 @@ class FlowView(tabs.Tabs): self.master.view_grideditor( grideditor.HeaderEditor( self.master, - message.headers.lst, + message.headers.fields, self.set_headers, message ) @@ -617,8 +617,7 @@ class FlowView(tabs.Tabs): key = None elif key == "v": if conn.content: - t = conn.headers["content-type"] or [None] - t = t[0] + t = conn.headers.get("content-type") if "EDITOR" in os.environ or "PAGER" in os.environ: self.master.spawn_external_viewer(conn.content, t) else: @@ -627,7 +626,7 @@ class FlowView(tabs.Tabs): ) elif key == "z": self.flow.backup() - e = conn.headers.get_first("content-encoding", "identity") + e = conn.headers.get("content-encoding", "identity") if e != "identity": if not conn.decode(): signals.status_message.send( diff --git a/libmproxy/contentview.py b/libmproxy/contentview.py index 45c1f2f1d..a9b6cf959 100644 --- a/libmproxy/contentview.py +++ b/libmproxy/contentview.py @@ -9,14 +9,13 @@ import lxml.html import lxml.etree from PIL import Image from PIL.ExifTags import TAGS -import urwid import html2text import netlib.utils -from netlib import odict, encoding from . import utils from .contrib import jsbeautifier from .contrib.wbxml.ASCommandResponse import ASCommandResponse +from netlib import encoding try: import pyamf @@ -129,7 +128,7 @@ class ViewAuto(View): content_types = [] def __call__(self, hdrs, content, limit): - ctype = hdrs.get_first("content-type") + ctype = hdrs.get("content-type") if ctype: ct = netlib.utils.parse_content_type(ctype) if ctype else None ct = "%s/%s" % (ct[0], ct[1]) @@ -536,7 +535,7 @@ def get(name): return i -def get_content_view(viewmode, hdrItems, content, limit, is_request, log=None): +def get_content_view(viewmode, headers, content, limit, is_request, log=None): """ Returns: A (msg, body) tuple. @@ -551,16 +550,14 @@ def get_content_view(viewmode, hdrItems, content, limit, is_request, log=None): return "No content", "" msg = [] - hdrs = odict.ODictCaseless([list(i) for i in hdrItems]) - - enc = hdrs.get_first("content-encoding") + enc = headers.get("content-encoding") if enc and enc != "identity": decoded = encoding.decode(enc, content) if decoded: content = decoded msg.append("[decoded %s]" % enc) try: - ret = viewmode(hdrs, content, limit) + ret = viewmode(headers, content, limit) # Third-party viewers can fail in unexpected ways... except Exception: if log: @@ -569,7 +566,7 @@ def get_content_view(viewmode, hdrItems, content, limit, is_request, log=None): log(s, "error") ret = None if not ret: - ret = get("Raw")(hdrs, content, limit) + ret = get("Raw")(headers, content, limit) msg.append("Couldn't parse: falling back to Raw") else: msg.append(ret[0]) diff --git a/libmproxy/dump.py b/libmproxy/dump.py index bf4098035..17b47dd2b 100644 --- a/libmproxy/dump.py +++ b/libmproxy/dump.py @@ -174,7 +174,7 @@ class DumpMaster(flow.FlowMaster): def _print_message(self, message): if self.o.flow_detail >= 2: - print(self.indent(4, message.headers.format()), file=self.outfile) + print(self.indent(4, str(message.headers)), file=self.outfile) if self.o.flow_detail >= 3: if message.content == CONTENT_MISSING: print(self.indent(4, "(content missing)"), file=self.outfile) diff --git a/libmproxy/exceptions.py b/libmproxy/exceptions.py index f34d97078..0e11c136d 100644 --- a/libmproxy/exceptions.py +++ b/libmproxy/exceptions.py @@ -1,23 +1,42 @@ +""" +We try to be very hygienic regarding the exceptions we throw: +Every Exception mitmproxy raises shall be a subclass of ProxyException. + + +See also: http://lucumr.pocoo.org/2014/10/16/on-error-handling/ +""" from __future__ import (absolute_import, print_function, division) class ProxyException(Exception): """ Base class for all exceptions thrown by libmproxy. + + Args: + message: the error message + cause: (optional) an error object that caused this exception, e.g. an IOError. """ - def __init__(self, message, cause=None): + def __init__(self, message): """ :param message: Error Message - :param cause: Exception object that caused this exception to be thrown. """ super(ProxyException, self).__init__(message) - self.cause = cause class ProtocolException(ProxyException): pass +class TlsException(ProtocolException): + pass + + +class ClientHandshakeException(TlsException): + def __init__(self, message, server): + super(ClientHandshakeException, self).__init__(message) + self.server = server + + class Socks5Exception(ProtocolException): pass diff --git a/libmproxy/filt.py b/libmproxy/filt.py index cfd3a1bc5..7cd0f4dfb 100644 --- a/libmproxy/filt.py +++ b/libmproxy/filt.py @@ -35,7 +35,6 @@ from __future__ import absolute_import import re import sys import pyparsing as pp -from .models import decoded class _Token: @@ -78,17 +77,19 @@ class FResp(_Action): class _Rex(_Action): + flags = 0 + def __init__(self, expr): self.expr = expr try: - self.re = re.compile(self.expr) + self.re = re.compile(self.expr, self.flags) except: raise ValueError("Cannot compile expression.") def _check_content_type(expr, o): - val = o.headers["content-type"] - if val and re.search(expr, val[0]): + val = o.headers.get("content-type") + if val and re.search(expr, val): return True return False @@ -146,11 +147,12 @@ class FResponseContentType(_Rex): class FHead(_Rex): code = "h" help = "Header" + flags = re.MULTILINE def __call__(self, f): - if f.request.headers.match_re(self.expr): + if f.request and self.re.search(str(f.request.headers)): return True - elif f.response and f.response.headers.match_re(self.expr): + if f.response and self.re.search(str(f.response.headers)): return True return False @@ -158,18 +160,20 @@ class FHead(_Rex): class FHeadRequest(_Rex): code = "hq" help = "Request header" + flags = re.MULTILINE def __call__(self, f): - if f.request.headers.match_re(self.expr): + if f.request and self.re.search(str(f.request.headers)): return True class FHeadResponse(_Rex): code = "hs" help = "Response header" + flags = re.MULTILINE def __call__(self, f): - if f.response and f.response.headers.match_re(self.expr): + if f.response and self.re.search(str(f.response.headers)): return True @@ -179,13 +183,11 @@ class FBod(_Rex): def __call__(self, f): if f.request and f.request.content: - with decoded(f.request): - if re.search(self.expr, f.request.content): - return True + if self.re.search(f.request.get_decoded_content()): + return True if f.response and f.response.content: - with decoded(f.response): - if re.search(self.expr, f.response.content): - return True + if self.re.search(f.response.get_decoded_content()): + return True return False @@ -195,9 +197,8 @@ class FBodRequest(_Rex): def __call__(self, f): if f.request and f.request.content: - with decoded(f.request): - if re.search(self.expr, f.request.content): - return True + if self.re.search(f.request.get_decoded_content()): + return True class FBodResponse(_Rex): @@ -206,25 +207,26 @@ class FBodResponse(_Rex): def __call__(self, f): if f.response and f.response.content: - with decoded(f.response): - if re.search(self.expr, f.response.content): - return True + if self.re.search(f.response.get_decoded_content()): + return True class FMethod(_Rex): code = "m" help = "Method" + flags = re.IGNORECASE def __call__(self, f): - return bool(re.search(self.expr, f.request.method, re.IGNORECASE)) + return bool(self.re.search(f.request.method)) class FDomain(_Rex): code = "d" help = "Domain" + flags = re.IGNORECASE def __call__(self, f): - return bool(re.search(self.expr, f.request.host, re.IGNORECASE)) + return bool(self.re.search(f.request.host)) class FUrl(_Rex): @@ -239,21 +241,24 @@ class FUrl(_Rex): return klass(*toks) def __call__(self, f): - return re.search(self.expr, f.request.url) + return self.re.search(f.request.url) + class FSrc(_Rex): code = "src" help = "Match source address" def __call__(self, f): - return f.client_conn.address and re.search(self.expr, repr(f.client_conn.address)) + return f.client_conn.address and self.re.search(repr(f.client_conn.address)) + class FDst(_Rex): code = "dst" help = "Match destination address" def __call__(self, f): - return f.server_conn.address and re.search(self.expr, repr(f.server_conn.address)) + return f.server_conn.address and self.re.search(repr(f.server_conn.address)) + class _Int(_Action): def __init__(self, num): diff --git a/libmproxy/flow.py b/libmproxy/flow.py index 5eac8da92..d037d36e4 100644 --- a/libmproxy/flow.py +++ b/libmproxy/flow.py @@ -11,8 +11,8 @@ import re import urlparse -from netlib import odict, wsgi -from netlib.http.semantics import CONTENT_MISSING +from netlib import wsgi +from netlib.http.semantics import CONTENT_MISSING, Headers import netlib.http from . import controller, tnetstring, filt, script, version from .onboarding import app @@ -45,7 +45,7 @@ class AppRegistry: if (request.host, request.port) in self.apps: return self.apps[(request.host, request.port)] if "host" in request.headers: - host = request.headers["host"][0] + host = request.headers["host"] return self.apps.get((host, request.port), None) @@ -144,15 +144,15 @@ class SetHeaders: for _, header, value, cpatt in self.lst: if cpatt(f): if f.response: - del f.response.headers[header] + f.response.headers.pop(header, None) else: - del f.request.headers[header] + f.request.headers.pop(header, None) for _, header, value, cpatt in self.lst: if cpatt(f): if f.response: - f.response.headers.add(header, value) + f.response.headers.fields.append((header, value)) else: - f.request.headers.add(header, value) + f.request.headers.fields.append((header, value)) class StreamLargeBodies(object): @@ -278,14 +278,11 @@ class ServerPlaybackState: key.append(p[1]) if self.headers: - hdrs = [] + headers = [] for i in self.headers: - v = r.headers[i] - # Slightly subtle: we need to convert everything to strings - # to prevent a mismatch between unicode/non-unicode. - v = [str(x) for x in v] - hdrs.append((i, v)) - key.append(hdrs) + v = r.headers.get(i) + headers.append((i, v)) + key.append(headers) return hashlib.sha256(repr(key)).digest() def next_flow(self, request): @@ -329,7 +326,7 @@ class StickyCookieState: return False def handle_response(self, f): - for i in f.response.headers["set-cookie"]: + for i in f.response.headers.get_all("set-cookie"): # FIXME: We now know that Cookie.py screws up some cookies with # valid RFC 822/1123 datetime specifications for expiry. Sigh. c = Cookie.SimpleCookie(str(i)) @@ -351,7 +348,7 @@ class StickyCookieState: l.append(self.jar[i].output(header="").strip()) if l: f.request.stickycookie = True - f.request.headers["cookie"] = l + f.request.headers.set_all("cookie",l) class StickyAuthState: @@ -836,7 +833,7 @@ class FlowMaster(controller.Master): ssl_established=True )) f = HTTPFlow(c, s) - headers = odict.ODictCaseless() + headers = Headers() req = HTTPRequest( "absolute", @@ -930,8 +927,7 @@ class FlowMaster(controller.Master): f.backup() f.request.is_replay = True if f.request.content: - f.request.headers[ - "Content-Length"] = [str(len(f.request.content))] + f.request.headers["Content-Length"] = str(len(f.request.content)) f.response = None f.error = None self.process_new_request(f) @@ -949,21 +945,25 @@ class FlowMaster(controller.Master): self.add_event(l.msg, l.level) l.reply() - def handle_clientconnect(self, cc): - self.run_script_hook("clientconnect", cc) - cc.reply() + def handle_clientconnect(self, root_layer): + self.run_script_hook("clientconnect", root_layer) + root_layer.reply() - def handle_clientdisconnect(self, r): - self.run_script_hook("clientdisconnect", r) - r.reply() + def handle_clientdisconnect(self, root_layer): + self.run_script_hook("clientdisconnect", root_layer) + root_layer.reply() - def handle_serverconnect(self, sc): - self.run_script_hook("serverconnect", sc) - sc.reply() + def handle_serverconnect(self, server_conn): + self.run_script_hook("serverconnect", server_conn) + server_conn.reply() - def handle_serverdisconnect(self, sc): - self.run_script_hook("serverdisconnect", sc) - sc.reply() + def handle_serverdisconnect(self, server_conn): + self.run_script_hook("serverdisconnect", server_conn) + server_conn.reply() + + def handle_next_layer(self, top_layer): + self.run_script_hook("next_layer", top_layer) + top_layer.reply() def handle_error(self, f): self.state.update_flow(f) diff --git a/libmproxy/models/http.py b/libmproxy/models/http.py index fb2f305bb..0d5e53b59 100644 --- a/libmproxy/models/http.py +++ b/libmproxy/models/http.py @@ -5,8 +5,8 @@ from email.utils import parsedate_tz, formatdate, mktime_tz import time from libmproxy import utils -from netlib import odict, encoding -from netlib.http import status_codes +from netlib import encoding +from netlib.http import status_codes, Headers from netlib.tcp import Address from netlib.http.semantics import Request, Response, CONTENT_MISSING from .. import version, stateobject @@ -16,7 +16,7 @@ from .flow import Flow class MessageMixin(stateobject.StateObject): _stateobject_attributes = dict( httpversion=tuple, - headers=odict.ODictCaseless, + headers=Headers, body=str, timestamp_start=float, timestamp_end=float @@ -40,7 +40,7 @@ class MessageMixin(stateobject.StateObject): header. Doesn't change the message iteself or its headers. """ - ce = self.headers.get_first("content-encoding") + ce = self.headers.get("content-encoding") if not self.body or ce not in encoding.ENCODINGS: return self.body return encoding.decode(ce, self.body) @@ -53,14 +53,14 @@ class MessageMixin(stateobject.StateObject): Returns True if decoding succeeded, False otherwise. """ - ce = self.headers.get_first("content-encoding") + ce = self.headers.get("content-encoding") if not self.body or ce not in encoding.ENCODINGS: return False data = encoding.decode(ce, self.body) if data is None: return False self.body = data - del self.headers["content-encoding"] + self.headers.pop("content-encoding", None) return True def encode(self, e): @@ -70,7 +70,7 @@ class MessageMixin(stateobject.StateObject): """ # FIXME: Error if there's an existing encoding header? self.body = encoding.encode(e, self.body) - self.headers["content-encoding"] = [e] + self.headers["content-encoding"] = e def copy(self): c = copy.copy(self) @@ -86,11 +86,18 @@ class MessageMixin(stateobject.StateObject): Returns the number of replacements made. """ with decoded(self): - self.body, c = utils.safe_subn( + self.body, count = utils.safe_subn( pattern, repl, self.body, *args, **kwargs ) - c += self.headers.replace(pattern, repl, *args, **kwargs) - return c + fields = [] + for name, value in self.headers.fields: + name, c = utils.safe_subn(pattern, repl, name, *args, **kwargs) + count += c + value, c = utils.safe_subn(pattern, repl, value, *args, **kwargs) + count += c + fields.append([name, value]) + self.headers.fields = fields + return count class HTTPRequest(MessageMixin, Request): @@ -115,7 +122,7 @@ class HTTPRequest(MessageMixin, Request): httpversion: HTTP version tuple, e.g. (1,1) - headers: odict.ODictCaseless object + headers: Headers object content: Content of the request, None, or CONTENT_MISSING if there is content associated, but not present. CONTENT_MISSING evaluates @@ -266,7 +273,7 @@ class HTTPResponse(MessageMixin, Response): msg: HTTP response message - headers: ODict Caseless object + headers: Headers object content: Content of the request, None, or CONTENT_MISSING if there is content associated, but not present. CONTENT_MISSING evaluates @@ -379,15 +386,15 @@ class HTTPResponse(MessageMixin, Response): ] for i in refresh_headers: if i in self.headers: - d = parsedate_tz(self.headers[i][0]) + d = parsedate_tz(self.headers[i]) if d: new = mktime_tz(d) + delta - self.headers[i] = [formatdate(new)] + self.headers[i] = formatdate(new) c = [] - for i in self.headers["set-cookie"]: + for i in self.headers.get_all("set-cookie"): c.append(self._refresh_cookie(i, delta)) if c: - self.headers["set-cookie"] = c + self.headers.set_all("set-cookie", c) class HTTPFlow(Flow): @@ -490,7 +497,7 @@ class decoded(object): def __init__(self, o): self.o = o - ce = o.headers.get_first("content-encoding") + ce = o.headers.get("content-encoding") if ce in encoding.ENCODINGS: self.ce = ce else: @@ -517,11 +524,12 @@ def make_error_response(status_code, message, headers=None): """.strip() % (status_code, response, message) if not headers: - headers = odict.ODictCaseless() - headers["Server"] = [version.NAMEVERSION] - headers["Connection"] = ["close"] - headers["Content-Length"] = [len(body)] - headers["Content-Type"] = ["text/html"] + headers = Headers( + Server=version.NAMEVERSION, + Connection="close", + Content_Length=str(len(body)), + Content_Type="text/html" + ) return HTTPResponse( (1, 1), # FIXME: Should be a string. @@ -536,15 +544,15 @@ def make_connect_request(address): address = Address.wrap(address) return HTTPRequest( "authority", "CONNECT", None, address.host, address.port, None, (1, 1), - odict.ODictCaseless(), "" + Headers(), "" ) def make_connect_response(httpversion): - headers = odict.ODictCaseless([ - ["Content-Length", "0"], - ["Proxy-Agent", version.NAMEVERSION] - ]) + headers = Headers( + Content_Length="0", + Proxy_Agent=version.NAMEVERSION + ) return HTTPResponse( httpversion, 200, diff --git a/libmproxy/protocol/__init__.py b/libmproxy/protocol/__init__.py index c582592ba..35d59f287 100644 --- a/libmproxy/protocol/__init__.py +++ b/libmproxy/protocol/__init__.py @@ -1,11 +1,38 @@ +""" +In mitmproxy, protocols are implemented as a set of layers, which are composed on top each other. +The first layer is usually the proxy mode, e.g. transparent proxy or normal HTTP proxy. Next, +various protocol layers are stacked on top of each other - imagine WebSockets on top of an HTTP +Upgrade request. An actual mitmproxy connection may look as follows (outermost layer first): + + Transparent HTTP proxy, no TLS: + - TransparentProxy + - Http1Layer + - HttpLayer + + Regular proxy, CONNECT request with WebSockets over SSL: + - ReverseProxy + - Http1Layer + - HttpLayer + - TLSLayer + - WebsocketLayer (or TCPLayer) + +Every layer acts as a read-only context for its inner layers (see :py:class:`Layer`). To communicate +with an outer layer, a layer can use functions provided in the context. The next layer is always +determined by a call to :py:meth:`.next_layer() `, +which is provided by the root context. + +Another subtle design goal of this architecture is that upstream connections should be established +as late as possible; this makes server replay without any outgoing connections possible. +""" + from __future__ import (absolute_import, print_function, division) -from .base import Layer, ServerConnectionMixin, Log, Kill +from .base import Layer, ServerConnectionMixin, Kill from .http import Http1Layer, Http2Layer from .tls import TlsLayer, is_tls_record_magic from .rawtcp import RawTCPLayer __all__ = [ - "Layer", "ServerConnectionMixin", "Log", "Kill", + "Layer", "ServerConnectionMixin", "Kill", "Http1Layer", "Http2Layer", "TlsLayer", "is_tls_record_magic", "RawTCPLayer" diff --git a/libmproxy/protocol/base.py b/libmproxy/protocol/base.py index 40ec0536f..b92aeea15 100644 --- a/libmproxy/protocol/base.py +++ b/libmproxy/protocol/base.py @@ -1,38 +1,8 @@ -""" -mitmproxy protocol architecture - -In mitmproxy, protocols are implemented as a set of layers, which are composed on top each other. -For example, the following scenarios depict possible settings (lowest layer first): - -Transparent HTTP proxy, no SSL: - TransparentProxy - Http1Layer - HttpLayer - -Regular proxy, CONNECT request with WebSockets over SSL: - HttpProxy - Http1Layer - HttpLayer - SslLayer - WebsocketLayer (or TcpLayer) - -Automated protocol detection by peeking into the buffer: - TransparentProxy - TLSLayer - Http2Layer - HttpLayer - -Communication between layers is done as follows: - - lower layers provide context information to higher layers - - higher layers can call functions provided by lower layers, - which are propagated until they reach a suitable layer. - -Further goals: - - Connections should always be peekable to make automatic protocol detection work. - - Upstream connections should be established as late as possible; - inline scripts shall have a chance to handle everything locally. -""" from __future__ import (absolute_import, print_function, division) +import sys + +import six + from netlib import tcp from ..models import ServerConnection from ..exceptions import ProtocolException @@ -43,8 +13,8 @@ class _LayerCodeCompletion(object): Dummy class that provides type hinting in PyCharm, which simplifies development a lot. """ - def __init__(self, *args, **kwargs): # pragma: nocover - super(_LayerCodeCompletion, self).__init__(*args, **kwargs) + def __init__(self, **mixin_args): # pragma: nocover + super(_LayerCodeCompletion, self).__init__(**mixin_args) if True: return self.config = None @@ -55,43 +25,64 @@ class _LayerCodeCompletion(object): """@type: libmproxy.models.ServerConnection""" self.channel = None """@type: libmproxy.controller.Channel""" + self.ctx = None + """@type: libmproxy.protocol.Layer""" class Layer(_LayerCodeCompletion): - def __init__(self, ctx, *args, **kwargs): + """ + Base class for all layers. All other protocol layers should inherit from this class. + """ + + def __init__(self, ctx, **mixin_args): """ + Each layer usually passes itself to its child layers as a context. Properties of the + context are transparently mapped to the layer, so that the following works: + + .. code-block:: python + + root_layer = Layer(None) + root_layer.client_conn = 42 + sub_layer = Layer(root_layer) + print(sub_layer.client_conn) # 42 + + The root layer is passed a :py:class:`libmproxy.proxy.RootContext` object, + which provides access to :py:attr:`.client_conn `, + :py:attr:`.next_layer ` and other basic attributes. + Args: - ctx: The (read-only) higher layer. + ctx: The (read-only) parent layer / context. """ self.ctx = ctx - """@type: libmproxy.protocol.Layer""" - super(Layer, self).__init__(*args, **kwargs) + """ + The parent layer. + + :type: :py:class:`Layer` + """ + super(Layer, self).__init__(**mixin_args) def __call__(self): - """ - Logic of the layer. + """Logic of the layer. + + Returns: + Once the protocol has finished without exceptions. + Raises: - ProtocolException in case of protocol exceptions. + ~libmproxy.exceptions.ProtocolException: if an exception occurs. No other exceptions must be raised. """ raise NotImplementedError() def __getattr__(self, name): """ - Attributes not present on the current layer may exist on a higher layer. + Attributes not present on the current layer are looked up on the context. """ return getattr(self.ctx, name) - def log(self, msg, level, subs=()): - full_msg = [ - "{}: {}".format(repr(self.client_conn.address), msg) - ] - for i in subs: - full_msg.append(" -> " + i) - full_msg = "\n".join(full_msg) - self.channel.tell("log", Log(full_msg, level)) - @property def layers(self): + """ + List of all layers, including the current layer (``[self, self.ctx, self.ctx.ctx, ...]``) + """ return [self] + self.ctx.layers def __repr__(self): @@ -101,20 +92,28 @@ class Layer(_LayerCodeCompletion): class ServerConnectionMixin(object): """ Mixin that provides a layer with the capabilities to manage a server connection. + The server address can be passed in the constructor or set by calling :py:meth:`set_server`. + Subclasses are responsible for calling :py:meth:`disconnect` before returning. + + Recommended Usage: + + .. code-block:: python + + class MyLayer(Layer, ServerConnectionMixin): + def __call__(self): + try: + # Do something. + finally: + if self.server_conn: + self.disconnect() """ def __init__(self, server_address=None): super(ServerConnectionMixin, self).__init__() self.server_conn = ServerConnection(server_address) - self._check_self_connect() + self.__check_self_connect() - def reconnect(self): - address = self.server_conn.address - self._disconnect() - self.server_conn.address = address - self.connect() - - def _check_self_connect(self): + def __check_self_connect(self): """ We try to protect the proxy from _accidentally_ connecting to itself, e.g. because of a failed transparent lookup or an invalid configuration. @@ -131,31 +130,45 @@ class ServerConnectionMixin(object): "The proxy shall not connect to itself.".format(repr(address)) ) - def set_server(self, address, server_tls=None, sni=None, depth=1): - if depth == 1: - if self.server_conn: - self._disconnect() - self.log("Set new server address: " + repr(address), "debug") - self.server_conn.address = address - self._check_self_connect() - if server_tls: - raise ProtocolException( - "Cannot upgrade to TLS, no TLS layer on the protocol stack." - ) - else: - self.ctx.set_server(address, server_tls, sni, depth - 1) + def set_server(self, address, server_tls=None, sni=None): + """ + Sets a new server address. If there is an existing connection, it will be closed. - def _disconnect(self): + Raises: + ~libmproxy.exceptions.ProtocolException: + if ``server_tls`` is ``True``, but there was no TLS layer on the + protocol stack which could have processed this. + """ + if self.server_conn: + self.disconnect() + self.log("Set new server address: " + repr(address), "debug") + self.server_conn.address = address + self.__check_self_connect() + if server_tls: + raise ProtocolException( + "Cannot upgrade to TLS, no TLS layer on the protocol stack." + ) + + def disconnect(self): """ Deletes (and closes) an existing server connection. + Must not be called if there is no existing connection. """ self.log("serverdisconnect", "debug", [repr(self.server_conn.address)]) + address = self.server_conn.address self.server_conn.finish() self.server_conn.close() self.channel.tell("serverdisconnect", self.server_conn) - self.server_conn = ServerConnection(None) + self.server_conn = ServerConnection(address) def connect(self): + """ + Establishes a server connection. + Must not be called if there is an existing connection. + + Raises: + ~libmproxy.exceptions.ProtocolException: if the connection could not be established. + """ if not self.server_conn.address: raise ProtocolException("Cannot connect to server, no server address given.") self.log("serverconnect", "debug", [repr(self.server_conn.address)]) @@ -163,17 +176,18 @@ class ServerConnectionMixin(object): try: self.server_conn.connect() except tcp.NetLibError as e: - raise ProtocolException( - "Server connection to %s failed: %s" % (repr(self.server_conn.address), e), e) - - -class Log(object): - def __init__(self, msg, level="info"): - self.msg = msg - self.level = level + six.reraise( + ProtocolException, + ProtocolException( + "Server connection to {} failed: {}".format( + repr(self.server_conn.address), str(e) + ) + ), + sys.exc_info()[2] + ) class Kill(Exception): """ - Kill a connection. + Signal that both client and server connection(s) should be killed immediately. """ diff --git a/libmproxy/protocol/http.py b/libmproxy/protocol/http.py index 7f57d17cd..3a4153201 100644 --- a/libmproxy/protocol/http.py +++ b/libmproxy/protocol/http.py @@ -1,14 +1,16 @@ from __future__ import (absolute_import, print_function, division) +import itertools +import sys + +import six from netlib import tcp -from netlib.http import http1, HttpErrorConnClosed, HttpError +from netlib.http import http1, HttpErrorConnClosed, HttpError, Headers from netlib.http.semantics import CONTENT_MISSING -from netlib import odict from netlib.tcp import NetLibError, Address from netlib.http.http1 import HTTP1Protocol from netlib.http.http2 import HTTP2Protocol -from netlib.http.http2.frame import WindowUpdateFrame - +from netlib.http.http2.frame import GoAwayFrame, PriorityFrame, WindowUpdateFrame from .. import utils from ..exceptions import InvalidCredentials, HttpException, ProtocolException from ..models import ( @@ -32,6 +34,9 @@ class _HttpLayer(Layer): def send_response(self, response): raise NotImplementedError() + def check_close_connection(self, flow): + raise NotImplementedError() + class _StreamingHttpLayer(_HttpLayer): supports_streaming = True @@ -43,12 +48,25 @@ class _StreamingHttpLayer(_HttpLayer): raise NotImplementedError() yield "this is a generator" # pragma: no cover + def read_response(self, request_method): + response = self.read_response_headers() + response.body = "".join( + self.read_response_body(response.headers, request_method, response.code) + ) + return response + def send_response_headers(self, response): raise NotImplementedError def send_response_body(self, response, chunks): raise NotImplementedError() + def send_response(self, response): + if response.body == CONTENT_MISSING: + raise HttpError(502, "Cannot assemble flow with CONTENT_MISSING") + self.send_response_headers(response) + self.send_response_body(response, [response.body]) + class Http1Layer(_StreamingHttpLayer): def __init__(self, ctx, mode): @@ -66,17 +84,6 @@ class Http1Layer(_StreamingHttpLayer): def send_request(self, request): self.server_conn.send(self.server_protocol.assemble(request)) - def read_response(self, request_method): - return HTTPResponse.from_protocol( - self.server_protocol, - request_method=request_method, - body_size_limit=self.config.body_size_limit, - include_body=True - ) - - def send_response(self, response): - self.client_conn.send(self.client_protocol.assemble(response)) - def read_response_headers(self): return HTTPResponse.from_protocol( self.server_protocol, @@ -102,25 +109,49 @@ class Http1Layer(_StreamingHttpLayer): response, preserve_transfer_encoding=True ) - self.client_conn.send(h + "\r\n") + self.client_conn.wfile.write(h + "\r\n") + self.client_conn.wfile.flush() def send_response_body(self, response, chunks): if self.client_protocol.has_chunked_encoding(response.headers): - chunks = ( - "%d\r\n%s\r\n" % (len(chunk), chunk) - for chunk in chunks + chunks = itertools.chain( + ( + "{:x}\r\n{}\r\n".format(len(chunk), chunk) + for chunk in chunks if chunk + ), + ("0\r\n\r\n",) ) for chunk in chunks: - self.client_conn.send(chunk) + self.client_conn.wfile.write(chunk) + self.client_conn.wfile.flush() + + def check_close_connection(self, flow): + close_connection = ( + http1.HTTP1Protocol.connection_close( + flow.request.httpversion, + flow.request.headers + ) or http1.HTTP1Protocol.connection_close( + flow.response.httpversion, + flow.response.headers + ) or http1.HTTP1Protocol.expected_http_body_size( + flow.response.headers, + False, + flow.request.method, + flow.response.code) == -1 + ) + if flow.request.form_in == "authority" and flow.response.code == 200: + # Workaround for + # https://github.com/mitmproxy/mitmproxy/issues/313: Some + # proxies (e.g. Charles) send a CONNECT response with HTTP/1.0 + # and no Content-Length header + + return False + return close_connection def connect(self): self.ctx.connect() self.server_protocol = HTTP1Protocol(self.server_conn) - def reconnect(self): - self.ctx.reconnect() - self.server_protocol = HTTP1Protocol(self.server_conn) - def set_server(self, *args, **kwargs): self.ctx.set_server(*args, **kwargs) self.server_protocol = HTTP1Protocol(self.server_conn) @@ -136,9 +167,9 @@ class Http2Layer(_HttpLayer): super(Http2Layer, self).__init__(ctx) self.mode = mode self.client_protocol = HTTP2Protocol(self.client_conn, is_server=True, - unhandled_frame_cb=self.handle_unexpected_frame) + unhandled_frame_cb=self.handle_unexpected_frame_from_client) self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False, - unhandled_frame_cb=self.handle_unexpected_frame) + unhandled_frame_cb=self.handle_unexpected_frame_from_server) def read_request(self): request = HTTPRequest.from_protocol( @@ -162,25 +193,24 @@ class Http2Layer(_HttpLayer): ) def send_response(self, message): - # TODO: implement flow control and WINDOW_UPDATE frames + # TODO: implement flow control to prevent client buffer filling up + # maintain a send buffer size, and read WindowUpdateFrames from client to increase the send buffer self.client_conn.send(self.client_protocol.assemble(message)) + def check_close_connection(self, flow): + # TODO: add a timer to disconnect after a 10 second timeout + return False + def connect(self): self.ctx.connect() self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False, - unhandled_frame_cb=self.handle_unexpected_frame) - self.server_protocol.perform_connection_preface() - - def reconnect(self): - self.ctx.reconnect() - self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False, - unhandled_frame_cb=self.handle_unexpected_frame) + unhandled_frame_cb=self.handle_unexpected_frame_from_server) self.server_protocol.perform_connection_preface() def set_server(self, *args, **kwargs): self.ctx.set_server(*args, **kwargs) self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False, - unhandled_frame_cb=self.handle_unexpected_frame) + unhandled_frame_cb=self.handle_unexpected_frame_from_server) self.server_protocol.perform_connection_preface() def __call__(self): @@ -188,7 +218,10 @@ class Http2Layer(_HttpLayer): layer = HttpLayer(self, self.mode) layer() - def handle_unexpected_frame(self, frame): + # terminate the connection + self.client_conn.send(GoAwayFrame().to_bytes()) + + def handle_unexpected_frame_from_client(self, frame): if isinstance(frame, WindowUpdateFrame): # Clients are sending WindowUpdate frames depending on their flow control algorithm. # Since we cannot predict these frames, and we do not need to respond to them, @@ -196,7 +229,34 @@ class Http2Layer(_HttpLayer): # Ideally we should keep track of our own flow control window and # stall transmission if the outgoing flow control buffer is full. return - self.log("Unexpected HTTP2 Frame: %s" % frame.human_readable(), "info") + if isinstance(frame, PriorityFrame): + # Clients are sending Priority frames depending on their implementation. + # The RFC does not clearly state when or which priority preferences should be set. + # Since we cannot predict these frames, and we do not need to respond to them, + # simply accept them, and hide them from the log. + # Ideally we should forward them to the server. + return + if isinstance(frame, GoAwayFrame): + # Client wants to terminate the connection, + # relay it to the server. + self.server_conn.send(frame.to_bytes()) + return + self.log("Unexpected HTTP2 frame from client: %s" % frame.human_readable(), "info") + + def handle_unexpected_frame_from_server(self, frame): + if isinstance(frame, WindowUpdateFrame): + # Servers are sending WindowUpdate frames depending on their flow control algorithm. + # Since we cannot predict these frames, and we do not need to respond to them, + # simply accept them, and hide them from the log. + # Ideally we should keep track of our own flow control window and + # stall transmission if the outgoing flow control buffer is full. + return + if isinstance(frame, GoAwayFrame): + # Server wants to terminate the connection, + # relay it to the client. + self.client_conn.send(frame.to_bytes()) + return + self.log("Unexpected HTTP2 frame from server: %s" % frame.human_readable(), "info") class ConnectServerConnection(object): @@ -245,20 +305,22 @@ class UpstreamConnectLayer(Layer): else: pass # swallow the message - def reconnect(self): - self.ctx.reconnect() - self._send_connect_request() + def change_upstream_proxy_server(self, address): + if address != self.server_conn.via.address: + self.ctx.set_server(address) - def set_server(self, address, server_tls=None, sni=None, depth=1): - if depth == 1: - if self.ctx.server_conn: - self.ctx.reconnect() - address = Address.wrap(address) - self.connect_request.host = address.host - self.connect_request.port = address.port - self.server_conn.address = address - else: - self.ctx.set_server(address, server_tls, sni, depth - 1) + def set_server(self, address, server_tls=None, sni=None): + if self.ctx.server_conn: + self.ctx.disconnect() + address = Address.wrap(address) + self.connect_request.host = address.host + self.connect_request.port = address.port + self.server_conn.address = address + + if server_tls: + raise ProtocolException( + "Cannot upgrade to TLS, no TLS layer on the protocol stack." + ) class HttpLayer(Layer): @@ -308,7 +370,13 @@ class HttpLayer(Layer): if self.check_close_connection(flow): return - # TODO: Implement HTTP Upgrade + # Handle 101 Switching Protocols + # It may be useful to pass additional args (such as the upgrade header) + # to next_layer in the future + if flow.response.status_code == 101: + layer = self.ctx.next_layer(self) + layer() + return # Upstream Proxy Mode: Handle CONNECT if flow.request.form_in == "authority" and flow.response.code == 200: @@ -327,12 +395,18 @@ class HttpLayer(Layer): except NetLibError: pass if isinstance(e, ProtocolException): - raise e + six.reraise(ProtocolException, e, sys.exc_info()[2]) else: - raise ProtocolException("Error in HTTP connection: %s" % repr(e), e) + six.reraise(ProtocolException, ProtocolException("Error in HTTP connection: %s" % repr(e)), sys.exc_info()[2]) finally: flow.live = False + def change_upstream_proxy_server(self, address): + # Make set_upstream_proxy_server always available, + # even if there's no UpstreamConnectLayer + if address != self.server_conn.address: + return self.set_server(address) + def handle_regular_mode_connect(self, request): self.set_server((request.host, request.port)) self.send_response(make_connect_response(request.httpversion)) @@ -343,36 +417,6 @@ class HttpLayer(Layer): layer = UpstreamConnectLayer(self, connect_request) layer() - def check_close_connection(self, flow): - """ - Checks if the connection should be closed depending on the HTTP - semantics. Returns True, if so. - """ - - # TODO: add logic for HTTP/2 - - close_connection = ( - http1.HTTP1Protocol.connection_close( - flow.request.httpversion, - flow.request.headers - ) or http1.HTTP1Protocol.connection_close( - flow.response.httpversion, - flow.response.headers - ) or http1.HTTP1Protocol.expected_http_body_size( - flow.response.headers, - False, - flow.request.method, - flow.response.code) == -1 - ) - if flow.request.form_in == "authority" and flow.response.code == 200: - # Workaround for - # https://github.com/mitmproxy/mitmproxy/issues/313: Some - # proxies (e.g. Charles) send a CONNECT response with HTTP/1.0 - # and no Content-Length header - - return False - return close_connection - def send_response_to_client(self, flow): if not (self.supports_streaming and flow.response.stream): # no streaming: @@ -420,7 +464,8 @@ class HttpLayer(Layer): # > server detects timeout, disconnects # > read (100-n)% of large request # > send large request upstream - self.reconnect() + self.disconnect() + self.connect() get_response() # call the appropriate script hook - this is an opportunity for an @@ -482,7 +527,7 @@ class HttpLayer(Layer): if self.mode == "regular" or self.mode == "transparent": # If there's an existing connection that doesn't match our expectations, kill it. - if address != self.server_conn.address or tls != self.server_conn.ssl_established: + if address != self.server_conn.address or tls != self.server_conn.tls_established: self.set_server(address, tls, address.host) # Establish connection is neccessary. if not self.server_conn: @@ -495,10 +540,12 @@ class HttpLayer(Layer): """ # This is a very ugly (untested) workaround to solve a very ugly problem. if self.server_conn and self.server_conn.tls_established and not ssl: - self.reconnect() + self.disconnect() + self.connect() elif ssl and not hasattr(self, "connected_to") or self.connected_to != address: if self.server_conn.tls_established: - self.reconnect() + self.disconnect() + self.connect() self.send_request(make_connect_request(address)) tls_layer = TlsLayer(self, False, True) @@ -536,10 +583,6 @@ class HttpLayer(Layer): self.send_response(make_error_response( 407, "Proxy Authentication Required", - odict.ODictCaseless( - [ - [k, v] for k, v in - self.config.authenticator.auth_challenge_headers().items() - ]) + Headers(**self.config.authenticator.auth_challenge_headers()) )) raise InvalidCredentials("Proxy Authentication Required") diff --git a/libmproxy/protocol/http_replay.py b/libmproxy/protocol/http_replay.py index 2759a0198..a9ee55069 100644 --- a/libmproxy/protocol/http_replay.py +++ b/libmproxy/protocol/http_replay.py @@ -6,7 +6,7 @@ from netlib.http.http1 import HTTP1Protocol from netlib.tcp import NetLibError from ..controller import Channel from ..models import Error, HTTPResponse, ServerConnection, make_connect_request -from .base import Log, Kill +from .base import Kill # TODO: Doesn't really belong into libmproxy.protocol... @@ -89,8 +89,9 @@ class RequestReplayThread(threading.Thread): if self.channel: self.channel.ask("error", self.flow) except Kill: - # KillSignal should only be raised if there's a channel in the + # Kill should only be raised if there's a channel in the # first place. + from ..proxy.root_context import Log self.channel.tell("log", Log("Connection killed", "info")) finally: r.form_out = form_out_backup diff --git a/libmproxy/protocol/rawtcp.py b/libmproxy/protocol/rawtcp.py index 864687734..9b155412a 100644 --- a/libmproxy/protocol/rawtcp.py +++ b/libmproxy/protocol/rawtcp.py @@ -1,10 +1,12 @@ from __future__ import (absolute_import, print_function, division) import socket import select +import six +import sys from OpenSSL import SSL -from netlib.tcp import NetLibError +from netlib.tcp import NetLibError, ssl_read_select from netlib.utils import cleanBin from ..exceptions import ProtocolException from .base import Layer @@ -28,7 +30,7 @@ class RawTCPLayer(Layer): try: while True: - r, _, _ = select.select(conns, [], [], 10) + r = ssl_read_select(conns, 10) for conn in r: dst = server if conn == client else client @@ -63,4 +65,8 @@ class RawTCPLayer(Layer): ) except (socket.error, NetLibError, SSL.Error) as e: - raise ProtocolException("TCP connection closed unexpectedly: {}".format(repr(e)), e) + six.reraise( + ProtocolException, + ProtocolException("TCP connection closed unexpectedly: {}".format(repr(e))), + sys.exc_info()[2] + ) diff --git a/libmproxy/protocol/tls.py b/libmproxy/protocol/tls.py index a8dc8bb2a..4f7c9300a 100644 --- a/libmproxy/protocol/tls.py +++ b/libmproxy/protocol/tls.py @@ -1,16 +1,210 @@ from __future__ import (absolute_import, print_function, division) import struct +import sys from construct import ConstructError +import six from netlib.tcp import NetLibError, NetLibInvalidCertificateError from netlib.http.http1 import HTTP1Protocol from ..contrib.tls._constructs import ClientHello -from ..exceptions import ProtocolException +from ..exceptions import ProtocolException, TlsException, ClientHandshakeException from .base import Layer + +# taken from https://testssl.sh/openssl-rfc.mappping.html +CIPHER_ID_NAME_MAP = { + 0x00: 'NULL-MD5', + 0x01: 'NULL-MD5', + 0x02: 'NULL-SHA', + 0x03: 'EXP-RC4-MD5', + 0x04: 'RC4-MD5', + 0x05: 'RC4-SHA', + 0x06: 'EXP-RC2-CBC-MD5', + 0x07: 'IDEA-CBC-SHA', + 0x08: 'EXP-DES-CBC-SHA', + 0x09: 'DES-CBC-SHA', + 0x0a: 'DES-CBC3-SHA', + 0x0b: 'EXP-DH-DSS-DES-CBC-SHA', + 0x0c: 'DH-DSS-DES-CBC-SHA', + 0x0d: 'DH-DSS-DES-CBC3-SHA', + 0x0e: 'EXP-DH-RSA-DES-CBC-SHA', + 0x0f: 'DH-RSA-DES-CBC-SHA', + 0x10: 'DH-RSA-DES-CBC3-SHA', + 0x11: 'EXP-EDH-DSS-DES-CBC-SHA', + 0x12: 'EDH-DSS-DES-CBC-SHA', + 0x13: 'EDH-DSS-DES-CBC3-SHA', + 0x14: 'EXP-EDH-RSA-DES-CBC-SHA', + 0x15: 'EDH-RSA-DES-CBC-SHA', + 0x16: 'EDH-RSA-DES-CBC3-SHA', + 0x17: 'EXP-ADH-RC4-MD5', + 0x18: 'ADH-RC4-MD5', + 0x19: 'EXP-ADH-DES-CBC-SHA', + 0x1a: 'ADH-DES-CBC-SHA', + 0x1b: 'ADH-DES-CBC3-SHA', + # 0x1c: , + # 0x1d: , + 0x1e: 'KRB5-DES-CBC-SHA', + 0x1f: 'KRB5-DES-CBC3-SHA', + 0x20: 'KRB5-RC4-SHA', + 0x21: 'KRB5-IDEA-CBC-SHA', + 0x22: 'KRB5-DES-CBC-MD5', + 0x23: 'KRB5-DES-CBC3-MD5', + 0x24: 'KRB5-RC4-MD5', + 0x25: 'KRB5-IDEA-CBC-MD5', + 0x26: 'EXP-KRB5-DES-CBC-SHA', + 0x27: 'EXP-KRB5-RC2-CBC-SHA', + 0x28: 'EXP-KRB5-RC4-SHA', + 0x29: 'EXP-KRB5-DES-CBC-MD5', + 0x2a: 'EXP-KRB5-RC2-CBC-MD5', + 0x2b: 'EXP-KRB5-RC4-MD5', + 0x2f: 'AES128-SHA', + 0x30: 'DH-DSS-AES128-SHA', + 0x31: 'DH-RSA-AES128-SHA', + 0x32: 'DHE-DSS-AES128-SHA', + 0x33: 'DHE-RSA-AES128-SHA', + 0x34: 'ADH-AES128-SHA', + 0x35: 'AES256-SHA', + 0x36: 'DH-DSS-AES256-SHA', + 0x37: 'DH-RSA-AES256-SHA', + 0x38: 'DHE-DSS-AES256-SHA', + 0x39: 'DHE-RSA-AES256-SHA', + 0x3a: 'ADH-AES256-SHA', + 0x3b: 'NULL-SHA256', + 0x3c: 'AES128-SHA256', + 0x3d: 'AES256-SHA256', + 0x3e: 'DH-DSS-AES128-SHA256', + 0x3f: 'DH-RSA-AES128-SHA256', + 0x40: 'DHE-DSS-AES128-SHA256', + 0x41: 'CAMELLIA128-SHA', + 0x42: 'DH-DSS-CAMELLIA128-SHA', + 0x43: 'DH-RSA-CAMELLIA128-SHA', + 0x44: 'DHE-DSS-CAMELLIA128-SHA', + 0x45: 'DHE-RSA-CAMELLIA128-SHA', + 0x46: 'ADH-CAMELLIA128-SHA', + 0x62: 'EXP1024-DES-CBC-SHA', + 0x63: 'EXP1024-DHE-DSS-DES-CBC-SHA', + 0x64: 'EXP1024-RC4-SHA', + 0x65: 'EXP1024-DHE-DSS-RC4-SHA', + 0x66: 'DHE-DSS-RC4-SHA', + 0x67: 'DHE-RSA-AES128-SHA256', + 0x68: 'DH-DSS-AES256-SHA256', + 0x69: 'DH-RSA-AES256-SHA256', + 0x6a: 'DHE-DSS-AES256-SHA256', + 0x6b: 'DHE-RSA-AES256-SHA256', + 0x6c: 'ADH-AES128-SHA256', + 0x6d: 'ADH-AES256-SHA256', + 0x80: 'GOST94-GOST89-GOST89', + 0x81: 'GOST2001-GOST89-GOST89', + 0x82: 'GOST94-NULL-GOST94', + 0x83: 'GOST2001-GOST89-GOST89', + 0x84: 'CAMELLIA256-SHA', + 0x85: 'DH-DSS-CAMELLIA256-SHA', + 0x86: 'DH-RSA-CAMELLIA256-SHA', + 0x87: 'DHE-DSS-CAMELLIA256-SHA', + 0x88: 'DHE-RSA-CAMELLIA256-SHA', + 0x89: 'ADH-CAMELLIA256-SHA', + 0x8a: 'PSK-RC4-SHA', + 0x8b: 'PSK-3DES-EDE-CBC-SHA', + 0x8c: 'PSK-AES128-CBC-SHA', + 0x8d: 'PSK-AES256-CBC-SHA', + # 0x8e: , + # 0x8f: , + # 0x90: , + # 0x91: , + # 0x92: , + # 0x93: , + # 0x94: , + # 0x95: , + 0x96: 'SEED-SHA', + 0x97: 'DH-DSS-SEED-SHA', + 0x98: 'DH-RSA-SEED-SHA', + 0x99: 'DHE-DSS-SEED-SHA', + 0x9a: 'DHE-RSA-SEED-SHA', + 0x9b: 'ADH-SEED-SHA', + 0x9c: 'AES128-GCM-SHA256', + 0x9d: 'AES256-GCM-SHA384', + 0x9e: 'DHE-RSA-AES128-GCM-SHA256', + 0x9f: 'DHE-RSA-AES256-GCM-SHA384', + 0xa0: 'DH-RSA-AES128-GCM-SHA256', + 0xa1: 'DH-RSA-AES256-GCM-SHA384', + 0xa2: 'DHE-DSS-AES128-GCM-SHA256', + 0xa3: 'DHE-DSS-AES256-GCM-SHA384', + 0xa4: 'DH-DSS-AES128-GCM-SHA256', + 0xa5: 'DH-DSS-AES256-GCM-SHA384', + 0xa6: 'ADH-AES128-GCM-SHA256', + 0xa7: 'ADH-AES256-GCM-SHA384', + 0x5600: 'TLS_FALLBACK_SCSV', + 0xc001: 'ECDH-ECDSA-NULL-SHA', + 0xc002: 'ECDH-ECDSA-RC4-SHA', + 0xc003: 'ECDH-ECDSA-DES-CBC3-SHA', + 0xc004: 'ECDH-ECDSA-AES128-SHA', + 0xc005: 'ECDH-ECDSA-AES256-SHA', + 0xc006: 'ECDHE-ECDSA-NULL-SHA', + 0xc007: 'ECDHE-ECDSA-RC4-SHA', + 0xc008: 'ECDHE-ECDSA-DES-CBC3-SHA', + 0xc009: 'ECDHE-ECDSA-AES128-SHA', + 0xc00a: 'ECDHE-ECDSA-AES256-SHA', + 0xc00b: 'ECDH-RSA-NULL-SHA', + 0xc00c: 'ECDH-RSA-RC4-SHA', + 0xc00d: 'ECDH-RSA-DES-CBC3-SHA', + 0xc00e: 'ECDH-RSA-AES128-SHA', + 0xc00f: 'ECDH-RSA-AES256-SHA', + 0xc010: 'ECDHE-RSA-NULL-SHA', + 0xc011: 'ECDHE-RSA-RC4-SHA', + 0xc012: 'ECDHE-RSA-DES-CBC3-SHA', + 0xc013: 'ECDHE-RSA-AES128-SHA', + 0xc014: 'ECDHE-RSA-AES256-SHA', + 0xc015: 'AECDH-NULL-SHA', + 0xc016: 'AECDH-RC4-SHA', + 0xc017: 'AECDH-DES-CBC3-SHA', + 0xc018: 'AECDH-AES128-SHA', + 0xc019: 'AECDH-AES256-SHA', + 0xc01a: 'SRP-3DES-EDE-CBC-SHA', + 0xc01b: 'SRP-RSA-3DES-EDE-CBC-SHA', + 0xc01c: 'SRP-DSS-3DES-EDE-CBC-SHA', + 0xc01d: 'SRP-AES-128-CBC-SHA', + 0xc01e: 'SRP-RSA-AES-128-CBC-SHA', + 0xc01f: 'SRP-DSS-AES-128-CBC-SHA', + 0xc020: 'SRP-AES-256-CBC-SHA', + 0xc021: 'SRP-RSA-AES-256-CBC-SHA', + 0xc022: 'SRP-DSS-AES-256-CBC-SHA', + 0xc023: 'ECDHE-ECDSA-AES128-SHA256', + 0xc024: 'ECDHE-ECDSA-AES256-SHA384', + 0xc025: 'ECDH-ECDSA-AES128-SHA256', + 0xc026: 'ECDH-ECDSA-AES256-SHA384', + 0xc027: 'ECDHE-RSA-AES128-SHA256', + 0xc028: 'ECDHE-RSA-AES256-SHA384', + 0xc029: 'ECDH-RSA-AES128-SHA256', + 0xc02a: 'ECDH-RSA-AES256-SHA384', + 0xc02b: 'ECDHE-ECDSA-AES128-GCM-SHA256', + 0xc02c: 'ECDHE-ECDSA-AES256-GCM-SHA384', + 0xc02d: 'ECDH-ECDSA-AES128-GCM-SHA256', + 0xc02e: 'ECDH-ECDSA-AES256-GCM-SHA384', + 0xc02f: 'ECDHE-RSA-AES128-GCM-SHA256', + 0xc030: 'ECDHE-RSA-AES256-GCM-SHA384', + 0xc031: 'ECDH-RSA-AES128-GCM-SHA256', + 0xc032: 'ECDH-RSA-AES256-GCM-SHA384', + 0xcc13: 'ECDHE-RSA-CHACHA20-POLY1305', + 0xcc14: 'ECDHE-ECDSA-CHACHA20-POLY1305', + 0xcc15: 'DHE-RSA-CHACHA20-POLY1305', + 0xff00: 'GOST-MD5', + 0xff01: 'GOST-GOST94', + 0xff02: 'GOST-GOST89MAC', + 0xff03: 'GOST-GOST89STREAM', + 0x010080: 'RC4-MD5', + 0x020080: 'EXP-RC4-MD5', + 0x030080: 'RC2-CBC-MD5', + 0x040080: 'EXP-RC2-CBC-MD5', + 0x050080: 'IDEA-CBC-MD5', + 0x060040: 'DES-CBC-MD5', + 0x0700c0: 'DES-CBC3-MD5', + 0x080080: 'RC4-64-MD5', +} + + def is_tls_record_magic(d): """ Returns: @@ -47,8 +241,8 @@ class TlsLayer(Layer): If so, we first connect to the server and then to the client. If not, we only connect to the client and do the server_ssl lazily on a Connect message. - An additional complexity is that establish ssl with the server may require a SNI value from the client. - In an ideal world, we'd do the following: + An additional complexity is that establish ssl with the server may require a SNI value from + the client. In an ideal world, we'd do the following: 1. Start the SSL handshake with the client 2. Check if the client sends a SNI. 3. Pause the client handshake, establish SSL with the server. @@ -100,11 +294,11 @@ class TlsLayer(Layer): while len(client_hello) < client_hello_size: record_header = self.client_conn.rfile.peek(offset + 5)[offset:] if not is_tls_record_magic(record_header) or len(record_header) != 5: - raise ProtocolException('Expected TLS record, got "%s" instead.' % record_header) + raise TlsException('Expected TLS record, got "%s" instead.' % record_header) record_size = struct.unpack("!H", record_header[3:])[0] + 5 record_body = self.client_conn.rfile.peek(offset + record_size)[offset + 5:] if len(record_body) != record_size - 5: - raise ProtocolException("Unexpected EOF in TLS handshake: %s" % record_body) + raise TlsException("Unexpected EOF in TLS handshake: %s" % record_body) client_hello += record_body offset += record_size client_hello_size = struct.unpack("!I", '\x00' + client_hello[1:4])[0] + 4 @@ -127,6 +321,8 @@ class TlsLayer(Layer): self.log("Raw Client Hello:\r\n:%s" % raw_client_hello.encode("hex"), "debug") return + self.client_ciphers = client_hello.cipher_suites.cipher_suites + for extension in client_hello.extensions: if extension.type == 0x00: if len(extension.server_names) != 1 or extension.server_names[0].type != 0: @@ -146,18 +342,11 @@ class TlsLayer(Layer): if self._server_tls and not self.server_conn.tls_established: self._establish_tls_with_server() - def reconnect(self): - self.ctx.reconnect() - if self._server_tls and not self.server_conn.tls_established: - self._establish_tls_with_server() - - def set_server(self, address, server_tls=None, sni=None, depth=1): - if depth == 1 and server_tls is not None: - self.ctx.set_server(address, None, None, 1) + def set_server(self, address, server_tls=None, sni=None): + if server_tls is not None: self._sni_from_server_change = sni self._server_tls = server_tls - else: - self.ctx.set_server(address, server_tls, sni, depth) + self.ctx.set_server(address, None, None) @property def sni_for_server_connection(self): @@ -201,7 +390,7 @@ class TlsLayer(Layer): self._establish_tls_with_client() except: pass - raise e + six.reraise(*sys.exc_info()) self._establish_tls_with_client() @@ -219,8 +408,22 @@ class TlsLayer(Layer): chain_file=chain_file, alpn_select_callback=self.__alpn_select_callback, ) + # Some TLS clients will not fail the handshake, + # but will immediately throw an "unexpected eof" error on the first read. + # The reason for this might be difficult to find, so we try to peek here to see if it + # raises ann error. + self.client_conn.rfile.peek(1) except NetLibError as e: - raise ProtocolException("Cannot establish TLS with client: %s" % repr(e), e) + six.reraise( + ClientHandshakeException, + ClientHandshakeException( + "Cannot establish TLS with client (sni: {sni}): {e}".format( + sni=self.client_sni, e=repr(e) + ), + self.client_sni or repr(self.server_conn.address) + ), + sys.exc_info()[2] + ) def _establish_tls_with_server(self): self.log("Establish TLS with server", "debug") @@ -230,9 +433,19 @@ class TlsLayer(Layer): # and mitmproxy would enter TCP passthrough mode, which we want to avoid. deprecated_http2_variant = lambda x: x.startswith("h2-") or x.startswith("spdy") if self.client_alpn_protocols: - alpn = filter(lambda x: not deprecated_http2_variant(x), self.client_alpn_protocols) + alpn = [x for x in self.client_alpn_protocols if not deprecated_http2_variant(x)] else: alpn = None + if alpn and "h2" in alpn and not self.config.http2 : + alpn.remove("h2") + + ciphers_server = self.config.ciphers_server + if not ciphers_server: + ciphers_server = [] + for id in self.client_ciphers: + if id in CIPHER_ID_NAME_MAP.keys(): + ciphers_server.append(CIPHER_ID_NAME_MAP[id]) + ciphers_server = ':'.join(ciphers_server) self.server_conn.establish_ssl( self.config.clientcerts, @@ -242,7 +455,7 @@ class TlsLayer(Layer): verify_options=self.config.openssl_verification_mode_server, ca_path=self.config.openssl_trusted_cadir_server, ca_pemfile=self.config.openssl_trusted_ca_server, - cipher_list=self.config.ciphers_server, + cipher_list=ciphers_server, alpn_protos=alpn, ) tls_cert_err = self.server_conn.ssl_verification_error @@ -259,17 +472,25 @@ class TlsLayer(Layer): (tls_cert_err['depth'], tls_cert_err['errno']), "error") self.log("Aborting connection attempt", "error") - raise ProtocolException("Cannot establish TLS with {address} (sni: {sni}): {e}".format( - address=repr(self.server_conn.address), - sni=self.sni_for_server_connection, - e=repr(e), - ), e) + six.reraise( + TlsException, + TlsException("Cannot establish TLS with {address} (sni: {sni}): {e}".format( + address=repr(self.server_conn.address), + sni=self.sni_for_server_connection, + e=repr(e), + )), + sys.exc_info()[2] + ) except NetLibError as e: - raise ProtocolException("Cannot establish TLS with {address} (sni: {sni}): {e}".format( - address=repr(self.server_conn.address), - sni=self.sni_for_server_connection, - e=repr(e), - ), e) + six.reraise( + TlsException, + TlsException("Cannot establish TLS with {address} (sni: {sni}): {e}".format( + address=repr(self.server_conn.address), + sni=self.sni_for_server_connection, + e=repr(e), + )), + sys.exc_info()[2] + ) self.log("ALPN selected by server: %s" % self.alpn_for_client_connection, "debug") @@ -294,5 +515,4 @@ class TlsLayer(Layer): if self._sni_from_server_change: sans.add(self._sni_from_server_change) - sans.discard(host) return self.config.certstore.get_cert(host, list(sans)) diff --git a/libmproxy/proxy/__init__.py b/libmproxy/proxy/__init__.py index d5297cb1d..be7f52073 100644 --- a/libmproxy/proxy/__init__.py +++ b/libmproxy/proxy/__init__.py @@ -2,8 +2,10 @@ from __future__ import (absolute_import, print_function, division) from .server import ProxyServer, DummyServer from .config import ProxyConfig +from .root_context import RootContext, Log __all__ = [ "ProxyServer", "DummyServer", "ProxyConfig", + "RootContext", "Log", ] diff --git a/libmproxy/proxy/config.py b/libmproxy/proxy/config.py index 2a1b84cb5..cd9eda5ab 100644 --- a/libmproxy/proxy/config.py +++ b/libmproxy/proxy/config.py @@ -54,6 +54,8 @@ class ProxyConfig: authenticator=None, ignore_hosts=tuple(), tcp_hosts=tuple(), + http2=False, + rawtcp=False, ciphers_client=None, ciphers_server=None, certs=tuple(), @@ -78,6 +80,8 @@ class ProxyConfig: self.check_ignore = HostMatcher(ignore_hosts) self.check_tcp = HostMatcher(tcp_hosts) + self.http2 = http2 + self.rawtcp = rawtcp self.authenticator = authenticator self.cadir = os.path.expanduser(cadir) self.certstore = certutils.CertStore.from_store( @@ -183,6 +187,8 @@ def process_proxy_options(parser, options): upstream_server=upstream_server, ignore_hosts=options.ignore_hosts, tcp_hosts=options.tcp_hosts, + http2=options.http2, + rawtcp=options.rawtcp, authenticator=authenticator, ciphers_client=options.ciphers_client, ciphers_server=options.ciphers_server, @@ -192,4 +198,4 @@ def process_proxy_options(parser, options): ssl_verify_upstream_cert=options.ssl_verify_upstream_cert, ssl_verify_upstream_trusted_cadir=options.ssl_verify_upstream_trusted_cadir, ssl_verify_upstream_trusted_ca=options.ssl_verify_upstream_trusted_ca - ) \ No newline at end of file + ) diff --git a/libmproxy/proxy/modes/http_proxy.py b/libmproxy/proxy/modes/http_proxy.py index 90c54cc6c..c7502c24f 100644 --- a/libmproxy/proxy/modes/http_proxy.py +++ b/libmproxy/proxy/modes/http_proxy.py @@ -10,7 +10,7 @@ class HttpProxy(Layer, ServerConnectionMixin): layer() finally: if self.server_conn: - self._disconnect() + self.disconnect() class HttpUpstreamProxy(Layer, ServerConnectionMixin): @@ -23,4 +23,4 @@ class HttpUpstreamProxy(Layer, ServerConnectionMixin): layer() finally: if self.server_conn: - self._disconnect() + self.disconnect() diff --git a/libmproxy/proxy/modes/reverse_proxy.py b/libmproxy/proxy/modes/reverse_proxy.py index b57ac5eb1..28f4e6f85 100644 --- a/libmproxy/proxy/modes/reverse_proxy.py +++ b/libmproxy/proxy/modes/reverse_proxy.py @@ -14,4 +14,4 @@ class ReverseProxy(Layer, ServerConnectionMixin): layer() finally: if self.server_conn: - self._disconnect() + self.disconnect() diff --git a/libmproxy/proxy/modes/socks_proxy.py b/libmproxy/proxy/modes/socks_proxy.py index ebaf939ea..545c38d65 100644 --- a/libmproxy/proxy/modes/socks_proxy.py +++ b/libmproxy/proxy/modes/socks_proxy.py @@ -48,7 +48,7 @@ class Socks5Proxy(Layer, ServerConnectionMixin): self.client_conn.wfile.flush() except (socks.SocksError, NetLibError) as e: - raise Socks5Exception("SOCKS5 mode failure: %s" % repr(e), e) + raise Socks5Exception("SOCKS5 mode failure: %s" % repr(e)) self.server_conn.address = connect_request.addr @@ -57,4 +57,4 @@ class Socks5Proxy(Layer, ServerConnectionMixin): layer() finally: if self.server_conn: - self._disconnect() + self.disconnect() diff --git a/libmproxy/proxy/modes/transparent_proxy.py b/libmproxy/proxy/modes/transparent_proxy.py index 96ad86c41..da1d46326 100644 --- a/libmproxy/proxy/modes/transparent_proxy.py +++ b/libmproxy/proxy/modes/transparent_proxy.py @@ -14,11 +14,11 @@ class TransparentProxy(Layer, ServerConnectionMixin): try: self.server_conn.address = self.resolver.original_addr(self.client_conn.connection) except Exception as e: - raise ProtocolException("Transparent mode failure: %s" % repr(e), e) + raise ProtocolException("Transparent mode failure: %s" % repr(e)) layer = self.ctx.next_layer(self) try: layer() finally: if self.server_conn: - self._disconnect() + self.disconnect() diff --git a/libmproxy/proxy/root_context.py b/libmproxy/proxy/root_context.py index 35909612c..54bea1db4 100644 --- a/libmproxy/proxy/root_context.py +++ b/libmproxy/proxy/root_context.py @@ -1,8 +1,13 @@ from __future__ import (absolute_import, print_function, division) +import string +import sys +import six + +from libmproxy.exceptions import ProtocolException from netlib.http.http1 import HTTP1Protocol from netlib.http.http2 import HTTP2Protocol - +from netlib.tcp import NetLibError from ..protocol import ( RawTCPLayer, TlsLayer, Http1Layer, Http2Layer, is_tls_record_magic, ServerConnectionMixin ) @@ -11,31 +16,47 @@ from .modes import HttpProxy, HttpUpstreamProxy, ReverseProxy class RootContext(object): """ - The outmost context provided to the root layer. - As a consequence, every layer has .client_conn, .channel, .next_layer() and .config. + The outermost context provided to the root layer. + As a consequence, every layer has access to methods and attributes defined here. + + Attributes: + client_conn: + The :py:class:`client connection `. + channel: + A :py:class:`~libmproxy.controller.Channel` to communicate with the FlowMaster. + Provides :py:meth:`.ask() ` and + :py:meth:`.tell() ` methods. + config: + The :py:class:`proxy server's configuration ` """ def __init__(self, client_conn, config, channel): - self.client_conn = client_conn # Client Connection - self.channel = channel # provides .ask() method to communicate with FlowMaster - self.config = config # Proxy Configuration + self.client_conn = client_conn + self.channel = channel + self.config = config def next_layer(self, top_layer): """ This function determines the next layer in the protocol stack. Arguments: - top_layer: the current top layer. + top_layer: the current innermost layer. Returns: The next layer """ + layer = self._next_layer(top_layer) + return self.channel.ask("next_layer", layer) + def _next_layer(self, top_layer): # 1. Check for --ignore. if self.config.check_ignore(top_layer.server_conn.address): return RawTCPLayer(top_layer, logging=False) - d = top_layer.client_conn.rfile.peek(3) + try: + d = top_layer.client_conn.rfile.peek(3) + except NetLibError as e: + six.reraise(ProtocolException, ProtocolException(str(e)), sys.exc_info()[2]) client_tls = is_tls_record_magic(d) # 2. Always insert a TLS layer, even if there's neither client nor server tls. @@ -69,21 +90,30 @@ class RootContext(object): if alpn == HTTP1Protocol.ALPN_PROTO_HTTP1: return Http1Layer(top_layer, 'transparent') - # 6. Assume HTTP1 by default + # 6. Check for raw tcp mode + is_ascii = ( + len(d) == 3 and + # better be safe here and don't expect uppercase... + all(x in string.ascii_letters for x in d) + ) + if self.config.rawtcp and not is_ascii: + return RawTCPLayer(top_layer) + + # 7. Assume HTTP1 by default return Http1Layer(top_layer, 'transparent') - # In a future version, we want to implement TCP passthrough as the last fallback, - # but we don't have the UI part ready for that. - # - # d = top_layer.client_conn.rfile.peek(3) - # is_ascii = ( - # len(d) == 3 and - # # better be safe here and don't expect uppercase... - # all(x in string.ascii_letters for x in d) - # ) - # # TODO: This could block if there are not enough bytes available? - # d = top_layer.client_conn.rfile.peek(len(HTTP2Protocol.CLIENT_CONNECTION_PREFACE)) - # is_http2_magic = (d == HTTP2Protocol.CLIENT_CONNECTION_PREFACE) + def log(self, msg, level, subs=()): + """ + Send a log message to the master. + """ + + full_msg = [ + "{}: {}".format(repr(self.client_conn.address), msg) + ] + for i in subs: + full_msg.append(" -> " + i) + full_msg = "\n".join(full_msg) + self.channel.tell("log", Log(full_msg, level)) @property def layers(self): @@ -91,3 +121,9 @@ class RootContext(object): def __repr__(self): return "RootContext" + + +class Log(object): + def __init__(self, msg, level="info"): + self.msg = msg + self.level = level diff --git a/libmproxy/proxy/server.py b/libmproxy/proxy/server.py index e9e8df092..88448172a 100644 --- a/libmproxy/proxy/server.py +++ b/libmproxy/proxy/server.py @@ -3,15 +3,16 @@ from __future__ import (absolute_import, print_function, division) import traceback import sys import socket +import six from netlib import tcp from netlib.http.http1 import HTTP1Protocol from netlib.tcp import NetLibError -from ..exceptions import ProtocolException, ServerException -from ..protocol import Log, Kill +from ..exceptions import ProtocolException, ServerException, ClientHandshakeException +from ..protocol import Kill from ..models import ClientConnection, make_error_response from .modes import HttpUpstreamProxy, HttpProxy, ReverseProxy, TransparentProxy, Socks5Proxy -from .root_context import RootContext +from .root_context import RootContext, Log class DummyServer: @@ -39,7 +40,11 @@ class ProxyServer(tcp.TCPServer): try: super(ProxyServer, self).__init__((config.host, config.port)) except socket.error as e: - raise ServerException('Error starting proxy server: ' + repr(e), e) + six.reraise( + ServerException, + ServerException('Error starting proxy server: ' + repr(e)), + sys.exc_info()[2] + ) self.channel = None def start_slave(self, klass, channel): @@ -116,7 +121,18 @@ class ConnectionHandler(object): except Kill: self.log("Connection killed", "info") except ProtocolException as e: - self.log(e, "info") + + if isinstance(e, ClientHandshakeException): + self.log( + "Client Handshake failed. " + "The client may not trust the proxy's certificate for {}.".format(e.server), + "error" + ) + self.log(repr(e), "debug") + else: + self.log(repr(e), "error") + + self.log(traceback.format_exc(), "debug") # If an error propagates to the topmost level, # we send an HTTP error response, which is both # understandable by HTTP clients and humans. @@ -137,4 +153,4 @@ class ConnectionHandler(object): def log(self, msg, level): msg = "{}: {}".format(repr(self.client_conn.address), msg) - self.channel.tell("log", Log(msg, level)) \ No newline at end of file + self.channel.tell("log", Log(msg, level)) diff --git a/libmproxy/script.py b/libmproxy/script.py index e13f0e2b6..b4ecfbbfc 100644 --- a/libmproxy/script.py +++ b/libmproxy/script.py @@ -95,8 +95,8 @@ class Script: """ if self.ns is not None: self.unload() - ns = {} script_dir = os.path.dirname(os.path.abspath(self.args[0])) + ns = {'__file__': os.path.abspath(self.args[0])} sys.path.append(script_dir) try: execfile(self.args[0], ns, ns) @@ -179,7 +179,8 @@ def concurrent(fn): "error", "clientconnect", "serverconnect", - "clientdisconnect"): + "clientdisconnect", + "next_layer"): def _concurrent(ctx, obj): _handle_concurrent_reply(fn, obj, ctx, obj) diff --git a/libmproxy/web/app.py b/libmproxy/web/app.py index d6082ee24..2517e7ad2 100644 --- a/libmproxy/web/app.py +++ b/libmproxy/web/app.py @@ -27,8 +27,7 @@ class RequestHandler(tornado.web.RequestHandler): @property def json(self): - if not self.request.headers.get( - "Content-Type").startswith("application/json"): + if not self.request.headers.get("Content-Type").startswith("application/json"): return None return json.loads(self.request.body) @@ -186,12 +185,12 @@ class FlowContent(RequestHandler): if not message.content: raise APIError(400, "No content.") - content_encoding = message.headers.get_first("Content-Encoding", None) + content_encoding = message.headers.get("Content-Encoding", None) if content_encoding: content_encoding = re.sub(r"[^\w]", "", content_encoding) self.set_header("Content-Encoding", content_encoding) - original_cd = message.headers.get_first("Content-Disposition", None) + original_cd = message.headers.get("Content-Disposition", None) filename = None if original_cd: filename = re.search("filename=([\w\" \.\-\(\)]+)", original_cd) diff --git a/setup.py b/setup.py index e28033ad1..896d02480 100644 --- a/setup.py +++ b/setup.py @@ -83,7 +83,8 @@ setup( "Topic :: Internet", "Topic :: Internet :: WWW/HTTP", "Topic :: Internet :: Proxy Servers", - "Topic :: Software Development :: Testing"], + "Topic :: Software Development :: Testing" + ], packages=find_packages(), include_package_data=True, entry_points={ @@ -94,8 +95,13 @@ setup( 'contentviews': [ "pyamf>=0.6.1", "protobuf>=2.5.0", - "cssutils>=1.0"], + "cssutils>=1.0" + ], 'examples': [ "pytz", "harparser", - "beautifulsoup4"]}) + "beautifulsoup4", + "enum34" + ] + } +) diff --git a/test/test_console_contentview.py b/test/test_console_contentview.py index 50c3f7668..d44a3cf4a 100644 --- a/test/test_console_contentview.py +++ b/test/test_console_contentview.py @@ -1,11 +1,13 @@ import os from nose.plugins.skip import SkipTest +from netlib.http import Headers + if os.name == "nt": raise SkipTest("Skipped on Windows.") import sys import netlib.utils -from netlib import odict, encoding +from netlib import encoding import libmproxy.contentview as cv from libmproxy import utils, flow @@ -33,34 +35,28 @@ class TestContentView: def test_view_auto(self): v = cv.ViewAuto() f = v( - odict.ODictCaseless(), + Headers(), "foo", 1000 ) assert f[0] == "Raw" f = v( - odict.ODictCaseless( - [["content-type", "text/html"]], - ), + Headers(content_type="text/html"), "", 1000 ) assert f[0] == "HTML" f = v( - odict.ODictCaseless( - [["content-type", "text/flibble"]], - ), + Headers(content_type="text/flibble"), "foo", 1000 ) assert f[0] == "Raw" f = v( - odict.ODictCaseless( - [["content-type", "text/flibble"]], - ), + Headers(content_type="text/flibble"), "", 1000 ) @@ -168,28 +164,22 @@ Content-Disposition: form-data; name="submit-name" Larry --AaB03x """.strip() - h = odict.ODictCaseless( - [("Content-Type", "multipart/form-data; boundary=AaB03x")] - ) + h = Headers(content_type="multipart/form-data; boundary=AaB03x") assert view(h, v, 1000) - h = odict.ODictCaseless() + h = Headers() assert not view(h, v, 1000) - h = odict.ODictCaseless( - [("Content-Type", "multipart/form-data")] - ) + h = Headers(content_type="multipart/form-data") assert not view(h, v, 1000) - h = odict.ODictCaseless( - [("Content-Type", "unparseable")] - ) + h = Headers(content_type="unparseable") assert not view(h, v, 1000) def test_get_content_view(self): r = cv.get_content_view( cv.get("Raw"), - [["content-type", "application/json"]], + Headers(content_type="application/json"), "[1, 2, 3]", 1000, False @@ -198,7 +188,7 @@ Larry r = cv.get_content_view( cv.get("Auto"), - [["content-type", "application/json"]], + Headers(content_type="application/json"), "[1, 2, 3]", 1000, False @@ -207,7 +197,7 @@ Larry r = cv.get_content_view( cv.get("Auto"), - [["content-type", "application/json"]], + Headers(content_type="application/json"), "[1, 2", 1000, False @@ -216,7 +206,7 @@ Larry r = cv.get_content_view( cv.get("AMF"), - [], + Headers(), "[1, 2", 1000, False @@ -225,10 +215,10 @@ Larry r = cv.get_content_view( cv.get("Auto"), - [ - ["content-type", "application/json"], - ["content-encoding", "gzip"] - ], + Headers( + content_type="application/json", + content_encoding="gzip" + ), encoding.encode('gzip', "[1, 2, 3]"), 1000, False @@ -238,10 +228,10 @@ Larry r = cv.get_content_view( cv.get("XML"), - [ - ["content-type", "application/json"], - ["content-encoding", "gzip"] - ], + Headers( + content_type="application/json", + content_encoding="gzip" + ), encoding.encode('gzip', "[1, 2, 3]"), 1000, False diff --git a/test/test_dump.py b/test/test_dump.py index a0ad6cb4c..c76f555f2 100644 --- a/test/test_dump.py +++ b/test/test_dump.py @@ -6,7 +6,7 @@ import netlib.tutils from netlib.http.semantics import CONTENT_MISSING from libmproxy import dump, flow -from libmproxy.protocol import Log +from libmproxy.proxy import Log import tutils import mock @@ -145,7 +145,7 @@ class TestDumpMaster: o = dump.Options(setheaders=[(".*", "one", "two")]) m = dump.DumpMaster(None, o, outfile=cs) f = self._cycle(m, "content") - assert f.request.headers["one"] == ["two"] + assert f.request.headers["one"] == "two" def test_basic(self): for i in (1, 2, 3): diff --git a/test/test_filt.py b/test/test_filt.py index aeec24857..76e107103 100644 --- a/test/test_filt.py +++ b/test/test_filt.py @@ -1,8 +1,8 @@ import cStringIO -from netlib import odict from libmproxy import filt, flow from libmproxy.protocol import http from libmproxy.models import Error +from netlib.http import Headers import tutils @@ -76,8 +76,7 @@ class TestParsing: class TestMatching: def req(self): - headers = odict.ODictCaseless() - headers["header"] = ["qvalue"] + headers = Headers(header="qvalue") req = http.HTTPRequest( "absolute", "GET", @@ -98,8 +97,7 @@ class TestMatching: def resp(self): f = self.req() - headers = odict.ODictCaseless() - headers["header_response"] = ["svalue"] + headers = Headers([["header_response", "svalue"]]) f.response = http.HTTPResponse( (1, 1), @@ -123,7 +121,7 @@ class TestMatching: def test_asset(self): s = self.resp() assert not self.q("~a", s) - s.response.headers["content-type"] = ["text/javascript"] + s.response.headers["content-type"] = "text/javascript" assert self.q("~a", s) def test_fcontenttype(self): @@ -132,16 +130,16 @@ class TestMatching: assert not self.q("~t content", q) assert not self.q("~t content", s) - q.request.headers["content-type"] = ["text/json"] + q.request.headers["content-type"] = "text/json" assert self.q("~t json", q) assert self.q("~tq json", q) assert not self.q("~ts json", q) - s.response.headers["content-type"] = ["text/json"] + s.response.headers["content-type"] = "text/json" assert self.q("~t json", s) del s.response.headers["content-type"] - s.request.headers["content-type"] = ["text/json"] + s.request.headers["content-type"] = "text/json" assert self.q("~t json", s) assert self.q("~tq json", s) assert not self.q("~ts json", s) diff --git a/test/test_flow.py b/test/test_flow.py index 9cce26b35..c93beca40 100644 --- a/test/test_flow.py +++ b/test/test_flow.py @@ -8,7 +8,7 @@ import mock import netlib.utils from netlib import odict -from netlib.http.semantics import CONTENT_MISSING, HDR_FORM_URLENCODED +from netlib.http.semantics import CONTENT_MISSING, HDR_FORM_URLENCODED, Headers from libmproxy import filt, protocol, controller, tnetstring, flow from libmproxy.models import Error, Flow, HTTPRequest, HTTPResponse, HTTPFlow, decoded from libmproxy.proxy.config import HostMatcher @@ -34,7 +34,7 @@ def test_app_registry(): r.host = "domain2" r.port = 80 assert not ar.get(r) - r.headers["host"] = ["domain"] + r.headers["host"] = "domain" assert ar.get(r) @@ -42,7 +42,7 @@ class TestStickyCookieState: def _response(self, cookie, host): s = flow.StickyCookieState(filt.parse(".*")) f = tutils.tflow(req=netlib.tutils.treq(host=host, port=80), resp=True) - f.response.headers["Set-Cookie"] = [cookie] + f.response.headers["Set-Cookie"] = cookie s.handle_response(f) return s, f @@ -75,13 +75,13 @@ class TestStickyAuthState: def test_handle_response(self): s = flow.StickyAuthState(filt.parse(".*")) f = tutils.tflow(resp=True) - f.request.headers["authorization"] = ["foo"] + f.request.headers["authorization"] = "foo" s.handle_request(f) assert "address" in s.hosts f = tutils.tflow(resp=True) s.handle_request(f) - assert f.request.headers["authorization"] == ["foo"] + assert f.request.headers["authorization"] == "foo" class TestClientPlaybackState: @@ -133,7 +133,7 @@ class TestServerPlaybackState: assert s._hash(r) assert s._hash(r) == s._hash(r2) - r.request.headers["foo"] = ["bar"] + r.request.headers["foo"] = "bar" assert s._hash(r) == s._hash(r2) r.request.path = "voing" assert s._hash(r) != s._hash(r2) @@ -153,12 +153,12 @@ class TestServerPlaybackState: None, False) r = tutils.tflow(resp=True) - r.request.headers["foo"] = ["bar"] + r.request.headers["foo"] = "bar" r2 = tutils.tflow(resp=True) assert not s._hash(r) == s._hash(r2) - r2.request.headers["foo"] = ["bar"] + r2.request.headers["foo"] = "bar" assert s._hash(r) == s._hash(r2) - r2.request.headers["oink"] = ["bar"] + r2.request.headers["oink"] = "bar" assert s._hash(r) == s._hash(r2) r = tutils.tflow(resp=True) @@ -167,10 +167,10 @@ class TestServerPlaybackState: def test_load(self): r = tutils.tflow(resp=True) - r.request.headers["key"] = ["one"] + r.request.headers["key"] = "one" r2 = tutils.tflow(resp=True) - r2.request.headers["key"] = ["two"] + r2.request.headers["key"] = "two" s = flow.ServerPlaybackState( None, [ @@ -179,21 +179,21 @@ class TestServerPlaybackState: assert len(s.fmap.keys()) == 1 n = s.next_flow(r) - assert n.request.headers["key"] == ["one"] + assert n.request.headers["key"] == "one" assert s.count() == 1 n = s.next_flow(r) - assert n.request.headers["key"] == ["two"] + assert n.request.headers["key"] == "two" assert s.count() == 0 assert not s.next_flow(r) def test_load_with_nopop(self): r = tutils.tflow(resp=True) - r.request.headers["key"] = ["one"] + r.request.headers["key"] = "one" r2 = tutils.tflow(resp=True) - r2.request.headers["key"] = ["two"] + r2.request.headers["key"] = "two" s = flow.ServerPlaybackState( None, [ @@ -224,12 +224,10 @@ class TestServerPlaybackState: None, [], False, False, None, False, [ "param1", "param2"], False) r = tutils.tflow(resp=True) - r.request.headers[ - "Content-Type"] = ["application/x-www-form-urlencoded"] + r.request.headers["Content-Type"] = "application/x-www-form-urlencoded" r.request.content = "paramx=x¶m1=1" r2 = tutils.tflow(resp=True) - r2.request.headers[ - "Content-Type"] = ["application/x-www-form-urlencoded"] + r2.request.headers["Content-Type"] = "application/x-www-form-urlencoded" r2.request.content = "paramx=x¶m1=1" # same parameters assert s._hash(r) == s._hash(r2) @@ -254,10 +252,10 @@ class TestServerPlaybackState: None, [], False, False, None, False, [ "param1", "param2"], False) r = tutils.tflow(resp=True) - r.request.headers["Content-Type"] = ["application/json"] + r.request.headers["Content-Type"] = "application/json" r.request.content = '{"param1":"1"}' r2 = tutils.tflow(resp=True) - r2.request.headers["Content-Type"] = ["application/json"] + r2.request.headers["Content-Type"] = "application/json" r2.request.content = '{"param1":"1"}' # same content assert s._hash(r) == s._hash(r2) @@ -271,12 +269,10 @@ class TestServerPlaybackState: None, [], False, False, None, True, [ "param1", "param2"], False) r = tutils.tflow(resp=True) - r.request.headers[ - "Content-Type"] = ["application/x-www-form-urlencoded"] + r.request.headers["Content-Type"] = "application/x-www-form-urlencoded" r.request.content = "paramx=y" r2 = tutils.tflow(resp=True) - r2.request.headers[ - "Content-Type"] = ["application/x-www-form-urlencoded"] + r2.request.headers["Content-Type"] = "application/x-www-form-urlencoded" r2.request.content = "paramx=x" # same parameters assert s._hash(r) == s._hash(r2) @@ -460,17 +456,17 @@ class TestFlow: def test_replace(self): f = tutils.tflow(resp=True) - f.request.headers["foo"] = ["foo"] + f.request.headers["foo"] = "foo" f.request.content = "afoob" - f.response.headers["foo"] = ["foo"] + f.response.headers["foo"] = "foo" f.response.content = "afoob" assert f.replace("foo", "bar") == 6 - assert f.request.headers["bar"] == ["bar"] + assert f.request.headers["bar"] == "bar" assert f.request.content == "abarb" - assert f.response.headers["bar"] == ["bar"] + assert f.response.headers["bar"] == "bar" assert f.response.content == "abarb" def test_replace_encoded(self): @@ -938,14 +934,14 @@ class TestFlowMaster: fm.set_stickycookie(".*") f = tutils.tflow(resp=True) - f.response.headers["set-cookie"] = ["foo=bar"] + f.response.headers["set-cookie"] = "foo=bar" fm.handle_request(f) fm.handle_response(f) assert fm.stickycookie_state.jar assert not "cookie" in f.request.headers f = f.copy() fm.handle_request(f) - assert f.request.headers["cookie"] == ["foo=bar"] + assert f.request.headers["cookie"] == "foo=bar" def test_stickyauth(self): s = flow.State() @@ -958,14 +954,14 @@ class TestFlowMaster: fm.set_stickyauth(".*") f = tutils.tflow(resp=True) - f.request.headers["authorization"] = ["foo"] + f.request.headers["authorization"] = "foo" fm.handle_request(f) f = tutils.tflow(resp=True) assert fm.stickyauth_state.hosts assert not "authorization" in f.request.headers fm.handle_request(f) - assert f.request.headers["authorization"] == ["foo"] + assert f.request.headers["authorization"] == "foo" def test_stream(self): with tutils.tmpdir() as tdir: @@ -1022,7 +1018,7 @@ class TestRequest: assert r.url == "https://address:22/path" assert r.pretty_url(True) == "https://address:22/path" - r.headers["Host"] = ["foo.com"] + r.headers["Host"] = "foo.com" assert r.pretty_url(False) == "https://address:22/path" assert r.pretty_url(True) == "https://foo.com:22/path" @@ -1048,19 +1044,17 @@ class TestRequest: def test_getset_form_urlencoded(self): d = odict.ODict([("one", "two"), ("three", "four")]) r = HTTPRequest.wrap(netlib.tutils.treq(content=netlib.utils.urlencode(d.lst))) - r.headers["content-type"] = [HDR_FORM_URLENCODED] + r.headers["content-type"] = HDR_FORM_URLENCODED assert r.get_form_urlencoded() == d d = odict.ODict([("x", "y")]) r.set_form_urlencoded(d) assert r.get_form_urlencoded() == d - r.headers["content-type"] = ["foo"] + r.headers["content-type"] = "foo" assert not r.get_form_urlencoded() def test_getset_query(self): - h = odict.ODictCaseless() - r = HTTPRequest.wrap(netlib.tutils.treq()) r.path = "/foo?x=y&a=b" q = r.get_query() @@ -1083,11 +1077,10 @@ class TestRequest: assert r.get_query() == qv def test_anticache(self): - h = odict.ODictCaseless() r = HTTPRequest.wrap(netlib.tutils.treq()) - r.headers = h - h["if-modified-since"] = ["test"] - h["if-none-match"] = ["test"] + r.headers = Headers() + r.headers["if-modified-since"] = "test" + r.headers["if-none-match"] = "test" r.anticache() assert not "if-modified-since" in r.headers assert not "if-none-match" in r.headers @@ -1095,25 +1088,29 @@ class TestRequest: def test_replace(self): r = HTTPRequest.wrap(netlib.tutils.treq()) r.path = "path/foo" - r.headers["Foo"] = ["fOo"] + r.headers["Foo"] = "fOo" r.content = "afoob" assert r.replace("foo(?i)", "boo") == 4 assert r.path == "path/boo" assert not "foo" in r.content - assert r.headers["boo"] == ["boo"] + assert r.headers["boo"] == "boo" def test_constrain_encoding(self): r = HTTPRequest.wrap(netlib.tutils.treq()) - r.headers["accept-encoding"] = ["gzip", "oink"] + r.headers["accept-encoding"] = "gzip, oink" + r.constrain_encoding() + assert "oink" not in r.headers["accept-encoding"] + + r.headers.set_all("accept-encoding", ["gzip", "oink"]) r.constrain_encoding() assert "oink" not in r.headers["accept-encoding"] def test_decodeencode(self): r = HTTPRequest.wrap(netlib.tutils.treq()) - r.headers["content-encoding"] = ["identity"] + r.headers["content-encoding"] = "identity" r.content = "falafel" r.decode() - assert not r.headers["content-encoding"] + assert "content-encoding" not in r.headers assert r.content == "falafel" r = HTTPRequest.wrap(netlib.tutils.treq()) @@ -1121,26 +1118,26 @@ class TestRequest: assert not r.decode() r = HTTPRequest.wrap(netlib.tutils.treq()) - r.headers["content-encoding"] = ["identity"] + r.headers["content-encoding"] = "identity" r.content = "falafel" r.encode("identity") - assert r.headers["content-encoding"] == ["identity"] + assert r.headers["content-encoding"] == "identity" assert r.content == "falafel" r = HTTPRequest.wrap(netlib.tutils.treq()) - r.headers["content-encoding"] = ["identity"] + r.headers["content-encoding"] = "identity" r.content = "falafel" r.encode("gzip") - assert r.headers["content-encoding"] == ["gzip"] + assert r.headers["content-encoding"] == "gzip" assert r.content != "falafel" r.decode() - assert not r.headers["content-encoding"] + assert "content-encoding" not in r.headers assert r.content == "falafel" def test_get_decoded_content(self): r = HTTPRequest.wrap(netlib.tutils.treq()) r.content = None - r.headers["content-encoding"] = ["identity"] + r.headers["content-encoding"] = "identity" assert r.get_decoded_content() == None r.content = "falafel" @@ -1148,11 +1145,9 @@ class TestRequest: assert r.get_decoded_content() == "falafel" def test_get_content_type(self): - h = odict.ODictCaseless() - h["Content-Type"] = ["text/plain"] resp = HTTPResponse.wrap(netlib.tutils.tresp()) - resp.headers = h - assert resp.headers.get_first("content-type") == "text/plain" + resp.headers = Headers(content_type="text/plain") + assert resp.headers["content-type"] == "text/plain" class TestResponse: @@ -1165,19 +1160,18 @@ class TestResponse: def test_refresh(self): r = HTTPResponse.wrap(netlib.tutils.tresp()) n = time.time() - r.headers["date"] = [email.utils.formatdate(n)] + r.headers["date"] = email.utils.formatdate(n) pre = r.headers["date"] r.refresh(n) assert pre == r.headers["date"] r.refresh(n + 60) - d = email.utils.parsedate_tz(r.headers["date"][0]) + d = email.utils.parsedate_tz(r.headers["date"]) d = email.utils.mktime_tz(d) # Weird that this is not exact... assert abs(60 - (d - n)) <= 1 - r.headers[ - "set-cookie"] = ["MOO=BAR; Expires=Tue, 08-Mar-2011 00:20:38 GMT; Path=foo.com; Secure"] + r.headers["set-cookie"] = "MOO=BAR; Expires=Tue, 08-Mar-2011 00:20:38 GMT; Path=foo.com; Secure" r.refresh() def test_refresh_cookie(self): @@ -1192,47 +1186,45 @@ class TestResponse: def test_replace(self): r = HTTPResponse.wrap(netlib.tutils.tresp()) - r.headers["Foo"] = ["fOo"] + r.headers["Foo"] = "fOo" r.content = "afoob" assert r.replace("foo(?i)", "boo") == 3 assert not "foo" in r.content - assert r.headers["boo"] == ["boo"] + assert r.headers["boo"] == "boo" def test_decodeencode(self): r = HTTPResponse.wrap(netlib.tutils.tresp()) - r.headers["content-encoding"] = ["identity"] + r.headers["content-encoding"] = "identity" r.content = "falafel" assert r.decode() - assert not r.headers["content-encoding"] + assert "content-encoding" not in r.headers assert r.content == "falafel" r = HTTPResponse.wrap(netlib.tutils.tresp()) - r.headers["content-encoding"] = ["identity"] + r.headers["content-encoding"] = "identity" r.content = "falafel" r.encode("identity") - assert r.headers["content-encoding"] == ["identity"] + assert r.headers["content-encoding"] == "identity" assert r.content == "falafel" r = HTTPResponse.wrap(netlib.tutils.tresp()) - r.headers["content-encoding"] = ["identity"] + r.headers["content-encoding"] = "identity" r.content = "falafel" r.encode("gzip") - assert r.headers["content-encoding"] == ["gzip"] + assert r.headers["content-encoding"] == "gzip" assert r.content != "falafel" assert r.decode() - assert not r.headers["content-encoding"] + assert "content-encoding" not in r.headers assert r.content == "falafel" - r.headers["content-encoding"] = ["gzip"] + r.headers["content-encoding"] = "gzip" assert not r.decode() assert r.content == "falafel" def test_get_content_type(self): - h = odict.ODictCaseless() - h["Content-Type"] = ["text/plain"] resp = HTTPResponse.wrap(netlib.tutils.tresp()) - resp.headers = h - assert resp.headers.get_first("content-type") == "text/plain" + resp.headers = Headers(content_type="text/plain") + assert resp.headers["content-type"] == "text/plain" class TestError: @@ -1276,12 +1268,12 @@ class TestClientConnection: def test_decoded(): r = HTTPRequest.wrap(netlib.tutils.treq()) assert r.content == "content" - assert not r.headers["content-encoding"] + assert "content-encoding" not in r.headers r.encode("gzip") assert r.headers["content-encoding"] assert r.content != "content" with decoded(r): - assert not r.headers["content-encoding"] + assert "content-encoding" not in r.headers assert r.content == "content" assert r.headers["content-encoding"] assert r.content != "content" @@ -1378,18 +1370,18 @@ def test_setheaders(): h.add("~s", "one", "two") h.add("~s", "one", "three") f = tutils.tflow(resp=True) - f.request.headers["one"] = ["xxx"] - f.response.headers["one"] = ["xxx"] + f.request.headers["one"] = "xxx" + f.response.headers["one"] = "xxx" h.run(f) - assert f.request.headers["one"] == ["xxx"] - assert f.response.headers["one"] == ["two", "three"] + assert f.request.headers["one"] == "xxx" + assert f.response.headers.get_all("one") == ["two", "three"] h.clear() h.add("~q", "one", "two") h.add("~q", "one", "three") f = tutils.tflow() - f.request.headers["one"] = ["xxx"] + f.request.headers["one"] = "xxx" h.run(f) - assert f.request.headers["one"] == ["two", "three"] + assert f.request.headers.get_all("one") == ["two", "three"] assert not h.add("~", "foo", "bar") diff --git a/test/test_protocol_http.py b/test/test_protocol_http.py index cd0f77fa7..f53d43cfa 100644 --- a/test/test_protocol_http.py +++ b/test/test_protocol_http.py @@ -5,7 +5,6 @@ from mock import MagicMock from libmproxy.protocol.http import * import netlib.http -from netlib import odict from netlib.http import http1 from netlib.http.semantics import CONTENT_MISSING diff --git a/test/test_server.py b/test/test_server.py index a1259b7fc..492587916 100644 --- a/test/test_server.py +++ b/test/test_server.py @@ -623,8 +623,7 @@ class MasterRedirectRequest(tservers.TestMaster): def handle_response(self, f): f.response.content = str(f.client_conn.address.port) - f.response.headers[ - "server-conn-id"] = [str(f.server_conn.source_address.port)] + f.response.headers["server-conn-id"] = str(f.server_conn.source_address.port) super(MasterRedirectRequest, self).handle_response(f) @@ -712,7 +711,7 @@ class TestStreamRequest(tservers.HTTPProxTest): connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM) connection.connect(("127.0.0.1", self.proxy.port)) fconn = connection.makefile() - spec = '200:h"Transfer-Encoding"="chunked":r:b"4\\r\\nthis\\r\\n7\\r\\nisatest\\r\\n0\\r\\n\\r\\n"' + spec = '200:h"Transfer-Encoding"="chunked":r:b"4\\r\\nthis\\r\\n11\\r\\nisatest__reachhex\\r\\n0\\r\\n\\r\\n"' connection.send( "GET %s/p/%s HTTP/1.1\r\n" % (self.server.urlbase, spec)) @@ -721,13 +720,13 @@ class TestStreamRequest(tservers.HTTPProxTest): protocol = http.http1.HTTP1Protocol(rfile=fconn) resp = protocol.read_response("GET", None, include_body=False) - assert resp.headers["Transfer-Encoding"][0] == 'chunked' + assert resp.headers["Transfer-Encoding"] == 'chunked' assert resp.status_code == 200 chunks = list(protocol.read_http_body_chunked( resp.headers, None, "GET", 200, False )) - assert chunks == ["this", "isatest", ""] + assert chunks == ["this", "isatest__reachhex"] connection.close() @@ -743,7 +742,7 @@ class TestFakeResponse(tservers.HTTPProxTest): def test_fake(self): f = self.pathod("200") - assert "header_response" in f.headers.keys() + assert "header_response" in f.headers class TestServerConnect(tservers.HTTPProxTest): diff --git a/test/test_utils.py b/test/test_utils.py index 0cda23b40..d2bd97e14 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,6 +1,5 @@ import json from libmproxy import utils -from netlib import odict import tutils utils.CERT_SLEEP_TIME = 0