From 2e41c074b52cd8e6e4966a8c5ce0e2279949fbd8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sybren=20A=2E=20St=C3=BCvel?= Date: Fri, 3 Mar 2017 14:14:36 +0100 Subject: [PATCH] Python 3.6 compatibility: bytes vs strings stuff These changes mostly revolve around the change in ObjectId constructor when running on Python 3.6. Where on 2.7 the constructor would accept 12- and 24-byte strings, now only 12-byte bytes and 24-character strings are accepted. Good thing, but required some changes in our code. Other changes include hashing of strings, which isn't supported, so they are converted to bytes first, and sometimes converted back afterwards. --- pillar/api/file_storage/__init__.py | 4 +-- pillar/api/local_auth.py | 20 +++++++++----- pillar/api/service.py | 2 +- pillar/api/utils/__init__.py | 2 +- pillar/api/utils/authentication.py | 2 +- pillar/tests/__init__.py | 5 ++-- pillar/web/utils/__init__.py | 32 ++++++++++------------- tests/test_api/test_local_auth.py | 11 ++++++++ tests/test_api/test_nodes_moving.py | 4 +-- tests/test_api/test_project_management.py | 11 ++++---- tests/test_api/test_utils.py | 5 ++-- tests/test_web/test_utils.py | 10 +++---- 12 files changed, 62 insertions(+), 46 deletions(-) diff --git a/pillar/api/file_storage/__init__.py b/pillar/api/file_storage/__init__.py index c630143b..8ca6affd 100644 --- a/pillar/api/file_storage/__init__.py +++ b/pillar/api/file_storage/__init__.py @@ -296,7 +296,7 @@ def delete_file(file_item): process_file_delete(file_item) -def generate_link(backend, file_path, project_id=None, is_public=False): +def generate_link(backend, file_path: str, project_id=None, is_public=False): """Hook to check the backend of a file resource, to build an appropriate link that can be used by the client to retrieve the actual file. """ @@ -328,7 +328,7 @@ def generate_link(backend, file_path, project_id=None, is_public=False): if backend == 'cdnsun': return hash_file_path(file_path, None) if backend == 'unittest': - return 'https://unit.test/%s' % md5(file_path).hexdigest() + return 'https://unit.test/%s' % md5(file_path.encode()).hexdigest() log.warning('generate_link(): Unknown backend %r, returning empty string ' 'as new link.', diff --git a/pillar/api/local_auth.py b/pillar/api/local_auth.py index 0febed98..091c521b 100644 --- a/pillar/api/local_auth.py +++ b/pillar/api/local_auth.py @@ -1,6 +1,7 @@ import base64 import hashlib import logging +import typing import bcrypt import datetime @@ -67,12 +68,12 @@ def make_token(): return jsonify(token=token['token']) -def generate_and_store_token(user_id, days=15, prefix=''): +def generate_and_store_token(user_id, days=15, prefix=b''): """Generates token based on random bits. :param user_id: ObjectId of the owning user. :param days: token will expire in this many days. - :param prefix: the token will be prefixed by this string, for easy identification. + :param prefix: the token will be prefixed by these bytes, for easy identification. :return: the token document. """ @@ -80,17 +81,22 @@ def generate_and_store_token(user_id, days=15, prefix=''): # Use 'xy' as altargs to prevent + and / characters from appearing. # We never have to b64decode the string anyway. - token = prefix + base64.b64encode(random_bits, altchars='xy').strip('=') + token = prefix + base64.b64encode(random_bits, altchars=b'xy').strip(b'=') token_expiry = datetime.datetime.now(tz=tz_util.utc) + datetime.timedelta(days=days) - return store_token(user_id, token, token_expiry) + return store_token(user_id, token.decode('ascii'), token_expiry) -def hash_password(password, salt): +def hash_password(password: str, salt: typing.Union[str, bytes]) -> str: + password = password.encode() + if isinstance(salt, str): salt = salt.encode('utf-8') - encoded_password = base64.b64encode(hashlib.sha256(password).digest()) - return bcrypt.hashpw(encoded_password, salt) + + hash = hashlib.sha256(password).digest() + encoded_password = base64.b64encode(hash) + hashed_password = bcrypt.hashpw(encoded_password, salt) + return hashed_password.decode('ascii') def setup_app(app, url_prefix): diff --git a/pillar/api/service.py b/pillar/api/service.py index ecd03a80..3e165a2d 100644 --- a/pillar/api/service.py +++ b/pillar/api/service.py @@ -221,7 +221,7 @@ def create_service_account(email, roles, service, update_existing=None): user.update(result) # Create an authentication token that won't expire for a long time. - token = local_auth.generate_and_store_token(user['_id'], days=36500, prefix='SRV') + token = local_auth.generate_and_store_token(user['_id'], days=36500, prefix=b'SRV') return user, token diff --git a/pillar/api/utils/__init__.py b/pillar/api/utils/__init__.py index 9eb397a1..77fdc19c 100644 --- a/pillar/api/utils/__init__.py +++ b/pillar/api/utils/__init__.py @@ -136,7 +136,7 @@ def str2id(document_id): try: return bson.ObjectId(document_id) - except bson.objectid.InvalidId: + except (bson.objectid.InvalidId, TypeError): log.debug('str2id(%r): Invalid Object ID', document_id) raise wz_exceptions.BadRequest('Invalid object ID %r' % document_id) diff --git a/pillar/api/utils/authentication.py b/pillar/api/utils/authentication.py index b58fd568..b1636423 100644 --- a/pillar/api/utils/authentication.py +++ b/pillar/api/utils/authentication.py @@ -118,7 +118,7 @@ def find_token(token, is_subclient_token=False, **extra_filters): return db_token -def store_token(user_id, token, token_expiry, oauth_subclient_id=False): +def store_token(user_id, token: str, token_expiry, oauth_subclient_id=False): """Stores an authentication token. :returns: the token document from MongoDB diff --git a/pillar/tests/__init__.py b/pillar/tests/__init__.py index 9323c258..704f62bc 100644 --- a/pillar/tests/__init__.py +++ b/pillar/tests/__init__.py @@ -301,10 +301,11 @@ class AbstractPillarTest(TestMinimal): json=BLENDER_ID_USER_RESPONSE, status=200) - def make_header(self, username, subclient_id=''): + def make_header(self, username: str, subclient_id: str='') -> bytes: """Returns a Basic HTTP Authentication header value.""" - return 'basic ' + base64.b64encode('%s:%s' % (username, subclient_id)) + content = '%s:%s' % (username, subclient_id) + return b'basic ' + base64.b64encode(content.encode()) def create_standard_groups(self, additional_groups=()): """Creates standard admin/demo/subscriber groups, plus any additional. diff --git a/pillar/web/utils/__init__.py b/pillar/web/utils/__init__.py index 7848ce8b..67da2fb9 100644 --- a/pillar/web/utils/__init__.py +++ b/pillar/web/utils/__init__.py @@ -4,6 +4,7 @@ import urllib.request, urllib.parse, urllib.error import logging import traceback import sys +import typing import dateutil.parser from flask import current_app @@ -174,7 +175,7 @@ def get_main_project(): return main_project -def is_valid_id(some_id): +def is_valid_id(some_id: typing.Union[str, bytes]): """Returns True iff the given string is a valid ObjectId. Only use this if you do NOT need an ObjectId object. If you do need that, @@ -184,27 +185,22 @@ def is_valid_id(some_id): :rtype: bool """ + if isinstance(some_id, bytes): + return len(some_id) == 12 + if not isinstance(some_id, str): return False - if isinstance(some_id, str): - try: - some_id = some_id.encode('ascii') - except UnicodeEncodeError: - return False + if len(some_id) != 24: + return False - if len(some_id) == 12: - return True - elif len(some_id) == 24: - # This is more than 5x faster than checking character by - # character in a loop. - try: - int(some_id, 16) - except ValueError: - return False - return True - - return False + # This is more than 5x faster than checking character by + # character in a loop. + try: + int(some_id, 16) + except ValueError: + return False + return True def last_page_index(meta_info): diff --git a/tests/test_api/test_local_auth.py b/tests/test_api/test_local_auth.py index e2f3d274..5c3e8285 100644 --- a/tests/test_api/test_local_auth.py +++ b/tests/test_api/test_local_auth.py @@ -75,3 +75,14 @@ class LocalAuthTest(AbstractPillarTest): 'password': 'koro'}) self.assertEqual(403, resp.status_code, resp.data) + + def test_hash_password(self): + from pillar.api.local_auth import hash_password + + salt = b'$2b$12$cHdK4M8/yJ7SWp2Q.PYW0O' + self.assertEqual(hash_password('© 2017 je moeder™', salt), + '$2b$12$cHdK4M8/yJ7SWp2Q.PYW0OAU1gE3DIVdeehq0XIzOMM0Vp3ldPMb6') + self.assertIsInstance(hash_password('Резиновая уточка', salt), str) + + # The password should be encodable as ASCII. + hash_password('Резиновая уточка', salt).encode('ascii') diff --git a/tests/test_api/test_nodes_moving.py b/tests/test_api/test_nodes_moving.py index 9f735a01..246a2ce1 100644 --- a/tests/test_api/test_nodes_moving.py +++ b/tests/test_api/test_nodes_moving.py @@ -95,7 +95,7 @@ class NodeMoverTest(unittest.TestCase): ], } } - prid = ObjectId('project_dest') + prid = ObjectId(b'project_dest') new_project = { '_id': prid } @@ -124,7 +124,7 @@ class NodeMoverTest(unittest.TestCase): ], } } - prid = ObjectId('project_dest') + prid = ObjectId(b'project_dest') new_project = { '_id': prid } diff --git a/tests/test_api/test_project_management.py b/tests/test_api/test_project_management.py index 389783cb..09ba337d 100644 --- a/tests/test_api/test_project_management.py +++ b/tests/test_api/test_project_management.py @@ -146,7 +146,7 @@ class ProjectCreationTest(AbstractProjectTest): self.assertEqual({'Prøject B'}, {p['name'] for p in proj_list['_items']}) # No access to anything for user C, should result in empty list. - self._create_user_with_token(roles={'subscriber'}, token='token-c', user_id=12 * 'c') + self._create_user_with_token(roles={'subscriber'}, token='token-c', user_id=24 * 'c') resp = self.client.get('/api/projects', headers={'Authorization': self.make_header('token-c')}) self.assertEqual(200, resp.status_code) @@ -277,7 +277,7 @@ class ProjectEditTest(AbstractProjectTest): project_url = '/api/projects/%s' % project_id # Create test user. - self._create_user_with_token(['admin'], 'admin-token', user_id='cafef00dbeef') + self._create_user_with_token(['admin'], 'admin-token', user_id='cafef00dbeefcafef00dbeef') # Admin user should be able to PUT. put_project = remove_private_keys(project) @@ -324,7 +324,7 @@ class ProjectEditTest(AbstractProjectTest): # Create admin user that doesn't own the project, to check that # non-owner admins can delete projects too. - self._create_user_with_token(['admin'], 'admin-token', user_id='cafef00dbeef') + self._create_user_with_token(['admin'], 'admin-token', user_id='cafef00dbeefcafef00dbeef') # Admin user should be able to DELETE. resp = self.client.delete(project_url, @@ -361,7 +361,8 @@ class ProjectEditTest(AbstractProjectTest): project_url = '/api/projects/%s' % project_id # Create test user. - self._create_user_with_token(['subscriber'], 'mortal-token', user_id='cafef00dbeef') + self._create_user_with_token(['subscriber'], 'mortal-token', + user_id='cafef00dbeefcafef00dbeef') # Other user should NOT be able to DELETE. resp = self.client.delete(project_url, @@ -451,7 +452,7 @@ class ProjectNodeAccess(AbstractProjectTest): put_project = remove_private_keys(self.project) # Create admin user. - self._create_user_with_token(['admin'], 'admin-token', user_id='cafef00dbeef') + self._create_user_with_token(['admin'], 'admin-token', user_id='cafef00dbeefcafef00dbeef') # Make the project public put_project['permissions']['world'] = ['GET'] # make public diff --git a/tests/test_api/test_utils.py b/tests/test_api/test_utils.py index 8fcc9153..de1d350d 100644 --- a/tests/test_api/test_utils.py +++ b/tests/test_api/test_utils.py @@ -15,7 +15,7 @@ class Str2idTest(AbstractPillarTest): self.assertEqual(ObjectId(str_id), str2id(str_id)) happy(24 * 'a') - happy(12 * 'a') + happy(12 * b'a') happy('577e23ad98377323f74c368c') def test_unhappy(self): @@ -25,10 +25,11 @@ class Str2idTest(AbstractPillarTest): self.assertRaises(BadRequest, str2id, str_id) unhappy(13 * 'a') + unhappy(13 * b'a') unhappy('577e23ad 8377323f74c368c') unhappy('김치') # Kimchi unhappy('') - unhappy('') + unhappy(b'') unhappy(None) diff --git a/tests/test_web/test_utils.py b/tests/test_web/test_utils.py index ea1eff80..f14f3642 100644 --- a/tests/test_web/test_utils.py +++ b/tests/test_web/test_utils.py @@ -15,12 +15,12 @@ class IsValidIdTest(unittest.TestCase): self.assertTrue(utils.is_valid_id('deadbeefbeefcacedeadcace')) self.assertTrue(utils.is_valid_id('deadbeefbeefcacedeadcace')) - # 12-byte arbitrary ASCII strings - self.assertTrue(utils.is_valid_id('DeadBeefCake')) - self.assertTrue(utils.is_valid_id('DeadBeefCake')) + # 12-byte arbitrary ASCII bytes + self.assertTrue(utils.is_valid_id(b'DeadBeefCake')) + self.assertTrue(utils.is_valid_id(b'DeadBeefCake')) - # 12-byte str object - self.assertTrue(utils.is_valid_id('beef€67890')) + # 12-byte object + self.assertTrue(utils.is_valid_id('beef€67890'.encode())) def test_bad_length(self): self.assertFalse(utils.is_valid_id(23 * 'a'))