commit 2aa2376e3a50c286a1373362c42297ae1f5a3194 Author: Ed Schouten Date: Tue Jun 28 11:44:35 2011 +0200 Add initial version of wsproxy source code. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..8c82ed6 --- /dev/null +++ b/Makefile @@ -0,0 +1,7 @@ +CFLAGS=-Wall -Wmissing-prototypes -Wstrict-prototypes -Wold-style-definition -Werror -O2 +LDFLAGS=-lresolv -lssl + +all: wsproxy + +clean: + rm -f wsproxy diff --git a/wsproxy.c b/wsproxy.c new file mode 100644 index 0000000..a40a580 --- /dev/null +++ b/wsproxy.c @@ -0,0 +1,422 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +static pid_t other; + +static void +die(int exitcode) +{ + + kill(other, SIGTERM); + exit(exitcode); +} + +static void +usage(void) +{ + + fprintf(stderr, "usage: wsproxy minport maxport\n"); + exit(1); +} + +static int +pgetc(FILE *fp) +{ + int ret; + + ret = fgetc(fp); + if (ret == EOF) + die(0); + return (ret); +} + +#if 0 /* UTF-8 */ + +static void +pputc(FILE *fp, unsigned char ch) +{ + int ret; + + ret = fputc(ch, fp); + if (ret == EOF) + die(0); +} + +static void +decode(FILE *in, int outfd) +{ + FILE *out; + int ch; + unsigned char och; + + out = fdopen(outfd, "w"); + if (out == NULL) { + perror("fdopen"); + die(1); + } + + for (;;) { + /* Frame header. */ + ch = pgetc(in); + if (ch != 0x00) { + fprintf(stderr, "malformed frame header received\n"); + die(1); + } + + for (;;) { + /* Frame trailer. */ + ch = pgetc(in); + if (ch == EOF) + die(0); + if (ch == 0xff) { + fflush(out); + break; + } + + /* UTF-8 character, only allowing points 0 to 255. */ + if (ch < 0x80) + och = ch; + else if ((ch & 0xf3) == 0xc0) { + och = ch << 6; + ch = pgetc(in); + if ((ch & 0xc0) != 0x80) + goto malformed; + och |= ch & 0x3f; + } else + goto malformed; + pputc(out, och); + } + } + +malformed: + fprintf(stderr, "malformed UTF-8 sequence received\n"); + die(1); +} + +static int +encode(int in, int out) +{ + unsigned char inbuf[512]; + unsigned char outbuf[sizeof inbuf * 2 + 2]; + unsigned char *op; + ssize_t len, i; + + for (;;) { + len = read(in, inbuf, sizeof inbuf); + if (len == -1) { + perror("read"); + die(1); + } else if (len == 0) + die(0); + + op = outbuf; + /* Frame header. */ + *op++ = 0x00; + for (i = 0; i < len; i++) { + /* Encode data as UTF-8. */ + if (inbuf[i] < 0x80) + *op++ = inbuf[i]; + else { + *op++ = 0xc0 | (inbuf[i] >> 6); + *op++ = 0x80 | (inbuf[i] & 0x3f); + } + } + /* Frame trailer. */ + *op++ = 0xff; + assert(op <= outbuf + sizeof outbuf); + len = write(out, outbuf, op - outbuf); + if (len == -1) { + perror("write"); + die(1); + } else if (len != op - outbuf) + die(0); + } +} + +#else /* base64 */ + +static void +putb64(FILE *out, char *inb, size_t *inblen) +{ + char inbuf[5] = { 0 }; + unsigned char outbuf[3]; + ssize_t outbuflen; + + if (*inblen == 0) + return; + + assert(*inblen <= 4); + memcpy(inbuf, inb, *inblen); + outbuflen = b64_pton(inbuf, outbuf, sizeof outbuf); + if (outbuflen <= 0) { + fprintf(stderr, "invalid Base64 data\n"); + die(1); + } + if (fwrite(outbuf, outbuflen, 1, out) != 1) { + perror("fwrite"); + die(1); + } + *inblen = 0; +} + +static void +decode(FILE *in, int outfd) +{ + FILE *out; + int ch; + char inb[4]; + size_t inblen = 0; + + out = fdopen(outfd, "w"); + if (out == NULL) { + perror("fdopen"); + die(1); + } + + for (;;) { + /* Frame header. */ + ch = pgetc(in); + if (ch != 0x00) { + fprintf(stderr, "malformed frame header received\n"); + die(1); + } + + for (;;) { + ch = pgetc(in); + if (ch == EOF) { + putb64(out, inb, &inblen); + die(0); + } + /* Frame trailer. */ + if (ch == 0xff) { + putb64(out, inb, &inblen); + if (fflush(out) == -1) { + perror("fflush"); + die(1); + } + break; + } + + if (!((ch >= 'A' && ch <= 'Z') || + (ch >= 'a' && ch <= 'z') || + (ch >= '0' && ch <= '9') || + ch == '+' || ch == '/' || ch == '=')) { + fprintf(stderr, + "non-Base64 character received\n"); + die(1); + } + + /* Base64 character. */ + inb[inblen++] = ch; + if (inblen == sizeof inb) + putb64(out, inb, &inblen); + } + } +} + +static int +encode(int in, int out) +{ + unsigned char inbuf[512]; + char outbuf[sizeof inbuf * 2 + 2]; + ssize_t len, wlen; + + for (;;) { + len = read(in, inbuf, sizeof inbuf); + if (len == -1) { + perror("read"); + die(1); + } else if (len == 0) + die(0); + + /* Frame header. */ + outbuf[0] = 0x00; + /* Encode data as Base64. */ + len = b64_ntop(inbuf, len, outbuf + 1, sizeof outbuf - 1) + 1; + assert(len >= 1); + /* Frame footer. */ + outbuf[len++] = 0xff; + + wlen = write(out, outbuf, len); + if (wlen == -1) { + perror("write"); + die(1); + } else if (wlen != len) + die(0); + } +} + +#endif + +static char * +parsestring(const char *in) +{ + size_t len; + + len = strlen(in); + if (len > 0 && in[len - 1] == '\n') + len--; + if (len > 0 && in[len - 1] == '\r') + len--; + if (len == 0) + return (NULL); + return (strndup(in, len)); +} + +static uint32_t +parsehdrkey(const char *key) +{ + uint32_t sum = 0, spaces = 0; + + for (; *key != '\0'; key++) { + if (*key >= '0' && *key <= '9') + sum = sum * 10 + *key - '0'; + else if (*key == ' ') + spaces++; + } + return (sum / spaces); +} + +static void +calcresponse(uint32_t key1, uint32_t key2, const char *key3, char *out) +{ + MD5_CTX c; + char in[16]; + + in[0] = key1 >> 24; + in[1] = key1 >> 16; + in[2] = key1 >> 8; + in[3] = key1; + in[4] = key2 >> 24; + in[5] = key2 >> 16; + in[6] = key2 >> 8; + in[7] = key2; + memcpy(in + 8, key3, 8); + + MD5_Init(&c); + MD5_Update(&c, (void *)in, sizeof in); + MD5_Final((void *)out, &c); +} + +int +main(int argc, char *argv[]) +{ + struct sockaddr_storage sa; + char line[512], key3[8], response[16], *host = NULL, *origin = NULL; + unsigned long minport, maxport, port; + uint32_t key1 = 0, key2 = 0; + socklen_t salen; + pid_t pid; + int s; + + if (argc != 3) + usage(); + minport = strtoul(argv[1], NULL, 10); + maxport = strtoul(argv[2], NULL, 10); + if (1 > minport || minport > maxport || maxport > 65535) + usage(); + + /* GET / header. */ + if (fgets(line, sizeof line, stdin) == NULL) { + fprintf(stderr, "no HTTP header received\n"); + return (1); + } + if (strncmp(line, "GET /", 5) != 0) { + fprintf(stderr, "malformed HTTP header received\n"); + return (1); + } + port = strtoul(line + 5, NULL, 10); + if (port < minport || port > maxport) { + fprintf(stderr, "port not allowed\n"); + return (1); + } + + /* Parse HTTP headers. */ + do { + if (fgets(line, sizeof line, stdin) == NULL) { + fprintf(stderr, "partial HTTP header received\n"); + return (1); + } + if (strncasecmp(line, "Host: ", 6) == 0) { + host = parsestring(line + 6); + } else if (strncasecmp(line, "Origin: ", 8) == 0) { + origin = parsestring(line + 8); + } else if (strncasecmp(line, "Sec-WebSocket-Key1: ", 20) == 0) { + key1 = parsehdrkey(line + 20); + } else if (strncasecmp(line, "Sec-WebSocket-Key2: ", 20) == 0) { + key2 = parsehdrkey(line + 20); + } + } while (strcmp(line, "\n") != 0 && strcmp(line, "\r\n") != 0); + + /* Eight byte payload. */ + if (fread(key3, sizeof key3, 1, stdin) != 1) { + fprintf(stderr, "key data missing\n"); + return (1); + } + + /* Use our own address. Fall back to 127.0.0.1 on failure. */ + salen = sizeof sa; + if (getsockname(STDIN_FILENO, (struct sockaddr *)&sa, &salen) == -1) { + struct sockaddr_in *sin = (struct sockaddr_in *)&sa; + salen = sizeof *sin; + memset(sin, 0, salen); + sin->sin_family = AF_INET; + sin->sin_addr.s_addr = inet_addr("127.0.0.1"); + } + switch (sa.ss_family) { + case AF_INET: + ((struct sockaddr_in *)&sa)->sin_port = htons(port); + break; + case AF_INET6: + ((struct sockaddr_in6 *)&sa)->sin6_port = htons(port); + break; + default: + /* Unknown protocol. */ + fprintf(stderr, "unsupported network protocol\n"); + return (1); + } + s = socket(sa.ss_family, SOCK_STREAM, 0); + if (s == -1) { + perror("socket"); + return (1); + } + if (connect(s, (struct sockaddr *)&sa, salen) == -1) { + perror("connect"); + return (1); + } + + /* Send HTTP response. */ + calcresponse(key1, key2, key3, response); + printf("HTTP/1.1 101 WebSocket Protocol Handshake\r\n" + "Upgrade: WebSocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Origin: %s\r\n" + "Sec-WebSocket-Location: ws://%s/%lu\r\n" + "Sec-WebSocket-Protocol: base64\r\n\r\n", origin, host, port); + fwrite(response, sizeof response, 1, stdout); + fflush(stdout); + + /* Spawn child process for bi-directional pipe. */ + pid = fork(); + if (pid == -1) { + perror("fork"); + return (1); + } else if (pid == 0) { + other = getppid(); + decode(stdin, s); + } else { + other = pid; + encode(s, STDOUT_FILENO); + } + assert(0); +}