diff --git a/libpathod/pathod.py b/libpathod/pathod.py index 1bb3bae1d..2e93a3304 100644 --- a/libpathod/pathod.py +++ b/libpathod/pathod.py @@ -1,9 +1,10 @@ -import urllib -import threading +import copy import logging import os import sys -from netlib import tcp, http, wsgi, certutils +import threading +import urllib +from netlib import tcp, http, wsgi, certutils, websockets import netlib.utils from . import version, app, language, utils @@ -57,6 +58,10 @@ class PathodHandler(tcp.BaseHandler): wbufsize = 0 sni = None + def __init__(self, connection, address, server, settings): + tcp.BaseHandler.__init__(self, connection, address, server) + self.settings = copy.copy(settings) + def info(self, s): logger.info( "%s:%s: %s" % (self.address.host, self.address.port, str(s)) @@ -67,11 +72,11 @@ class PathodHandler(tcp.BaseHandler): def serve_crafted(self, crafted): error, crafted = self.server.check_policy( - crafted, self.server.settings + crafted, self.settings ) if error: err = language.make_error_response(error) - language.serve(err, self.wfile, self.server.settings) + language.serve(err, self.wfile, self.settings) log = dict( type="error", msg = error @@ -79,12 +84,12 @@ class PathodHandler(tcp.BaseHandler): return False, log if self.server.explain and not isinstance(crafted, language.PathodErrorResponse): - crafted = crafted.freeze(self.server.settings) + crafted = crafted.freeze(self.settings) self.info(">> Spec: %s" % crafted.spec()) response_log = language.serve( crafted, self.wfile, - self.server.settings + self.settings ) if response_log["disconnect"]: return False, response_log @@ -197,6 +202,9 @@ class PathodHandler(tcp.BaseHandler): return again if not self.server.nocraft and path.startswith(self.server.craftanchor): + key = websockets.check_client_handshake(headers) + if key: + self.settings.websocket_key = key spec = urllib.unquote(path)[len(self.server.craftanchor):] self.info("crafting spec: %s" % spec) try: @@ -212,7 +220,7 @@ class PathodHandler(tcp.BaseHandler): return again elif self.server.noweb: crafted = language.make_error_response("Access Denied") - language.serve(crafted, self.wfile, self.server.settings) + language.serve(crafted, self.wfile, self.settings) self.addlog(dict( type="error", msg="Access denied: web interface disabled" @@ -360,7 +368,7 @@ class Pathod(tcp.TCPServer): return None, req def handle_client_connection(self, request, client_address): - h = PathodHandler(request, client_address, self) + h = PathodHandler(request, client_address, self, self.settings) try: h.handle() h.finish() diff --git a/test/test_pathod.py b/test/test_pathod.py index 00634e270..266f41abd 100644 --- a/test/test_pathod.py +++ b/test/test_pathod.py @@ -184,6 +184,9 @@ class CommonTests(tutils.DaemonTests): r = self.pathoc(r"get:'http://foo.com/p/202':da") assert r.status_code == 202 + def test_websocket(self): + r = self.pathoc("ws:/p/") + class TestDaemon(CommonTests): ssl = False