diff --git a/.travis.yml b/.travis.yml index 7cec876..62e88a5 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,6 +1,5 @@ language: python python: - - 2.7 - 3.6 install: diff --git a/test-requirements.txt b/test-requirements.txt index 8e01437..1fc147a 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -1,5 +1,5 @@ mock nose -jwcrypto;python_version>="2.7" -redis;python_version>="2.7" -simplejson;python_version>="2.7" +jwcrypto +redis +simplejson diff --git a/tests/test_websocketproxy.py b/tests/test_websocketproxy.py index d8a4916..74ed9a4 100644 --- a/tests/test_websocketproxy.py +++ b/tests/test_websocketproxy.py @@ -20,32 +20,20 @@ import sys import unittest import unittest import socket -try: - from mock import patch -except ImportError: - from unittest.mock import patch +from io import StringIO +from io import BytesIO +from unittest.mock import patch + +from jwcrypto import jwt from websockify import websocketproxy from websockify import token_plugins from websockify import auth_plugins -if sys.version_info >= (2,7): - from jwcrypto import jwt - -try: - from StringIO import StringIO - BytesIO = StringIO -except ImportError: - from io import StringIO - from io import BytesIO - class FakeSocket(object): - def __init__(self, data=''): - if isinstance(data, bytes): - self._data = data - else: - self._data = data.encode('latin_1') + def __init__(self, data=b''): + self._data = data def recv(self, amt, flags=None): res = self._data[0:amt] @@ -76,7 +64,7 @@ class ProxyRequestHandlerTestCase(unittest.TestCase): def setUp(self): super(ProxyRequestHandlerTestCase, self).setUp() self.handler = websocketproxy.ProxyRequestHandler( - FakeSocket(''), "127.0.0.1", FakeServer()) + FakeSocket(), "127.0.0.1", FakeServer()) self.handler.path = "https://localhost:6080/websockify?token=blah" self.handler.headers = None patch('websockify.websockifyserver.WebSockifyServer.socket').start() diff --git a/tests/test_websockifyserver.py b/tests/test_websockifyserver.py index a089f55..4f46cb1 100644 --- a/tests/test_websockifyserver.py +++ b/tests/test_websockifyserver.py @@ -22,42 +22,26 @@ import select import shutil import socket import ssl -try: - from mock import patch, MagicMock, ANY -except ImportError: - from unittest.mock import patch, MagicMock, ANY +from unittest.mock import patch, MagicMock, ANY import sys import tempfile import unittest import socket import signal +from http.server import BaseHTTPRequestHandler +from io import StringIO +from io import BytesIO + from websockify import websockifyserver -try: - from BaseHTTPServer import BaseHTTPRequestHandler -except ImportError: - from http.server import BaseHTTPRequestHandler - -try: - from StringIO import StringIO - BytesIO = StringIO -except ImportError: - from io import StringIO - from io import BytesIO - - - def raise_oserror(*args, **kwargs): raise OSError('fake error') class FakeSocket(object): - def __init__(self, data=''): - if isinstance(data, bytes): - self._data = data - else: - self._data = data.encode('latin_1') + def __init__(self, data=b''): + self._data = data def recv(self, amt, flags=None): res = self._data[0:amt] @@ -99,7 +83,7 @@ class WebSockifyRequestHandlerTestCase(unittest.TestCase): def test_normal_get_with_only_upgrade_returns_error(self, send_error): server = self._get_server(web=None) handler = websockifyserver.WebSockifyRequestHandler( - FakeSocket('GET /tmp.txt HTTP/1.1'), '127.0.0.1', server) + FakeSocket(b'GET /tmp.txt HTTP/1.1'), '127.0.0.1', server) handler.do_GET() send_error.assert_called_with(405, ANY) @@ -108,7 +92,7 @@ class WebSockifyRequestHandlerTestCase(unittest.TestCase): def test_list_dir_with_file_only_returns_error(self, send_error): server = self._get_server(file_only=True) handler = websockifyserver.WebSockifyRequestHandler( - FakeSocket('GET / HTTP/1.1'), '127.0.0.1', server) + FakeSocket(b'GET / HTTP/1.1'), '127.0.0.1', server) handler.path = '/' handler.do_GET() @@ -186,7 +170,7 @@ class WebSockifyServerTestCase(unittest.TestCase): def test_handshake_ssl_only_without_ssl_raises_error(self): server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1) - sock = FakeSocket('some initial data') + sock = FakeSocket(b'some initial data') def fake_select(rlist, wlist, xlist, timeout=None): return ([sock], [], []) @@ -208,7 +192,7 @@ class WebSockifyServerTestCase(unittest.TestCase): handler_class=FakeHandler, daemon=True, ssl_only=0, idle_timeout=1) - sock = FakeSocket('some initial data') + sock = FakeSocket(b'some initial data') def fake_select(rlist, wlist, xlist, timeout=None): return ([sock], [], []) @@ -229,7 +213,7 @@ class WebSockifyServerTestCase(unittest.TestCase): server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1, cert='afdsfasdafdsafdsafdsafdas') - sock = FakeSocket("\x16some ssl data") + sock = FakeSocket(b"\x16some ssl data") def fake_select(rlist, wlist, xlist, timeout=None): return ([sock], [], []) @@ -242,7 +226,7 @@ class WebSockifyServerTestCase(unittest.TestCase): def test_do_handshake_ssl_error_eof_raises_close_error(self): server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1) - sock = FakeSocket("\x16some ssl data") + sock = FakeSocket(b"\x16some ssl data") def fake_select(rlist, wlist, xlist, timeout=None): return ([sock], [], []) @@ -264,12 +248,7 @@ class WebSockifyServerTestCase(unittest.TestCase): raise ssl.SSLError(ssl.SSL_ERROR_EOF) patch('select.select').start().side_effect = fake_select - if (hasattr(ssl, 'create_default_context')): - # for recent versions of python - patch('ssl.create_default_context').start().side_effect = fake_create_default_context - else: - # for fallback for old versions of python - patch('ssl.warp_socket').start().side_effect = fake_wrap_socket + patch('ssl.create_default_context').start().side_effect = fake_create_default_context self.assertRaises( websockifyserver.WebSockifyServer.EClose, server.do_handshake, sock, '127.0.0.1') @@ -283,7 +262,7 @@ class WebSockifyServerTestCase(unittest.TestCase): server = self._get_server(handler_class=FakeHandler, daemon=True, idle_timeout=1, ssl_ciphers=test_ciphers) - sock = FakeSocket("\x16some ssl data") + sock = FakeSocket(b"\x16some ssl data") def fake_select(rlist, wlist, xlist, timeout=None): return ([sock], [], []) @@ -305,15 +284,9 @@ class WebSockifyServerTestCase(unittest.TestCase): fake_create_default_context.CIPHERS = ciphers_to_set patch('select.select').start().side_effect = fake_select - if (hasattr(ssl, 'create_default_context')): - # for recent versions of python - patch('ssl.create_default_context').start().side_effect = fake_create_default_context - server.do_handshake(sock, '127.0.0.1') - self.assertEqual(fake_create_default_context.CIPHERS, test_ciphers) - else: - # for fallback for old versions of python - # not supperted, nothing to test - pass + patch('ssl.create_default_context').start().side_effect = fake_create_default_context + server.do_handshake(sock, '127.0.0.1') + self.assertEqual(fake_create_default_context.CIPHERS, test_ciphers) def test_do_handshake_ssl_sets_opions(self): test_options = 0xCAFEBEEF @@ -324,7 +297,7 @@ class WebSockifyServerTestCase(unittest.TestCase): server = self._get_server(handler_class=FakeHandler, daemon=True, idle_timeout=1, ssl_options=test_options) - sock = FakeSocket("\x16some ssl data") + sock = FakeSocket(b"\x16some ssl data") def fake_select(rlist, wlist, xlist, timeout=None): return ([sock], [], []) @@ -349,15 +322,9 @@ class WebSockifyServerTestCase(unittest.TestCase): options = property(get_options, set_options) patch('select.select').start().side_effect = fake_select - if (hasattr(ssl, 'create_default_context')): - # for recent versions of python - patch('ssl.create_default_context').start().side_effect = fake_create_default_context - server.do_handshake(sock, '127.0.0.1') - self.assertEqual(fake_create_default_context.OPTIONS, test_options) - else: - # for fallback for old versions of python - # not supperted, nothing to test - pass + patch('ssl.create_default_context').start().side_effect = fake_create_default_context + server.do_handshake(sock, '127.0.0.1') + self.assertEqual(fake_create_default_context.OPTIONS, test_options) def test_fallback_sigchld_handler(self): # TODO(directxman12): implement this diff --git a/tox.ini b/tox.ini index 79f7201..526eff6 100644 --- a/tox.ini +++ b/tox.ini @@ -4,7 +4,7 @@ # and then run "tox" from this directory. [tox] -envlist = py24,py26,py27,py33,py34 +envlist = py34 [testenv] commands = nosetests {posargs} diff --git a/websockify/auth_plugins.py b/websockify/auth_plugins.py index 2d636c2..36fac52 100644 --- a/websockify/auth_plugins.py +++ b/websockify/auth_plugins.py @@ -1,4 +1,4 @@ -class BasePlugin(object): +class BasePlugin(): def __init__(self, src=None): self.source = src @@ -15,7 +15,7 @@ class AuthenticationError(Exception): if log_msg is None: log_msg = response_msg - super(AuthenticationError, self).__init__('%s %s' % (self.code, log_msg)) + super().__init__('%s %s' % (self.code, log_msg)) class InvalidOriginError(AuthenticationError): @@ -23,13 +23,13 @@ class InvalidOriginError(AuthenticationError): self.expected_origin = expected self.actual_origin = actual - super(InvalidOriginError, self).__init__( + super().__init__( response_msg='Invalid Origin', log_msg="Invalid Origin Header: Expected one of " "%s, got '%s'" % (expected, actual)) -class BasicHTTPAuth(object): +class BasicHTTPAuth(): """Verifies Basic Auth headers. Specify src as username:password""" def __init__(self, src=None): @@ -76,7 +76,7 @@ class BasicHTTPAuth(object): raise AuthenticationError(response_code=401, response_headers={'WWW-Authenticate': 'Basic realm="Websockify"'}) -class ExpectOrigin(object): +class ExpectOrigin(): def __init__(self, src=None): if src is None: self.source = [] @@ -88,7 +88,7 @@ class ExpectOrigin(object): if origin is None or origin not in self.source: raise InvalidOriginError(expected=self.source, actual=origin) -class ClientCertCNAuth(object): +class ClientCertCNAuth(): """Verifies client by SSL certificate. Specify src as whitespace separated list of common names.""" def __init__(self, src=None): diff --git a/websockify/sysloghandler.py b/websockify/sysloghandler.py index 92ca66f..37ee9dd 100644 --- a/websockify/sysloghandler.py +++ b/websockify/sysloghandler.py @@ -44,7 +44,7 @@ class WebsockifySysLogHandler(handlers.SysLogHandler): self._legacy = True self._head_fmt = self._legacy_head_fmt - handlers.SysLogHandler.__init__(self, address, facility, socktype) + super().__init__(address, facility, socktype) def emit(self, record): diff --git a/websockify/token_plugins.py b/websockify/token_plugins.py index 7a988a3..e03839a 100644 --- a/websockify/token_plugins.py +++ b/websockify/token_plugins.py @@ -1,8 +1,7 @@ -from __future__ import print_function import os import sys -class BasePlugin(object): +class BasePlugin(): def __init__(self, src): self.source = src @@ -15,7 +14,7 @@ class ReadOnlyTokenFile(BasePlugin): # token: host:port # or a directory of such files def __init__(self, *args, **kwargs): - super(ReadOnlyTokenFile, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) self._targets = None def _load_targets(self): @@ -57,7 +56,7 @@ class TokenFile(ReadOnlyTokenFile): def lookup(self, token): self._load_targets() - return super(TokenFile, self).lookup(token) + return super().lookup(token) class BaseTokenAPI(BasePlugin): @@ -137,7 +136,7 @@ class JWTTokenApi(BasePlugin): print("package jwcrypto not found, are you sure you've installed it correctly?", file=sys.stderr) return None -class TokenRedis(object): +class TokenRedis(): def __init__(self, src): self._server, self._port = src.split(":") @@ -162,7 +161,7 @@ class TokenRedis(object): class UnixDomainSocketDirectory(BasePlugin): def __init__(self, *args, **kwargs): - super(UnixDomainSocketDirectory, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) self._dir_path = os.path.abspath(self.source) def lookup(self, token): diff --git a/websockify/websocket.py b/websockify/websocket.py index a0dd5b8..d1857f3 100644 --- a/websockify/websocket.py +++ b/websockify/websocket.py @@ -22,6 +22,7 @@ import ssl import struct from base64 import b64encode from hashlib import sha1 +from urllib.parse import urlparse try: import numpy @@ -30,25 +31,10 @@ except ImportError: warnings.warn("no 'numpy' module, HyBi protocol will be slower") numpy = None -# python 3.0 differences -try: - from urllib.parse import urlparse -except ImportError: - from urlparse import urlparse - -# SSLWant*Error is 2.7.9+ -try: - class WebSocketWantReadError(ssl.SSLWantReadError): - pass - class WebSocketWantWriteError(ssl.SSLWantWriteError): - pass -except AttributeError: - class WebSocketWantReadError(OSError): - def __init__(self): - OSError.__init__(self, errno.EWOULDBLOCK) - class WebSocketWantWriteError(OSError): - def __init__(self): - OSError.__init__(self, errno.EWOULDBLOCK) +class WebSocketWantReadError(ssl.SSLWantReadError): + pass +class WebSocketWantWriteError(ssl.SSLWantWriteError): + pass class WebSocket(object): """WebSocket protocol socket like class. @@ -87,11 +73,11 @@ class WebSocket(object): self._state = "new" - self._partial_msg = ''.encode("ascii") + self._partial_msg = b'' - self._recv_buffer = ''.encode("ascii") + self._recv_buffer = b'' self._recv_queue = [] - self._send_buffer = ''.encode("ascii") + self._send_buffer = b'' self._previous_sendmsg = None @@ -166,9 +152,7 @@ class WebSocket(object): self._key = '' for i in range(16): self._key += chr(random.randrange(256)) - if sys.hexversion >= 0x3000000: - self._key = bytes(self._key, "latin-1") - self._key = b64encode(self._key).decode("ascii") + self._key = b64encode(self._key.encode("latin-1")).decode("ascii") path = uri.path if not path: @@ -198,10 +182,10 @@ class WebSocket(object): if not self._recv(): raise Exception("Socket closed unexpectedly") - if self._recv_buffer.find('\r\n\r\n'.encode("ascii")) == -1: + if self._recv_buffer.find(b'\r\n\r\n') == -1: raise WebSocketWantReadError - (request, self._recv_buffer) = self._recv_buffer.split('\r\n'.encode("ascii"), 1) + (request, self._recv_buffer) = self._recv_buffer.split(b'\r\n', 1) request = request.decode("latin-1") words = request.split() @@ -210,7 +194,7 @@ class WebSocket(object): if words[1] != "101": raise Exception("WebSocket request denied: %s" % " ".join(words[1:])) - (headers, self._recv_buffer) = self._recv_buffer.split('\r\n\r\n'.encode("ascii"), 1) + (headers, self._recv_buffer) = self._recv_buffer.split(b'\r\n\r\n', 1) headers = headers.decode('latin-1') + '\r\n' headers = email.message_from_string(headers) @@ -463,7 +447,7 @@ class WebSocket(object): return len(msg) - def ping(self, data=''.encode('ascii')): + def ping(self, data=b''): """Write a ping message to the WebSocket WebSocketWantWriteError can be raised if there is insufficient @@ -488,7 +472,7 @@ class WebSocket(object): self._previous_sendmsg = data raise - def pong(self, data=''.encode('ascii')): + def pong(self, data=b''): """Write a pong message to the WebSocket WebSocketWantWriteError can be raised if there is insufficient @@ -542,7 +526,7 @@ class WebSocket(object): self._sent_close = True - msg = ''.encode('ascii') + msg = b'' if code is not None: msg += struct.pack(">H", code) if reason is not None: @@ -571,16 +555,9 @@ class WebSocket(object): while True: try: data = self.socket.recv(4096) - except (socket.error, OSError): - exc = sys.exc_info()[1] - if hasattr(exc, 'errno'): - err = exc.errno - else: - err = exc[0] - - if err == errno.EWOULDBLOCK: + except OSError as exc: + if exc.errno == errno.EWOULDBLOCK: raise WebSocketWantReadError - raise if len(data) == 0: @@ -637,7 +614,7 @@ class WebSocket(object): if frame["fin"]: msg = self._partial_msg - self._partial_msg = ''.encode("ascii") + self._partial_msg = b'' return msg elif frame["opcode"] == 0x1: self.shutdown(socket.SHUT_RDWR, 1003, "Unsupported: Text frames are not supported") @@ -712,16 +689,9 @@ class WebSocket(object): try: sent = self.socket.send(self._send_buffer) - except (socket.error, OSError): - exc = sys.exc_info()[1] - if hasattr(exc, 'errno'): - err = exc.errno - else: - err = exc[0] - - if err == errno.EWOULDBLOCK: + except OSError as exc: + if exc.errno == errno.EWOULDBLOCK: raise WebSocketWantWriteError - raise self._send_buffer = self._send_buffer[sent:] @@ -747,11 +717,9 @@ class WebSocket(object): def _sendmsg(self, opcode, msg): # Sends a standard data message if self.client: - mask = '' + mask = b'' for i in range(4): - mask += chr(random.randrange(256)) - if sys.hexversion >= 0x3000000: - mask = bytes(mask, "latin-1") + mask += random.randrange(256) frame = self._encode_hybi(opcode, msg, mask) else: frame = self._encode_hybi(opcode, msg) @@ -773,7 +741,7 @@ class WebSocket(object): plen = len(buf) pstart = 0 pend = plen - b = c = ''.encode('ascii') + b = c = b'' if plen >= 4: dtype=numpy.dtype(' 0x3000000: - s2b = lambda s: s.encode('latin_1') -else: - s2b = lambda s: s # No-op - -try: - from http.server import SimpleHTTPRequestHandler -except ImportError: - from SimpleHTTPServer import SimpleHTTPRequestHandler +from http.server import SimpleHTTPRequestHandler # Degraded functionality if these imports are missing for mod, msg in [('ssl', 'TLS/SSL/wss is disabled'), @@ -96,7 +84,7 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa if self.logger is None: self.logger = WebSockifyServer.get_logger() - SimpleHTTPRequestHandler.__init__(self, req, addr, server) + super().__init__(req, addr, server) def log_message(self, format, *args): self.logger.info("%s - - [%s] %s" % (self.client_address[0], self.log_date_time_string(), format % args)) @@ -212,7 +200,7 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa self.validate_connection() self.auth_connection() - WebSocketRequestHandlerMixIn.handle_upgrade(self) + super().handle_upgrade() def handle_websocket(self): # Indicate to server that a Websocket upgrade was done @@ -264,13 +252,13 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa if self.only_upgrade: self.send_error(405, "Method Not Allowed") else: - SimpleHTTPRequestHandler.do_GET(self) + super().do_GET() def list_directory(self, path): if self.file_only: self.send_error(404, "No such file") else: - return SimpleHTTPRequestHandler.list_directory(self, path) + return super().list_directory(path) def new_websocket_client(self): """ Do something with a WebSockets client connection. """ @@ -291,13 +279,13 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa if self.only_upgrade: self.send_error(405, "Method Not Allowed") else: - SimpleHTTPRequestHandler.do_HEAD(self) + super().do_HEAD() def finish(self): if self.rec: self.rec.write("'EOF'];\n") self.rec.close() - SimpleHTTPRequestHandler.finish(self) + super().finish() def handle(self): # When using run_once, we have a single process, so @@ -306,14 +294,14 @@ class WebSockifyRequestHandler(WebSocketRequestHandlerMixIn, SimpleHTTPRequestHa if self.run_once: self.handle_one_request() else: - SimpleHTTPRequestHandler.handle(self) + super().handle() def log_request(self, code='-', size='-'): if self.verbose: - SimpleHTTPRequestHandler.log_request(self, code, size) + super().log_request(code, size) -class WebSockifyServer(object): +class WebSockifyServer(): """ WebSockets server class. As an alternative, the standard library SocketServer can be used @@ -553,7 +541,7 @@ class WebSockifyServer(object): if not handshake: raise self.EClose("") - elif handshake[0] in ("\x16", "\x80", 22, 128): + elif handshake[0] in (22, 128): # SSL wrap the connection if not ssl: raise self.EClose("SSL connection but no 'ssl' module") @@ -562,32 +550,21 @@ class WebSockifyServer(object): % self.cert) retsock = None try: - if (hasattr(ssl, 'create_default_context') - and callable(ssl.create_default_context)): - # create new-style SSL wrapping for extended features - context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) - if self.ssl_ciphers is not None: - context.set_ciphers(self.ssl_ciphers) - context.options = self.ssl_options - context.load_cert_chain(certfile=self.cert, keyfile=self.key, password=self.key_password) - if self.verify_client: - context.verify_mode = ssl.CERT_REQUIRED - if self.cafile: - context.load_verify_locations(cafile=self.cafile) - else: - context.set_default_verify_paths() - retsock = context.wrap_socket( - sock, - server_side=True) - else: - if self.verify_client: - raise self.EClose("Client certificate verification requested, but this Python is too old.") - # new-style SSL wrapping is not needed, using to old style - retsock = ssl.wrap_socket( - sock, - server_side=True, - certfile=self.cert, - keyfile=self.key) + # create new-style SSL wrapping for extended features + context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + if self.ssl_ciphers is not None: + context.set_ciphers(self.ssl_ciphers) + context.options = self.ssl_options + context.load_cert_chain(certfile=self.cert, keyfile=self.key, password=self.key_password) + if self.verify_client: + context.verify_mode = ssl.CERT_REQUIRED + if self.cafile: + context.load_verify_locations(cafile=self.cafile) + else: + context.set_default_verify_paths() + retsock = context.wrap_socket( + sock, + server_side=True) except ssl.SSLError: _, x, _ = sys.exc_info() if x.args[0] == ssl.SSL_ERROR_EOF: @@ -723,10 +700,6 @@ class WebSockifyServer(object): if self.listen_fd != None: lsock = socket.fromfd(self.listen_fd, socket.AF_INET, socket.SOCK_STREAM) - if sys.hexversion < 0x3000000: - # For python 2 we have to wrap the "raw" socket into a socket object, - # otherwise ssl wrap_socket doesn't work. - lsock = socket.socket(_sock=lsock) else: lsock = self.socket(self.listen_host, self.listen_port, False, self.prefer_ipv6,