Add daemonization support to wsproxy.*.

Refactor how settings are passed around.
This commit is contained in:
Joel Martin 2010-06-17 16:06:18 -05:00
parent b2fd1bc374
commit 6ee61a4cf6
7 changed files with 230 additions and 85 deletions

View File

@ -159,4 +159,7 @@ if __name__ == '__main__':
for i in range(0, 100000):
rand_array.append(random.randint(0, 9))
start_server(listen_port, test_handler)
settings['listen_port'] = listen_port
settings['daemon'] = False
settings['handler'] = test_handler
start_server()

View File

@ -81,4 +81,7 @@ if __name__ == '__main__':
print "Usage: <listen_port>"
sys.exit(1)
start_server(listen_port, responder)
settings['listen_port'] = listen_port
settings['daemon'] = False
settings['handler'] = responder
start_server()

View File

@ -14,6 +14,8 @@
#include <netinet/in.h>
#include <arpa/inet.h>
#include <netdb.h>
#include <signal.h> // daemonizing
#include <fcntl.h> // daemonizing
#include <openssl/err.h>
#include <openssl/ssl.h>
#include <resolv.h> /* base64 encode/decode */
@ -37,6 +39,7 @@ const char policy_response[] = "<cross-domain-policy><allow-access-from domain=\
int ssl_initialized = 0;
char *tbuf, *cbuf, *tbuf_tmp, *cbuf_tmp;
unsigned int bufsize, dbufsize;
settings_t settings;
client_settings_t client_settings;
void traffic(char * token) {
@ -269,7 +272,7 @@ int decode(char *src, size_t srclength, u_char *target, size_t targsize) {
return retlen;
}
ws_ctx_t *do_handshake(int sock, int ssl_only) {
ws_ctx_t *do_handshake(int sock) {
char handshake[4096], response[4096];
char *scheme, *line, *path, *host, *origin;
char *args_start, *args_end, *arg_idx;
@ -281,6 +284,7 @@ ws_ctx_t *do_handshake(int sock, int ssl_only) {
client_settings.do_seq_num = 0;
client_settings.seq_num = 0;
// Peek, but don't read the data
len = recv(sock, handshake, 1024, MSG_PEEK);
handshake[len] = 0;
if (bcmp(handshake, "<policy-file-request/>", 22) == 0) {
@ -292,11 +296,11 @@ ws_ctx_t *do_handshake(int sock, int ssl_only) {
return NULL;
} else if (bcmp(handshake, "\x16", 1) == 0) {
// SSL
ws_ctx = ws_socket_ssl(sock, "self.pem");
ws_ctx = ws_socket_ssl(sock, settings.cert);
if (! ws_ctx) { return NULL; }
scheme = "wss";
printf("Using SSL socket\n");
} else if (ssl_only) {
printf(" using SSL socket\n");
} else if (settings.ssl_only) {
printf("Non-SSL connection disallowed");
close(sock);
return NULL;
@ -304,7 +308,7 @@ ws_ctx_t *do_handshake(int sock, int ssl_only) {
ws_ctx = ws_socket(sock);
if (! ws_ctx) { return NULL; }
scheme = "ws";
printf("Using plain (not SSL) socket\n");
printf(" using plain (not SSL) socket\n");
}
len = ws_recv(ws_ctx, handshake, 4096);
handshake[len] = 0;
@ -327,7 +331,7 @@ ws_ctx_t *do_handshake(int sock, int ssl_only) {
//printf("host: %s\n", host);
//printf("origin: %s\n", origin);
// TODO: parse out client settings
// Parse client settings from the GET path
args_start = strstr(path, "?");
if (args_start) {
if (strstr(args_start, "#")) {
@ -337,31 +341,70 @@ ws_ctx_t *do_handshake(int sock, int ssl_only) {
}
arg_idx = strstr(args_start, "b64encode");
if (arg_idx && arg_idx < args_end) {
//printf("setting b64encode\n");
printf(" b64encode=1\n");
client_settings.do_b64encode = 1;
}
arg_idx = strstr(args_start, "seq_num");
if (arg_idx && arg_idx < args_end) {
//printf("setting seq_num\n");
printf(" seq_num=1\n");
client_settings.do_seq_num = 1;
}
}
sprintf(response, server_handshake, origin, scheme, host, path);
printf("response: %s\n", response);
//printf("response: %s\n", response);
ws_send(ws_ctx, response, strlen(response));
return ws_ctx;
}
void start_server(int listen_port,
void (*handler)(ws_ctx_t*),
char *listen_host,
int ssl_only) {
void signal_handler(sig) {
switch (sig) {
case SIGHUP: break; // ignore
case SIGTERM: exit(0); break;
}
}
void daemonize() {
int pid, i;
umask(0);
chdir('/');
setgid(getgid());
setuid(getuid());
/* Double fork to daemonize */
pid = fork();
if (pid<0) { fatal("fork error"); }
if (pid>0) { exit(0); } // parent exits
setsid(); // Obtain new process group
pid = fork();
if (pid<0) { fatal("fork error"); }
if (pid>0) { exit(0); } // parent exits
/* Signal handling */
signal(SIGHUP, signal_handler); // catch HUP
signal(SIGTERM, signal_handler); // catch kill
/* Close open files */
for (i=getdtablesize(); i>=0; --i) {
close(i);
}
i=open("/dev/null", O_RDWR); // Redirect stdin
dup(i); // Redirect stdout
dup(i); // Redirect stderr
}
void start_server() {
int lsock, csock, clilen, sopt = 1, i;
struct sockaddr_in serv_addr, cli_addr;
ws_ctx_t *ws_ctx;
if (settings.daemon) {
daemonize();
}
/* Initialize buffers */
bufsize = 65536;
if (! (tbuf = malloc(bufsize)) )
@ -377,15 +420,15 @@ void start_server(int listen_port,
if (lsock < 0) { error("ERROR creating listener socket"); }
bzero((char *) &serv_addr, sizeof(serv_addr));
serv_addr.sin_family = AF_INET;
serv_addr.sin_port = htons(listen_port);
serv_addr.sin_port = htons(settings.listen_port);
/* Resolve listen address */
if ((listen_host == NULL) || (listen_host[0] == '\0')) {
serv_addr.sin_addr.s_addr = INADDR_ANY;
} else {
if (resolve_host(&serv_addr.sin_addr, listen_host) < -1) {
if (settings.listen_host && (settings.listen_host[0] != '\0')) {
if (resolve_host(&serv_addr.sin_addr, settings.listen_host) < -1) {
fatal("Could not resolve listen address");
}
} else {
serv_addr.sin_addr.s_addr = INADDR_ANY;
}
setsockopt(lsock, SOL_SOCKET, SO_REUSEADDR, (char *)&sopt, sizeof(sopt));
@ -396,10 +439,12 @@ void start_server(int listen_port,
while (1) {
clilen = sizeof(cli_addr);
if (listen_host) {
printf("waiting for connection on %s:%d\n", listen_host, listen_port);
if (settings.listen_host && settings.listen_host[0] != '\0') {
printf("waiting for connection on %s:%d\n",
settings.listen_host, settings.listen_port);
} else {
printf("waiting for connection on port %d\n", listen_port);
printf("waiting for connection on port %d\n",
settings.listen_port);
}
csock = accept(lsock,
(struct sockaddr *) &cli_addr,
@ -409,7 +454,7 @@ void start_server(int listen_port,
continue;
}
printf("Got client connection from %s\n", inet_ntoa(cli_addr.sin_addr));
ws_ctx = do_handshake(csock, ssl_only);
ws_ctx = do_handshake(csock);
if (ws_ctx == NULL) {
close(csock);
continue;
@ -425,7 +470,7 @@ void start_server(int listen_port,
dbufsize = (bufsize/2) - 20;
}
handler(ws_ctx);
settings.handler(ws_ctx);
close(csock);
}

View File

@ -1,4 +1,5 @@
#include <openssl/ssl.h>
#include <unistd.h>
typedef struct {
int sockfd;
@ -6,6 +7,16 @@ typedef struct {
SSL *ssl;
} ws_ctx_t;
typedef struct {
char listen_host[256];
int listen_port;
void (*handler)(ws_ctx_t*);
int ssl_only;
int daemon;
char record[1024];
char cert[1024];
} settings_t;
typedef struct {
int do_b64encode;
int do_seq_num;

View File

@ -10,9 +10,21 @@ as taken from http://docs.python.org/dev/library/ssl.html#certificates
'''
import sys, socket, ssl, traceback
import os, resource, errno, signal # daemonizing
from base64 import b64encode, b64decode
client_settings = {}
settings = {
'listen_host' : '',
'listen_port' : None,
'handler' : None,
'cert' : None,
'ssl_only' : False,
'daemon' : True,
'record' : None, }
client_settings = {
'b64encode' : False,
'seq_num' : False, }
send_seq = 0
server_handshake = """HTTP/1.1 101 Web Socket Protocol Handshake\r
@ -33,35 +45,39 @@ def traffic(token="."):
def decode(buf):
""" Parse out WebSocket packets. """
if buf.count('\xff') > 1:
if client_settings["b64encode"]:
if client_settings['b64encode']:
return [b64decode(d[1:]) for d in buf.split('\xff')]
else:
# Modified UTF-8 decode
return [d[1:].replace("\xc4\x80", "\x00").decode('utf-8').encode('latin-1') for d in buf.split('\xff')]
else:
if client_settings["b64encode"]:
if client_settings['b64encode']:
return [b64decode(buf[1:-1])]
else:
return [buf[1:-1].replace("\xc4\x80", "\x00").decode('utf-8').encode('latin-1')]
def encode(buf):
global send_seq
if client_settings["b64encode"]:
if client_settings['b64encode']:
buf = b64encode(buf)
else:
# Modified UTF-8 encode
buf = buf.decode('latin-1').encode('utf-8').replace("\x00", "\xc4\x80")
if client_settings["seq_num"]:
if client_settings['seq_num']:
send_seq += 1
return "\x00%d:%s\xff" % (send_seq-1, buf)
else:
return "\x00%s\xff" % buf
def do_handshake(sock, ssl_only=False):
def do_handshake(sock):
global client_settings, send_seq
client_settings['b64encode'] = False
client_settings['seq_num'] = False
send_seq = 0
# Peek, but don't read the data
handshake = sock.recv(1024, socket.MSG_PEEK)
#print "Handshake [%s]" % repr(handshake)
@ -75,54 +91,88 @@ def do_handshake(sock, ssl_only=False):
retsock = ssl.wrap_socket(
sock,
server_side=True,
certfile='self.pem',
certfile=settings['cert'],
ssl_version=ssl.PROTOCOL_TLSv1)
scheme = "wss"
print "Using SSL/TLS"
elif ssl_only:
print " using SSL/TLS"
elif settings['ssl_only']:
print "Non-SSL connection disallowed"
sock.close()
return False
else:
retsock = sock
scheme = "ws"
print "Using plain (not SSL) socket"
print " using plain (not SSL) socket"
handshake = retsock.recv(4096)
req_lines = handshake.split("\r\n")
_, path, _ = req_lines[0].split(" ")
_, origin = req_lines[4].split(" ")
_, host = req_lines[3].split(" ")
# Parse settings from the path
# Parse client settings from the GET path
cvars = path.partition('?')[2].partition('#')[0].split('&')
client_settings = {'b64encode': None, 'seq_num': None}
for cvar in [c for c in cvars if c]:
name, _, value = cvar.partition('=')
client_settings[name] = value and value or True
print "client_settings:", client_settings
name, _, val = cvar.partition('=')
if name not in ['b64encode', 'seq_num']: continue
value = val and val or True
client_settings[name] = value
print " %s=%s" % (name, value)
retsock.send(server_handshake % (origin, scheme, host, path))
return retsock
def start_server(listen_port, handler, listen_host='', ssl_only=False):
def daemonize():
os.umask(0)
os.chdir('/')
os.setgid(os.getgid()) # relinquish elevations
os.setuid(os.getuid()) # relinquish elevations
# Double fork to daemonize
if os.fork() > 0: os._exit(0) # Parent exits
os.setsid() # Obtain new process group
if os.fork() > 0: os._exit(0) # Parent exits
# Signal handling
def terminate(a,b): os._exit(0)
signal.signal(signal.SIGTERM, terminate)
signal.signal(signal.SIGINT, signal.SIG_IGN)
# Close open files
maxfd = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
if maxfd == resource.RLIM_INFINITY: maxfd = 256
for fd in reversed(range(maxfd)):
try:
os.close(fd)
except OSError, exc:
if exc.errno != errno.EBADF: raise
# Redirect I/O to /dev/null
os.dup2(os.open(os.devnull, os.O_RDWR), sys.stdin.fileno())
os.dup2(os.open(os.devnull, os.O_RDWR), sys.stdout.fileno())
os.dup2(os.open(os.devnull, os.O_RDWR), sys.stderr.fileno())
def start_server():
if settings['daemon']: daemonize()
lsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
lsock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
lsock.bind((listen_host, listen_port))
lsock.bind((settings['listen_host'], settings['listen_port']))
lsock.listen(100)
while True:
try:
csock = None
print 'waiting for connection on port %s' % listen_port
csock = startsock = None
print 'waiting for connection on port %s' % settings['listen_port']
startsock, address = lsock.accept()
print 'Got client connection from %s' % address[0]
csock = do_handshake(startsock, ssl_only=ssl_only)
csock = do_handshake(startsock)
if not csock: continue
handler(csock)
settings['handler'](csock)
except Exception:
print "Ignoring exception:"
print traceback.format_exc()
if csock: csock.close()
if startsock and startsock != csock: startsock.close()

View File

@ -36,11 +36,11 @@ void usage() {
exit(1);
}
char *target_host;
char target_host[256];
int target_port;
char *record_filename = NULL;
int recordfd = 0;
extern settings_t settings;
extern char *tbuf, *cbuf, *tbuf_tmp, *cbuf_tmp;
extern unsigned int bufsize, dbufsize;
@ -198,6 +198,11 @@ void proxy_handler(ws_ctx_t *ws_ctx) {
int tsock = 0;
struct sockaddr_in taddr;
if (settings.record) {
recordfd = open(settings.record, O_WRONLY | O_CREAT | O_TRUNC,
S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH);
}
printf("Connecting to: %s:%d\n", target_host, target_port);
tsock = socket(AF_INET, SOCK_STREAM, 0);
@ -220,11 +225,6 @@ void proxy_handler(ws_ctx_t *ws_ctx) {
return;
}
if (record_filename) {
recordfd = open(record_filename, O_WRONLY | O_CREAT | O_TRUNC,
S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH);
}
printf("%s", traffic_legend);
do_proxy(ws_ctx, tsock);
@ -239,52 +239,74 @@ void proxy_handler(ws_ctx_t *ws_ctx) {
int main(int argc, char *argv[])
{
int listen_port, c, option_index = 0;
static int ssl_only = 0;
char *listen_host;
static int ssl_only = 0, foreground = 0;
char *found;
static struct option long_options[] = {
{"ssl-only", no_argument, &ssl_only, 1},
{"ssl-only", no_argument, &ssl_only, 1 },
{"foreground", no_argument, &foreground, 'f'},
/* ---- */
{"record", required_argument, 0, 'r'},
{"record", required_argument, 0, 'r'},
{"cert", required_argument, 0, 'c'},
{0, 0, 0, 0}
};
settings.record[0] = '\0';
strcpy(settings.cert, "self.pem");
while (1) {
c = getopt_long (argc, argv, "r:",
c = getopt_long (argc, argv, "fr:c:",
long_options, &option_index);
/* Detect the end */
if (c == -1) { break; }
switch (c) {
case 0: break; // ignore
case 1: break; // ignore
case 'r': record_filename = optarg; break;
default: usage();
case 0:
break; // ignore
case 1:
break; // ignore
case 'f':
foreground = 1;
break;
case 'r':
memcpy(settings.record, optarg, sizeof(settings.record));
break;
case 'c':
memcpy(settings.cert, optarg, sizeof(settings.cert));
break;
default:
usage();
}
}
settings.ssl_only = ssl_only;
settings.daemon = foreground ? 0: 1;
printf("ssl_only: %d\n", ssl_only);
printf("record_filename: %s\n", record_filename);
printf(" ssl_only: %d\n", settings.ssl_only);
printf(" daemon: %d\n", settings.daemon);
printf(" record: %s\n", settings.record);
printf(" cert: %s\n", settings.cert);
if ((argc-optind) != 2) {
usage();
}
if (strstr(argv[optind], ":")) {
listen_host = strtok(argv[optind], ":");
listen_port = strtol(strtok(NULL, ":"), NULL, 10);
found = strstr(argv[optind], ":");
if (found) {
memcpy(settings.listen_host, argv[optind], found-argv[optind]);
settings.listen_port = strtol(found+1, NULL, 10);
} else {
listen_host = NULL;
listen_port = strtol(argv[optind], NULL, 10);
settings.listen_host[0] = '\0';
settings.listen_port = strtol(argv[optind], NULL, 10);
}
optind++;
if ((errno != 0) || (listen_port == 0)) {
usage();
}
if (strstr(argv[optind], ":")) {
target_host = strtok(argv[optind], ":");
target_port = strtol(strtok(NULL, ":"), NULL, 10);
found = strstr(argv[optind], ":");
if (found) {
memcpy(target_host, argv[optind], found-argv[optind]);
target_port = strtol(found+1, NULL, 10);
} else {
usage();
}
@ -303,7 +325,8 @@ int main(int argc, char *argv[])
if (! (cbuf_tmp = malloc(bufsize)) )
{ fatal("malloc()"); }
start_server(listen_port, &proxy_handler, listen_host, ssl_only);
settings.handler = proxy_handler;
start_server();
free(tbuf);
free(cbuf);

View File

@ -99,14 +99,14 @@ def do_proxy(client, target):
def proxy_handler(client):
global target_host, target_port, options, rec
if settings['record']:
print "Opening record file: %s" % settings['record']
rec = open(settings['record'], 'w')
print "Connecting to: %s:%s" % (target_host, target_port)
tsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
tsock.connect((target_host, target_port))
if options.record:
print "Opening record file: %s" % options.record
rec = open(options.record, 'w')
print traffic_legend
try:
@ -122,25 +122,35 @@ if __name__ == '__main__':
parser = optparse.OptionParser(usage=usage)
parser.add_option("--record",
help="record session to a file", metavar="FILE")
parser.add_option("--foreground", "-f",
dest="daemon", default=True, action="store_false",
help="stay in foreground, do not daemonize")
parser.add_option("--ssl-only", action="store_true",
help="disallow non-encrypted connections")
parser.add_option("--cert", default="self.pem",
help="SSL certificate")
(options, args) = parser.parse_args()
if len(args) > 2: parser.error("Too many arguments")
if len(args) < 2: parser.error("Too few arguments")
if args[0].count(':') > 0:
listen_host,listen_port = args[0].split(':')
host,port = args[0].split(':')
else:
listen_host = ''
listen_port = args[0]
host,port = '',args[0]
if args[1].count(':') > 0:
target_host,target_port = args[1].split(':')
else:
parser.error("Error parsing target")
try: listen_port = int(listen_port)
try: port = int(port)
except: parser.error("Error parsing listen port")
try: target_port = int(target_port)
except: parser.error("Error parsing target port")
start_server(listen_port, proxy_handler, listen_host=listen_host,
ssl_only=options.ssl_only)
settings['listen_host'] = host
settings['listen_port'] = port
settings['handler'] = proxy_handler
settings['cert'] = os.path.abspath(options.cert)
settings['ssl_only'] = options.ssl_only
settings['daemon'] = options.daemon
settings['record'] = os.path.abspath(options.record)
start_server()