Merge branch 'master' into on-issues

This commit is contained in:
Nikhil Soni 2017-03-03 12:58:44 +05:30 committed by GitHub
commit 0081d9b828
56 changed files with 1674 additions and 680 deletions

View File

@ -42,7 +42,7 @@ iOS
See http://jasdev.me/intercepting-ios-traffic See http://jasdev.me/intercepting-ios-traffic
and http://web.archive.org/web/20150920082614/http://kb.mit.edu/confluence/pages/viewpage.action?pageId=152600377 and https://web.archive.org/web/20150920082614/http://kb.mit.edu/confluence/pages/viewpage.action?pageId=152600377
iOS Simulator iOS Simulator
^^^^^^^^^^^^^ ^^^^^^^^^^^^^
@ -52,7 +52,7 @@ See https://github.com/ADVTOOLS/ADVTrustStore#how-to-use-advtruststore
Java Java
^^^^ ^^^^
See http://docs.oracle.com/cd/E19906-01/820-4916/geygn/index.html See https://docs.oracle.com/cd/E19906-01/820-4916/geygn/index.html
Android/Android Simulator Android/Android Simulator
^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^
@ -62,7 +62,7 @@ See http://wiki.cacert.org/FAQ/ImportRootCert#Android_Phones_.26_Tablets
Windows Windows
^^^^^^^ ^^^^^^^
See http://windows.microsoft.com/en-ca/windows/import-export-certificates-private-keys#1TC=windows-7 See https://web.archive.org/web/20160612045445/http://windows.microsoft.com/en-ca/windows/import-export-certificates-private-keys#1TC=windows-7
Windows (automated) Windows (automated)
^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^
@ -79,7 +79,7 @@ See https://support.apple.com/kb/PH7297?locale=en_US
Ubuntu/Debian Ubuntu/Debian
^^^^^^^^^^^^^ ^^^^^^^^^^^^^
See http://askubuntu.com/questions/73287/how-do-i-install-a-root-certificate/94861#94861 See https://askubuntu.com/questions/73287/how-do-i-install-a-root-certificate/94861#94861
Mozilla Firefox Mozilla Firefox
^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^
@ -89,7 +89,7 @@ See https://wiki.mozilla.org/MozillaRootCertificate#Mozilla_Firefox
Chrome on Linux Chrome on Linux
^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^
See https://code.google.com/p/chromium/wiki/LinuxCertManagement See https://stackoverflow.com/a/15076602/198996
The mitmproxy certificate authority The mitmproxy certificate authority
@ -205,4 +205,4 @@ directory and uses this as the client cert.
.. _Certificate Pinning: http://security.stackexchange.com/questions/29988/what-is-certificate-pinning/ .. _Certificate Pinning: https://security.stackexchange.com/questions/29988/what-is-certificate-pinning/

View File

@ -10,7 +10,7 @@ import zlib
import os import os
from datetime import datetime from datetime import datetime
import pytz from datetime import timezone
import mitmproxy import mitmproxy
@ -89,7 +89,7 @@ def response(flow):
# Timings set to -1 will be ignored as per spec. # Timings set to -1 will be ignored as per spec.
full_time = sum(v for v in timings.values() if v > -1) full_time = sum(v for v in timings.values() if v > -1)
started_date_time = format_datetime(datetime.utcfromtimestamp(flow.request.timestamp_start)) started_date_time = datetime.fromtimestamp(flow.request.timestamp_start, timezone.utc).isoformat()
# Response body size and encoding # Response body size and encoding
response_body_size = len(flow.response.raw_content) response_body_size = len(flow.response.raw_content)
@ -173,10 +173,6 @@ def done():
mitmproxy.ctx.log("HAR dump finished (wrote %s bytes to file)" % len(json_dump)) mitmproxy.ctx.log("HAR dump finished (wrote %s bytes to file)" % len(json_dump))
def format_datetime(dt):
return dt.replace(tzinfo=pytz.timezone("UTC")).isoformat()
def format_cookies(cookie_list): def format_cookies(cookie_list):
rv = [] rv = []
@ -198,7 +194,7 @@ def format_cookies(cookie_list):
# Expiration time needs to be formatted # Expiration time needs to be formatted
expire_ts = cookies.get_expiration_ts(attrs) expire_ts = cookies.get_expiration_ts(attrs)
if expire_ts is not None: if expire_ts is not None:
cookie_har["expires"] = format_datetime(datetime.fromtimestamp(expire_ts)) cookie_har["expires"] = datetime.fromtimestamp(expire_ts, timezone.utc).isoformat()
rv.append(cookie_har) rv.append(cookie_har)

407
examples/complex/xss_scanner.py Executable file
View File

@ -0,0 +1,407 @@
"""
__ __ _____ _____ _____
\ \ / // ____/ ____| / ____|
\ V /| (___| (___ | (___ ___ __ _ _ __ _ __ ___ _ __
> < \___ \\___ \ \___ \ / __/ _` | '_ \| '_ \ / _ \ '__|
/ . \ ____) |___) | ____) | (_| (_| | | | | | | | __/ |
/_/ \_\_____/_____/ |_____/ \___\__,_|_| |_|_| |_|\___|_|
This script automatically scans all visited webpages for XSS and SQLi vulnerabilities.
Usage: mitmproxy -s xss_scanner.py
This script scans for vulnerabilities by injecting a fuzzing payload (see PAYLOAD below) into 4 different places
and examining the HTML to look for XSS and SQLi injection vulnerabilities. The XSS scanning functionality works by
looking to see whether it is possible to inject HTML based off of of where the payload appears in the page and what
characters are escaped. In addition, it also looks for any script tags that load javascript from unclaimed domains.
The SQLi scanning functionality works by using regular expressions to look for errors from a number of different
common databases. Since it is only looking for errors, it will not find blind SQLi vulnerabilities.
The 4 places it injects the payload into are:
1. URLs (e.g. https://example.com/ -> https://example.com/PAYLOAD/)
2. Queries (e.g. https://example.com/index.html?a=b -> https://example.com/index.html?a=PAYLOAD)
3. Referers (e.g. The referer changes from https://example.com to PAYLOAD)
4. User Agents (e.g. The UA changes from Chrome to PAYLOAD)
Reports from this script show up in the event log (viewable by pressing e) and formatted like:
===== XSS Found ====
XSS URL: http://daviddworken.com/vulnerableUA.php
Injection Point: User Agent
Suggested Exploit: <script>alert(0)</script>
Line: 1029zxcs'd"ao<ac>so[sb]po(pc)se;sl/bsl\eq=3847asd
"""
from mitmproxy import ctx
from socket import gaierror, gethostbyname
from urllib.parse import urlparse
import requests
import re
from html.parser import HTMLParser
from mitmproxy import http
from typing import Dict, Union, Tuple, Optional, List, NamedTuple
# The actual payload is put between a frontWall and a backWall to make it easy
# to locate the payload with regular expressions
FRONT_WALL = b"1029zxc"
BACK_WALL = b"3847asd"
PAYLOAD = b"""s'd"ao<ac>so[sb]po(pc)se;sl/bsl\\eq="""
FULL_PAYLOAD = FRONT_WALL + PAYLOAD + BACK_WALL
# A XSSData is a named tuple with the following fields:
# - url -> str
# - injection_point -> str
# - exploit -> str
# - line -> str
XSSData = NamedTuple('XSSData', [('url', str),
('injection_point', str),
('exploit', str),
('line', str)])
# A SQLiData is named tuple with the following fields:
# - url -> str
# - injection_point -> str
# - regex -> str
# - dbms -> str
SQLiData = NamedTuple('SQLiData', [('url', str),
('injection_point', str),
('regex', str),
('dbms', str)])
VulnData = Tuple[Optional[XSSData], Optional[SQLiData]]
Cookies = Dict[str, str]
def get_cookies(flow: http.HTTPFlow) -> Cookies:
""" Return a dict going from cookie names to cookie values
- Note that it includes both the cookies sent in the original request and
the cookies sent by the server """
return {name: value for name, value in flow.request.cookies.fields}
def find_unclaimed_URLs(body: Union[str, bytes], requestUrl: bytes) -> None:
""" Look for unclaimed URLs in script tags and log them if found"""
class ScriptURLExtractor(HTMLParser):
script_URLs = []
def handle_starttag(self, tag, attrs):
if tag == "script" and "src" in [name for name, value in attrs]:
for name, value in attrs:
if name == "src":
self.script_URLs.append(value)
parser = ScriptURLExtractor()
try:
parser.feed(body)
except TypeError:
parser.feed(body.decode('utf-8'))
for url in parser.script_URLs:
parser = urlparse(url)
domain = parser.netloc
try:
gethostbyname(domain)
except gaierror:
ctx.log.error("XSS found in %s due to unclaimed URL \"%s\" in script tag." % (requestUrl, url))
def test_end_of_URL_injection(original_body: str, request_URL: str, cookies: Cookies) -> VulnData:
""" Test the given URL for XSS via injection onto the end of the URL and
log the XSS if found """
parsed_URL = urlparse(request_URL)
path = parsed_URL.path
if path != "" and path[-1] != "/": # ensure the path ends in a /
path += "/"
path += FULL_PAYLOAD.decode('utf-8') # the path must be a string while the payload is bytes
url = parsed_URL._replace(path=path).geturl()
body = requests.get(url, cookies=cookies).text.lower()
xss_info = get_XSS_data(body, url, "End of URL")
sqli_info = get_SQLi_data(body, original_body, url, "End of URL")
return xss_info, sqli_info
def test_referer_injection(original_body: str, request_URL: str, cookies: Cookies) -> VulnData:
""" Test the given URL for XSS via injection into the referer and
log the XSS if found """
body = requests.get(request_URL, headers={'referer': FULL_PAYLOAD}, cookies=cookies).text.lower()
xss_info = get_XSS_data(body, request_URL, "Referer")
sqli_info = get_SQLi_data(body, original_body, request_URL, "Referer")
return xss_info, sqli_info
def test_user_agent_injection(original_body: str, request_URL: str, cookies: Cookies) -> VulnData:
""" Test the given URL for XSS via injection into the user agent and
log the XSS if found """
body = requests.get(request_URL, headers={'User-Agent': FULL_PAYLOAD}, cookies=cookies).text.lower()
xss_info = get_XSS_data(body, request_URL, "User Agent")
sqli_info = get_SQLi_data(body, original_body, request_URL, "User Agent")
return xss_info, sqli_info
def test_query_injection(original_body: str, request_URL: str, cookies: Cookies):
""" Test the given URL for XSS via injection into URL queries and
log the XSS if found """
parsed_URL = urlparse(request_URL)
query_string = parsed_URL.query
# queries is a list of parameters where each parameter is set to the payload
queries = [query.split("=")[0] + "=" + FULL_PAYLOAD.decode('utf-8') for query in query_string.split("&")]
new_query_string = "&".join(queries)
new_URL = parsed_URL._replace(query=new_query_string).geturl()
body = requests.get(new_URL, cookies=cookies).text.lower()
xss_info = get_XSS_data(body, new_URL, "Query")
sqli_info = get_SQLi_data(body, original_body, new_URL, "Query")
return xss_info, sqli_info
def log_XSS_data(xss_info: Optional[XSSData]) -> None:
""" Log information about the given XSS to mitmproxy """
# If it is None, then there is no info to log
if not xss_info:
return
ctx.log.error("===== XSS Found ====")
ctx.log.error("XSS URL: %s" % xss_info.url)
ctx.log.error("Injection Point: %s" % xss_info.injection_point)
ctx.log.error("Suggested Exploit: %s" % xss_info.exploit)
ctx.log.error("Line: %s" % xss_info.line)
def log_SQLi_data(sqli_info: Optional[SQLiData]) -> None:
""" Log information about the given SQLi to mitmproxy """
if not sqli_info:
return
ctx.log.error("===== SQLi Found =====")
ctx.log.error("SQLi URL: %s" % sqli_info.url.decode('utf-8'))
ctx.log.error("Injection Point: %s" % sqli_info.injection_point.decode('utf-8'))
ctx.log.error("Regex used: %s" % sqli_info.regex.decode('utf-8'))
ctx.log.error("Suspected DBMS: %s" % sqli_info.dbms.decode('utf-8'))
def get_SQLi_data(new_body: str, original_body: str, request_URL: str, injection_point: str) -> Optional[SQLiData]:
""" Return a SQLiDict if there is a SQLi otherwise return None
String String URL String -> (SQLiDict or None) """
# Regexes taken from Damn Small SQLi Scanner: https://github.com/stamparm/DSSS/blob/master/dsss.py#L17
DBMS_ERRORS = {
"MySQL": (r"SQL syntax.*MySQL", r"Warning.*mysql_.*", r"valid MySQL result", r"MySqlClient\."),
"PostgreSQL": (r"PostgreSQL.*ERROR", r"Warning.*\Wpg_.*", r"valid PostgreSQL result", r"Npgsql\."),
"Microsoft SQL Server": (r"Driver.* SQL[\-\_\ ]*Server", r"OLE DB.* SQL Server", r"(\W|\A)SQL Server.*Driver",
r"Warning.*mssql_.*", r"(\W|\A)SQL Server.*[0-9a-fA-F]{8}",
r"(?s)Exception.*\WSystem\.Data\.SqlClient\.", r"(?s)Exception.*\WRoadhouse\.Cms\."),
"Microsoft Access": (r"Microsoft Access Driver", r"JET Database Engine", r"Access Database Engine"),
"Oracle": (r"\bORA-[0-9][0-9][0-9][0-9]", r"Oracle error", r"Oracle.*Driver", r"Warning.*\Woci_.*", r"Warning.*\Wora_.*"),
"IBM DB2": (r"CLI Driver.*DB2", r"DB2 SQL error", r"\bdb2_\w+\("),
"SQLite": (r"SQLite/JDBCDriver", r"SQLite.Exception", r"System.Data.SQLite.SQLiteException", r"Warning.*sqlite_.*",
r"Warning.*SQLite3::", r"\[SQLITE_ERROR\]"),
"Sybase": (r"(?i)Warning.*sybase.*", r"Sybase message", r"Sybase.*Server message.*"),
}
for dbms, regexes in DBMS_ERRORS.items():
for regex in regexes:
if re.search(regex, new_body) and not re.search(regex, original_body):
return SQLiData(request_URL,
injection_point,
regex,
dbms)
# A qc is either ' or "
def inside_quote(qc: str, substring: bytes, text_index: int, body: bytes) -> bool:
""" Whether the Numberth occurence of the first string in the second
string is inside quotes as defined by the supplied QuoteChar """
substring = substring.decode('utf-8')
body = body.decode('utf-8')
num_substrings_found = 0
in_quote = False
for index, char in enumerate(body):
# Whether the next chunk of len(substring) chars is the substring
next_part_is_substring = (
(not (index + len(substring) > len(body))) and
(body[index:index + len(substring)] == substring)
)
# Whether this char is escaped with a \
is_not_escaped = (
(index - 1 < 0 or index - 1 > len(body)) or
(body[index - 1] != "\\")
)
if char == qc and is_not_escaped:
in_quote = not in_quote
if next_part_is_substring:
if num_substrings_found == text_index:
return in_quote
num_substrings_found += 1
return False
def paths_to_text(html: str, str: str) -> List[str]:
""" Return list of Paths to a given str in the given HTML tree
- Note that it does a BFS """
def remove_last_occurence_of_sub_string(str: str, substr: str):
""" Delete the last occurence of substr from str
String String -> String
"""
index = str.rfind(substr)
return str[:index] + str[index + len(substr):]
class PathHTMLParser(HTMLParser):
currentPath = ""
paths = []
def handle_starttag(self, tag, attrs):
self.currentPath += ("/" + tag)
def handle_endtag(self, tag):
self.currentPath = remove_last_occurence_of_sub_string(self.currentPath, "/" + tag)
def handle_data(self, data):
if str in data:
self.paths.append(self.currentPath)
parser = PathHTMLParser()
parser.feed(html)
return parser.paths
def get_XSS_data(body: str, request_URL: str, injection_point: str) -> Optional[XSSData]:
""" Return a XSSDict if there is a XSS otherwise return None """
def in_script(text, index, body) -> bool:
""" Whether the Numberth occurence of the first string in the second
string is inside a script tag """
paths = paths_to_text(body.decode('utf-8'), text.decode("utf-8"))
try:
path = paths[index]
return "script" in path
except IndexError:
return False
def in_HTML(text: bytes, index: int, body: bytes) -> bool:
""" Whether the Numberth occurence of the first string in the second
string is inside the HTML but not inside a script tag or part of
a HTML attribute"""
# if there is a < then lxml will interpret that as a tag, so only search for the stuff before it
text = text.split(b"<")[0]
paths = paths_to_text(body.decode('utf-8'), text.decode("utf-8"))
try:
path = paths[index]
return "script" not in path
except IndexError:
return False
def inject_javascript_handler(html: str) -> bool:
""" Whether you can inject a Javascript:alert(0) as a link """
class injectJSHandlerHTMLParser(HTMLParser):
injectJSHandler = False
def handle_starttag(self, tag, attrs):
for name, value in attrs:
if name == "href" and value.startswith(FRONT_WALL.decode('utf-8')):
self.injectJSHandler = True
parser = injectJSHandlerHTMLParser()
parser.feed(html)
return parser.injectJSHandler
# Only convert the body to bytes if needed
if isinstance(body, str):
body = bytes(body, 'utf-8')
# Regex for between 24 and 72 (aka 24*3) characters encapsulated by the walls
regex = re.compile(b"""%s.{24,72}?%s""" % (FRONT_WALL, BACK_WALL))
matches = regex.findall(body)
for index, match in enumerate(matches):
# Where the string is injected into the HTML
in_script = in_script(match, index, body)
in_HTML = in_HTML(match, index, body)
in_tag = not in_script and not in_HTML
in_single_quotes = inside_quote("'", match, index, body)
in_double_quotes = inside_quote('"', match, index, body)
# Whether you can inject:
inject_open_angle = b"ao<ac" in match # open angle brackets
inject_close_angle = b"ac>so" in match # close angle brackets
inject_single_quotes = b"s'd" in match # single quotes
inject_double_quotes = b'd"ao' in match # double quotes
inject_slash = b"sl/bsl" in match # forward slashes
inject_semi = b"se;sl" in match # semicolons
inject_equals = b"eq=" in match # equals sign
if in_script and inject_slash and inject_open_angle and inject_close_angle: # e.g. <script>PAYLOAD</script>
return XSSData(request_URL,
injection_point,
'</script><script>alert(0)</script><script>',
match.decode('utf-8'))
elif in_script and in_single_quotes and inject_single_quotes and inject_semi: # e.g. <script>t='PAYLOAD';</script>
return XSSData(request_URL,
injection_point,
"';alert(0);g='",
match.decode('utf-8'))
elif in_script and in_double_quotes and inject_double_quotes and inject_semi: # e.g. <script>t="PAYLOAD";</script>
return XSSData(request_URL,
injection_point,
'";alert(0);g="',
match.decode('utf-8'))
elif in_tag and in_single_quotes and inject_single_quotes and inject_open_angle and inject_close_angle and inject_slash:
# e.g. <a href='PAYLOAD'>Test</a>
return XSSData(request_URL,
injection_point,
"'><script>alert(0)</script>",
match.decode('utf-8'))
elif in_tag and in_double_quotes and inject_double_quotes and inject_open_angle and inject_close_angle and inject_slash:
# e.g. <a href="PAYLOAD">Test</a>
return XSSData(request_URL,
injection_point,
'"><script>alert(0)</script>',
match.decode('utf-8'))
elif in_tag and not in_double_quotes and not in_single_quotes and inject_open_angle and inject_close_angle and inject_slash:
# e.g. <a href=PAYLOAD>Test</a>
return XSSData(request_URL,
injection_point,
'><script>alert(0)</script>',
match.decode('utf-8'))
elif inject_javascript_handler(body.decode('utf-8')): # e.g. <html><a href=PAYLOAD>Test</a>
return XSSData(request_URL,
injection_point,
'Javascript:alert(0)',
match.decode('utf-8'))
elif in_tag and in_double_quotes and inject_double_quotes and inject_equals: # e.g. <a href="PAYLOAD">Test</a>
return XSSData(request_URL,
injection_point,
'" onmouseover="alert(0)" t="',
match.decode('utf-8'))
elif in_tag and in_single_quotes and inject_single_quotes and inject_equals: # e.g. <a href='PAYLOAD'>Test</a>
return XSSData(request_URL,
injection_point,
"' onmouseover='alert(0)' t='",
match.decode('utf-8'))
elif in_tag and not in_single_quotes and not in_double_quotes and inject_equals: # e.g. <a href=PAYLOAD>Test</a>
return XSSData(request_URL,
injection_point,
" onmouseover=alert(0) t=",
match.decode('utf-8'))
elif in_HTML and not in_script and inject_open_angle and inject_close_angle and inject_slash: # e.g. <html>PAYLOAD</html>
return XSSData(request_URL,
injection_point,
'<script>alert(0)</script>',
match.decode('utf-8'))
else:
return None
# response is mitmproxy's entry point
def response(flow: http.HTTPFlow) -> None:
cookiesDict = get_cookies(flow)
# Example: http://xss.guru/unclaimedScriptTag.html
find_unclaimed_URLs(flow.response.content, flow.request.url)
results = test_end_of_URL_injection(flow.response.content.decode('utf-8'), flow.request.url, cookiesDict)
log_XSS_data(results[0])
log_SQLi_data(results[1])
# Example: https://daviddworken.com/vulnerableReferer.php
results = test_referer_injection(flow.response.content.decode('utf-8'), flow.request.url, cookiesDict)
log_XSS_data(results[0])
log_SQLi_data(results[1])
# Example: https://daviddworken.com/vulnerableUA.php
results = test_user_agent_injection(flow.response.content.decode('utf-8'), flow.request.url, cookiesDict)
log_XSS_data(results[0])
log_SQLi_data(results[1])
if "?" in flow.request.url:
# Example: https://daviddworken.com/vulnerable.php?name=
results = test_query_injection(flow.response.content.decode('utf-8'), flow.request.url, cookiesDict)
log_XSS_data(results[0])
log_SQLi_data(results[1])

View File

@ -93,9 +93,9 @@ def dummy_cert(privkey, cacert, commonname, sans):
try: try:
ipaddress.ip_address(i.decode("ascii")) ipaddress.ip_address(i.decode("ascii"))
except ValueError: except ValueError:
ss.append(b"DNS: %s" % i) ss.append(b"DNS:%s" % i)
else: else:
ss.append(b"IP: %s" % i) ss.append(b"IP:%s" % i)
ss = b", ".join(ss) ss = b", ".join(ss)
cert = OpenSSL.crypto.X509() cert = OpenSSL.crypto.X509()
@ -356,14 +356,14 @@ class CertStore:
class _GeneralName(univ.Choice): class _GeneralName(univ.Choice):
# We are only interested in dNSNames. We use a default handler to ignore # We only care about dNSName and iPAddress
# other types.
# TODO: We should also handle iPAddresses.
componentType = namedtype.NamedTypes( componentType = namedtype.NamedTypes(
namedtype.NamedType('dNSName', char.IA5String().subtype( namedtype.NamedType('dNSName', char.IA5String().subtype(
implicitTag=tag.Tag(tag.tagClassContext, tag.tagFormatSimple, 2) implicitTag=tag.Tag(tag.tagClassContext, tag.tagFormatSimple, 2)
) )),
), namedtype.NamedType('iPAddress', univ.OctetString().subtype(
implicitTag=tag.Tag(tag.tagClassContext, tag.tagFormatSimple, 7)
)),
) )
@ -477,5 +477,10 @@ class SSLCert(serializable.Serializable):
except PyAsn1Error: except PyAsn1Error:
continue continue
for i in dec[0]: for i in dec[0]:
altnames.append(i[0].asOctets()) if i[0] is None and isinstance(i[1], univ.OctetString) and not isinstance(i[1], char.IA5String):
# This would give back the IP address: b'.'.join([str(e).encode() for e in i[1].asNumbers()])
continue
else:
e = i[0].asOctets()
altnames.append(e)
return altnames return altnames

View File

@ -54,24 +54,35 @@ class ClientConnection(tcp.BaseHandler, stateobject.StateObject):
return bool(self.connection) and not self.finished return bool(self.connection) and not self.finished
def __repr__(self): def __repr__(self):
if self.ssl_established:
tls = "[{}] ".format(self.tls_version)
else:
tls = ""
if self.alpn_proto_negotiated: if self.alpn_proto_negotiated:
alpn = "[ALPN: {}] ".format( alpn = "[ALPN: {}] ".format(
strutils.bytes_to_escaped_str(self.alpn_proto_negotiated) strutils.bytes_to_escaped_str(self.alpn_proto_negotiated)
) )
else: else:
alpn = "" alpn = ""
return "<ClientConnection: {ssl}{alpn}{address}>".format(
ssl="[ssl] " if self.ssl_established else "", return "<ClientConnection: {tls}{alpn}{host}:{port}>".format(
tls=tls,
alpn=alpn, alpn=alpn,
address=repr(self.address) host=self.address[0],
port=self.address[1],
) )
@property @property
def tls_established(self): def tls_established(self):
return self.ssl_established return self.ssl_established
@tls_established.setter
def tls_established(self, value):
self.ssl_established = value
_stateobject_attributes = dict( _stateobject_attributes = dict(
address=tcp.Address, address=tuple,
ssl_established=bool, ssl_established=bool,
clientcert=certs.SSLCert, clientcert=certs.SSLCert,
mitmcert=certs.SSLCert, mitmcert=certs.SSLCert,
@ -99,7 +110,7 @@ class ClientConnection(tcp.BaseHandler, stateobject.StateObject):
@classmethod @classmethod
def make_dummy(cls, address): def make_dummy(cls, address):
return cls.from_state(dict( return cls.from_state(dict(
address=dict(address=address, use_ipv6=False), address=address,
clientcert=None, clientcert=None,
mitmcert=None, mitmcert=None,
ssl_established=False, ssl_established=False,
@ -143,6 +154,7 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
cert: The certificate presented by the remote during the TLS handshake cert: The certificate presented by the remote during the TLS handshake
sni: Server Name Indication sent by the proxy during the TLS handshake sni: Server Name Indication sent by the proxy during the TLS handshake
alpn_proto_negotiated: The negotiated application protocol alpn_proto_negotiated: The negotiated application protocol
tls_version: TLS version
via: The underlying server connection (e.g. the connection to the upstream proxy in upstream proxy mode) via: The underlying server connection (e.g. the connection to the upstream proxy in upstream proxy mode)
timestamp_start: Connection start timestamp timestamp_start: Connection start timestamp
timestamp_tcp_setup: TCP ACK received timestamp timestamp_tcp_setup: TCP ACK received timestamp
@ -154,6 +166,7 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
tcp.TCPClient.__init__(self, address, source_address, spoof_source_address) tcp.TCPClient.__init__(self, address, source_address, spoof_source_address)
self.alpn_proto_negotiated = None self.alpn_proto_negotiated = None
self.tls_version = None
self.via = None self.via = None
self.timestamp_start = None self.timestamp_start = None
self.timestamp_end = None self.timestamp_end = None
@ -165,35 +178,41 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
def __repr__(self): def __repr__(self):
if self.ssl_established and self.sni: if self.ssl_established and self.sni:
ssl = "[ssl: {0}] ".format(self.sni) tls = "[{}: {}] ".format(self.tls_version or "TLS", self.sni)
elif self.ssl_established: elif self.ssl_established:
ssl = "[ssl] " tls = "[{}] ".format(self.tls_version or "TLS")
else: else:
ssl = "" tls = ""
if self.alpn_proto_negotiated: if self.alpn_proto_negotiated:
alpn = "[ALPN: {}] ".format( alpn = "[ALPN: {}] ".format(
strutils.bytes_to_escaped_str(self.alpn_proto_negotiated) strutils.bytes_to_escaped_str(self.alpn_proto_negotiated)
) )
else: else:
alpn = "" alpn = ""
return "<ServerConnection: {ssl}{alpn}{address}>".format( return "<ServerConnection: {tls}{alpn}{host}:{port}>".format(
ssl=ssl, tls=tls,
alpn=alpn, alpn=alpn,
address=repr(self.address) host=self.address[0],
port=self.address[1],
) )
@property @property
def tls_established(self): def tls_established(self):
return self.ssl_established return self.ssl_established
@tls_established.setter
def tls_established(self, value):
self.ssl_established = value
_stateobject_attributes = dict( _stateobject_attributes = dict(
address=tcp.Address, address=tuple,
ip_address=tcp.Address, ip_address=tuple,
source_address=tcp.Address, source_address=tuple,
ssl_established=bool, ssl_established=bool,
cert=certs.SSLCert, cert=certs.SSLCert,
sni=str, sni=str,
alpn_proto_negotiated=bytes, alpn_proto_negotiated=bytes,
tls_version=str,
timestamp_start=float, timestamp_start=float,
timestamp_tcp_setup=float, timestamp_tcp_setup=float,
timestamp_ssl_setup=float, timestamp_ssl_setup=float,
@ -209,12 +228,13 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
@classmethod @classmethod
def make_dummy(cls, address): def make_dummy(cls, address):
return cls.from_state(dict( return cls.from_state(dict(
address=dict(address=address, use_ipv6=False), address=address,
ip_address=dict(address=address, use_ipv6=False), ip_address=address,
cert=None, cert=None,
sni=None, sni=None,
alpn_proto_negotiated=None, alpn_proto_negotiated=None,
source_address=dict(address=('', 0), use_ipv6=False), tls_version=None,
source_address=('', 0),
ssl_established=False, ssl_established=False,
timestamp_start=None, timestamp_start=None,
timestamp_tcp_setup=None, timestamp_tcp_setup=None,
@ -244,13 +264,14 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
else: else:
path = os.path.join( path = os.path.join(
clientcerts, clientcerts,
self.address.host.encode("idna").decode()) + ".pem" self.address[0].encode("idna").decode()) + ".pem"
if os.path.exists(path): if os.path.exists(path):
clientcert = path clientcert = path
self.convert_to_ssl(cert=clientcert, sni=sni, **kwargs) self.convert_to_ssl(cert=clientcert, sni=sni, **kwargs)
self.sni = sni self.sni = sni
self.alpn_proto_negotiated = self.get_alpn_proto_negotiated() self.alpn_proto_negotiated = self.get_alpn_proto_negotiated()
self.tls_version = self.connection.get_protocol_version_name()
self.timestamp_ssl_setup = time.time() self.timestamp_ssl_setup = time.time()
def finish(self): def finish(self):

View File

@ -348,7 +348,10 @@ class FSrc(_Rex):
is_binary = False is_binary = False
def __call__(self, f): def __call__(self, f):
return f.client_conn.address and self.re.search(repr(f.client_conn.address)) if not f.client_conn or not f.client_conn.address:
return False
r = "{}:{}".format(f.client_conn.address[0], f.client_conn.address[1])
return f.client_conn.address and self.re.search(r)
class FDst(_Rex): class FDst(_Rex):
@ -357,7 +360,10 @@ class FDst(_Rex):
is_binary = False is_binary = False
def __call__(self, f): def __call__(self, f):
return f.server_conn.address and self.re.search(repr(f.server_conn.address)) if not f.server_conn or not f.server_conn.address:
return False
r = "{}:{}".format(f.server_conn.address[0], f.server_conn.address[1])
return f.server_conn.address and self.re.search(r)
class _Int(_Action): class _Int(_Action):
@ -425,6 +431,7 @@ filter_unary = [
FReq, FReq,
FResp, FResp,
FTCP, FTCP,
FWebSocket,
] ]
filter_rex = [ filter_rex = [
FBod, FBod,

View File

@ -5,7 +5,6 @@ from mitmproxy import flow
from mitmproxy.net import http from mitmproxy.net import http
from mitmproxy import version from mitmproxy import version
from mitmproxy.net import tcp
from mitmproxy import connections # noqa from mitmproxy import connections # noqa
@ -245,9 +244,8 @@ def make_error_response(
def make_connect_request(address): def make_connect_request(address):
address = tcp.Address.wrap(address)
return HTTPRequest( return HTTPRequest(
"authority", b"CONNECT", None, address.host, address.port, None, b"HTTP/1.1", "authority", b"CONNECT", None, address[0], address[1], None, b"HTTP/1.1",
http.Headers(), b"" http.Headers(), b""
) )

View File

@ -88,12 +88,20 @@ def convert_019_100(data):
def convert_100_200(data): def convert_100_200(data):
data["version"] = (2, 0, 0) data["version"] = (2, 0, 0)
data["client_conn"]["address"] = data["client_conn"]["address"]["address"]
data["server_conn"]["address"] = data["server_conn"]["address"]["address"]
data["server_conn"]["source_address"] = data["server_conn"]["source_address"]["address"]
if data["server_conn"]["ip_address"]:
data["server_conn"]["ip_address"] = data["server_conn"]["ip_address"]["address"]
return data return data
def convert_200_300(data): def convert_200_300(data):
data["version"] = (3, 0, 0) data["version"] = (3, 0, 0)
data["client_conn"]["mitmcert"] = None data["client_conn"]["mitmcert"] = None
data["server_conn"]["tls_version"] = None
if data["server_conn"]["via"]:
data["server_conn"]["via"]["tls_version"] = None
return data return data

View File

@ -149,8 +149,8 @@ class Master:
""" """
if isinstance(f, http.HTTPFlow): if isinstance(f, http.HTTPFlow):
if self.server and self.options.mode == "reverse": if self.server and self.options.mode == "reverse":
f.request.host = self.server.config.upstream_server.address.host f.request.host = self.server.config.upstream_server.address[0]
f.request.port = self.server.config.upstream_server.address.port f.request.port = self.server.config.upstream_server.address[1]
f.request.scheme = self.server.config.upstream_server.scheme f.request.scheme = self.server.config.upstream_server.scheme
f.reply = controller.DummyReply() f.reply = controller.DummyReply()
for e, o in eventsequence.iterate(f): for e, o in eventsequence.iterate(f):

View File

@ -1,3 +1,4 @@
import ipaddress
import re import re
# Allow underscore in host name # Allow underscore in host name
@ -6,17 +7,26 @@ _label_valid = re.compile(b"(?!-)[A-Z\d\-_]{1,63}(?<!-)$", re.IGNORECASE)
def is_valid_host(host: bytes) -> bool: def is_valid_host(host: bytes) -> bool:
""" """
Checks if a hostname is valid. Checks if the passed bytes are a valid DNS hostname or an IPv4/IPv6 address.
""" """
try: try:
host.decode("idna") host.decode("idna")
except ValueError: except ValueError:
return False return False
# RFC1035: 255 bytes or less.
if len(host) > 255: if len(host) > 255:
return False return False
if host and host[-1:] == b".": if host and host[-1:] == b".":
host = host[:-1] host = host[:-1]
return all(_label_valid.match(x) for x in host.split(b".")) # DNS hostname
if all(_label_valid.match(x) for x in host.split(b".")):
return True
# IPv4/IPv6 address
try:
ipaddress.ip_address(host.decode('idna'))
return True
except ValueError:
return False
def is_valid_port(port): def is_valid_port(port):

View File

@ -2,7 +2,6 @@ import struct
import array import array
import ipaddress import ipaddress
from mitmproxy.net import tcp
from mitmproxy.net import check from mitmproxy.net import check
from mitmproxy.types import bidi from mitmproxy.types import bidi
@ -179,7 +178,7 @@ class Message:
self.ver = ver self.ver = ver
self.msg = msg self.msg = msg
self.atyp = atyp self.atyp = atyp
self.addr = tcp.Address.wrap(addr) self.addr = addr
def assert_socks5(self): def assert_socks5(self):
if self.ver != VERSION.SOCKS5: if self.ver != VERSION.SOCKS5:
@ -199,37 +198,34 @@ class Message:
if atyp == ATYP.IPV4_ADDRESS: if atyp == ATYP.IPV4_ADDRESS:
# We use tnoa here as ntop is not commonly available on Windows. # We use tnoa here as ntop is not commonly available on Windows.
host = ipaddress.IPv4Address(f.safe_read(4)).compressed host = ipaddress.IPv4Address(f.safe_read(4)).compressed
use_ipv6 = False
elif atyp == ATYP.IPV6_ADDRESS: elif atyp == ATYP.IPV6_ADDRESS:
host = ipaddress.IPv6Address(f.safe_read(16)).compressed host = ipaddress.IPv6Address(f.safe_read(16)).compressed
use_ipv6 = True
elif atyp == ATYP.DOMAINNAME: elif atyp == ATYP.DOMAINNAME:
length, = struct.unpack("!B", f.safe_read(1)) length, = struct.unpack("!B", f.safe_read(1))
host = f.safe_read(length) host = f.safe_read(length)
if not check.is_valid_host(host): if not check.is_valid_host(host):
raise SocksError(REP.GENERAL_SOCKS_SERVER_FAILURE, "Invalid hostname: %s" % host) raise SocksError(REP.GENERAL_SOCKS_SERVER_FAILURE, "Invalid hostname: %s" % host)
host = host.decode("idna") host = host.decode("idna")
use_ipv6 = False
else: else:
raise SocksError(REP.ADDRESS_TYPE_NOT_SUPPORTED, raise SocksError(REP.ADDRESS_TYPE_NOT_SUPPORTED,
"Socks Request: Unknown ATYP: %s" % atyp) "Socks Request: Unknown ATYP: %s" % atyp)
port, = struct.unpack("!H", f.safe_read(2)) port, = struct.unpack("!H", f.safe_read(2))
addr = tcp.Address((host, port), use_ipv6=use_ipv6) addr = (host, port)
return cls(ver, msg, atyp, addr) return cls(ver, msg, atyp, addr)
def to_file(self, f): def to_file(self, f):
f.write(struct.pack("!BBBB", self.ver, self.msg, 0x00, self.atyp)) f.write(struct.pack("!BBBB", self.ver, self.msg, 0x00, self.atyp))
if self.atyp == ATYP.IPV4_ADDRESS: if self.atyp == ATYP.IPV4_ADDRESS:
f.write(ipaddress.IPv4Address(self.addr.host).packed) f.write(ipaddress.IPv4Address(self.addr[0]).packed)
elif self.atyp == ATYP.IPV6_ADDRESS: elif self.atyp == ATYP.IPV6_ADDRESS:
f.write(ipaddress.IPv6Address(self.addr.host).packed) f.write(ipaddress.IPv6Address(self.addr[0]).packed)
elif self.atyp == ATYP.DOMAINNAME: elif self.atyp == ATYP.DOMAINNAME:
f.write(struct.pack("!B", len(self.addr.host))) f.write(struct.pack("!B", len(self.addr[0])))
f.write(self.addr.host.encode("idna")) f.write(self.addr[0].encode("idna"))
else: else:
raise SocksError( raise SocksError(
REP.ADDRESS_TYPE_NOT_SUPPORTED, REP.ADDRESS_TYPE_NOT_SUPPORTED,
"Unknown ATYP: %s" % self.atyp "Unknown ATYP: %s" % self.atyp
) )
f.write(struct.pack("!H", self.addr.port)) f.write(struct.pack("!H", self.addr[1]))

View File

@ -19,7 +19,6 @@ from OpenSSL import SSL
from mitmproxy import certs from mitmproxy import certs
from mitmproxy.utils import version_check from mitmproxy.utils import version_check
from mitmproxy.types import serializable
from mitmproxy import exceptions from mitmproxy import exceptions
from mitmproxy.types import basethread from mitmproxy.types import basethread
@ -29,6 +28,10 @@ version_check.check_pyopenssl_version()
socket_fileobject = socket.SocketIO socket_fileobject = socket.SocketIO
# workaround for https://bugs.python.org/issue29515
# Python 3.5 and 3.6 for Windows is missing a constant
IPPROTO_IPV6 = getattr(socket, "IPPROTO_IPV6", 41)
EINTR = 4 EINTR = 4
HAS_ALPN = SSL._lib.Cryptography_HAS_ALPN HAS_ALPN = SSL._lib.Cryptography_HAS_ALPN
@ -299,73 +302,6 @@ class Reader(_FileLike):
raise NotImplementedError("Can only peek into (pyOpenSSL) sockets") raise NotImplementedError("Can only peek into (pyOpenSSL) sockets")
class Address(serializable.Serializable):
"""
This class wraps an IPv4/IPv6 tuple to provide named attributes and
ipv6 information.
"""
def __init__(self, address, use_ipv6=False):
self.address = tuple(address)
self.use_ipv6 = use_ipv6
def get_state(self):
return {
"address": self.address,
"use_ipv6": self.use_ipv6
}
def set_state(self, state):
self.address = state["address"]
self.use_ipv6 = state["use_ipv6"]
@classmethod
def from_state(cls, state):
return Address(**state)
@classmethod
def wrap(cls, t):
if isinstance(t, cls):
return t
else:
return cls(t)
def __call__(self):
return self.address
@property
def host(self):
return self.address[0]
@property
def port(self):
return self.address[1]
@property
def use_ipv6(self):
return self.family == socket.AF_INET6
@use_ipv6.setter
def use_ipv6(self, b):
self.family = socket.AF_INET6 if b else socket.AF_INET
def __repr__(self):
return "{}:{}".format(self.host, self.port)
def __eq__(self, other):
if not other:
return False
other = Address.wrap(other)
return (self.address, self.family) == (other.address, other.family)
def __ne__(self, other):
return not self.__eq__(other)
def __hash__(self):
return hash(self.address) ^ 42 # different hash than the tuple alone.
def ssl_read_select(rlist, timeout): def ssl_read_select(rlist, timeout):
""" """
This is a wrapper around select.select() which also works for SSL.Connections This is a wrapper around select.select() which also works for SSL.Connections
@ -452,7 +388,7 @@ class _Connection:
def __init__(self, connection): def __init__(self, connection):
if connection: if connection:
self.connection = connection self.connection = connection
self.ip_address = Address(connection.getpeername()) self.ip_address = connection.getpeername()
self._makefile() self._makefile()
else: else:
self.connection = None self.connection = None
@ -629,28 +565,6 @@ class TCPClient(_Connection):
self.sni = None self.sni = None
self.spoof_source_address = spoof_source_address self.spoof_source_address = spoof_source_address
@property
def address(self):
return self.__address
@address.setter
def address(self, address):
if address:
self.__address = Address.wrap(address)
else:
self.__address = None
@property
def source_address(self):
return self.__source_address
@source_address.setter
def source_address(self, source_address):
if source_address:
self.__source_address = Address.wrap(source_address)
else:
self.__source_address = None
def close(self): def close(self):
# Make sure to close the real socket, not the SSL proxy. # Make sure to close the real socket, not the SSL proxy.
# OpenSSL is really good at screwing up, i.e. when trying to recv from a failed connection, # OpenSSL is really good at screwing up, i.e. when trying to recv from a failed connection,
@ -741,34 +655,57 @@ class TCPClient(_Connection):
self.rfile.set_descriptor(self.connection) self.rfile.set_descriptor(self.connection)
self.wfile.set_descriptor(self.connection) self.wfile.set_descriptor(self.connection)
def makesocket(self): def makesocket(self, family, type, proto):
# some parties (cuckoo sandbox) need to hook this # some parties (cuckoo sandbox) need to hook this
return socket.socket(self.address.family, socket.SOCK_STREAM) return socket.socket(family, type, proto)
def create_connection(self, timeout=None):
# Based on the official socket.create_connection implementation of Python 3.6.
# https://github.com/python/cpython/blob/3cc5817cfaf5663645f4ee447eaed603d2ad290a/Lib/socket.py
err = None
for res in socket.getaddrinfo(self.address[0], self.address[1], 0, socket.SOCK_STREAM):
af, socktype, proto, canonname, sa = res
sock = None
try:
sock = self.makesocket(af, socktype, proto)
if timeout:
sock.settimeout(timeout)
if self.source_address:
sock.bind(self.source_address)
if self.spoof_source_address:
try:
if not sock.getsockopt(socket.SOL_IP, socket.IP_TRANSPARENT):
sock.setsockopt(socket.SOL_IP, socket.IP_TRANSPARENT, 1)
except Exception as e:
# socket.IP_TRANSPARENT might not be available on every OS and Python version
raise exceptions.TcpException(
"Failed to spoof the source address: " + e.strerror
)
sock.connect(sa)
return sock
except socket.error as _:
err = _
if sock is not None:
sock.close()
if err is not None:
raise err
else:
raise socket.error("getaddrinfo returns an empty list")
def connect(self): def connect(self):
try: try:
connection = self.makesocket() connection = self.create_connection()
if self.spoof_source_address:
try:
# 19 is `IP_TRANSPARENT`, which is only available on Python 3.3+ on some OSes
if not connection.getsockopt(socket.SOL_IP, 19):
connection.setsockopt(socket.SOL_IP, 19, 1)
except socket.error as e:
raise exceptions.TcpException(
"Failed to spoof the source address: " + e.strerror
)
if self.source_address:
connection.bind(self.source_address())
connection.connect(self.address())
self.source_address = Address(connection.getsockname())
except (socket.error, IOError) as err: except (socket.error, IOError) as err:
raise exceptions.TcpException( raise exceptions.TcpException(
'Error connecting to "%s": %s' % 'Error connecting to "%s": %s' %
(self.address.host, err) (self.address[0], err)
) )
self.connection = connection self.connection = connection
self.ip_address = Address(connection.getpeername()) self.source_address = connection.getsockname()
self.ip_address = connection.getpeername()
self._makefile() self._makefile()
return ConnectionCloser(self) return ConnectionCloser(self)
@ -793,7 +730,7 @@ class BaseHandler(_Connection):
def __init__(self, connection, address, server): def __init__(self, connection, address, server):
super().__init__(connection) super().__init__(connection)
self.address = Address.wrap(address) self.address = address
self.server = server self.server = server
self.clientcert = None self.clientcert = None
@ -915,19 +852,36 @@ class TCPServer:
request_queue_size = 20 request_queue_size = 20
def __init__(self, address): def __init__(self, address):
self.address = Address.wrap(address) self.address = address
self.__is_shut_down = threading.Event() self.__is_shut_down = threading.Event()
self.__shutdown_request = False self.__shutdown_request = False
self.socket = socket.socket(self.address.family, socket.SOCK_STREAM)
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) if self.address == 'localhost':
self.socket.bind(self.address()) raise socket.error("Binding to 'localhost' is prohibited. Please use '::1' or '127.0.0.1' directly.")
self.address = Address.wrap(self.socket.getsockname())
try:
# First try to bind an IPv6 socket, with possible IPv4 if the OS supports it.
# This allows us to accept connections for ::1 and 127.0.0.1 on the same socket.
# Only works if self.address == ""
self.socket = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.socket.setsockopt(IPPROTO_IPV6, socket.IPV6_V6ONLY, 0)
self.socket.bind(self.address)
except socket.error:
self.socket = None
if not self.socket:
# Binding to an IPv6 socket failed, lets fall back to IPv4.
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.socket.bind(self.address)
self.address = self.socket.getsockname()
self.socket.listen(self.request_queue_size) self.socket.listen(self.request_queue_size)
self.handler_counter = Counter() self.handler_counter = Counter()
def connection_thread(self, connection, client_address): def connection_thread(self, connection, client_address):
with self.handler_counter: with self.handler_counter:
client_address = Address(client_address)
try: try:
self.handle_client_connection(connection, client_address) self.handle_client_connection(connection, client_address)
except: except:
@ -954,8 +908,8 @@ class TCPServer:
self.__class__.__name__, self.__class__.__name__,
client_address[0], client_address[0],
client_address[1], client_address[1],
self.address.host, self.address[0],
self.address.port self.address[1],
), ),
target=self.connection_thread, target=self.connection_thread,
args=(connection, client_address), args=(connection, client_address),
@ -964,7 +918,7 @@ class TCPServer:
try: try:
t.start() t.start()
except threading.ThreadError: except threading.ThreadError:
self.handle_error(connection, Address(client_address)) self.handle_error(connection, client_address)
connection.close() connection.close()
finally: finally:
self.__shutdown_request = False self.__shutdown_request = False

View File

@ -4,14 +4,13 @@ import urllib
import io import io
from mitmproxy.net import http from mitmproxy.net import http
from mitmproxy.net import tcp
from mitmproxy.utils import strutils from mitmproxy.utils import strutils
class ClientConn: class ClientConn:
def __init__(self, address): def __init__(self, address):
self.address = tcp.Address.wrap(address) self.address = address
class Flow: class Flow:
@ -84,8 +83,8 @@ class WSGIAdaptor:
} }
environ.update(extra) environ.update(extra)
if flow.client_conn.address: if flow.client_conn.address:
environ["REMOTE_ADDR"] = strutils.always_str(flow.client_conn.address.host, "latin-1") environ["REMOTE_ADDR"] = strutils.always_str(flow.client_conn.address[0], "latin-1")
environ["REMOTE_PORT"] = flow.client_conn.address.port environ["REMOTE_PORT"] = flow.client_conn.address[1]
for key, value in flow.request.headers.items(): for key, value in flow.request.headers.items():
key = 'HTTP_' + strutils.always_str(key, "latin-1").upper().replace('-', '_') key = 'HTTP_' + strutils.always_str(key, "latin-1").upper().replace('-', '_')

View File

@ -23,8 +23,7 @@ class HostMatcher:
def __call__(self, address): def __call__(self, address):
if not address: if not address:
return False return False
address = tcp.Address.wrap(address) host = "%s:%s" % address
host = "%s:%s" % (address.host, address.port)
if any(rex.search(host) for rex in self.regexes): if any(rex.search(host) for rex in self.regexes):
return True return True
else: else:
@ -47,7 +46,7 @@ def parse_server_spec(spec):
"Invalid server specification: %s" % spec "Invalid server specification: %s" % spec
) )
host, port = p[1:3] host, port = p[1:3]
address = tcp.Address((host.decode("ascii"), port)) address = (host.decode("ascii"), port)
scheme = p[0].decode("ascii").lower() scheme = p[0].decode("ascii").lower()
return ServerSpec(scheme, address) return ServerSpec(scheme, address)

View File

@ -101,7 +101,7 @@ class ServerConnectionMixin:
self.server_conn = None self.server_conn = None
if self.config.options.spoof_source_address and self.config.options.upstream_bind_address == '': if self.config.options.spoof_source_address and self.config.options.upstream_bind_address == '':
self.server_conn = connections.ServerConnection( self.server_conn = connections.ServerConnection(
server_address, (self.ctx.client_conn.address.host, 0), True) server_address, (self.ctx.client_conn.address[0], 0), True)
else: else:
self.server_conn = connections.ServerConnection( self.server_conn = connections.ServerConnection(
server_address, (self.config.options.upstream_bind_address, 0), server_address, (self.config.options.upstream_bind_address, 0),
@ -118,8 +118,8 @@ class ServerConnectionMixin:
address = self.server_conn.address address = self.server_conn.address
if address: if address:
self_connect = ( self_connect = (
address.port == self.config.options.listen_port and address[1] == self.config.options.listen_port and
address.host in ("localhost", "127.0.0.1", "::1") address[0] in ("localhost", "127.0.0.1", "::1")
) )
if self_connect: if self_connect:
raise exceptions.ProtocolException( raise exceptions.ProtocolException(
@ -133,7 +133,7 @@ class ServerConnectionMixin:
""" """
if self.server_conn.connected(): if self.server_conn.connected():
self.disconnect() self.disconnect()
self.log("Set new server address: " + repr(address), "debug") self.log("Set new server address: {}:{}".format(address[0], address[1]), "debug")
self.server_conn.address = address self.server_conn.address = address
self.__check_self_connect() self.__check_self_connect()
@ -150,7 +150,7 @@ class ServerConnectionMixin:
self.server_conn = connections.ServerConnection( self.server_conn = connections.ServerConnection(
address, address,
(self.server_conn.source_address.host, 0), (self.server_conn.source_address[0], 0),
self.config.options.spoof_source_address self.config.options.spoof_source_address
) )

View File

@ -8,7 +8,6 @@ from mitmproxy import http
from mitmproxy import flow from mitmproxy import flow
from mitmproxy.proxy.protocol import base from mitmproxy.proxy.protocol import base
from mitmproxy.proxy.protocol.websocket import WebSocketLayer from mitmproxy.proxy.protocol.websocket import WebSocketLayer
from mitmproxy.net import tcp
from mitmproxy.net import websockets from mitmproxy.net import websockets
@ -59,7 +58,7 @@ class ConnectServerConnection:
""" """
def __init__(self, address, ctx): def __init__(self, address, ctx):
self.address = tcp.Address.wrap(address) self.address = address
self._ctx = ctx self._ctx = ctx
@property @property
@ -112,9 +111,8 @@ class UpstreamConnectLayer(base.Layer):
def set_server(self, address): def set_server(self, address):
if self.ctx.server_conn.connected(): if self.ctx.server_conn.connected():
self.ctx.disconnect() self.ctx.disconnect()
address = tcp.Address.wrap(address) self.connect_request.host = address[0]
self.connect_request.host = address.host self.connect_request.port = address[1]
self.connect_request.port = address.port
self.server_conn.address = address self.server_conn.address = address
@ -291,7 +289,7 @@ class HttpLayer(base.Layer):
# update host header in reverse proxy mode # update host header in reverse proxy mode
if self.config.options.mode == "reverse" and not self.config.options.keep_host_header: if self.config.options.mode == "reverse" and not self.config.options.keep_host_header:
f.request.host_header = self.config.upstream_server.address.host f.request.host_header = self.config.upstream_server.address[0]
# Determine .scheme, .host and .port attributes for inline scripts. For # Determine .scheme, .host and .port attributes for inline scripts. For
# absolute-form requests, they are directly given in the request. For # absolute-form requests, they are directly given in the request. For
@ -302,8 +300,8 @@ class HttpLayer(base.Layer):
# Setting request.host also updates the host header, which we want # Setting request.host also updates the host header, which we want
# to preserve # to preserve
host_header = f.request.host_header host_header = f.request.host_header
f.request.host = self.__initial_server_conn.address.host f.request.host = self.__initial_server_conn.address[0]
f.request.port = self.__initial_server_conn.address.port f.request.port = self.__initial_server_conn.address[1]
f.request.host_header = host_header # set again as .host overwrites this. f.request.host_header = host_header # set again as .host overwrites this.
f.request.scheme = "https" if self.__initial_server_tls else "http" f.request.scheme = "https" if self.__initial_server_tls else "http"
self.channel.ask("request", f) self.channel.ask("request", f)
@ -453,14 +451,14 @@ class HttpLayer(base.Layer):
self.set_server(address) self.set_server(address)
def establish_server_connection(self, host: str, port: int, scheme: str): def establish_server_connection(self, host: str, port: int, scheme: str):
address = tcp.Address((host, port))
tls = (scheme == "https") tls = (scheme == "https")
if self.mode is HTTPMode.regular or self.mode is HTTPMode.transparent: if self.mode is HTTPMode.regular or self.mode is HTTPMode.transparent:
# If there's an existing connection that doesn't match our expectations, kill it. # If there's an existing connection that doesn't match our expectations, kill it.
address = (host, port)
if address != self.server_conn.address or tls != self.server_tls: if address != self.server_conn.address or tls != self.server_tls:
self.set_server(address) self.set_server(address)
self.set_server_tls(tls, address.host) self.set_server_tls(tls, address[0])
# Establish connection is neccessary. # Establish connection is neccessary.
if not self.server_conn.connected(): if not self.server_conn.connected():
self.connect() self.connect()

View File

@ -97,7 +97,6 @@ class Http2Layer(base.Layer):
client_side=False, client_side=False,
header_encoding=False, header_encoding=False,
validate_outbound_headers=False, validate_outbound_headers=False,
normalize_outbound_headers=False,
validate_inbound_headers=False) validate_inbound_headers=False)
self.connections[self.client_conn] = SafeH2Connection(self.client_conn, config=config) self.connections[self.client_conn] = SafeH2Connection(self.client_conn, config=config)
@ -107,7 +106,6 @@ class Http2Layer(base.Layer):
client_side=True, client_side=True,
header_encoding=False, header_encoding=False,
validate_outbound_headers=False, validate_outbound_headers=False,
normalize_outbound_headers=False,
validate_inbound_headers=False) validate_inbound_headers=False)
self.connections[self.server_conn] = SafeH2Connection(self.server_conn, config=config) self.connections[self.server_conn] = SafeH2Connection(self.server_conn, config=config)
self.connections[self.server_conn].initiate_connection() self.connections[self.server_conn].initiate_connection()
@ -599,9 +597,6 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
def send_response_headers(self, response): def send_response_headers(self, response):
headers = response.headers.copy() headers = response.headers.copy()
headers.insert(0, ":status", str(response.status_code)) headers.insert(0, ":status", str(response.status_code))
for forbidden_header in h2.utilities.CONNECTION_HEADERS:
if forbidden_header in headers:
del headers[forbidden_header]
with self.connections[self.client_conn].lock: with self.connections[self.client_conn].lock:
self.connections[self.client_conn].safe_send_headers( self.connections[self.client_conn].safe_send_headers(
self.raise_zombie, self.raise_zombie,

View File

@ -545,8 +545,9 @@ class TlsLayer(base.Layer):
raise exceptions.InvalidServerCertificate(str(e)) raise exceptions.InvalidServerCertificate(str(e))
except exceptions.TlsException as e: except exceptions.TlsException as e:
raise exceptions.TlsProtocolException( raise exceptions.TlsProtocolException(
"Cannot establish TLS with {address} (sni: {sni}): {e}".format( "Cannot establish TLS with {host}:{port} (sni: {sni}): {e}".format(
address=repr(self.server_conn.address), host=self.server_conn.address[0],
port=self.server_conn.address[1],
sni=self.server_sni, sni=self.server_sni,
e=repr(e) e=repr(e)
) )
@ -567,7 +568,7 @@ class TlsLayer(base.Layer):
# However, we may just want to establish TLS so that we can send an error message to the client, # However, we may just want to establish TLS so that we can send an error message to the client,
# in which case the address can be None. # in which case the address can be None.
if self.server_conn.address: if self.server_conn.address:
host = self.server_conn.address.host.encode("idna") host = self.server_conn.address[0].encode("idna")
# Should we incorporate information from the server certificate? # Should we incorporate information from the server certificate?
use_upstream_cert = ( use_upstream_cert = (

View File

@ -68,7 +68,7 @@ class RootContext:
top_layer, top_layer,
client_tls, client_tls,
top_layer.server_tls, top_layer.server_tls,
top_layer.server_conn.address.host top_layer.server_conn.address[0]
) )
if isinstance(top_layer, protocol.ServerConnectionMixin) or isinstance(top_layer, protocol.UpstreamConnectLayer): if isinstance(top_layer, protocol.ServerConnectionMixin) or isinstance(top_layer, protocol.UpstreamConnectLayer):
return protocol.TlsLayer(top_layer, client_tls, client_tls) return protocol.TlsLayer(top_layer, client_tls, client_tls)
@ -104,7 +104,7 @@ class RootContext:
Send a log message to the master. Send a log message to the master.
""" """
full_msg = [ full_msg = [
"{}: {}".format(repr(self.client_conn.address), msg) "{}:{}: {}".format(self.client_conn.address[0], self.client_conn.address[1], msg)
] ]
for i in subs: for i in subs:
full_msg.append(" -> " + i) full_msg.append(" -> " + i)

View File

@ -1,4 +1,3 @@
import socket
import sys import sys
import traceback import traceback
@ -46,10 +45,10 @@ class ProxyServer(tcp.TCPServer):
) )
if config.options.mode == "transparent": if config.options.mode == "transparent":
platform.init_transparent_mode() platform.init_transparent_mode()
except socket.error as e: except Exception as e:
raise exceptions.ServerException( raise exceptions.ServerException(
'Error starting proxy server: ' + repr(e) 'Error starting proxy server: ' + repr(e)
) ) from e
self.channel = None self.channel = None
def set_channel(self, channel): def set_channel(self, channel):

View File

@ -1,3 +1,5 @@
import io
from mitmproxy.net import websockets from mitmproxy.net import websockets
from mitmproxy.test import tutils from mitmproxy.test import tutils
from mitmproxy import tcp from mitmproxy import tcp
@ -72,7 +74,8 @@ def twebsocketflow(client_conn=True, server_conn=True, messages=True, err=None,
if messages is True: if messages is True:
messages = [ messages = [
websocket.WebSocketMessage(websockets.OPCODE.BINARY, True, b"hello binary"), websocket.WebSocketMessage(websockets.OPCODE.BINARY, True, b"hello binary"),
websocket.WebSocketMessage(websockets.OPCODE.TEXT, False, "hello text".encode()), websocket.WebSocketMessage(websockets.OPCODE.TEXT, True, "hello text".encode()),
websocket.WebSocketMessage(websockets.OPCODE.TEXT, False, "it's me".encode()),
] ]
if err is True: if err is True:
err = terr() err = terr()
@ -142,7 +145,7 @@ def tclient_conn():
@return: mitmproxy.proxy.connection.ClientConnection @return: mitmproxy.proxy.connection.ClientConnection
""" """
c = connections.ClientConnection.from_state(dict( c = connections.ClientConnection.from_state(dict(
address=dict(address=("address", 22), use_ipv6=True), address=("address", 22),
clientcert=None, clientcert=None,
mitmcert=None, mitmcert=None,
ssl_established=False, ssl_established=False,
@ -155,6 +158,8 @@ def tclient_conn():
tls_version="TLSv1.2", tls_version="TLSv1.2",
)) ))
c.reply = controller.DummyReply() c.reply = controller.DummyReply()
c.rfile = io.BytesIO()
c.wfile = io.BytesIO()
return c return c
@ -163,8 +168,8 @@ def tserver_conn():
@return: mitmproxy.proxy.connection.ServerConnection @return: mitmproxy.proxy.connection.ServerConnection
""" """
c = connections.ServerConnection.from_state(dict( c = connections.ServerConnection.from_state(dict(
address=dict(address=("address", 22), use_ipv6=True), address=("address", 22),
source_address=dict(address=("address", 22), use_ipv6=True), source_address=("address", 22),
ip_address=None, ip_address=None,
cert=None, cert=None,
timestamp_start=1, timestamp_start=1,
@ -174,9 +179,12 @@ def tserver_conn():
ssl_established=False, ssl_established=False,
sni="address", sni="address",
alpn_proto_negotiated=None, alpn_proto_negotiated=None,
tls_version=None,
via=None, via=None,
)) ))
c.reply = controller.DummyReply() c.reply = controller.DummyReply()
c.rfile = io.BytesIO()
c.wfile = io.BytesIO()
return c return c

View File

@ -30,8 +30,8 @@ def flowdetails(state, flow: http.HTTPFlow):
if sc is not None: if sc is not None:
text.append(urwid.Text([("head", "Server Connection:")])) text.append(urwid.Text([("head", "Server Connection:")]))
parts = [ parts = [
["Address", repr(sc.address)], ["Address", "{}:{}".format(sc.address[0], sc.address[1])],
["Resolved Address", repr(sc.ip_address)], ["Resolved Address", "{}:{}".format(sc.ip_address[0], sc.ip_address[1])],
] ]
if resp: if resp:
parts.append(["HTTP Version", resp.http_version]) parts.append(["HTTP Version", resp.http_version])
@ -92,7 +92,7 @@ def flowdetails(state, flow: http.HTTPFlow):
text.append(urwid.Text([("head", "Client Connection:")])) text.append(urwid.Text([("head", "Client Connection:")]))
parts = [ parts = [
["Address", repr(cc.address)], ["Address", "{}:{}".format(cc.address[0], cc.address[1])],
] ]
if req: if req:
parts.append(["HTTP Version", req.http_version]) parts.append(["HTTP Version", req.http_version])

View File

@ -681,7 +681,7 @@ class FlowView(tabs.Tabs):
encoding_map = { encoding_map = {
"z": "gzip", "z": "gzip",
"d": "deflate", "d": "deflate",
"b": "brotli", "b": "br",
} }
conn.encode(encoding_map[key]) conn.encode(encoding_map[key])
signals.flow_change.send(self, flow = self.flow) signals.flow_change.send(self, flow = self.flow)

View File

@ -429,9 +429,11 @@ class ConsoleMaster(master.Master):
super().tcp_message(f) super().tcp_message(f)
message = f.messages[-1] message = f.messages[-1]
direction = "->" if message.from_client else "<-" direction = "->" if message.from_client else "<-"
signals.add_log("{client} {direction} tcp {direction} {server}".format( signals.add_log("{client_host}:{client_port} {direction} tcp {direction} {server_host}:{server_port}".format(
client=repr(f.client_conn.address), client_host=f.client_conn.address[0],
server=repr(f.server_conn.address), client_port=f.client_conn.address[1],
server_host=f.server_conn.address[0],
server_port=f.server_conn.address[1],
direction=direction, direction=direction,
), "info") ), "info")
signals.add_log(strutils.bytes_to_escaped_str(message.content), "debug") signals.add_log(strutils.bytes_to_escaped_str(message.content), "debug")

View File

@ -43,7 +43,7 @@ class PalettePicker(urwid.WidgetWrap):
i, i,
None, None,
lambda: self.master.options.console_palette == name, lambda: self.master.options.console_palette == name,
lambda: setattr(self.master.options, "palette", name) lambda: setattr(self.master.options, "console_palette", name)
) )
for i in high: for i in high:
@ -59,7 +59,7 @@ class PalettePicker(urwid.WidgetWrap):
"Transparent", "Transparent",
"T", "T",
lambda: master.options.console_palette_transparent, lambda: master.options.console_palette_transparent,
master.options.toggler("palette_transparent") master.options.toggler("console_palette_transparent")
) )
] ]
) )

View File

@ -238,8 +238,8 @@ class StatusBar(urwid.WidgetWrap):
dst = self.master.server.config.upstream_server dst = self.master.server.config.upstream_server
r.append("[dest:%s]" % mitmproxy.net.http.url.unparse( r.append("[dest:%s]" % mitmproxy.net.http.url.unparse(
dst.scheme, dst.scheme,
dst.address.host, dst.address[0],
dst.address.port dst.address[1],
)) ))
if self.master.options.scripts: if self.master.options.scripts:
r.append("[") r.append("[")
@ -272,10 +272,10 @@ class StatusBar(urwid.WidgetWrap):
] ]
if self.master.server.bound: if self.master.server.bound:
host = self.master.server.address.host host = self.master.server.address[0]
if host == "0.0.0.0": if host == "0.0.0.0":
host = "*" host = "*"
boundaddr = "[%s:%s]" % (host, self.master.server.address.port) boundaddr = "[%s:%s]" % (host, self.master.server.address[1])
else: else:
boundaddr = "" boundaddr = ""
t.extend(self.get_status()) t.extend(self.get_status())

View File

@ -46,7 +46,7 @@ class DumpMaster(master.Master):
if not self.options.no_server: if not self.options.no_server:
self.add_log( self.add_log(
"Proxy server listening at http://{}".format(server.address), "Proxy server listening at http://{}:{}".format(server.address[0], server.address[1]),
"info" "info"
) )

View File

@ -85,6 +85,7 @@ def flow_to_json(flow: mitmproxy.flow.Flow) -> dict:
"is_replay": flow.response.is_replay, "is_replay": flow.response.is_replay,
} }
f.get("server_conn", {}).pop("cert", None) f.get("server_conn", {}).pop("cert", None)
f.get("client_conn", {}).pop("mitmcert", None)
return f return f

View File

@ -109,7 +109,7 @@ class WebMaster(master.Master):
tornado.ioloop.PeriodicCallback(lambda: self.tick(timeout=0), 5).start() tornado.ioloop.PeriodicCallback(lambda: self.tick(timeout=0), 5).start()
self.add_log( self.add_log(
"Proxy server listening at http://{}/".format(self.server.address), "Proxy server listening at http://{}:{}/".format(self.server.address[0], self.server.address[1]),
"info" "info"
) )

View File

@ -2046,7 +2046,7 @@ function ConnectionInfo(_ref2) {
_react2.default.createElement( _react2.default.createElement(
'td', 'td',
null, null,
conn.address.address.join(':') conn.address.join(':')
) )
), ),
conn.sni && _react2.default.createElement( conn.sni && _react2.default.createElement(
@ -8449,7 +8449,7 @@ module.exports = function () {
function destination(regex) { function destination(regex) {
regex = new RegExp(regex, "i"); regex = new RegExp(regex, "i");
function destinationFilter(flow) { function destinationFilter(flow) {
return !!flow.server_conn.address && regex.test(flow.server_conn.address.address[0] + ":" + flow.server_conn.address.address[1]); return !!flow.server_conn.address && regex.test(flow.server_conn.address[0] + ":" + flow.server_conn.address[1]);
} }
destinationFilter.desc = "destination address matches " + regex; destinationFilter.desc = "destination address matches " + regex;
return destinationFilter; return destinationFilter;
@ -8509,7 +8509,7 @@ module.exports = function () {
function source(regex) { function source(regex) {
regex = new RegExp(regex, "i"); regex = new RegExp(regex, "i");
function sourceFilter(flow) { function sourceFilter(flow) {
return !!flow.client_conn.address && regex.test(flow.client_conn.address.address[0] + ":" + flow.client_conn.address.address[1]); return !!flow.client_conn.address && regex.test(flow.client_conn.address[0] + ":" + flow.client_conn.address[1]);
} }
sourceFilter.desc = "source address matches " + regex; sourceFilter.desc = "source address matches " + regex;
return sourceFilter; return sourceFilter;

View File

@ -239,7 +239,7 @@ class Pathoc(tcp.TCPClient):
is_client=True, is_client=True,
staticdir=os.getcwd(), staticdir=os.getcwd(),
unconstrained_file_access=True, unconstrained_file_access=True,
request_host=self.address.host, request_host=self.address[0],
protocol=self.protocol, protocol=self.protocol,
) )
@ -286,7 +286,7 @@ class Pathoc(tcp.TCPClient):
socks.VERSION.SOCKS5, socks.VERSION.SOCKS5,
socks.CMD.CONNECT, socks.CMD.CONNECT,
socks.ATYP.DOMAINNAME, socks.ATYP.DOMAINNAME,
tcp.Address.wrap(connect_to) connect_to,
) )
connect_request.to_file(self.wfile) connect_request.to_file(self.wfile)
self.wfile.flush() self.wfile.flush()

View File

@ -166,7 +166,7 @@ class PathodHandler(tcp.BaseHandler):
headers=headers.fields, headers=headers.fields,
http_version=http_version, http_version=http_version,
sni=self.sni, sni=self.sni,
remote_address=self.address(), remote_address=self.address,
clientcert=clientcert, clientcert=clientcert,
first_line_format=first_line_format first_line_format=first_line_format
), ),

View File

@ -172,9 +172,9 @@ class HTTP2StateProtocol:
def assemble_request(self, request): def assemble_request(self, request):
assert isinstance(request, mitmproxy.net.http.request.Request) assert isinstance(request, mitmproxy.net.http.request.Request)
authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address[0]
if self.tcp_handler.address.port != 443: if self.tcp_handler.address[1] != 443:
authority += ":%d" % self.tcp_handler.address.port authority += ":%d" % self.tcp_handler.address[1]
headers = request.headers.copy() headers = request.headers.copy()

View File

@ -97,8 +97,8 @@ class _PaThread(basethread.BaseThread):
**self.daemonargs **self.daemonargs
) )
self.name = "PathodThread (%s:%s)" % ( self.name = "PathodThread (%s:%s)" % (
self.server.address.host, self.server.address[0],
self.server.address.port self.server.address[1],
) )
self.q.put(self.server.address.port) self.q.put(self.server.address[1])
self.server.serve_forever() self.server.serve_forever()

View File

@ -34,16 +34,11 @@ exclude =
mitmproxy/proxy/root_context.py mitmproxy/proxy/root_context.py
mitmproxy/proxy/server.py mitmproxy/proxy/server.py
mitmproxy/tools/ mitmproxy/tools/
mitmproxy/certs.py
mitmproxy/connections.py
mitmproxy/controller.py mitmproxy/controller.py
mitmproxy/export.py mitmproxy/export.py
mitmproxy/flow.py mitmproxy/flow.py
mitmproxy/flowfilter.py
mitmproxy/http.py
mitmproxy/io_compat.py mitmproxy/io_compat.py
mitmproxy/master.py mitmproxy/master.py
mitmproxy/optmanager.py
pathod/pathoc.py pathod/pathoc.py
pathod/pathod.py pathod/pathod.py
pathod/test.py pathod/test.py
@ -54,8 +49,6 @@ exclude =
mitmproxy/addonmanager.py mitmproxy/addonmanager.py
mitmproxy/addons/onboardingapp/app.py mitmproxy/addons/onboardingapp/app.py
mitmproxy/addons/termlog.py mitmproxy/addons/termlog.py
mitmproxy/certs.py
mitmproxy/connections.py
mitmproxy/contentviews/base.py mitmproxy/contentviews/base.py
mitmproxy/contentviews/wbxml.py mitmproxy/contentviews/wbxml.py
mitmproxy/contentviews/xml_html.py mitmproxy/contentviews/xml_html.py
@ -64,8 +57,6 @@ exclude =
mitmproxy/exceptions.py mitmproxy/exceptions.py
mitmproxy/export.py mitmproxy/export.py
mitmproxy/flow.py mitmproxy/flow.py
mitmproxy/flowfilter.py
mitmproxy/http.py
mitmproxy/io.py mitmproxy/io.py
mitmproxy/io_compat.py mitmproxy/io_compat.py
mitmproxy/log.py mitmproxy/log.py
@ -78,7 +69,6 @@ exclude =
mitmproxy/net/http/url.py mitmproxy/net/http/url.py
mitmproxy/net/tcp.py mitmproxy/net/tcp.py
mitmproxy/options.py mitmproxy/options.py
mitmproxy/optmanager.py
mitmproxy/proxy/config.py mitmproxy/proxy/config.py
mitmproxy/proxy/modes/http_proxy.py mitmproxy/proxy/modes/http_proxy.py
mitmproxy/proxy/modes/reverse_proxy.py mitmproxy/proxy/modes/reverse_proxy.py

View File

@ -113,7 +113,6 @@ setup(
], ],
'examples': [ 'examples': [
"beautifulsoup4>=4.4.1, <4.6", "beautifulsoup4>=4.4.1, <4.6",
"pytz>=2015.07.0, <=2016.10",
"Pillow>=3.2, <4.1", "Pillow>=3.2, <4.1",
] ]
} }

View File

@ -70,7 +70,7 @@ def test_simple():
flow.request = tutils.treq() flow.request = tutils.treq()
flow.request.stickycookie = True flow.request.stickycookie = True
flow.client_conn = mock.MagicMock() flow.client_conn = mock.MagicMock()
flow.client_conn.address.host = "foo" flow.client_conn.address[0] = "foo"
flow.response = tutils.tresp(content=None) flow.response = tutils.tresp(content=None)
flow.response.is_replay = True flow.response.is_replay = True
flow.response.status_code = 300 flow.response.status_code = 300
@ -176,7 +176,7 @@ def test_websocket():
ctx.configure(d, flow_detail=3, showhost=True) ctx.configure(d, flow_detail=3, showhost=True)
f = tflow.twebsocketflow() f = tflow.twebsocketflow()
d.websocket_message(f) d.websocket_message(f)
assert "hello text" in sio.getvalue() assert "it's me" in sio.getvalue()
sio.truncate(0) sio.truncate(0)
d.websocket_end(f) d.websocket_end(f)

View File

@ -0,0 +1,368 @@
import pytest
import requests
from examples.complex import xss_scanner as xss
from mitmproxy.test import tflow, tutils
class TestXSSScanner():
def test_get_XSS_info(self):
# First type of exploit: <script>PAYLOAD</script>
# Exploitable:
xss_info = xss.get_XSS_data(b"<html><script>%s</script><html>" %
xss.FULL_PAYLOAD,
"https://example.com",
"End of URL")
expected_xss_info = xss.XSSData('https://example.com',
"End of URL",
'</script><script>alert(0)</script><script>',
xss.FULL_PAYLOAD.decode('utf-8'))
assert xss_info == expected_xss_info
xss_info = xss.get_XSS_data(b"<html><script>%s</script><html>" %
xss.FULL_PAYLOAD.replace(b"'", b"%27").replace(b'"', b"%22"),
"https://example.com",
"End of URL")
expected_xss_info = xss.XSSData("https://example.com",
"End of URL",
'</script><script>alert(0)</script><script>',
xss.FULL_PAYLOAD.replace(b"'", b"%27").replace(b'"', b"%22").decode('utf-8'))
assert xss_info == expected_xss_info
# Non-Exploitable:
xss_info = xss.get_XSS_data(b"<html><script>%s</script><html>" %
xss.FULL_PAYLOAD.replace(b"'", b"%27").replace(b'"', b"%22").replace(b"/", b"%2F"),
"https://example.com",
"End of URL")
assert xss_info is None
# Second type of exploit: <script>t='PAYLOAD'</script>
# Exploitable:
xss_info = xss.get_XSS_data(b"<html><script>t='%s';</script></html>" %
xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E").replace(b"\"", b"%22"),
"https://example.com",
"End of URL")
expected_xss_info = xss.XSSData("https://example.com",
"End of URL",
"';alert(0);g='",
xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E")
.replace(b"\"", b"%22").decode('utf-8'))
assert xss_info == expected_xss_info
# Non-Exploitable:
xss_info = xss.get_XSS_data(b"<html><script>t='%s';</script></html>" %
xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b"\"", b"%22").replace(b"'", b"%22"),
"https://example.com",
"End of URL")
assert xss_info is None
# Third type of exploit: <script>t="PAYLOAD"</script>
# Exploitable:
xss_info = xss.get_XSS_data(b"<html><script>t=\"%s\";</script></html>" %
xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E").replace(b"'", b"%27"),
"https://example.com",
"End of URL")
expected_xss_info = xss.XSSData("https://example.com",
"End of URL",
'";alert(0);g="',
xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E")
.replace(b"'", b"%27").decode('utf-8'))
assert xss_info == expected_xss_info
# Non-Exploitable:
xss_info = xss.get_XSS_data(b"<html><script>t=\"%s\";</script></html>" %
xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b"'", b"%27").replace(b"\"", b"%22"),
"https://example.com",
"End of URL")
assert xss_info is None
# Fourth type of exploit: <a href='PAYLOAD'>Test</a>
# Exploitable:
xss_info = xss.get_XSS_data(b"<html><a href='%s'>Test</a></html>" %
xss.FULL_PAYLOAD,
"https://example.com",
"End of URL")
expected_xss_info = xss.XSSData("https://example.com",
"End of URL",
"'><script>alert(0)</script>",
xss.FULL_PAYLOAD.decode('utf-8'))
assert xss_info == expected_xss_info
# Non-Exploitable:
xss_info = xss.get_XSS_data(b"<html><a href='OtherStuff%s'>Test</a></html>" %
xss.FULL_PAYLOAD.replace(b"'", b"%27"),
"https://example.com",
"End of URL")
assert xss_info is None
# Fifth type of exploit: <a href="PAYLOAD">Test</a>
# Exploitable:
xss_info = xss.get_XSS_data(b"<html><a href=\"%s\">Test</a></html>" %
xss.FULL_PAYLOAD.replace(b"'", b"%27"),
"https://example.com",
"End of URL")
expected_xss_info = xss.XSSData("https://example.com",
"End of URL",
"\"><script>alert(0)</script>",
xss.FULL_PAYLOAD.replace(b"'", b"%27").decode('utf-8'))
assert xss_info == expected_xss_info
# Non-Exploitable:
xss_info = xss.get_XSS_data(b"<html><a href=\"OtherStuff%s\">Test</a></html>" %
xss.FULL_PAYLOAD.replace(b"'", b"%27").replace(b"\"", b"%22"),
"https://example.com",
"End of URL")
assert xss_info is None
# Sixth type of exploit: <a href=PAYLOAD>Test</a>
# Exploitable:
xss_info = xss.get_XSS_data(b"<html><a href=%s>Test</a></html>" %
xss.FULL_PAYLOAD,
"https://example.com",
"End of URL")
expected_xss_info = xss.XSSData("https://example.com",
"End of URL",
"><script>alert(0)</script>",
xss.FULL_PAYLOAD.decode('utf-8'))
assert xss_info == expected_xss_info
# Non-Exploitable
xss_info = xss.get_XSS_data(b"<html><a href=OtherStuff%s>Test</a></html>" %
xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E")
.replace(b"=", b"%3D"),
"https://example.com",
"End of URL")
assert xss_info is None
# Seventh type of exploit: <html>PAYLOAD</html>
# Exploitable:
xss_info = xss.get_XSS_data(b"<html><b>%s</b></html>" %
xss.FULL_PAYLOAD,
"https://example.com",
"End of URL")
expected_xss_info = xss.XSSData("https://example.com",
"End of URL",
"<script>alert(0)</script>",
xss.FULL_PAYLOAD.decode('utf-8'))
assert xss_info == expected_xss_info
# Non-Exploitable
xss_info = xss.get_XSS_data(b"<html><b>%s</b></html>" %
xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E").replace(b"/", b"%2F"),
"https://example.com",
"End of URL")
assert xss_info is None
# Eighth type of exploit: <a href=PAYLOAD>Test</a>
# Exploitable:
xss_info = xss.get_XSS_data(b"<html><a href=%s>Test</a></html>" %
xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E"),
"https://example.com",
"End of URL")
expected_xss_info = xss.XSSData("https://example.com",
"End of URL",
"Javascript:alert(0)",
xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E").decode('utf-8'))
assert xss_info == expected_xss_info
# Non-Exploitable:
xss_info = xss.get_XSS_data(b"<html><a href=OtherStuff%s>Test</a></html>" %
xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E")
.replace(b"=", b"%3D"),
"https://example.com",
"End of URL")
assert xss_info is None
# Ninth type of exploit: <a href="STUFF PAYLOAD">Test</a>
# Exploitable:
xss_info = xss.get_XSS_data(b"<html><a href=\"STUFF %s\">Test</a></html>" %
xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E"),
"https://example.com",
"End of URL")
expected_xss_info = xss.XSSData("https://example.com",
"End of URL",
'" onmouseover="alert(0)" t="',
xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E").decode('utf-8'))
assert xss_info == expected_xss_info
# Non-Exploitable:
xss_info = xss.get_XSS_data(b"<html><a href=\"STUFF %s\">Test</a></html>" %
xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E")
.replace(b'"', b"%22"),
"https://example.com",
"End of URL")
assert xss_info is None
# Tenth type of exploit: <a href='STUFF PAYLOAD'>Test</a>
# Exploitable:
xss_info = xss.get_XSS_data(b"<html><a href='STUFF %s'>Test</a></html>" %
xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E"),
"https://example.com",
"End of URL")
expected_xss_info = xss.XSSData("https://example.com",
"End of URL",
"' onmouseover='alert(0)' t='",
xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E").decode('utf-8'))
assert xss_info == expected_xss_info
# Non-Exploitable:
xss_info = xss.get_XSS_data(b"<html><a href='STUFF %s'>Test</a></html>" %
xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E")
.replace(b"'", b"%22"),
"https://example.com",
"End of URL")
assert xss_info is None
# Eleventh type of exploit: <a href=STUFF_PAYLOAD>Test</a>
# Exploitable:
xss_info = xss.get_XSS_data(b"<html><a href=STUFF%s>Test</a></html>" %
xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E"),
"https://example.com",
"End of URL")
expected_xss_info = xss.XSSData("https://example.com",
"End of URL",
" onmouseover=alert(0) t=",
xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E").decode('utf-8'))
assert xss_info == expected_xss_info
# Non-Exploitable:
xss_info = xss.get_XSS_data(b"<html><a href=STUFF_%s>Test</a></html>" %
xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E")
.replace(b"=", b"%3D"),
"https://example.com",
"End of URL")
assert xss_info is None
def test_get_SQLi_data(self):
sqli_data = xss.get_SQLi_data("<html>SQL syntax MySQL</html>",
"<html></html>",
"https://example.com",
"End of URL")
expected_sqli_data = xss.SQLiData("https://example.com",
"End of URL",
"SQL syntax.*MySQL",
"MySQL")
assert sqli_data == expected_sqli_data
sqli_data = xss.get_SQLi_data("<html>SQL syntax MySQL</html>",
"<html>SQL syntax MySQL</html>",
"https://example.com",
"End of URL")
assert sqli_data is None
def test_inside_quote(self):
assert not xss.inside_quote("'", b"no", 0, b"no")
assert xss.inside_quote("'", b"yes", 0, b"'yes'")
assert xss.inside_quote("'", b"yes", 1, b"'yes'otherJunk'yes'more")
assert not xss.inside_quote("'", b"longStringNotInIt", 1, b"short")
def test_paths_to_text(self):
text = xss.paths_to_text("""<html><head><h1>STRING</h1></head>
<script>STRING</script>
<a href=STRING></a></html>""", "STRING")
expected_text = ["/html/head/h1", "/html/script"]
assert text == expected_text
assert xss.paths_to_text("""<html></html>""", "STRING") == []
def mocked_requests_vuln(*args, headers=None, cookies=None):
class MockResponse:
def __init__(self, html, headers=None, cookies=None):
self.text = html
return MockResponse("<html>%s</html>" % xss.FULL_PAYLOAD)
def mocked_requests_invuln(*args, headers=None, cookies=None):
class MockResponse:
def __init__(self, html, headers=None, cookies=None):
self.text = html
return MockResponse("<html></html>")
def test_test_end_of_url_injection(self, monkeypatch):
monkeypatch.setattr(requests, 'get', self.mocked_requests_vuln)
xss_info = xss.test_end_of_URL_injection("<html></html>", "https://example.com/index.html", {})[0]
expected_xss_info = xss.XSSData('https://example.com/index.html/1029zxcs\'d"ao<ac>so[sb]po(pc)se;sl/bsl\\eq=3847asd',
'End of URL',
'<script>alert(0)</script>',
'1029zxcs\\\'d"ao<ac>so[sb]po(pc)se;sl/bsl\\\\eq=3847asd')
sqli_info = xss.test_end_of_URL_injection("<html></html>", "https://example.com/", {})[1]
assert xss_info == expected_xss_info
assert sqli_info is None
def test_test_referer_injection(self, monkeypatch):
monkeypatch.setattr(requests, 'get', self.mocked_requests_vuln)
xss_info = xss.test_referer_injection("<html></html>", "https://example.com/", {})[0]
expected_xss_info = xss.XSSData('https://example.com/',
'Referer',
'<script>alert(0)</script>',
'1029zxcs\\\'d"ao<ac>so[sb]po(pc)se;sl/bsl\\\\eq=3847asd')
sqli_info = xss.test_referer_injection("<html></html>", "https://example.com/", {})[1]
assert xss_info == expected_xss_info
assert sqli_info is None
def test_test_user_agent_injection(self, monkeypatch):
monkeypatch.setattr(requests, 'get', self.mocked_requests_vuln)
xss_info = xss.test_user_agent_injection("<html></html>", "https://example.com/", {})[0]
expected_xss_info = xss.XSSData('https://example.com/',
'User Agent',
'<script>alert(0)</script>',
'1029zxcs\\\'d"ao<ac>so[sb]po(pc)se;sl/bsl\\\\eq=3847asd')
sqli_info = xss.test_user_agent_injection("<html></html>", "https://example.com/", {})[1]
assert xss_info == expected_xss_info
assert sqli_info is None
def test_test_query_injection(self, monkeypatch):
monkeypatch.setattr(requests, 'get', self.mocked_requests_vuln)
xss_info = xss.test_query_injection("<html></html>", "https://example.com/vuln.php?cmd=ls", {})[0]
expected_xss_info = xss.XSSData('https://example.com/vuln.php?cmd=1029zxcs\'d"ao<ac>so[sb]po(pc)se;sl/bsl\\eq=3847asd',
'Query',
'<script>alert(0)</script>',
'1029zxcs\\\'d"ao<ac>so[sb]po(pc)se;sl/bsl\\\\eq=3847asd')
sqli_info = xss.test_query_injection("<html></html>", "https://example.com/vuln.php?cmd=ls", {})[1]
assert xss_info == expected_xss_info
assert sqli_info is None
@pytest.fixture
def logger(self):
class Logger():
def __init__(self):
self.args = []
def error(self, str):
self.args.append(str)
return Logger()
def test_find_unclaimed_URLs(self, monkeypatch, logger):
logger.args = []
monkeypatch.setattr("mitmproxy.ctx.log", logger)
xss.find_unclaimed_URLs("<html><script src=\"http://google.com\"></script></html>",
"https://example.com")
assert logger.args == []
xss.find_unclaimed_URLs("<html><script src=\"http://unclaimedDomainName.com\"></script></html>",
"https://example.com")
assert logger.args[0] == 'XSS found in https://example.com due to unclaimed URL "http://unclaimedDomainName.com" in script tag.'
def test_log_XSS_data(self, monkeypatch, logger):
logger.args = []
monkeypatch.setattr("mitmproxy.ctx.log", logger)
xss.log_XSS_data(None)
assert logger.args == []
# self, url: str, injection_point: str, exploit: str, line: str
xss.log_XSS_data(xss.XSSData('https://example.com',
'Location',
'String',
'Line of HTML'))
assert logger.args[0] == '===== XSS Found ===='
assert logger.args[1] == 'XSS URL: https://example.com'
assert logger.args[2] == 'Injection Point: Location'
assert logger.args[3] == 'Suggested Exploit: String'
assert logger.args[4] == 'Line: Line of HTML'
def test_log_SQLi_data(self, monkeypatch, logger):
logger.args = []
monkeypatch.setattr("mitmproxy.ctx.log", logger)
xss.log_SQLi_data(None)
assert logger.args == []
xss.log_SQLi_data(xss.SQLiData(b'https://example.com',
b'Location',
b'Oracle.*Driver',
b'Oracle'))
assert logger.args[0] == '===== SQLi Found ====='
assert logger.args[1] == 'SQLi URL: https://example.com'
assert logger.args[2] == 'Injection Point: Location'
assert logger.args[3] == 'Regex used: Oracle.*Driver'
def test_get_cookies(self):
mocked_req = tutils.treq()
mocked_req.cookies = [("cookieName2", "cookieValue2")]
mocked_flow = tflow.tflow(req=mocked_req)
# It only uses the request cookies
assert xss.get_cookies(mocked_flow) == {"cookieName2": "cookieValue2"}
def test_response(self, monkeypatch, logger):
logger.args = []
monkeypatch.setattr("mitmproxy.ctx.log", logger)
monkeypatch.setattr(requests, 'get', self.mocked_requests_invuln)
mocked_flow = tflow.tflow(req=tutils.treq(path=b"index.html?q=1"), resp=tutils.tresp(content=b'<html></html>'))
xss.response(mocked_flow)
assert logger.args == []
def test_data_equals(self):
xssData = xss.XSSData("a", "b", "c", "d")
sqliData = xss.SQLiData("a", "b", "c", "d")
assert xssData == xssData
assert sqliData == sqliData

View File

@ -11,3 +11,4 @@ def test_is_valid_host():
assert check.is_valid_host(b"one.two.") assert check.is_valid_host(b"one.two.")
# Allow underscore # Allow underscore
assert check.is_valid_host(b"one_two") assert check.is_valid_host(b"one_two")
assert check.is_valid_host(b"::1")

View File

@ -3,7 +3,6 @@ from io import BytesIO
import pytest import pytest
from mitmproxy.net import socks from mitmproxy.net import socks
from mitmproxy.net import tcp
from mitmproxy.test import tutils from mitmproxy.test import tutils
@ -176,7 +175,7 @@ def test_message_ipv6():
msg.to_file(out) msg.to_file(out)
assert out.getvalue() == raw.getvalue()[:-2] assert out.getvalue() == raw.getvalue()[:-2]
assert msg.addr.host == ipv6_addr assert msg.addr[0] == ipv6_addr
def test_message_invalid_host(): def test_message_invalid_host():
@ -196,6 +195,6 @@ def test_message_unknown_atyp():
with pytest.raises(socks.SocksError): with pytest.raises(socks.SocksError):
socks.Message.from_file(raw) socks.Message.from_file(raw)
m = socks.Message(5, 1, 0x02, tcp.Address(("example.com", 5050))) m = socks.Message(5, 1, 0x02, ("example.com", 5050))
with pytest.raises(socks.SocksError): with pytest.raises(socks.SocksError):
m.to_file(BytesIO()) m.to_file(BytesIO())

View File

@ -116,11 +116,11 @@ class TestServerBind(tservers.ServerTestBase):
class TestServerIPv6(tservers.ServerTestBase): class TestServerIPv6(tservers.ServerTestBase):
handler = EchoHandler handler = EchoHandler
addr = tcp.Address(("localhost", 0), use_ipv6=True) addr = ("::1", 0)
def test_echo(self): def test_echo(self):
testval = b"echo!\n" testval = b"echo!\n"
c = tcp.TCPClient(tcp.Address(("::1", self.port), use_ipv6=True)) c = tcp.TCPClient(("::1", self.port))
with c.connect(): with c.connect():
c.wfile.write(testval) c.wfile.write(testval)
c.wfile.flush() c.wfile.flush()
@ -132,7 +132,7 @@ class TestEcho(tservers.ServerTestBase):
def test_echo(self): def test_echo(self):
testval = b"echo!\n" testval = b"echo!\n"
c = tcp.TCPClient(("127.0.0.1", self.port)) c = tcp.TCPClient(("localhost", self.port))
with c.connect(): with c.connect():
c.wfile.write(testval) c.wfile.write(testval)
c.wfile.flush() c.wfile.flush()
@ -602,12 +602,6 @@ class TestDHParams(tservers.ServerTestBase):
ret = c.get_current_cipher() ret = c.get_current_cipher()
assert ret[0] == "DHE-RSA-AES256-SHA" assert ret[0] == "DHE-RSA-AES256-SHA"
def test_create_dhparams(self):
with tutils.tmpdir() as d:
filename = os.path.join(d, "dhparam.pem")
certs.CertStore.load_dhparam(filename)
assert os.path.exists(filename)
class TestTCPClient: class TestTCPClient:
@ -783,18 +777,6 @@ class TestPeekSSL(TestPeek):
return conn.pop() return conn.pop()
class TestAddress:
def test_simple(self):
a = tcp.Address(("localhost", 80), True)
assert a.use_ipv6
b = tcp.Address(("foo.com", 80), True)
assert not a == b
c = tcp.Address(("localhost", 80), True)
assert a == c
assert not a != c
assert repr(a) == "localhost:80"
class TestSSLKeyLogger(tservers.ServerTestBase): class TestSSLKeyLogger(tservers.ServerTestBase):
handler = EchoHandler handler = EchoHandler
ssl = dict( ssl = dict(

View File

@ -86,13 +86,13 @@ class _TServer(tcp.TCPServer):
class ServerTestBase: class ServerTestBase:
ssl = None ssl = None
handler = None handler = None
addr = ("localhost", 0) addr = ("127.0.0.1", 0)
@classmethod @classmethod
def setup_class(cls, **kwargs): def setup_class(cls, **kwargs):
cls.q = queue.Queue() cls.q = queue.Queue()
s = cls.makeserver(**kwargs) s = cls.makeserver(**kwargs)
cls.port = s.address.port cls.port = s.address[1]
cls.server = _ServerThread(s) cls.server = _ServerThread(s)
cls.server.start() cls.server.start()

View File

@ -124,10 +124,10 @@ class _Http2TestBase:
b'CONNECT', b'CONNECT',
b'', b'',
b'localhost', b'localhost',
self.server.server.address.port, self.server.server.address[1],
b'/', b'/',
b'HTTP/1.1', b'HTTP/1.1',
[(b'host', b'localhost:%d' % self.server.server.address.port)], [(b'host', b'localhost:%d' % self.server.server.address[1])],
b'', b'',
))) )))
client.wfile.flush() client.wfile.flush()
@ -231,7 +231,7 @@ class TestSimple(_Http2Test):
client.wfile, client.wfile,
h2_conn, h2_conn,
headers=[ headers=[
(':authority', "127.0.0.1:{}".format(self.server.server.address.port)), (':authority', "127.0.0.1:{}".format(self.server.server.address[1])),
(':method', 'GET'), (':method', 'GET'),
(':scheme', 'https'), (':scheme', 'https'),
(':path', '/'), (':path', '/'),
@ -271,75 +271,6 @@ class TestSimple(_Http2Test):
assert response_body_buffer == b'response body' assert response_body_buffer == b'response body'
@requires_alpn
class TestForbiddenHeaders(_Http2Test):
@classmethod
def handle_server_event(cls, event, h2_conn, rfile, wfile):
if isinstance(event, h2.events.ConnectionTerminated):
return False
elif isinstance(event, h2.events.StreamEnded):
import warnings
with warnings.catch_warnings():
# Ignore UnicodeWarning:
# h2/utilities.py:64: UnicodeWarning: Unicode equal comparison
# failed to convert both arguments to Unicode - interpreting
# them as being unequal.
# elif header[0] in (b'cookie', u'cookie') and len(header[1]) < 20:
warnings.simplefilter("ignore")
h2_conn.config.validate_outbound_headers = False
h2_conn.send_headers(event.stream_id, [
(':status', '200'),
('keep-alive', 'foobar'),
])
h2_conn.send_data(event.stream_id, b'response body')
h2_conn.end_stream(event.stream_id)
wfile.write(h2_conn.data_to_send())
wfile.flush()
return True
def test_forbidden_headers(self):
client, h2_conn = self._setup_connection()
self._send_request(
client.wfile,
h2_conn,
headers=[
(':authority', "127.0.0.1:{}".format(self.server.server.address.port)),
(':method', 'GET'),
(':scheme', 'https'),
(':path', '/'),
])
done = False
while not done:
try:
raw = b''.join(http2.read_raw_frame(client.rfile))
events = h2_conn.receive_data(raw)
except exceptions.HttpException:
print(traceback.format_exc())
assert False
client.wfile.write(h2_conn.data_to_send())
client.wfile.flush()
for event in events:
if isinstance(event, h2.events.ResponseReceived):
assert 'keep-alive' not in event.headers
elif isinstance(event, h2.events.StreamEnded):
done = True
h2_conn.close_connection()
client.wfile.write(h2_conn.data_to_send())
client.wfile.flush()
assert len(self.master.state.flows) == 1
assert self.master.state.flows[0].response.status_code == 200
assert self.master.state.flows[0].response.headers['keep-alive'] == 'foobar'
@requires_alpn @requires_alpn
class TestRequestWithPriority(_Http2Test): class TestRequestWithPriority(_Http2Test):
@ -384,7 +315,7 @@ class TestRequestWithPriority(_Http2Test):
client.wfile, client.wfile,
h2_conn, h2_conn,
headers=[ headers=[
(':authority', "127.0.0.1:{}".format(self.server.server.address.port)), (':authority', "127.0.0.1:{}".format(self.server.server.address[1])),
(':method', 'GET'), (':method', 'GET'),
(':scheme', 'https'), (':scheme', 'https'),
(':path', '/'), (':path', '/'),
@ -469,7 +400,7 @@ class TestPriority(_Http2Test):
client.wfile, client.wfile,
h2_conn, h2_conn,
headers=[ headers=[
(':authority', "127.0.0.1:{}".format(self.server.server.address.port)), (':authority', "127.0.0.1:{}".format(self.server.server.address[1])),
(':method', 'GET'), (':method', 'GET'),
(':scheme', 'https'), (':scheme', 'https'),
(':path', '/'), (':path', '/'),
@ -527,7 +458,7 @@ class TestStreamResetFromServer(_Http2Test):
client.wfile, client.wfile,
h2_conn, h2_conn,
headers=[ headers=[
(':authority', "127.0.0.1:{}".format(self.server.server.address.port)), (':authority', "127.0.0.1:{}".format(self.server.server.address[1])),
(':method', 'GET'), (':method', 'GET'),
(':scheme', 'https'), (':scheme', 'https'),
(':path', '/'), (':path', '/'),
@ -576,7 +507,7 @@ class TestBodySizeLimit(_Http2Test):
client.wfile, client.wfile,
h2_conn, h2_conn,
headers=[ headers=[
(':authority', "127.0.0.1:{}".format(self.server.server.address.port)), (':authority', "127.0.0.1:{}".format(self.server.server.address[1])),
(':method', 'GET'), (':method', 'GET'),
(':scheme', 'https'), (':scheme', 'https'),
(':path', '/'), (':path', '/'),
@ -672,7 +603,7 @@ class TestPushPromise(_Http2Test):
client, h2_conn = self._setup_connection() client, h2_conn = self._setup_connection()
self._send_request(client.wfile, h2_conn, stream_id=1, headers=[ self._send_request(client.wfile, h2_conn, stream_id=1, headers=[
(':authority', "127.0.0.1:{}".format(self.server.server.address.port)), (':authority', "127.0.0.1:{}".format(self.server.server.address[1])),
(':method', 'GET'), (':method', 'GET'),
(':scheme', 'https'), (':scheme', 'https'),
(':path', '/'), (':path', '/'),
@ -728,7 +659,7 @@ class TestPushPromise(_Http2Test):
client, h2_conn = self._setup_connection() client, h2_conn = self._setup_connection()
self._send_request(client.wfile, h2_conn, stream_id=1, headers=[ self._send_request(client.wfile, h2_conn, stream_id=1, headers=[
(':authority', "127.0.0.1:{}".format(self.server.server.address.port)), (':authority', "127.0.0.1:{}".format(self.server.server.address[1])),
(':method', 'GET'), (':method', 'GET'),
(':scheme', 'https'), (':scheme', 'https'),
(':path', '/'), (':path', '/'),
@ -791,7 +722,7 @@ class TestConnectionLost(_Http2Test):
client, h2_conn = self._setup_connection() client, h2_conn = self._setup_connection()
self._send_request(client.wfile, h2_conn, stream_id=1, headers=[ self._send_request(client.wfile, h2_conn, stream_id=1, headers=[
(':authority', "127.0.0.1:{}".format(self.server.server.address.port)), (':authority', "127.0.0.1:{}".format(self.server.server.address[1])),
(':method', 'GET'), (':method', 'GET'),
(':scheme', 'https'), (':scheme', 'https'),
(':path', '/'), (':path', '/'),
@ -848,7 +779,7 @@ class TestMaxConcurrentStreams(_Http2Test):
# this will exceed MAX_CONCURRENT_STREAMS on the server connection # this will exceed MAX_CONCURRENT_STREAMS on the server connection
# and cause mitmproxy to throttle stream creation to the server # and cause mitmproxy to throttle stream creation to the server
self._send_request(client.wfile, h2_conn, stream_id=stream_id, headers=[ self._send_request(client.wfile, h2_conn, stream_id=stream_id, headers=[
(':authority', "127.0.0.1:{}".format(self.server.server.address.port)), (':authority', "127.0.0.1:{}".format(self.server.server.address[1])),
(':method', 'GET'), (':method', 'GET'),
(':scheme', 'https'), (':scheme', 'https'),
(':path', '/'), (':path', '/'),
@ -894,7 +825,7 @@ class TestConnectionTerminated(_Http2Test):
client, h2_conn = self._setup_connection() client, h2_conn = self._setup_connection()
self._send_request(client.wfile, h2_conn, headers=[ self._send_request(client.wfile, h2_conn, headers=[
(':authority', "127.0.0.1:{}".format(self.server.server.address.port)), (':authority', "127.0.0.1:{}".format(self.server.server.address[1])),
(':method', 'GET'), (':method', 'GET'),
(':scheme', 'https'), (':scheme', 'https'),
(':path', '/'), (':path', '/'),

View File

@ -87,8 +87,8 @@ class _WebSocketTestBase:
"authority", "authority",
"CONNECT", "CONNECT",
"", "",
"localhost", "127.0.0.1",
self.server.server.address.port, self.server.server.address[1],
"", "",
"HTTP/1.1", "HTTP/1.1",
content=b'') content=b'')
@ -105,8 +105,8 @@ class _WebSocketTestBase:
"relative", "relative",
"GET", "GET",
"http", "http",
"localhost", "127.0.0.1",
self.server.server.address.port, self.server.server.address[1],
"/ws", "/ws",
"HTTP/1.1", "HTTP/1.1",
headers=http.Headers( headers=http.Headers(

View File

@ -17,7 +17,6 @@ from mitmproxy.net import socks
from mitmproxy import certs from mitmproxy import certs
from mitmproxy import exceptions from mitmproxy import exceptions
from mitmproxy.net.http import http1 from mitmproxy.net.http import http1
from mitmproxy.net.tcp import Address
from pathod import pathoc from pathod import pathoc
from pathod import pathod from pathod import pathod
@ -581,7 +580,7 @@ class TestHttps2Http(tservers.ReverseProxyTest):
def get_options(cls): def get_options(cls):
opts = super().get_options() opts = super().get_options()
s = parse_server_spec(opts.upstream_server) s = parse_server_spec(opts.upstream_server)
opts.upstream_server = "http://%s" % s.address opts.upstream_server = "http://{}:{}".format(s.address[0], s.address[1])
return opts return opts
def pathoc(self, ssl, sni=None): def pathoc(self, ssl, sni=None):
@ -740,7 +739,7 @@ class MasterRedirectRequest(tservers.TestMaster):
# This part should have no impact, but it should also not cause any exceptions. # This part should have no impact, but it should also not cause any exceptions.
addr = f.live.server_conn.address addr = f.live.server_conn.address
addr2 = Address(("127.0.0.1", self.redirect_port)) addr2 = ("127.0.0.1", self.redirect_port)
f.live.set_server(addr2) f.live.set_server(addr2)
f.live.set_server(addr) f.live.set_server(addr)
@ -750,8 +749,8 @@ class MasterRedirectRequest(tservers.TestMaster):
@controller.handler @controller.handler
def response(self, f): def response(self, f):
f.response.content = bytes(f.client_conn.address.port) f.response.content = bytes(f.client_conn.address[1])
f.response.headers["server-conn-id"] = str(f.server_conn.source_address.port) f.response.headers["server-conn-id"] = str(f.server_conn.source_address[1])
super().response(f) super().response(f)

View File

@ -117,6 +117,12 @@ class TestCertStore:
ret = ca1.get_cert(b"foo.com", []) ret = ca1.get_cert(b"foo.com", [])
assert ret[0].serial == dc[0].serial assert ret[0].serial == dc[0].serial
def test_create_dhparams(self):
with tutils.tmpdir() as d:
filename = os.path.join(d, "dhparam.pem")
certs.CertStore.load_dhparam(filename)
assert os.path.exists(filename)
class TestDummyCert: class TestDummyCert:
@ -127,9 +133,10 @@ class TestDummyCert:
ca.default_privatekey, ca.default_privatekey,
ca.default_ca, ca.default_ca,
b"foo.com", b"foo.com",
[b"one.com", b"two.com", b"*.three.com"] [b"one.com", b"two.com", b"*.three.com", b"127.0.0.1"]
) )
assert r.cn == b"foo.com" assert r.cn == b"foo.com"
assert r.altnames == [b'one.com', b'two.com', b'*.three.com']
r = certs.dummy_cert( r = certs.dummy_cert(
ca.default_privatekey, ca.default_privatekey,
@ -138,6 +145,7 @@ class TestDummyCert:
[] []
) )
assert r.cn is None assert r.cn is None
assert r.altnames == []
class TestSSLCert: class TestSSLCert:
@ -179,3 +187,20 @@ class TestSSLCert:
d = f.read() d = f.read()
s = certs.SSLCert.from_der(d) s = certs.SSLCert.from_der(d)
assert s.cn assert s.cn
def test_state(self):
with open(tutils.test_data.path("mitmproxy/net/data/text_cert"), "rb") as f:
d = f.read()
c = certs.SSLCert.from_pem(d)
c.get_state()
c2 = c.copy()
a = c.get_state()
b = c2.get_state()
assert a == b
assert c == c2
assert c is not c2
x = certs.SSLCert('')
x.set_state(a)
assert x == c

View File

@ -1 +1,210 @@
# TODO: write tests import socket
import os
import threading
import ssl
import OpenSSL
import pytest
from unittest import mock
from mitmproxy import connections
from mitmproxy import exceptions
from mitmproxy.net import tcp
from mitmproxy.net.http import http1
from mitmproxy.test import tflow
from mitmproxy.test import tutils
from .net import tservers
from pathod import test
class TestClientConnection:
def test_send(self):
c = tflow.tclient_conn()
c.send(b'foobar')
c.send([b'foo', b'bar'])
with pytest.raises(TypeError):
c.send('string')
with pytest.raises(TypeError):
c.send(['string', 'not'])
assert c.wfile.getvalue() == b'foobarfoobar'
def test_repr(self):
c = tflow.tclient_conn()
assert 'address:22' in repr(c)
assert 'ALPN' in repr(c)
assert 'TLS' not in repr(c)
c.alpn_proto_negotiated = None
c.tls_established = True
assert 'ALPN' not in repr(c)
assert 'TLS' in repr(c)
def test_tls_established_property(self):
c = tflow.tclient_conn()
c.tls_established = True
assert c.ssl_established
assert c.tls_established
c.tls_established = False
assert not c.ssl_established
assert not c.tls_established
def test_make_dummy(self):
c = connections.ClientConnection.make_dummy(('foobar', 1234))
assert c.address == ('foobar', 1234)
def test_state(self):
c = tflow.tclient_conn()
assert connections.ClientConnection.from_state(c.get_state()).get_state() == \
c.get_state()
c2 = tflow.tclient_conn()
c2.address = (c2.address[0], 4242)
assert not c == c2
c2.timestamp_start = 42
c.set_state(c2.get_state())
assert c.timestamp_start == 42
c3 = c.copy()
assert c3.get_state() == c.get_state()
class TestServerConnection:
def test_send(self):
c = tflow.tserver_conn()
c.send(b'foobar')
c.send([b'foo', b'bar'])
with pytest.raises(TypeError):
c.send('string')
with pytest.raises(TypeError):
c.send(['string', 'not'])
assert c.wfile.getvalue() == b'foobarfoobar'
def test_repr(self):
c = tflow.tserver_conn()
c.sni = 'foobar'
c.tls_established = True
c.alpn_proto_negotiated = b'h2'
assert 'address:22' in repr(c)
assert 'ALPN' in repr(c)
assert 'TLS: foobar' in repr(c)
c.sni = None
c.tls_established = True
c.alpn_proto_negotiated = None
assert 'ALPN' not in repr(c)
assert 'TLS' in repr(c)
c.sni = None
c.tls_established = False
assert 'TLS' not in repr(c)
def test_tls_established_property(self):
c = tflow.tserver_conn()
c.tls_established = True
assert c.ssl_established
assert c.tls_established
c.tls_established = False
assert not c.ssl_established
assert not c.tls_established
def test_make_dummy(self):
c = connections.ServerConnection.make_dummy(('foobar', 1234))
assert c.address == ('foobar', 1234)
def test_simple(self):
d = test.Daemon()
c = connections.ServerConnection((d.IFACE, d.port))
c.connect()
f = tflow.tflow()
f.server_conn = c
f.request.path = "/p/200:da"
# use this protocol just to assemble - not for actual sending
c.wfile.write(http1.assemble_request(f.request))
c.wfile.flush()
assert http1.read_response(c.rfile, f.request, 1000)
assert d.last_log()
c.finish()
d.shutdown()
def test_terminate_error(self):
d = test.Daemon()
c = connections.ServerConnection((d.IFACE, d.port))
c.connect()
c.connection = mock.Mock()
c.connection.recv = mock.Mock(return_value=False)
c.connection.flush = mock.Mock(side_effect=exceptions.TcpDisconnect)
c.finish()
d.shutdown()
def test_sni(self):
c = connections.ServerConnection(('', 1234))
with pytest.raises(ValueError, matches='sni must be str, not '):
c.establish_ssl(None, b'foobar')
class TestClientConnectionTLS:
@pytest.mark.parametrize("sni", [
None,
"example.com"
])
def test_tls_with_sni(self, sni):
address = ('127.0.0.1', 0)
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(address)
sock.listen()
address = sock.getsockname()
def client_run():
ctx = ssl.create_default_context()
ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE
s = socket.create_connection(address)
s = ctx.wrap_socket(s, server_hostname=sni)
s.send(b'foobar')
s.shutdown(socket.SHUT_RDWR)
threading.Thread(target=client_run).start()
connection, client_address = sock.accept()
c = connections.ClientConnection(connection, client_address, None)
cert = tutils.test_data.path("mitmproxy/net/data/server.crt")
key = OpenSSL.crypto.load_privatekey(
OpenSSL.crypto.FILETYPE_PEM,
open(tutils.test_data.path("mitmproxy/net/data/server.key"), "rb").read())
c.convert_to_ssl(cert, key)
assert c.connected()
assert c.sni == sni
assert c.tls_established
assert c.rfile.read(6) == b'foobar'
c.finish()
class TestServerConnectionTLS(tservers.ServerTestBase):
ssl = True
class handler(tcp.BaseHandler):
def handle(self):
self.finish()
@pytest.mark.parametrize("clientcert", [
None,
tutils.test_data.path("mitmproxy/data/clientcert"),
os.path.join(tutils.test_data.path("mitmproxy/data/clientcert"), "client.pem"),
])
def test_tls(self, clientcert):
c = connections.ServerConnection(("127.0.0.1", self.port))
c.connect()
c.establish_ssl(clientcert, "foo.com")
assert c.connected()
assert c.sni == "foo.com"
assert c.tls_established
c.close()
c.finish()

View File

@ -32,6 +32,8 @@ def test_websocket_flow(err):
assert len(f.messages) == 1 assert len(f.messages) == 1
assert next(i) == ("websocket_message", f) assert next(i) == ("websocket_message", f)
assert len(f.messages) == 2 assert len(f.messages) == 2
assert next(i) == ("websocket_message", f)
assert len(f.messages) == 3
if err: if err:
assert next(i) == ("websocket_error", f) assert next(i) == ("websocket_error", f)
assert next(i) == ("websocket_end", f) assert next(i) == ("websocket_end", f)

View File

@ -2,160 +2,18 @@ import io
import pytest import pytest
from mitmproxy.test import tflow from mitmproxy.test import tflow
from mitmproxy.net.http import Headers
import mitmproxy.io import mitmproxy.io
from mitmproxy import flowfilter, options from mitmproxy import flowfilter, options
from mitmproxy.contrib import tnetstring from mitmproxy.contrib import tnetstring
from mitmproxy.exceptions import FlowReadException, Kill from mitmproxy.exceptions import FlowReadException
from mitmproxy import flow from mitmproxy import flow
from mitmproxy import http from mitmproxy import http
from mitmproxy import connections
from mitmproxy.proxy import ProxyConfig from mitmproxy.proxy import ProxyConfig
from mitmproxy.proxy.server import DummyServer from mitmproxy.proxy.server import DummyServer
from mitmproxy import master from mitmproxy import master
from . import tservers from . import tservers
class TestHTTPFlow:
def test_copy(self):
f = tflow.tflow(resp=True)
f.get_state()
f2 = f.copy()
a = f.get_state()
b = f2.get_state()
del a["id"]
del b["id"]
assert a == b
assert not f == f2
assert f is not f2
assert f.request.get_state() == f2.request.get_state()
assert f.request is not f2.request
assert f.request.headers == f2.request.headers
assert f.request.headers is not f2.request.headers
assert f.response.get_state() == f2.response.get_state()
assert f.response is not f2.response
f = tflow.tflow(err=True)
f2 = f.copy()
assert f is not f2
assert f.request is not f2.request
assert f.request.headers == f2.request.headers
assert f.request.headers is not f2.request.headers
assert f.error.get_state() == f2.error.get_state()
assert f.error is not f2.error
def test_match(self):
f = tflow.tflow(resp=True)
assert not flowfilter.match("~b test", f)
assert flowfilter.match(None, f)
assert not flowfilter.match("~b test", f)
f = tflow.tflow(err=True)
assert flowfilter.match("~e", f)
with pytest.raises(ValueError):
flowfilter.match("~", f)
def test_backup(self):
f = tflow.tflow()
f.response = http.HTTPResponse.wrap(mitmproxy.test.tutils.tresp())
f.request.content = b"foo"
assert not f.modified()
f.backup()
f.request.content = b"bar"
assert f.modified()
f.revert()
assert f.request.content == b"foo"
def test_backup_idempotence(self):
f = tflow.tflow(resp=True)
f.backup()
f.revert()
f.backup()
f.revert()
def test_getset_state(self):
f = tflow.tflow(resp=True)
state = f.get_state()
assert f.get_state() == http.HTTPFlow.from_state(
state).get_state()
f.response = None
f.error = flow.Error("error")
state = f.get_state()
assert f.get_state() == http.HTTPFlow.from_state(
state).get_state()
f2 = f.copy()
f2.id = f.id # copy creates a different uuid
assert f.get_state() == f2.get_state()
assert not f == f2
f2.error = flow.Error("e2")
assert not f == f2
f.set_state(f2.get_state())
assert f.get_state() == f2.get_state()
def test_kill(self):
f = tflow.tflow()
f.reply.handle()
f.intercept()
assert f.killable
f.kill()
assert not f.killable
assert f.reply.value == Kill
def test_resume(self):
f = tflow.tflow()
f.reply.handle()
f.intercept()
assert f.reply.state == "taken"
f.resume()
assert f.reply.state == "committed"
def test_replace_unicode(self):
f = tflow.tflow(resp=True)
f.response.content = b"\xc2foo"
f.replace(b"foo", u"bar")
def test_replace_no_content(self):
f = tflow.tflow()
f.request.content = None
assert f.replace("foo", "bar") == 0
def test_replace(self):
f = tflow.tflow(resp=True)
f.request.headers["foo"] = "foo"
f.request.content = b"afoob"
f.response.headers["foo"] = "foo"
f.response.content = b"afoob"
assert f.replace("foo", "bar") == 6
assert f.request.headers["bar"] == "bar"
assert f.request.content == b"abarb"
assert f.response.headers["bar"] == "bar"
assert f.response.content == b"abarb"
def test_replace_encoded(self):
f = tflow.tflow(resp=True)
f.request.content = b"afoob"
f.request.encode("gzip")
f.response.content = b"afoob"
f.response.encode("gzip")
f.replace("foo", "bar")
assert f.request.raw_content != b"abarb"
f.request.decode()
assert f.request.raw_content == b"abarb"
assert f.response.raw_content != b"abarb"
f.response.decode()
assert f.response.raw_content == b"abarb"
class TestSerialize: class TestSerialize:
def _treader(self): def _treader(self):
@ -307,88 +165,6 @@ class TestFlowMaster:
fm.shutdown() fm.shutdown()
class TestRequest:
def test_simple(self):
f = tflow.tflow()
r = f.request
u = r.url
r.url = u
with pytest.raises(ValueError):
setattr(r, "url", "")
assert r.url == u
r2 = r.copy()
assert r.get_state() == r2.get_state()
def test_get_url(self):
r = http.HTTPRequest.wrap(mitmproxy.test.tutils.treq())
assert r.url == "http://address:22/path"
r.scheme = "https"
assert r.url == "https://address:22/path"
r.host = "host"
r.port = 42
assert r.url == "https://host:42/path"
r.host = "address"
r.port = 22
assert r.url == "https://address:22/path"
assert r.pretty_url == "https://address:22/path"
r.headers["Host"] = "foo.com:22"
assert r.url == "https://address:22/path"
assert r.pretty_url == "https://foo.com:22/path"
def test_replace(self):
r = http.HTTPRequest.wrap(mitmproxy.test.tutils.treq())
r.path = "path/foo"
r.headers["Foo"] = "fOo"
r.content = b"afoob"
assert r.replace("foo(?i)", "boo") == 4
assert r.path == "path/boo"
assert b"foo" not in r.content
assert r.headers["boo"] == "boo"
def test_constrain_encoding(self):
r = http.HTTPRequest.wrap(mitmproxy.test.tutils.treq())
r.headers["accept-encoding"] = "gzip, oink"
r.constrain_encoding()
assert "oink" not in r.headers["accept-encoding"]
r.headers.set_all("accept-encoding", ["gzip", "oink"])
r.constrain_encoding()
assert "oink" not in r.headers["accept-encoding"]
def test_get_content_type(self):
resp = http.HTTPResponse.wrap(mitmproxy.test.tutils.tresp())
resp.headers = Headers(content_type="text/plain")
assert resp.headers["content-type"] == "text/plain"
class TestResponse:
def test_simple(self):
f = tflow.tflow(resp=True)
resp = f.response
resp2 = resp.copy()
assert resp2.get_state() == resp.get_state()
def test_replace(self):
r = http.HTTPResponse.wrap(mitmproxy.test.tutils.tresp())
r.headers["Foo"] = "fOo"
r.content = b"afoob"
assert r.replace("foo(?i)", "boo") == 3
assert b"foo" not in r.content
assert r.headers["boo"] == "boo"
def test_get_content_type(self):
resp = http.HTTPResponse.wrap(mitmproxy.test.tutils.tresp())
resp.headers = Headers(content_type="text/plain")
assert resp.headers["content-type"] == "text/plain"
class TestError: class TestError:
def test_getset_state(self): def test_getset_state(self):
@ -409,23 +185,4 @@ class TestError:
def test_repr(self): def test_repr(self):
e = flow.Error("yay") e = flow.Error("yay")
assert repr(e) assert repr(e)
assert str(e)
class TestClientConnection:
def test_state(self):
c = tflow.tclient_conn()
assert connections.ClientConnection.from_state(c.get_state()).get_state() == \
c.get_state()
c2 = tflow.tclient_conn()
c2.address.address = (c2.address.host, 4242)
assert not c == c2
c2.timestamp_start = 42
c.set_state(c2.get_state())
assert c.timestamp_start == 42
c3 = c.copy()
assert c3.get_state() == c.get_state()
assert str(c)

View File

@ -1,4 +1,5 @@
import io import io
import pytest
from unittest.mock import patch from unittest.mock import patch
from mitmproxy.test import tflow from mitmproxy.test import tflow
@ -134,6 +135,12 @@ class TestMatchingHTTPFlow:
e = self.err() e = self.err()
assert self.q("~e", e) assert self.q("~e", e)
def test_fmarked(self):
q = self.req()
assert not self.q("~marked", q)
q.marked = True
assert self.q("~marked", q)
def test_head(self): def test_head(self):
q = self.req() q = self.req()
s = self.resp() s = self.resp()
@ -221,6 +228,11 @@ class TestMatchingHTTPFlow:
assert not self.q("~src :99", q) assert not self.q("~src :99", q)
assert self.q("~src address:22", q) assert self.q("~src address:22", q)
q.client_conn.address = None
assert not self.q('~src address:22', q)
q.client_conn = None
assert not self.q('~src address:22', q)
def test_dst(self): def test_dst(self):
q = self.req() q = self.req()
q.server_conn = tflow.tserver_conn() q.server_conn = tflow.tserver_conn()
@ -230,6 +242,11 @@ class TestMatchingHTTPFlow:
assert not self.q("~dst :99", q) assert not self.q("~dst :99", q)
assert self.q("~dst address:22", q) assert self.q("~dst address:22", q)
q.server_conn.address = None
assert not self.q('~dst address:22', q)
q.server_conn = None
assert not self.q('~dst address:22', q)
def test_and(self): def test_and(self):
s = self.resp() s = self.resp()
assert self.q("~c 200 & ~h head", s) assert self.q("~c 200 & ~h head", s)
@ -269,6 +286,7 @@ class TestMatchingTCPFlow:
f = self.flow() f = self.flow()
assert self.q("~tcp", f) assert self.q("~tcp", f)
assert not self.q("~http", f) assert not self.q("~http", f)
assert not self.q("~websocket", f)
def test_ferr(self): def test_ferr(self):
e = self.err() e = self.err()
@ -378,6 +396,87 @@ class TestMatchingTCPFlow:
assert not self.q("~u whatever", f) assert not self.q("~u whatever", f)
class TestMatchingWebSocketFlow:
def flow(self):
return tflow.twebsocketflow()
def err(self):
return tflow.twebsocketflow(err=True)
def q(self, q, o):
return flowfilter.parse(q)(o)
def test_websocket(self):
f = self.flow()
assert self.q("~websocket", f)
assert not self.q("~tcp", f)
assert not self.q("~http", f)
def test_ferr(self):
e = self.err()
assert self.q("~e", e)
def test_body(self):
f = self.flow()
# Messages sent by client or server
assert self.q("~b hello", f)
assert self.q("~b me", f)
assert not self.q("~b nonexistent", f)
# Messages sent by client
assert self.q("~bq hello", f)
assert not self.q("~bq me", f)
assert not self.q("~bq nonexistent", f)
# Messages sent by server
assert self.q("~bs me", f)
assert not self.q("~bs hello", f)
assert not self.q("~bs nonexistent", f)
def test_src(self):
f = self.flow()
assert self.q("~src address", f)
assert not self.q("~src foobar", f)
assert self.q("~src :22", f)
assert not self.q("~src :99", f)
assert self.q("~src address:22", f)
def test_dst(self):
f = self.flow()
f.server_conn = tflow.tserver_conn()
assert self.q("~dst address", f)
assert not self.q("~dst foobar", f)
assert self.q("~dst :22", f)
assert not self.q("~dst :99", f)
assert self.q("~dst address:22", f)
def test_and(self):
f = self.flow()
f.server_conn = tflow.tserver_conn()
assert self.q("~b hello & ~b me", f)
assert not self.q("~src wrongaddress & ~b hello", f)
assert self.q("(~src :22 & ~dst :22) & ~b hello", f)
assert not self.q("(~src address:22 & ~dst :22) & ~b nonexistent", f)
assert not self.q("(~src address:22 & ~dst :99) & ~b hello", f)
def test_or(self):
f = self.flow()
f.server_conn = tflow.tserver_conn()
assert self.q("~b hello | ~b me", f)
assert self.q("~src :22 | ~b me", f)
assert not self.q("~src :99 | ~dst :99", f)
assert self.q("(~src :22 | ~dst :22) | ~b me", f)
def test_not(self):
f = self.flow()
assert not self.q("! ~src :22", f)
assert self.q("! ~src :99", f)
assert self.q("!~src :99 !~src :99", f)
assert not self.q("!~src :99 !~src :22", f)
class TestMatchingDummyFlow: class TestMatchingDummyFlow:
def flow(self): def flow(self):
@ -411,6 +510,8 @@ class TestMatchingDummyFlow:
assert not self.q("~e", f) assert not self.q("~e", f)
assert not self.q("~http", f) assert not self.q("~http", f)
assert not self.q("~tcp", f)
assert not self.q("~websocket", f)
assert not self.q("~h whatever", f) assert not self.q("~h whatever", f)
assert not self.q("~hq whatever", f) assert not self.q("~hq whatever", f)
@ -440,3 +541,11 @@ def test_pyparsing_bug(extract_tb):
# The text is a string with leading and trailing whitespace stripped; if the source is not available it is None. # The text is a string with leading and trailing whitespace stripped; if the source is not available it is None.
extract_tb.return_value = [("", 1, "test", None)] extract_tb.return_value = [("", 1, "test", None)]
assert flowfilter.parse("test") assert flowfilter.parse("test")
def test_match():
with pytest.raises(ValueError):
flowfilter.match('[foobar', None)
assert flowfilter.match(None, None)
assert not flowfilter.match('foobar', None)

View File

@ -1 +1,256 @@
# TODO: write tests import pytest
from mitmproxy.test import tflow
from mitmproxy.net.http import Headers
import mitmproxy.io
from mitmproxy import flowfilter
from mitmproxy.exceptions import Kill
from mitmproxy import flow
from mitmproxy import http
class TestHTTPRequest:
def test_simple(self):
f = tflow.tflow()
r = f.request
u = r.url
r.url = u
with pytest.raises(ValueError):
setattr(r, "url", "")
assert r.url == u
r2 = r.copy()
assert r.get_state() == r2.get_state()
assert hash(r)
def test_get_url(self):
r = http.HTTPRequest.wrap(mitmproxy.test.tutils.treq())
assert r.url == "http://address:22/path"
r.scheme = "https"
assert r.url == "https://address:22/path"
r.host = "host"
r.port = 42
assert r.url == "https://host:42/path"
r.host = "address"
r.port = 22
assert r.url == "https://address:22/path"
assert r.pretty_url == "https://address:22/path"
r.headers["Host"] = "foo.com:22"
assert r.url == "https://address:22/path"
assert r.pretty_url == "https://foo.com:22/path"
def test_replace(self):
r = http.HTTPRequest.wrap(mitmproxy.test.tutils.treq())
r.path = "path/foo"
r.headers["Foo"] = "fOo"
r.content = b"afoob"
assert r.replace("foo(?i)", "boo") == 4
assert r.path == "path/boo"
assert b"foo" not in r.content
assert r.headers["boo"] == "boo"
def test_constrain_encoding(self):
r = http.HTTPRequest.wrap(mitmproxy.test.tutils.treq())
r.headers["accept-encoding"] = "gzip, oink"
r.constrain_encoding()
assert "oink" not in r.headers["accept-encoding"]
r.headers.set_all("accept-encoding", ["gzip", "oink"])
r.constrain_encoding()
assert "oink" not in r.headers["accept-encoding"]
def test_get_content_type(self):
resp = http.HTTPResponse.wrap(mitmproxy.test.tutils.tresp())
resp.headers = Headers(content_type="text/plain")
assert resp.headers["content-type"] == "text/plain"
class TestHTTPResponse:
def test_simple(self):
f = tflow.tflow(resp=True)
resp = f.response
resp2 = resp.copy()
assert resp2.get_state() == resp.get_state()
def test_replace(self):
r = http.HTTPResponse.wrap(mitmproxy.test.tutils.tresp())
r.headers["Foo"] = "fOo"
r.content = b"afoob"
assert r.replace("foo(?i)", "boo") == 3
assert b"foo" not in r.content
assert r.headers["boo"] == "boo"
def test_get_content_type(self):
resp = http.HTTPResponse.wrap(mitmproxy.test.tutils.tresp())
resp.headers = Headers(content_type="text/plain")
assert resp.headers["content-type"] == "text/plain"
class TestHTTPFlow:
def test_copy(self):
f = tflow.tflow(resp=True)
assert repr(f)
f.get_state()
f2 = f.copy()
a = f.get_state()
b = f2.get_state()
del a["id"]
del b["id"]
assert a == b
assert not f == f2
assert f is not f2
assert f.request.get_state() == f2.request.get_state()
assert f.request is not f2.request
assert f.request.headers == f2.request.headers
assert f.request.headers is not f2.request.headers
assert f.response.get_state() == f2.response.get_state()
assert f.response is not f2.response
f = tflow.tflow(err=True)
f2 = f.copy()
assert f is not f2
assert f.request is not f2.request
assert f.request.headers == f2.request.headers
assert f.request.headers is not f2.request.headers
assert f.error.get_state() == f2.error.get_state()
assert f.error is not f2.error
def test_match(self):
f = tflow.tflow(resp=True)
assert not flowfilter.match("~b test", f)
assert flowfilter.match(None, f)
assert not flowfilter.match("~b test", f)
f = tflow.tflow(err=True)
assert flowfilter.match("~e", f)
with pytest.raises(ValueError):
flowfilter.match("~", f)
def test_backup(self):
f = tflow.tflow()
f.response = http.HTTPResponse.wrap(mitmproxy.test.tutils.tresp())
f.request.content = b"foo"
assert not f.modified()
f.backup()
f.request.content = b"bar"
assert f.modified()
f.revert()
assert f.request.content == b"foo"
def test_backup_idempotence(self):
f = tflow.tflow(resp=True)
f.backup()
f.revert()
f.backup()
f.revert()
def test_getset_state(self):
f = tflow.tflow(resp=True)
state = f.get_state()
assert f.get_state() == http.HTTPFlow.from_state(
state).get_state()
f.response = None
f.error = flow.Error("error")
state = f.get_state()
assert f.get_state() == http.HTTPFlow.from_state(
state).get_state()
f2 = f.copy()
f2.id = f.id # copy creates a different uuid
assert f.get_state() == f2.get_state()
assert not f == f2
f2.error = flow.Error("e2")
assert not f == f2
f.set_state(f2.get_state())
assert f.get_state() == f2.get_state()
def test_kill(self):
f = tflow.tflow()
f.reply.handle()
f.intercept()
assert f.killable
f.kill()
assert not f.killable
assert f.reply.value == Kill
def test_resume(self):
f = tflow.tflow()
f.reply.handle()
f.intercept()
assert f.reply.state == "taken"
f.resume()
assert f.reply.state == "committed"
def test_replace_unicode(self):
f = tflow.tflow(resp=True)
f.response.content = b"\xc2foo"
f.replace(b"foo", u"bar")
def test_replace_no_content(self):
f = tflow.tflow()
f.request.content = None
assert f.replace("foo", "bar") == 0
def test_replace(self):
f = tflow.tflow(resp=True)
f.request.headers["foo"] = "foo"
f.request.content = b"afoob"
f.response.headers["foo"] = "foo"
f.response.content = b"afoob"
assert f.replace("foo", "bar") == 6
assert f.request.headers["bar"] == "bar"
assert f.request.content == b"abarb"
assert f.response.headers["bar"] == "bar"
assert f.response.content == b"abarb"
def test_replace_encoded(self):
f = tflow.tflow(resp=True)
f.request.content = b"afoob"
f.request.encode("gzip")
f.response.content = b"afoob"
f.response.encode("gzip")
f.replace("foo", "bar")
assert f.request.raw_content != b"abarb"
f.request.decode()
assert f.request.raw_content == b"abarb"
assert f.response.raw_content != b"abarb"
f.response.decode()
assert f.response.raw_content == b"abarb"
def test_make_error_response():
resp = http.make_error_response(543, 'foobar', Headers())
assert resp
def test_make_connect_request():
req = http.make_connect_request(('invalidhost', 1234))
assert req.first_line_format == 'authority'
assert req.method == 'CONNECT'
assert req.http_version == 'HTTP/1.1'
def test_make_connect_response():
resp = http.make_connect_response('foobar')
assert resp.http_version == 'foobar'
assert resp.status_code == 200
def test_expect_continue_response():
assert http.expect_continue_response.http_version == 'HTTP/1.1'
assert http.expect_continue_response.status_code == 100

View File

@ -30,6 +30,14 @@ class TD2(TD):
super().__init__(three=three, **kwargs) super().__init__(three=three, **kwargs)
class TM(optmanager.OptManager):
def __init__(self, one="one", two=["foo"], three=None):
self.one = one
self.two = two
self.three = three
super().__init__()
def test_defaults(): def test_defaults():
assert TD2.default("one") == "done" assert TD2.default("one") == "done"
assert TD2.default("two") == "dtwo" assert TD2.default("two") == "dtwo"
@ -203,6 +211,9 @@ def test_serialize():
t = "" t = ""
o2.load(t) o2.load(t)
with pytest.raises(exceptions.OptionsError, matches='No such option: foobar'):
o2.load("foobar: '123'")
def test_serialize_defaults(): def test_serialize_defaults():
o = options.Options() o = options.Options()
@ -224,13 +235,10 @@ def test_saving():
o.load_paths(dst) o.load_paths(dst)
assert o.three == "foo" assert o.three == "foo"
with open(dst, 'a') as f:
class TM(optmanager.OptManager): f.write("foobar: '123'")
def __init__(self, one="one", two=["foo"], three=None): with pytest.raises(exceptions.OptionsError, matches=''):
self.one = one o.load_paths(dst)
self.two = two
self.three = three
super().__init__()
def test_merge(): def test_merge():

View File

@ -4,62 +4,17 @@ from unittest import mock
from OpenSSL import SSL from OpenSSL import SSL
import pytest import pytest
from mitmproxy.test import tflow
from mitmproxy.tools import cmdline from mitmproxy.tools import cmdline
from mitmproxy import options from mitmproxy import options
from mitmproxy.proxy import ProxyConfig from mitmproxy.proxy import ProxyConfig
from mitmproxy import connections
from mitmproxy.proxy.server import DummyServer, ProxyServer, ConnectionHandler from mitmproxy.proxy.server import DummyServer, ProxyServer, ConnectionHandler
from mitmproxy.proxy import config from mitmproxy.proxy import config
from mitmproxy import exceptions
from pathod import test
from mitmproxy.net.http import http1
from mitmproxy.test import tutils from mitmproxy.test import tutils
from ..conftest import skip_windows from ..conftest import skip_windows
class TestServerConnection:
def test_simple(self):
self.d = test.Daemon()
sc = connections.ServerConnection((self.d.IFACE, self.d.port))
sc.connect()
f = tflow.tflow()
f.server_conn = sc
f.request.path = "/p/200:da"
# use this protocol just to assemble - not for actual sending
sc.wfile.write(http1.assemble_request(f.request))
sc.wfile.flush()
assert http1.read_response(sc.rfile, f.request, 1000)
assert self.d.last_log()
sc.finish()
self.d.shutdown()
def test_terminate_error(self):
self.d = test.Daemon()
sc = connections.ServerConnection((self.d.IFACE, self.d.port))
sc.connect()
sc.connection = mock.Mock()
sc.connection.recv = mock.Mock(return_value=False)
sc.connection.flush = mock.Mock(side_effect=exceptions.TcpDisconnect)
sc.finish()
self.d.shutdown()
def test_repr(self):
sc = tflow.tserver_conn()
assert "address:22" in repr(sc)
assert "ssl" not in repr(sc)
sc.ssl_established = True
assert "ssl" in repr(sc)
sc.sni = "foo"
assert "foo" in repr(sc)
class MockParser(argparse.ArgumentParser): class MockParser(argparse.ArgumentParser):
""" """
@ -160,7 +115,7 @@ class TestProxyServer:
ProxyServer(conf) ProxyServer(conf)
def test_err_2(self): def test_err_2(self):
conf = ProxyConfig(options.Options(listen_host="invalidhost")) conf = ProxyConfig(options.Options(listen_host="256.256.256.256"))
with pytest.raises(Exception, match="Error starting proxy server"): with pytest.raises(Exception, match="Error starting proxy server"):
ProxyServer(conf) ProxyServer(conf)

View File

@ -98,13 +98,14 @@ class ProxyThread(threading.Thread):
threading.Thread.__init__(self) threading.Thread.__init__(self)
self.tmaster = tmaster self.tmaster = tmaster
self.name = "ProxyThread (%s:%s)" % ( self.name = "ProxyThread (%s:%s)" % (
tmaster.server.address.host, tmaster.server.address.port tmaster.server.address[0],
tmaster.server.address[1],
) )
controller.should_exit = False controller.should_exit = False
@property @property
def port(self): def port(self):
return self.tmaster.server.address.port return self.tmaster.server.address[1]
@property @property
def tlog(self): def tlog(self):

View File

@ -26,7 +26,7 @@ export function ConnectionInfo({ conn }) {
<tbody> <tbody>
<tr key="address"> <tr key="address">
<td>Address:</td> <td>Address:</td>
<td>{conn.address.address.join(':')}</td> <td>{conn.address.join(':')}</td>
</tr> </tr>
{conn.sni && ( {conn.sni && (
<tr key="sni"> <tr key="sni">

View File

@ -106,7 +106,7 @@ function destination(regex){
function destinationFilter(flow){ function destinationFilter(flow){
return (!!flow.server_conn.address) return (!!flow.server_conn.address)
&& &&
regex.test(flow.server_conn.address.address[0] + ":" + flow.server_conn.address.address[1]); regex.test(flow.server_conn.address[0] + ":" + flow.server_conn.address[1]);
} }
destinationFilter.desc = "destination address matches " + regex; destinationFilter.desc = "destination address matches " + regex;
return destinationFilter; return destinationFilter;
@ -172,7 +172,7 @@ function source(regex){
function sourceFilter(flow){ function sourceFilter(flow){
return (!!flow.client_conn.address) return (!!flow.client_conn.address)
&& &&
regex.test(flow.client_conn.address.address[0] + ":" + flow.client_conn.address.address[1]); regex.test(flow.client_conn.address[0] + ":" + flow.client_conn.address[1]);
} }
sourceFilter.desc = "source address matches " + regex; sourceFilter.desc = "source address matches " + regex;
return sourceFilter; return sourceFilter;