diff --git a/tests/test_websocket.py b/tests/test_websocket.py index c603189..49efe81 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -14,11 +14,27 @@ # License for the specific language governing permissions and limitations # under the License. -"""Unit tests for websockify.""" - +""" Unit tests for websocket """ +import errno +import os +import select import socket +import ssl +import stubout +import sys +import tempfile import unittest +from ssl import SSLError from websockify import websocket as websocket +from SimpleHTTPServer import SimpleHTTPRequestHandler + + +class MockConnection(object): + def __init__(self, path): + self.path = path + + def makefile(self, mode='r', bufsize=-1): + return open(self.path, mode, bufsize) class WebSocketTestCase(unittest.TestCase): @@ -26,27 +42,235 @@ class WebSocketTestCase(unittest.TestCase): def setUp(self): """Called automatically before each test.""" super(WebSocketTestCase, self).setUp() + self.stubs = stubout.StubOutForTesting() + self.server = websocket.WebSocketServer(listen_host='localhost', + listen_port=80, + key='./', + web='./', + record='./', + daemon=True, + ssl_only=False) + self.soc = self.server.socket('localhost') def tearDown(self): """Called automatically after each test.""" + self.stubs.UnsetAll() super(WebSocketTestCase, self).tearDown() + def _mock_os_open_oserror(self, file, flags): + raise OSError('') + + def _mock_os_close_oserror(self, fd): + raise OSError('') + + def _mock_os_close_oserror_EBADF(self, fd): + raise OSError(errno.EBADF, '') + + def _mock_socket(self, *args, **kwargs): + return self.soc + + def _mock_select(self, rlist, wlist, xlist, timeout=None): + return '_mock_select' + + def _mock_select_exception(self, rlist, wlist, xlist, timeout=None): + raise Exception + + def _mock_select_keyboardinterrupt(self, rlist, wlist, + xlist, timeout=None): + raise KeyboardInterrupt + + def _mock_select_systemexit(self, rlist, wlist, xlist, timeout=None): + sys.exit() + + def test_daemonize_error(self): + soc = websocket.WebSocketServer(listen_host='localhost', + listen_port=80, + key='../', + web='../', + record='../', + daemon=True, + ssl_only=1, + idle_timeout=1) + self.stubs.Set(os, 'fork', lambda *args: None) + self.stubs.Set(os, 'setsid', lambda *args: None) + self.stubs.Set(os, 'close', self._mock_os_close_oserror) + self.assertRaises(OSError, soc.daemonize, keepfd=None, chdir='./') + + def test_daemonize_EBADF_error(self): + soc = websocket.WebSocketServer(listen_host='localhost', + listen_port=80, + key='../', + web='../', + record='../', + daemon=True, + ssl_only=1, + idle_timeout=1) + self.stubs.Set(os, 'fork', lambda *args: None) + self.stubs.Set(os, 'setsid', lambda *args: None) + self.stubs.Set(os, 'close', self._mock_os_close_oserror_EBADF) + self.stubs.Set(os, 'open', self._mock_os_open_oserror) + self.assertRaises(OSError, soc.daemonize, keepfd=None, chdir='./') + + def test_decode_hybi(self): + soc = websocket.WebSocketServer(listen_host='localhost', + listen_port=80, + key='../', + web='../', + record='../', + daemon=False, + ssl_only=1, + idle_timeout=1) + + self.assertRaises(Exception, soc.decode_hybi, 'a' * 128, + base64=True) + + def test_do_websocket_handshake(self): + soc = websocket.WebSocketServer(listen_host='localhost', + listen_port=80, + key='../', + web='../', + record='../', + daemon=True, + ssl_only=0, + idle_timeout=1) + soc.scheme = 'scheme' + headers = {'Sec-WebSocket-Protocol': 'binary', + 'Sec-WebSocket-Version': '7', + 'Sec-WebSocket-Key': 'foo'} + soc.do_websocket_handshake(headers, '127.0.0.1') + + def test_do_handshake(self): + soc = websocket.WebSocketServer(listen_host='localhost', + listen_port=80, + key='../', + web='../', + record='../', + daemon=True, + ssl_only=0, + idle_timeout=1) + self.stubs.Set(select, 'select', self._mock_select) + self.stubs.Set(socket._socketobject, 'recv', lambda *args: 'mock_recv') + self.assertRaises(Exception, soc.do_handshake, self.soc, '127.0.0.1') + + def test_do_handshake_ssl_error(self): + soc = websocket.WebSocketServer(listen_host='localhost', + listen_port=80, + key='../', + web='../', + record='../', + daemon=True, + ssl_only=0, + idle_timeout=1) + + def _mock_wrap_socket(*args, **kwargs): + from ssl import SSLError + raise SSLError('unit test exception') + + self.stubs.Set(select, 'select', self._mock_select) + self.stubs.Set(socket._socketobject, 'recv', lambda *args: '\x16') + self.stubs.Set(ssl, 'wrap_socket', _mock_wrap_socket) + self.assertRaises(SSLError, soc.do_handshake, self.soc, '127.0.0.1') + + def test_fallback_SIGCHILD(self): + soc = websocket.WebSocketServer(listen_host='localhost', + listen_port=80, + key='../', + web='../', + record='../', + daemon=True, + ssl_only=0, + idle_timeout=1) + soc.fallback_SIGCHLD(None, None) + + def test_start_server_Exception(self): + soc = websocket.WebSocketServer(listen_host='localhost', + listen_port=80, + key='../', + web='../', + record='../', + daemon=False, + ssl_only=1, + idle_timeout=1) + self.stubs.Set(websocket.WebSocketServer, 'socket', self._mock_socket) + self.stubs.Set(websocket.WebSocketServer, 'daemonize', + lambda *args, **kwargs: None) + self.stubs.Set(select, 'select', self._mock_select_exception) + self.assertEqual(None, soc.start_server()) + + def test_start_server_KeyboardInterrupt(self): + soc = websocket.WebSocketServer(listen_host='localhost', + listen_port=80, + key='../', + web='../', + record='../', + cert='xxxxxx', + daemon=True, + ssl_only=1, + idle_timeout=1) + self.stubs.Set(websocket.WebSocketServer, 'socket', self._mock_socket) + self.stubs.Set(websocket.WebSocketServer, 'daemonize', + lambda *args, **kwargs: None) + self.stubs.Set(select, 'select', self._mock_select_keyboardinterrupt) + self.assertEqual(None, soc.start_server()) + + def test_start_server_systemexit(self): + websocket.ssl = None + soc = websocket.WebSocketServer(listen_host='localhost', + listen_port=80, + key='../', + web='../', + record='../', + daemon=True, + ssl_only=0, + idle_timeout=1, + verbose=True) + self.stubs.Set(websocket.WebSocketServer, 'socket', self._mock_socket) + self.stubs.Set(websocket.WebSocketServer, 'daemonize', + lambda *args, **kwargs: None) + self.stubs.Set(select, 'select', self._mock_select_systemexit) + self.assertEqual(None, soc.start_server()) + + def test_WSRequestHandle_do_GET_nofile(self): + request = 'GET /tmp.txt HTTP/0.9' + with tempfile.NamedTemporaryFile() as test_file: + test_file.write(request) + test_file.flush() + test_file.seek(0) + con = MockConnection(test_file.name) + soc = websocket.WSRequestHandler(con, "127.0.0.1", file_only=True) + soc.path = '' + soc.headers = {'upgrade': ''} + self.stubs.Set(SimpleHTTPRequestHandler, 'send_response', + lambda *args: None) + soc.do_GET() + self.assertEqual(404, soc.last_code) + + def test_WSRequestHandle_do_GET_hidden_resource(self): + request = 'GET /tmp.txt HTTP/0.9' + with tempfile.NamedTemporaryFile() as test_file: + test_file.write(request) + test_file.flush() + test_file.seek(0) + con = MockConnection(test_file.name) + soc = websocket.WSRequestHandler(con, '127.0.0.1', no_parent=True) + soc.path = test_file.name + '?' + soc.headers = {'upgrade': ''} + soc.webroot = 'no match startswith' + self.stubs.Set(SimpleHTTPRequestHandler, + 'send_response', + lambda *args: None) + soc.do_GET() + self.assertEqual(403, soc.last_code) + def testsocket_set_keepalive_options(self): - server = websocket.WebSocketServer(listen_host='localhost', - listen_port=80, - key='./', - web='./', - record='./', - daemon=True, - ssl_only=1) keepcnt = 12 keepidle = 34 keepintvl = 56 - sock = server.socket('localhost', - tcp_keepcnt=keepcnt, - tcp_keepidle=keepidle, - tcp_keepintvl=keepintvl) + sock = self.server.socket('localhost', + tcp_keepcnt=keepcnt, + tcp_keepidle=keepidle, + tcp_keepintvl=keepintvl) self.assertEqual(sock.getsockopt(socket.SOL_TCP, socket.TCP_KEEPCNT), keepcnt) @@ -55,11 +279,11 @@ class WebSocketTestCase(unittest.TestCase): self.assertEqual(sock.getsockopt(socket.SOL_TCP, socket.TCP_KEEPINTVL), keepintvl) - sock = server.socket('localhost', - tcp_keepalive=False, - tcp_keepcnt=keepcnt, - tcp_keepidle=keepidle, - tcp_keepintvl=keepintvl) + sock = self.server.socket('localhost', + tcp_keepalive=False, + tcp_keepcnt=keepcnt, + tcp_keepidle=keepidle, + tcp_keepintvl=keepintvl) self.assertNotEqual(sock.getsockopt(socket.SOL_TCP, socket.TCP_KEEPCNT), keepcnt) diff --git a/tests/test_websocketproxy.py b/tests/test_websocketproxy.py new file mode 100644 index 0000000..0197ce5 --- /dev/null +++ b/tests/test_websocketproxy.py @@ -0,0 +1,101 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright(c)2013 NTT corp. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" Unit tests for websocketproxy """ +import unittest +import time +import subprocess +import stubout +import select + +from websockify import websocketproxy + + +class MockSocket(object): + def __init__(*args, **kwargs): + pass + + def shutdown(*args): + pass + + def close(*args): + pass + + +class WebSocketProxyTest(unittest.TestCase): + + def setUp(self): + """Called automatically before each test.""" + super(WebSocketProxyTest, self).setUp() + + self.soc = '' + self.stubs = stubout.StubOutForTesting() + + def tearDown(self): + """Called automatically after each test.""" + self.stubs.UnsetAll() + super(WebSocketProxyTest, self).tearDown() + + def test_run_wrap_cmd(self): + web_socket_proxy = websocketproxy.WebSocketProxy() + web_socket_proxy.__dict__["wrap_cmd"] = "wrap_cmd" + + def mock_Popen(*args, **kwargs): + return '_mock_cmd' + + self.stubs.Set(subprocess, 'Popen', mock_Popen) + web_socket_proxy.run_wrap_cmd() + self.assertEquals(web_socket_proxy.spawn_message, True) + + def test_started(self): + web_socket_proxy = websocketproxy.WebSocketProxy() + web_socket_proxy.__dict__["spawn_message"] = False + web_socket_proxy.__dict__["wrap_cmd"] = "wrap_cmd" + + def mock_run_wrap_cmd(*args, **kwargs): + web_socket_proxy.__dict__["spawn_message"] = True + + self.stubs.Set(web_socket_proxy, 'run_wrap_cmd', mock_run_wrap_cmd) + web_socket_proxy.started() + self.assertEquals(web_socket_proxy.__dict__["spawn_message"], True) + + def test_poll(self): + web_socket_proxy = websocketproxy.WebSocketProxy() + web_socket_proxy.__dict__["wrap_cmd"] = "wrap_cmd" + web_socket_proxy.__dict__["wrap_mode"] = "respawn" + web_socket_proxy.__dict__["wrap_times"] = [99999999] + web_socket_proxy.__dict__["spawn_message"] = True + web_socket_proxy.__dict__["cmd"] = None + self.stubs.Set(time, 'time', lambda: 100000000.000) + web_socket_proxy.poll() + self.assertEquals(web_socket_proxy.spawn_message, False) + + def test_new_client(self): + web_socket_proxy = websocketproxy.WebSocketProxy() + web_socket_proxy.__dict__["verbose"] = "verbose" + web_socket_proxy.__dict__["daemon"] = None + web_socket_proxy.__dict__["client"] = "client" + + self.stubs.Set(web_socket_proxy, 'socket', MockSocket) + + def mock_select(*args, **kwargs): + ins = None + outs = None + excepts = "excepts" + return ins, outs, excepts + + self.stubs.Set(select, 'select', mock_select) + self.assertRaises(Exception, web_socket_proxy.new_client) diff --git a/tests/tox.ini b/tests/tox.ini new file mode 100644 index 0000000..4f28f3f --- /dev/null +++ b/tests/tox.ini @@ -0,0 +1,14 @@ +# Tox (http://tox.testrun.org/) is a tool for running tests +# in multiple virtualenvs. This configuration file will run the +# test suite on all supported python versions. To use it, "pip install tox" +# and then run "tox" from this directory. + +[tox] +envlist = py24, py25, py26, py27, py30 +setupdir = ../ + +[testenv] +commands = nosetests {posargs} +deps = + mox + nose diff --git a/websockify/websocket.py b/websockify/websocket.py index 57b6b8d..d93a1fc 100644 --- a/websockify/websocket.py +++ b/websockify/websocket.py @@ -59,6 +59,7 @@ for mod, msg in [('numpy', 'HyBi protocol will be slower'), except ImportError: globals()[mod] = None print("WARNING: no '%s' module, %s" % (mod, msg)) + if multiprocessing and sys.platform == 'win32': # make sockets pickle-able/inheritable import multiprocessing.reduction