From 32c1abd5d9643296e0d30abe2b2ccde324d5abcc Mon Sep 17 00:00:00 2001 From: Edward Hope-Morley Date: Wed, 13 Nov 2013 12:01:05 +0000 Subject: [PATCH] Added temp dir for unit test data and cleanup Unit test data will now go to a temporary dir that will be deleted once the test completes. The unit tests also setup a logger which will persist so that it can be inspected once tests complete. Also fixes a bug where instance var is missing from decode_hybi() Co-authored-by: natsume.takashi@lab.ntt.co.jp --- tests/test_websocket.py | 130 +++++++++++------------------------ tests/test_websocketproxy.py | 44 +++++++++--- tests/tox.ini | 8 ++- websockify/websocket.py | 15 ++-- 4 files changed, 90 insertions(+), 107 deletions(-) diff --git a/tests/test_websocket.py b/tests/test_websocket.py index 49efe81..c7a106f 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -17,7 +17,9 @@ """ Unit tests for websocket """ import errno import os +import logging import select +import shutil import socket import ssl import stubout @@ -39,24 +41,44 @@ class MockConnection(object): class WebSocketTestCase(unittest.TestCase): + def _init_logger(self, tmpdir): + name = 'websocket-unittest' + logger = logging.getLogger(name) + logger.setLevel(logging.DEBUG) + logger.propagate = True + filename = "%s.log" % (name) + handler = logging.FileHandler(filename) + handler.setFormatter(logging.Formatter("%(message)s")) + logger.addHandler(handler) + def setUp(self): """Called automatically before each test.""" super(WebSocketTestCase, self).setUp() self.stubs = stubout.StubOutForTesting() - self.server = websocket.WebSocketServer(listen_host='localhost', - listen_port=80, - key='./', - web='./', - record='./', - daemon=True, - ssl_only=False) + # Temporary dir for test data + self.tmpdir = tempfile.mkdtemp() + # Put log somewhere persistent + self._init_logger('./') + # Mock this out cause it screws tests up + self.stubs.Set(os, 'chdir', lambda *args, **kwargs: None) + self.server = self._get_websockserver(daemon=True, + ssl_only=False) self.soc = self.server.socket('localhost') def tearDown(self): """Called automatically after each test.""" self.stubs.UnsetAll() + shutil.rmtree(self.tmpdir) super(WebSocketTestCase, self).tearDown() + def _get_websockserver(self, **kwargs): + return websocket.WebSocketServer(listen_host='localhost', + listen_port=80, + key=self.tmpdir, + web=self.tmpdir, + record=self.tmpdir, + **kwargs) + def _mock_os_open_oserror(self, file, flags): raise OSError('') @@ -83,28 +105,14 @@ class WebSocketTestCase(unittest.TestCase): sys.exit() def test_daemonize_error(self): - soc = websocket.WebSocketServer(listen_host='localhost', - listen_port=80, - key='../', - web='../', - record='../', - daemon=True, - ssl_only=1, - idle_timeout=1) + soc = self._get_websockserver(daemon=True, ssl_only=1, idle_timeout=1) self.stubs.Set(os, 'fork', lambda *args: None) self.stubs.Set(os, 'setsid', lambda *args: None) self.stubs.Set(os, 'close', self._mock_os_close_oserror) self.assertRaises(OSError, soc.daemonize, keepfd=None, chdir='./') def test_daemonize_EBADF_error(self): - soc = websocket.WebSocketServer(listen_host='localhost', - listen_port=80, - key='../', - web='../', - record='../', - daemon=True, - ssl_only=1, - idle_timeout=1) + soc = self._get_websockserver(daemon=True, ssl_only=1, idle_timeout=1) self.stubs.Set(os, 'fork', lambda *args: None) self.stubs.Set(os, 'setsid', lambda *args: None) self.stubs.Set(os, 'close', self._mock_os_close_oserror_EBADF) @@ -112,27 +120,12 @@ class WebSocketTestCase(unittest.TestCase): self.assertRaises(OSError, soc.daemonize, keepfd=None, chdir='./') def test_decode_hybi(self): - soc = websocket.WebSocketServer(listen_host='localhost', - listen_port=80, - key='../', - web='../', - record='../', - daemon=False, - ssl_only=1, - idle_timeout=1) - + soc = self._get_websockserver(daemon=False, ssl_only=1, idle_timeout=1) self.assertRaises(Exception, soc.decode_hybi, 'a' * 128, base64=True) def test_do_websocket_handshake(self): - soc = websocket.WebSocketServer(listen_host='localhost', - listen_port=80, - key='../', - web='../', - record='../', - daemon=True, - ssl_only=0, - idle_timeout=1) + soc = self._get_websockserver(daemon=True, ssl_only=0, idle_timeout=1) soc.scheme = 'scheme' headers = {'Sec-WebSocket-Protocol': 'binary', 'Sec-WebSocket-Version': '7', @@ -140,27 +133,13 @@ class WebSocketTestCase(unittest.TestCase): soc.do_websocket_handshake(headers, '127.0.0.1') def test_do_handshake(self): - soc = websocket.WebSocketServer(listen_host='localhost', - listen_port=80, - key='../', - web='../', - record='../', - daemon=True, - ssl_only=0, - idle_timeout=1) + soc = self._get_websockserver(daemon=True, ssl_only=0, idle_timeout=1) self.stubs.Set(select, 'select', self._mock_select) self.stubs.Set(socket._socketobject, 'recv', lambda *args: 'mock_recv') self.assertRaises(Exception, soc.do_handshake, self.soc, '127.0.0.1') def test_do_handshake_ssl_error(self): - soc = websocket.WebSocketServer(listen_host='localhost', - listen_port=80, - key='../', - web='../', - record='../', - daemon=True, - ssl_only=0, - idle_timeout=1) + soc = self._get_websockserver(daemon=True, ssl_only=0, idle_timeout=1) def _mock_wrap_socket(*args, **kwargs): from ssl import SSLError @@ -172,25 +151,11 @@ class WebSocketTestCase(unittest.TestCase): self.assertRaises(SSLError, soc.do_handshake, self.soc, '127.0.0.1') def test_fallback_SIGCHILD(self): - soc = websocket.WebSocketServer(listen_host='localhost', - listen_port=80, - key='../', - web='../', - record='../', - daemon=True, - ssl_only=0, - idle_timeout=1) + soc = self._get_websockserver(daemon=True, ssl_only=0, idle_timeout=1) soc.fallback_SIGCHLD(None, None) def test_start_server_Exception(self): - soc = websocket.WebSocketServer(listen_host='localhost', - listen_port=80, - key='../', - web='../', - record='../', - daemon=False, - ssl_only=1, - idle_timeout=1) + soc = self._get_websockserver(daemon=False, ssl_only=1, idle_timeout=1) self.stubs.Set(websocket.WebSocketServer, 'socket', self._mock_socket) self.stubs.Set(websocket.WebSocketServer, 'daemonize', lambda *args, **kwargs: None) @@ -198,15 +163,7 @@ class WebSocketTestCase(unittest.TestCase): self.assertEqual(None, soc.start_server()) def test_start_server_KeyboardInterrupt(self): - soc = websocket.WebSocketServer(listen_host='localhost', - listen_port=80, - key='../', - web='../', - record='../', - cert='xxxxxx', - daemon=True, - ssl_only=1, - idle_timeout=1) + soc = self._get_websockserver(daemon=False, ssl_only=1, idle_timeout=1) self.stubs.Set(websocket.WebSocketServer, 'socket', self._mock_socket) self.stubs.Set(websocket.WebSocketServer, 'daemonize', lambda *args, **kwargs: None) @@ -215,19 +172,12 @@ class WebSocketTestCase(unittest.TestCase): def test_start_server_systemexit(self): websocket.ssl = None - soc = websocket.WebSocketServer(listen_host='localhost', - listen_port=80, - key='../', - web='../', - record='../', - daemon=True, - ssl_only=0, - idle_timeout=1, - verbose=True) self.stubs.Set(websocket.WebSocketServer, 'socket', self._mock_socket) self.stubs.Set(websocket.WebSocketServer, 'daemonize', lambda *args, **kwargs: None) self.stubs.Set(select, 'select', self._mock_select_systemexit) + soc = self._get_websockserver(daemon=True, ssl_only=0, idle_timeout=1, + verbose=True) self.assertEqual(None, soc.start_server()) def test_WSRequestHandle_do_GET_nofile(self): diff --git a/tests/test_websocketproxy.py b/tests/test_websocketproxy.py index 0197ce5..0fdd0fb 100644 --- a/tests/test_websocketproxy.py +++ b/tests/test_websocketproxy.py @@ -15,11 +15,15 @@ # under the License. """ Unit tests for websocketproxy """ -import unittest -import time -import subprocess -import stubout +import os +import logging import select +import shutil +import stubout +import subprocess +import tempfile +import time +import unittest from websockify import websocketproxy @@ -37,20 +41,42 @@ class MockSocket(object): class WebSocketProxyTest(unittest.TestCase): + def _init_logger(self, tmpdir): + name = 'websocket-unittest' + logger = logging.getLogger(name) + logger.setLevel(logging.DEBUG) + logger.propagate = True + filename = "%s.log" % (name) + handler = logging.FileHandler(filename) + handler.setFormatter(logging.Formatter("%(message)s")) + logger.addHandler(handler) + def setUp(self): """Called automatically before each test.""" super(WebSocketProxyTest, self).setUp() - self.soc = '' self.stubs = stubout.StubOutForTesting() + # Temporary dir for test data + self.tmpdir = tempfile.mkdtemp() + # Put log somewhere persistent + self._init_logger('./') + # Mock this out cause it screws tests up + self.stubs.Set(os, 'chdir', lambda *args, **kwargs: None) def tearDown(self): """Called automatically after each test.""" self.stubs.UnsetAll() + shutil.rmtree(self.tmpdir) super(WebSocketProxyTest, self).tearDown() + def _get_websockproxy(self, **kwargs): + return websocketproxy.WebSocketProxy(key=self.tmpdir, + web=self.tmpdir, + record=self.tmpdir, + **kwargs) + def test_run_wrap_cmd(self): - web_socket_proxy = websocketproxy.WebSocketProxy() + web_socket_proxy = self._get_websockproxy() web_socket_proxy.__dict__["wrap_cmd"] = "wrap_cmd" def mock_Popen(*args, **kwargs): @@ -61,7 +87,7 @@ class WebSocketProxyTest(unittest.TestCase): self.assertEquals(web_socket_proxy.spawn_message, True) def test_started(self): - web_socket_proxy = websocketproxy.WebSocketProxy() + web_socket_proxy = self._get_websockproxy() web_socket_proxy.__dict__["spawn_message"] = False web_socket_proxy.__dict__["wrap_cmd"] = "wrap_cmd" @@ -73,7 +99,7 @@ class WebSocketProxyTest(unittest.TestCase): self.assertEquals(web_socket_proxy.__dict__["spawn_message"], True) def test_poll(self): - web_socket_proxy = websocketproxy.WebSocketProxy() + web_socket_proxy = self._get_websockproxy() web_socket_proxy.__dict__["wrap_cmd"] = "wrap_cmd" web_socket_proxy.__dict__["wrap_mode"] = "respawn" web_socket_proxy.__dict__["wrap_times"] = [99999999] @@ -84,7 +110,7 @@ class WebSocketProxyTest(unittest.TestCase): self.assertEquals(web_socket_proxy.spawn_message, False) def test_new_client(self): - web_socket_proxy = websocketproxy.WebSocketProxy() + web_socket_proxy = self._get_websockproxy() web_socket_proxy.__dict__["verbose"] = "verbose" web_socket_proxy.__dict__["daemon"] = None web_socket_proxy.__dict__["client"] = "client" diff --git a/tests/tox.ini b/tests/tox.ini index 4f28f3f..098e89c 100644 --- a/tests/tox.ini +++ b/tests/tox.ini @@ -4,7 +4,7 @@ # and then run "tox" from this directory. [tox] -envlist = py24, py25, py26, py27, py30 +envlist = py24,py25,py26,py27,py30 setupdir = ../ [testenv] @@ -12,3 +12,9 @@ commands = nosetests {posargs} deps = mox nose + +# At some point we should enable this since tox epdctes it to exist but +# the code will need pep8ising first. +#[testenv:pep8] +#commands = flake8 +#dep = flake8 diff --git a/websockify/websocket.py b/websockify/websocket.py index d93a1fc..d5ea96b 100644 --- a/websockify/websocket.py +++ b/websockify/websocket.py @@ -331,7 +331,7 @@ Sec-WebSocket-Accept: %s\r return header + buf, len(header), 0 @staticmethod - def decode_hybi(buf, base64=False): + def decode_hybi(buf, base64=False, logger=None): """ Decode HyBi style WebSocket packets. Returns: {'fin' : 0_or_1, @@ -355,7 +355,8 @@ Sec-WebSocket-Accept: %s\r 'close_code' : 1000, 'close_reason' : ''} - logger = WebSocketServer.get_logger() + if logger is None: + logger = WebSocketServer.get_logger() blen = len(buf) f['left'] = blen @@ -395,16 +396,15 @@ Sec-WebSocket-Accept: %s\r f['payload'] = WebSocketServer.unmask(buf, f['hlen'], f['length']) else: - self.vmsg("Unmasked frame: %s" % repr(buf)) + logger.debug("Unmasked frame: %s" % repr(buf)) f['payload'] = buf[(f['hlen'] + f['masked'] * 4):full_len] if base64 and f['opcode'] in [1, 2]: try: f['payload'] = b64decode(f['payload']) except: - self.warn("Exception while b64decoding buffer: %s", - repr(buf)) - self.vmsg('Exception', exc_info=True) + logger.exception("Exception while b64decoding buffer: %s" % + (repr(buf))) raise if f['opcode'] == 0x08: @@ -510,7 +510,8 @@ Sec-WebSocket-Accept: %s\r self.recv_part = None while buf: - frame = self.decode_hybi(buf, base64=self.base64) + frame = self.decode_hybi(buf, base64=self.base64, + logger=self.logger) #self.msg("Received buf: %s, frame: %s", repr(buf), frame) if frame['payload'] == None: