refactoring. Basically, we are dividing WebSocketServer into two
classes: One request handler following the SocketServer Requesthandler
API, and one optional server engine. The standard Python SocketServer
engine can also be used.

websocketproxy.py has been updated to match the API change. I've also
added a new option --libserver in order to use the Python built in
server instead.

I've done a lot of testing with the new code. This includes: verbose,
daemon, run-once, timeout, idle-timeout, ssl, web, libserver. I've
tested both Python 2 and 3. I've also tested websocket.py in another
external service.

Code details follows:

* The new request handler class is called WebSocketRequestHandler,
  inheriting SimpleHTTPRequestHandler.

* The service engine is called WebSocketServer, just like before.

* do_websocket_handshake: Using send_header() etc, instead of manually
  sending HTTP response.

* A new method called handle_websocket() upgrades the connection to
  WebSocket, if requested. Otherwise, it returns False. A typical
  application use is:

    def do_GET(self):
        if not self.handle_websocket():
	   # handle normal requests

* new_client has been renamed to new_websocket_client, in order to
  have a better name in the SocketServer/HTTPServer request handler
  hierarchy.

* Note that in the request handler, configuration variables must be
  provided by the "server" object, ie self.server.target_host.
This commit is contained in:
Peter Åstrand (astrand) 2013-03-14 16:07:40 +01:00
parent 208f83b9a2
commit 7b3dd8a6f5
2 changed files with 250 additions and 153 deletions

View File

@ -64,26 +64,32 @@ if multiprocessing and sys.platform == 'win32':
import multiprocessing.reduction
class WebSocketServer(object):
"""
WebSockets server class.
Must be sub-classed with new_client method definition.
"""
# HTTP handler with WebSocket upgrade support
class WebSocketRequestHandler(SimpleHTTPRequestHandler):
buffer_size = 65536
GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
server_handshake_hybi = """HTTP/1.1 101 Switching Protocols\r
Upgrade: websocket\r
Connection: Upgrade\r
Sec-WebSocket-Accept: %s\r
"""
server_version = "WebSockify"
protocol_version = "HTTP/1.1"
# An exception while the WebSocket client was connected
class CClose(Exception):
pass
def __init__(self, req, addr, server):
# Retrieve a few configuration variables from the server
self.only_upgrade = getattr(server, "only_upgrade", False)
self.verbose = getattr(server, "verbose", False)
self.daemon = getattr(server, "daemon", False)
self.record = getattr(server, "record", False)
self.run_once = getattr(server, "run_once", False)
self.rec = None
self.handler_id = getattr(server, "handler_id", False)
SimpleHTTPRequestHandler.__init__(self, req, addr, server)
@staticmethod
def unmask(buf, hlen, plen):
pstart = hlen + 4
@ -204,7 +210,7 @@ Sec-WebSocket-Accept: %s\r
# Process 1 frame
if f['masked']:
# unmask payload
f['payload'] = WebSocketServer.unmask(buf, f['hlen'],
f['payload'] = WebSocketRequestHandler.unmask(buf, f['hlen'],
f['length'])
else:
print("Unmasked frame: %s" % repr(buf))
@ -227,9 +233,18 @@ Sec-WebSocket-Accept: %s\r
return f
#
# Main WebSocketServer methods
# WebSocketRequestHandler logging/output functions
#
def traffic(self, token="."):
""" Show traffic flow in verbose mode. """
if self.verbose and not self.daemon:
sys.stdout.write(token)
sys.stdout.flush()
#
# Main WebSocketRequestHandler methods
#
def send_frames(self, bufs=None):
""" Encode and send WebSocket frames. Any frames already
queued will be sent first. If buf is not set then only queued
@ -311,7 +326,7 @@ Sec-WebSocket-Accept: %s\r
start = frame['hlen']
end = frame['hlen'] + frame['length']
if frame['masked']:
recbuf = WebSocketServer.unmask(buf, frame['hlen'],
recbuf = WebSocketRequestHandler.unmask(buf, frame['hlen'],
frame['length'])
else:
recbuf = buf[frame['hlen']:frame['hlen'] +
@ -336,9 +351,8 @@ Sec-WebSocket-Accept: %s\r
buf, h, t = self.encode_hybi(msg, opcode=0x08, base64=False)
self.request.send(buf)
def do_websocket_handshake(self, headers, path):
h = self.headers = headers
self.path = path
def do_websocket_handshake(self):
h = self.headers
prot = 'WebSocket-Protocol'
protocols = h.get('Sec-'+prot, h.get(prot, '')).split(',')
@ -353,7 +367,8 @@ Sec-WebSocket-Accept: %s\r
if ver in ['7', '8', '13']:
self.version = "hybi-%02d" % int(ver)
else:
raise self.EClose('Unsupported protocol version %s' % ver)
self.send_error(400, "Unsupported protocol version %s" % ver)
return False
key = h['Sec-WebSocket-Key']
@ -363,26 +378,130 @@ Sec-WebSocket-Accept: %s\r
elif 'base64' in protocols:
self.base64 = True
else:
raise self.EClose("Client must support 'binary' or 'base64' protocol")
self.send_error(400, "Client must support 'binary' or 'base64' protocol")
return False
# Generate the hash value for the accept header
accept = b64encode(sha1(s2b(key + self.GUID)).digest())
response = self.server_handshake_hybi % b2s(accept)
self.send_response(101, "Switching Protocols")
self.send_header("Upgrade", "websocket")
self.send_header("Connection", "Upgrade")
self.send_header("Sec-WebSocket-Accept", b2s(accept))
if self.base64:
response += "Sec-WebSocket-Protocol: base64\r\n"
self.send_header("Sec-WebSocket-Protocol", "base64")
else:
response += "Sec-WebSocket-Protocol: binary\r\n"
response += "\r\n"
self.send_header("Sec-WebSocket-Protocol", "binary")
self.end_headers()
return True
else:
self.send_error(400, "Missing Sec-WebSocket-Version header. Hixie protocols not supported.")
return False
def handle_websocket(self):
"""Upgrade a connection to Websocket, if requested. If this succeeds,
new_websocket_client() will be called. Otherwise, False is returned.
"""
if (self.headers.get('upgrade') and
self.headers.get('upgrade').lower() == 'websocket'):
if not self.do_websocket_handshake():
return False
# Indicate to server that a Websocket upgrade was done
self.server.ws_connection = True
# Initialize per client settings
self.send_parts = []
self.recv_part = None
self.start_time = int(time.time()*1000)
# client_address is empty with, say, UNIX domain sockets
client_addr = ""
is_ssl = False
try:
client_addr = self.client_address[0]
is_ssl = self.client_address[2]
except IndexError:
pass
if is_ssl:
self.stype = "SSL/TLS (wss://)"
else:
self.stype = "Plain non-SSL (ws://)"
self.log_message("%s: %s WebSocket connection" % (client_addr,
self.stype))
self.log_message("%s: Version %s, base64: '%s'" % (client_addr,
self.version, self.base64))
if self.path != '/':
self.log_message("%s: Path: '%s'" % (client_addr, self.path))
if self.record:
# Record raw frame data as JavaScript array
fname = "%s.%s" % (self.record,
self.handler_id)
self.log_message("opening record file: %s" % fname)
self.rec = open(fname, 'w+')
encoding = "binary"
if self.base64: encoding = "base64"
self.rec.write("var VNC_frame_encoding = '%s';\n"
% encoding)
self.rec.write("var VNC_frame_data = [\n")
try:
self.new_websocket_client()
except self.CClose:
# Close the client
_, exc, _ = sys.exc_info()
self.send_close(exc.args[0], exc.args[1])
return True
else:
raise self.EClose("Missing Sec-WebSocket-Version header. Hixie protocols not supported.")
return False
return response
def new_client(self):
def do_GET(self):
"""Handle GET request. Calls handle_websocket(). If unsuccessful,
and web server is enabled, SimpleHTTPRequestHandler.do_GET will be called."""
if not self.handle_websocket():
if self.only_upgrade:
self.send_error(405, "Method Not Allowed")
else:
SimpleHTTPRequestHandler.do_GET(self)
def new_websocket_client(self):
""" Do something with a WebSockets client connection. """
raise("WebSocketServer.new_client() must be overloaded")
raise("WebSocketRequestHandler.new_websocket_client() must be overloaded")
def do_HEAD(self):
if self.only_upgrade:
self.send_error(405, "Method Not Allowed")
else:
SimpleHTTPRequestHandler.do_HEAD(self)
def finish(self):
if self.rec:
self.rec.write("'EOF'];\n")
self.rec.close()
def handle(self):
# When using run_once, we have a single process, so
# we cannot loop in BaseHTTPRequestHandler.handle; we
# must return and handle new connections
if self.run_once:
self.handle_one_request()
else:
SimpleHTTPRequestHandler.handle(self)
def log_request(self, code='-', size='-'):
if self.verbose:
SimpleHTTPRequestHandler.log_request(self, code, size)
class WebSocketServer(object):
"""
WebSockets server class.
Must be sub-classed with new_client method definition.
"""
policy_response = """<cross-domain-policy><allow-access-from domain="*" to-ports="*" /></cross-domain-policy>\n"""
@ -390,12 +509,14 @@ Sec-WebSocket-Accept: %s\r
class EClose(Exception):
pass
def __init__(self, listen_host='', listen_port=None, source_is_ipv6=False,
def __init__(self, RequestHandlerClass, listen_host='',
listen_port=None, source_is_ipv6=False,
verbose=False, cert='', key='', ssl_only=None,
daemon=False, record='', web='',
run_once=False, timeout=0, idle_timeout=0):
# settings
self.RequestHandlerClass = RequestHandlerClass
self.verbose = verbose
self.listen_host = listen_host
self.listen_port = listen_port
@ -422,6 +543,7 @@ Sec-WebSocket-Accept: %s\r
if self.web:
os.chdir(self.web)
self.only_upgrade = not self.web
# Sanity checks
if not ssl and self.ssl_only:
@ -548,7 +670,6 @@ Sec-WebSocket-Accept: %s\r
- Send a WebSockets handshake server response.
- Return the socket for this WebSocket client.
"""
stype = ""
ready = select.select([sock], [], [], 3)[0]
@ -592,42 +713,19 @@ Sec-WebSocket-Accept: %s\r
else:
raise
self.scheme = "wss"
stype = "SSL/TLS (wss://)"
elif self.ssl_only:
raise self.EClose("non-SSL connection received but disallowed")
else:
retsock = sock
self.scheme = "ws"
stype = "Plain non-SSL (ws://)"
wsh = WSRequestHandler(retsock, address, not self.web)
if wsh.last_code == 101:
# Continue on to handle WebSocket upgrade
pass
elif wsh.last_code == 405:
raise self.EClose("Normal web request received but disallowed")
elif wsh.last_code < 200 or wsh.last_code >= 300:
raise self.EClose(wsh.last_message)
elif self.verbose:
raise self.EClose(wsh.last_message)
else:
raise self.EClose("")
# If the address is like (host, port), we are extending it
# with a flag indicating SSL. Not many other options
# available...
if len(address) == 2:
address = (address[0], address[1], (retsock != sock))
response = self.do_websocket_handshake(wsh.headers, wsh.path)
self.msg("%s: %s WebSocket connection" % (address[0], stype))
self.msg("%s: Version %s, base64: '%s'" % (address[0],
self.version, self.base64))
if self.path != '/':
self.msg("%s: Path: '%s'" % (address[0], self.path))
# Send server WebSockets handshake response
#self.msg("sending response [%s]" % response)
retsock.send(s2b(response))
self.RequestHandlerClass(retsock, address, self)
# Return the WebSockets socket which may be SSL wrapped
return retsock
@ -636,13 +734,6 @@ Sec-WebSocket-Accept: %s\r
#
# WebSocketServer logging/output functions
#
def traffic(self, token="."):
""" Show traffic flow in verbose mode. """
if self.verbose and not self.daemon:
sys.stdout.write(token)
sys.stdout.flush()
def msg(self, msg):
""" Output message with handler_id prefix. """
if not self.daemon:
@ -683,37 +774,11 @@ Sec-WebSocket-Accept: %s\r
def top_new_client(self, startsock, address):
""" Do something with a WebSockets client connection. """
# Initialize per client settings
self.send_parts = []
self.recv_part = None
self.base64 = False
self.rec = None
self.start_time = int(time.time()*1000)
# handler process
client = None
try:
try:
self.request = self.do_handshake(startsock, address)
if self.record:
# Record raw frame data as JavaScript array
fname = "%s.%s" % (self.record,
self.handler_id)
self.msg("opening record file: %s" % fname)
self.rec = open(fname, 'w+')
encoding = "binary"
if self.base64: encoding = "base64"
self.rec.write("var VNC_frame_encoding = '%s';\n"
% encoding)
self.rec.write("var VNC_frame_data = [\n")
self.ws_connection = True
self.new_client()
except self.CClose:
# Close the client
_, exc, _ = sys.exc_info()
if self.request:
self.send_close(exc.args[0], exc.args[1])
client = self.do_handshake(startsock, address)
except self.EClose:
_, exc, _ = sys.exc_info()
# Connection was not a WebSockets connection
@ -725,14 +790,11 @@ Sec-WebSocket-Accept: %s\r
if self.verbose:
self.msg(traceback.format_exc())
finally:
if self.rec:
self.rec.write("'EOF'];\n")
self.rec.close()
if self.request and self.request != startsock:
if client and client != startsock:
# Close the SSL wrapped socket
# Original socket closed by caller
self.request.close()
client.close()
def start_server(self):
"""
@ -758,7 +820,6 @@ Sec-WebSocket-Accept: %s\r
while True:
try:
try:
self.request = None
startsock = None
pid = err = 0
child_count = 0
@ -855,33 +916,3 @@ Sec-WebSocket-Accept: %s\r
self.vmsg("Closing socket listening at %s:%s"
% (self.listen_host, self.listen_port))
lsock.close()
# HTTP handler with WebSocket upgrade support
class WSRequestHandler(SimpleHTTPRequestHandler):
def __init__(self, req, addr, only_upgrade=False):
self.only_upgrade = only_upgrade # only allow upgrades
SimpleHTTPRequestHandler.__init__(self, req, addr, object())
def do_GET(self):
if (self.headers.get('upgrade') and
self.headers.get('upgrade').lower() == 'websocket'):
# Just indicate that an WebSocket upgrade is needed
self.last_code = 101
self.last_message = "101 Switching Protocols"
elif self.only_upgrade:
# Normal web request responses are disabled
self.last_code = 405
self.last_message = "405 Method Not Allowed"
else:
SimpleHTTPRequestHandler.do_GET(self)
def send_response(self, code, message=None):
# Save the status code
self.last_code = code
SimpleHTTPRequestHandler.send_response(self, code, message)
def log_message(self, f, *args):
# Save instead of printing
self.last_message = f % args

View File

@ -12,6 +12,8 @@ as taken from http://docs.python.org/dev/library/ssl.html#certificates
'''
import signal, socket, optparse, time, os, sys, subprocess
import SocketServer, BaseHTTPServer
from SimpleHTTPServer import SimpleHTTPRequestHandler
from select import select
import websocket
try:
@ -20,7 +22,7 @@ except:
from cgi import parse_qs
from urlparse import urlparse
class WebSocketProxy(websocket.WebSocketServer):
class CustomProxyServer(websocket.WebSocketServer):
"""
Proxy traffic to and from a WebSockets client to a normal TCP
socket server target. All traffic to/from the client is base64
@ -140,6 +142,8 @@ class WebSocketProxy(websocket.WebSocketServer):
# process.
#
class ProxyRequestHandler(websocket.WebSocketRequestHandler):
#
# Routines below this point are connection handler routines and
# will be run in a separate forked process for each connection.
@ -150,38 +154,38 @@ Traffic Legend:
} - Client receive
}. - Client receive partial
{ - Target receive
> - Target send
>. - Target send partial
< - Client send
<. - Client send partial
"""
def new_client(self):
def new_websocket_client(self):
"""
Called after a new WebSocket connection has been established.
"""
# Checks if we receive a token, and look
# for a valid target for it then
if self.target_cfg:
(self.target_host, self.target_port) = self.get_target(self.target_cfg, self.path)
if self.server.target_cfg:
(self.server.target_host, self.server.target_port) = self.get_target(self.server.target_cfg, self.path)
# Connect to the target
if self.wrap_cmd:
msg = "connecting to command: '%s' (port %s)" % (" ".join(self.wrap_cmd), self.target_port)
elif self.unix_target:
msg = "connecting to unix socket: %s" % self.unix_target
if self.server.wrap_cmd:
msg = "connecting to command: '%s' (port %s)" % (" ".join(self.server.wrap_cmd), self.server.target_port)
elif self.server.unix_target:
msg = "connecting to unix socket: %s" % self.server.unix_target
else:
msg = "connecting to: %s:%s" % (
self.target_host, self.target_port)
self.server.target_host, self.server.target_port)
if self.ssl_target:
if self.server.ssl_target:
msg += " (using SSL)"
self.msg(msg)
self.log_message(msg)
tsock = websocket.WebSocketServer.socket(self.target_host,
self.target_port,
connect=True, use_ssl=self.ssl_target, unix_socket=self.unix_target)
tsock = websocket.WebSocketServer.socket(self.server.target_host,
self.server.target_port,
connect=True, use_ssl=self.server.ssl_target, unix_socket=self.server.unix_target)
if self.verbose and not self.daemon:
print(self.traffic_legend)
@ -193,8 +197,9 @@ Traffic Legend:
if tsock:
tsock.shutdown(socket.SHUT_RDWR)
tsock.close()
self.vmsg("%s:%s: Closed target" %(
self.target_host, self.target_port))
if self.verbose:
self.log_message("%s:%s: Closed target" %(
self.server.target_host, self.server.target_port))
raise
def get_target(self, target_cfg, path):
@ -266,8 +271,9 @@ Traffic Legend:
if closed:
# TODO: What about blocking on client socket?
self.vmsg("%s:%s: Client closed connection" %(
self.target_host, self.target_port))
if self.verbose:
self.log_message("%s:%s: Client closed connection" %(
self.server.target_host, self.server.target_port))
raise self.CClose(closed['code'], closed['reason'])
@ -287,8 +293,9 @@ Traffic Legend:
# Receive target data, encode it and queue for client
buf = target.recv(self.buffer_size)
if len(buf) == 0:
self.vmsg("%s:%s: Target closed connection" %(
self.target_host, self.target_port))
if self.verbose:
self.log_message("%s:%s: Target closed connection" %(
self.server.target_host, self.server.target_port))
raise self.CClose(1000, "Target closed")
cqueue.append(buf)
@ -346,6 +353,8 @@ def websockify_init():
help="Configuration file containing valid targets "
"in the form 'token: host:port' or, alternatively, a "
"directory containing configuration files of this form")
parser.add_option("--libserver", action="store_true",
help="use Python library SocketServer engine")
(opts, args) = parser.parse_args()
# Sanity checks
@ -387,8 +396,65 @@ def websockify_init():
except: parser.error("Error parsing target port")
# Create and start the WebSockets proxy
server = WebSocketProxy(**opts.__dict__)
server.start_server()
libserver = opts.libserver
del opts.libserver
if libserver:
# Use standard Python SocketServer framework
httpd = LibProxyServer(ProxyRequestHandler, **opts.__dict__)
httpd.serve_forever()
else:
# Use internal service framework
server = CustomProxyServer(ProxyRequestHandler, **opts.__dict__)
server.start_server()
class LibProxyServer(SocketServer.ForkingMixIn, BaseHTTPServer.HTTPServer):
"""
Just like CustomProxyServer, but uses standard Python SocketServer
framework.
"""
def __init__(self, RequestHandlerClass, **kwargs):
# Save off proxy specific options
self.target_host = kwargs.pop('target_host', None)
self.target_port = kwargs.pop('target_port', None)
self.wrap_cmd = kwargs.pop('wrap_cmd', None)
self.wrap_mode = kwargs.pop('wrap_mode', None)
self.unix_target = kwargs.pop('unix_target', None)
self.ssl_target = kwargs.pop('ssl_target', None)
self.target_cfg = kwargs.pop('target_cfg', None)
self.daemon = False
self.target_cfg = None
# Server configuration
listen_host = kwargs.pop('listen_host', '')
listen_port = kwargs.pop('listen_port', None)
web = kwargs.pop('web', '')
# Configuration affecting base request handler
self.only_upgrade = not web
self.verbose = kwargs.pop('verbose', False)
record = kwargs.pop('record', '')
if record:
self.record = os.path.abspath(record)
self.run_once = kwargs.pop('run_once', False)
self.handler_id = 0
for arg in kwargs.keys():
print("warning: option %s ignored when using --libserver" % arg)
if web:
os.chdir(web)
BaseHTTPServer.HTTPServer.__init__(self, (listen_host, listen_port),
RequestHandlerClass)
def process_request(self, request, client_address):
"""Override process_request to implement a counter"""
self.handler_id += 1
SocketServer.ForkingMixIn.process_request(self, request, client_address)
if __name__ == '__main__':
websockify_init()