Force choice of sub-protocol
The WebSocket standard require us to choose one of the protocols supported by the client. Enforce this with a specific check in the base class rather than relying on generous clients.
This commit is contained in:
parent
94783ea0cd
commit
c7bde00a4e
|
@ -18,6 +18,104 @@
|
|||
import unittest
|
||||
from websockify import websocket
|
||||
|
||||
class FakeSocket:
|
||||
def __init__(self):
|
||||
self.data = b''
|
||||
|
||||
def send(self, buf):
|
||||
self.data += buf
|
||||
return len(buf)
|
||||
|
||||
class AcceptTestCase(unittest.TestCase):
|
||||
def test_success(self):
|
||||
ws = websocket.WebSocket()
|
||||
sock = FakeSocket()
|
||||
ws.accept(sock, {'upgrade': 'websocket',
|
||||
'Sec-WebSocket-Version': '13',
|
||||
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='})
|
||||
self.assertEqual(sock.data[:13], b'HTTP/1.1 101 ')
|
||||
self.assertIn(b'\r\nUpgrade: websocket\r\n', sock.data)
|
||||
self.assertIn(b'\r\nConnection: Upgrade\r\n', sock.data)
|
||||
self.assertIn(b'\r\nSec-WebSocket-Accept: pczpYSQsvE1vBpTQYjFQPcuoj6M=\r\n', sock.data)
|
||||
|
||||
def test_bad_version(self):
|
||||
ws = websocket.WebSocket()
|
||||
sock = FakeSocket()
|
||||
self.assertRaises(Exception, ws.accept,
|
||||
sock, {'upgrade': 'websocket',
|
||||
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='})
|
||||
self.assertRaises(Exception, ws.accept,
|
||||
sock, {'upgrade': 'websocket',
|
||||
'Sec-WebSocket-Version': '5',
|
||||
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='})
|
||||
self.assertRaises(Exception, ws.accept,
|
||||
sock, {'upgrade': 'websocket',
|
||||
'Sec-WebSocket-Version': '20',
|
||||
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='})
|
||||
|
||||
def test_bad_upgrade(self):
|
||||
ws = websocket.WebSocket()
|
||||
sock = FakeSocket()
|
||||
self.assertRaises(Exception, ws.accept,
|
||||
sock, {'Sec-WebSocket-Version': '13',
|
||||
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='})
|
||||
self.assertRaises(Exception, ws.accept,
|
||||
sock, {'upgrade': 'websocket2',
|
||||
'Sec-WebSocket-Version': '13',
|
||||
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='})
|
||||
|
||||
def test_missing_key(self):
|
||||
ws = websocket.WebSocket()
|
||||
sock = FakeSocket()
|
||||
self.assertRaises(Exception, ws.accept,
|
||||
sock, {'upgrade': 'websocket',
|
||||
'Sec-WebSocket-Version': '13'})
|
||||
|
||||
def test_protocol(self):
|
||||
class ProtoSocket(websocket.WebSocket):
|
||||
def select_subprotocol(self, protocol):
|
||||
return 'gazonk'
|
||||
|
||||
ws = ProtoSocket()
|
||||
sock = FakeSocket()
|
||||
ws.accept(sock, {'upgrade': 'websocket',
|
||||
'Sec-WebSocket-Version': '13',
|
||||
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q==',
|
||||
'Sec-WebSocket-Protocol': 'foobar gazonk'})
|
||||
self.assertEqual(sock.data[:13], b'HTTP/1.1 101 ')
|
||||
self.assertIn(b'\r\nSec-WebSocket-Protocol: gazonk\r\n', sock.data)
|
||||
|
||||
def test_no_protocol(self):
|
||||
ws = websocket.WebSocket()
|
||||
sock = FakeSocket()
|
||||
ws.accept(sock, {'upgrade': 'websocket',
|
||||
'Sec-WebSocket-Version': '13',
|
||||
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='})
|
||||
self.assertEqual(sock.data[:13], b'HTTP/1.1 101 ')
|
||||
self.assertNotIn(b'\r\nSec-WebSocket-Protocol:', sock.data)
|
||||
|
||||
def test_missing_protocol(self):
|
||||
ws = websocket.WebSocket()
|
||||
sock = FakeSocket()
|
||||
self.assertRaises(Exception, ws.accept,
|
||||
sock, {'upgrade': 'websocket',
|
||||
'Sec-WebSocket-Version': '13',
|
||||
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q==',
|
||||
'Sec-WebSocket-Protocol': 'foobar gazonk'})
|
||||
|
||||
def test_protocol(self):
|
||||
class ProtoSocket(websocket.WebSocket):
|
||||
def select_subprotocol(self, protocol):
|
||||
return 'oddball'
|
||||
|
||||
ws = ProtoSocket()
|
||||
sock = FakeSocket()
|
||||
self.assertRaises(Exception, ws.accept,
|
||||
sock, {'upgrade': 'websocket',
|
||||
'Sec-WebSocket-Version': '13',
|
||||
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q==',
|
||||
'Sec-WebSocket-Protocol': 'foobar gazonk'})
|
||||
|
||||
class HyBiEncodeDecodeTestCase(unittest.TestCase):
|
||||
def test_decode_hybi_text(self):
|
||||
buf = b'\x81\x85\x37\xfa\x21\x3d\x7f\x9f\x4d\x51\x58'
|
||||
|
|
|
@ -226,6 +226,14 @@ class WebSocket(object):
|
|||
if accept != expected:
|
||||
raise Exception("Invalid Sec-WebSocket-Accept header");
|
||||
|
||||
self.protocol = headers.get('Sec-WebSocket-Protocol')
|
||||
if len(protocols) == 0:
|
||||
if self.protocol is not None:
|
||||
raise Exception("Unexpected Sec-WebSocket-Protocol header")
|
||||
else:
|
||||
if self.protocol not in protocols:
|
||||
raise Exception("Invalid protocol chosen by server")
|
||||
|
||||
self._state = "done"
|
||||
|
||||
return
|
||||
|
@ -282,6 +290,10 @@ class WebSocket(object):
|
|||
protocols = headers.get('Sec-WebSocket-Protocol', '').split(',')
|
||||
if protocols:
|
||||
self.protocol = self.select_subprotocol(protocols)
|
||||
# We are required to choose one of the protocols
|
||||
# presented by the client
|
||||
if self.protocol not in protocols:
|
||||
raise Exception('Invalid protocol selected')
|
||||
|
||||
self._queue_str("HTTP/1.1 101 Switching Protocols\r\n")
|
||||
self._queue_str("Upgrade: websocket\r\n")
|
||||
|
|
Loading…
Reference in New Issue