Support for UDP socket forwarding

Allow forwarding WS->UDP. Useful for experiments and hobby projects.
This commit is contained in:
Joshua Ashton 2021-08-29 08:49:30 +01:00
parent dc345815c0
commit 5d8f5ccfab
No known key found for this signature in database
GPG Key ID: C85A08669126BE8D
2 changed files with 36 additions and 23 deletions

View File

@ -86,6 +86,7 @@ Traffic Legend:
""" """
Called after a new WebSocket connection has been established. Called after a new WebSocket connection has been established.
""" """
use_tcp = False
# Checking for a token is done in validate_connection() # Checking for a token is done in validate_connection()
# Connect to the target # Connect to the target
@ -106,15 +107,17 @@ Traffic Legend:
self.server.target_port, self.server.target_port,
connect=True, connect=True,
use_ssl=self.server.ssl_target, use_ssl=self.server.ssl_target,
unix_socket=self.server.unix_target) unix_socket=self.server.unix_target,
use_tcp=use_tcp)
except Exception as e: except Exception as e:
self.log_message("Failed to connect to %s:%s: %s", self.log_message("Failed to connect to %s:%s: %s",
self.server.target_host, self.server.target_port, e) self.server.target_host, self.server.target_port, e)
raise self.CClose(1011, "Failed to connect to downstream server") raise self.CClose(1011, "Failed to connect to downstream server")
self.request.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) if use_tcp:
if not self.server.wrap_cmd and not self.server.unix_target: self.request.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
tsock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) if not self.server.wrap_cmd and not self.server.unix_target:
tsock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
self.print_traffic(self.traffic_legend) self.print_traffic(self.traffic_legend)

View File

@ -422,7 +422,8 @@ class WebSockifyServer():
@staticmethod @staticmethod
def socket(host, port=None, connect=False, prefer_ipv6=False, def socket(host, port=None, connect=False, prefer_ipv6=False,
unix_socket=None, use_ssl=False, tcp_keepalive=True, unix_socket=None, use_ssl=False, tcp_keepalive=True,
tcp_keepcnt=None, tcp_keepidle=None, tcp_keepintvl=None): tcp_keepcnt=None, tcp_keepidle=None, tcp_keepintvl=None,
use_tcp=True):
""" Resolve a host (and optional port) to an IPv4 or IPv6 """ Resolve a host (and optional port) to an IPv4 or IPv6
address. Create a socket. Bind to it if listen is set, address. Create a socket. Bind to it if listen is set,
otherwise connect to it. Return the socket. otherwise connect to it. Return the socket.
@ -440,8 +441,12 @@ class WebSockifyServer():
flags = flags | socket.AI_PASSIVE flags = flags | socket.AI_PASSIVE
if not unix_socket: if not unix_socket:
addrs = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM, if use_tcp:
socket.IPPROTO_TCP, flags) addrs = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM,
socket.IPPROTO_TCP, flags)
else:
addrs = socket.getaddrinfo(host, port, 0, socket.SOCK_DGRAM,
socket.IPPROTO_UDP, flags)
if not addrs: if not addrs:
raise Exception("Could not resolve host '%s'" % host) raise Exception("Could not resolve host '%s'" % host)
addrs.sort(key=lambda x: x[0]) addrs.sort(key=lambda x: x[0])
@ -449,29 +454,35 @@ class WebSockifyServer():
addrs.reverse() addrs.reverse()
sock = socket.socket(addrs[0][0], addrs[0][1]) sock = socket.socket(addrs[0][0], addrs[0][1])
if tcp_keepalive: if use_tcp:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) if tcp_keepalive:
if tcp_keepcnt: sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPCNT, if tcp_keepcnt:
tcp_keepcnt) sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPCNT,
if tcp_keepidle: tcp_keepcnt)
sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPIDLE, if tcp_keepidle:
tcp_keepidle) sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPIDLE,
if tcp_keepintvl: tcp_keepidle)
sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPINTVL, if tcp_keepintvl:
tcp_keepintvl) sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPINTVL,
tcp_keepintvl)
if connect: if connect:
sock.connect(addrs[0][4]) sock.connect(addrs[0][4])
if use_ssl: if use_ssl:
sock = ssl.wrap_socket(sock) sock = ssl.wrap_socket(sock)
else: else:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) if not use_tcp:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(addrs[0][4]) sock.bind(addrs[0][4])
sock.listen(100) if use_tcp:
sock.listen(100)
else: else:
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) if use_tcp:
sock.connect(unix_socket) sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
sock.connect(unix_socket)
else:
sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
return sock return sock
@ -697,7 +708,6 @@ class WebSockifyServer():
is a WebSockets client then call new_websocket_client() method (which must is a WebSockets client then call new_websocket_client() method (which must
be overridden) for each new client connection. be overridden) for each new client connection.
""" """
if self.listen_fd != None: if self.listen_fd != None:
lsock = socket.fromfd(self.listen_fd, socket.AF_INET, socket.SOCK_STREAM) lsock = socket.fromfd(self.listen_fd, socket.AF_INET, socket.SOCK_STREAM)
else: else: