from __future__ import absolute_import, division, print_function import functools import sys import traceback from tornado.concurrent import Future from tornado import gen from tornado.httpclient import HTTPError, HTTPRequest from tornado.log import gen_log, app_log from tornado.simple_httpclient import SimpleAsyncHTTPClient from tornado.template import DictLoader from tornado.testing import AsyncHTTPTestCase, gen_test, bind_unused_port, ExpectLog from tornado.test.util import unittest, skipBefore35, exec_test from tornado.web import Application, RequestHandler try: import tornado.websocket # noqa from tornado.util import _websocket_mask_python except ImportError: # The unittest module presents misleading errors on ImportError # (it acts as if websocket_test could not be found, hiding the underlying # error). If we get an ImportError here (which could happen due to # TORNADO_EXTENSION=1), print some extra information before failing. traceback.print_exc() raise from tornado.websocket import WebSocketHandler, websocket_connect, WebSocketError, WebSocketClosedError try: from tornado import speedups except ImportError: speedups = None class TestWebSocketHandler(WebSocketHandler): """Base class for testing handlers that exposes the on_close event. This allows for deterministic cleanup of the associated socket. """ def initialize(self, close_future, compression_options=None): self.close_future = close_future self.compression_options = compression_options def get_compression_options(self): return self.compression_options def on_close(self): self.close_future.set_result((self.close_code, self.close_reason)) class EchoHandler(TestWebSocketHandler): @gen.coroutine def on_message(self, message): try: yield self.write_message(message, isinstance(message, bytes)) except WebSocketClosedError: pass class ErrorInOnMessageHandler(TestWebSocketHandler): def on_message(self, message): 1 / 0 class HeaderHandler(TestWebSocketHandler): def open(self): methods_to_test = [ functools.partial(self.write, 'This should not work'), functools.partial(self.redirect, 'http://localhost/elsewhere'), functools.partial(self.set_header, 'X-Test', ''), functools.partial(self.set_cookie, 'Chocolate', 'Chip'), functools.partial(self.set_status, 503), self.flush, self.finish, ] for method in methods_to_test: try: # In a websocket context, many RequestHandler methods # raise RuntimeErrors. method() raise Exception("did not get expected exception") except RuntimeError: pass self.write_message(self.request.headers.get('X-Test', '')) class HeaderEchoHandler(TestWebSocketHandler): def set_default_headers(self): self.set_header("X-Extra-Response-Header", "Extra-Response-Value") def prepare(self): for k, v in self.request.headers.get_all(): if k.lower().startswith('x-test'): self.set_header(k, v) class NonWebSocketHandler(RequestHandler): def get(self): self.write('ok') class CloseReasonHandler(TestWebSocketHandler): def open(self): self.on_close_called = False self.close(1001, "goodbye") class AsyncPrepareHandler(TestWebSocketHandler): @gen.coroutine def prepare(self): yield gen.moment def on_message(self, message): self.write_message(message) class PathArgsHandler(TestWebSocketHandler): def open(self, arg): self.write_message(arg) class CoroutineOnMessageHandler(TestWebSocketHandler): def initialize(self, close_future, compression_options=None): super(CoroutineOnMessageHandler, self).initialize(close_future, compression_options) self.sleeping = 0 @gen.coroutine def on_message(self, message): if self.sleeping > 0: self.write_message('another coroutine is already sleeping') self.sleeping += 1 yield gen.sleep(0.01) self.sleeping -= 1 self.write_message(message) class RenderMessageHandler(TestWebSocketHandler): def on_message(self, message): self.write_message(self.render_string('message.html', message=message)) class WebSocketBaseTestCase(AsyncHTTPTestCase): @gen.coroutine def ws_connect(self, path, **kwargs): ws = yield websocket_connect( 'ws://127.0.0.1:%d%s' % (self.get_http_port(), path), **kwargs) raise gen.Return(ws) @gen.coroutine def close(self, ws): """Close a websocket connection and wait for the server side. If we don't wait here, there are sometimes leak warnings in the tests. """ ws.close() yield self.close_future class WebSocketTest(WebSocketBaseTestCase): def get_app(self): self.close_future = Future() return Application([ ('/echo', EchoHandler, dict(close_future=self.close_future)), ('/non_ws', NonWebSocketHandler), ('/header', HeaderHandler, dict(close_future=self.close_future)), ('/header_echo', HeaderEchoHandler, dict(close_future=self.close_future)), ('/close_reason', CloseReasonHandler, dict(close_future=self.close_future)), ('/error_in_on_message', ErrorInOnMessageHandler, dict(close_future=self.close_future)), ('/async_prepare', AsyncPrepareHandler, dict(close_future=self.close_future)), ('/path_args/(.*)', PathArgsHandler, dict(close_future=self.close_future)), ('/coroutine', CoroutineOnMessageHandler, dict(close_future=self.close_future)), ('/render', RenderMessageHandler, dict(close_future=self.close_future)), ], template_loader=DictLoader({ 'message.html': '{{ message }}', })) def get_http_client(self): # These tests require HTTP/1; force the use of SimpleAsyncHTTPClient. return SimpleAsyncHTTPClient() def tearDown(self): super(WebSocketTest, self).tearDown() RequestHandler._template_loaders.clear() def test_http_request(self): # WS server, HTTP client. response = self.fetch('/echo') self.assertEqual(response.code, 400) def test_missing_websocket_key(self): response = self.fetch('/echo', headers={'Connection': 'Upgrade', 'Upgrade': 'WebSocket', 'Sec-WebSocket-Version': '13'}) self.assertEqual(response.code, 400) def test_bad_websocket_version(self): response = self.fetch('/echo', headers={'Connection': 'Upgrade', 'Upgrade': 'WebSocket', 'Sec-WebSocket-Version': '12'}) self.assertEqual(response.code, 426) @gen_test def test_websocket_gen(self): ws = yield self.ws_connect('/echo') yield ws.write_message('hello') response = yield ws.read_message() self.assertEqual(response, 'hello') yield self.close(ws) def test_websocket_callbacks(self): websocket_connect( 'ws://127.0.0.1:%d/echo' % self.get_http_port(), callback=self.stop) ws = self.wait().result() ws.write_message('hello') ws.read_message(self.stop) response = self.wait().result() self.assertEqual(response, 'hello') self.close_future.add_done_callback(lambda f: self.stop()) ws.close() self.wait() @gen_test def test_binary_message(self): ws = yield self.ws_connect('/echo') ws.write_message(b'hello \xe9', binary=True) response = yield ws.read_message() self.assertEqual(response, b'hello \xe9') yield self.close(ws) @gen_test def test_unicode_message(self): ws = yield self.ws_connect('/echo') ws.write_message(u'hello \u00e9') response = yield ws.read_message() self.assertEqual(response, u'hello \u00e9') yield self.close(ws) @gen_test def test_render_message(self): ws = yield self.ws_connect('/render') ws.write_message('hello') response = yield ws.read_message() self.assertEqual(response, 'hello') yield self.close(ws) @gen_test def test_error_in_on_message(self): ws = yield self.ws_connect('/error_in_on_message') ws.write_message('hello') with ExpectLog(app_log, "Uncaught exception"): response = yield ws.read_message() self.assertIs(response, None) yield self.close(ws) @gen_test def test_websocket_http_fail(self): with self.assertRaises(HTTPError) as cm: yield self.ws_connect('/notfound') self.assertEqual(cm.exception.code, 404) @gen_test def test_websocket_http_success(self): with self.assertRaises(WebSocketError): yield self.ws_connect('/non_ws') @gen_test def test_websocket_network_fail(self): sock, port = bind_unused_port() sock.close() with self.assertRaises(IOError): with ExpectLog(gen_log, ".*"): yield websocket_connect( 'ws://127.0.0.1:%d/' % port, connect_timeout=3600) @gen_test def test_websocket_close_buffered_data(self): ws = yield websocket_connect( 'ws://127.0.0.1:%d/echo' % self.get_http_port()) ws.write_message('hello') ws.write_message('world') # Close the underlying stream. ws.stream.close() yield self.close_future @gen_test def test_websocket_headers(self): # Ensure that arbitrary headers can be passed through websocket_connect. ws = yield websocket_connect( HTTPRequest('ws://127.0.0.1:%d/header' % self.get_http_port(), headers={'X-Test': 'hello'})) response = yield ws.read_message() self.assertEqual(response, 'hello') yield self.close(ws) @gen_test def test_websocket_header_echo(self): # Ensure that headers can be returned in the response. # Specifically, that arbitrary headers passed through websocket_connect # can be returned. ws = yield websocket_connect( HTTPRequest('ws://127.0.0.1:%d/header_echo' % self.get_http_port(), headers={'X-Test-Hello': 'hello'})) self.assertEqual(ws.headers.get('X-Test-Hello'), 'hello') self.assertEqual(ws.headers.get('X-Extra-Response-Header'), 'Extra-Response-Value') yield self.close(ws) @gen_test def test_server_close_reason(self): ws = yield self.ws_connect('/close_reason') msg = yield ws.read_message() # A message of None means the other side closed the connection. self.assertIs(msg, None) self.assertEqual(ws.close_code, 1001) self.assertEqual(ws.close_reason, "goodbye") # The on_close callback is called no matter which side closed. code, reason = yield self.close_future # The client echoed the close code it received to the server, # so the server's close code (returned via close_future) is # the same. self.assertEqual(code, 1001) @gen_test def test_client_close_reason(self): ws = yield self.ws_connect('/echo') ws.close(1001, 'goodbye') code, reason = yield self.close_future self.assertEqual(code, 1001) self.assertEqual(reason, 'goodbye') @gen_test def test_write_after_close(self): ws = yield self.ws_connect('/close_reason') msg = yield ws.read_message() self.assertIs(msg, None) with self.assertRaises(WebSocketClosedError): ws.write_message('hello') @gen_test def test_async_prepare(self): # Previously, an async prepare method triggered a bug that would # result in a timeout on test shutdown (and a memory leak). ws = yield self.ws_connect('/async_prepare') ws.write_message('hello') res = yield ws.read_message() self.assertEqual(res, 'hello') @gen_test def test_path_args(self): ws = yield self.ws_connect('/path_args/hello') res = yield ws.read_message() self.assertEqual(res, 'hello') @gen_test def test_coroutine(self): ws = yield self.ws_connect('/coroutine') # Send both messages immediately, coroutine must process one at a time. yield ws.write_message('hello1') yield ws.write_message('hello2') res = yield ws.read_message() self.assertEqual(res, 'hello1') res = yield ws.read_message() self.assertEqual(res, 'hello2') @gen_test def test_check_origin_valid_no_path(self): port = self.get_http_port() url = 'ws://127.0.0.1:%d/echo' % port headers = {'Origin': 'http://127.0.0.1:%d' % port} ws = yield websocket_connect(HTTPRequest(url, headers=headers)) ws.write_message('hello') response = yield ws.read_message() self.assertEqual(response, 'hello') yield self.close(ws) @gen_test def test_check_origin_valid_with_path(self): port = self.get_http_port() url = 'ws://127.0.0.1:%d/echo' % port headers = {'Origin': 'http://127.0.0.1:%d/something' % port} ws = yield websocket_connect(HTTPRequest(url, headers=headers)) ws.write_message('hello') response = yield ws.read_message() self.assertEqual(response, 'hello') yield self.close(ws) @gen_test def test_check_origin_invalid_partial_url(self): port = self.get_http_port() url = 'ws://127.0.0.1:%d/echo' % port headers = {'Origin': '127.0.0.1:%d' % port} with self.assertRaises(HTTPError) as cm: yield websocket_connect(HTTPRequest(url, headers=headers)) self.assertEqual(cm.exception.code, 403) @gen_test def test_check_origin_invalid(self): port = self.get_http_port() url = 'ws://127.0.0.1:%d/echo' % port # Host is 127.0.0.1, which should not be accessible from some other # domain headers = {'Origin': 'http://somewhereelse.com'} with self.assertRaises(HTTPError) as cm: yield websocket_connect(HTTPRequest(url, headers=headers)) self.assertEqual(cm.exception.code, 403) @gen_test def test_check_origin_invalid_subdomains(self): port = self.get_http_port() url = 'ws://localhost:%d/echo' % port # Subdomains should be disallowed by default. If we could pass a # resolver to websocket_connect we could test sibling domains as well. headers = {'Origin': 'http://subtenant.localhost'} with self.assertRaises(HTTPError) as cm: yield websocket_connect(HTTPRequest(url, headers=headers)) self.assertEqual(cm.exception.code, 403) if sys.version_info >= (3, 5): NativeCoroutineOnMessageHandler = exec_test(globals(), locals(), """ class NativeCoroutineOnMessageHandler(TestWebSocketHandler): def initialize(self, close_future, compression_options=None): super().initialize(close_future, compression_options) self.sleeping = 0 async def on_message(self, message): if self.sleeping > 0: self.write_message('another coroutine is already sleeping') self.sleeping += 1 await gen.sleep(0.01) self.sleeping -= 1 self.write_message(message)""")['NativeCoroutineOnMessageHandler'] class WebSocketNativeCoroutineTest(WebSocketBaseTestCase): def get_app(self): self.close_future = Future() return Application([ ('/native', NativeCoroutineOnMessageHandler, dict(close_future=self.close_future))]) @skipBefore35 @gen_test def test_native_coroutine(self): ws = yield self.ws_connect('/native') # Send both messages immediately, coroutine must process one at a time. yield ws.write_message('hello1') yield ws.write_message('hello2') res = yield ws.read_message() self.assertEqual(res, 'hello1') res = yield ws.read_message() self.assertEqual(res, 'hello2') class CompressionTestMixin(object): MESSAGE = 'Hello world. Testing 123 123' def get_app(self): self.close_future = Future() return Application([ ('/echo', EchoHandler, dict( close_future=self.close_future, compression_options=self.get_server_compression_options())), ]) def get_server_compression_options(self): return None def get_client_compression_options(self): return None @gen_test def test_message_sizes(self): ws = yield self.ws_connect( '/echo', compression_options=self.get_client_compression_options()) # Send the same message three times so we can measure the # effect of the context_takeover options. for i in range(3): ws.write_message(self.MESSAGE) response = yield ws.read_message() self.assertEqual(response, self.MESSAGE) self.assertEqual(ws.protocol._message_bytes_out, len(self.MESSAGE) * 3) self.assertEqual(ws.protocol._message_bytes_in, len(self.MESSAGE) * 3) self.verify_wire_bytes(ws.protocol._wire_bytes_in, ws.protocol._wire_bytes_out) yield self.close(ws) class UncompressedTestMixin(CompressionTestMixin): """Specialization of CompressionTestMixin when we expect no compression.""" def verify_wire_bytes(self, bytes_in, bytes_out): # Bytes out includes the 4-byte mask key per message. self.assertEqual(bytes_out, 3 * (len(self.MESSAGE) + 6)) self.assertEqual(bytes_in, 3 * (len(self.MESSAGE) + 2)) class NoCompressionTest(UncompressedTestMixin, WebSocketBaseTestCase): pass # If only one side tries to compress, the extension is not negotiated. class ServerOnlyCompressionTest(UncompressedTestMixin, WebSocketBaseTestCase): def get_server_compression_options(self): return {} class ClientOnlyCompressionTest(UncompressedTestMixin, WebSocketBaseTestCase): def get_client_compression_options(self): return {} class DefaultCompressionTest(CompressionTestMixin, WebSocketBaseTestCase): def get_server_compression_options(self): return {} def get_client_compression_options(self): return {} def verify_wire_bytes(self, bytes_in, bytes_out): self.assertLess(bytes_out, 3 * (len(self.MESSAGE) + 6)) self.assertLess(bytes_in, 3 * (len(self.MESSAGE) + 2)) # Bytes out includes the 4 bytes mask key per message. self.assertEqual(bytes_out, bytes_in + 12) class MaskFunctionMixin(object): # Subclasses should define self.mask(mask, data) def test_mask(self): self.assertEqual(self.mask(b'abcd', b''), b'') self.assertEqual(self.mask(b'abcd', b'b'), b'\x03') self.assertEqual(self.mask(b'abcd', b'54321'), b'TVPVP') self.assertEqual(self.mask(b'ZXCV', b'98765432'), b'c`t`olpd') # Include test cases with \x00 bytes (to ensure that the C # extension isn't depending on null-terminated strings) and # bytes with the high bit set (to smoke out signedness issues). self.assertEqual(self.mask(b'\x00\x01\x02\x03', b'\xff\xfb\xfd\xfc\xfe\xfa'), b'\xff\xfa\xff\xff\xfe\xfb') self.assertEqual(self.mask(b'\xff\xfb\xfd\xfc', b'\x00\x01\x02\x03\x04\x05'), b'\xff\xfa\xff\xff\xfb\xfe') class PythonMaskFunctionTest(MaskFunctionMixin, unittest.TestCase): def mask(self, mask, data): return _websocket_mask_python(mask, data) @unittest.skipIf(speedups is None, "tornado.speedups module not present") class CythonMaskFunctionTest(MaskFunctionMixin, unittest.TestCase): def mask(self, mask, data): return speedups.websocket_mask(mask, data) class ServerPeriodicPingTest(WebSocketBaseTestCase): def get_app(self): class PingHandler(TestWebSocketHandler): def on_pong(self, data): self.write_message("got pong") self.close_future = Future() return Application([ ('/', PingHandler, dict(close_future=self.close_future)), ], websocket_ping_interval=0.01) @gen_test def test_server_ping(self): ws = yield self.ws_connect('/') for i in range(3): response = yield ws.read_message() self.assertEqual(response, "got pong") yield self.close(ws) # TODO: test that the connection gets closed if ping responses stop. class ClientPeriodicPingTest(WebSocketBaseTestCase): def get_app(self): class PingHandler(TestWebSocketHandler): def on_ping(self, data): self.write_message("got ping") self.close_future = Future() return Application([ ('/', PingHandler, dict(close_future=self.close_future)), ]) @gen_test def test_client_ping(self): ws = yield self.ws_connect('/', ping_interval=0.01) for i in range(3): response = yield ws.read_message() self.assertEqual(response, "got ping") yield self.close(ws) # TODO: test that the connection gets closed if ping responses stop. class MaxMessageSizeTest(WebSocketBaseTestCase): def get_app(self): self.close_future = Future() return Application([ ('/', EchoHandler, dict(close_future=self.close_future)), ], websocket_max_message_size=1024) @gen_test def test_large_message(self): ws = yield self.ws_connect('/') # Write a message that is allowed. msg = 'a' * 1024 ws.write_message(msg) resp = yield ws.read_message() self.assertEqual(resp, msg) # Write a message that is too large. ws.write_message(msg + 'b') resp = yield ws.read_message() # A message of None means the other side closed the connection. self.assertIs(resp, None) self.assertEqual(ws.close_code, 1009) self.assertEqual(ws.close_reason, "message too big") # TODO: Needs tests of messages split over multiple # continuation frames.