diff --git a/pillar/flask_extra.py b/pillar/flask_extra.py new file mode 100644 index 00000000..c84488f2 --- /dev/null +++ b/pillar/flask_extra.py @@ -0,0 +1,29 @@ +import functools +import flask + + +def add_response_headers(headers: dict): + """This decorator adds the headers passed in to the response""" + + def decorator(f): + @functools.wraps(f) + def decorated_function(*args, **kwargs): + resp = flask.make_response(f(*args, **kwargs)) + h = resp.headers + for header, value in headers.items(): + h[header] = value + return resp + + return decorated_function + + return decorator + + +def vary_xhr(): + """View function decorator; adds HTTP header "Vary: X-Requested-With" to the response""" + + def decorator(f): + header_adder = add_response_headers({'Vary': 'X-Requested-With'}) + return header_adder(f) + + return decorator diff --git a/tests/test_flask_extra.py b/tests/test_flask_extra.py new file mode 100644 index 00000000..f373afd2 --- /dev/null +++ b/tests/test_flask_extra.py @@ -0,0 +1,35 @@ +import unittest + +import flask + + +class FlaskExtraTest(unittest.TestCase): + def test_vary_xhr(self): + import pillar.flask_extra + + class TestApp(flask.Flask): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.add_url_rule('/must-vary', 'must-vary', self.must_vary) + self.add_url_rule('/no-vary', 'no-vary', self.no_vary) + + @pillar.flask_extra.vary_xhr() + def must_vary(self): + return 'yay' + + def no_vary(self): + return 'nah', 201 + + app = TestApp(__name__) + client = app.test_client() + + resp = client.get('/must-vary') + self.assertEqual(200, resp.status_code) + self.assertEqual('X-Requested-With', resp.headers['Vary']) + self.assertEqual('yay', resp.data.decode()) + + resp = client.get('/no-vary') + self.assertEqual(201, resp.status_code) + self.assertNotIn('Vary', resp.headers) + self.assertEqual('nah', resp.data.decode())