diff --git a/pillar/api/utils/authorization.py b/pillar/api/utils/authorization.py index 1a55cdc0..9595114b 100644 --- a/pillar/api/utils/authorization.py +++ b/pillar/api/utils/authorization.py @@ -289,8 +289,7 @@ def require_login(*, require_roles=set(), require_cap='', require_all=False, redirect_to_login=False, - error_view=None, - error_headers: typing.Optional[typing.Dict[str, str]]=None): + error_view=None): """Decorator that enforces users to authenticate. Optionally only allows access to users with a certain role and/or capability. @@ -314,7 +313,6 @@ def require_login(*, require_roles=set(), requests, and mimicks the flask_login behaviour. :param error_view: Callable that returns a Flask response object. This is sent back to the client instead of the default 403 Forbidden. - :param error_headers: HTTP headers to include in error responses. """ from flask import request, redirect, url_for, Response @@ -337,14 +335,6 @@ def require_login(*, require_roles=set(), else: resp = error_view() resp.status_code = 403 - if error_headers: - for header_name, header_value in error_headers.items(): - resp.headers.set(header_name, header_value) - - if 'Access-Control-Allow-Origin' in error_headers: - origin = request.headers.get('Origin', '') - resp.headers.set('Access-Control-Allow-Origin', origin) - return resp def decorator(func): diff --git a/pillar/auth/cors.py b/pillar/auth/cors.py new file mode 100644 index 00000000..10515042 --- /dev/null +++ b/pillar/auth/cors.py @@ -0,0 +1,48 @@ +"""Support for adding CORS headers to responses.""" + +import functools + +import flask +import werkzeug.wrappers as wz_wrappers +import werkzeug.exceptions as wz_exceptions + + +def allow(*, allow_credentials=False): + """Flask endpoint decorator, adds CORS headers to the response. + + If the request has a non-empty 'Origin' header, the response header + 'Access-Control-Allow-Origin' is set to the value of that request header, + and some other CORS headers are set. + """ + def decorator(wrapped): + @functools.wraps(wrapped) + def wrapper(*args, **kwargs): + request_origin = flask.request.headers.get('Origin') + if not request_origin: + # No CORS headers requested, so don't bother touching the response. + return wrapped(*args, **kwargs) + + try: + response = wrapped(*args, **kwargs) + except wz_exceptions.HTTPException as ex: + response = ex.get_response() + else: + if isinstance(response, tuple): + response = flask.make_response(*response) + elif isinstance(response, str): + response = flask.make_response(response) + elif isinstance(response, wz_wrappers.Response): + pass + else: + raise TypeError(f'unknown response type {type(response)}') + + assert isinstance(response, wz_wrappers.Response) + + response.headers.set('Access-Control-Allow-Origin', request_origin) + response.headers.set('Access-Control-Allow-Headers', 'x-requested-with') + if allow_credentials: + response.headers.set('Access-Control-Allow-Credentials', 'true') + + return response + return wrapper + return decorator diff --git a/tests/test_auth/__init__.py b/tests/test_auth/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_auth/test_cors.py b/tests/test_auth/test_cors.py new file mode 100644 index 00000000..9649ba32 --- /dev/null +++ b/tests/test_auth/test_cors.py @@ -0,0 +1,127 @@ +from pillar.tests import AbstractPillarTest + +import flask +import werkzeug.wrappers as wz_wrappers +import werkzeug.exceptions as wz_exceptions + + +class CorsWrapperTest(AbstractPillarTest): + def test_noncors_request(self): + from pillar.auth.cors import allow + + @allow() + def wrapped(a, b): + return f'{a} and {b}' + + with self.app.test_request_context(): + resp = wrapped('x', 'y') + + self.assertEqual('x and y', resp, 'Non-CORS request should not be modified') + + def test_string_response(self): + from pillar.auth.cors import allow + + @allow() + def wrapped(a, b): + return f'{a} and {b}' + + with self.app.test_request_context(headers={'Origin': 'http://jemoeder.nl:1234/'}): + resp = wrapped('x', 'y') + + self.assertIsInstance(resp, wz_wrappers.Response) + self.assertEqual(b'x and y', resp.data) + self.assertEqual(200, resp.status_code) + + self.assertEqual('http://jemoeder.nl:1234/', resp.headers['Access-Control-Allow-Origin']) + self.assertEqual('x-requested-with', resp.headers['Access-Control-Allow-Headers']) + self.assertNotIn('Access-Control-Allow-Credentials', resp.headers) + + def test_string_with_code_response(self): + from pillar.auth.cors import allow + + @allow() + def wrapped(a, b): + return f'{a} and {b}', 403 + + with self.app.test_request_context(headers={'Origin': 'http://jemoeder.nl:1234/'}): + resp = wrapped('x', 'y') + + self.assertIsInstance(resp, wz_wrappers.Response) + self.assertEqual(b'x and y', resp.data) + self.assertEqual(403, resp.status_code) + + self.assertEqual('http://jemoeder.nl:1234/', resp.headers['Access-Control-Allow-Origin']) + self.assertEqual('x-requested-with', resp.headers['Access-Control-Allow-Headers']) + self.assertNotIn('Access-Control-Allow-Credentials', resp.headers) + + def test_flask_response_object(self): + from pillar.auth.cors import allow + + @allow() + def wrapped(a, b): + return flask.Response(f'{a} and {b}', status=147, headers={'op-je': 'hoofd'}) + + with self.app.test_request_context(headers={'Origin': 'http://jemoeder.nl:1234/'}): + resp = wrapped('x', 'y') + + self.assertIsInstance(resp, wz_wrappers.Response) + self.assertEqual(b'x and y', resp.data) + self.assertEqual(147, resp.status_code) + self.assertEqual('hoofd', resp.headers['Op-Je']) + + self.assertEqual('http://jemoeder.nl:1234/', resp.headers['Access-Control-Allow-Origin']) + self.assertEqual('x-requested-with', resp.headers['Access-Control-Allow-Headers']) + self.assertNotIn('Access-Control-Allow-Credentials', resp.headers) + + def test_wz_exception(self): + from pillar.auth.cors import allow + + @allow() + def wrapped(a, b): + raise wz_exceptions.NotImplemented('nee') + + with self.app.test_request_context(headers={'Origin': 'http://jemoeder.nl:1234/'}): + resp = wrapped('x', 'y') + + self.assertIsInstance(resp, wz_wrappers.Response) + self.assertIn(b'nee', resp.data) + self.assertEqual(501, resp.status_code) + + self.assertEqual('http://jemoeder.nl:1234/', resp.headers['Access-Control-Allow-Origin']) + self.assertEqual('x-requested-with', resp.headers['Access-Control-Allow-Headers']) + self.assertNotIn('Access-Control-Allow-Credentials', resp.headers) + + def test_flask_abort(self): + from pillar.auth.cors import allow + + @allow() + def wrapped(a, b): + raise flask.abort(401) + + with self.app.test_request_context(headers={'Origin': 'http://jemoeder.nl:1234/'}): + resp = wrapped('x', 'y') + + self.assertIsInstance(resp, wz_wrappers.Response) + self.assertEqual(401, resp.status_code) + + self.assertEqual('http://jemoeder.nl:1234/', resp.headers['Access-Control-Allow-Origin']) + self.assertEqual('x-requested-with', resp.headers['Access-Control-Allow-Headers']) + self.assertNotIn('Access-Control-Allow-Credentials', resp.headers) + + def test_with_credentials(self): + from pillar.auth.cors import allow + + @allow(allow_credentials=True) + def wrapped(a, b): + return f'{a} and {b}' + + with self.app.test_request_context(headers={'Origin': 'http://jemoeder.nl:1234/'}): + resp = wrapped('x', 'y') + + self.assertIsInstance(resp, wz_wrappers.Response) + self.assertEqual(b'x and y', resp.data) + self.assertEqual(200, resp.status_code) + + self.assertEqual('http://jemoeder.nl:1234/', resp.headers['Access-Control-Allow-Origin']) + self.assertEqual('x-requested-with', resp.headers['Access-Control-Allow-Headers']) + self.assertEqual('true', resp.headers['Access-Control-Allow-Credentials'])