diff --git a/pillar/api/utils/authorization.py b/pillar/api/utils/authorization.py index b7d8f4c9..1a55cdc0 100644 --- a/pillar/api/utils/authorization.py +++ b/pillar/api/utils/authorization.py @@ -289,7 +289,8 @@ def require_login(*, require_roles=set(), require_cap='', require_all=False, redirect_to_login=False, - error_view=None): + error_view=None, + error_headers: typing.Optional[typing.Dict[str, str]]=None): """Decorator that enforces users to authenticate. Optionally only allows access to users with a certain role and/or capability. @@ -313,6 +314,7 @@ def require_login(*, require_roles=set(), requests, and mimicks the flask_login behaviour. :param error_view: Callable that returns a Flask response object. This is sent back to the client instead of the default 403 Forbidden. + :param error_headers: HTTP headers to include in error responses. """ from flask import request, redirect, url_for, Response @@ -331,9 +333,18 @@ def require_login(*, require_roles=set(), def render_error() -> Response: if error_view is None: - abort(403) - resp: Response = error_view() + resp = Forbidden().get_response() + else: + resp = error_view() resp.status_code = 403 + if error_headers: + for header_name, header_value in error_headers.items(): + resp.headers.set(header_name, header_value) + + if 'Access-Control-Allow-Origin' in error_headers: + origin = request.headers.get('Origin', '') + resp.headers.set('Access-Control-Allow-Origin', origin) + return resp def decorator(func): diff --git a/tests/test_api/test_auth.py b/tests/test_api/test_auth.py index 49cad737..349da074 100644 --- a/tests/test_api/test_auth.py +++ b/tests/test_api/test_auth.py @@ -631,21 +631,25 @@ class RequireRolesTest(AbstractPillarTest): def test_some_roles_required(self): from pillar.api.utils.authorization import require_login - called = [False] + called = False @require_login(require_roles={'admin'}) def call_me(): - called[0] = True + nonlocal called + called = True + return None with self.app.test_request_context(): self.login_api_as(ObjectId(24 * 'a'), ['succubus']) - self.assertRaises(Forbidden, call_me) - self.assertFalse(called[0]) + resp = call_me() + self.assertEqual(403, resp.status_code) + self.assertFalse(called, 'Forbidden function should not have been called') with self.app.test_request_context(): self.login_api_as(ObjectId(24 * 'a'), ['admin']) - call_me() - self.assertTrue(called[0]) + resp = call_me() + self.assertIsNone(resp) + self.assertTrue(called) def test_all_roles_required(self): from pillar.api.utils.authorization import require_login @@ -659,17 +663,20 @@ class RequireRolesTest(AbstractPillarTest): with self.app.test_request_context(): self.login_api_as(ObjectId(24 * 'a'), ['admin']) - self.assertRaises(Forbidden, call_me) + resp = call_me() + self.assertEqual(403, resp.status_code) self.assertFalse(called[0]) with self.app.test_request_context(): self.login_api_as(ObjectId(24 * 'a'), ['service']) - self.assertRaises(Forbidden, call_me) + resp = call_me() + self.assertEqual(403, resp.status_code) self.assertFalse(called[0]) with self.app.test_request_context(): self.login_api_as(ObjectId(24 * 'a'), ['badger']) - self.assertRaises(Forbidden, call_me) + resp = call_me() + self.assertEqual(403, resp.status_code) self.assertFalse(called[0]) with self.app.test_request_context(): @@ -702,7 +709,8 @@ class RequireRolesTest(AbstractPillarTest): with self.app.test_request_context(): self.login_api_as(ObjectId(24 * 'a'), ['succubus']) - self.assertRaises(Forbidden, call_me) + resp = call_me() + self.assertEqual(403, resp.status_code) self.assertFalse(called[0]) with self.app.test_request_context():