Merge commit 'a04edfe80f54b44df5a3579f71710560c6b7b4fc'

* commit 'a04edfe80f54b44df5a3579f71710560c6b7b4fc':
  Added temp dir for unit test data and cleanup
This commit is contained in:
Peter Åstrand (astrand) 2013-11-28 09:34:33 +01:00
commit 7ecfa4f384
4 changed files with 90 additions and 107 deletions

View File

@ -17,7 +17,9 @@
""" Unit tests for websocket """ """ Unit tests for websocket """
import errno import errno
import os import os
import logging
import select import select
import shutil
import socket import socket
import ssl import ssl
import stubout import stubout
@ -39,24 +41,44 @@ class MockConnection(object):
class WebSocketTestCase(unittest.TestCase): class WebSocketTestCase(unittest.TestCase):
def _init_logger(self, tmpdir):
name = 'websocket-unittest'
logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)
logger.propagate = True
filename = "%s.log" % (name)
handler = logging.FileHandler(filename)
handler.setFormatter(logging.Formatter("%(message)s"))
logger.addHandler(handler)
def setUp(self): def setUp(self):
"""Called automatically before each test.""" """Called automatically before each test."""
super(WebSocketTestCase, self).setUp() super(WebSocketTestCase, self).setUp()
self.stubs = stubout.StubOutForTesting() self.stubs = stubout.StubOutForTesting()
self.server = websocket.WebSocketServer(listen_host='localhost', # Temporary dir for test data
listen_port=80, self.tmpdir = tempfile.mkdtemp()
key='./', # Put log somewhere persistent
web='./', self._init_logger('./')
record='./', # Mock this out cause it screws tests up
daemon=True, self.stubs.Set(os, 'chdir', lambda *args, **kwargs: None)
ssl_only=False) self.server = self._get_websockserver(daemon=True,
ssl_only=False)
self.soc = self.server.socket('localhost') self.soc = self.server.socket('localhost')
def tearDown(self): def tearDown(self):
"""Called automatically after each test.""" """Called automatically after each test."""
self.stubs.UnsetAll() self.stubs.UnsetAll()
shutil.rmtree(self.tmpdir)
super(WebSocketTestCase, self).tearDown() super(WebSocketTestCase, self).tearDown()
def _get_websockserver(self, **kwargs):
return websocket.WebSocketServer(listen_host='localhost',
listen_port=80,
key=self.tmpdir,
web=self.tmpdir,
record=self.tmpdir,
**kwargs)
def _mock_os_open_oserror(self, file, flags): def _mock_os_open_oserror(self, file, flags):
raise OSError('') raise OSError('')
@ -83,28 +105,14 @@ class WebSocketTestCase(unittest.TestCase):
sys.exit() sys.exit()
def test_daemonize_error(self): def test_daemonize_error(self):
soc = websocket.WebSocketServer(listen_host='localhost', soc = self._get_websockserver(daemon=True, ssl_only=1, idle_timeout=1)
listen_port=80,
key='../',
web='../',
record='../',
daemon=True,
ssl_only=1,
idle_timeout=1)
self.stubs.Set(os, 'fork', lambda *args: None) self.stubs.Set(os, 'fork', lambda *args: None)
self.stubs.Set(os, 'setsid', lambda *args: None) self.stubs.Set(os, 'setsid', lambda *args: None)
self.stubs.Set(os, 'close', self._mock_os_close_oserror) self.stubs.Set(os, 'close', self._mock_os_close_oserror)
self.assertRaises(OSError, soc.daemonize, keepfd=None, chdir='./') self.assertRaises(OSError, soc.daemonize, keepfd=None, chdir='./')
def test_daemonize_EBADF_error(self): def test_daemonize_EBADF_error(self):
soc = websocket.WebSocketServer(listen_host='localhost', soc = self._get_websockserver(daemon=True, ssl_only=1, idle_timeout=1)
listen_port=80,
key='../',
web='../',
record='../',
daemon=True,
ssl_only=1,
idle_timeout=1)
self.stubs.Set(os, 'fork', lambda *args: None) self.stubs.Set(os, 'fork', lambda *args: None)
self.stubs.Set(os, 'setsid', lambda *args: None) self.stubs.Set(os, 'setsid', lambda *args: None)
self.stubs.Set(os, 'close', self._mock_os_close_oserror_EBADF) self.stubs.Set(os, 'close', self._mock_os_close_oserror_EBADF)
@ -112,27 +120,12 @@ class WebSocketTestCase(unittest.TestCase):
self.assertRaises(OSError, soc.daemonize, keepfd=None, chdir='./') self.assertRaises(OSError, soc.daemonize, keepfd=None, chdir='./')
def test_decode_hybi(self): def test_decode_hybi(self):
soc = websocket.WebSocketServer(listen_host='localhost', soc = self._get_websockserver(daemon=False, ssl_only=1, idle_timeout=1)
listen_port=80,
key='../',
web='../',
record='../',
daemon=False,
ssl_only=1,
idle_timeout=1)
self.assertRaises(Exception, soc.decode_hybi, 'a' * 128, self.assertRaises(Exception, soc.decode_hybi, 'a' * 128,
base64=True) base64=True)
def test_do_websocket_handshake(self): def test_do_websocket_handshake(self):
soc = websocket.WebSocketServer(listen_host='localhost', soc = self._get_websockserver(daemon=True, ssl_only=0, idle_timeout=1)
listen_port=80,
key='../',
web='../',
record='../',
daemon=True,
ssl_only=0,
idle_timeout=1)
soc.scheme = 'scheme' soc.scheme = 'scheme'
headers = {'Sec-WebSocket-Protocol': 'binary', headers = {'Sec-WebSocket-Protocol': 'binary',
'Sec-WebSocket-Version': '7', 'Sec-WebSocket-Version': '7',
@ -140,27 +133,13 @@ class WebSocketTestCase(unittest.TestCase):
soc.do_websocket_handshake(headers, '127.0.0.1') soc.do_websocket_handshake(headers, '127.0.0.1')
def test_do_handshake(self): def test_do_handshake(self):
soc = websocket.WebSocketServer(listen_host='localhost', soc = self._get_websockserver(daemon=True, ssl_only=0, idle_timeout=1)
listen_port=80,
key='../',
web='../',
record='../',
daemon=True,
ssl_only=0,
idle_timeout=1)
self.stubs.Set(select, 'select', self._mock_select) self.stubs.Set(select, 'select', self._mock_select)
self.stubs.Set(socket._socketobject, 'recv', lambda *args: 'mock_recv') self.stubs.Set(socket._socketobject, 'recv', lambda *args: 'mock_recv')
self.assertRaises(Exception, soc.do_handshake, self.soc, '127.0.0.1') self.assertRaises(Exception, soc.do_handshake, self.soc, '127.0.0.1')
def test_do_handshake_ssl_error(self): def test_do_handshake_ssl_error(self):
soc = websocket.WebSocketServer(listen_host='localhost', soc = self._get_websockserver(daemon=True, ssl_only=0, idle_timeout=1)
listen_port=80,
key='../',
web='../',
record='../',
daemon=True,
ssl_only=0,
idle_timeout=1)
def _mock_wrap_socket(*args, **kwargs): def _mock_wrap_socket(*args, **kwargs):
from ssl import SSLError from ssl import SSLError
@ -172,25 +151,11 @@ class WebSocketTestCase(unittest.TestCase):
self.assertRaises(SSLError, soc.do_handshake, self.soc, '127.0.0.1') self.assertRaises(SSLError, soc.do_handshake, self.soc, '127.0.0.1')
def test_fallback_SIGCHILD(self): def test_fallback_SIGCHILD(self):
soc = websocket.WebSocketServer(listen_host='localhost', soc = self._get_websockserver(daemon=True, ssl_only=0, idle_timeout=1)
listen_port=80,
key='../',
web='../',
record='../',
daemon=True,
ssl_only=0,
idle_timeout=1)
soc.fallback_SIGCHLD(None, None) soc.fallback_SIGCHLD(None, None)
def test_start_server_Exception(self): def test_start_server_Exception(self):
soc = websocket.WebSocketServer(listen_host='localhost', soc = self._get_websockserver(daemon=False, ssl_only=1, idle_timeout=1)
listen_port=80,
key='../',
web='../',
record='../',
daemon=False,
ssl_only=1,
idle_timeout=1)
self.stubs.Set(websocket.WebSocketServer, 'socket', self._mock_socket) self.stubs.Set(websocket.WebSocketServer, 'socket', self._mock_socket)
self.stubs.Set(websocket.WebSocketServer, 'daemonize', self.stubs.Set(websocket.WebSocketServer, 'daemonize',
lambda *args, **kwargs: None) lambda *args, **kwargs: None)
@ -198,15 +163,7 @@ class WebSocketTestCase(unittest.TestCase):
self.assertEqual(None, soc.start_server()) self.assertEqual(None, soc.start_server())
def test_start_server_KeyboardInterrupt(self): def test_start_server_KeyboardInterrupt(self):
soc = websocket.WebSocketServer(listen_host='localhost', soc = self._get_websockserver(daemon=False, ssl_only=1, idle_timeout=1)
listen_port=80,
key='../',
web='../',
record='../',
cert='xxxxxx',
daemon=True,
ssl_only=1,
idle_timeout=1)
self.stubs.Set(websocket.WebSocketServer, 'socket', self._mock_socket) self.stubs.Set(websocket.WebSocketServer, 'socket', self._mock_socket)
self.stubs.Set(websocket.WebSocketServer, 'daemonize', self.stubs.Set(websocket.WebSocketServer, 'daemonize',
lambda *args, **kwargs: None) lambda *args, **kwargs: None)
@ -215,19 +172,12 @@ class WebSocketTestCase(unittest.TestCase):
def test_start_server_systemexit(self): def test_start_server_systemexit(self):
websocket.ssl = None websocket.ssl = None
soc = websocket.WebSocketServer(listen_host='localhost',
listen_port=80,
key='../',
web='../',
record='../',
daemon=True,
ssl_only=0,
idle_timeout=1,
verbose=True)
self.stubs.Set(websocket.WebSocketServer, 'socket', self._mock_socket) self.stubs.Set(websocket.WebSocketServer, 'socket', self._mock_socket)
self.stubs.Set(websocket.WebSocketServer, 'daemonize', self.stubs.Set(websocket.WebSocketServer, 'daemonize',
lambda *args, **kwargs: None) lambda *args, **kwargs: None)
self.stubs.Set(select, 'select', self._mock_select_systemexit) self.stubs.Set(select, 'select', self._mock_select_systemexit)
soc = self._get_websockserver(daemon=True, ssl_only=0, idle_timeout=1,
verbose=True)
self.assertEqual(None, soc.start_server()) self.assertEqual(None, soc.start_server())
def test_WSRequestHandle_do_GET_nofile(self): def test_WSRequestHandle_do_GET_nofile(self):

View File

@ -15,11 +15,15 @@
# under the License. # under the License.
""" Unit tests for websocketproxy """ """ Unit tests for websocketproxy """
import unittest import os
import time import logging
import subprocess
import stubout
import select import select
import shutil
import stubout
import subprocess
import tempfile
import time
import unittest
from websockify import websocketproxy from websockify import websocketproxy
@ -37,20 +41,42 @@ class MockSocket(object):
class WebSocketProxyTest(unittest.TestCase): class WebSocketProxyTest(unittest.TestCase):
def _init_logger(self, tmpdir):
name = 'websocket-unittest'
logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)
logger.propagate = True
filename = "%s.log" % (name)
handler = logging.FileHandler(filename)
handler.setFormatter(logging.Formatter("%(message)s"))
logger.addHandler(handler)
def setUp(self): def setUp(self):
"""Called automatically before each test.""" """Called automatically before each test."""
super(WebSocketProxyTest, self).setUp() super(WebSocketProxyTest, self).setUp()
self.soc = '' self.soc = ''
self.stubs = stubout.StubOutForTesting() self.stubs = stubout.StubOutForTesting()
# Temporary dir for test data
self.tmpdir = tempfile.mkdtemp()
# Put log somewhere persistent
self._init_logger('./')
# Mock this out cause it screws tests up
self.stubs.Set(os, 'chdir', lambda *args, **kwargs: None)
def tearDown(self): def tearDown(self):
"""Called automatically after each test.""" """Called automatically after each test."""
self.stubs.UnsetAll() self.stubs.UnsetAll()
shutil.rmtree(self.tmpdir)
super(WebSocketProxyTest, self).tearDown() super(WebSocketProxyTest, self).tearDown()
def _get_websockproxy(self, **kwargs):
return websocketproxy.WebSocketProxy(key=self.tmpdir,
web=self.tmpdir,
record=self.tmpdir,
**kwargs)
def test_run_wrap_cmd(self): def test_run_wrap_cmd(self):
web_socket_proxy = websocketproxy.WebSocketProxy() web_socket_proxy = self._get_websockproxy()
web_socket_proxy.__dict__["wrap_cmd"] = "wrap_cmd" web_socket_proxy.__dict__["wrap_cmd"] = "wrap_cmd"
def mock_Popen(*args, **kwargs): def mock_Popen(*args, **kwargs):
@ -61,7 +87,7 @@ class WebSocketProxyTest(unittest.TestCase):
self.assertEquals(web_socket_proxy.spawn_message, True) self.assertEquals(web_socket_proxy.spawn_message, True)
def test_started(self): def test_started(self):
web_socket_proxy = websocketproxy.WebSocketProxy() web_socket_proxy = self._get_websockproxy()
web_socket_proxy.__dict__["spawn_message"] = False web_socket_proxy.__dict__["spawn_message"] = False
web_socket_proxy.__dict__["wrap_cmd"] = "wrap_cmd" web_socket_proxy.__dict__["wrap_cmd"] = "wrap_cmd"
@ -73,7 +99,7 @@ class WebSocketProxyTest(unittest.TestCase):
self.assertEquals(web_socket_proxy.__dict__["spawn_message"], True) self.assertEquals(web_socket_proxy.__dict__["spawn_message"], True)
def test_poll(self): def test_poll(self):
web_socket_proxy = websocketproxy.WebSocketProxy() web_socket_proxy = self._get_websockproxy()
web_socket_proxy.__dict__["wrap_cmd"] = "wrap_cmd" web_socket_proxy.__dict__["wrap_cmd"] = "wrap_cmd"
web_socket_proxy.__dict__["wrap_mode"] = "respawn" web_socket_proxy.__dict__["wrap_mode"] = "respawn"
web_socket_proxy.__dict__["wrap_times"] = [99999999] web_socket_proxy.__dict__["wrap_times"] = [99999999]
@ -84,7 +110,7 @@ class WebSocketProxyTest(unittest.TestCase):
self.assertEquals(web_socket_proxy.spawn_message, False) self.assertEquals(web_socket_proxy.spawn_message, False)
def test_new_client(self): def test_new_client(self):
web_socket_proxy = websocketproxy.WebSocketProxy() web_socket_proxy = self._get_websockproxy()
web_socket_proxy.__dict__["verbose"] = "verbose" web_socket_proxy.__dict__["verbose"] = "verbose"
web_socket_proxy.__dict__["daemon"] = None web_socket_proxy.__dict__["daemon"] = None
web_socket_proxy.__dict__["client"] = "client" web_socket_proxy.__dict__["client"] = "client"

View File

@ -4,7 +4,7 @@
# and then run "tox" from this directory. # and then run "tox" from this directory.
[tox] [tox]
envlist = py24, py25, py26, py27, py30 envlist = py24,py25,py26,py27,py30
setupdir = ../ setupdir = ../
[testenv] [testenv]
@ -12,3 +12,9 @@ commands = nosetests {posargs}
deps = deps =
mox mox
nose nose
# At some point we should enable this since tox epdctes it to exist but
# the code will need pep8ising first.
#[testenv:pep8]
#commands = flake8
#dep = flake8

View File

@ -168,7 +168,7 @@ class WebSocketRequestHandler(SimpleHTTPRequestHandler):
return header + buf, len(header), 0 return header + buf, len(header), 0
@staticmethod @staticmethod
def decode_hybi(buf, base64=False): def decode_hybi(buf, base64=False, logger=None):
""" Decode HyBi style WebSocket packets. """ Decode HyBi style WebSocket packets.
Returns: Returns:
{'fin' : 0_or_1, {'fin' : 0_or_1,
@ -192,7 +192,8 @@ class WebSocketRequestHandler(SimpleHTTPRequestHandler):
'close_code' : 1000, 'close_code' : 1000,
'close_reason' : ''} 'close_reason' : ''}
logger = WebSocketServer.get_logger() if logger is None:
logger = WebSocketServer.get_logger()
blen = len(buf) blen = len(buf)
f['left'] = blen f['left'] = blen
@ -232,16 +233,15 @@ class WebSocketRequestHandler(SimpleHTTPRequestHandler):
f['payload'] = WebSocketRequestHandler.unmask(buf, f['hlen'], f['payload'] = WebSocketRequestHandler.unmask(buf, f['hlen'],
f['length']) f['length'])
else: else:
self.vmsg("Unmasked frame: %s" % repr(buf)) logger.debug("Unmasked frame: %s" % repr(buf))
f['payload'] = buf[(f['hlen'] + f['masked'] * 4):full_len] f['payload'] = buf[(f['hlen'] + f['masked'] * 4):full_len]
if base64 and f['opcode'] in [1, 2]: if base64 and f['opcode'] in [1, 2]:
try: try:
f['payload'] = b64decode(f['payload']) f['payload'] = b64decode(f['payload'])
except: except:
self.warn("Exception while b64decoding buffer: %s", logger.exception("Exception while b64decoding buffer: %s" %
repr(buf)) (repr(buf)))
self.vmsg('Exception', exc_info=True)
raise raise
if f['opcode'] == 0x08: if f['opcode'] == 0x08:
@ -340,7 +340,8 @@ class WebSocketRequestHandler(SimpleHTTPRequestHandler):
self.recv_part = None self.recv_part = None
while buf: while buf:
frame = self.decode_hybi(buf, base64=self.base64) frame = self.decode_hybi(buf, base64=self.base64,
logger=self.logger)
#self.msg("Received buf: %s, frame: %s", repr(buf), frame) #self.msg("Received buf: %s, frame: %s", repr(buf), frame)
if frame['payload'] == None: if frame['payload'] == None: