diff --git a/tests/test_websockifyserver.py b/tests/test_websockifyserver.py index 63c9449..7ce82da 100644 --- a/tests/test_websockifyserver.py +++ b/tests/test_websockifyserver.py @@ -270,6 +270,7 @@ class WebSockifyServerTestCase(unittest.TestCase): class fake_create_default_context(): def __init__(self, purpose): self.verify_mode = None + self.options = 0 def load_cert_chain(self, certfile, keyfile): pass def set_default_verify_paths(self): @@ -290,6 +291,91 @@ class WebSockifyServerTestCase(unittest.TestCase): websockifyserver.WebSockifyServer.EClose, server.do_handshake, sock, '127.0.0.1') + def test_do_handshake_ssl_sets_ciphers(self): + test_ciphers = 'TEST-CIPHERS-1:TEST-CIPHER-2' + + class FakeHandler(object): + def __init__(self, *args, **kwargs): + pass + + server = self._get_server(handler_class=FakeHandler, daemon=True, + idle_timeout=1, ssl_ciphers=test_ciphers) + sock = FakeSocket("\x16some ssl data") + + def fake_select(rlist, wlist, xlist, timeout=None): + return ([sock], [], []) + + class fake_create_default_context(): + CIPHERS = '' + def __init__(self, purpose): + self.verify_mode = None + self.options = 0 + def load_cert_chain(self, certfile, keyfile): + pass + def set_default_verify_paths(self): + pass + def load_verify_locations(self, cafile): + pass + def wrap_socket(self, *args, **kwargs): + pass + def set_ciphers(self, ciphers_to_set): + fake_create_default_context.CIPHERS = ciphers_to_set + + self.stubs.Set(select, 'select', fake_select) + if (hasattr(ssl, 'create_default_context')): + # for recent versions of python + self.stubs.Set(ssl, 'create_default_context', fake_create_default_context) + server.do_handshake(sock, '127.0.0.1') + self.assertEqual(fake_create_default_context.CIPHERS, test_ciphers) + else: + # for fallback for old versions of python + # not supperted, nothing to test + pass + + def test_do_handshake_ssl_sets_opions(self): + test_options = 0xCAFEBEEF + + class FakeHandler(object): + def __init__(self, *args, **kwargs): + pass + + server = self._get_server(handler_class=FakeHandler, daemon=True, + idle_timeout=1, ssl_options=test_options) + sock = FakeSocket("\x16some ssl data") + + def fake_select(rlist, wlist, xlist, timeout=None): + return ([sock], [], []) + + class fake_create_default_context(object): + OPTIONS = 0 + def __init__(self, purpose): + self.verify_mode = None + self._options = 0 + def load_cert_chain(self, certfile, keyfile): + pass + def set_default_verify_paths(self): + pass + def load_verify_locations(self, cafile): + pass + def wrap_socket(self, *args, **kwargs): + pass + def get_options(self): + return self._options + def set_options(self, val): + fake_create_default_context.OPTIONS = val + options = property(get_options, set_options) + + self.stubs.Set(select, 'select', fake_select) + if (hasattr(ssl, 'create_default_context')): + # for recent versions of python + self.stubs.Set(ssl, 'create_default_context', fake_create_default_context) + server.do_handshake(sock, '127.0.0.1') + self.assertEqual(fake_create_default_context.OPTIONS, test_options) + else: + # for fallback for old versions of python + # not supperted, nothing to test + pass + def test_fallback_sigchld_handler(self): # TODO(directxman12): implement this pass diff --git a/websockify/websockifyserver.py b/websockify/websockifyserver.py index b9787a6..f8dccc4 100644 --- a/websockify/websockifyserver.py +++ b/websockify/websockifyserver.py @@ -346,7 +346,7 @@ class WebSockifyServer(object): file_only=False, run_once=False, timeout=0, idle_timeout=0, traffic=False, tcp_keepalive=True, tcp_keepcnt=None, tcp_keepidle=None, - tcp_keepintvl=None): + tcp_keepintvl=None, ssl_ciphers=None, ssl_options=0): # settings self.RequestHandlerClass = RequestHandlerClass @@ -356,6 +356,8 @@ class WebSockifyServer(object): self.listen_port = listen_port self.prefer_ipv6 = source_is_ipv6 self.ssl_only = ssl_only + self.ssl_ciphers = ssl_ciphers + self.ssl_options = ssl_options self.verify_client = verify_client self.daemon = daemon self.run_once = run_once @@ -572,6 +574,9 @@ class WebSockifyServer(object): and callable(ssl.create_default_context)): # create new-style SSL wrapping for extended features context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + if self.ssl_ciphers is not None: + context.set_ciphers(self.ssl_ciphers) + context.options = self.ssl_options context.load_cert_chain(certfile=self.cert, keyfile=self.key) if self.verify_client: context.verify_mode = ssl.CERT_REQUIRED