websockify/tests/test_websocket.py

294 lines
13 KiB
Python

# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright(c)2013 NTT corp. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
""" Unit tests for websocket """
import errno
import os
import select
import socket
import ssl
import stubout
import sys
import tempfile
import unittest
from ssl import SSLError
from websockify import websocket as websocket
from SimpleHTTPServer import SimpleHTTPRequestHandler
class MockConnection(object):
def __init__(self, path):
self.path = path
def makefile(self, mode='r', bufsize=-1):
return open(self.path, mode, bufsize)
class WebSocketTestCase(unittest.TestCase):
def setUp(self):
"""Called automatically before each test."""
super(WebSocketTestCase, self).setUp()
self.stubs = stubout.StubOutForTesting()
self.server = websocket.WebSocketServer(listen_host='localhost',
listen_port=80,
key='./',
web='./',
record='./',
daemon=True,
ssl_only=False)
self.soc = self.server.socket('localhost')
def tearDown(self):
"""Called automatically after each test."""
self.stubs.UnsetAll()
super(WebSocketTestCase, self).tearDown()
def _mock_os_open_oserror(self, file, flags):
raise OSError('')
def _mock_os_close_oserror(self, fd):
raise OSError('')
def _mock_os_close_oserror_EBADF(self, fd):
raise OSError(errno.EBADF, '')
def _mock_socket(self, *args, **kwargs):
return self.soc
def _mock_select(self, rlist, wlist, xlist, timeout=None):
return '_mock_select'
def _mock_select_exception(self, rlist, wlist, xlist, timeout=None):
raise Exception
def _mock_select_keyboardinterrupt(self, rlist, wlist,
xlist, timeout=None):
raise KeyboardInterrupt
def _mock_select_systemexit(self, rlist, wlist, xlist, timeout=None):
sys.exit()
def test_daemonize_error(self):
soc = websocket.WebSocketServer(listen_host='localhost',
listen_port=80,
key='../',
web='../',
record='../',
daemon=True,
ssl_only=1,
idle_timeout=1)
self.stubs.Set(os, 'fork', lambda *args: None)
self.stubs.Set(os, 'setsid', lambda *args: None)
self.stubs.Set(os, 'close', self._mock_os_close_oserror)
self.assertRaises(OSError, soc.daemonize, keepfd=None, chdir='./')
def test_daemonize_EBADF_error(self):
soc = websocket.WebSocketServer(listen_host='localhost',
listen_port=80,
key='../',
web='../',
record='../',
daemon=True,
ssl_only=1,
idle_timeout=1)
self.stubs.Set(os, 'fork', lambda *args: None)
self.stubs.Set(os, 'setsid', lambda *args: None)
self.stubs.Set(os, 'close', self._mock_os_close_oserror_EBADF)
self.stubs.Set(os, 'open', self._mock_os_open_oserror)
self.assertRaises(OSError, soc.daemonize, keepfd=None, chdir='./')
def test_decode_hybi(self):
soc = websocket.WebSocketServer(listen_host='localhost',
listen_port=80,
key='../',
web='../',
record='../',
daemon=False,
ssl_only=1,
idle_timeout=1)
self.assertRaises(Exception, soc.decode_hybi, 'a' * 128,
base64=True)
def test_do_websocket_handshake(self):
soc = websocket.WebSocketServer(listen_host='localhost',
listen_port=80,
key='../',
web='../',
record='../',
daemon=True,
ssl_only=0,
idle_timeout=1)
soc.scheme = 'scheme'
headers = {'Sec-WebSocket-Protocol': 'binary',
'Sec-WebSocket-Version': '7',
'Sec-WebSocket-Key': 'foo'}
soc.do_websocket_handshake(headers, '127.0.0.1')
def test_do_handshake(self):
soc = websocket.WebSocketServer(listen_host='localhost',
listen_port=80,
key='../',
web='../',
record='../',
daemon=True,
ssl_only=0,
idle_timeout=1)
self.stubs.Set(select, 'select', self._mock_select)
self.stubs.Set(socket._socketobject, 'recv', lambda *args: 'mock_recv')
self.assertRaises(Exception, soc.do_handshake, self.soc, '127.0.0.1')
def test_do_handshake_ssl_error(self):
soc = websocket.WebSocketServer(listen_host='localhost',
listen_port=80,
key='../',
web='../',
record='../',
daemon=True,
ssl_only=0,
idle_timeout=1)
def _mock_wrap_socket(*args, **kwargs):
from ssl import SSLError
raise SSLError('unit test exception')
self.stubs.Set(select, 'select', self._mock_select)
self.stubs.Set(socket._socketobject, 'recv', lambda *args: '\x16')
self.stubs.Set(ssl, 'wrap_socket', _mock_wrap_socket)
self.assertRaises(SSLError, soc.do_handshake, self.soc, '127.0.0.1')
def test_fallback_SIGCHILD(self):
soc = websocket.WebSocketServer(listen_host='localhost',
listen_port=80,
key='../',
web='../',
record='../',
daemon=True,
ssl_only=0,
idle_timeout=1)
soc.fallback_SIGCHLD(None, None)
def test_start_server_Exception(self):
soc = websocket.WebSocketServer(listen_host='localhost',
listen_port=80,
key='../',
web='../',
record='../',
daemon=False,
ssl_only=1,
idle_timeout=1)
self.stubs.Set(websocket.WebSocketServer, 'socket', self._mock_socket)
self.stubs.Set(websocket.WebSocketServer, 'daemonize',
lambda *args, **kwargs: None)
self.stubs.Set(select, 'select', self._mock_select_exception)
self.assertEqual(None, soc.start_server())
def test_start_server_KeyboardInterrupt(self):
soc = websocket.WebSocketServer(listen_host='localhost',
listen_port=80,
key='../',
web='../',
record='../',
cert='xxxxxx',
daemon=True,
ssl_only=1,
idle_timeout=1)
self.stubs.Set(websocket.WebSocketServer, 'socket', self._mock_socket)
self.stubs.Set(websocket.WebSocketServer, 'daemonize',
lambda *args, **kwargs: None)
self.stubs.Set(select, 'select', self._mock_select_keyboardinterrupt)
self.assertEqual(None, soc.start_server())
def test_start_server_systemexit(self):
websocket.ssl = None
soc = websocket.WebSocketServer(listen_host='localhost',
listen_port=80,
key='../',
web='../',
record='../',
daemon=True,
ssl_only=0,
idle_timeout=1,
verbose=True)
self.stubs.Set(websocket.WebSocketServer, 'socket', self._mock_socket)
self.stubs.Set(websocket.WebSocketServer, 'daemonize',
lambda *args, **kwargs: None)
self.stubs.Set(select, 'select', self._mock_select_systemexit)
self.assertEqual(None, soc.start_server())
def test_WSRequestHandle_do_GET_nofile(self):
request = 'GET /tmp.txt HTTP/0.9'
with tempfile.NamedTemporaryFile() as test_file:
test_file.write(request)
test_file.flush()
test_file.seek(0)
con = MockConnection(test_file.name)
soc = websocket.WSRequestHandler(con, "127.0.0.1", file_only=True)
soc.path = ''
soc.headers = {'upgrade': ''}
self.stubs.Set(SimpleHTTPRequestHandler, 'send_response',
lambda *args: None)
soc.do_GET()
self.assertEqual(404, soc.last_code)
def test_WSRequestHandle_do_GET_hidden_resource(self):
request = 'GET /tmp.txt HTTP/0.9'
with tempfile.NamedTemporaryFile() as test_file:
test_file.write(request)
test_file.flush()
test_file.seek(0)
con = MockConnection(test_file.name)
soc = websocket.WSRequestHandler(con, '127.0.0.1', no_parent=True)
soc.path = test_file.name + '?'
soc.headers = {'upgrade': ''}
soc.webroot = 'no match startswith'
self.stubs.Set(SimpleHTTPRequestHandler,
'send_response',
lambda *args: None)
soc.do_GET()
self.assertEqual(403, soc.last_code)
def testsocket_set_keepalive_options(self):
keepcnt = 12
keepidle = 34
keepintvl = 56
sock = self.server.socket('localhost',
tcp_keepcnt=keepcnt,
tcp_keepidle=keepidle,
tcp_keepintvl=keepintvl)
self.assertEqual(sock.getsockopt(socket.SOL_TCP,
socket.TCP_KEEPCNT), keepcnt)
self.assertEqual(sock.getsockopt(socket.SOL_TCP,
socket.TCP_KEEPIDLE), keepidle)
self.assertEqual(sock.getsockopt(socket.SOL_TCP,
socket.TCP_KEEPINTVL), keepintvl)
sock = self.server.socket('localhost',
tcp_keepalive=False,
tcp_keepcnt=keepcnt,
tcp_keepidle=keepidle,
tcp_keepintvl=keepintvl)
self.assertNotEqual(sock.getsockopt(socket.SOL_TCP,
socket.TCP_KEEPCNT), keepcnt)
self.assertNotEqual(sock.getsockopt(socket.SOL_TCP,
socket.TCP_KEEPIDLE), keepidle)
self.assertNotEqual(sock.getsockopt(socket.SOL_TCP,
socket.TCP_KEEPINTVL), keepintvl)