Force choice of sub-protocol
The WebSocket standard require us to choose one of the protocols supported by the client. Enforce this with a specific check in the base class rather than relying on generous clients.
This commit is contained in:
parent
94783ea0cd
commit
c7bde00a4e
|
@ -18,6 +18,104 @@
|
||||||
import unittest
|
import unittest
|
||||||
from websockify import websocket
|
from websockify import websocket
|
||||||
|
|
||||||
|
class FakeSocket:
|
||||||
|
def __init__(self):
|
||||||
|
self.data = b''
|
||||||
|
|
||||||
|
def send(self, buf):
|
||||||
|
self.data += buf
|
||||||
|
return len(buf)
|
||||||
|
|
||||||
|
class AcceptTestCase(unittest.TestCase):
|
||||||
|
def test_success(self):
|
||||||
|
ws = websocket.WebSocket()
|
||||||
|
sock = FakeSocket()
|
||||||
|
ws.accept(sock, {'upgrade': 'websocket',
|
||||||
|
'Sec-WebSocket-Version': '13',
|
||||||
|
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='})
|
||||||
|
self.assertEqual(sock.data[:13], b'HTTP/1.1 101 ')
|
||||||
|
self.assertIn(b'\r\nUpgrade: websocket\r\n', sock.data)
|
||||||
|
self.assertIn(b'\r\nConnection: Upgrade\r\n', sock.data)
|
||||||
|
self.assertIn(b'\r\nSec-WebSocket-Accept: pczpYSQsvE1vBpTQYjFQPcuoj6M=\r\n', sock.data)
|
||||||
|
|
||||||
|
def test_bad_version(self):
|
||||||
|
ws = websocket.WebSocket()
|
||||||
|
sock = FakeSocket()
|
||||||
|
self.assertRaises(Exception, ws.accept,
|
||||||
|
sock, {'upgrade': 'websocket',
|
||||||
|
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='})
|
||||||
|
self.assertRaises(Exception, ws.accept,
|
||||||
|
sock, {'upgrade': 'websocket',
|
||||||
|
'Sec-WebSocket-Version': '5',
|
||||||
|
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='})
|
||||||
|
self.assertRaises(Exception, ws.accept,
|
||||||
|
sock, {'upgrade': 'websocket',
|
||||||
|
'Sec-WebSocket-Version': '20',
|
||||||
|
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='})
|
||||||
|
|
||||||
|
def test_bad_upgrade(self):
|
||||||
|
ws = websocket.WebSocket()
|
||||||
|
sock = FakeSocket()
|
||||||
|
self.assertRaises(Exception, ws.accept,
|
||||||
|
sock, {'Sec-WebSocket-Version': '13',
|
||||||
|
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='})
|
||||||
|
self.assertRaises(Exception, ws.accept,
|
||||||
|
sock, {'upgrade': 'websocket2',
|
||||||
|
'Sec-WebSocket-Version': '13',
|
||||||
|
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='})
|
||||||
|
|
||||||
|
def test_missing_key(self):
|
||||||
|
ws = websocket.WebSocket()
|
||||||
|
sock = FakeSocket()
|
||||||
|
self.assertRaises(Exception, ws.accept,
|
||||||
|
sock, {'upgrade': 'websocket',
|
||||||
|
'Sec-WebSocket-Version': '13'})
|
||||||
|
|
||||||
|
def test_protocol(self):
|
||||||
|
class ProtoSocket(websocket.WebSocket):
|
||||||
|
def select_subprotocol(self, protocol):
|
||||||
|
return 'gazonk'
|
||||||
|
|
||||||
|
ws = ProtoSocket()
|
||||||
|
sock = FakeSocket()
|
||||||
|
ws.accept(sock, {'upgrade': 'websocket',
|
||||||
|
'Sec-WebSocket-Version': '13',
|
||||||
|
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q==',
|
||||||
|
'Sec-WebSocket-Protocol': 'foobar gazonk'})
|
||||||
|
self.assertEqual(sock.data[:13], b'HTTP/1.1 101 ')
|
||||||
|
self.assertIn(b'\r\nSec-WebSocket-Protocol: gazonk\r\n', sock.data)
|
||||||
|
|
||||||
|
def test_no_protocol(self):
|
||||||
|
ws = websocket.WebSocket()
|
||||||
|
sock = FakeSocket()
|
||||||
|
ws.accept(sock, {'upgrade': 'websocket',
|
||||||
|
'Sec-WebSocket-Version': '13',
|
||||||
|
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='})
|
||||||
|
self.assertEqual(sock.data[:13], b'HTTP/1.1 101 ')
|
||||||
|
self.assertNotIn(b'\r\nSec-WebSocket-Protocol:', sock.data)
|
||||||
|
|
||||||
|
def test_missing_protocol(self):
|
||||||
|
ws = websocket.WebSocket()
|
||||||
|
sock = FakeSocket()
|
||||||
|
self.assertRaises(Exception, ws.accept,
|
||||||
|
sock, {'upgrade': 'websocket',
|
||||||
|
'Sec-WebSocket-Version': '13',
|
||||||
|
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q==',
|
||||||
|
'Sec-WebSocket-Protocol': 'foobar gazonk'})
|
||||||
|
|
||||||
|
def test_protocol(self):
|
||||||
|
class ProtoSocket(websocket.WebSocket):
|
||||||
|
def select_subprotocol(self, protocol):
|
||||||
|
return 'oddball'
|
||||||
|
|
||||||
|
ws = ProtoSocket()
|
||||||
|
sock = FakeSocket()
|
||||||
|
self.assertRaises(Exception, ws.accept,
|
||||||
|
sock, {'upgrade': 'websocket',
|
||||||
|
'Sec-WebSocket-Version': '13',
|
||||||
|
'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q==',
|
||||||
|
'Sec-WebSocket-Protocol': 'foobar gazonk'})
|
||||||
|
|
||||||
class HyBiEncodeDecodeTestCase(unittest.TestCase):
|
class HyBiEncodeDecodeTestCase(unittest.TestCase):
|
||||||
def test_decode_hybi_text(self):
|
def test_decode_hybi_text(self):
|
||||||
buf = b'\x81\x85\x37\xfa\x21\x3d\x7f\x9f\x4d\x51\x58'
|
buf = b'\x81\x85\x37\xfa\x21\x3d\x7f\x9f\x4d\x51\x58'
|
||||||
|
|
|
@ -226,6 +226,14 @@ class WebSocket(object):
|
||||||
if accept != expected:
|
if accept != expected:
|
||||||
raise Exception("Invalid Sec-WebSocket-Accept header");
|
raise Exception("Invalid Sec-WebSocket-Accept header");
|
||||||
|
|
||||||
|
self.protocol = headers.get('Sec-WebSocket-Protocol')
|
||||||
|
if len(protocols) == 0:
|
||||||
|
if self.protocol is not None:
|
||||||
|
raise Exception("Unexpected Sec-WebSocket-Protocol header")
|
||||||
|
else:
|
||||||
|
if self.protocol not in protocols:
|
||||||
|
raise Exception("Invalid protocol chosen by server")
|
||||||
|
|
||||||
self._state = "done"
|
self._state = "done"
|
||||||
|
|
||||||
return
|
return
|
||||||
|
@ -282,6 +290,10 @@ class WebSocket(object):
|
||||||
protocols = headers.get('Sec-WebSocket-Protocol', '').split(',')
|
protocols = headers.get('Sec-WebSocket-Protocol', '').split(',')
|
||||||
if protocols:
|
if protocols:
|
||||||
self.protocol = self.select_subprotocol(protocols)
|
self.protocol = self.select_subprotocol(protocols)
|
||||||
|
# We are required to choose one of the protocols
|
||||||
|
# presented by the client
|
||||||
|
if self.protocol not in protocols:
|
||||||
|
raise Exception('Invalid protocol selected')
|
||||||
|
|
||||||
self._queue_str("HTTP/1.1 101 Switching Protocols\r\n")
|
self._queue_str("HTTP/1.1 101 Switching Protocols\r\n")
|
||||||
self._queue_str("Upgrade: websocket\r\n")
|
self._queue_str("Upgrade: websocket\r\n")
|
||||||
|
|
Loading…
Reference in New Issue