diff --git a/apitokens/authentication.py b/apitokens/authentication.py index 857b7bf8..4f43cd23 100644 --- a/apitokens/authentication.py +++ b/apitokens/authentication.py @@ -1,4 +1,4 @@ -import datetime +from django.utils import timezone from rest_framework.authentication import BaseAuthentication from rest_framework.exceptions import AuthenticationFailed @@ -27,7 +27,7 @@ class UserTokenAuthentication(BaseAuthentication): raise AuthenticationFailed('Invalid token') token.ip_address_last_access = clean_ip_address(request) - token.date_last_access = datetime.datetime.now() + token.date_last_access = timezone.now() token.save(update_fields={'ip_address_last_access', 'date_last_access'}) return (token.user, token) diff --git a/apitokens/tests/test_user_token.py b/apitokens/tests/test_user_token.py index b763f356..10503562 100644 --- a/apitokens/tests/test_user_token.py +++ b/apitokens/tests/test_user_token.py @@ -6,6 +6,7 @@ from django.urls import reverse from apitokens.models import UserToken from common.tests.factories.users import UserFactory +from common.tests.utils import create_user_token class UserTokenTest(TestCase): @@ -44,13 +45,7 @@ class UserTokenTest(TestCase): self.assertNotContains(response, token_key) def test_list_page_does_not_display_full_token_value(self): - token_key = UserToken.generate_token_key() - - token_prefix = UserToken.generate_token_prefix(token_key) - token_hash = UserToken.generate_hash(token_key) - token = UserToken.objects.create( - user=self.user, name='Test Token', token_prefix=token_prefix, token_hash=token_hash - ) + token, token_key = create_user_token(user=self.user, name='Test Token') response = self.client.get(reverse('apitokens:list')) self.assertContains(response, str(token.token_prefix)) diff --git a/common/tests/utils.py b/common/tests/utils.py index 194778b2..e95e4bad 100644 --- a/common/tests/utils.py +++ b/common/tests/utils.py @@ -1,9 +1,13 @@ import itertools +from typing import Tuple import django.urls as urls from django.utils.functional import cached_property from django.utils.regex_helper import normalize +from apitokens.models import UserToken + + try: # Django 2.0 url_resolver_types = (urls.URLResolver,) DJANGO_2 = True @@ -109,3 +113,11 @@ class CheckFilePropertiesMixin: self.assertEqual(file.original_name, kwargs.get('original_name')) if 'size_bytes' in kwargs: self.assertEqual(file.size_bytes, kwargs.get('size_bytes')) + + +def create_user_token(*args, **kwargs) -> Tuple['UserToken', str]: + token_key = UserToken.generate_token_key() + kwargs['token_hash'] = UserToken.generate_hash(token_key) + kwargs['token_prefix'] = UserToken.generate_token_prefix(token_key) + token = UserToken.objects.create(*args, **kwargs) + return token, token_key diff --git a/extensions/tests/test_api.py b/extensions/tests/test_api.py new file mode 100644 index 00000000..54d66b96 --- /dev/null +++ b/extensions/tests/test_api.py @@ -0,0 +1,118 @@ +from pathlib import Path + +from django.urls import reverse +from rest_framework import status +from rest_framework.test import APITestCase, APIClient + +from common.tests.factories.users import UserFactory +from common.tests.factories.extensions import create_approved_version +from common.tests.utils import create_user_token + +from extensions.models import Version + + +TEST_FILES_DIR = Path(__file__).resolve().parent / 'files' + + +class VersionUploadAPITest(APITestCase): + def setUp(self): + self.user = UserFactory() + self.token, self.token_key = create_user_token(user=self.user) + + self.client = APIClient() + self.version = create_approved_version( + extension__extension_id="amaranth", + version="1.0.7", + file__user=self.user, + ) + self.extension = self.version.extension + self.file_path = TEST_FILES_DIR / "amaranth-1.0.8.zip" + + @staticmethod + def _get_upload_url(extension_id): + upload_url = reverse('extensions:upload-extension-version', args=(extension_id,)) + return upload_url + + def test_version_upload_unauthenticated(self): + with open(self.file_path, 'rb') as version_file: + response = self.client.post( + self._get_upload_url(self.extension.extension_id), + { + 'version_file': version_file, + 'release_notes': 'These are the release notes', + }, + format='multipart', + ) + + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + def test_version_upload_extension_not_maintained_by_user(self): + other_user = UserFactory() + other_extension = create_approved_version( + extension__extension_id='other_extension', file__user=other_user + ).extension + + with open(self.file_path, 'rb') as version_file: + response = self.client.post( + self._get_upload_url(other_extension.extension_id), + { + 'version_file': version_file, + 'release_notes': 'These are the release notes', + }, + format='multipart', + HTTP_AUTHORIZATION=f'Bearer {self.token_key}', + ) + + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + self.assertEqual( + response.data['message'], + f'Extension "{other_extension.extension_id}" not maintained by user "{self.user.full_name}"', + ) + + def test_version_upload_extension_does_not_exist(self): + extension_name = 'extension_do_not_exist' + with open(self.file_path, 'rb') as version_file: + response = self.client.post( + self._get_upload_url(extension_name), + { + 'version_file': version_file, + 'release_notes': 'These are the release notes', + }, + format='multipart', + HTTP_AUTHORIZATION=f'Bearer {self.token_key}', + ) + + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + self.assertEqual(response.data['message'], f'Extension "{extension_name}" not found') + + def test_version_upload_success(self): + self.assertEqual(Version.objects.filter(extension=self.extension).count(), 1) + with open(self.file_path, 'rb') as version_file: + response = self.client.post( + self._get_upload_url(self.extension.extension_id), + { + 'version_file': version_file, + 'release_notes': 'These are the release notes', + }, + format='multipart', + HTTP_AUTHORIZATION=f'Bearer {self.token_key}', + ) + + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(Version.objects.filter(extension=self.extension).count(), 2) + + def test_date_last_access(self): + self.assertIsNone(self.token.date_last_access) + with open(self.file_path, 'rb') as version_file: + response = self.client.post( + self._get_upload_url(self.extension.extension_id), + { + 'version_file': version_file, + 'release_notes': 'These are the release notes', + }, + format='multipart', + HTTP_AUTHORIZATION=f'Bearer {self.token_key}', + ) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.token.refresh_from_db() + self.assertIsNotNone(self.token.date_last_access) diff --git a/extensions/urls.py b/extensions/urls.py index f247b371..e3531431 100644 --- a/extensions/urls.py +++ b/extensions/urls.py @@ -16,6 +16,11 @@ urlpatterns = [ ), # API path('api/v1/extensions/', api.ExtensionsAPIView.as_view(), name='api'), + path( + 'api/v1/extensions//versions/new/', + api.UploadExtensionVersionView.as_view(), + name='upload-extension-version', + ), # Public pages path('', public.HomeView.as_view(), name='home'), path('search/', public.SearchView.as_view(), name='search'), diff --git a/extensions/views/api.py b/extensions/views/api.py index 8f60fc02..8db7fb06 100644 --- a/extensions/views/api.py +++ b/extensions/views/api.py @@ -2,14 +2,18 @@ import logging from rest_framework.permissions import AllowAny from rest_framework.response import Response -from rest_framework import serializers +from rest_framework import serializers, status from rest_framework.views import APIView +from rest_framework.permissions import IsAuthenticated from drf_spectacular.utils import OpenApiParameter, extend_schema from django.core.exceptions import ValidationError +from django.db import transaction from common.compare import is_in_version_range, version -from extensions.models import Extension, Platform +from extensions.models import Extension, Platform, Version from extensions.utils import clean_json_dictionary_from_optional_fields +from extensions.views.manage import NewVersionView +from files.forms import FileFormSkipAgreed from constants.base import ( @@ -151,3 +155,76 @@ class ExtensionsAPIView(APIView): 'version': 'v1', } ) + + +class ExtensionVersionSerializer(serializers.Serializer): + version_file = serializers.FileField() + release_notes = serializers.CharField(max_length=1024, required=False) + + +class UploadExtensionVersionView(APIView): + permission_classes = [IsAuthenticated] + + @extend_schema( + request=ExtensionVersionSerializer, + responses={201: 'Extension version uploaded successfully!'}, + ) + def post(self, request, extension_id, *args, **kwargs): + serializer = ExtensionVersionSerializer(data=request.data) + if not serializer.is_valid(): + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + + user = request.user + version_file = serializer.validated_data['version_file'] + release_notes = serializer.validated_data.get('release_notes', '') + + extension = Extension.objects.filter(extension_id=extension_id).first() + if not extension: + return Response( + { + 'message': f'Extension "{extension_id}" not found', + }, + status=status.HTTP_404_NOT_FOUND, + ) + + if not extension.has_maintainer(user): + return Response( + { + 'message': f'Extension "{extension_id}" not maintained by user "{user}"', + }, + status=status.HTTP_403_FORBIDDEN, + ) + + # Create a NewVersionView instance to handle file creation + new_version_view = NewVersionView(request=request, extension=extension) + + # Pass the version_file to the form + form = new_version_view.get_form(FileFormSkipAgreed) + form.fields['source'].initial = version_file + + if not form.is_valid(): + return Response({'message': form.errors}, status=status.HTTP_400_BAD_REQUEST) + + with transaction.atomic(): + # Create the file instance + file_instance = form.save(commit=False) + file_instance.user = user + file_instance.save() + + # Create the version from the file + version = Version.objects.update_or_create( + extension=extension, + file=file_instance, + release_notes=release_notes, + **file_instance.parsed_version_fields, + )[0] + + return Response( + { + 'message': 'Extension version uploaded successfully!', + 'extension_id': extension_id, + 'version_file': version_file.name, + 'release_notes': version.release_notes, + }, + status=status.HTTP_201_CREATED, + ) diff --git a/files/forms.py b/files/forms.py index a2798787..ae1be39d 100644 --- a/files/forms.py +++ b/files/forms.py @@ -167,6 +167,16 @@ class FileForm(forms.ModelForm): return self.cleaned_data +class FileFormSkipAgreed(FileForm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.fields['agreed_with_terms'].required = False + + def clean(self): + self.cleaned_data['agreed_with_terms'] = True + super().clean() + + class BaseMediaFileForm(forms.ModelForm): class Meta: model = files.models.File