wswrapper: Allow multiple WebSockets connections.

Allocate buffer and state memory for each accepted connection. This
allows all WebSockets connections to a given listen port to be wrapped
with WebSockets support.
This commit is contained in:
Joel Martin 2010-12-14 12:43:34 -05:00
parent 70c585968b
commit c99124b527
1 changed files with 86 additions and 79 deletions

View File

@ -48,7 +48,6 @@
return -1;
const char _WS_response[] = "HTTP/1.1 101 Web Socket Protocol Handshake\r\n\
Upgrade: WebSocket\r\n\
Connection: Upgrade\r\n\
@ -57,6 +56,17 @@ Connection: Upgrade\r\n\
%sWebSocket-Protocol: sample\r\n\
\r\n%s";
#define WS_BUFSIZE 65536
typedef struct {
char rbuf[WS_BUFSIZE];
char sbuf[WS_BUFSIZE];
int rcarry_cnt;
char rcarry[3];
int newframe;
} _WS_connection;
/*
* If WSWRAP_PORT environment variable is set then listen to the bind fd that
* matches WSWRAP_PORT, otherwise listen to the first socket fd that bind is
@ -65,26 +75,12 @@ Connection: Upgrade\r\n\
int _WS_listen_fd = 0;
int _WS_sockfd = 0;
typedef struct {
char _WS_rbuf[65536];
char _WS_sbuf[65536];
} _WS_connection;
_WS_connection * _WS_connections[65546];
int _WS_bufsize = 65536;
char *_WS_rbuf = NULL;
char *_WS_sbuf = NULL;
int _WS_rcarry_cnt = 0;
char _WS_rcarry[3] = "";
int _WS_newframe = 1;
int _WS_init() {
if (! (_WS_rbuf = malloc(_WS_bufsize)) ) {
return 0;
}
if (! (_WS_sbuf = malloc(_WS_bufsize)) ) {
return 0;
}
}
/*
* WebSocket handshake routines
*/
int _WS_gen_md5(char *key1, char *key2, char *key3, char *target) {
unsigned int i, spaces1 = 0, spaces2 = 0;
@ -246,13 +242,17 @@ int _WS_handshake(int sockfd)
return ret;
}
/*
* WebSockets recv and read interposer routine
*/
ssize_t _WS_recv(int recvf, int sockfd, const void *buf,
size_t len, int flags)
{
_WS_connection *ws = _WS_connections[sockfd];
int rawcount, deccount, left, rawlen, retlen, decodelen;
int sockflags;
int i;
char * fstart, * fend, * cstart;
char *fstart, *fend, *cstart;
static void * (*rfunc)(), * (*rfunc2)();
if (!rfunc) rfunc = (void *(*)()) dlsym(RTLD_NEXT, "recv");
@ -262,7 +262,7 @@ ssize_t _WS_recv(int recvf, int sockfd, const void *buf,
return 0;
}
if ((_WS_sockfd == 0) || (_WS_sockfd != sockfd)) {
if (! ws) {
// Not our file descriptor, just pass through
if (recvf) {
return (ssize_t) rfunc(sockfd, buf, len, flags);
@ -277,26 +277,26 @@ ssize_t _WS_recv(int recvf, int sockfd, const void *buf,
retlen = 0;
// first copy in any carry-over bytes
if (_WS_rcarry_cnt) {
if (_WS_rcarry_cnt == 1) {
DEBUG("Using carry byte: %u (", _WS_rcarry[0]);
} else if (_WS_rcarry_cnt == 2) {
DEBUG("Using carry bytes: %u,%u (", _WS_rcarry[0],
_WS_rcarry[1]);
if (ws->rcarry_cnt) {
if (ws->rcarry_cnt == 1) {
DEBUG("Using carry byte: %u (", ws->rcarry[0]);
} else if (ws->rcarry_cnt == 2) {
DEBUG("Using carry bytes: %u,%u (", ws->rcarry[0],
ws->rcarry[1]);
} else {
RET_ERROR(EIO, "Too many carry-over bytes\n");
}
if (len <= _WS_rcarry_cnt) {
if (len <= ws->rcarry_cnt) {
DEBUG("final)\n");
memcpy((char *) buf, _WS_rcarry, len);
_WS_rcarry_cnt -= len;
memcpy((char *) buf, ws->rcarry, len);
ws->rcarry_cnt -= len;
return len;
} else {
DEBUG("prepending)\n");
memcpy((char *) buf, _WS_rcarry, _WS_rcarry_cnt);
retlen += _WS_rcarry_cnt;
left -= _WS_rcarry_cnt;
_WS_rcarry_cnt = 0;
memcpy((char *) buf, ws->rcarry, ws->rcarry_cnt);
retlen += ws->rcarry_cnt;
left -= ws->rcarry_cnt;
ws->rcarry_cnt = 0;
}
}
@ -304,20 +304,20 @@ ssize_t _WS_recv(int recvf, int sockfd, const void *buf,
rawcount = (left * 4) / 3 + 3;
rawcount -= rawcount%4;
if (rawcount > _WS_bufsize - 1) {
if (rawcount > WS_BUFSIZE - 1) {
RET_ERROR(ENOMEM, "recv of %d bytes is larger than buffer\n", rawcount);
}
i = 0;
while (1) {
// Peek at everything available
rawlen = (int) rfunc(sockfd, _WS_rbuf, _WS_bufsize-1,
rawlen = (int) rfunc(sockfd, ws->rbuf, WS_BUFSIZE-1,
flags | MSG_PEEK);
if (rawlen <= 0) {
DEBUG("_WS_recv: returning because rawlen %d\n", rawlen);
return (ssize_t) rawlen;
}
fstart = _WS_rbuf;
fstart = ws->rbuf;
/*
while (rawlen >= 2 && fstart[0] == '\x00' && fstart[1] == '\xff') {
@ -326,7 +326,7 @@ ssize_t _WS_recv(int recvf, int sockfd, const void *buf,
}
*/
if (rawlen >= 2 && fstart[0] == '\x00' && fstart[1] == '\xff') {
rawlen = (int) rfunc(sockfd, _WS_rbuf, 2, flags);
rawlen = (int) rfunc(sockfd, ws->rbuf, 2, flags);
if (rawlen != 2) {
RET_ERROR(EIO, "Could not strip empty frame headers\n");
}
@ -335,7 +335,7 @@ ssize_t _WS_recv(int recvf, int sockfd, const void *buf,
fstart[rawlen] = '\x00';
if (rawlen - _WS_newframe >= 4) {
if (rawlen - ws->newframe >= 4) {
// We have enough to base64 decode at least 1 byte
break;
}
@ -362,19 +362,19 @@ ssize_t _WS_recv(int recvf, int sockfd, const void *buf,
DEBUG("\n");
*/
if (_WS_newframe) {
if (ws->newframe) {
if (fstart[0] != '\x00') {
RET_ERROR(EPROTO, "Missing frame start\n");
}
fstart++;
rawlen--;
_WS_newframe = 0;
ws->newframe = 0;
}
fend = memchr(fstart, '\xff', rawlen);
if (fend) {
_WS_newframe = 1;
ws->newframe = 1;
if ((fend - fstart) % 4) {
RET_ERROR(EPROTO, "Frame length is not multiple of 4\n");
}
@ -387,7 +387,7 @@ ssize_t _WS_recv(int recvf, int sockfd, const void *buf,
// How much should we consume
if (rawcount < fend - fstart) {
_WS_newframe = 0;
ws->newframe = 0;
deccount = rawcount;
} else {
deccount = fend - fstart;
@ -397,7 +397,7 @@ ssize_t _WS_recv(int recvf, int sockfd, const void *buf,
if (flags & MSG_PEEK) {
MSG("*** Got MSG_PEEK ***\n");
} else {
rfunc(sockfd, _WS_rbuf, fstart - _WS_rbuf + deccount + _WS_newframe, flags);
rfunc(sockfd, ws->rbuf, fstart - ws->rbuf + deccount + ws->newframe, flags);
}
fstart[deccount] = '\x00'; // base64 terminator
@ -415,16 +415,16 @@ ssize_t _WS_recv(int recvf, int sockfd, const void *buf,
if (! (flags & MSG_PEEK)) {
// Add anything left over to the carry-over
_WS_rcarry_cnt = decodelen - left;
if (_WS_rcarry_cnt > 2) {
ws->rcarry_cnt = decodelen - left;
if (ws->rcarry_cnt > 2) {
RET_ERROR(EPROTO, "Got too much base64 data\n");
}
memcpy(_WS_rcarry, buf + retlen, _WS_rcarry_cnt);
if (_WS_rcarry_cnt == 1) {
DEBUG("Saving carry byte: %u\n", _WS_rcarry[0]);
} else if (_WS_rcarry_cnt == 2) {
DEBUG("Saving carry bytes: %u,%u\n", _WS_rcarry[0],
_WS_rcarry[1]);
memcpy(ws->rcarry, buf + retlen, ws->rcarry_cnt);
if (ws->rcarry_cnt == 1) {
DEBUG("Saving carry byte: %u\n", ws->rcarry[0]);
} else if (ws->rcarry_cnt == 2) {
DEBUG("Saving carry bytes: %u,%u\n", ws->rcarry[0],
ws->rcarry[1]);
} else {
MSG("Waah2!\n");
}
@ -442,9 +442,13 @@ ssize_t _WS_recv(int recvf, int sockfd, const void *buf,
return retlen;
}
/*
* WebSockets send and write interposer routine
*/
ssize_t _WS_send(int sendf, int sockfd, const void *buf,
size_t len, int flags)
{
_WS_connection *ws = _WS_connections[sockfd];
int rawlen, enclen, rlen, over, left, clen, retlen, dbufsize;
int sockflags;
char * target;
@ -453,7 +457,7 @@ ssize_t _WS_send(int sendf, int sockfd, const void *buf,
if (!sfunc) sfunc = (void *(*)()) dlsym(RTLD_NEXT, "send");
if (!sfunc2) sfunc2 = (void *(*)()) dlsym(RTLD_NEXT, "write");
if ((_WS_sockfd == 0) || (_WS_sockfd != sockfd)) {
if (! ws) {
// Not our file descriptor, just pass through
if (sendf) {
return (ssize_t) sfunc(sockfd, buf, len, flags);
@ -465,22 +469,22 @@ ssize_t _WS_send(int sendf, int sockfd, const void *buf,
sockflags = fcntl(sockfd, F_GETFL, 0);
dbufsize = (_WS_bufsize * 3)/4 - 2;
dbufsize = (WS_BUFSIZE * 3)/4 - 2;
if (len > dbufsize) {
RET_ERROR(ENOMEM, "send of %d bytes is larger than send buffer\n", len);
}
// base64 encode and add frame markers
rawlen = 0;
_WS_sbuf[rawlen++] = '\x00';
enclen = b64_ntop(buf, len, _WS_sbuf+rawlen, _WS_bufsize-rawlen);
ws->sbuf[rawlen++] = '\x00';
enclen = b64_ntop(buf, len, ws->sbuf+rawlen, WS_BUFSIZE-rawlen);
if (enclen < 0) {
RET_ERROR(EPROTO, "Base64 encoding error\n");
}
rawlen += enclen;
_WS_sbuf[rawlen++] = '\xff';
ws->sbuf[rawlen++] = '\xff';
rlen = (int) sfunc(sockfd, _WS_sbuf, rawlen, flags);
rlen = (int) sfunc(sockfd, ws->sbuf, rawlen, flags);
if (rlen <= 0) {
return rlen;
@ -490,11 +494,11 @@ ssize_t _WS_send(int sendf, int sockfd, const void *buf,
left = (4 - over) % 4 + 1; // left to send
DEBUG("_WS_send: rlen: %d (over: %d, left: %d), rawlen: %d\n", rlen, over, left, rawlen);
rlen += left;
_WS_sbuf[rlen-1] = '\xff';
ws->sbuf[rlen-1] = '\xff';
i = 0;
do {
i++;
clen = (int) sfunc(sockfd, _WS_sbuf + rlen - left, left, flags);
clen = (int) sfunc(sockfd, ws->sbuf + rlen - left, left, flags);
if (clen > 0) {
left -= clen;
} else if (clen == 0) {
@ -518,8 +522,8 @@ ssize_t _WS_send(int sendf, int sockfd, const void *buf,
// Adjust for framing
retlen = rlen - 2;
// Adjust for base64 padding
if (_WS_sbuf[rlen-1] == '=') { retlen --; }
if (_WS_sbuf[rlen-2] == '=') { retlen --; }
if (ws->sbuf[rlen-1] == '=') { retlen --; }
if (ws->sbuf[rlen-2] == '=') { retlen --; }
// Adjust for base64 encoding
retlen = (retlen*3)/4;
@ -529,13 +533,15 @@ ssize_t _WS_send(int sendf, int sockfd, const void *buf,
for (i = 0; i < retlen; i++) {
DEBUG("%u,", (unsigned char) ((char *)buf)[i]);
}
DEBUG(" as '%s' (%d)\n", _WS_sbuf+1, rlen);
DEBUG(" as '%s' (%d)\n", ws->sbuf+1, rlen);
*/
return (ssize_t) retlen;
}
/* Override network routines */
/*
* Overload (LD_PRELOAD) standard library network routines
*/
/*
int socket(int domain, int type, int protocol)
@ -603,24 +609,24 @@ int accept(int sockfd, struct sockaddr *addr, socklen_t *addrlen)
return fd;
}
if (_WS_sockfd == 0) {
// TODO: not just first connection
_WS_sockfd = fd;
if (!_WS_rbuf) {
if (! _WS_init()) {
RET_ERROR(ENOMEM, "Could not allocate interposer buffer\n");
}
if (_WS_connections[fd]) {
MSG("error, already interposing on fd %d\n", fd);
} else {
if (! (_WS_connections[fd] = malloc(sizeof(_WS_connection)))) {
RET_ERROR(ENOMEM, "Could not allocate interposer memory\n");
}
_WS_connections[fd]->rcarry_cnt = 0;
_WS_connections[fd]->rcarry[0] = '\0';
_WS_connections[fd]->newframe = 1;
ret = _WS_handshake(_WS_sockfd);
ret = _WS_handshake(fd);
if (ret < 0) {
free(_WS_connections[fd]);
_WS_connections[fd] = NULL;
errno = EPROTO;
return ret;
}
MSG("interposing on fd %d\n", _WS_sockfd);
} else {
DEBUG("already interposing on fd %d\n", _WS_sockfd);
MSG("interposing on fd %d (allocated memory)\n", fd);
}
return fd;
@ -631,9 +637,10 @@ int close(int fd)
static void * (*func)();
if (!func) func = (void *(*)()) dlsym(RTLD_NEXT, "close");
if ((_WS_sockfd != 0) && (_WS_sockfd == fd)) {
MSG("finished interposing on fd %d\n", _WS_sockfd);
_WS_sockfd = 0;
if (_WS_connections[fd]) {
free(_WS_connections[fd]);
_WS_connections[fd] = NULL;
MSG("finished interposing on fd %d (freed memory)\n", fd);
}
return (int) func(fd);
}