Allow HTTP headers to be set for @require_login() error responses
This makes the `require_login` decorator always return a Flask response. Previously it could also raise a `Forbidden` exception; now it returns a 403 Forbidden response in that case too.
This commit is contained in:
@@ -289,7 +289,8 @@ def require_login(*, require_roles=set(),
|
|||||||
require_cap='',
|
require_cap='',
|
||||||
require_all=False,
|
require_all=False,
|
||||||
redirect_to_login=False,
|
redirect_to_login=False,
|
||||||
error_view=None):
|
error_view=None,
|
||||||
|
error_headers: typing.Optional[typing.Dict[str, str]]=None):
|
||||||
"""Decorator that enforces users to authenticate.
|
"""Decorator that enforces users to authenticate.
|
||||||
|
|
||||||
Optionally only allows access to users with a certain role and/or capability.
|
Optionally only allows access to users with a certain role and/or capability.
|
||||||
@@ -313,6 +314,7 @@ def require_login(*, require_roles=set(),
|
|||||||
requests, and mimicks the flask_login behaviour.
|
requests, and mimicks the flask_login behaviour.
|
||||||
:param error_view: Callable that returns a Flask response object. This is
|
:param error_view: Callable that returns a Flask response object. This is
|
||||||
sent back to the client instead of the default 403 Forbidden.
|
sent back to the client instead of the default 403 Forbidden.
|
||||||
|
:param error_headers: HTTP headers to include in error responses.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from flask import request, redirect, url_for, Response
|
from flask import request, redirect, url_for, Response
|
||||||
@@ -331,9 +333,18 @@ def require_login(*, require_roles=set(),
|
|||||||
|
|
||||||
def render_error() -> Response:
|
def render_error() -> Response:
|
||||||
if error_view is None:
|
if error_view is None:
|
||||||
abort(403)
|
resp = Forbidden().get_response()
|
||||||
resp: Response = error_view()
|
else:
|
||||||
|
resp = error_view()
|
||||||
resp.status_code = 403
|
resp.status_code = 403
|
||||||
|
if error_headers:
|
||||||
|
for header_name, header_value in error_headers.items():
|
||||||
|
resp.headers.set(header_name, header_value)
|
||||||
|
|
||||||
|
if 'Access-Control-Allow-Origin' in error_headers:
|
||||||
|
origin = request.headers.get('Origin', '')
|
||||||
|
resp.headers.set('Access-Control-Allow-Origin', origin)
|
||||||
|
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
def decorator(func):
|
def decorator(func):
|
||||||
|
@@ -631,21 +631,25 @@ class RequireRolesTest(AbstractPillarTest):
|
|||||||
def test_some_roles_required(self):
|
def test_some_roles_required(self):
|
||||||
from pillar.api.utils.authorization import require_login
|
from pillar.api.utils.authorization import require_login
|
||||||
|
|
||||||
called = [False]
|
called = False
|
||||||
|
|
||||||
@require_login(require_roles={'admin'})
|
@require_login(require_roles={'admin'})
|
||||||
def call_me():
|
def call_me():
|
||||||
called[0] = True
|
nonlocal called
|
||||||
|
called = True
|
||||||
|
return None
|
||||||
|
|
||||||
with self.app.test_request_context():
|
with self.app.test_request_context():
|
||||||
self.login_api_as(ObjectId(24 * 'a'), ['succubus'])
|
self.login_api_as(ObjectId(24 * 'a'), ['succubus'])
|
||||||
self.assertRaises(Forbidden, call_me)
|
resp = call_me()
|
||||||
self.assertFalse(called[0])
|
self.assertEqual(403, resp.status_code)
|
||||||
|
self.assertFalse(called, 'Forbidden function should not have been called')
|
||||||
|
|
||||||
with self.app.test_request_context():
|
with self.app.test_request_context():
|
||||||
self.login_api_as(ObjectId(24 * 'a'), ['admin'])
|
self.login_api_as(ObjectId(24 * 'a'), ['admin'])
|
||||||
call_me()
|
resp = call_me()
|
||||||
self.assertTrue(called[0])
|
self.assertIsNone(resp)
|
||||||
|
self.assertTrue(called)
|
||||||
|
|
||||||
def test_all_roles_required(self):
|
def test_all_roles_required(self):
|
||||||
from pillar.api.utils.authorization import require_login
|
from pillar.api.utils.authorization import require_login
|
||||||
@@ -659,17 +663,20 @@ class RequireRolesTest(AbstractPillarTest):
|
|||||||
|
|
||||||
with self.app.test_request_context():
|
with self.app.test_request_context():
|
||||||
self.login_api_as(ObjectId(24 * 'a'), ['admin'])
|
self.login_api_as(ObjectId(24 * 'a'), ['admin'])
|
||||||
self.assertRaises(Forbidden, call_me)
|
resp = call_me()
|
||||||
|
self.assertEqual(403, resp.status_code)
|
||||||
self.assertFalse(called[0])
|
self.assertFalse(called[0])
|
||||||
|
|
||||||
with self.app.test_request_context():
|
with self.app.test_request_context():
|
||||||
self.login_api_as(ObjectId(24 * 'a'), ['service'])
|
self.login_api_as(ObjectId(24 * 'a'), ['service'])
|
||||||
self.assertRaises(Forbidden, call_me)
|
resp = call_me()
|
||||||
|
self.assertEqual(403, resp.status_code)
|
||||||
self.assertFalse(called[0])
|
self.assertFalse(called[0])
|
||||||
|
|
||||||
with self.app.test_request_context():
|
with self.app.test_request_context():
|
||||||
self.login_api_as(ObjectId(24 * 'a'), ['badger'])
|
self.login_api_as(ObjectId(24 * 'a'), ['badger'])
|
||||||
self.assertRaises(Forbidden, call_me)
|
resp = call_me()
|
||||||
|
self.assertEqual(403, resp.status_code)
|
||||||
self.assertFalse(called[0])
|
self.assertFalse(called[0])
|
||||||
|
|
||||||
with self.app.test_request_context():
|
with self.app.test_request_context():
|
||||||
@@ -702,7 +709,8 @@ class RequireRolesTest(AbstractPillarTest):
|
|||||||
|
|
||||||
with self.app.test_request_context():
|
with self.app.test_request_context():
|
||||||
self.login_api_as(ObjectId(24 * 'a'), ['succubus'])
|
self.login_api_as(ObjectId(24 * 'a'), ['succubus'])
|
||||||
self.assertRaises(Forbidden, call_me)
|
resp = call_me()
|
||||||
|
self.assertEqual(403, resp.status_code)
|
||||||
self.assertFalse(called[0])
|
self.assertFalse(called[0])
|
||||||
|
|
||||||
with self.app.test_request_context():
|
with self.app.test_request_context():
|
||||||
|
Reference in New Issue
Block a user