From 5e721c61b96408d00999893145274a9fb75ea1be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sybren=20A=2E=20St=C3=BCvel?= Date: Wed, 24 May 2017 10:56:41 +0200 Subject: [PATCH] Added function to easily remove someone from a group. --- pillar/api/users/__init__.py | 43 ++++++++++++++++++++++++------------ tests/test_api/test_users.py | 38 +++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 14 deletions(-) diff --git a/pillar/api/users/__init__.py b/pillar/api/users/__init__.py index 78c16b11..d08078a7 100644 --- a/pillar/api/users/__init__.py +++ b/pillar/api/users/__init__.py @@ -9,34 +9,49 @@ from .routes import blueprint_api log = logging.getLogger(__name__) +def remove_user_from_group(user_id: bson.ObjectId, group_id: bson.ObjectId): + """Removes the user from the given group. + + Directly uses MongoDB, so that it doesn't require any special permissions. + """ + + log.info('Removing user %s from group %s', user_id, group_id) + user_group_action(user_id, group_id, '$pull') + + def add_user_to_group(user_id: bson.ObjectId, group_id: bson.ObjectId): """Makes the user member of the given group. - + Directly uses MongoDB, so that it doesn't require any special permissions. """ + log.info('Adding user %s to group %s', user_id, group_id) + user_group_action(user_id, group_id, '$addToSet') + + +def user_group_action(user_id: bson.ObjectId, group_id: bson.ObjectId, action: str): + """Performs a group action (add/remove). + + :param user_id: the user's ObjectID. + :param group_id: the group's ObjectID. + :param action: either '$pull' to remove from a group, or '$addToSet' to add to a group. + """ + from pymongo.results import UpdateResult assert isinstance(user_id, bson.ObjectId) assert isinstance(group_id, bson.ObjectId) - - log.info('Adding user %s to group %s', user_id, group_id) + assert action in {'$pull', '$addToSet'} users_coll = current_app.db('users') - db_user = users_coll.find_one(user_id, projection={'groups': 1}) - if db_user is None: - raise ValueError('user %s not found', user_id, group_id) - - groups = set(db_user.get('groups', [])) - groups.add(group_id) - - # Sort the groups so that we have predictable, repeatable results. result: UpdateResult = users_coll.update_one( - {'_id': db_user['_id']}, - {'$set': {'groups': sorted(groups)}}) + {'_id': user_id}, + {action: {'groups': group_id}}, + ) if result.matched_count == 0: - raise ValueError('Unable to add user %s to group %s; user not found.') + raise ValueError('Unable to %s user %s membership of group %s; user not found.', + action, user_id, group_id) def setup_app(app, api_prefix): diff --git a/tests/test_api/test_users.py b/tests/test_api/test_users.py index 6010dfbb..7465d112 100644 --- a/tests/test_api/test_users.py +++ b/tests/test_api/test_users.py @@ -48,3 +48,41 @@ class UsersTest(AbstractPillarTest): self.assertEqual([ self.group_map['test1'], ], db_user['groups']) + + def test_remove_user_from_group_happy(self): + from pillar.api import users + + user_id = bson.ObjectId(24 * '1') + + self.create_user(user_id, roles={'subscriber'}, groups=[ + self.group_map['subscriber'], + self.group_map['test1'], + ]) + + # Remove from existing group + with self.app.test_request_context(): + users.remove_user_from_group(user_id, self.group_map['test1']) + + db_user = self.fetch_user_from_db(user_id) + self.assertEqual([self.group_map['subscriber']], db_user['groups']) + + # Remove same group again, should be no-op + with self.app.test_request_context(): + users.remove_user_from_group(user_id, self.group_map['test1']) + + db_user = self.fetch_user_from_db(user_id) + self.assertEqual([self.group_map['subscriber']], db_user['groups']) + + # Remove from last group, should result in empty list. + with self.app.test_request_context(): + users.remove_user_from_group(user_id, self.group_map['subscriber']) + + db_user = self.fetch_user_from_db(user_id) + self.assertEqual([], db_user['groups']) + + # Remove non-existing group from empty list, should also work. + with self.app.test_request_context(): + users.remove_user_from_group(user_id, bson.ObjectId()) + + db_user = self.fetch_user_from_db(user_id) + self.assertEqual([], db_user['groups'])