diff --git a/bokeh/server/tests/test_tornado.py b/bokeh/server/tests/test_tornado.py index bc69ec298d9..d7228d3f191 100644 --- a/bokeh/server/tests/test_tornado.py +++ b/bokeh/server/tests/test_tornado.py @@ -17,4 +17,16 @@ def test__whitelist_replaces_prepare_only_once(): tornado._whitelist(h) new_prepare = h.prepare tornado._whitelist(h) - assert h.prepare == new_prepare \ No newline at end of file + assert h.prepare == new_prepare + +def test_check_whitelist_rejects_port_mismatch(): + assert False == tornado.check_whitelist("foo:100", ["foo:101", "foo:102"]) + +def test_check_whitelist_rejects_name_mismatch(): + assert False == tornado.check_whitelist("foo:100", ["bar:100", "baz:100"]) + +def test_check_whitelist_accepts_name_port_match(): + assert True == tornado.check_whitelist("foo:100", ["foo:100", "baz:100"]) + +def test_check_whitelist_accepts_implicit_port_80(): + assert True == tornado.check_whitelist("foo", ["foo:80"]) \ No newline at end of file diff --git a/bokeh/server/tornado.py b/bokeh/server/tornado.py index 09b2ebf863b..6d2ccae16e4 100644 --- a/bokeh/server/tornado.py +++ b/bokeh/server/tornado.py @@ -25,12 +25,29 @@ from .application_context import ApplicationContext from .views.static_handler import StaticHandler +# factored out to be easier to test +def check_whitelist(request_host, whitelist): + ''' Check a given request host against a whitelist. + + ''' + if request_host not in whitelist: + + # see if the request came with no port, assume port 80 in that case + if len(request_host.split(':')) == 1: + host = request_host + ":80" + return host in whitelist + else: + return False + + return True + + def _whitelist(handler_class): if hasattr(handler_class.prepare, 'patched'): return old_prepare = handler_class.prepare def _prepare(self, *args, **kw): - if self.request.host not in self.application._hosts: + if not check_whitelist(self.request.host, self.application._hosts): log.info("Rejected connection from host '%s' because it is not in the --host whitelist" % self.request.host) raise HTTPError(403) return old_prepare(self, *args, **kw) diff --git a/bokeh/server/views/ws.py b/bokeh/server/views/ws.py index f0be930e37f..ee8606977b5 100644 --- a/bokeh/server/views/ws.py +++ b/bokeh/server/views/ws.py @@ -42,12 +42,14 @@ def initialize(self, application_context, bokeh_websocket_path): pass def check_origin(self, origin): + from ..tornado import check_whitelist parsed_origin = urlparse(origin) origin_host = parsed_origin.netloc.lower() - allowed = self.application.websocket_origins + allowed_hosts = self.application.websocket_origins - if origin_host in allowed: + allowed = check_whitelist(origin_host, allowed_hosts) + if allowed: return True else: log.error("Refusing websocket connection from Origin '%s'; use --allow-websocket-origin=%s to permit this; currently we allow origins %r", origin, origin_host, allowed)