From c7bde00a4e34d7194242bbe61da540394ec2a8c6 Mon Sep 17 00:00:00 2001 From: Pierre Ossman Date: Wed, 8 Feb 2017 15:45:48 +0100 Subject: [PATCH] Force choice of sub-protocol The WebSocket standard require us to choose one of the protocols supported by the client. Enforce this with a specific check in the base class rather than relying on generous clients. --- tests/test_websocket.py | 98 +++++++++++++++++++++++++++++++++++++++++ websockify/websocket.py | 12 +++++ 2 files changed, 110 insertions(+) diff --git a/tests/test_websocket.py b/tests/test_websocket.py index 77d0eca..49cc620 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -18,6 +18,104 @@ import unittest from websockify import websocket +class FakeSocket: + def __init__(self): + self.data = b'' + + def send(self, buf): + self.data += buf + return len(buf) + +class AcceptTestCase(unittest.TestCase): + def test_success(self): + ws = websocket.WebSocket() + sock = FakeSocket() + ws.accept(sock, {'upgrade': 'websocket', + 'Sec-WebSocket-Version': '13', + 'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='}) + self.assertEqual(sock.data[:13], b'HTTP/1.1 101 ') + self.assertIn(b'\r\nUpgrade: websocket\r\n', sock.data) + self.assertIn(b'\r\nConnection: Upgrade\r\n', sock.data) + self.assertIn(b'\r\nSec-WebSocket-Accept: pczpYSQsvE1vBpTQYjFQPcuoj6M=\r\n', sock.data) + + def test_bad_version(self): + ws = websocket.WebSocket() + sock = FakeSocket() + self.assertRaises(Exception, ws.accept, + sock, {'upgrade': 'websocket', + 'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='}) + self.assertRaises(Exception, ws.accept, + sock, {'upgrade': 'websocket', + 'Sec-WebSocket-Version': '5', + 'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='}) + self.assertRaises(Exception, ws.accept, + sock, {'upgrade': 'websocket', + 'Sec-WebSocket-Version': '20', + 'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='}) + + def test_bad_upgrade(self): + ws = websocket.WebSocket() + sock = FakeSocket() + self.assertRaises(Exception, ws.accept, + sock, {'Sec-WebSocket-Version': '13', + 'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='}) + self.assertRaises(Exception, ws.accept, + sock, {'upgrade': 'websocket2', + 'Sec-WebSocket-Version': '13', + 'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='}) + + def test_missing_key(self): + ws = websocket.WebSocket() + sock = FakeSocket() + self.assertRaises(Exception, ws.accept, + sock, {'upgrade': 'websocket', + 'Sec-WebSocket-Version': '13'}) + + def test_protocol(self): + class ProtoSocket(websocket.WebSocket): + def select_subprotocol(self, protocol): + return 'gazonk' + + ws = ProtoSocket() + sock = FakeSocket() + ws.accept(sock, {'upgrade': 'websocket', + 'Sec-WebSocket-Version': '13', + 'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q==', + 'Sec-WebSocket-Protocol': 'foobar gazonk'}) + self.assertEqual(sock.data[:13], b'HTTP/1.1 101 ') + self.assertIn(b'\r\nSec-WebSocket-Protocol: gazonk\r\n', sock.data) + + def test_no_protocol(self): + ws = websocket.WebSocket() + sock = FakeSocket() + ws.accept(sock, {'upgrade': 'websocket', + 'Sec-WebSocket-Version': '13', + 'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='}) + self.assertEqual(sock.data[:13], b'HTTP/1.1 101 ') + self.assertNotIn(b'\r\nSec-WebSocket-Protocol:', sock.data) + + def test_missing_protocol(self): + ws = websocket.WebSocket() + sock = FakeSocket() + self.assertRaises(Exception, ws.accept, + sock, {'upgrade': 'websocket', + 'Sec-WebSocket-Version': '13', + 'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q==', + 'Sec-WebSocket-Protocol': 'foobar gazonk'}) + + def test_protocol(self): + class ProtoSocket(websocket.WebSocket): + def select_subprotocol(self, protocol): + return 'oddball' + + ws = ProtoSocket() + sock = FakeSocket() + self.assertRaises(Exception, ws.accept, + sock, {'upgrade': 'websocket', + 'Sec-WebSocket-Version': '13', + 'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q==', + 'Sec-WebSocket-Protocol': 'foobar gazonk'}) + class HyBiEncodeDecodeTestCase(unittest.TestCase): def test_decode_hybi_text(self): buf = b'\x81\x85\x37\xfa\x21\x3d\x7f\x9f\x4d\x51\x58' diff --git a/websockify/websocket.py b/websockify/websocket.py index 72a269c..6d4d8cb 100644 --- a/websockify/websocket.py +++ b/websockify/websocket.py @@ -226,6 +226,14 @@ class WebSocket(object): if accept != expected: raise Exception("Invalid Sec-WebSocket-Accept header"); + self.protocol = headers.get('Sec-WebSocket-Protocol') + if len(protocols) == 0: + if self.protocol is not None: + raise Exception("Unexpected Sec-WebSocket-Protocol header") + else: + if self.protocol not in protocols: + raise Exception("Invalid protocol chosen by server") + self._state = "done" return @@ -282,6 +290,10 @@ class WebSocket(object): protocols = headers.get('Sec-WebSocket-Protocol', '').split(',') if protocols: self.protocol = self.select_subprotocol(protocols) + # We are required to choose one of the protocols + # presented by the client + if self.protocol not in protocols: + raise Exception('Invalid protocol selected') self._queue_str("HTTP/1.1 101 Switching Protocols\r\n") self._queue_str("Upgrade: websocket\r\n")