pathod: register websocket key from client handshake

This commit is contained in:
Aldo Cortesi 2015-04-23 17:11:20 +12:00
parent 2306a7ab6d
commit dacb350040
2 changed files with 20 additions and 9 deletions

View File

@ -1,9 +1,10 @@
import urllib import copy
import threading
import logging import logging
import os import os
import sys import sys
from netlib import tcp, http, wsgi, certutils import threading
import urllib
from netlib import tcp, http, wsgi, certutils, websockets
import netlib.utils import netlib.utils
from . import version, app, language, utils from . import version, app, language, utils
@ -57,6 +58,10 @@ class PathodHandler(tcp.BaseHandler):
wbufsize = 0 wbufsize = 0
sni = None sni = None
def __init__(self, connection, address, server, settings):
tcp.BaseHandler.__init__(self, connection, address, server)
self.settings = copy.copy(settings)
def info(self, s): def info(self, s):
logger.info( logger.info(
"%s:%s: %s" % (self.address.host, self.address.port, str(s)) "%s:%s: %s" % (self.address.host, self.address.port, str(s))
@ -67,11 +72,11 @@ class PathodHandler(tcp.BaseHandler):
def serve_crafted(self, crafted): def serve_crafted(self, crafted):
error, crafted = self.server.check_policy( error, crafted = self.server.check_policy(
crafted, self.server.settings crafted, self.settings
) )
if error: if error:
err = language.make_error_response(error) err = language.make_error_response(error)
language.serve(err, self.wfile, self.server.settings) language.serve(err, self.wfile, self.settings)
log = dict( log = dict(
type="error", type="error",
msg = error msg = error
@ -79,12 +84,12 @@ class PathodHandler(tcp.BaseHandler):
return False, log return False, log
if self.server.explain and not isinstance(crafted, language.PathodErrorResponse): if self.server.explain and not isinstance(crafted, language.PathodErrorResponse):
crafted = crafted.freeze(self.server.settings) crafted = crafted.freeze(self.settings)
self.info(">> Spec: %s" % crafted.spec()) self.info(">> Spec: %s" % crafted.spec())
response_log = language.serve( response_log = language.serve(
crafted, crafted,
self.wfile, self.wfile,
self.server.settings self.settings
) )
if response_log["disconnect"]: if response_log["disconnect"]:
return False, response_log return False, response_log
@ -197,6 +202,9 @@ class PathodHandler(tcp.BaseHandler):
return again return again
if not self.server.nocraft and path.startswith(self.server.craftanchor): if not self.server.nocraft and path.startswith(self.server.craftanchor):
key = websockets.check_client_handshake(headers)
if key:
self.settings.websocket_key = key
spec = urllib.unquote(path)[len(self.server.craftanchor):] spec = urllib.unquote(path)[len(self.server.craftanchor):]
self.info("crafting spec: %s" % spec) self.info("crafting spec: %s" % spec)
try: try:
@ -212,7 +220,7 @@ class PathodHandler(tcp.BaseHandler):
return again return again
elif self.server.noweb: elif self.server.noweb:
crafted = language.make_error_response("Access Denied") crafted = language.make_error_response("Access Denied")
language.serve(crafted, self.wfile, self.server.settings) language.serve(crafted, self.wfile, self.settings)
self.addlog(dict( self.addlog(dict(
type="error", type="error",
msg="Access denied: web interface disabled" msg="Access denied: web interface disabled"
@ -360,7 +368,7 @@ class Pathod(tcp.TCPServer):
return None, req return None, req
def handle_client_connection(self, request, client_address): def handle_client_connection(self, request, client_address):
h = PathodHandler(request, client_address, self) h = PathodHandler(request, client_address, self, self.settings)
try: try:
h.handle() h.handle()
h.finish() h.finish()

View File

@ -184,6 +184,9 @@ class CommonTests(tutils.DaemonTests):
r = self.pathoc(r"get:'http://foo.com/p/202':da") r = self.pathoc(r"get:'http://foo.com/p/202':da")
assert r.status_code == 202 assert r.status_code == 202
def test_websocket(self):
r = self.pathoc("ws:/p/")
class TestDaemon(CommonTests): class TestDaemon(CommonTests):
ssl = False ssl = False