401 lines
15 KiB
Python
401 lines
15 KiB
Python
# vim: tabstop=4 shiftwidth=4 softtabstop=4
|
|
|
|
# Copyright(c)2013 NTT corp. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
|
# not use this file except in compliance with the License. You may obtain
|
|
# a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
# License for the specific language governing permissions and limitations
|
|
# under the License.
|
|
|
|
""" Unit tests for websockifyserver """
|
|
import errno
|
|
import os
|
|
import logging
|
|
import select
|
|
import shutil
|
|
import socket
|
|
import ssl
|
|
from unittest.mock import patch, MagicMock, ANY
|
|
import sys
|
|
import tempfile
|
|
import unittest
|
|
import socket
|
|
import signal
|
|
from http.server import BaseHTTPRequestHandler
|
|
from io import StringIO
|
|
from io import BytesIO
|
|
|
|
from websockify import websockifyserver
|
|
|
|
|
|
def raise_oserror(*args, **kwargs):
|
|
raise OSError('fake error')
|
|
|
|
|
|
class FakeSocket(object):
|
|
def __init__(self, data=b''):
|
|
self._data = data
|
|
|
|
def recv(self, amt, flags=None):
|
|
res = self._data[0:amt]
|
|
if not (flags & socket.MSG_PEEK):
|
|
self._data = self._data[amt:]
|
|
|
|
return res
|
|
|
|
def makefile(self, mode='r', buffsize=None):
|
|
if 'b' in mode:
|
|
return BytesIO(self._data)
|
|
else:
|
|
return StringIO(self._data.decode('latin_1'))
|
|
|
|
|
|
class WebSockifyRequestHandlerTestCase(unittest.TestCase):
|
|
def setUp(self):
|
|
super(WebSockifyRequestHandlerTestCase, self).setUp()
|
|
self.tmpdir = tempfile.mkdtemp('-websockify-tests')
|
|
# Mock this out cause it screws tests up
|
|
patch('os.chdir').start()
|
|
|
|
def tearDown(self):
|
|
"""Called automatically after each test."""
|
|
patch.stopall()
|
|
os.rmdir(self.tmpdir)
|
|
super(WebSockifyRequestHandlerTestCase, self).tearDown()
|
|
|
|
def _get_server(self, handler_class=websockifyserver.WebSockifyRequestHandler,
|
|
**kwargs):
|
|
web = kwargs.pop('web', self.tmpdir)
|
|
return websockifyserver.WebSockifyServer(
|
|
handler_class, listen_host='localhost',
|
|
listen_port=80, key=self.tmpdir, web=web,
|
|
record=self.tmpdir, daemon=False, ssl_only=0, idle_timeout=1,
|
|
**kwargs)
|
|
|
|
@patch('websockify.websockifyserver.WebSockifyRequestHandler.send_error')
|
|
def test_normal_get_with_only_upgrade_returns_error(self, send_error):
|
|
server = self._get_server(web=None)
|
|
handler = websockifyserver.WebSockifyRequestHandler(
|
|
FakeSocket(b'GET /tmp.txt HTTP/1.1'), '127.0.0.1', server)
|
|
|
|
handler.do_GET()
|
|
send_error.assert_called_with(405, ANY)
|
|
|
|
@patch('websockify.websockifyserver.WebSockifyRequestHandler.send_error')
|
|
def test_list_dir_with_file_only_returns_error(self, send_error):
|
|
server = self._get_server(file_only=True)
|
|
handler = websockifyserver.WebSockifyRequestHandler(
|
|
FakeSocket(b'GET / HTTP/1.1'), '127.0.0.1', server)
|
|
|
|
handler.path = '/'
|
|
handler.do_GET()
|
|
send_error.assert_called_with(404, ANY)
|
|
|
|
|
|
class WebSockifyServerTestCase(unittest.TestCase):
|
|
def setUp(self):
|
|
super(WebSockifyServerTestCase, self).setUp()
|
|
self.tmpdir = tempfile.mkdtemp('-websockify-tests')
|
|
# Mock this out cause it screws tests up
|
|
patch('os.chdir').start()
|
|
|
|
def tearDown(self):
|
|
"""Called automatically after each test."""
|
|
patch.stopall()
|
|
os.rmdir(self.tmpdir)
|
|
super(WebSockifyServerTestCase, self).tearDown()
|
|
|
|
def _get_server(self, handler_class=websockifyserver.WebSockifyRequestHandler,
|
|
**kwargs):
|
|
return websockifyserver.WebSockifyServer(
|
|
handler_class, listen_host='localhost',
|
|
listen_port=80, key=self.tmpdir, web=self.tmpdir,
|
|
record=self.tmpdir, **kwargs)
|
|
|
|
def test_daemonize_raises_error_while_closing_fds(self):
|
|
server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1)
|
|
patch('os.fork').start().return_value = 0
|
|
patch('signal.signal').start()
|
|
patch('os.setsid').start()
|
|
patch('os.close').start().side_effect = raise_oserror
|
|
self.assertRaises(OSError, server.daemonize, keepfd=None, chdir='./')
|
|
|
|
def test_daemonize_ignores_ebadf_error_while_closing_fds(self):
|
|
def raise_oserror_ebadf(fd):
|
|
raise OSError(errno.EBADF, 'fake error')
|
|
|
|
server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1)
|
|
patch('os.fork').start().return_value = 0
|
|
patch('signal.signal').start()
|
|
patch('os.setsid').start()
|
|
patch('os.close').start().side_effect = raise_oserror_ebadf
|
|
patch('os.open').start().side_effect = raise_oserror
|
|
self.assertRaises(OSError, server.daemonize, keepfd=None, chdir='./')
|
|
|
|
def test_handshake_fails_on_not_ready(self):
|
|
server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1)
|
|
|
|
def fake_select(rlist, wlist, xlist, timeout=None):
|
|
return ([], [], [])
|
|
|
|
patch('select.select').start().side_effect = fake_select
|
|
self.assertRaises(
|
|
websockifyserver.WebSockifyServer.EClose, server.do_handshake,
|
|
FakeSocket(), '127.0.0.1')
|
|
|
|
def test_empty_handshake_fails(self):
|
|
server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1)
|
|
|
|
sock = FakeSocket('')
|
|
|
|
def fake_select(rlist, wlist, xlist, timeout=None):
|
|
return ([sock], [], [])
|
|
|
|
patch('select.select').start().side_effect = fake_select
|
|
self.assertRaises(
|
|
websockifyserver.WebSockifyServer.EClose, server.do_handshake,
|
|
sock, '127.0.0.1')
|
|
|
|
def test_handshake_policy_request(self):
|
|
# TODO(directxman12): implement
|
|
pass
|
|
|
|
def test_handshake_ssl_only_without_ssl_raises_error(self):
|
|
server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1)
|
|
|
|
sock = FakeSocket(b'some initial data')
|
|
|
|
def fake_select(rlist, wlist, xlist, timeout=None):
|
|
return ([sock], [], [])
|
|
|
|
patch('select.select').start().side_effect = fake_select
|
|
self.assertRaises(
|
|
websockifyserver.WebSockifyServer.EClose, server.do_handshake,
|
|
sock, '127.0.0.1')
|
|
|
|
def test_do_handshake_no_ssl(self):
|
|
class FakeHandler(object):
|
|
CALLED = False
|
|
def __init__(self, *args, **kwargs):
|
|
type(self).CALLED = True
|
|
|
|
FakeHandler.CALLED = False
|
|
|
|
server = self._get_server(
|
|
handler_class=FakeHandler, daemon=True,
|
|
ssl_only=0, idle_timeout=1)
|
|
|
|
sock = FakeSocket(b'some initial data')
|
|
|
|
def fake_select(rlist, wlist, xlist, timeout=None):
|
|
return ([sock], [], [])
|
|
|
|
patch('select.select').start().side_effect = fake_select
|
|
self.assertEqual(server.do_handshake(sock, '127.0.0.1'), sock)
|
|
self.assertTrue(FakeHandler.CALLED, True)
|
|
|
|
def test_do_handshake_ssl(self):
|
|
# TODO(directxman12): implement this
|
|
pass
|
|
|
|
def test_do_handshake_ssl_without_ssl_raises_error(self):
|
|
# TODO(directxman12): implement this
|
|
pass
|
|
|
|
def test_do_handshake_ssl_without_cert_raises_error(self):
|
|
server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1,
|
|
cert='afdsfasdafdsafdsafdsafdas')
|
|
|
|
sock = FakeSocket(b"\x16some ssl data")
|
|
|
|
def fake_select(rlist, wlist, xlist, timeout=None):
|
|
return ([sock], [], [])
|
|
|
|
patch('select.select').start().side_effect = fake_select
|
|
self.assertRaises(
|
|
websockifyserver.WebSockifyServer.EClose, server.do_handshake,
|
|
sock, '127.0.0.1')
|
|
|
|
def test_do_handshake_ssl_error_eof_raises_close_error(self):
|
|
server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1)
|
|
|
|
sock = FakeSocket(b"\x16some ssl data")
|
|
|
|
def fake_select(rlist, wlist, xlist, timeout=None):
|
|
return ([sock], [], [])
|
|
|
|
def fake_wrap_socket(*args, **kwargs):
|
|
raise ssl.SSLError(ssl.SSL_ERROR_EOF)
|
|
|
|
class fake_create_default_context():
|
|
def __init__(self, purpose):
|
|
self.verify_mode = None
|
|
self.options = 0
|
|
def load_cert_chain(self, certfile, keyfile, password):
|
|
pass
|
|
def set_default_verify_paths(self):
|
|
pass
|
|
def load_verify_locations(self, cafile):
|
|
pass
|
|
def wrap_socket(self, *args, **kwargs):
|
|
raise ssl.SSLError(ssl.SSL_ERROR_EOF)
|
|
|
|
patch('select.select').start().side_effect = fake_select
|
|
patch('ssl.create_default_context').start().side_effect = fake_create_default_context
|
|
self.assertRaises(
|
|
websockifyserver.WebSockifyServer.EClose, server.do_handshake,
|
|
sock, '127.0.0.1')
|
|
|
|
def test_do_handshake_ssl_sets_ciphers(self):
|
|
test_ciphers = 'TEST-CIPHERS-1:TEST-CIPHER-2'
|
|
|
|
class FakeHandler(object):
|
|
def __init__(self, *args, **kwargs):
|
|
pass
|
|
|
|
server = self._get_server(handler_class=FakeHandler, daemon=True,
|
|
idle_timeout=1, ssl_ciphers=test_ciphers)
|
|
sock = FakeSocket(b"\x16some ssl data")
|
|
|
|
def fake_select(rlist, wlist, xlist, timeout=None):
|
|
return ([sock], [], [])
|
|
|
|
class fake_create_default_context():
|
|
CIPHERS = ''
|
|
def __init__(self, purpose):
|
|
self.verify_mode = None
|
|
self.options = 0
|
|
def load_cert_chain(self, certfile, keyfile, password):
|
|
pass
|
|
def set_default_verify_paths(self):
|
|
pass
|
|
def load_verify_locations(self, cafile):
|
|
pass
|
|
def wrap_socket(self, *args, **kwargs):
|
|
pass
|
|
def set_ciphers(self, ciphers_to_set):
|
|
fake_create_default_context.CIPHERS = ciphers_to_set
|
|
|
|
patch('select.select').start().side_effect = fake_select
|
|
patch('ssl.create_default_context').start().side_effect = fake_create_default_context
|
|
server.do_handshake(sock, '127.0.0.1')
|
|
self.assertEqual(fake_create_default_context.CIPHERS, test_ciphers)
|
|
|
|
def test_do_handshake_ssl_sets_opions(self):
|
|
test_options = 0xCAFEBEEF
|
|
|
|
class FakeHandler(object):
|
|
def __init__(self, *args, **kwargs):
|
|
pass
|
|
|
|
server = self._get_server(handler_class=FakeHandler, daemon=True,
|
|
idle_timeout=1, ssl_options=test_options)
|
|
sock = FakeSocket(b"\x16some ssl data")
|
|
|
|
def fake_select(rlist, wlist, xlist, timeout=None):
|
|
return ([sock], [], [])
|
|
|
|
class fake_create_default_context(object):
|
|
OPTIONS = 0
|
|
def __init__(self, purpose):
|
|
self.verify_mode = None
|
|
self._options = 0
|
|
def load_cert_chain(self, certfile, keyfile, password):
|
|
pass
|
|
def set_default_verify_paths(self):
|
|
pass
|
|
def load_verify_locations(self, cafile):
|
|
pass
|
|
def wrap_socket(self, *args, **kwargs):
|
|
pass
|
|
def get_options(self):
|
|
return self._options
|
|
def set_options(self, val):
|
|
fake_create_default_context.OPTIONS = val
|
|
options = property(get_options, set_options)
|
|
|
|
patch('select.select').start().side_effect = fake_select
|
|
patch('ssl.create_default_context').start().side_effect = fake_create_default_context
|
|
server.do_handshake(sock, '127.0.0.1')
|
|
self.assertEqual(fake_create_default_context.OPTIONS, test_options)
|
|
|
|
def test_fallback_sigchld_handler(self):
|
|
# TODO(directxman12): implement this
|
|
pass
|
|
|
|
def test_start_server_error(self):
|
|
server = self._get_server(daemon=False, ssl_only=1, idle_timeout=1)
|
|
sock = server.socket('localhost')
|
|
|
|
def fake_select(rlist, wlist, xlist, timeout=None):
|
|
raise Exception("fake error")
|
|
|
|
patch('websockify.websockifyserver.WebSockifyServer.socket').start()
|
|
patch('websockify.websockifyserver.WebSockifyServer.daemonize').start()
|
|
patch('select.select').start().side_effect = fake_select
|
|
server.start_server()
|
|
|
|
def test_start_server_keyboardinterrupt(self):
|
|
server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1)
|
|
sock = server.socket('localhost')
|
|
|
|
def fake_select(rlist, wlist, xlist, timeout=None):
|
|
raise KeyboardInterrupt
|
|
|
|
patch('websockify.websockifyserver.WebSockifyServer.socket').start()
|
|
patch('websockify.websockifyserver.WebSockifyServer.daemonize').start()
|
|
patch('select.select').start().side_effect = fake_select
|
|
server.start_server()
|
|
|
|
def test_start_server_systemexit(self):
|
|
server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1)
|
|
sock = server.socket('localhost')
|
|
|
|
def fake_select(rlist, wlist, xlist, timeout=None):
|
|
sys.exit()
|
|
|
|
patch('websockify.websockifyserver.WebSockifyServer.socket').start()
|
|
patch('websockify.websockifyserver.WebSockifyServer.daemonize').start()
|
|
patch('select.select').start().side_effect = fake_select
|
|
server.start_server()
|
|
|
|
def test_socket_set_keepalive_options(self):
|
|
keepcnt = 12
|
|
keepidle = 34
|
|
keepintvl = 56
|
|
|
|
server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1)
|
|
sock = server.socket('localhost',
|
|
tcp_keepcnt=keepcnt,
|
|
tcp_keepidle=keepidle,
|
|
tcp_keepintvl=keepintvl)
|
|
|
|
if hasattr(socket, 'TCP_KEEPCNT'):
|
|
self.assertEqual(sock.getsockopt(socket.SOL_TCP,
|
|
socket.TCP_KEEPCNT), keepcnt)
|
|
self.assertEqual(sock.getsockopt(socket.SOL_TCP,
|
|
socket.TCP_KEEPIDLE), keepidle)
|
|
self.assertEqual(sock.getsockopt(socket.SOL_TCP,
|
|
socket.TCP_KEEPINTVL), keepintvl)
|
|
|
|
sock = server.socket('localhost',
|
|
tcp_keepalive=False,
|
|
tcp_keepcnt=keepcnt,
|
|
tcp_keepidle=keepidle,
|
|
tcp_keepintvl=keepintvl)
|
|
|
|
if hasattr(socket, 'TCP_KEEPCNT'):
|
|
self.assertNotEqual(sock.getsockopt(socket.SOL_TCP,
|
|
socket.TCP_KEEPCNT), keepcnt)
|
|
self.assertNotEqual(sock.getsockopt(socket.SOL_TCP,
|
|
socket.TCP_KEEPIDLE), keepidle)
|
|
self.assertNotEqual(sock.getsockopt(socket.SOL_TCP,
|
|
socket.TCP_KEEPINTVL), keepintvl)
|