diff --git a/tests/test_websocketproxy.py b/tests/test_websocketproxy.py index 7eede27..81c119e 100644 --- a/tests/test_websocketproxy.py +++ b/tests/test_websocketproxy.py @@ -139,8 +139,8 @@ class ProxyRequestHandlerTestCase(unittest.TestCase): self.handler.server.target_port = "someport" self.assertRaises(auth_plugins.AuthenticationError, - self.handler.validate_connection) + self.handler.auth_connection) self.handler.server.target_host = "someotherhost" - self.handler.validate_connection() + self.handler.auth_connection() diff --git a/websockify/auth_plugins.py b/websockify/auth_plugins.py index 7ed3c2a..ed3a169 100644 --- a/websockify/auth_plugins.py +++ b/websockify/auth_plugins.py @@ -40,29 +40,28 @@ class BasicHTTPAuth(object): auth_header = headers.get('Authorization') if auth_header: if not auth_header.startswith('Basic '): - raise AuthenticationError(response_code=403) + self.auth_error() try: user_pass_raw = base64.b64decode(auth_header[6:]) except TypeError: - raise AuthenticationError(response_code=403) + self.auth_error() try: # http://stackoverflow.com/questions/7242316/what-encoding-should-i-use-for-http-basic-authentication user_pass_as_text = user_pass_raw.decode('ISO-8859-1') except UnicodeDecodeError: - raise AuthenticationError(response_code=403) + self.auth_error() user_pass = user_pass_as_text.split(':', 1) if len(user_pass) != 2: - raise AuthenticationError(response_code=403) + self.auth_error() if not self.validate_creds(*user_pass): - raise AuthenticationError(response_code=403) + self.demand_auth() else: - raise AuthenticationError(response_code=401, - response_headers={'WWW-Authenticate': 'Basic realm="Websockify"'}) + self.demand_auth() def validate_creds(self, username, password): if '%s:%s' % (username, password) == self.src: @@ -70,6 +69,13 @@ class BasicHTTPAuth(object): else: return False + def auth_error(self): + raise AuthenticationError(response_code=403) + + def demand_auth(self): + raise AuthenticationError(response_code=401, + response_headers={'WWW-Authenticate': 'Basic realm="Websockify"'}) + class ExpectOrigin(object): def __init__(self, src=None): if src is None: diff --git a/websockify/websocketproxy.py b/websockify/websocketproxy.py index 6c17c10..6ab21fa 100644 --- a/websockify/websocketproxy.py +++ b/websockify/websocketproxy.py @@ -56,38 +56,42 @@ Traffic Legend: self.end_headers() def validate_connection(self): - if self.server.token_plugin: - host, port = self.get_target(self.server.token_plugin, self.path) - if host == 'unix_socket': - self.server.unix_target = port + if not self.server.token_plugin: + return - else: - self.server.target_host = host - self.server.target_port = port + host, port = self.get_target(self.server.token_plugin) + if host == 'unix_socket': + self.server.unix_target = port - if self.server.auth_plugin: + else: + self.server.target_host = host + self.server.target_port = port + + def auth_connection(self): + if not self.server.auth_plugin: + return + + try: + # get client certificate data + client_cert_data = self.request.getpeercert() + # extract subject information + client_cert_subject = client_cert_data['subject'] + # flatten data structure + client_cert_subject = dict([x[0] for x in client_cert_subject]) + # add common name to headers (apache +StdEnvVars style) + self.headers['SSL_CLIENT_S_DN_CN'] = client_cert_subject['commonName'] + except (TypeError, AttributeError, KeyError): + # not a SSL connection or client presented no certificate with valid data + pass - try: - # get client certificate data - client_cert_data = self.request.getpeercert() - # extract subject information - client_cert_subject = client_cert_data['subject'] - # flatten data structure - client_cert_subject = dict([x[0] for x in client_cert_subject]) - # add common name to headers (apache +StdEnvVars style) - self.headers['SSL_CLIENT_S_DN_CN'] = client_cert_subject['commonName'] - except (TypeError, AttributeError, KeyError): - # not a SSL connection or client presented no certificate with valid data - pass - - try: - self.server.auth_plugin.authenticate( - headers=self.headers, target_host=self.server.target_host, - target_port=self.server.target_port) - except auth.AuthenticationError: - ex = sys.exc_info()[1] - self.send_auth_error(ex) - raise + try: + self.server.auth_plugin.authenticate( + headers=self.headers, target_host=self.server.target_host, + target_port=self.server.target_port) + except auth.AuthenticationError: + ex = sys.exc_info()[1] + self.send_auth_error(ex) + raise def new_websocket_client(self): """ @@ -424,6 +428,8 @@ def websockify_init(): help="inetd mode, receive listening socket from stdin", action="store_true") parser.add_option("--web", default=None, metavar="DIR", help="run webserver on same port. Serve files from DIR.") + parser.add_option("--web-auth", action="store_true", + help="require authentication to access webserver.") parser.add_option("--wrap-mode", default="exit", metavar="MODE", choices=["exit", "ignore", "respawn"], help="action to take when the wrapped program exits " @@ -479,6 +485,12 @@ def websockify_init(): if opts.auth_source and not opts.auth_plugin: parser.error("You must use --auth-plugin to use --auth-source") + if opts.web_auth and not opts.auth_plugin: + parser.error("You must use --auth-plugin to use --web-auth") + + if opts.web_auth and not opts.web: + parser.error("You must use --web to use --web-auth") + # Transform to absolute path as daemon may chdir if opts.target_cfg: diff --git a/websockify/websockifyserver.py b/websockify/websockifyserver.py index 5e20128..e8c7ae5 100644 --- a/websockify/websockifyserver.py +++ b/websockify/websockifyserver.py @@ -92,6 +92,7 @@ class WebSockifyRequestHandler(WebSocketRequestHandler, SimpleHTTPRequestHandler self.handler_id = getattr(server, "handler_id", False) self.file_only = getattr(server, "file_only", False) self.traffic = getattr(server, "traffic", False) + self.web_auth = getattr(server, "web_auth", False) self.logger = getattr(server, "logger", None) if self.logger is None: @@ -217,6 +218,7 @@ class WebSockifyRequestHandler(WebSocketRequestHandler, SimpleHTTPRequestHandler def handle_upgrade(self): # ensure connection is authorized, and determine the target self.validate_connection() + self.auth_connection() WebSocketRequestHandler.handle_upgrade(self) @@ -263,6 +265,10 @@ class WebSockifyRequestHandler(WebSocketRequestHandler, SimpleHTTPRequestHandler self.send_close(exc.args[0], exc.args[1]) def do_GET(self): + if self.web_auth: + # ensure connection is authorized, this seems to apply to list_directory() as well + self.auth_connection() + if self.only_upgrade: self.send_error(405, "Method Not Allowed") else: @@ -279,10 +285,17 @@ class WebSockifyRequestHandler(WebSocketRequestHandler, SimpleHTTPRequestHandler raise Exception("WebSocketRequestHandler.new_websocket_client() must be overloaded") def validate_connection(self): - """ Ensure that the connection is a valid connection, and set the target. """ + """ Ensure that the connection has a valid token, and set the target. """ + pass + + def auth_connection(self): + """ Ensure that the connection is authorized. """ pass def do_HEAD(self): + if self.web_auth: + self.auth_connection() + if self.only_upgrade: self.send_error(405, "Method Not Allowed") else: @@ -328,7 +341,7 @@ class WebSockifyServer(object): listen_host='', listen_port=None, source_is_ipv6=False, verbose=False, cert='', key='', ssl_only=None, verify_client=False, cafile=None, - daemon=False, record='', web='', + daemon=False, record='', web='', web_auth=False, file_only=False, run_once=False, timeout=0, idle_timeout=0, traffic=False, tcp_keepalive=True, tcp_keepcnt=None, tcp_keepidle=None, @@ -349,6 +362,7 @@ class WebSockifyServer(object): self.idle_timeout = idle_timeout self.traffic = traffic self.file_only = file_only + self.web_auth = web_auth self.launch_time = time.time() self.ws_connection = False