From 251f5ac86a2d55cfd35aaf1d62a8dfb816749ca0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sybren=20A=2E=20St=C3=BCvel?= Date: Thu, 7 Jul 2016 14:56:30 +0200 Subject: [PATCH] Added app.utils.str2id() to convert IDs on URLs to ObjectId. Raises a BadRequest exception when the ID is malformed. --- pillar/application/utils/__init__.py | 21 ++++++++++++++++++- tests/test_utils.py | 31 ++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 1 deletion(-) create mode 100644 tests/test_utils.py diff --git a/pillar/application/utils/__init__.py b/pillar/application/utils/__init__.py index 42ef8423..4a444375 100644 --- a/pillar/application/utils/__init__.py +++ b/pillar/application/utils/__init__.py @@ -4,9 +4,11 @@ import datetime import functools import logging -import bson +import bson.objectid from eve import RFC1123_DATE_FORMAT from flask import current_app +from werkzeug import exceptions as wz_exceptions + __all__ = ('remove_private_keys', 'PillarJSONEncoder') log = logging.getLogger(__name__) @@ -80,3 +82,20 @@ def project_get_node_type(project_document, node_type_node_name): return next((node_type for node_type in project_document['node_types'] if node_type['name'] == node_type_node_name), None) + + +def str2id(document_id): + """Returns the document ID as ObjectID, or raises a BadRequest exception. + + :type document_id: str + :rtype: bson.ObjectId + :raises: wz_exceptions.BadRequest + """ + + if not document_id: + raise wz_exceptions.BadRequest('Invalid object ID %r', document_id) + + try: + return bson.ObjectId(document_id) + except bson.objectid.InvalidId: + raise wz_exceptions.BadRequest('Invalid object ID %r', document_id) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..e616e148 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,31 @@ +# -*- encoding: utf-8 -*- + +from bson import ObjectId +from werkzeug.exceptions import BadRequest + +from common_test_class import AbstractPillarTest + + +class Str2idTest(AbstractPillarTest): + def test_happy(self): + from application.utils import str2id + + def happy(str_id): + self.assertEqual(ObjectId(str_id), str2id(str_id)) + + happy(24 * 'a') + happy(12 * 'a') + happy(u'577e23ad98377323f74c368c') + + def test_unhappy(self): + from application.utils import str2id + + def unhappy(str_id): + self.assertRaises(BadRequest, str2id, str_id) + + unhappy(13 * 'a') + unhappy(u'577e23ad 8377323f74c368c') + unhappy(u'김치') # Kimchi + unhappy('') + unhappy(u'') + unhappy(None)