diff --git a/api/views/diagnostic.py b/api/views/diagnostic.py index 425d6998f..f4029f4cf 100644 --- a/api/views/diagnostic.py +++ b/api/views/diagnostic.py @@ -13,7 +13,6 @@ from rest_framework.exceptions import NotFound, PermissionDenied from rest_framework.generics import CreateAPIView, ListAPIView, UpdateAPIView from rest_framework.pagination import LimitOffsetPagination -from rest_framework.views import APIView from api.exceptions import DuplicateException from api.permissions import ( @@ -23,7 +22,7 @@ IsCanteenManager, ) from api.serializers import DiagnosticAndCanteenSerializer, ManagerDiagnosticSerializer -from api.views.utils import update_change_reason_with_auth +from api.views.utils import CSVImportApiView, update_change_reason_with_auth from common.utils import send_mail from data.models import Canteen, Teledeclaration from data.models.diagnostic import Diagnostic @@ -93,13 +92,14 @@ def perform_update(self, serializer): update_change_reason_with_auth(self, diagnostic) -class EmailDiagnosticImportFileView(APIView): +class EmailDiagnosticImportFileView(CSVImportApiView): permission_classes = [IsAuthenticated] def post(self, request): try: - file = request.data["file"] - self._verify_file_size(file) + self.file = request.data["file"] + super()._verify_file_size() + super()._verify_file_format() email = request.data.get("email", request.user.email).strip() context = { "from": email, @@ -111,7 +111,7 @@ def post(self, request): to=[settings.CONTACT_EMAIL], reply_to=[email], template="unusual_diagnostic_import_file", - attachments=[(file.name, file.read(), file.content_type)], + attachments=[(self.file.name, self.file.read(), self.file.content_type)], context=context, ) except ValidationError as e: @@ -127,11 +127,6 @@ def post(self, request): return HttpResponse() - @staticmethod - def _verify_file_size(file): - if file.size > settings.CSV_IMPORT_MAX_SIZE: - raise ValidationError("Ce fichier est trop grand, merci d'utiliser un fichier de moins de 10Mo") - class DiagnosticsToTeledeclarePagination(LimitOffsetPagination): default_limit = 100 diff --git a/api/views/diagnosticimport.py b/api/views/diagnosticimport.py index d62906f18..5891596f9 100644 --- a/api/views/diagnosticimport.py +++ b/api/views/diagnosticimport.py @@ -7,7 +7,6 @@ from decimal import Decimal, InvalidOperation import requests -from django.conf import settings from django.contrib.auth import get_user_model from django.core.exceptions import ValidationError from django.core.validators import validate_email @@ -16,11 +15,11 @@ from django.http import JsonResponse from rest_framework import status from rest_framework.exceptions import PermissionDenied -from rest_framework.views import APIView from simple_history.utils import update_change_reason from api.permissions import IsAuthenticated from api.serializers import FullCanteenSerializer +from api.views.utils import CSVImportApiView from data.models import Canteen, ImportFailure, ImportType, Sector from data.models.diagnostic import Diagnostic from data.models.teledeclaration import Teledeclaration @@ -31,7 +30,7 @@ logger = logging.getLogger(__name__) -class ImportDiagnosticsView(ABC, APIView): +class ImportDiagnosticsView(ABC, CSVImportApiView): permission_classes = [IsAuthenticated] value_error_regex = re.compile(r"Field '(.+)' expected .+? got '(.+)'.") annotated_sectors = Sector.objects.annotate(name_lower=Lower("name")) @@ -80,8 +79,8 @@ def post(self, request): try: with transaction.atomic(): self.file = request.data["file"] - ImportDiagnosticsView._verify_file_format(self.file) - ImportDiagnosticsView._verify_file_size(self.file) + super()._verify_file_size() + super()._verify_file_format() self._process_file(self.file) if self.errors: @@ -113,18 +112,6 @@ def _log_error(self, message, level="warning"): import_type=self.import_type, ) - @staticmethod - def _verify_file_format(file): - if file.content_type != "text/csv" and file.content_type != "text/tab-separated-values": - raise ValidationError( - f"Ce fichier est au format {file.content_type}, merci d'exporter votre fichier au format CSV et réessayer." - ) - - @staticmethod - def _verify_file_size(file): - if file.size > settings.CSV_IMPORT_MAX_SIZE: - raise ValidationError("Ce fichier est trop grand, merci d'utiliser un fichier de moins de 10Mo") - def check_admin_values(self, header): is_admin_import = any("admin_" in column for column in header) if is_admin_import and not self.request.user.is_staff: diff --git a/api/views/purchaseimport.py b/api/views/purchaseimport.py index 64756ec41..df34988df 100644 --- a/api/views/purchaseimport.py +++ b/api/views/purchaseimport.py @@ -1,5 +1,4 @@ import csv -import hashlib import io import json import logging @@ -13,19 +12,18 @@ from django.http import JsonResponse from rest_framework import status from rest_framework.exceptions import PermissionDenied -from rest_framework.views import APIView from api.permissions import IsAuthenticated from api.serializers import PurchaseSerializer +from api.views.utils import CSVImportApiView from data.models import Canteen, ImportFailure, ImportType, Purchase -from .diagnosticimport import ImportDiagnosticsView from .utils import camelize, decode_bytes, normalise_siret logger = logging.getLogger(__name__) -class ImportPurchasesView(APIView): +class ImportPurchasesView(CSVImportApiView): permission_classes = [IsAuthenticated] max_error_items = 30 @@ -50,8 +48,12 @@ def post(self, request): logger.info("Purchase bulk import started") try: self.file = request.data["file"] - self._verify_file_size() - ImportDiagnosticsView._verify_file_format(self.file) + super()._verify_file_size() + super()._verify_file_format() + + self.file_digest = super()._get_file_digest() + self._check_duplication() + with transaction.atomic(): self._process_file() @@ -60,10 +62,6 @@ def post(self, request): if self.errors: raise IntegrityError() - # The duplication check is called after the processing. The cost of eventually processing - # the file for nothing appears to be smaller than read the file twice. - self._check_duplication() - # Update all purchases's import source with file digest Purchase.objects.filter(import_source=self.tmp_id).update(import_source=self.file_digest) @@ -100,13 +98,10 @@ def _log_error(self, message, level="warning"): ) def _process_file(self): - file_hash = hashlib.md5() chunk = [] read_header = True row_count = 1 for row in self.file: - file_hash.update(row) - # Sniffing 1st line if read_header: # decode header, discarding encoding result that might not be accurate without more data @@ -133,12 +128,6 @@ def _process_file(self): if len(chunk) > 0: self._process_chunk(chunk) - self.file_digest = file_hash.hexdigest() - - def _verify_file_size(self): - if self.file.size > settings.CSV_IMPORT_MAX_SIZE: - raise ValidationError("Ce fichier est trop grand, merci d'utiliser un fichier de moins de 10Mo") - def _decode_chunk(self, chunk_list): if self.encoding_detected is None: chunk = b"".join(chunk_list) diff --git a/api/views/utils.py b/api/views/utils.py index e72dc355c..a184aa03a 100644 --- a/api/views/utils.py +++ b/api/views/utils.py @@ -1,13 +1,37 @@ +import hashlib import json import logging import chardet +from django.conf import settings +from django.core.exceptions import ValidationError from djangorestframework_camel_case.render import CamelCaseJSONRenderer +from rest_framework.views import APIView from simple_history.utils import update_change_reason logger = logging.getLogger(__name__) +class CSVImportApiView(APIView): + def _verify_file_size(self): + if self.file.size > settings.CSV_IMPORT_MAX_SIZE: + raise ValidationError( + f"Ce fichier est trop grand, merci d'utiliser un fichier de moins de {settings.CSV_IMPORT_MAX_SIZE_PRETTY}" + ) + + def _verify_file_format(self): + if self.file.content_type not in ["text/csv", "text/tab-separated-values"]: + raise ValidationError( + f"Ce fichier est au format {self.file.content_type}, merci d'exporter votre fichier au format CSV et réessayer." + ) + + def _get_file_digest(self): + file_hash = hashlib.md5() + for row in self.file: + file_hash.update(row) + return file_hash.hexdigest() + + def camelize(data): camel_case_bytes = CamelCaseJSONRenderer().render(data) return json.loads(camel_case_bytes.decode("utf-8")) diff --git a/macantine/settings.py b/macantine/settings.py index 339d3d610..f3fe75d6a 100644 --- a/macantine/settings.py +++ b/macantine/settings.py @@ -420,6 +420,7 @@ # Maximum CSV import file size: 10Mo CSV_IMPORT_MAX_SIZE = 10485760 +CSV_IMPORT_MAX_SIZE_PRETTY = "10Mo" # Size of each chunk when processing files CSV_PURCHASE_CHUNK_LINES = 10000