From 081046b6cd04fd2cc84488bf5862d31a016fd0ad Mon Sep 17 00:00:00 2001 From: Edward Hope-Morley Date: Fri, 20 Sep 2013 16:34:01 +0100 Subject: [PATCH] Adds optional TCP_KEEPALIVE to WebSocketServer TCP_KEEPALIVE is now enabled by default. Settings for KEEPCNT, KEEPINTVL and KEEPIDLE can be supplied when creating WebSocketServer and KEEPALIVE can also be disabled if required. Also adds new unit test for testing. Co-authored-by: natsume.takashi@lab.ntt.co.jp --- tests/test_websocket.py | 69 +++++++++++++++++++++++++++++++++++++++++ websockify/websocket.py | 37 +++++++++++++++++++--- 2 files changed, 101 insertions(+), 5 deletions(-) create mode 100644 tests/test_websocket.py diff --git a/tests/test_websocket.py b/tests/test_websocket.py new file mode 100644 index 0000000..c603189 --- /dev/null +++ b/tests/test_websocket.py @@ -0,0 +1,69 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright(c)2013 NTT corp. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Unit tests for websockify.""" + +import socket +import unittest +from websockify import websocket as websocket + + +class WebSocketTestCase(unittest.TestCase): + + def setUp(self): + """Called automatically before each test.""" + super(WebSocketTestCase, self).setUp() + + def tearDown(self): + """Called automatically after each test.""" + super(WebSocketTestCase, self).tearDown() + + def testsocket_set_keepalive_options(self): + server = websocket.WebSocketServer(listen_host='localhost', + listen_port=80, + key='./', + web='./', + record='./', + daemon=True, + ssl_only=1) + keepcnt = 12 + keepidle = 34 + keepintvl = 56 + + sock = server.socket('localhost', + tcp_keepcnt=keepcnt, + tcp_keepidle=keepidle, + tcp_keepintvl=keepintvl) + + self.assertEqual(sock.getsockopt(socket.SOL_TCP, + socket.TCP_KEEPCNT), keepcnt) + self.assertEqual(sock.getsockopt(socket.SOL_TCP, + socket.TCP_KEEPIDLE), keepidle) + self.assertEqual(sock.getsockopt(socket.SOL_TCP, + socket.TCP_KEEPINTVL), keepintvl) + + sock = server.socket('localhost', + tcp_keepalive=False, + tcp_keepcnt=keepcnt, + tcp_keepidle=keepidle, + tcp_keepintvl=keepintvl) + + self.assertNotEqual(sock.getsockopt(socket.SOL_TCP, + socket.TCP_KEEPCNT), keepcnt) + self.assertNotEqual(sock.getsockopt(socket.SOL_TCP, + socket.TCP_KEEPIDLE), keepidle) + self.assertNotEqual(sock.getsockopt(socket.SOL_TCP, + socket.TCP_KEEPINTVL), keepintvl) diff --git a/websockify/websocket.py b/websockify/websocket.py index 2548fb8..14d1422 100644 --- a/websockify/websocket.py +++ b/websockify/websocket.py @@ -58,7 +58,7 @@ for mod, msg in [('numpy', 'HyBi protocol will be slower'), globals()[mod] = __import__(mod) except ImportError: globals()[mod] = None - self.msg("WARNING: no '%s' module, %s", mod, msg) + print("WARNING: no '%s' module, %s" % (mod, msg)) if multiprocessing and sys.platform == 'win32': # make sockets pickle-able/inheritable import multiprocessing.reduction @@ -96,8 +96,11 @@ Sec-WebSocket-Accept: %s\r def __init__(self, listen_host='', listen_port=None, source_is_ipv6=False, verbose=False, cert='', key='', ssl_only=None, - daemon=False, record='', web='', file_only=False, no_parent=False, - run_once=False, timeout=0, idle_timeout=0, traffic=False): + daemon=False, record='', web='', + file_only=False, no_parent=False, + run_once=False, timeout=0, idle_timeout=0, traffic=False, + tcp_keepalive=True, tcp_keepcnt=None, tcp_keepidle=None, + tcp_keepintvl=None): # settings self.verbose = verbose @@ -120,6 +123,10 @@ Sec-WebSocket-Accept: %s\r self.no_parent = no_parent self.logger = self.get_logger() + self.tcp_keepalive = tcp_keepalive + self.tcp_keepcnt = tcp_keepcnt + self.tcp_keepidle = tcp_keepidle + self.tcp_keepintvl = tcp_keepintvl # Make paths settings absolute self.cert = os.path.abspath(cert) @@ -172,7 +179,9 @@ Sec-WebSocket-Accept: %s\r WebSocketServer.__class__.__name__)) @staticmethod - def socket(host, port=None, connect=False, prefer_ipv6=False, unix_socket=None, use_ssl=False): + def socket(host, port=None, connect=False, prefer_ipv6=False, + unix_socket=None, use_ssl=False, tcp_keepalive=True, + tcp_keepcnt=None, tcp_keepidle=None, tcp_keepintvl=None): """ Resolve a host (and optional port) to an IPv4 or IPv6 address. Create a socket. Bind to it if listen is set, otherwise connect to it. Return the socket. @@ -198,6 +207,19 @@ Sec-WebSocket-Accept: %s\r if prefer_ipv6: addrs.reverse() sock = socket.socket(addrs[0][0], addrs[0][1]) + + if tcp_keepalive: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + if tcp_keepcnt: + sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPCNT, + tcp_keepcnt) + if tcp_keepidle: + sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPIDLE, + tcp_keepidle) + if tcp_keepintvl: + sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPINTVL, + tcp_keepintvl) + if connect: sock.connect(addrs[0][4]) if use_ssl: @@ -781,7 +803,12 @@ Sec-WebSocket-Accept: %s\r is a WebSockets client then call new_client() method (which must be overridden) for each new client connection. """ - lsock = self.socket(self.listen_host, self.listen_port, False, self.prefer_ipv6) + lsock = self.socket(self.listen_host, self.listen_port, False, + self.prefer_ipv6, + tcp_keepalive=self.tcp_keepalive, + tcp_keepcnt=self.tcp_keepcnt, + tcp_keepidle=self.tcp_keepidle, + tcp_keepintvl=self.tcp_keepintvl) if self.daemon: self.daemonize(keepfd=lsock.fileno(), chdir=self.web)