Force choice of sub-protocol

The WebSocket standard require us to choose one of the protocols
supported by the client. Enforce this with a specific check in the
base class rather than relying on generous clients.
This commit is contained in:
Pierre Ossman 2017-02-08 15:45:48 +01:00
parent 94783ea0cd
commit c7bde00a4e
2 changed files with 110 additions and 0 deletions

View File

@ -18,6 +18,104 @@
import unittest
from websockify import websocket
class FakeSocket:
def __init__(self):
self.data = b''
def send(self, buf):
self.data += buf
return len(buf)
class AcceptTestCase(unittest.TestCase):
def test_success(self):
ws = websocket.WebSocket()
sock = FakeSocket()
ws.accept(sock, {'upgrade': 'websocket',
'Sec-WebSocket-Version': '13',
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='})
self.assertEqual(sock.data[:13], b'HTTP/1.1 101 ')
self.assertIn(b'\r\nUpgrade: websocket\r\n', sock.data)
self.assertIn(b'\r\nConnection: Upgrade\r\n', sock.data)
self.assertIn(b'\r\nSec-WebSocket-Accept: pczpYSQsvE1vBpTQYjFQPcuoj6M=\r\n', sock.data)
def test_bad_version(self):
ws = websocket.WebSocket()
sock = FakeSocket()
self.assertRaises(Exception, ws.accept,
sock, {'upgrade': 'websocket',
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='})
self.assertRaises(Exception, ws.accept,
sock, {'upgrade': 'websocket',
'Sec-WebSocket-Version': '5',
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='})
self.assertRaises(Exception, ws.accept,
sock, {'upgrade': 'websocket',
'Sec-WebSocket-Version': '20',
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='})
def test_bad_upgrade(self):
ws = websocket.WebSocket()
sock = FakeSocket()
self.assertRaises(Exception, ws.accept,
sock, {'Sec-WebSocket-Version': '13',
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='})
self.assertRaises(Exception, ws.accept,
sock, {'upgrade': 'websocket2',
'Sec-WebSocket-Version': '13',
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='})
def test_missing_key(self):
ws = websocket.WebSocket()
sock = FakeSocket()
self.assertRaises(Exception, ws.accept,
sock, {'upgrade': 'websocket',
'Sec-WebSocket-Version': '13'})
def test_protocol(self):
class ProtoSocket(websocket.WebSocket):
def select_subprotocol(self, protocol):
return 'gazonk'
ws = ProtoSocket()
sock = FakeSocket()
ws.accept(sock, {'upgrade': 'websocket',
'Sec-WebSocket-Version': '13',
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q==',
'Sec-WebSocket-Protocol': 'foobar gazonk'})
self.assertEqual(sock.data[:13], b'HTTP/1.1 101 ')
self.assertIn(b'\r\nSec-WebSocket-Protocol: gazonk\r\n', sock.data)
def test_no_protocol(self):
ws = websocket.WebSocket()
sock = FakeSocket()
ws.accept(sock, {'upgrade': 'websocket',
'Sec-WebSocket-Version': '13',
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='})
self.assertEqual(sock.data[:13], b'HTTP/1.1 101 ')
self.assertNotIn(b'\r\nSec-WebSocket-Protocol:', sock.data)
def test_missing_protocol(self):
ws = websocket.WebSocket()
sock = FakeSocket()
self.assertRaises(Exception, ws.accept,
sock, {'upgrade': 'websocket',
'Sec-WebSocket-Version': '13',
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q==',
'Sec-WebSocket-Protocol': 'foobar gazonk'})
def test_protocol(self):
class ProtoSocket(websocket.WebSocket):
def select_subprotocol(self, protocol):
return 'oddball'
ws = ProtoSocket()
sock = FakeSocket()
self.assertRaises(Exception, ws.accept,
sock, {'upgrade': 'websocket',
'Sec-WebSocket-Version': '13',
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q==',
'Sec-WebSocket-Protocol': 'foobar gazonk'})
class HyBiEncodeDecodeTestCase(unittest.TestCase):
def test_decode_hybi_text(self):
buf = b'\x81\x85\x37\xfa\x21\x3d\x7f\x9f\x4d\x51\x58'

View File

@ -226,6 +226,14 @@ class WebSocket(object):
if accept != expected:
raise Exception("Invalid Sec-WebSocket-Accept header");
self.protocol = headers.get('Sec-WebSocket-Protocol')
if len(protocols) == 0:
if self.protocol is not None:
raise Exception("Unexpected Sec-WebSocket-Protocol header")
else:
if self.protocol not in protocols:
raise Exception("Invalid protocol chosen by server")
self._state = "done"
return
@ -282,6 +290,10 @@ class WebSocket(object):
protocols = headers.get('Sec-WebSocket-Protocol', '').split(',')
if protocols:
self.protocol = self.select_subprotocol(protocols)
# We are required to choose one of the protocols
# presented by the client
if self.protocol not in protocols:
raise Exception('Invalid protocol selected')
self._queue_str("HTTP/1.1 101 Switching Protocols\r\n")
self._queue_str("Upgrade: websocket\r\n")