Python 3.6 compatibility: bytes vs strings stuff

These changes mostly revolve around the change in ObjectId constructor
when running on Python 3.6. Where on 2.7 the constructor would accept
12- and 24-byte strings, now only 12-byte bytes and 24-character strings
are accepted. Good thing, but required some changes in our code.

Other changes include hashing of strings, which isn't supported, so they
are converted to bytes first, and sometimes converted back afterwards.
This commit is contained in:
Sybren A. Stüvel 2017-03-03 14:14:36 +01:00
parent c2206e6b27
commit 2e41c074b5
12 changed files with 62 additions and 46 deletions

View File

@ -296,7 +296,7 @@ def delete_file(file_item):
process_file_delete(file_item) process_file_delete(file_item)
def generate_link(backend, file_path, project_id=None, is_public=False): def generate_link(backend, file_path: str, project_id=None, is_public=False):
"""Hook to check the backend of a file resource, to build an appropriate link """Hook to check the backend of a file resource, to build an appropriate link
that can be used by the client to retrieve the actual file. that can be used by the client to retrieve the actual file.
""" """
@ -328,7 +328,7 @@ def generate_link(backend, file_path, project_id=None, is_public=False):
if backend == 'cdnsun': if backend == 'cdnsun':
return hash_file_path(file_path, None) return hash_file_path(file_path, None)
if backend == 'unittest': if backend == 'unittest':
return 'https://unit.test/%s' % md5(file_path).hexdigest() return 'https://unit.test/%s' % md5(file_path.encode()).hexdigest()
log.warning('generate_link(): Unknown backend %r, returning empty string ' log.warning('generate_link(): Unknown backend %r, returning empty string '
'as new link.', 'as new link.',

View File

@ -1,6 +1,7 @@
import base64 import base64
import hashlib import hashlib
import logging import logging
import typing
import bcrypt import bcrypt
import datetime import datetime
@ -67,12 +68,12 @@ def make_token():
return jsonify(token=token['token']) return jsonify(token=token['token'])
def generate_and_store_token(user_id, days=15, prefix=''): def generate_and_store_token(user_id, days=15, prefix=b''):
"""Generates token based on random bits. """Generates token based on random bits.
:param user_id: ObjectId of the owning user. :param user_id: ObjectId of the owning user.
:param days: token will expire in this many days. :param days: token will expire in this many days.
:param prefix: the token will be prefixed by this string, for easy identification. :param prefix: the token will be prefixed by these bytes, for easy identification.
:return: the token document. :return: the token document.
""" """
@ -80,17 +81,22 @@ def generate_and_store_token(user_id, days=15, prefix=''):
# Use 'xy' as altargs to prevent + and / characters from appearing. # Use 'xy' as altargs to prevent + and / characters from appearing.
# We never have to b64decode the string anyway. # We never have to b64decode the string anyway.
token = prefix + base64.b64encode(random_bits, altchars='xy').strip('=') token = prefix + base64.b64encode(random_bits, altchars=b'xy').strip(b'=')
token_expiry = datetime.datetime.now(tz=tz_util.utc) + datetime.timedelta(days=days) token_expiry = datetime.datetime.now(tz=tz_util.utc) + datetime.timedelta(days=days)
return store_token(user_id, token, token_expiry) return store_token(user_id, token.decode('ascii'), token_expiry)
def hash_password(password, salt): def hash_password(password: str, salt: typing.Union[str, bytes]) -> str:
password = password.encode()
if isinstance(salt, str): if isinstance(salt, str):
salt = salt.encode('utf-8') salt = salt.encode('utf-8')
encoded_password = base64.b64encode(hashlib.sha256(password).digest())
return bcrypt.hashpw(encoded_password, salt) hash = hashlib.sha256(password).digest()
encoded_password = base64.b64encode(hash)
hashed_password = bcrypt.hashpw(encoded_password, salt)
return hashed_password.decode('ascii')
def setup_app(app, url_prefix): def setup_app(app, url_prefix):

View File

@ -221,7 +221,7 @@ def create_service_account(email, roles, service, update_existing=None):
user.update(result) user.update(result)
# Create an authentication token that won't expire for a long time. # Create an authentication token that won't expire for a long time.
token = local_auth.generate_and_store_token(user['_id'], days=36500, prefix='SRV') token = local_auth.generate_and_store_token(user['_id'], days=36500, prefix=b'SRV')
return user, token return user, token

View File

@ -136,7 +136,7 @@ def str2id(document_id):
try: try:
return bson.ObjectId(document_id) return bson.ObjectId(document_id)
except bson.objectid.InvalidId: except (bson.objectid.InvalidId, TypeError):
log.debug('str2id(%r): Invalid Object ID', document_id) log.debug('str2id(%r): Invalid Object ID', document_id)
raise wz_exceptions.BadRequest('Invalid object ID %r' % document_id) raise wz_exceptions.BadRequest('Invalid object ID %r' % document_id)

View File

@ -118,7 +118,7 @@ def find_token(token, is_subclient_token=False, **extra_filters):
return db_token return db_token
def store_token(user_id, token, token_expiry, oauth_subclient_id=False): def store_token(user_id, token: str, token_expiry, oauth_subclient_id=False):
"""Stores an authentication token. """Stores an authentication token.
:returns: the token document from MongoDB :returns: the token document from MongoDB

View File

@ -301,10 +301,11 @@ class AbstractPillarTest(TestMinimal):
json=BLENDER_ID_USER_RESPONSE, json=BLENDER_ID_USER_RESPONSE,
status=200) status=200)
def make_header(self, username, subclient_id=''): def make_header(self, username: str, subclient_id: str='') -> bytes:
"""Returns a Basic HTTP Authentication header value.""" """Returns a Basic HTTP Authentication header value."""
return 'basic ' + base64.b64encode('%s:%s' % (username, subclient_id)) content = '%s:%s' % (username, subclient_id)
return b'basic ' + base64.b64encode(content.encode())
def create_standard_groups(self, additional_groups=()): def create_standard_groups(self, additional_groups=()):
"""Creates standard admin/demo/subscriber groups, plus any additional. """Creates standard admin/demo/subscriber groups, plus any additional.

View File

@ -4,6 +4,7 @@ import urllib.request, urllib.parse, urllib.error
import logging import logging
import traceback import traceback
import sys import sys
import typing
import dateutil.parser import dateutil.parser
from flask import current_app from flask import current_app
@ -174,7 +175,7 @@ def get_main_project():
return main_project return main_project
def is_valid_id(some_id): def is_valid_id(some_id: typing.Union[str, bytes]):
"""Returns True iff the given string is a valid ObjectId. """Returns True iff the given string is a valid ObjectId.
Only use this if you do NOT need an ObjectId object. If you do need that, Only use this if you do NOT need an ObjectId object. If you do need that,
@ -184,27 +185,22 @@ def is_valid_id(some_id):
:rtype: bool :rtype: bool
""" """
if isinstance(some_id, bytes):
return len(some_id) == 12
if not isinstance(some_id, str): if not isinstance(some_id, str):
return False return False
if isinstance(some_id, str): if len(some_id) != 24:
try: return False
some_id = some_id.encode('ascii')
except UnicodeEncodeError:
return False
if len(some_id) == 12: # This is more than 5x faster than checking character by
return True # character in a loop.
elif len(some_id) == 24: try:
# This is more than 5x faster than checking character by int(some_id, 16)
# character in a loop. except ValueError:
try: return False
int(some_id, 16) return True
except ValueError:
return False
return True
return False
def last_page_index(meta_info): def last_page_index(meta_info):

View File

@ -75,3 +75,14 @@ class LocalAuthTest(AbstractPillarTest):
'password': 'koro'}) 'password': 'koro'})
self.assertEqual(403, resp.status_code, resp.data) self.assertEqual(403, resp.status_code, resp.data)
def test_hash_password(self):
from pillar.api.local_auth import hash_password
salt = b'$2b$12$cHdK4M8/yJ7SWp2Q.PYW0O'
self.assertEqual(hash_password('© 2017 je moeder™', salt),
'$2b$12$cHdK4M8/yJ7SWp2Q.PYW0OAU1gE3DIVdeehq0XIzOMM0Vp3ldPMb6')
self.assertIsInstance(hash_password('Резиновая уточка', salt), str)
# The password should be encodable as ASCII.
hash_password('Резиновая уточка', salt).encode('ascii')

View File

@ -95,7 +95,7 @@ class NodeMoverTest(unittest.TestCase):
], ],
} }
} }
prid = ObjectId('project_dest') prid = ObjectId(b'project_dest')
new_project = { new_project = {
'_id': prid '_id': prid
} }
@ -124,7 +124,7 @@ class NodeMoverTest(unittest.TestCase):
], ],
} }
} }
prid = ObjectId('project_dest') prid = ObjectId(b'project_dest')
new_project = { new_project = {
'_id': prid '_id': prid
} }

View File

@ -146,7 +146,7 @@ class ProjectCreationTest(AbstractProjectTest):
self.assertEqual({'Prøject B'}, {p['name'] for p in proj_list['_items']}) self.assertEqual({'Prøject B'}, {p['name'] for p in proj_list['_items']})
# No access to anything for user C, should result in empty list. # No access to anything for user C, should result in empty list.
self._create_user_with_token(roles={'subscriber'}, token='token-c', user_id=12 * 'c') self._create_user_with_token(roles={'subscriber'}, token='token-c', user_id=24 * 'c')
resp = self.client.get('/api/projects', resp = self.client.get('/api/projects',
headers={'Authorization': self.make_header('token-c')}) headers={'Authorization': self.make_header('token-c')})
self.assertEqual(200, resp.status_code) self.assertEqual(200, resp.status_code)
@ -277,7 +277,7 @@ class ProjectEditTest(AbstractProjectTest):
project_url = '/api/projects/%s' % project_id project_url = '/api/projects/%s' % project_id
# Create test user. # Create test user.
self._create_user_with_token(['admin'], 'admin-token', user_id='cafef00dbeef') self._create_user_with_token(['admin'], 'admin-token', user_id='cafef00dbeefcafef00dbeef')
# Admin user should be able to PUT. # Admin user should be able to PUT.
put_project = remove_private_keys(project) put_project = remove_private_keys(project)
@ -324,7 +324,7 @@ class ProjectEditTest(AbstractProjectTest):
# Create admin user that doesn't own the project, to check that # Create admin user that doesn't own the project, to check that
# non-owner admins can delete projects too. # non-owner admins can delete projects too.
self._create_user_with_token(['admin'], 'admin-token', user_id='cafef00dbeef') self._create_user_with_token(['admin'], 'admin-token', user_id='cafef00dbeefcafef00dbeef')
# Admin user should be able to DELETE. # Admin user should be able to DELETE.
resp = self.client.delete(project_url, resp = self.client.delete(project_url,
@ -361,7 +361,8 @@ class ProjectEditTest(AbstractProjectTest):
project_url = '/api/projects/%s' % project_id project_url = '/api/projects/%s' % project_id
# Create test user. # Create test user.
self._create_user_with_token(['subscriber'], 'mortal-token', user_id='cafef00dbeef') self._create_user_with_token(['subscriber'], 'mortal-token',
user_id='cafef00dbeefcafef00dbeef')
# Other user should NOT be able to DELETE. # Other user should NOT be able to DELETE.
resp = self.client.delete(project_url, resp = self.client.delete(project_url,
@ -451,7 +452,7 @@ class ProjectNodeAccess(AbstractProjectTest):
put_project = remove_private_keys(self.project) put_project = remove_private_keys(self.project)
# Create admin user. # Create admin user.
self._create_user_with_token(['admin'], 'admin-token', user_id='cafef00dbeef') self._create_user_with_token(['admin'], 'admin-token', user_id='cafef00dbeefcafef00dbeef')
# Make the project public # Make the project public
put_project['permissions']['world'] = ['GET'] # make public put_project['permissions']['world'] = ['GET'] # make public

View File

@ -15,7 +15,7 @@ class Str2idTest(AbstractPillarTest):
self.assertEqual(ObjectId(str_id), str2id(str_id)) self.assertEqual(ObjectId(str_id), str2id(str_id))
happy(24 * 'a') happy(24 * 'a')
happy(12 * 'a') happy(12 * b'a')
happy('577e23ad98377323f74c368c') happy('577e23ad98377323f74c368c')
def test_unhappy(self): def test_unhappy(self):
@ -25,10 +25,11 @@ class Str2idTest(AbstractPillarTest):
self.assertRaises(BadRequest, str2id, str_id) self.assertRaises(BadRequest, str2id, str_id)
unhappy(13 * 'a') unhappy(13 * 'a')
unhappy(13 * b'a')
unhappy('577e23ad 8377323f74c368c') unhappy('577e23ad 8377323f74c368c')
unhappy('김치') # Kimchi unhappy('김치') # Kimchi
unhappy('') unhappy('')
unhappy('') unhappy(b'')
unhappy(None) unhappy(None)

View File

@ -15,12 +15,12 @@ class IsValidIdTest(unittest.TestCase):
self.assertTrue(utils.is_valid_id('deadbeefbeefcacedeadcace')) self.assertTrue(utils.is_valid_id('deadbeefbeefcacedeadcace'))
self.assertTrue(utils.is_valid_id('deadbeefbeefcacedeadcace')) self.assertTrue(utils.is_valid_id('deadbeefbeefcacedeadcace'))
# 12-byte arbitrary ASCII strings # 12-byte arbitrary ASCII bytes
self.assertTrue(utils.is_valid_id('DeadBeefCake')) self.assertTrue(utils.is_valid_id(b'DeadBeefCake'))
self.assertTrue(utils.is_valid_id('DeadBeefCake')) self.assertTrue(utils.is_valid_id(b'DeadBeefCake'))
# 12-byte str object # 12-byte object
self.assertTrue(utils.is_valid_id('beef€67890')) self.assertTrue(utils.is_valid_id('beef€67890'.encode()))
def test_bad_length(self): def test_bad_length(self):
self.assertFalse(utils.is_valid_id(23 * 'a')) self.assertFalse(utils.is_valid_id(23 * 'a'))