websockify/websocket.py

426 lines
14 KiB
Python
Executable File

#!/usr/bin/python
'''
Python WebSocket library with support for "wss://" encryption.
Copyright 2010 Joel Martin
Licensed under LGPL version 3 (see docs/LICENSE.LGPL-3)
You can make a cert/key with openssl using:
openssl req -new -x509 -days 365 -nodes -out self.pem -keyout self.pem
as taken from http://docs.python.org/dev/library/ssl.html#certificates
'''
import sys, socket, ssl, struct, traceback, select
import os, resource, errno, signal # daemonizing
from SimpleHTTPServer import SimpleHTTPRequestHandler
from cStringIO import StringIO
from base64 import b64encode, b64decode
try:
from hashlib import md5
except:
from md5 import md5 # Support python 2.4
from urlparse import urlsplit
from cgi import parse_qsl
class WebSocketServer(object):
"""
WebSockets server class.
Must be sub-classed with new_client method definition.
"""
server_handshake = """HTTP/1.1 101 Web Socket Protocol Handshake\r
Upgrade: WebSocket\r
Connection: Upgrade\r
%sWebSocket-Origin: %s\r
%sWebSocket-Location: %s://%s%s\r
%sWebSocket-Protocol: sample\r
\r
%s"""
policy_response = """<cross-domain-policy><allow-access-from domain="*" to-ports="*" /></cross-domain-policy>\n"""
class EClose(Exception):
pass
def __init__(self, listen_host='', listen_port=None,
verbose=False, cert='', key='', ssl_only=None,
daemon=False, record='', web=''):
# settings
self.verbose = verbose
self.listen_host = listen_host
self.listen_port = listen_port
self.ssl_only = ssl_only
self.daemon = daemon
# Make paths settings absolute
self.cert = os.path.abspath(cert)
self.key = self.web = self.record = ''
if key:
self.key = os.path.abspath(key)
if web:
self.web = os.path.abspath(web)
if record:
self.record = os.path.abspath(record)
if self.web:
os.chdir(self.web)
self.handler_id = 1
print "WebSocket server settings:"
print " - Listen on %s:%s" % (
self.listen_host, self.listen_port)
print " - Flash security policy server"
if self.web:
print " - Web server"
if os.path.exists(self.cert):
print " - SSL/TLS support"
if self.ssl_only:
print " - Deny non-SSL/TLS connections"
else:
print " - No SSL/TLS support (no cert file)"
if self.daemon:
print " - Backgrounding (daemon)"
#
# WebSocketServer static methods
#
@staticmethod
def daemonize(self, keepfd=None):
os.umask(0)
if self.web:
os.chdir(self.web)
else:
os.chdir('/')
os.setgid(os.getgid()) # relinquish elevations
os.setuid(os.getuid()) # relinquish elevations
# Double fork to daemonize
if os.fork() > 0: os._exit(0) # Parent exits
os.setsid() # Obtain new process group
if os.fork() > 0: os._exit(0) # Parent exits
# Signal handling
def terminate(a,b): os._exit(0)
signal.signal(signal.SIGTERM, terminate)
signal.signal(signal.SIGINT, signal.SIG_IGN)
# Close open files
maxfd = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
if maxfd == resource.RLIM_INFINITY: maxfd = 256
for fd in reversed(range(maxfd)):
try:
if fd != keepfd:
os.close(fd)
except OSError, exc:
if exc.errno != errno.EBADF: raise
# Redirect I/O to /dev/null
os.dup2(os.open(os.devnull, os.O_RDWR), sys.stdin.fileno())
os.dup2(os.open(os.devnull, os.O_RDWR), sys.stdout.fileno())
os.dup2(os.open(os.devnull, os.O_RDWR), sys.stderr.fileno())
@staticmethod
def encode(buf):
""" Encode a WebSocket packet. """
buf = b64encode(buf)
return "\x00%s\xff" % buf
@staticmethod
def decode(buf):
""" Decode WebSocket packets. """
if buf.count('\xff') > 1:
return [b64decode(d[1:]) for d in buf.split('\xff')]
else:
return [b64decode(buf[1:-1])]
@staticmethod
def parse_handshake(handshake):
""" Parse fields from client WebSockets handshake. """
ret = {}
req_lines = handshake.split("\r\n")
if not req_lines[0].startswith("GET "):
raise Exception("Invalid handshake: no GET request line")
ret['path'] = req_lines[0].split(" ")[1]
for line in req_lines[1:]:
if line == "": break
var, val = line.split(": ")
ret[var] = val
if req_lines[-2] == "":
ret['key3'] = req_lines[-1]
return ret
@staticmethod
def gen_md5(keys):
""" Generate hash value for WebSockets handshake v76. """
key1 = keys['Sec-WebSocket-Key1']
key2 = keys['Sec-WebSocket-Key2']
key3 = keys['key3']
spaces1 = key1.count(" ")
spaces2 = key2.count(" ")
num1 = int("".join([c for c in key1 if c.isdigit()])) / spaces1
num2 = int("".join([c for c in key2 if c.isdigit()])) / spaces2
return md5(struct.pack('>II8s', num1, num2, key3)).digest()
#
# WebSocketServer logging/output functions
#
def traffic(self, token="."):
""" Show traffic flow in verbose mode. """
if self.verbose and not self.daemon:
sys.stdout.write(token)
sys.stdout.flush()
def msg(self, msg):
""" Output message with handler_id prefix. """
if not self.daemon:
print "% 3d: %s" % (self.handler_id, msg)
def vmsg(self, msg):
""" Same as msg() but only if verbose. """
if self.verbose:
self.msg(msg)
#
# Main WebSocketServer methods
#
def do_handshake(self, sock, address):
"""
do_handshake does the following:
- Peek at the first few bytes from the socket.
- If the connection is Flash policy request then answer it,
close the socket and return.
- If the connection is an HTTPS/SSL/TLS connection then SSL
wrap the socket.
- Read from the (possibly wrapped) socket.
- If we have received a HTTP GET request and the webserver
functionality is enabled, answer it, close the socket and
return.
- Assume we have a WebSockets connection, parse the client
handshake data.
- Send a WebSockets handshake server response.
- Return the socket for this WebSocket client.
"""
stype = ""
# Peek, but don't read the data
handshake = sock.recv(1024, socket.MSG_PEEK)
#self.msg("Handshake [%s]" % repr(handshake))
if handshake == "":
raise self.EClose("ignoring empty handshake")
elif handshake.startswith("<policy-file-request/>"):
# Answer Flash policy request
handshake = sock.recv(1024)
sock.send(self.policy_response)
raise self.EClose("Sending flash policy response")
elif handshake[0] in ("\x16", "\x80"):
# SSL wrap the connection
if not os.path.exists(self.cert):
raise self.EClose("SSL connection but '%s' not found"
% self.cert)
try:
retsock = ssl.wrap_socket(
sock,
server_side=True,
certfile=self.cert,
keyfile=self.key)
except ssl.SSLError, x:
if x.args[0] == ssl.SSL_ERROR_EOF:
raise self.EClose("")
else:
raise
scheme = "wss"
stype = "SSL/TLS (wss://)"
elif self.ssl_only:
raise self.EClose("non-SSL connection received but disallowed")
else:
retsock = sock
scheme = "ws"
stype = "Plain non-SSL (ws://)"
# Now get the data from the socket
handshake = retsock.recv(4096)
#self.msg("handshake: " + repr(handshake))
if len(handshake) == 0:
raise self.EClose("Client closed during handshake")
# Check for and handle normal web requests
if handshake.startswith('GET ') and \
handshake.find('Upgrade: WebSocket\r\n') == -1:
if not self.web:
raise self.EClose("Normal web request received but disallowed")
sh = SplitHTTPHandler(handshake, retsock, address)
if sh.last_code < 200 or sh.last_code >= 300:
raise self.EClose(sh.last_message)
elif self.verbose:
raise self.EClose(sh.last_message)
else:
raise self.EClose("")
# Parse client WebSockets handshake
h = self.parse_handshake(handshake)
if h.get('key3'):
trailer = self.gen_md5(h)
pre = "Sec-"
ver = 76
else:
trailer = ""
pre = ""
ver = 75
self.msg("%s: %s WebSocket connection (version %s)"
% (address[0], stype, ver))
# Send server WebSockets handshake response
response = self.server_handshake % (pre, h['Origin'], pre,
scheme, h['Host'], h['path'], pre, trailer)
#self.msg("sending response:", repr(response))
retsock.send(response)
# Return the WebSockets socket which may be SSL wrapped
return retsock
#
# Events that can/should be overridden in sub-classes
#
def started(self):
""" Called after WebSockets startup """
self.vmsg("WebSockets server started")
def poll(self):
""" Run periodically while waiting for connections. """
self.msg("Running poll()")
def do_SIGCHLD(self, sig, stack):
self.vmsg("Got SIGCHLD, ignoring")
def do_SIGINT(self, sig, stack):
self.msg("Got SIGINT, exiting")
sys.exit(0)
def new_client(self, client):
""" Do something with a WebSockets client connection. """
raise("WebSocketServer.new_client() must be overloaded")
def start_server(self):
"""
Daemonize if requested. Listen for for connections. Run
do_handshake() method for each connection. If the connection
is a WebSockets client then call new_client() method (which must
be overridden) for each new client connection.
"""
lsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
lsock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
lsock.bind((self.listen_host, self.listen_port))
lsock.listen(100)
if self.daemon:
self.daemonize(self, keepfd=lsock.fileno())
self.started() # Some things need to happen after daemonizing
# Reep zombies
signal.signal(signal.SIGCHLD, self.do_SIGCHLD)
signal.signal(signal.SIGINT, self.do_SIGINT)
while True:
try:
try:
csock = startsock = None
pid = err = 0
try:
self.poll()
ready = select.select([lsock], [], [], 1)[0];
if lsock in ready:
startsock, address = lsock.accept()
else:
continue
except Exception, exc:
if hasattr(exc, 'errno'):
err = exc.errno
else:
err = exc[0]
if err == errno.EINTR:
self.vmsg("Ignoring interrupted syscall")
continue
else:
raise
self.vmsg('%s: forking handler' % address[0])
pid = os.fork()
if pid == 0:
# handler process
csock = self.do_handshake(startsock, address)
self.new_client(csock)
else:
# parent process
self.handler_id += 1
except self.EClose, exc:
# Connection was not a WebSockets connection
if exc.args[0]:
self.msg("%s: %s" % (address[0], exc.args[0]))
except KeyboardInterrupt, exc:
pass
except Exception, exc:
self.msg("handler exception: %s" % str(exc))
if self.verbose:
self.msg(traceback.format_exc())
finally:
if csock and csock != startsock:
csock.close()
if startsock:
startsock.close()
if pid == 0:
break # Child process exits
# HTTP handler with request from a string and response to a socket
class SplitHTTPHandler(SimpleHTTPRequestHandler):
def __init__(self, req, resp, addr):
# Save the response socket
self.response = resp
SimpleHTTPRequestHandler.__init__(self, req, addr, object())
def setup(self):
self.connection = self.response
# Duck type request string to file object
self.rfile = StringIO(self.request)
self.wfile = self.connection.makefile('wb', self.wbufsize)
def send_response(self, code, message=None):
# Save the status code
self.last_code = code
SimpleHTTPRequestHandler.send_response(self, code, message)
def log_message(self, f, *args):
# Save instead of printing
self.last_message = f % args