Use patch() as a decorator in tests

Cleaner and more robust.
This commit is contained in:
Pierre Ossman 2021-01-29 13:09:19 +01:00
parent a82eb10b48
commit 3f17696dc6
1 changed files with 10 additions and 17 deletions

View File

@ -22,7 +22,7 @@ import unittest
import socket import socket
from io import StringIO from io import StringIO
from io import BytesIO from io import BytesIO
from unittest.mock import patch from unittest.mock import patch, MagicMock
from jwcrypto import jwt from jwcrypto import jwt
@ -102,19 +102,19 @@ class ProxyRequestHandlerTestCase(unittest.TestCase):
self.assertRaises(FakeServer.EClose, self.handler.get_target, self.assertRaises(FakeServer.EClose, self.handler.get_target,
TestPlugin(None)) TestPlugin(None))
@patch('websockify.websocketproxy.ProxyRequestHandler.send_auth_error', MagicMock())
def test_token_plugin(self): def test_token_plugin(self):
class TestPlugin(token_plugins.BasePlugin): class TestPlugin(token_plugins.BasePlugin):
def lookup(self, token): def lookup(self, token):
return (self.source + token).split(',') return (self.source + token).split(',')
patcher = patch('websockify.websocketproxy.ProxyRequestHandler.send_auth_error').start()
self.handler.server.token_plugin = TestPlugin("somehost,") self.handler.server.token_plugin = TestPlugin("somehost,")
self.handler.validate_connection() self.handler.validate_connection()
self.assertEqual(self.handler.server.target_host, "somehost") self.assertEqual(self.handler.server.target_host, "somehost")
self.assertEqual(self.handler.server.target_port, "blah") self.assertEqual(self.handler.server.target_port, "blah")
@patch('websockify.websocketproxy.ProxyRequestHandler.send_auth_error', MagicMock())
def test_asymmetric_jws_token_plugin(self): def test_asymmetric_jws_token_plugin(self):
key = jwt.JWK() key = jwt.JWK()
private_key = open("./tests/fixtures/private.pem", "rb").read() private_key = open("./tests/fixtures/private.pem", "rb").read()
@ -123,14 +123,13 @@ class ProxyRequestHandlerTestCase(unittest.TestCase):
jwt_token.make_signed_token(key) jwt_token.make_signed_token(key)
self.handler.path = "https://localhost:6080/websockify?token={jwt_token}".format(jwt_token=jwt_token.serialize()) self.handler.path = "https://localhost:6080/websockify?token={jwt_token}".format(jwt_token=jwt_token.serialize())
patcher = patch('websockify.websocketproxy.ProxyRequestHandler.send_auth_error').start()
self.handler.server.token_plugin = token_plugins.JWTTokenApi("./tests/fixtures/public.pem") self.handler.server.token_plugin = token_plugins.JWTTokenApi("./tests/fixtures/public.pem")
self.handler.validate_connection() self.handler.validate_connection()
self.assertEqual(self.handler.server.target_host, "remote_host") self.assertEqual(self.handler.server.target_host, "remote_host")
self.assertEqual(self.handler.server.target_port, "remote_port") self.assertEqual(self.handler.server.target_port, "remote_port")
@patch('websockify.websocketproxy.ProxyRequestHandler.send_auth_error', MagicMock())
def test_asymmetric_jws_token_plugin_with_illigal_key_exception(self): def test_asymmetric_jws_token_plugin_with_illigal_key_exception(self):
key = jwt.JWK() key = jwt.JWK()
private_key = open("./tests/fixtures/private.pem", "rb").read() private_key = open("./tests/fixtures/private.pem", "rb").read()
@ -139,13 +138,12 @@ class ProxyRequestHandlerTestCase(unittest.TestCase):
jwt_token.make_signed_token(key) jwt_token.make_signed_token(key)
self.handler.path = "https://localhost:6080/websockify?token={jwt_token}".format(jwt_token=jwt_token.serialize()) self.handler.path = "https://localhost:6080/websockify?token={jwt_token}".format(jwt_token=jwt_token.serialize())
patcher = patch('websockify.websocketproxy.ProxyRequestHandler.send_auth_error').start()
self.handler.server.token_plugin = token_plugins.JWTTokenApi("wrong.pub") self.handler.server.token_plugin = token_plugins.JWTTokenApi("wrong.pub")
self.assertRaises(self.handler.server.EClose, self.assertRaises(self.handler.server.EClose,
self.handler.validate_connection) self.handler.validate_connection)
@patch('websockify.websocketproxy.ProxyRequestHandler.send_auth_error', MagicMock())
def test_symmetric_jws_token_plugin(self): def test_symmetric_jws_token_plugin(self):
secret = open("./tests/fixtures/symmetric.key").read() secret = open("./tests/fixtures/symmetric.key").read()
key = jwt.JWK() key = jwt.JWK()
@ -154,14 +152,13 @@ class ProxyRequestHandlerTestCase(unittest.TestCase):
jwt_token.make_signed_token(key) jwt_token.make_signed_token(key)
self.handler.path = "https://localhost:6080/websockify?token={jwt_token}".format(jwt_token=jwt_token.serialize()) self.handler.path = "https://localhost:6080/websockify?token={jwt_token}".format(jwt_token=jwt_token.serialize())
patcher = patch('websockify.websocketproxy.ProxyRequestHandler.send_auth_error').start()
self.handler.server.token_plugin = token_plugins.JWTTokenApi("./tests/fixtures/symmetric.key") self.handler.server.token_plugin = token_plugins.JWTTokenApi("./tests/fixtures/symmetric.key")
self.handler.validate_connection() self.handler.validate_connection()
self.assertEqual(self.handler.server.target_host, "remote_host") self.assertEqual(self.handler.server.target_host, "remote_host")
self.assertEqual(self.handler.server.target_port, "remote_port") self.assertEqual(self.handler.server.target_port, "remote_port")
@patch('websockify.websocketproxy.ProxyRequestHandler.send_auth_error', MagicMock())
def test_symmetric_jws_token_plugin_with_illigal_key_exception(self): def test_symmetric_jws_token_plugin_with_illigal_key_exception(self):
secret = open("./tests/fixtures/symmetric.key").read() secret = open("./tests/fixtures/symmetric.key").read()
key = jwt.JWK() key = jwt.JWK()
@ -170,12 +167,11 @@ class ProxyRequestHandlerTestCase(unittest.TestCase):
jwt_token.make_signed_token(key) jwt_token.make_signed_token(key)
self.handler.path = "https://localhost:6080/websockify?token={jwt_token}".format(jwt_token=jwt_token.serialize()) self.handler.path = "https://localhost:6080/websockify?token={jwt_token}".format(jwt_token=jwt_token.serialize())
patcher = patch('websockify.websocketproxy.ProxyRequestHandler.send_auth_error').start()
self.handler.server.token_plugin = token_plugins.JWTTokenApi("wrong_sauce") self.handler.server.token_plugin = token_plugins.JWTTokenApi("wrong_sauce")
self.assertRaises(self.handler.server.EClose, self.assertRaises(self.handler.server.EClose,
self.handler.validate_connection) self.handler.validate_connection)
@patch('websockify.websocketproxy.ProxyRequestHandler.send_auth_error', MagicMock())
def test_asymmetric_jwe_token_plugin(self): def test_asymmetric_jwe_token_plugin(self):
private_key = jwt.JWK() private_key = jwt.JWK()
public_key = jwt.JWK() public_key = jwt.JWK()
@ -191,22 +187,19 @@ class ProxyRequestHandlerTestCase(unittest.TestCase):
self.handler.path = "https://localhost:6080/websockify?token={jwt_token}".format(jwt_token=jwe_token.serialize()) self.handler.path = "https://localhost:6080/websockify?token={jwt_token}".format(jwt_token=jwe_token.serialize())
patcher = patch('websockify.websocketproxy.ProxyRequestHandler.send_auth_error').start()
self.handler.server.token_plugin = token_plugins.JWTTokenApi("./tests/fixtures/private.pem") self.handler.server.token_plugin = token_plugins.JWTTokenApi("./tests/fixtures/private.pem")
self.handler.validate_connection() self.handler.validate_connection()
self.assertEqual(self.handler.server.target_host, "remote_host") self.assertEqual(self.handler.server.target_host, "remote_host")
self.assertEqual(self.handler.server.target_port, "remote_port") self.assertEqual(self.handler.server.target_port, "remote_port")
@patch('websockify.websocketproxy.ProxyRequestHandler.send_auth_error', MagicMock())
def test_auth_plugin(self): def test_auth_plugin(self):
class TestPlugin(auth_plugins.BasePlugin): class TestPlugin(auth_plugins.BasePlugin):
def authenticate(self, headers, target_host, target_port): def authenticate(self, headers, target_host, target_port):
if target_host == self.source: if target_host == self.source:
raise auth_plugins.AuthenticationError(response_msg="some_error") raise auth_plugins.AuthenticationError(response_msg="some_error")
patcher = patch('websockify.websocketproxy.ProxyRequestHandler.send_auth_error').start()
self.handler.server.auth_plugin = TestPlugin("somehost") self.handler.server.auth_plugin = TestPlugin("somehost")
self.handler.server.target_host = "somehost" self.handler.server.target_host = "somehost"
self.handler.server.target_port = "someport" self.handler.server.target_port = "someport"