Scan files with clamdscan #77

Merged
Anna Sirota merged 17 commits from scan-file into main 2024-04-12 19:11:30 +02:00
9 changed files with 136 additions and 56 deletions
Showing only changes of commit 9f7c917e23 - Show all commits

View File

@ -5,7 +5,6 @@ from django.dispatch import receiver
import django.dispatch import django.dispatch
import extensions.models import extensions.models
import extensions.tasks
import files.models import files.models
version_changed = django.dispatch.Signal() version_changed = django.dispatch.Signal()

View File

View File

@ -1,18 +0,0 @@
# Generated by Django 4.2.11 on 2024-04-11 17:13
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('files', '0004_alter_file_status'),
]
operations = [
migrations.AlterField(
model_name='filevalidation',
name='validation',
field=models.JSONField(),
),
]

View File

@ -0,0 +1,35 @@
# Generated by Django 4.2.11 on 2024-04-12 09:05
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('files', '0004_alter_file_status'),
]
operations = [
migrations.RenameField(
model_name='filevalidation',
old_name='validation',
new_name='results',
),
migrations.AlterField(
model_name='filevalidation',
name='results',
field=models.JSONField(),
),
migrations.RemoveField(
model_name='filevalidation',
name='errors',
),
migrations.RemoveField(
model_name='filevalidation',
name='notices',
),
migrations.RemoveField(
model_name='filevalidation',
name='warnings',
),
]

View File

@ -1,14 +1,12 @@
from pathlib import Path from pathlib import Path
from typing import Dict, Any from typing import Dict, Any
import logging import logging
import os.path
from django.conf import settings
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from django.db import models from django.db import models
from common.model_mixins import CreatedModifiedMixin, TrackChangesMixin, SoftDeleteMixin from common.model_mixins import CreatedModifiedMixin, TrackChangesMixin, SoftDeleteMixin
from files.utils import get_sha256, guess_mimetype_from_ext, scan from files.utils import get_sha256, guess_mimetype_from_ext
from constants.base import ( from constants.base import (
FILE_STATUS_CHOICES, FILE_STATUS_CHOICES,
FILE_TYPE_CHOICES, FILE_TYPE_CHOICES,
@ -206,32 +204,10 @@ class File(CreatedModifiedMixin, TrackChangesMixin, SoftDeleteMixin, models.Mode
def get_submit_url(self) -> str: def get_submit_url(self) -> str:
return self.extension.get_draft_url() return self.extension.get_draft_url()
def scan(self) -> 'FileValidation':
"""Run a scanner on the source file and save its output as a FileValidation record."""
abs_path = os.path.join(settings.MEDIA_ROOT, self.source.path)
completed_process = scan(abs_path)
validation = {
'args': completed_process.args,
'stdout': completed_process.stdout.decode(),
'stderr': completed_process.stderr.decode(),
'returncode': completed_process.returncode,
}
file_validation, is_new = FileValidation.objects.get_or_create(
file=self, defaults={'validation': validation}
)
file_validation.is_valid = completed_process.returncode == 0
# FIXME: do we need `errors`/`warnings`/`notices` counters at all?
file_validation.errors = 1 if not file_validation.is_valid else 0
file_validation.save()
return file_validation
class FileValidation(CreatedModifiedMixin, TrackChangesMixin, models.Model): class FileValidation(CreatedModifiedMixin, TrackChangesMixin, models.Model):
track_changes_to_fields = {'is_valid', 'errors', 'warnings', 'notices', 'validation'} track_changes_to_fields = {'is_valid', 'errors', 'warnings', 'notices', 'validation'}
file = models.OneToOneField(File, related_name='validation', on_delete=models.CASCADE) file = models.OneToOneField(File, related_name='validation', on_delete=models.CASCADE)
is_valid = models.BooleanField(default=False) is_valid = models.BooleanField(default=False)
errors = models.IntegerField(default=0) results = models.JSONField()
warnings = models.IntegerField(default=0)
notices = models.IntegerField(default=0)
validation = models.JSONField()

View File

@ -1,7 +1,12 @@
from django.db.models.signals import pre_save import logging
from django.db.models.signals import pre_save, post_save
from django.dispatch import receiver from django.dispatch import receiver
import files.models import files.models
import files.tasks
logger = logging.getLogger(__name__)
@receiver(pre_save, sender=files.models.File) @receiver(pre_save, sender=files.models.File)
@ -9,3 +14,14 @@ def _record_changes(sender: object, instance: files.models.File, **kwargs: objec
was_changed, old_state = instance.pre_save_record() was_changed, old_state = instance.pre_save_record()
instance.record_status_change(was_changed, old_state, **kwargs) instance.record_status_change(was_changed, old_state, **kwargs)
@receiver(post_save, sender=files.models.File)
def _scan_new_file(
sender: object, instance: files.models.File, created: bool, **kwargs: object
) -> None:
if not created:
return
logger.info('Initiating a scan for file pk=%s', instance.pk)
files.tasks.scan(file_id=instance.pk, creator=instance)

26
files/tasks.py Normal file
View File

@ -0,0 +1,26 @@
import os.path
from background_task import background
from django.conf import settings
import files.models
import files.utils
@background()
def scan(file_id: int):
"""Run a scan on a given file and save its output as a FileValidation record."""
file = files.models.File.objects.get(pk=file_id)
abs_path = os.path.join(settings.MEDIA_ROOT, file.source.path)
completed_process = files.utils.run_clamdscan(abs_path)
scan_result = {
'args': completed_process.args,
'stdout': completed_process.stdout.decode(),
'stderr': completed_process.stderr.decode(),
'returncode': completed_process.returncode,
}
file_validation, is_new = files.models.FileValidation.objects.get_or_create(
file=file, defaults={'results': {completed_process.args[0]: scan_result}}
)
file_validation.is_valid = completed_process.returncode == 0
file_validation.save()

View File

@ -1,7 +1,11 @@
import json import json
import os
import shutil import shutil
import tempfile import tempfile
import unittest
from background_task.models import Task
from django.conf import settings
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from django.test import TestCase, override_settings from django.test import TestCase, override_settings
@ -9,6 +13,7 @@ from common.admin import get_admin_change_path
from common.log_entries import entries_for from common.log_entries import entries_for
from common.tests.factories.files import FileFactory from common.tests.factories.files import FileFactory
from files.models import File from files.models import File
import files.tasks
User = get_user_model() User = get_user_model()
@ -80,23 +85,64 @@ class FileTest(TestCase):
self.assertEqual(response.status_code, 200, path) self.assertEqual(response.status_code, 200, path)
@override_settings(MEDIA_ROOT='./files/tests/files') @unittest.skipUnless(shutil.which('clamdscan'), 'requires clamdscan')
@override_settings(MEDIA_ROOT='/tmp/')
class FileScanTest(TestCase): class FileScanTest(TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.temp_directory = tempfile.mkdtemp() self.temp_directory = tempfile.mkdtemp(prefix=settings.MEDIA_ROOT)
def tearDown(self): def tearDown(self):
super().tearDown() super().tearDown()
shutil.rmtree(self.temp_directory) shutil.rmtree(self.temp_directory)
def test_scan(self): def test_scan_flags_found_invalid(self):
# TODO: write the test files on the fly test_file_path = os.path.join(self.temp_directory, 'test_file.zip')
file = FileFactory(source='Win.Test.EICAR_HDB-1.zip') test_content = (
b'X5O!P%@AP[4\PZX54(P^)7CC)7}$EICAR-STANDARD-ANTIVIRUS-TEST-FILE!$H+H*' # noqa: W605
)
with open(test_file_path, 'wb+') as test_file:
test_file.write(test_content)
file_validation = file.scan() file = FileFactory(source=test_file_path)
self.assertEqual(file_validation.validation['returncode'], 1) # A background task should have been created
stdout_lines = file_validation.validation['stdout'].split('\n') task = Task.objects.created_by(creator=file).first()
self.assertIsNotNone(task)
self.assertEqual(task.task_name, 'files.tasks.scan')
self.assertEqual(task.task_params, f'[[], {{"file_id": {file.pk}}}]')
# Actually run the task as if by background runner
task_args, task_kwargs = task.params()
files.tasks.scan.task_function(*task_args, **task_kwargs)
self.assertFalse(file.validation.is_valid)
result = file.validation.results['clamdscan']
self.assertEqual(result['returncode'], 1)
stdout_lines = result['stdout'].split('\n')
self.assertIn(f'{file.source.name}: Win.Test.EICAR_HDB-1 FOUND', stdout_lines[0]) self.assertIn(f'{file.source.name}: Win.Test.EICAR_HDB-1 FOUND', stdout_lines[0])
self.assertEqual(file_validation.validation['stderr'], '') self.assertEqual(result['stderr'], '')
def test_scan_flags_nothing_found_valid(self):
test_file_path = os.path.join(self.temp_directory, 'test_file.zip')
with open(test_file_path, 'wb+') as test_file:
test_file.write(b'some file')
file = FileFactory(source=test_file_path)
# A background task should have been created
task = Task.objects.created_by(creator=file).first()
self.assertIsNotNone(task)
self.assertEqual(task.task_name, 'files.tasks.scan')
self.assertEqual(task.task_params, f'[[], {{"file_id": {file.pk}}}]')
# Actually run the task as if by background runner
task_args, task_kwargs = task.params()
files.tasks.scan.task_function(*task_args, **task_kwargs)
self.assertTrue(file.validation.is_valid)
result = file.validation.results['clamdscan']
self.assertEqual(result['returncode'], 0)
stdout_lines = result['stdout'].split('\n')
self.assertIn(f'{file.source.name}: OK', stdout_lines[0])
self.assertEqual(result['stderr'], '')

View File

@ -165,7 +165,7 @@ def guess_mimetype_from_content(file_obj) -> str:
return mimetype_from_bytes return mimetype_from_bytes
def scan(abs_path: str) -> 'subprocess.CompletedProcess': def run_clamdscan(abs_path: str) -> 'subprocess.CompletedProcess':
scan_args = ['clamdscan', '--fdpass', abs_path] scan_args = ['clamdscan', '--fdpass', abs_path]
logger.info('Running %s', scan_args) logger.info('Running %s', scan_args)
return subprocess.run(scan_args, capture_output=True) return subprocess.run(scan_args, capture_output=True)