diff --git a/tests/test_api/test_file_storage_backends.py b/tests/test_api/test_file_storage_backends.py index 9e661ab7..89e3d5db 100644 --- a/tests/test_api/test_file_storage_backends.py +++ b/tests/test_api/test_file_storage_backends.py @@ -4,18 +4,20 @@ from pillar.tests import AbstractPillarTest class LocalStorageBackendTest(AbstractPillarTest): - - def test_upload_download(self): + def create_test_file(self) -> (typing.IO, bytes): import io import secrets - from pillar.api.file_storage_backends import Bucket - file_contents = secrets.token_bytes(512) test_file: typing.IO = io.BytesIO(file_contents) + return test_file, file_contents + + def test_upload_download(self): + test_file, file_contents = self.create_test_file() + with self.app.test_request_context(): - bucket_class = Bucket.for_backend('local') + bucket_class = self.storage_backend() bucket = bucket_class('buckettest') blob = bucket.blob('somefile.bin') @@ -30,3 +32,35 @@ class LocalStorageBackendTest(AbstractPillarTest): self.assertEqual(200, resp.status_code) self.assertEqual('512', resp.headers['Content-Length']) self.assertEqual(file_contents, resp.data) + + def storage_backend(self): + from pillar.api.file_storage_backends import Bucket + + return Bucket.for_backend('local') + + def test_copy_to_bucket(self): + from bson import ObjectId + + test_file, file_contents = self.create_test_file() + + src_project_id = ObjectId(24 * 'a') + dest_project_id = ObjectId(24 * 'd') + + with self.app.test_request_context(): + bucket_class = self.storage_backend() + bucket1 = bucket_class(str(src_project_id)) + src_blob = bucket1.blob('somefile.bin') + src_blob.create_from_file(test_file, content_type='application/octet-stream') + + bucket_class.copy_to_bucket('somefile.bin', src_project_id, dest_project_id) + + # Test that the file now exists at the new bucket. + bucket2 = bucket_class(str(dest_project_id)) + dest_blob = bucket2.blob('somefile.bin') + url = dest_blob.get_url(is_public=True) + + resp = self.get(url) + + self.assertEqual(200, resp.status_code) + self.assertEqual('512', resp.headers['Content-Length']) + self.assertEqual(file_contents, resp.data)