Refactor and cleanup websocket.py and deps.
Moved websocket.py code into a class WebSocketServer. WebSockets server implementations will sub-class and define a handler() method which is passed the client socket after. Global variable settings have been changed to be parameters for WebSocketServer when created. Subclass implementations still have to handle queueing and sending but the parent class handles everything else (daemonizing, websocket handshake, encode/decode, etc). It would be better if the parent class could handle queueing and sending. This adds some buffering and polling complexity to the parent class but it would be better to do so at some point. However, the result is still much cleaner as can be seen in wsecho.py. Refactored wsproxy.py and wstest.py (formerly ws.py) to use the new class. Added wsecho.py as a simple echo server. - rename tests/ws.py to utils/wstest.py and add a symlink from tests/wstest.py - rename tests/ws.html to tests/wstest.html to match utils/wstest.py. - add utils/wsecho.py - add tests/wsecho.html which communicates with wsecho.py and simply sends periodic messages and shows what is received.
This commit is contained in:
parent
6ace64d3ae
commit
6a88340929
166
tests/ws.py
166
tests/ws.py
|
@ -1,166 +0,0 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
'''
|
||||
WebSocket server-side load test program. Sends and receives traffic
|
||||
that has a random payload (length and content) that is checksummed and
|
||||
given a sequence number. Any errors are reported and counted.
|
||||
'''
|
||||
|
||||
import sys, os, socket, ssl, time, traceback
|
||||
import random, time
|
||||
from base64 import b64encode, b64decode
|
||||
from select import select
|
||||
|
||||
sys.path.insert(0,os.path.dirname(__file__) + "/../utils/")
|
||||
from websocket import *
|
||||
|
||||
buffer_size = 65536
|
||||
max_packet_size = 10000
|
||||
recv_cnt = send_cnt = 0
|
||||
|
||||
|
||||
def check(buf):
|
||||
global recv_cnt
|
||||
|
||||
try:
|
||||
data_list = decode(buf)
|
||||
except:
|
||||
print "\n<BOF>" + repr(buf) + "<EOF>"
|
||||
return "Failed to decode"
|
||||
|
||||
err = ""
|
||||
for data in data_list:
|
||||
if data.count('$') > 1:
|
||||
raise Exception("Multiple parts within single packet")
|
||||
if len(data) == 0:
|
||||
traffic("_")
|
||||
continue
|
||||
|
||||
if data[0] != "^":
|
||||
err += "buf did not start with '^'\n"
|
||||
continue
|
||||
|
||||
try:
|
||||
cnt, length, chksum, nums = data[1:-1].split(':')
|
||||
cnt = int(cnt)
|
||||
length = int(length)
|
||||
chksum = int(chksum)
|
||||
except:
|
||||
print "\n<BOF>" + repr(data) + "<EOF>"
|
||||
err += "Invalid data format\n"
|
||||
continue
|
||||
|
||||
if recv_cnt != cnt:
|
||||
err += "Expected count %d but got %d\n" % (recv_cnt, cnt)
|
||||
recv_cnt = cnt + 1
|
||||
continue
|
||||
|
||||
recv_cnt += 1
|
||||
|
||||
if len(nums) != length:
|
||||
err += "Expected length %d but got %d\n" % (length, len(nums))
|
||||
continue
|
||||
|
||||
inv = nums.translate(None, "0123456789")
|
||||
if inv:
|
||||
err += "Invalid characters found: %s\n" % inv
|
||||
continue
|
||||
|
||||
real_chksum = 0
|
||||
for num in nums:
|
||||
real_chksum += int(num)
|
||||
|
||||
if real_chksum != chksum:
|
||||
err += "Expected checksum %d but real chksum is %d\n" % (chksum, real_chksum)
|
||||
return err
|
||||
|
||||
|
||||
def generate():
|
||||
global send_cnt, rand_array
|
||||
length = random.randint(10, max_packet_size)
|
||||
numlist = rand_array[max_packet_size-length:]
|
||||
# Error in length
|
||||
#numlist.append(5)
|
||||
chksum = sum(numlist)
|
||||
# Error in checksum
|
||||
#numlist[0] = 5
|
||||
nums = "".join( [str(n) for n in numlist] )
|
||||
data = "^%d:%d:%d:%s$" % (send_cnt, length, chksum, nums)
|
||||
send_cnt += 1
|
||||
|
||||
return encode(data)
|
||||
|
||||
def responder(client, delay=10):
|
||||
global errors
|
||||
cqueue = []
|
||||
cpartial = ""
|
||||
socks = [client]
|
||||
last_send = time.time() * 1000
|
||||
|
||||
while True:
|
||||
ins, outs, excepts = select(socks, socks, socks, 1)
|
||||
if excepts: raise Exception("Socket exception")
|
||||
|
||||
if client in ins:
|
||||
buf = client.recv(buffer_size)
|
||||
if len(buf) == 0: raise Exception("Client closed")
|
||||
#print "Client recv: %s (%d)" % (repr(buf[1:-1]), len(buf))
|
||||
if buf[-1] == '\xff':
|
||||
if cpartial:
|
||||
err = check(cpartial + buf)
|
||||
cpartial = ""
|
||||
else:
|
||||
err = check(buf)
|
||||
if err:
|
||||
traffic("}")
|
||||
errors = errors + 1
|
||||
print err
|
||||
else:
|
||||
traffic(">")
|
||||
else:
|
||||
traffic(".>")
|
||||
cpartial = cpartial + buf
|
||||
|
||||
now = time.time() * 1000
|
||||
if client in outs and now > (last_send + delay):
|
||||
last_send = now
|
||||
#print "Client send: %s" % repr(cqueue[0])
|
||||
client.send(generate())
|
||||
traffic("<")
|
||||
|
||||
def test_handler(client):
|
||||
global errors, delay, send_cnt, recv_cnt
|
||||
|
||||
send_cnt = 0
|
||||
recv_cnt = 0
|
||||
|
||||
try:
|
||||
responder(client, delay)
|
||||
except:
|
||||
print "accumulated errors:", errors
|
||||
errors = 0
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
errors = 0
|
||||
try:
|
||||
if len(sys.argv) < 2: raise
|
||||
listen_port = int(sys.argv[1])
|
||||
if len(sys.argv) == 3:
|
||||
delay = int(sys.argv[2])
|
||||
else:
|
||||
delay = 10
|
||||
except:
|
||||
print "Usage: <listen_port> [delay_ms]"
|
||||
sys.exit(1)
|
||||
|
||||
print "Prepopulating random array"
|
||||
rand_array = []
|
||||
for i in range(0, max_packet_size):
|
||||
rand_array.append(random.randint(0, 9))
|
||||
|
||||
settings['listen_port'] = listen_port
|
||||
settings['daemon'] = False
|
||||
settings['handler'] = test_handler
|
||||
start_server()
|
|
@ -0,0 +1,176 @@
|
|||
<html>
|
||||
|
||||
<head>
|
||||
<title>WebSockets Echo Test</title>
|
||||
<script src="include/base64.js"></script>
|
||||
<script src="include/util.js"></script>
|
||||
<script src="include/webutil.js"></script>
|
||||
<!-- Uncomment to activate firebug lite -->
|
||||
<!--
|
||||
<script type='text/javascript'
|
||||
src='http://getfirebug.com/releases/lite/1.2/firebug-lite-compressed.js'></script>
|
||||
-->
|
||||
|
||||
|
||||
</head>
|
||||
|
||||
<body>
|
||||
|
||||
Host: <input id='host' style='width:100'>
|
||||
Port: <input id='port' style='width:50'>
|
||||
Encrypt: <input id='encrypt' type='checkbox'>
|
||||
<input id='connectButton' type='button' value='Start' style='width:100px'
|
||||
onclick="connect();">
|
||||
|
||||
|
||||
<br>
|
||||
Log:<br>
|
||||
<textarea id="messages" style="font-size: 9;" cols=80 rows=25></textarea>
|
||||
</body>
|
||||
|
||||
|
||||
<script>
|
||||
var ws, host = null, port = null,
|
||||
msg_cnt = 0, send_cnt = 1, echoDelay = 500,
|
||||
echo_ref;
|
||||
|
||||
function message(str) {
|
||||
console.log(str);
|
||||
cell = $D('messages');
|
||||
cell.innerHTML += msg_cnt + ": " + str + "\n";
|
||||
cell.scrollTop = cell.scrollHeight;
|
||||
msg_cnt++;
|
||||
}
|
||||
|
||||
Array.prototype.pushStr = function (str) {
|
||||
var n = str.length;
|
||||
for (var i=0; i < n; i++) {
|
||||
this.push(str.charCodeAt(i));
|
||||
}
|
||||
}
|
||||
|
||||
function send_msg() {
|
||||
if (ws.bufferedAmount > 0) {
|
||||
console.log("Delaying send");
|
||||
return;
|
||||
}
|
||||
var str = "Message #" + send_cnt, arr = [];
|
||||
arr.pushStr(str)
|
||||
ws.send(Base64.encode(arr));
|
||||
message("Sent message: '" + str + "'");
|
||||
send_cnt++;
|
||||
}
|
||||
|
||||
function update_stats() {
|
||||
$D('sent').innerHTML = sent;
|
||||
$D('received').innerHTML = received;
|
||||
$D('errors').innerHTML = errors;
|
||||
}
|
||||
|
||||
function init_ws() {
|
||||
console.log(">> init_ws");
|
||||
console.log("<< init_ws");
|
||||
}
|
||||
|
||||
function connect() {
|
||||
var host = $D('host').value,
|
||||
port = $D('port').value,
|
||||
scheme = "ws://", uri;
|
||||
|
||||
console.log(">> connect");
|
||||
if ((!host) || (!port)) {
|
||||
console.log("must set host and port");
|
||||
return;
|
||||
}
|
||||
|
||||
if (ws) {
|
||||
ws.close();
|
||||
}
|
||||
|
||||
if ($D('encrypt').checked) {
|
||||
scheme = "wss://";
|
||||
}
|
||||
uri = scheme + host + ":" + port;
|
||||
message("connecting to " + uri);
|
||||
ws = new WebSocket(uri);
|
||||
|
||||
ws.onmessage = function(e) {
|
||||
//console.log(">> WebSockets.onmessage");
|
||||
var arr = Base64.decode(e.data), str = "", i;
|
||||
|
||||
for (i = 0; i < arr.length; i++) {
|
||||
str = str + String.fromCharCode(arr[i]);
|
||||
}
|
||||
|
||||
message("Received message '" + str + "'");
|
||||
//console.log("<< WebSockets.onmessage");
|
||||
};
|
||||
ws.onopen = function(e) {
|
||||
console.log(">> WebSockets.onopen");
|
||||
echo_ref = setInterval(send_msg, echoDelay);
|
||||
console.log("<< WebSockets.onopen");
|
||||
};
|
||||
ws.onclose = function(e) {
|
||||
console.log(">> WebSockets.onclose");
|
||||
if (echo_ref) {
|
||||
clearInterval(echo_ref);
|
||||
echo_ref = null;
|
||||
}
|
||||
console.log("<< WebSockets.onclose");
|
||||
};
|
||||
ws.onerror = function(e) {
|
||||
console.log(">> WebSockets.onerror");
|
||||
if (echo_ref) {
|
||||
clearInterval(echo_ref);
|
||||
echo_ref = null;
|
||||
}
|
||||
console.log("<< WebSockets.onerror");
|
||||
};
|
||||
|
||||
$D('connectButton').value = "Stop";
|
||||
$D('connectButton').onclick = disconnect;
|
||||
console.log("<< connect");
|
||||
}
|
||||
|
||||
function disconnect() {
|
||||
console.log(">> disconnect");
|
||||
if (ws) {
|
||||
ws.close();
|
||||
}
|
||||
|
||||
if (echo_ref) {
|
||||
clearInterval(echo_ref);
|
||||
}
|
||||
|
||||
$D('connectButton').value = "Start";
|
||||
$D('connectButton').onclick = connect;
|
||||
console.log("<< disconnect");
|
||||
}
|
||||
|
||||
|
||||
/* If no builtin websockets then load web_socket.js */
|
||||
if (window.WebSocket) {
|
||||
VNC_native_ws = true;
|
||||
} else {
|
||||
VNC_native_ws = false;
|
||||
console.log("Loading web-socket-js flash bridge");
|
||||
var extra = "<script src='include/web-socket-js/swfobject.js'><\/script>";
|
||||
extra += "<script src='include/web-socket-js/FABridge.js'><\/script>";
|
||||
extra += "<script src='include/web-socket-js/web_socket.js'><\/script>";
|
||||
document.write(extra);
|
||||
}
|
||||
|
||||
window.onload = function() {
|
||||
console.log("onload");
|
||||
if (!VNC_native_ws) {
|
||||
console.log("initializing web-socket-js flash bridge");
|
||||
WebSocket.__swfLocation = "include/web-socket-js/WebSocketMain.swf";
|
||||
WebSocket.__initialize();
|
||||
}
|
||||
var url = document.location.href;
|
||||
$D('host').value = (url.match(/host=([^&#]*)/) || ['',''])[1];
|
||||
$D('port').value = (url.match(/port=([^&#]*)/) || ['',''])[1];
|
||||
}
|
||||
</script>
|
||||
|
||||
</html>
|
|
@ -0,0 +1 @@
|
|||
../utils/wstest.py
|
|
@ -23,20 +23,13 @@ except:
|
|||
from urlparse import urlsplit
|
||||
from cgi import parse_qsl
|
||||
|
||||
settings = {
|
||||
'verbose' : False,
|
||||
'listen_host' : '',
|
||||
'listen_port' : None,
|
||||
'handler' : None,
|
||||
'handler_id' : 1,
|
||||
'cert' : None,
|
||||
'key' : None,
|
||||
'ssl_only' : False,
|
||||
'daemon' : True,
|
||||
'record' : None,
|
||||
'web' : False, }
|
||||
class WebSocketServer():
|
||||
"""
|
||||
WebSockets server class.
|
||||
Must be sub-classed with handler method definition.
|
||||
"""
|
||||
|
||||
server_handshake = """HTTP/1.1 101 Web Socket Protocol Handshake\r
|
||||
server_handshake = """HTTP/1.1 101 Web Socket Protocol Handshake\r
|
||||
Upgrade: WebSocket\r
|
||||
Connection: Upgrade\r
|
||||
%sWebSocket-Origin: %s\r
|
||||
|
@ -45,10 +38,323 @@ Connection: Upgrade\r
|
|||
\r
|
||||
%s"""
|
||||
|
||||
policy_response = """<cross-domain-policy><allow-access-from domain="*" to-ports="*" /></cross-domain-policy>\n"""
|
||||
policy_response = """<cross-domain-policy><allow-access-from domain="*" to-ports="*" /></cross-domain-policy>\n"""
|
||||
|
||||
class EClose(Exception):
|
||||
pass
|
||||
|
||||
def __init__(self, listen_host='', listen_port=None,
|
||||
verbose=False, cert='', key='', ssl_only=None,
|
||||
daemon=False, record='', web=''):
|
||||
|
||||
# settings
|
||||
self.verbose = verbose
|
||||
self.listen_host = listen_host
|
||||
self.listen_port = listen_port
|
||||
self.ssl_only = ssl_only
|
||||
self.daemon = daemon
|
||||
|
||||
|
||||
# Make paths settings absolute
|
||||
self.cert = os.path.abspath(cert)
|
||||
self.key = self.web = self.record = ''
|
||||
if key:
|
||||
self.key = os.path.abspath(key)
|
||||
if web:
|
||||
self.web = os.path.abspath(web)
|
||||
if record:
|
||||
self.record = os.path.abspath(record)
|
||||
|
||||
if self.web:
|
||||
os.chdir(self.web)
|
||||
|
||||
self.handler_id = 1
|
||||
|
||||
#
|
||||
# WebSocketServer static methods
|
||||
#
|
||||
@staticmethod
|
||||
def daemonize(self, keepfd=None):
|
||||
os.umask(0)
|
||||
if self.web:
|
||||
os.chdir(self.web)
|
||||
else:
|
||||
os.chdir('/')
|
||||
os.setgid(os.getgid()) # relinquish elevations
|
||||
os.setuid(os.getuid()) # relinquish elevations
|
||||
|
||||
# Double fork to daemonize
|
||||
if os.fork() > 0: os._exit(0) # Parent exits
|
||||
os.setsid() # Obtain new process group
|
||||
if os.fork() > 0: os._exit(0) # Parent exits
|
||||
|
||||
# Signal handling
|
||||
def terminate(a,b): os._exit(0)
|
||||
signal.signal(signal.SIGTERM, terminate)
|
||||
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
||||
|
||||
# Close open files
|
||||
maxfd = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
|
||||
if maxfd == resource.RLIM_INFINITY: maxfd = 256
|
||||
for fd in reversed(range(maxfd)):
|
||||
try:
|
||||
if fd != keepfd:
|
||||
os.close(fd)
|
||||
except OSError, exc:
|
||||
if exc.errno != errno.EBADF: raise
|
||||
|
||||
# Redirect I/O to /dev/null
|
||||
os.dup2(os.open(os.devnull, os.O_RDWR), sys.stdin.fileno())
|
||||
os.dup2(os.open(os.devnull, os.O_RDWR), sys.stdout.fileno())
|
||||
os.dup2(os.open(os.devnull, os.O_RDWR), sys.stderr.fileno())
|
||||
|
||||
@staticmethod
|
||||
def encode(buf):
|
||||
""" Encode a WebSocket packet. """
|
||||
buf = b64encode(buf)
|
||||
return "\x00%s\xff" % buf
|
||||
|
||||
@staticmethod
|
||||
def decode(buf):
|
||||
""" Decode WebSocket packets. """
|
||||
if buf.count('\xff') > 1:
|
||||
return [b64decode(d[1:]) for d in buf.split('\xff')]
|
||||
else:
|
||||
return [b64decode(buf[1:-1])]
|
||||
|
||||
@staticmethod
|
||||
def parse_handshake(handshake):
|
||||
""" Parse fields from client WebSockets handshake. """
|
||||
ret = {}
|
||||
req_lines = handshake.split("\r\n")
|
||||
if not req_lines[0].startswith("GET "):
|
||||
raise Exception("Invalid handshake: no GET request line")
|
||||
ret['path'] = req_lines[0].split(" ")[1]
|
||||
for line in req_lines[1:]:
|
||||
if line == "": break
|
||||
var, val = line.split(": ")
|
||||
ret[var] = val
|
||||
|
||||
if req_lines[-2] == "":
|
||||
ret['key3'] = req_lines[-1]
|
||||
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def gen_md5(keys):
|
||||
""" Generate hash value for WebSockets handshake v76. """
|
||||
key1 = keys['Sec-WebSocket-Key1']
|
||||
key2 = keys['Sec-WebSocket-Key2']
|
||||
key3 = keys['key3']
|
||||
spaces1 = key1.count(" ")
|
||||
spaces2 = key2.count(" ")
|
||||
num1 = int("".join([c for c in key1 if c.isdigit()])) / spaces1
|
||||
num2 = int("".join([c for c in key2 if c.isdigit()])) / spaces2
|
||||
|
||||
return md5(struct.pack('>II8s', num1, num2, key3)).digest()
|
||||
|
||||
|
||||
#
|
||||
# WebSocketServer logging/output functions
|
||||
#
|
||||
|
||||
def traffic(self, token="."):
|
||||
""" Show traffic flow in verbose mode. """
|
||||
if self.verbose and not self.daemon:
|
||||
sys.stdout.write(token)
|
||||
sys.stdout.flush()
|
||||
|
||||
def msg(self, msg):
|
||||
""" Output message with handler_id prefix. """
|
||||
if not self.daemon:
|
||||
print "% 3d: %s" % (self.handler_id, msg)
|
||||
|
||||
def vmsg(self, msg):
|
||||
""" Same as msg() but only if verbose. """
|
||||
if self.verbose:
|
||||
self.msg(msg)
|
||||
|
||||
#
|
||||
# Main WebSocketServer methods
|
||||
#
|
||||
|
||||
def do_handshake(self, sock, address):
|
||||
"""
|
||||
do_handshake does the following:
|
||||
- Peek at the first few bytes from the socket.
|
||||
- If the connection is Flash policy request then answer it,
|
||||
close the socket and return.
|
||||
- If the connection is an HTTPS/SSL/TLS connection then SSL
|
||||
wrap the socket.
|
||||
- Read from the (possibly wrapped) socket.
|
||||
- If we have received a HTTP GET request and the webserver
|
||||
functionality is enabled, answer it, close the socket and
|
||||
return.
|
||||
- Assume we have a WebSockets connection, parse the client
|
||||
handshake data.
|
||||
- Send a WebSockets handshake server response.
|
||||
- Return the socket for this WebSocket client.
|
||||
"""
|
||||
|
||||
stype = ""
|
||||
|
||||
# Peek, but don't read the data
|
||||
handshake = sock.recv(1024, socket.MSG_PEEK)
|
||||
#self.msg("Handshake [%s]" % repr(handshake))
|
||||
|
||||
if handshake == "":
|
||||
raise self.EClose("ignoring empty handshake")
|
||||
|
||||
elif handshake.startswith("<policy-file-request/>"):
|
||||
# Answer Flash policy request
|
||||
handshake = sock.recv(1024)
|
||||
sock.send(self.policy_response)
|
||||
raise self.EClose("Sending flash policy response")
|
||||
|
||||
elif handshake[0] in ("\x16", "\x80"):
|
||||
# SSL wrap the connection
|
||||
if not os.path.exists(self.cert):
|
||||
raise self.EClose("SSL connection but '%s' not found"
|
||||
% self.cert)
|
||||
try:
|
||||
retsock = ssl.wrap_socket(
|
||||
sock,
|
||||
server_side=True,
|
||||
certfile=self.cert,
|
||||
keyfile=self.key)
|
||||
except ssl.SSLError, x:
|
||||
if x.args[0] == ssl.SSL_ERROR_EOF:
|
||||
raise self.EClose("")
|
||||
else:
|
||||
raise
|
||||
|
||||
scheme = "wss"
|
||||
stype = "SSL/TLS (wss://)"
|
||||
|
||||
elif self.ssl_only:
|
||||
raise self.EClose("non-SSL connection received but disallowed")
|
||||
|
||||
else:
|
||||
retsock = sock
|
||||
scheme = "ws"
|
||||
stype = "Plain non-SSL (ws://)"
|
||||
|
||||
# Now get the data from the socket
|
||||
handshake = retsock.recv(4096)
|
||||
#self.msg("handshake: " + repr(handshake))
|
||||
|
||||
if len(handshake) == 0:
|
||||
raise self.EClose("Client closed during handshake")
|
||||
|
||||
# Check for and handle normal web requests
|
||||
if handshake.startswith('GET ') and \
|
||||
handshake.find('Upgrade: WebSocket\r\n') == -1:
|
||||
if not self.web:
|
||||
raise self.EClose("Normal web request received but disallowed")
|
||||
sh = SplitHTTPHandler(handshake, retsock, address)
|
||||
if sh.last_code < 200 or sh.last_code >= 300:
|
||||
raise self.EClose(sh.last_message)
|
||||
elif self.verbose:
|
||||
raise self.EClose(sh.last_message)
|
||||
else:
|
||||
raise self.EClose("")
|
||||
|
||||
# Parse client WebSockets handshake
|
||||
h = self.parse_handshake(handshake)
|
||||
|
||||
if h.get('key3'):
|
||||
trailer = self.gen_md5(h)
|
||||
pre = "Sec-"
|
||||
ver = 76
|
||||
else:
|
||||
trailer = ""
|
||||
pre = ""
|
||||
ver = 75
|
||||
|
||||
self.msg("%s: %s WebSocket connection (version %s)"
|
||||
% (address[0], stype, ver))
|
||||
|
||||
# Send server WebSockets handshake response
|
||||
response = self.server_handshake % (pre, h['Origin'], pre,
|
||||
scheme, h['Host'], h['path'], pre, trailer)
|
||||
#self.msg("sending response:", repr(response))
|
||||
retsock.send(response)
|
||||
|
||||
# Return the WebSockets socket which may be SSL wrapped
|
||||
return retsock
|
||||
|
||||
|
||||
def handler(self, client):
|
||||
""" Do something with a WebSockets client connection. """
|
||||
raise("WebSocketServer.handler() must be overloaded")
|
||||
|
||||
def start_server(self):
|
||||
"""
|
||||
Daemonize if requested. Listen for for connections. Run
|
||||
do_handshake() method for each connection. If the connection
|
||||
is a WebSockets client then call handler() method (which must
|
||||
be overridden) for each connection.
|
||||
"""
|
||||
|
||||
lsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
lsock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
lsock.bind((self.listen_host, self.listen_port))
|
||||
lsock.listen(100)
|
||||
|
||||
print "WebSocket server settings:"
|
||||
print " - Listening on %s:%s" % (
|
||||
self.listen_host, self.listen_port)
|
||||
if self.daemon:
|
||||
print " - Backgrounding (daemon)"
|
||||
print " - Flash security policy server"
|
||||
if self.web:
|
||||
print " - Web server"
|
||||
if os.path.exists(self.cert):
|
||||
print " - SSL/TLS support"
|
||||
if self.ssl_only:
|
||||
print " - Deny non-SSL/TLS connections"
|
||||
|
||||
if self.daemon:
|
||||
self.daemonize(self, keepfd=lsock.fileno())
|
||||
|
||||
# Reep zombies
|
||||
signal.signal(signal.SIGCHLD, signal.SIG_IGN)
|
||||
|
||||
while True:
|
||||
try:
|
||||
csock = startsock = None
|
||||
pid = 0
|
||||
startsock, address = lsock.accept()
|
||||
self.vmsg('%s: forking handler' % address[0])
|
||||
pid = os.fork()
|
||||
|
||||
if pid == 0:
|
||||
# handler process
|
||||
csock = self.do_handshake(startsock, address)
|
||||
self.handler(csock)
|
||||
else:
|
||||
# parent process
|
||||
self.handler_id += 1
|
||||
|
||||
except self.EClose, exc:
|
||||
# Connection was not a WebSockets connection
|
||||
if exc.args[0]:
|
||||
self.msg("%s: %s" % (address[0], exc.args[0]))
|
||||
except KeyboardInterrupt, exc:
|
||||
pass
|
||||
except Exception, exc:
|
||||
self.msg("handler exception: %s" % str(exc))
|
||||
if self.verbose:
|
||||
self.msg(traceback.format_exc())
|
||||
finally:
|
||||
if csock and csock != startsock:
|
||||
csock.close()
|
||||
if startsock:
|
||||
startsock.close()
|
||||
|
||||
if pid == 0:
|
||||
break # Child process exits
|
||||
|
||||
class EClose(Exception):
|
||||
pass
|
||||
|
||||
# HTTP handler with request from a string and response to a socket
|
||||
class SplitHTTPHandler(SimpleHTTPRequestHandler):
|
||||
|
@ -73,213 +379,3 @@ class SplitHTTPHandler(SimpleHTTPRequestHandler):
|
|||
self.last_message = f % args
|
||||
|
||||
|
||||
def traffic(token="."):
|
||||
if settings['verbose'] and not settings['daemon']:
|
||||
sys.stdout.write(token)
|
||||
sys.stdout.flush()
|
||||
|
||||
def handler_msg(msg):
|
||||
if not settings['daemon']:
|
||||
print "% 3d: %s" % (settings['handler_id'], msg)
|
||||
|
||||
def handler_vmsg(msg):
|
||||
if settings['verbose']: handler_msg(msg)
|
||||
|
||||
def encode(buf):
|
||||
buf = b64encode(buf)
|
||||
|
||||
return "\x00%s\xff" % buf
|
||||
|
||||
def decode(buf):
|
||||
""" Parse out WebSocket packets. """
|
||||
if buf.count('\xff') > 1:
|
||||
return [b64decode(d[1:]) for d in buf.split('\xff')]
|
||||
else:
|
||||
return [b64decode(buf[1:-1])]
|
||||
|
||||
def parse_handshake(handshake):
|
||||
ret = {}
|
||||
req_lines = handshake.split("\r\n")
|
||||
if not req_lines[0].startswith("GET "):
|
||||
raise Exception("Invalid handshake: no GET request line")
|
||||
ret['path'] = req_lines[0].split(" ")[1]
|
||||
for line in req_lines[1:]:
|
||||
if line == "": break
|
||||
var, val = line.split(": ")
|
||||
ret[var] = val
|
||||
|
||||
if req_lines[-2] == "":
|
||||
ret['key3'] = req_lines[-1]
|
||||
|
||||
return ret
|
||||
|
||||
def gen_md5(keys):
|
||||
key1 = keys['Sec-WebSocket-Key1']
|
||||
key2 = keys['Sec-WebSocket-Key2']
|
||||
key3 = keys['key3']
|
||||
spaces1 = key1.count(" ")
|
||||
spaces2 = key2.count(" ")
|
||||
num1 = int("".join([c for c in key1 if c.isdigit()])) / spaces1
|
||||
num2 = int("".join([c for c in key2 if c.isdigit()])) / spaces2
|
||||
|
||||
return md5(struct.pack('>II8s', num1, num2, key3)).digest()
|
||||
|
||||
|
||||
def do_handshake(sock, address):
|
||||
stype = ""
|
||||
|
||||
# Peek, but don't read the data
|
||||
handshake = sock.recv(1024, socket.MSG_PEEK)
|
||||
#handler_msg("Handshake [%s]" % repr(handshake))
|
||||
if handshake == "":
|
||||
raise EClose("ignoring empty handshake")
|
||||
elif handshake.startswith("<policy-file-request/>"):
|
||||
handshake = sock.recv(1024)
|
||||
sock.send(policy_response)
|
||||
raise EClose("Sending flash policy response")
|
||||
elif handshake[0] in ("\x16", "\x80"):
|
||||
if not os.path.exists(settings['cert']):
|
||||
raise EClose("SSL connection but '%s' not found"
|
||||
% settings['cert'])
|
||||
try:
|
||||
retsock = ssl.wrap_socket(
|
||||
sock,
|
||||
server_side=True,
|
||||
certfile=settings['cert'],
|
||||
keyfile=settings['key'])
|
||||
except ssl.SSLError, x:
|
||||
if x.args[0] == ssl.SSL_ERROR_EOF:
|
||||
raise EClose("")
|
||||
else:
|
||||
raise
|
||||
|
||||
scheme = "wss"
|
||||
stype = "SSL/TLS (wss://)"
|
||||
elif settings['ssl_only']:
|
||||
raise EClose("non-SSL connection received but disallowed")
|
||||
else:
|
||||
retsock = sock
|
||||
scheme = "ws"
|
||||
stype = "Plain non-SSL (ws://)"
|
||||
|
||||
# Now get the data from the socket
|
||||
handshake = retsock.recv(4096)
|
||||
#handler_msg("handshake: " + repr(handshake))
|
||||
|
||||
if len(handshake) == 0:
|
||||
raise EClose("Client closed during handshake")
|
||||
|
||||
# Handle normal web requests
|
||||
if handshake.startswith('GET ') and \
|
||||
handshake.find('Upgrade: WebSocket\r\n') == -1:
|
||||
if not settings['web']:
|
||||
raise EClose("Normal web request received but disallowed")
|
||||
sh = SplitHTTPHandler(handshake, retsock, address)
|
||||
if sh.last_code < 200 or sh.last_code >= 300:
|
||||
raise EClose(sh.last_message)
|
||||
elif settings['verbose']:
|
||||
raise EClose(sh.last_message)
|
||||
else:
|
||||
raise EClose("")
|
||||
|
||||
# Do WebSockets handshake and return the socket
|
||||
h = parse_handshake(handshake)
|
||||
|
||||
if h.get('key3'):
|
||||
trailer = gen_md5(h)
|
||||
pre = "Sec-"
|
||||
ver = 76
|
||||
else:
|
||||
trailer = ""
|
||||
pre = ""
|
||||
ver = 75
|
||||
|
||||
handler_msg("%s WebSocket connection (version %s) from %s"
|
||||
% (stype, ver, address[0]))
|
||||
|
||||
response = server_handshake % (pre, h['Origin'], pre, scheme,
|
||||
h['Host'], h['path'], pre, trailer)
|
||||
|
||||
#handler_msg("sending response:", repr(response))
|
||||
retsock.send(response)
|
||||
return retsock
|
||||
|
||||
def daemonize(keepfd=None):
|
||||
os.umask(0)
|
||||
os.chdir('/')
|
||||
os.setgid(os.getgid()) # relinquish elevations
|
||||
os.setuid(os.getuid()) # relinquish elevations
|
||||
|
||||
# Double fork to daemonize
|
||||
if os.fork() > 0: os._exit(0) # Parent exits
|
||||
os.setsid() # Obtain new process group
|
||||
if os.fork() > 0: os._exit(0) # Parent exits
|
||||
|
||||
# Signal handling
|
||||
def terminate(a,b): os._exit(0)
|
||||
signal.signal(signal.SIGTERM, terminate)
|
||||
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
||||
|
||||
# Close open files
|
||||
maxfd = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
|
||||
if maxfd == resource.RLIM_INFINITY: maxfd = 256
|
||||
for fd in reversed(range(maxfd)):
|
||||
try:
|
||||
if fd != keepfd:
|
||||
os.close(fd)
|
||||
else:
|
||||
handler_vmsg("Keeping fd: %d" % fd)
|
||||
except OSError, exc:
|
||||
if exc.errno != errno.EBADF: raise
|
||||
|
||||
# Redirect I/O to /dev/null
|
||||
os.dup2(os.open(os.devnull, os.O_RDWR), sys.stdin.fileno())
|
||||
os.dup2(os.open(os.devnull, os.O_RDWR), sys.stdout.fileno())
|
||||
os.dup2(os.open(os.devnull, os.O_RDWR), sys.stderr.fileno())
|
||||
|
||||
|
||||
def start_server():
|
||||
|
||||
lsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
lsock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
lsock.bind((settings['listen_host'], settings['listen_port']))
|
||||
lsock.listen(100)
|
||||
|
||||
if settings['daemon']:
|
||||
daemonize(keepfd=lsock.fileno())
|
||||
|
||||
# Reep zombies
|
||||
signal.signal(signal.SIGCHLD, signal.SIG_IGN)
|
||||
|
||||
print 'Waiting for connections on %s:%s' % (
|
||||
settings['listen_host'], settings['listen_port'])
|
||||
|
||||
while True:
|
||||
try:
|
||||
csock = startsock = None
|
||||
pid = 0
|
||||
startsock, address = lsock.accept()
|
||||
handler_vmsg('%s: forking handler' % address[0])
|
||||
pid = os.fork()
|
||||
|
||||
if pid == 0: # handler process
|
||||
csock = do_handshake(startsock, address)
|
||||
settings['handler'](csock)
|
||||
else: # parent process
|
||||
settings['handler_id'] += 1
|
||||
|
||||
except EClose, exc:
|
||||
if csock and csock != startsock:
|
||||
csock.close()
|
||||
startsock.close()
|
||||
if exc.args[0]:
|
||||
handler_msg("%s: %s" % (address[0], exc.args[0]))
|
||||
except Exception, exc:
|
||||
handler_msg("handler exception: %s" % str(exc))
|
||||
if settings['verbose']:
|
||||
handler_msg(traceback.format_exc())
|
||||
|
||||
if pid == 0:
|
||||
if csock: csock.close()
|
||||
if startsock and startsock != csock: startsock.close()
|
||||
break # Child process exits
|
||||
|
|
243
utils/wsproxy.py
243
utils/wsproxy.py
|
@ -11,14 +11,21 @@ as taken from http://docs.python.org/dev/library/ssl.html#certificates
|
|||
|
||||
'''
|
||||
|
||||
import socket, optparse, time
|
||||
import socket, optparse, time, os
|
||||
from select import select
|
||||
from websocket import *
|
||||
from websocket import WebSocketServer
|
||||
|
||||
buffer_size = 65536
|
||||
rec = None
|
||||
class WebSocketProxy(WebSocketServer):
|
||||
"""
|
||||
Proxy traffic to and from a WebSockets client to a normal TCP
|
||||
socket server target. All traffic to/from the client is base64
|
||||
encoded/decoded to allow binary data to be sent/received to/from
|
||||
the target.
|
||||
"""
|
||||
|
||||
traffic_legend = """
|
||||
buffer_size = 65536
|
||||
|
||||
traffic_legend = """
|
||||
Traffic Legend:
|
||||
} - Client receive
|
||||
}. - Client receive partial
|
||||
|
@ -30,101 +37,122 @@ Traffic Legend:
|
|||
<. - Client send partial
|
||||
"""
|
||||
|
||||
def do_proxy(client, target):
|
||||
""" Proxy WebSocket to normal socket. """
|
||||
global rec
|
||||
cqueue = []
|
||||
cpartial = ""
|
||||
tqueue = []
|
||||
rlist = [client, target]
|
||||
tstart = int(time.time()*1000)
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Save off the target host:port
|
||||
self.target_host = kwargs.pop('target_host')
|
||||
self.target_port = kwargs.pop('target_port')
|
||||
WebSocketServer.__init__(self, *args, **kwargs)
|
||||
|
||||
while True:
|
||||
wlist = []
|
||||
tdelta = int(time.time()*1000) - tstart
|
||||
if tqueue: wlist.append(target)
|
||||
if cqueue: wlist.append(client)
|
||||
ins, outs, excepts = select(rlist, wlist, [], 1)
|
||||
if excepts: raise Exception("Socket exception")
|
||||
def handler(self, client):
|
||||
"""
|
||||
Called after a new WebSocket connection has been established.
|
||||
"""
|
||||
|
||||
if target in outs:
|
||||
dat = tqueue.pop(0)
|
||||
sent = target.send(dat)
|
||||
if sent == len(dat):
|
||||
traffic(">")
|
||||
else:
|
||||
tqueue.insert(0, dat[sent:])
|
||||
traffic(".>")
|
||||
##if rec: rec.write("Target send: %s\n" % map(ord, dat))
|
||||
self.rec = None
|
||||
if self.record:
|
||||
# Record raw frame data as a JavaScript compatible file
|
||||
fname = "%s.%s" % (self.record,
|
||||
self.handler_id)
|
||||
self.msg("opening record file: %s" % fname)
|
||||
self.rec = open(fname, 'w+')
|
||||
self.rec.write("var VNC_frame_data = [\n")
|
||||
|
||||
if client in outs:
|
||||
dat = cqueue.pop(0)
|
||||
sent = client.send(dat)
|
||||
if sent == len(dat):
|
||||
traffic("<")
|
||||
##if rec: rec.write("Client send: %s ...\n" % repr(dat[0:80]))
|
||||
if rec: rec.write("%s,\n" % repr("{%s{" % tdelta + dat[1:-1]))
|
||||
else:
|
||||
cqueue.insert(0, dat[sent:])
|
||||
traffic("<.")
|
||||
##if rec: rec.write("Client send partial: %s\n" % repr(dat[0:send]))
|
||||
# Connect to the target
|
||||
self.msg("connecting to: %s:%s" % (
|
||||
self.target_host, self.target_port))
|
||||
tsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
tsock.connect((self.target_host, self.target_port))
|
||||
|
||||
if self.verbose and not self.daemon:
|
||||
print self.traffic_legend
|
||||
|
||||
if target in ins:
|
||||
buf = target.recv(buffer_size)
|
||||
if len(buf) == 0: raise EClose("Target closed")
|
||||
# Stat proxying
|
||||
try:
|
||||
self.do_proxy(client, tsock)
|
||||
except:
|
||||
if tsock: tsock.close()
|
||||
if self.rec:
|
||||
self.rec.write("'EOF']\n")
|
||||
self.rec.close()
|
||||
raise
|
||||
|
||||
cqueue.append(encode(buf))
|
||||
traffic("{")
|
||||
##if rec: rec.write("Target recv (%d): %s\n" % (len(buf), map(ord, buf)))
|
||||
def do_proxy(self, client, target):
|
||||
"""
|
||||
Proxy client WebSocket to normal target socket.
|
||||
"""
|
||||
cqueue = []
|
||||
cpartial = ""
|
||||
tqueue = []
|
||||
rlist = [client, target]
|
||||
tstart = int(time.time()*1000)
|
||||
|
||||
if client in ins:
|
||||
buf = client.recv(buffer_size)
|
||||
if len(buf) == 0: raise EClose("Client closed")
|
||||
while True:
|
||||
wlist = []
|
||||
tdelta = int(time.time()*1000) - tstart
|
||||
|
||||
if buf == '\xff\x00':
|
||||
raise EClose("Client sent orderly close frame")
|
||||
elif buf[-1] == '\xff':
|
||||
if buf.count('\xff') > 1:
|
||||
traffic(str(buf.count('\xff')))
|
||||
traffic("}")
|
||||
##if rec: rec.write("Client recv (%d): %s\n" % (len(buf), repr(buf)))
|
||||
if rec: rec.write("%s,\n" % (repr("}%s}" % tdelta + buf[1:-1])))
|
||||
if cpartial:
|
||||
tqueue.extend(decode(cpartial + buf))
|
||||
cpartial = ""
|
||||
if tqueue: wlist.append(target)
|
||||
if cqueue: wlist.append(client)
|
||||
ins, outs, excepts = select(rlist, wlist, [], 1)
|
||||
if excepts: raise Exception("Socket exception")
|
||||
|
||||
if target in outs:
|
||||
# Send queued client data to the target
|
||||
dat = tqueue.pop(0)
|
||||
sent = target.send(dat)
|
||||
if sent == len(dat):
|
||||
self.traffic(">")
|
||||
else:
|
||||
tqueue.extend(decode(buf))
|
||||
else:
|
||||
traffic(".}")
|
||||
##if rec: rec.write("Client recv partial (%d): %s\n" % (len(buf), repr(buf)))
|
||||
cpartial = cpartial + buf
|
||||
# requeue the remaining data
|
||||
tqueue.insert(0, dat[sent:])
|
||||
self.traffic(".>")
|
||||
|
||||
def proxy_handler(client):
|
||||
global target_host, target_port, options, rec, fname
|
||||
if client in outs:
|
||||
# Send queued target data to the client
|
||||
dat = cqueue.pop(0)
|
||||
sent = client.send(dat)
|
||||
if sent == len(dat):
|
||||
self.traffic("<")
|
||||
if self.rec:
|
||||
self.rec.write("%s,\n" %
|
||||
repr("{%s{" % tdelta + dat[1:-1]))
|
||||
else:
|
||||
cqueue.insert(0, dat[sent:])
|
||||
self.traffic("<.")
|
||||
|
||||
if settings['record']:
|
||||
fname = "%s.%s" % (settings['record'],
|
||||
settings['handler_id'])
|
||||
handler_msg("opening record file: %s" % fname)
|
||||
rec = open(fname, 'w+')
|
||||
rec.write("var VNC_frame_data = [\n")
|
||||
|
||||
handler_msg("connecting to: %s:%s" % (target_host, target_port))
|
||||
tsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
tsock.connect((target_host, target_port))
|
||||
if target in ins:
|
||||
# Receive target data, encode it and queue for client
|
||||
buf = target.recv(self.buffer_size)
|
||||
if len(buf) == 0: raise self.EClose("Target closed")
|
||||
|
||||
if settings['verbose'] and not settings['daemon']:
|
||||
print traffic_legend
|
||||
cqueue.append(self.encode(buf))
|
||||
self.traffic("{")
|
||||
|
||||
try:
|
||||
do_proxy(client, tsock)
|
||||
except:
|
||||
if tsock: tsock.close()
|
||||
if rec:
|
||||
rec.write("'EOF']\n")
|
||||
rec.close()
|
||||
raise
|
||||
if client in ins:
|
||||
# Receive client data, decode it, and queue for target
|
||||
buf = client.recv(self.buffer_size)
|
||||
if len(buf) == 0: raise self.EClose("Client closed")
|
||||
|
||||
if buf == '\xff\x00':
|
||||
raise self.EClose("Client sent orderly close frame")
|
||||
elif buf[-1] == '\xff':
|
||||
if buf.count('\xff') > 1:
|
||||
self.traffic(str(buf.count('\xff')))
|
||||
self.traffic("}")
|
||||
if self.rec:
|
||||
self.rec.write("%s,\n" %
|
||||
(repr("}%s}" % tdelta + buf[1:-1])))
|
||||
if cpartial:
|
||||
# Prepend saved partial and decode frame(s)
|
||||
tqueue.extend(self.decode(cpartial + buf))
|
||||
cpartial = ""
|
||||
else:
|
||||
# decode frame(s)
|
||||
tqueue.extend(self.decode(buf))
|
||||
else:
|
||||
# Save off partial WebSockets frame
|
||||
self.traffic(".}")
|
||||
cpartial = cpartial + buf
|
||||
|
||||
if __name__ == '__main__':
|
||||
usage = "%prog [--record FILE]"
|
||||
|
@ -145,40 +173,31 @@ if __name__ == '__main__':
|
|||
help="disallow non-encrypted connections")
|
||||
parser.add_option("--web", default=None, metavar="DIR",
|
||||
help="run webserver on same port. Serve files from DIR.")
|
||||
(options, args) = parser.parse_args()
|
||||
(opts, args) = parser.parse_args()
|
||||
|
||||
# Sanity checks
|
||||
if len(args) > 2: parser.error("Too many arguments")
|
||||
if len(args) < 2: parser.error("Too few arguments")
|
||||
|
||||
if opts.ssl_only and not os.path.exists(opts.cert):
|
||||
parser.error("SSL only and %s not found" % opts.cert)
|
||||
elif not os.path.exists(opts.cert):
|
||||
print "Warning: %s not found" % opts.cert
|
||||
|
||||
# Parse host:port and convert ports to numbers
|
||||
if args[0].count(':') > 0:
|
||||
host,port = args[0].split(':')
|
||||
opts.listen_host, opts.listen_port = args[0].split(':')
|
||||
else:
|
||||
host,port = '',args[0]
|
||||
opts.listen_host, opts.listen_port = '', args[0]
|
||||
if args[1].count(':') > 0:
|
||||
target_host,target_port = args[1].split(':')
|
||||
opts.target_host, opts.target_port = args[1].split(':')
|
||||
else:
|
||||
parser.error("Error parsing target")
|
||||
try: port = int(port)
|
||||
try: opts.listen_port = int(opts.listen_port)
|
||||
except: parser.error("Error parsing listen port")
|
||||
try: target_port = int(target_port)
|
||||
try: opts.target_port = int(opts.target_port)
|
||||
except: parser.error("Error parsing target port")
|
||||
|
||||
if options.ssl_only and not os.path.exists(options.cert):
|
||||
parser.error("SSL only and %s not found" % options.cert)
|
||||
elif not os.path.exists(options.cert):
|
||||
print "Warning: %s not found" % options.cert
|
||||
|
||||
settings['verbose'] = options.verbose
|
||||
settings['listen_host'] = host
|
||||
settings['listen_port'] = port
|
||||
settings['handler'] = proxy_handler
|
||||
settings['cert'] = os.path.abspath(options.cert)
|
||||
if options.key:
|
||||
settings['key'] = os.path.abspath(options.key)
|
||||
settings['ssl_only'] = options.ssl_only
|
||||
settings['daemon'] = options.daemon
|
||||
if options.record:
|
||||
settings['record'] = os.path.abspath(options.record)
|
||||
if options.web:
|
||||
os.chdir = options.web
|
||||
settings['web'] = options.web
|
||||
start_server()
|
||||
# Create and start the WebSockets proxy
|
||||
server = WebSocketProxy(**opts.__dict__)
|
||||
server.start_server()
|
||||
|
|
|
@ -0,0 +1,171 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
'''
|
||||
WebSocket server-side load test program. Sends and receives traffic
|
||||
that has a random payload (length and content) that is checksummed and
|
||||
given a sequence number. Any errors are reported and counted.
|
||||
'''
|
||||
|
||||
import sys, os, socket, ssl, time, traceback
|
||||
import random, time
|
||||
from select import select
|
||||
|
||||
sys.path.insert(0,os.path.dirname(__file__) + "/../utils/")
|
||||
from websocket import WebSocketServer
|
||||
|
||||
|
||||
class WebSocketTest(WebSocketServer):
|
||||
|
||||
buffer_size = 65536
|
||||
max_packet_size = 10000
|
||||
recv_cnt = 0
|
||||
send_cnt = 0
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.errors = 0
|
||||
self.delay = kwargs.pop('delay')
|
||||
|
||||
print "Prepopulating random array"
|
||||
self.rand_array = []
|
||||
for i in range(0, self.max_packet_size):
|
||||
self.rand_array.append(random.randint(0, 9))
|
||||
|
||||
WebSocketServer.__init__(self, *args, **kwargs)
|
||||
|
||||
def handler(self, client):
|
||||
self.send_cnt = 0
|
||||
self.recv_cnt = 0
|
||||
|
||||
try:
|
||||
self.responder(client)
|
||||
except:
|
||||
print "accumulated errors:", self.errors
|
||||
self.errors = 0
|
||||
raise
|
||||
|
||||
def responder(self, client):
|
||||
cqueue = []
|
||||
cpartial = ""
|
||||
socks = [client]
|
||||
last_send = time.time() * 1000
|
||||
|
||||
while True:
|
||||
ins, outs, excepts = select(socks, socks, socks, 1)
|
||||
if excepts: raise Exception("Socket exception")
|
||||
|
||||
if client in ins:
|
||||
buf = client.recv(self.buffer_size)
|
||||
if len(buf) == 0:
|
||||
raise self.EClose("Client closed")
|
||||
#print "Client recv: %s (%d)" % (repr(buf[1:-1]), len(buf))
|
||||
if buf[-1] == '\xff':
|
||||
if cpartial:
|
||||
err = self.check(cpartial + buf)
|
||||
cpartial = ""
|
||||
else:
|
||||
err = self.check(buf)
|
||||
if err:
|
||||
self.traffic("}")
|
||||
self.errors = self.errors + 1
|
||||
print err
|
||||
else:
|
||||
self.traffic(">")
|
||||
else:
|
||||
self.traffic(".>")
|
||||
cpartial = cpartial + buf
|
||||
|
||||
now = time.time() * 1000
|
||||
if client in outs and now > (last_send + self.delay):
|
||||
last_send = now
|
||||
#print "Client send: %s" % repr(cqueue[0])
|
||||
client.send(self.generate())
|
||||
self.traffic("<")
|
||||
|
||||
def generate(self):
|
||||
length = random.randint(10, self.max_packet_size)
|
||||
numlist = self.rand_array[self.max_packet_size-length:]
|
||||
# Error in length
|
||||
#numlist.append(5)
|
||||
chksum = sum(numlist)
|
||||
# Error in checksum
|
||||
#numlist[0] = 5
|
||||
nums = "".join( [str(n) for n in numlist] )
|
||||
data = "^%d:%d:%d:%s$" % (self.send_cnt, length, chksum, nums)
|
||||
self.send_cnt += 1
|
||||
|
||||
return WebSocketServer.encode(data)
|
||||
|
||||
|
||||
def check(self, buf):
|
||||
try:
|
||||
data_list = WebSocketServer.decode(buf)
|
||||
except:
|
||||
print "\n<BOF>" + repr(buf) + "<EOF>"
|
||||
return "Failed to decode"
|
||||
|
||||
err = ""
|
||||
for data in data_list:
|
||||
if data.count('$') > 1:
|
||||
raise Exception("Multiple parts within single packet")
|
||||
if len(data) == 0:
|
||||
self.traffic("_")
|
||||
continue
|
||||
|
||||
if data[0] != "^":
|
||||
err += "buf did not start with '^'\n"
|
||||
continue
|
||||
|
||||
try:
|
||||
cnt, length, chksum, nums = data[1:-1].split(':')
|
||||
cnt = int(cnt)
|
||||
length = int(length)
|
||||
chksum = int(chksum)
|
||||
except:
|
||||
print "\n<BOF>" + repr(data) + "<EOF>"
|
||||
err += "Invalid data format\n"
|
||||
continue
|
||||
|
||||
if self.recv_cnt != cnt:
|
||||
err += "Expected count %d but got %d\n" % (self.recv_cnt, cnt)
|
||||
self.recv_cnt = cnt + 1
|
||||
continue
|
||||
|
||||
self.recv_cnt += 1
|
||||
|
||||
if len(nums) != length:
|
||||
err += "Expected length %d but got %d\n" % (length, len(nums))
|
||||
continue
|
||||
|
||||
inv = nums.translate(None, "0123456789")
|
||||
if inv:
|
||||
err += "Invalid characters found: %s\n" % inv
|
||||
continue
|
||||
|
||||
real_chksum = 0
|
||||
for num in nums:
|
||||
real_chksum += int(num)
|
||||
|
||||
if real_chksum != chksum:
|
||||
err += "Expected checksum %d but real chksum is %d\n" % (chksum, real_chksum)
|
||||
return err
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
try:
|
||||
if len(sys.argv) < 2: raise
|
||||
listen_port = int(sys.argv[1])
|
||||
if len(sys.argv) == 3:
|
||||
delay = int(sys.argv[2])
|
||||
else:
|
||||
delay = 10
|
||||
except:
|
||||
print "Usage: %s <listen_port> [delay_ms]" % sys.argv[0]
|
||||
sys.exit(1)
|
||||
|
||||
server = WebSocketTest(
|
||||
listen_port=listen_port,
|
||||
verbose=True,
|
||||
cert='self.pem',
|
||||
web='.',
|
||||
delay=delay)
|
||||
server.start_server()
|
Loading…
Reference in New Issue