Skip to content

Commit 1c501ef

Browse files
committed
Improve data import dry_run functionality
1 parent 75c752b commit 1c501ef

File tree

6 files changed

+139
-59
lines changed

6 files changed

+139
-59
lines changed

jolpica/formula_one/importer/importer.py

+45-16
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,6 @@ def deserialise(self, data: dict) -> DeserialisationResult:
1919
deserialiser = self.factory.get_deserialiser(data["object_type"])
2020
return deserialiser.deserialise(data)
2121

22-
def get_object_priority(self, object_dict: dict) -> int:
23-
object_type = object_dict.get("object_type")
24-
25-
if object_type in self.OBJECT_TYPE_PRIORITY:
26-
return self.OBJECT_TYPE_PRIORITY.index(object_type)
27-
else:
28-
return len(self.OBJECT_TYPE_PRIORITY)
29-
30-
def get_model_import_priority(self, import_key: ModelImport) -> int:
31-
object_type = import_key.model_class.__name__
32-
return self.get_object_priority({"object_type": object_type})
33-
3422
def deserialise_all(self, data: list[dict]) -> DeserialisationResult:
3523
prioritised_data = sorted(enumerate(data), key=lambda x: self.get_object_priority(x[1]))
3624
indexed_results = [(i, self.deserialise(item)) for i, item in prioritised_data]
@@ -51,11 +39,52 @@ def deserialise_all(self, data: list[dict]) -> DeserialisationResult:
5139
errors=errors,
5240
)
5341

54-
def save_deserialisation_result_to_db(self, result: DeserialisationResult):
55-
prioritised_items = sorted(result.instances.items(), key=lambda x: self.get_model_import_priority(x[0]))
42+
@classmethod
43+
def get_object_priority(cls, object_dict: dict) -> int:
44+
object_type = object_dict.get("object_type")
45+
46+
if object_type in cls.OBJECT_TYPE_PRIORITY:
47+
return cls.OBJECT_TYPE_PRIORITY.index(object_type)
48+
else:
49+
return len(cls.OBJECT_TYPE_PRIORITY)
50+
51+
@classmethod
52+
def get_model_import_priority(cls, import_key: ModelImport) -> int:
53+
object_type = import_key.model_class.__name__
54+
return cls.get_object_priority({"object_type": object_type})
55+
56+
@classmethod
57+
def save_deserialisation_result_to_db(cls, result: DeserialisationResult) -> dict:
58+
prioritised_items = sorted(result.instances.items(), key=lambda x: cls.get_model_import_priority(x[0]))
59+
60+
import_stats = {"updated_count": 0, "created_count": 0, "models": {}}
61+
5662
for model_import, instances in prioritised_items:
63+
model_name = model_import.model_class.__name__
64+
if model_name not in import_stats["models"]:
65+
import_stats["models"][model_name] = {
66+
"updated_count": 0,
67+
"created_count": 0,
68+
"updated": [],
69+
"created": [],
70+
}
71+
5772
for ins in instances:
58-
ins.id = model_import.model_class.objects.update_or_create( # type: ignore[attr-defined]
73+
updated_ins, is_created = model_import.model_class.objects.update_or_create( # type: ignore[attr-defined]
5974
**{field: getattr(ins, field) for field in model_import.unique_fields},
6075
defaults={field: getattr(ins, field) for field in model_import.update_fields},
61-
)[0].id
76+
)
77+
ins.pk = updated_ins.pk
78+
79+
if is_created:
80+
import_stats["created_count"] += 1
81+
import_stats["models"][model_name]["created_count"] += 1
82+
import_stats["models"][model_name]["created"].append(updated_ins.pk)
83+
else:
84+
import_stats["updated_count"] += 1
85+
import_stats["models"][model_name]["updated_count"] += 1
86+
import_stats["models"][model_name]["updated"].append(updated_ins.pk)
87+
88+
import_stats["total_count"] = import_stats["created_count"] + import_stats["updated_count"]
89+
90+
return import_stats

jolpica_api/data_import/admin.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55

66
class DataImportLogAdmin(admin.ModelAdmin):
77
def __init__(self, model, admin_site):
8-
self.list_display = [field.name for field in model._meta.fields if field.name != "updated_records"]
8+
self.list_display = [
9+
field.name for field in model._meta.fields if field.name not in {"import_result", "errors"}
10+
]
11+
self.list_filter = ["dry_run", "error_type", "is_success"]
912
super().__init__(model, admin_site)
1013

1114

jolpica_api/data_import/migrations/0002_dataimportlog_description.py jolpica_api/data_import/migrations/0002_rename_updated_records_dataimportlog_import_result_and_more.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Generated by Django 5.1.6 on 2025-03-03 21:29
1+
# Generated by Django 5.1.6 on 2025-03-03 23:08
22

33
from django.db import migrations, models
44

@@ -10,10 +10,20 @@ class Migration(migrations.Migration):
1010
]
1111

1212
operations = [
13+
migrations.RenameField(
14+
model_name="dataimportlog",
15+
old_name="updated_records",
16+
new_name="import_result",
17+
),
1318
migrations.AddField(
1419
model_name="dataimportlog",
1520
name="description",
1621
field=models.CharField(default="", max_length=255),
1722
preserve_default=False,
1823
),
24+
migrations.AddField(
25+
model_name="dataimportlog",
26+
name="dry_run",
27+
field=models.BooleanField(default=True),
28+
),
1929
]

jolpica_api/data_import/models.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@ class DataImportLog(models.Model):
88
id = models.BigAutoField(primary_key=True)
99
description = models.CharField(max_length=255)
1010
user = models.ForeignKey(User, on_delete=models.SET_NULL, null=True)
11+
dry_run = models.BooleanField(default=True)
1112
is_success = models.BooleanField(default=False)
1213
completed_at = models.DateTimeField(auto_now_add=True)
1314
total_records = models.PositiveIntegerField(null=True)
14-
updated_records = models.JSONField(null=True)
15+
import_result = models.JSONField(null=True)
1516
error_type = models.CharField(max_length=255, null=True)
1617
errors = models.JSONField(null=True)

jolpica_api/data_import/tests/test_views.py

+16-5
Original file line numberDiff line numberDiff line change
@@ -216,8 +216,12 @@ def test_successful_import(client):
216216
assert DataImportLog.objects.count() == 1
217217
log = DataImportLog.objects.first()
218218
assert log.completed_at is not None
219+
assert not log.dry_run
219220
assert log.error_type is None
220-
assert log.updated_records == {"Driver": [831]}
221+
assert log.import_result["created_count"] == 0
222+
assert log.import_result["updated_count"] == 1
223+
assert list(log.import_result["models"].keys()) == ["Driver"]
224+
assert log.import_result["models"]["Driver"]["updated"] == [831]
221225
assert log.is_success
222226
assert log.error_type is None
223227
assert log.errors is None
@@ -242,6 +246,7 @@ def test_validation_error_has_logs(client):
242246
assert DataImportLog.objects.count() == 1
243247
log = DataImportLog.objects.first()
244248
assert not log.is_success
249+
assert not log.dry_run
245250
assert log.error_type == "VALIDATION"
246251
assert log.errors[0]["type"] == "missing"
247252

@@ -265,6 +270,7 @@ def test_deserialisation_error_has_log(client):
265270
assert DataImportLog.objects.count() == 1
266271
log = DataImportLog.objects.first()
267272
assert not log.is_success
273+
assert not log.dry_run
268274
assert log.error_type == "DESERIALISATION"
269275
assert response.json()["errors"][0] == {
270276
"index": 0,
@@ -276,27 +282,31 @@ def test_deserialisation_error_has_log(client):
276282
@pytest.mark.django_db
277283
def test_dry_run(client):
278284
"""Test dry run."""
285+
assert f1.Driver.objects.get(reference="max_verstappen").forename == "Max"
279286
data = {
280287
"dry_run": True,
281288
"data": [
282289
{
283290
"object_type": "Driver",
284291
"foreign_keys": {},
285-
"objects": [{"reference": "max_verstappen", "forename": "Max"}],
292+
"objects": [{"reference": "max_verstappen", "forename": "Maxxx"}],
286293
}
287294
],
288295
}
289296
response = client.put("/data/import/", data, format="json")
290297

291298
assert response.status_code == status.HTTP_200_OK
292-
assert DataImportLog.objects.count() == 0 # No log for dry run
299+
assert f1.Driver.objects.get(reference="max_verstappen").forename == "Max"
300+
assert DataImportLog.objects.count() == 1
301+
assert DataImportLog.objects.first().dry_run
293302

294303

304+
@pytest.mark.parametrize(["dry_run"], [(True,), (False,)])
295305
@pytest.mark.django_db
296-
def test_db_error(client):
306+
def test_db_error(client, dry_run):
297307
"""Test database error during import."""
298308
data = {
299-
"dry_run": False,
309+
"dry_run": dry_run,
300310
"legacy_import": False,
301311
"data": [
302312
{
@@ -317,5 +327,6 @@ def test_db_error(client):
317327
assert response.status_code == status.HTTP_400_BAD_REQUEST
318328
assert DataImportLog.objects.count() == 1
319329
log = DataImportLog.objects.first()
330+
assert log.dry_run == dry_run
320331
assert not log.is_success
321332
assert log.error_type == "IMPORT"

jolpica_api/data_import/views.py

+61-35
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
from collections import defaultdict
2-
31
from django.contrib.auth.models import User
2+
from django.db import transaction
43
from pydantic import BaseModel, Field, ValidationError
54
from rest_framework.permissions import IsAdminUser, IsAuthenticated
65
from rest_framework.request import Request
@@ -14,6 +13,10 @@
1413
from .models import DataImportLog
1514

1615

16+
class DryRunError(Exception):
17+
pass
18+
19+
1720
class ImportDataRequestData(BaseModel):
1821
# If True, the data will be validated but not saved to the database
1922
dry_run: bool = True
@@ -38,51 +41,74 @@ def put(self, request: Request) -> Response:
3841
request_data = ImportDataRequestData.model_validate(request.data)
3942
except ValidationError as ex:
4043
errors = ex.errors(include_url=False)
41-
DataImportLog(user=request.user, is_success=False, error_type="VALIDATION", errors=errors).save()
44+
dry_run = request.data.get("dry_run", True)
45+
log_data_import_result(
46+
request.user,
47+
dry_run=dry_run if isinstance(dry_run, bool) else True,
48+
error_type="VALIDATION",
49+
errors=errors,
50+
)
4251
return Response({"errors": errors}, status=400)
4352

4453
model_importer = JSONModelImporter(legacy_import=request_data.legacy_import)
4554
result = model_importer.deserialise_all(request.data["data"])
4655

4756
if not result.success:
48-
DataImportLog(
57+
log_data_import_result(
58+
request.user,
59+
dry_run=request_data.dry_run,
4960
description=request_data.description,
50-
user=request.user,
51-
is_success=False,
5261
error_type="DESERIALISATION",
5362
errors=result.errors,
54-
).save()
63+
)
5564
return Response({"errors": result.errors}, status=400)
5665

57-
if not request_data.dry_run:
58-
try:
59-
model_importer.save_deserialisation_result_to_db(result)
60-
except Exception as ex:
61-
DataImportLog(
62-
description=request_data.description,
63-
user=request.user,
64-
is_success=False,
65-
error_type="IMPORT",
66-
errors=[repr(ex)],
67-
).save()
68-
return Response({"errors": [{"type": "import_error", "message": repr(ex)}]}, status=400)
69-
70-
save_successful_import_to_db(request_data.description, request.user, result)
71-
return Response({})
72-
73-
74-
def save_successful_import_to_db(description: str, user: User | None, result: DeserialisationResult) -> None:
75-
updated_record_count = 0
76-
updated_records = defaultdict(list)
77-
for model_import, instances in result.instances.items():
78-
instance_pks = [ins.pk for ins in instances]
79-
updated_record_count += len(instance_pks)
80-
updated_records[model_import.model_class.__name__].extend(instance_pks)
66+
try:
67+
import_stats = save_deserialisation_result_to_db(result, request_data.dry_run)
68+
except Exception as ex:
69+
errors = [{"type": "import_error", "message": repr(ex)}]
70+
log_data_import_result(
71+
request.user,
72+
dry_run=request_data.dry_run,
73+
description=request_data.description,
74+
error_type="IMPORT",
75+
errors=errors,
76+
)
77+
return Response({"errors": errors}, status=400)
78+
79+
log_data_import_result(
80+
request.user, dry_run=request_data.dry_run, description=request_data.description, import_stats=import_stats
81+
)
82+
return Response(import_stats)
8183

84+
85+
def save_deserialisation_result_to_db(result: DeserialisationResult, dry_run: bool) -> dict:
86+
try:
87+
with transaction.atomic():
88+
import_stats = JSONModelImporter.save_deserialisation_result_to_db(result)
89+
90+
if dry_run:
91+
raise DryRunError("Transaction should be rolled back as this is a dry run") # noqa: TRY301
92+
except DryRunError:
93+
pass # Rollback the transaction, but keep import_stats
94+
return import_stats
95+
96+
97+
def log_data_import_result(
98+
user: User,
99+
dry_run: bool,
100+
description: str = "",
101+
import_stats: dict | None = None,
102+
error_type: str | None = None,
103+
errors: list | None = None,
104+
):
82105
DataImportLog(
83-
description=description,
106+
is_success=False if error_type else True,
107+
dry_run=dry_run,
84108
user=user,
85-
is_success=True,
86-
total_records=updated_record_count,
87-
updated_records=updated_records,
109+
description=description,
110+
total_records=import_stats.get("total_count") if import_stats else None,
111+
import_result=import_stats,
112+
error_type=error_type,
113+
errors=errors,
88114
).save()

0 commit comments

Comments
 (0)