diff --git a/tests/ws.py b/tests/ws.py index 40db606..301941d 100755 --- a/tests/ws.py +++ b/tests/ws.py @@ -159,4 +159,7 @@ if __name__ == '__main__': for i in range(0, 100000): rand_array.append(random.randint(0, 9)) - start_server(listen_port, test_handler) + settings['listen_port'] = listen_port + settings['daemon'] = False + settings['handler'] = test_handler + start_server() diff --git a/tests/wsencoding.py b/tests/wsencoding.py index e4d1477..c44b934 100755 --- a/tests/wsencoding.py +++ b/tests/wsencoding.py @@ -81,4 +81,7 @@ if __name__ == '__main__': print "Usage: " sys.exit(1) - start_server(listen_port, responder) + settings['listen_port'] = listen_port + settings['daemon'] = False + settings['handler'] = responder + start_server() diff --git a/utils/websocket.c b/utils/websocket.c index 7a67f22..eb4e1a6 100644 --- a/utils/websocket.c +++ b/utils/websocket.c @@ -14,6 +14,8 @@ #include #include #include +#include // daemonizing +#include // daemonizing #include #include #include /* base64 encode/decode */ @@ -37,6 +39,7 @@ const char policy_response[] = "", 22) == 0) { @@ -292,11 +296,11 @@ ws_ctx_t *do_handshake(int sock, int ssl_only) { return NULL; } else if (bcmp(handshake, "\x16", 1) == 0) { // SSL - ws_ctx = ws_socket_ssl(sock, "self.pem"); + ws_ctx = ws_socket_ssl(sock, settings.cert); if (! ws_ctx) { return NULL; } scheme = "wss"; - printf("Using SSL socket\n"); - } else if (ssl_only) { + printf(" using SSL socket\n"); + } else if (settings.ssl_only) { printf("Non-SSL connection disallowed"); close(sock); return NULL; @@ -304,7 +308,7 @@ ws_ctx_t *do_handshake(int sock, int ssl_only) { ws_ctx = ws_socket(sock); if (! ws_ctx) { return NULL; } scheme = "ws"; - printf("Using plain (not SSL) socket\n"); + printf(" using plain (not SSL) socket\n"); } len = ws_recv(ws_ctx, handshake, 4096); handshake[len] = 0; @@ -327,7 +331,7 @@ ws_ctx_t *do_handshake(int sock, int ssl_only) { //printf("host: %s\n", host); //printf("origin: %s\n", origin); - // TODO: parse out client settings + // Parse client settings from the GET path args_start = strstr(path, "?"); if (args_start) { if (strstr(args_start, "#")) { @@ -337,31 +341,70 @@ ws_ctx_t *do_handshake(int sock, int ssl_only) { } arg_idx = strstr(args_start, "b64encode"); if (arg_idx && arg_idx < args_end) { - //printf("setting b64encode\n"); + printf(" b64encode=1\n"); client_settings.do_b64encode = 1; } arg_idx = strstr(args_start, "seq_num"); if (arg_idx && arg_idx < args_end) { - //printf("setting seq_num\n"); + printf(" seq_num=1\n"); client_settings.do_seq_num = 1; } } sprintf(response, server_handshake, origin, scheme, host, path); - printf("response: %s\n", response); + //printf("response: %s\n", response); ws_send(ws_ctx, response, strlen(response)); return ws_ctx; } -void start_server(int listen_port, - void (*handler)(ws_ctx_t*), - char *listen_host, - int ssl_only) { +void signal_handler(sig) { + switch (sig) { + case SIGHUP: break; // ignore + case SIGTERM: exit(0); break; + } +} + +void daemonize() { + int pid, i; + + umask(0); + chdir('/'); + setgid(getgid()); + setuid(getuid()); + + /* Double fork to daemonize */ + pid = fork(); + if (pid<0) { fatal("fork error"); } + if (pid>0) { exit(0); } // parent exits + setsid(); // Obtain new process group + pid = fork(); + if (pid<0) { fatal("fork error"); } + if (pid>0) { exit(0); } // parent exits + + /* Signal handling */ + signal(SIGHUP, signal_handler); // catch HUP + signal(SIGTERM, signal_handler); // catch kill + + /* Close open files */ + for (i=getdtablesize(); i>=0; --i) { + close(i); + } + i=open("/dev/null", O_RDWR); // Redirect stdin + dup(i); // Redirect stdout + dup(i); // Redirect stderr +} + + +void start_server() { int lsock, csock, clilen, sopt = 1, i; struct sockaddr_in serv_addr, cli_addr; ws_ctx_t *ws_ctx; + if (settings.daemon) { + daemonize(); + } + /* Initialize buffers */ bufsize = 65536; if (! (tbuf = malloc(bufsize)) ) @@ -377,15 +420,15 @@ void start_server(int listen_port, if (lsock < 0) { error("ERROR creating listener socket"); } bzero((char *) &serv_addr, sizeof(serv_addr)); serv_addr.sin_family = AF_INET; - serv_addr.sin_port = htons(listen_port); + serv_addr.sin_port = htons(settings.listen_port); /* Resolve listen address */ - if ((listen_host == NULL) || (listen_host[0] == '\0')) { - serv_addr.sin_addr.s_addr = INADDR_ANY; - } else { - if (resolve_host(&serv_addr.sin_addr, listen_host) < -1) { + if (settings.listen_host && (settings.listen_host[0] != '\0')) { + if (resolve_host(&serv_addr.sin_addr, settings.listen_host) < -1) { fatal("Could not resolve listen address"); } + } else { + serv_addr.sin_addr.s_addr = INADDR_ANY; } setsockopt(lsock, SOL_SOCKET, SO_REUSEADDR, (char *)&sopt, sizeof(sopt)); @@ -396,10 +439,12 @@ void start_server(int listen_port, while (1) { clilen = sizeof(cli_addr); - if (listen_host) { - printf("waiting for connection on %s:%d\n", listen_host, listen_port); + if (settings.listen_host && settings.listen_host[0] != '\0') { + printf("waiting for connection on %s:%d\n", + settings.listen_host, settings.listen_port); } else { - printf("waiting for connection on port %d\n", listen_port); + printf("waiting for connection on port %d\n", + settings.listen_port); } csock = accept(lsock, (struct sockaddr *) &cli_addr, @@ -409,7 +454,7 @@ void start_server(int listen_port, continue; } printf("Got client connection from %s\n", inet_ntoa(cli_addr.sin_addr)); - ws_ctx = do_handshake(csock, ssl_only); + ws_ctx = do_handshake(csock); if (ws_ctx == NULL) { close(csock); continue; @@ -425,7 +470,7 @@ void start_server(int listen_port, dbufsize = (bufsize/2) - 20; } - handler(ws_ctx); + settings.handler(ws_ctx); close(csock); } diff --git a/utils/websocket.h b/utils/websocket.h index 7b1f179..9520018 100644 --- a/utils/websocket.h +++ b/utils/websocket.h @@ -1,4 +1,5 @@ #include +#include typedef struct { int sockfd; @@ -6,6 +7,16 @@ typedef struct { SSL *ssl; } ws_ctx_t; +typedef struct { + char listen_host[256]; + int listen_port; + void (*handler)(ws_ctx_t*); + int ssl_only; + int daemon; + char record[1024]; + char cert[1024]; +} settings_t; + typedef struct { int do_b64encode; int do_seq_num; diff --git a/utils/websocket.py b/utils/websocket.py index a46f794..e89b946 100755 --- a/utils/websocket.py +++ b/utils/websocket.py @@ -10,9 +10,21 @@ as taken from http://docs.python.org/dev/library/ssl.html#certificates ''' import sys, socket, ssl, traceback +import os, resource, errno, signal # daemonizing from base64 import b64encode, b64decode -client_settings = {} +settings = { + 'listen_host' : '', + 'listen_port' : None, + 'handler' : None, + 'cert' : None, + 'ssl_only' : False, + 'daemon' : True, + 'record' : None, } +client_settings = { + 'b64encode' : False, + 'seq_num' : False, } + send_seq = 0 server_handshake = """HTTP/1.1 101 Web Socket Protocol Handshake\r @@ -33,35 +45,39 @@ def traffic(token="."): def decode(buf): """ Parse out WebSocket packets. """ if buf.count('\xff') > 1: - if client_settings["b64encode"]: + if client_settings['b64encode']: return [b64decode(d[1:]) for d in buf.split('\xff')] else: # Modified UTF-8 decode return [d[1:].replace("\xc4\x80", "\x00").decode('utf-8').encode('latin-1') for d in buf.split('\xff')] else: - if client_settings["b64encode"]: + if client_settings['b64encode']: return [b64decode(buf[1:-1])] else: return [buf[1:-1].replace("\xc4\x80", "\x00").decode('utf-8').encode('latin-1')] def encode(buf): global send_seq - if client_settings["b64encode"]: + if client_settings['b64encode']: buf = b64encode(buf) else: # Modified UTF-8 encode buf = buf.decode('latin-1').encode('utf-8').replace("\x00", "\xc4\x80") - if client_settings["seq_num"]: + if client_settings['seq_num']: send_seq += 1 return "\x00%d:%s\xff" % (send_seq-1, buf) else: return "\x00%s\xff" % buf -def do_handshake(sock, ssl_only=False): +def do_handshake(sock): global client_settings, send_seq + + client_settings['b64encode'] = False + client_settings['seq_num'] = False send_seq = 0 + # Peek, but don't read the data handshake = sock.recv(1024, socket.MSG_PEEK) #print "Handshake [%s]" % repr(handshake) @@ -75,54 +91,88 @@ def do_handshake(sock, ssl_only=False): retsock = ssl.wrap_socket( sock, server_side=True, - certfile='self.pem', + certfile=settings['cert'], ssl_version=ssl.PROTOCOL_TLSv1) scheme = "wss" - print "Using SSL/TLS" - elif ssl_only: + print " using SSL/TLS" + elif settings['ssl_only']: print "Non-SSL connection disallowed" sock.close() return False else: retsock = sock scheme = "ws" - print "Using plain (not SSL) socket" + print " using plain (not SSL) socket" handshake = retsock.recv(4096) req_lines = handshake.split("\r\n") _, path, _ = req_lines[0].split(" ") _, origin = req_lines[4].split(" ") _, host = req_lines[3].split(" ") - # Parse settings from the path + # Parse client settings from the GET path cvars = path.partition('?')[2].partition('#')[0].split('&') - client_settings = {'b64encode': None, 'seq_num': None} for cvar in [c for c in cvars if c]: - name, _, value = cvar.partition('=') - client_settings[name] = value and value or True - - print "client_settings:", client_settings + name, _, val = cvar.partition('=') + if name not in ['b64encode', 'seq_num']: continue + value = val and val or True + client_settings[name] = value + print " %s=%s" % (name, value) retsock.send(server_handshake % (origin, scheme, host, path)) return retsock -def start_server(listen_port, handler, listen_host='', ssl_only=False): +def daemonize(): + os.umask(0) + os.chdir('/') + os.setgid(os.getgid()) # relinquish elevations + os.setuid(os.getuid()) # relinquish elevations + + # Double fork to daemonize + if os.fork() > 0: os._exit(0) # Parent exits + os.setsid() # Obtain new process group + if os.fork() > 0: os._exit(0) # Parent exits + + # Signal handling + def terminate(a,b): os._exit(0) + signal.signal(signal.SIGTERM, terminate) + signal.signal(signal.SIGINT, signal.SIG_IGN) + + # Close open files + maxfd = resource.getrlimit(resource.RLIMIT_NOFILE)[1] + if maxfd == resource.RLIM_INFINITY: maxfd = 256 + for fd in reversed(range(maxfd)): + try: + os.close(fd) + except OSError, exc: + if exc.errno != errno.EBADF: raise + + # Redirect I/O to /dev/null + os.dup2(os.open(os.devnull, os.O_RDWR), sys.stdin.fileno()) + os.dup2(os.open(os.devnull, os.O_RDWR), sys.stdout.fileno()) + os.dup2(os.open(os.devnull, os.O_RDWR), sys.stderr.fileno()) + + +def start_server(): + + if settings['daemon']: daemonize() + lsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) lsock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - lsock.bind((listen_host, listen_port)) + lsock.bind((settings['listen_host'], settings['listen_port'])) lsock.listen(100) while True: try: - csock = None - print 'waiting for connection on port %s' % listen_port + csock = startsock = None + print 'waiting for connection on port %s' % settings['listen_port'] startsock, address = lsock.accept() print 'Got client connection from %s' % address[0] - csock = do_handshake(startsock, ssl_only=ssl_only) + csock = do_handshake(startsock) if not csock: continue - handler(csock) + settings['handler'](csock) except Exception: print "Ignoring exception:" print traceback.format_exc() if csock: csock.close() - + if startsock and startsock != csock: startsock.close() diff --git a/utils/wsproxy.c b/utils/wsproxy.c index 0f36229..a03c3a7 100644 --- a/utils/wsproxy.c +++ b/utils/wsproxy.c @@ -36,11 +36,11 @@ void usage() { exit(1); } -char *target_host; +char target_host[256]; int target_port; -char *record_filename = NULL; int recordfd = 0; +extern settings_t settings; extern char *tbuf, *cbuf, *tbuf_tmp, *cbuf_tmp; extern unsigned int bufsize, dbufsize; @@ -198,6 +198,11 @@ void proxy_handler(ws_ctx_t *ws_ctx) { int tsock = 0; struct sockaddr_in taddr; + if (settings.record) { + recordfd = open(settings.record, O_WRONLY | O_CREAT | O_TRUNC, + S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH); + } + printf("Connecting to: %s:%d\n", target_host, target_port); tsock = socket(AF_INET, SOCK_STREAM, 0); @@ -220,11 +225,6 @@ void proxy_handler(ws_ctx_t *ws_ctx) { return; } - if (record_filename) { - recordfd = open(record_filename, O_WRONLY | O_CREAT | O_TRUNC, - S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH); - } - printf("%s", traffic_legend); do_proxy(ws_ctx, tsock); @@ -239,52 +239,74 @@ void proxy_handler(ws_ctx_t *ws_ctx) { int main(int argc, char *argv[]) { int listen_port, c, option_index = 0; - static int ssl_only = 0; - char *listen_host; + static int ssl_only = 0, foreground = 0; + char *found; static struct option long_options[] = { - {"ssl-only", no_argument, &ssl_only, 1}, + {"ssl-only", no_argument, &ssl_only, 1 }, + {"foreground", no_argument, &foreground, 'f'}, /* ---- */ - {"record", required_argument, 0, 'r'}, + {"record", required_argument, 0, 'r'}, + {"cert", required_argument, 0, 'c'}, {0, 0, 0, 0} }; + settings.record[0] = '\0'; + strcpy(settings.cert, "self.pem"); + while (1) { - c = getopt_long (argc, argv, "r:", + c = getopt_long (argc, argv, "fr:c:", long_options, &option_index); /* Detect the end */ if (c == -1) { break; } switch (c) { - case 0: break; // ignore - case 1: break; // ignore - case 'r': record_filename = optarg; break; - default: usage(); + case 0: + break; // ignore + case 1: + break; // ignore + case 'f': + foreground = 1; + break; + case 'r': + memcpy(settings.record, optarg, sizeof(settings.record)); + break; + case 'c': + memcpy(settings.cert, optarg, sizeof(settings.cert)); + break; + default: + usage(); } } + settings.ssl_only = ssl_only; + settings.daemon = foreground ? 0: 1; - printf("ssl_only: %d\n", ssl_only); - printf("record_filename: %s\n", record_filename); + printf(" ssl_only: %d\n", settings.ssl_only); + printf(" daemon: %d\n", settings.daemon); + printf(" record: %s\n", settings.record); + printf(" cert: %s\n", settings.cert); if ((argc-optind) != 2) { usage(); } - if (strstr(argv[optind], ":")) { - listen_host = strtok(argv[optind], ":"); - listen_port = strtol(strtok(NULL, ":"), NULL, 10); + found = strstr(argv[optind], ":"); + if (found) { + memcpy(settings.listen_host, argv[optind], found-argv[optind]); + settings.listen_port = strtol(found+1, NULL, 10); } else { - listen_host = NULL; - listen_port = strtol(argv[optind], NULL, 10); + settings.listen_host[0] = '\0'; + settings.listen_port = strtol(argv[optind], NULL, 10); } optind++; if ((errno != 0) || (listen_port == 0)) { usage(); } - if (strstr(argv[optind], ":")) { - target_host = strtok(argv[optind], ":"); - target_port = strtol(strtok(NULL, ":"), NULL, 10); + found = strstr(argv[optind], ":"); + if (found) { + memcpy(target_host, argv[optind], found-argv[optind]); + target_port = strtol(found+1, NULL, 10); } else { usage(); } @@ -303,7 +325,8 @@ int main(int argc, char *argv[]) if (! (cbuf_tmp = malloc(bufsize)) ) { fatal("malloc()"); } - start_server(listen_port, &proxy_handler, listen_host, ssl_only); + settings.handler = proxy_handler; + start_server(); free(tbuf); free(cbuf); diff --git a/utils/wsproxy.py b/utils/wsproxy.py index a82ce70..64aeaa8 100755 --- a/utils/wsproxy.py +++ b/utils/wsproxy.py @@ -99,14 +99,14 @@ def do_proxy(client, target): def proxy_handler(client): global target_host, target_port, options, rec + if settings['record']: + print "Opening record file: %s" % settings['record'] + rec = open(settings['record'], 'w') + print "Connecting to: %s:%s" % (target_host, target_port) tsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) tsock.connect((target_host, target_port)) - if options.record: - print "Opening record file: %s" % options.record - rec = open(options.record, 'w') - print traffic_legend try: @@ -122,25 +122,35 @@ if __name__ == '__main__': parser = optparse.OptionParser(usage=usage) parser.add_option("--record", help="record session to a file", metavar="FILE") + parser.add_option("--foreground", "-f", + dest="daemon", default=True, action="store_false", + help="stay in foreground, do not daemonize") parser.add_option("--ssl-only", action="store_true", help="disallow non-encrypted connections") + parser.add_option("--cert", default="self.pem", + help="SSL certificate") (options, args) = parser.parse_args() if len(args) > 2: parser.error("Too many arguments") if len(args) < 2: parser.error("Too few arguments") if args[0].count(':') > 0: - listen_host,listen_port = args[0].split(':') + host,port = args[0].split(':') else: - listen_host = '' - listen_port = args[0] + host,port = '',args[0] if args[1].count(':') > 0: target_host,target_port = args[1].split(':') else: parser.error("Error parsing target") - try: listen_port = int(listen_port) + try: port = int(port) except: parser.error("Error parsing listen port") try: target_port = int(target_port) except: parser.error("Error parsing target port") - start_server(listen_port, proxy_handler, listen_host=listen_host, - ssl_only=options.ssl_only) + settings['listen_host'] = host + settings['listen_port'] = port + settings['handler'] = proxy_handler + settings['cert'] = os.path.abspath(options.cert) + settings['ssl_only'] = options.ssl_only + settings['daemon'] = options.daemon + settings['record'] = os.path.abspath(options.record) + start_server()