websockify/tests/test_websockifyserver.py

434 lines
16 KiB
Python

# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright(c)2013 NTT corp. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
""" Unit tests for websockifyserver """
import errno
import os
import logging
import select
import shutil
import socket
import ssl
try:
from mock import patch, MagicMock, ANY
except ImportError:
from unittest.mock import patch, MagicMock, ANY
import sys
import tempfile
import unittest
import socket
import signal
from websockify import websockifyserver
try:
from BaseHTTPServer import BaseHTTPRequestHandler
except ImportError:
from http.server import BaseHTTPRequestHandler
try:
from StringIO import StringIO
BytesIO = StringIO
except ImportError:
from io import StringIO
from io import BytesIO
def raise_oserror(*args, **kwargs):
raise OSError('fake error')
class FakeSocket(object):
def __init__(self, data=''):
if isinstance(data, bytes):
self._data = data
else:
self._data = data.encode('latin_1')
def recv(self, amt, flags=None):
res = self._data[0:amt]
if not (flags & socket.MSG_PEEK):
self._data = self._data[amt:]
return res
def makefile(self, mode='r', buffsize=None):
if 'b' in mode:
return BytesIO(self._data)
else:
return StringIO(self._data.decode('latin_1'))
class WebSockifyRequestHandlerTestCase(unittest.TestCase):
def setUp(self):
super(WebSockifyRequestHandlerTestCase, self).setUp()
self.tmpdir = tempfile.mkdtemp('-websockify-tests')
# Mock this out cause it screws tests up
patch('os.chdir').start()
def tearDown(self):
"""Called automatically after each test."""
patch.stopall()
os.rmdir(self.tmpdir)
super(WebSockifyRequestHandlerTestCase, self).tearDown()
def _get_server(self, handler_class=websockifyserver.WebSockifyRequestHandler,
**kwargs):
web = kwargs.pop('web', self.tmpdir)
return websockifyserver.WebSockifyServer(
handler_class, listen_host='localhost',
listen_port=80, key=self.tmpdir, web=web,
record=self.tmpdir, daemon=False, ssl_only=0, idle_timeout=1,
**kwargs)
@patch('websockify.websockifyserver.WebSockifyRequestHandler.send_error')
def test_normal_get_with_only_upgrade_returns_error(self, send_error):
server = self._get_server(web=None)
handler = websockifyserver.WebSockifyRequestHandler(
FakeSocket('GET /tmp.txt HTTP/1.1'), '127.0.0.1', server)
handler.do_GET()
send_error.assert_called_with(405, ANY)
@patch('websockify.websockifyserver.WebSockifyRequestHandler.send_error')
def test_list_dir_with_file_only_returns_error(self, send_error):
server = self._get_server(file_only=True)
handler = websockifyserver.WebSockifyRequestHandler(
FakeSocket('GET / HTTP/1.1'), '127.0.0.1', server)
handler.path = '/'
handler.do_GET()
send_error.assert_called_with(404, ANY)
class WebSockifyServerTestCase(unittest.TestCase):
def setUp(self):
super(WebSockifyServerTestCase, self).setUp()
self.tmpdir = tempfile.mkdtemp('-websockify-tests')
# Mock this out cause it screws tests up
patch('os.chdir').start()
def tearDown(self):
"""Called automatically after each test."""
patch.stopall()
os.rmdir(self.tmpdir)
super(WebSockifyServerTestCase, self).tearDown()
def _get_server(self, handler_class=websockifyserver.WebSockifyRequestHandler,
**kwargs):
return websockifyserver.WebSockifyServer(
handler_class, listen_host='localhost',
listen_port=80, key=self.tmpdir, web=self.tmpdir,
record=self.tmpdir, **kwargs)
def test_daemonize_raises_error_while_closing_fds(self):
server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1)
patch('os.fork').start().return_value = 0
patch('signal.signal').start()
patch('os.setsid').start()
patch('os.close').start().side_effect = raise_oserror
self.assertRaises(OSError, server.daemonize, keepfd=None, chdir='./')
def test_daemonize_ignores_ebadf_error_while_closing_fds(self):
def raise_oserror_ebadf(fd):
raise OSError(errno.EBADF, 'fake error')
server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1)
patch('os.fork').start().return_value = 0
patch('signal.signal').start()
patch('os.setsid').start()
patch('os.close').start().side_effect = raise_oserror_ebadf
patch('os.open').start().side_effect = raise_oserror
self.assertRaises(OSError, server.daemonize, keepfd=None, chdir='./')
def test_handshake_fails_on_not_ready(self):
server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1)
def fake_select(rlist, wlist, xlist, timeout=None):
return ([], [], [])
patch('select.select').start().side_effect = fake_select
self.assertRaises(
websockifyserver.WebSockifyServer.EClose, server.do_handshake,
FakeSocket(), '127.0.0.1')
def test_empty_handshake_fails(self):
server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1)
sock = FakeSocket('')
def fake_select(rlist, wlist, xlist, timeout=None):
return ([sock], [], [])
patch('select.select').start().side_effect = fake_select
self.assertRaises(
websockifyserver.WebSockifyServer.EClose, server.do_handshake,
sock, '127.0.0.1')
def test_handshake_policy_request(self):
# TODO(directxman12): implement
pass
def test_handshake_ssl_only_without_ssl_raises_error(self):
server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1)
sock = FakeSocket('some initial data')
def fake_select(rlist, wlist, xlist, timeout=None):
return ([sock], [], [])
patch('select.select').start().side_effect = fake_select
self.assertRaises(
websockifyserver.WebSockifyServer.EClose, server.do_handshake,
sock, '127.0.0.1')
def test_do_handshake_no_ssl(self):
class FakeHandler(object):
CALLED = False
def __init__(self, *args, **kwargs):
type(self).CALLED = True
FakeHandler.CALLED = False
server = self._get_server(
handler_class=FakeHandler, daemon=True,
ssl_only=0, idle_timeout=1)
sock = FakeSocket('some initial data')
def fake_select(rlist, wlist, xlist, timeout=None):
return ([sock], [], [])
patch('select.select').start().side_effect = fake_select
self.assertEqual(server.do_handshake(sock, '127.0.0.1'), sock)
self.assertTrue(FakeHandler.CALLED, True)
def test_do_handshake_ssl(self):
# TODO(directxman12): implement this
pass
def test_do_handshake_ssl_without_ssl_raises_error(self):
# TODO(directxman12): implement this
pass
def test_do_handshake_ssl_without_cert_raises_error(self):
server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1,
cert='afdsfasdafdsafdsafdsafdas')
sock = FakeSocket("\x16some ssl data")
def fake_select(rlist, wlist, xlist, timeout=None):
return ([sock], [], [])
patch('select.select').start().side_effect = fake_select
self.assertRaises(
websockifyserver.WebSockifyServer.EClose, server.do_handshake,
sock, '127.0.0.1')
def test_do_handshake_ssl_error_eof_raises_close_error(self):
server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1)
sock = FakeSocket("\x16some ssl data")
def fake_select(rlist, wlist, xlist, timeout=None):
return ([sock], [], [])
def fake_wrap_socket(*args, **kwargs):
raise ssl.SSLError(ssl.SSL_ERROR_EOF)
class fake_create_default_context():
def __init__(self, purpose):
self.verify_mode = None
self.options = 0
def load_cert_chain(self, certfile, keyfile, password):
pass
def set_default_verify_paths(self):
pass
def load_verify_locations(self, cafile):
pass
def wrap_socket(self, *args, **kwargs):
raise ssl.SSLError(ssl.SSL_ERROR_EOF)
patch('select.select').start().side_effect = fake_select
if (hasattr(ssl, 'create_default_context')):
# for recent versions of python
patch('ssl.create_default_context').start().side_effect = fake_create_default_context
else:
# for fallback for old versions of python
patch('ssl.warp_socket').start().side_effect = fake_wrap_socket
self.assertRaises(
websockifyserver.WebSockifyServer.EClose, server.do_handshake,
sock, '127.0.0.1')
def test_do_handshake_ssl_sets_ciphers(self):
test_ciphers = 'TEST-CIPHERS-1:TEST-CIPHER-2'
class FakeHandler(object):
def __init__(self, *args, **kwargs):
pass
server = self._get_server(handler_class=FakeHandler, daemon=True,
idle_timeout=1, ssl_ciphers=test_ciphers)
sock = FakeSocket("\x16some ssl data")
def fake_select(rlist, wlist, xlist, timeout=None):
return ([sock], [], [])
class fake_create_default_context():
CIPHERS = ''
def __init__(self, purpose):
self.verify_mode = None
self.options = 0
def load_cert_chain(self, certfile, keyfile, password):
pass
def set_default_verify_paths(self):
pass
def load_verify_locations(self, cafile):
pass
def wrap_socket(self, *args, **kwargs):
pass
def set_ciphers(self, ciphers_to_set):
fake_create_default_context.CIPHERS = ciphers_to_set
patch('select.select').start().side_effect = fake_select
if (hasattr(ssl, 'create_default_context')):
# for recent versions of python
patch('ssl.create_default_context').start().side_effect = fake_create_default_context
server.do_handshake(sock, '127.0.0.1')
self.assertEqual(fake_create_default_context.CIPHERS, test_ciphers)
else:
# for fallback for old versions of python
# not supperted, nothing to test
pass
def test_do_handshake_ssl_sets_opions(self):
test_options = 0xCAFEBEEF
class FakeHandler(object):
def __init__(self, *args, **kwargs):
pass
server = self._get_server(handler_class=FakeHandler, daemon=True,
idle_timeout=1, ssl_options=test_options)
sock = FakeSocket("\x16some ssl data")
def fake_select(rlist, wlist, xlist, timeout=None):
return ([sock], [], [])
class fake_create_default_context(object):
OPTIONS = 0
def __init__(self, purpose):
self.verify_mode = None
self._options = 0
def load_cert_chain(self, certfile, keyfile, password):
pass
def set_default_verify_paths(self):
pass
def load_verify_locations(self, cafile):
pass
def wrap_socket(self, *args, **kwargs):
pass
def get_options(self):
return self._options
def set_options(self, val):
fake_create_default_context.OPTIONS = val
options = property(get_options, set_options)
patch('select.select').start().side_effect = fake_select
if (hasattr(ssl, 'create_default_context')):
# for recent versions of python
patch('ssl.create_default_context').start().side_effect = fake_create_default_context
server.do_handshake(sock, '127.0.0.1')
self.assertEqual(fake_create_default_context.OPTIONS, test_options)
else:
# for fallback for old versions of python
# not supperted, nothing to test
pass
def test_fallback_sigchld_handler(self):
# TODO(directxman12): implement this
pass
def test_start_server_error(self):
server = self._get_server(daemon=False, ssl_only=1, idle_timeout=1)
sock = server.socket('localhost')
def fake_select(rlist, wlist, xlist, timeout=None):
raise Exception("fake error")
patch('websockify.websockifyserver.WebSockifyServer.socket').start()
patch('websockify.websockifyserver.WebSockifyServer.daemonize').start()
patch('select.select').start().side_effect = fake_select
server.start_server()
def test_start_server_keyboardinterrupt(self):
server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1)
sock = server.socket('localhost')
def fake_select(rlist, wlist, xlist, timeout=None):
raise KeyboardInterrupt
patch('websockify.websockifyserver.WebSockifyServer.socket').start()
patch('websockify.websockifyserver.WebSockifyServer.daemonize').start()
patch('select.select').start().side_effect = fake_select
server.start_server()
def test_start_server_systemexit(self):
server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1)
sock = server.socket('localhost')
def fake_select(rlist, wlist, xlist, timeout=None):
sys.exit()
patch('websockify.websockifyserver.WebSockifyServer.socket').start()
patch('websockify.websockifyserver.WebSockifyServer.daemonize').start()
patch('select.select').start().side_effect = fake_select
server.start_server()
def test_socket_set_keepalive_options(self):
keepcnt = 12
keepidle = 34
keepintvl = 56
server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1)
sock = server.socket('localhost',
tcp_keepcnt=keepcnt,
tcp_keepidle=keepidle,
tcp_keepintvl=keepintvl)
if hasattr(socket, 'TCP_KEEPCNT'):
self.assertEqual(sock.getsockopt(socket.SOL_TCP,
socket.TCP_KEEPCNT), keepcnt)
self.assertEqual(sock.getsockopt(socket.SOL_TCP,
socket.TCP_KEEPIDLE), keepidle)
self.assertEqual(sock.getsockopt(socket.SOL_TCP,
socket.TCP_KEEPINTVL), keepintvl)
sock = server.socket('localhost',
tcp_keepalive=False,
tcp_keepcnt=keepcnt,
tcp_keepidle=keepidle,
tcp_keepintvl=keepintvl)
if hasattr(socket, 'TCP_KEEPCNT'):
self.assertNotEqual(sock.getsockopt(socket.SOL_TCP,
socket.TCP_KEEPCNT), keepcnt)
self.assertNotEqual(sock.getsockopt(socket.SOL_TCP,
socket.TCP_KEEPIDLE), keepidle)
self.assertNotEqual(sock.getsockopt(socket.SOL_TCP,
socket.TCP_KEEPINTVL), keepintvl)