diff --git a/cloud/__init__.py b/cloud/__init__.py index de05e83..3f94a71 100644 --- a/cloud/__init__.py +++ b/cloud/__init__.py @@ -7,6 +7,7 @@ from pillar.api.utils import authorization from pillar.extension import PillarExtension EXTENSION_NAME = 'cloud' +ROLES_TO_BE_SUBSCRIBER = {'demo', 'subscriber', 'admin'} # TODO: get rid of this, use 'subscriber' cap class CloudExtension(PillarExtension): @@ -85,11 +86,49 @@ class CloudExtension(PillarExtension): } def setup_app(self, app): + """Links certain roles to the subscriber role. + + This means that users who get the subscriber role also get this linked + role, and when the subscriber role is revoked, the linked role is also + revoked. + """ + + from pillar.api.service import signal_user_changed_role from . import routes, webhooks + signal_user_changed_role.connect(self._user_changed_role) routes.setup_app(app) app.register_api_blueprint(webhooks.blueprint, '/webhooks') + def _user_changed_role(self, sender, user: dict): + from pillar.api import service + + linked_roles = {'flamenco-user', 'attract-user'} + link_to = {'subscriber', 'demo'} + user_roles = set(user.get('roles', [])) + + # Determine what to do + has_linked_roles = not (linked_roles - user_roles) + has_link_to = bool(link_to.intersection(user_roles)) + action = '' + if has_link_to and not has_linked_roles: + self._log.info('Granting roles %s to user %s', linked_roles, user['_id']) + action = 'grant' + elif not has_link_to and has_linked_roles: + self._log.info('Revoking roles %s from user %s', linked_roles, user['_id']) + action = 'revoke' + + if not action: + return + + # Avoid infinite loops while we're changing the user's roles. + service.signal_user_changed_role.disconnect(self._user_changed_role) + try: + if linked_roles: + service.do_badger(action, roles=linked_roles, user_id=user['_id']) + finally: + service.signal_user_changed_role.connect(self._user_changed_role) + def _get_current_cloud(): """Returns the Cloud extension of the current application.""" diff --git a/tests/test_linked_roles.py b/tests/test_linked_roles.py new file mode 100644 index 0000000..90eaa04 --- /dev/null +++ b/tests/test_linked_roles.py @@ -0,0 +1,31 @@ +from abstract_cloud_test import AbstractCloudTest + + +class LinkedRolesTest(AbstractCloudTest): + def test_linked_roles_subscriber(self): + user_id = self.create_user(roles=[]) + db_user = self.fetch_user_from_db(user_id) + + self.badger(db_user['email'], {'subscriber'}, 'grant') + db_user = self.fetch_user_from_db(user_id) + self.assertEqual({'subscriber', 'flamenco-user', 'attract-user'}, + set(db_user['roles'])) + + self.badger(db_user['email'], {'subscriber'}, 'revoke') + db_user = self.fetch_user_from_db(user_id) + self.assertEqual(set(), + set(db_user.get('roles', []))) + + def test_linked_roles_demo(self): + user_id = self.create_user(roles=[]) + db_user = self.fetch_user_from_db(user_id) + + self.badger(db_user['email'], {'demo'}, 'grant') + db_user = self.fetch_user_from_db(user_id) + self.assertEqual({'demo', 'flamenco-user', 'attract-user'}, + set(db_user['roles'])) + + self.badger(db_user['email'], {'demo'}, 'revoke') + db_user = self.fetch_user_from_db(user_id) + self.assertEqual(set(), + set(db_user.get('roles', []))) diff --git a/tests/test_webhooks.py b/tests/test_webhooks.py index 05238ea..ed6bda0 100644 --- a/tests/test_webhooks.py +++ b/tests/test_webhooks.py @@ -98,7 +98,7 @@ class UserModifiedTest(AbstractCloudTest): db_user = self.fetch_user_from_db(self.uid) self.assertEqual('old@email.address', db_user['email']) self.assertEqual('ကြယ်ဆွတ်', db_user['full_name']) - self.assertEqual({'demo'}, set(db_user['roles'])) + self.assertEqual({'flamenco-user', 'attract-user', 'demo'}, set(db_user['roles'])) def test_bad_hmac(self): payload = {'id': 1112333,