diff --git a/pillar/web/projects/routes.py b/pillar/web/projects/routes.py index 5ae3c753..708de3a2 100644 --- a/pillar/web/projects/routes.py +++ b/pillar/web/projects/routes.py @@ -322,6 +322,9 @@ def view_node(project_url, node_id): node_id=node_id[1:]), code=301) # permanent redirect + if not utils.is_valid_id(node_id): + raise wz_exceptions.NotFound('No such node') + api = system_util.pillar_api() theatre_mode = 't' in request.args diff --git a/pillar/web/utils/__init__.py b/pillar/web/utils/__init__.py index 798bae57..f8fe796d 100644 --- a/pillar/web/utils/__init__.py +++ b/pillar/web/utils/__init__.py @@ -133,3 +133,36 @@ def get_main_project(): except KeyError: raise ConfigError('MAIN_PROJECT_ID missing from config.py') return main_project + + +def is_valid_id(some_id): + """Returns True iff the given string is a valid ObjectId. + + Only use this if you do NOT need an ObjectId object. If you do need that, + use pillar.api.utils.str2id() instead. + + :type some_id: unicode + :rtype: bool + """ + + if not isinstance(some_id, basestring): + return False + + if isinstance(some_id, unicode): + try: + some_id = some_id.encode('ascii') + except UnicodeEncodeError: + return False + + if len(some_id) == 12: + return True + elif len(some_id) == 24: + # This is more than 5x faster than checking character by + # character in a loop. + try: + int(some_id, 16) + except ValueError: + return False + return True + + return False diff --git a/tests/test_api/__init__.py b/tests/test_api/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_web/__init__.py b/tests/test_web/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_web/test_utils.py b/tests/test_web/test_utils.py new file mode 100644 index 00000000..389c1a0d --- /dev/null +++ b/tests/test_web/test_utils.py @@ -0,0 +1,37 @@ +# -*- encoding: utf-8 -*- + +import unittest + +from pillar.web import utils + + +class IsValidIdTest(unittest.TestCase): + def test_valid(self): + # 24-byte hex strings + self.assertTrue(utils.is_valid_id(24 * 'a')) + self.assertTrue(utils.is_valid_id(24 * u'a')) + self.assertTrue(utils.is_valid_id('deadbeefbeefcacedeadcace')) + self.assertTrue(utils.is_valid_id(u'deadbeefbeefcacedeadcace')) + + # 12-byte arbitrary ASCII strings + self.assertTrue(utils.is_valid_id('DeadBeefCake')) + self.assertTrue(utils.is_valid_id(u'DeadBeefCake')) + + # 12-byte str object + self.assertTrue(utils.is_valid_id('beef€67890')) + + def test_bad_length(self): + self.assertFalse(utils.is_valid_id(23 * 'a')) + self.assertFalse(utils.is_valid_id(25 * u'a')) + + def test_non_string(self): + self.assertFalse(utils.is_valid_id(None)) + self.assertFalse(utils.is_valid_id(1234)) + self.assertFalse(utils.is_valid_id([24 * 'a'])) + + def test_bad_content(self): + # 24-character non-hexadecimal string + self.assertFalse(utils.is_valid_id('deadbeefbeefcakedeadcake')) + + # unicode variant of valid 12-byte str object + self.assertFalse(utils.is_valid_id(u'beef€67890'))