Allow HTTP headers to be set for @require_login() error responses

This makes the `require_login` decorator always return a Flask response.
Previously it could also raise a `Forbidden` exception; now it returns a
403 Forbidden response in that case too.
This commit is contained in:
2019-03-18 14:42:00 +01:00
parent cfff5ef189
commit 0ee1d0d3da
2 changed files with 32 additions and 13 deletions

View File

@@ -289,7 +289,8 @@ def require_login(*, require_roles=set(),
require_cap='', require_cap='',
require_all=False, require_all=False,
redirect_to_login=False, redirect_to_login=False,
error_view=None): error_view=None,
error_headers: typing.Optional[typing.Dict[str, str]]=None):
"""Decorator that enforces users to authenticate. """Decorator that enforces users to authenticate.
Optionally only allows access to users with a certain role and/or capability. Optionally only allows access to users with a certain role and/or capability.
@@ -313,6 +314,7 @@ def require_login(*, require_roles=set(),
requests, and mimicks the flask_login behaviour. requests, and mimicks the flask_login behaviour.
:param error_view: Callable that returns a Flask response object. This is :param error_view: Callable that returns a Flask response object. This is
sent back to the client instead of the default 403 Forbidden. sent back to the client instead of the default 403 Forbidden.
:param error_headers: HTTP headers to include in error responses.
""" """
from flask import request, redirect, url_for, Response from flask import request, redirect, url_for, Response
@@ -331,9 +333,18 @@ def require_login(*, require_roles=set(),
def render_error() -> Response: def render_error() -> Response:
if error_view is None: if error_view is None:
abort(403) resp = Forbidden().get_response()
resp: Response = error_view() else:
resp = error_view()
resp.status_code = 403 resp.status_code = 403
if error_headers:
for header_name, header_value in error_headers.items():
resp.headers.set(header_name, header_value)
if 'Access-Control-Allow-Origin' in error_headers:
origin = request.headers.get('Origin', '')
resp.headers.set('Access-Control-Allow-Origin', origin)
return resp return resp
def decorator(func): def decorator(func):

View File

@@ -631,21 +631,25 @@ class RequireRolesTest(AbstractPillarTest):
def test_some_roles_required(self): def test_some_roles_required(self):
from pillar.api.utils.authorization import require_login from pillar.api.utils.authorization import require_login
called = [False] called = False
@require_login(require_roles={'admin'}) @require_login(require_roles={'admin'})
def call_me(): def call_me():
called[0] = True nonlocal called
called = True
return None
with self.app.test_request_context(): with self.app.test_request_context():
self.login_api_as(ObjectId(24 * 'a'), ['succubus']) self.login_api_as(ObjectId(24 * 'a'), ['succubus'])
self.assertRaises(Forbidden, call_me) resp = call_me()
self.assertFalse(called[0]) self.assertEqual(403, resp.status_code)
self.assertFalse(called, 'Forbidden function should not have been called')
with self.app.test_request_context(): with self.app.test_request_context():
self.login_api_as(ObjectId(24 * 'a'), ['admin']) self.login_api_as(ObjectId(24 * 'a'), ['admin'])
call_me() resp = call_me()
self.assertTrue(called[0]) self.assertIsNone(resp)
self.assertTrue(called)
def test_all_roles_required(self): def test_all_roles_required(self):
from pillar.api.utils.authorization import require_login from pillar.api.utils.authorization import require_login
@@ -659,17 +663,20 @@ class RequireRolesTest(AbstractPillarTest):
with self.app.test_request_context(): with self.app.test_request_context():
self.login_api_as(ObjectId(24 * 'a'), ['admin']) self.login_api_as(ObjectId(24 * 'a'), ['admin'])
self.assertRaises(Forbidden, call_me) resp = call_me()
self.assertEqual(403, resp.status_code)
self.assertFalse(called[0]) self.assertFalse(called[0])
with self.app.test_request_context(): with self.app.test_request_context():
self.login_api_as(ObjectId(24 * 'a'), ['service']) self.login_api_as(ObjectId(24 * 'a'), ['service'])
self.assertRaises(Forbidden, call_me) resp = call_me()
self.assertEqual(403, resp.status_code)
self.assertFalse(called[0]) self.assertFalse(called[0])
with self.app.test_request_context(): with self.app.test_request_context():
self.login_api_as(ObjectId(24 * 'a'), ['badger']) self.login_api_as(ObjectId(24 * 'a'), ['badger'])
self.assertRaises(Forbidden, call_me) resp = call_me()
self.assertEqual(403, resp.status_code)
self.assertFalse(called[0]) self.assertFalse(called[0])
with self.app.test_request_context(): with self.app.test_request_context():
@@ -702,7 +709,8 @@ class RequireRolesTest(AbstractPillarTest):
with self.app.test_request_context(): with self.app.test_request_context():
self.login_api_as(ObjectId(24 * 'a'), ['succubus']) self.login_api_as(ObjectId(24 * 'a'), ['succubus'])
self.assertRaises(Forbidden, call_me) resp = call_me()
self.assertEqual(403, resp.status_code)
self.assertFalse(called[0]) self.assertFalse(called[0])
with self.app.test_request_context(): with self.app.test_request_context():