Skip to content

Commit c07e28c

Browse files
committed
Reuse deserialiser logic
1 parent afc6557 commit c07e28c

File tree

4 files changed

+129
-112
lines changed

4 files changed

+129
-112
lines changed

jolpica/formula_one/importer/deserialisers.py

+35-59
Original file line numberDiff line numberDiff line change
@@ -117,31 +117,30 @@ def get_model_instance(self, model_class: type[M], **unique_fields) -> M:
117117
return cache[cache_key]
118118

119119

120-
class BaseDeserializer[O: json_models.F1Object]:
120+
class Deserialiser[O: json_models.F1Object]:
121121
"""Base class for all deserializers."""
122122

123-
MODEL: ClassVar[type[models.Model]]
124-
JSON_IMPORT_TYPE: ClassVar[type[json_models.F1Import]]
125-
UNIQUE_FIELDS: ClassVar[tuple[str, ...]]
126-
127123
_cache: ModelLookupCache
128124
legacy_import: bool
129125

130-
def __init__(self, cache: ModelLookupCache | None = None, legacy_import: bool = False):
131-
if not hasattr(self, "MODEL") or self.MODEL is None:
132-
raise NotImplementedError(f"{self.__class__.__name__} must define MODEL")
133-
if not hasattr(self, "JSON_IMPORT_TYPE") or self.JSON_IMPORT_TYPE is None:
134-
raise NotImplementedError(f"{self.__class__.__name__} must define JSON_IMPORT_TYPE")
135-
if not hasattr(self, "UNIQUE_FIELDS") or self.UNIQUE_FIELDS is None:
136-
raise NotImplementedError(f"{self.__class__.__name__} must define UNIQUE_FIELDS")
137-
126+
def __init__(
127+
self,
128+
model: type[models.Model],
129+
json_import_type: type[json_models.F1Import[O]],
130+
unique_fields: tuple[str, ...],
131+
cache: ModelLookupCache | None = None,
132+
legacy_import: bool = False,
133+
):
134+
self.model = model
135+
self.json_import_type = json_import_type
136+
self.unique_fields = unique_fields
138137
self._cache = cache if cache is not None else ModelLookupCache()
139138
self.legacy_import = legacy_import
140139

141140
def _get_common_foreign_keys(self, foreign_keys: json_models.F1ForeignKeys) -> ForeignKeyDict:
142141
"""Get the foreign keys that are required to get or create the unique model instance."""
143142
values = {}
144-
if self.MODEL == f1.RoundEntry:
143+
if self.model == f1.RoundEntry:
145144
values["round"] = self._cache.get_model_instance(
146145
f1.Round, season__year=foreign_keys.year, number=foreign_keys.round
147146
)
@@ -151,7 +150,7 @@ def _get_common_foreign_keys(self, foreign_keys: json_models.F1ForeignKeys) -> F
151150
driver__reference=foreign_keys.driver_reference,
152151
team__reference=foreign_keys.team_reference,
153152
)
154-
elif self.MODEL == f1.SessionEntry:
153+
elif self.model == f1.SessionEntry:
155154
values["session"] = self._cache.get_model_instance(
156155
f1.Session,
157156
round__season__year=foreign_keys.year,
@@ -164,38 +163,38 @@ def _get_common_foreign_keys(self, foreign_keys: json_models.F1ForeignKeys) -> F
164163
round__number=foreign_keys.round,
165164
car_number=foreign_keys.car_number,
166165
)
167-
elif self.MODEL in {f1.Lap, f1.PitStop}:
166+
elif self.model in {f1.Lap, f1.PitStop}:
168167
values["session_entry"] = self._cache.get_model_instance(
169168
f1.SessionEntry,
170169
session__round__season__year=foreign_keys.year,
171170
session__round__number=foreign_keys.round,
172171
session__type=foreign_keys.session,
173172
round_entry__car_number=foreign_keys.car_number,
174173
)
175-
if self.MODEL == f1.PitStop:
174+
if self.model == f1.PitStop:
176175
values["lap"] = self._cache.get_model_instance(
177176
f1.Lap, session_entry_id=values["session_entry"].id, number=foreign_keys.lap
178177
)
179178
return values
180179

181180
def create_model_instance(self, foreign_key_fields: ForeignKeyDict, field_values: O) -> models.Model:
182-
return self.MODEL(**foreign_key_fields, **field_values.model_dump(exclude_unset=True))
181+
return self.model(**foreign_key_fields, **field_values.model_dump(exclude_unset=True))
183182

184183
def get_unique_fields(self, data: json_models.F1Import[O], object_data: O) -> tuple[str, ...]:
185184
if (
186-
self.MODEL == f1.Lap
185+
self.model == f1.Lap
187186
and isinstance(object_data, json_models.LapObject)
188187
and self.legacy_import
189188
and data.foreign_keys.session != "R"
190189
and object_data.is_entry_fastest_lap
191190
):
192191
logger.warning(f"Legacy import for {data.object_type} overriding unique fields")
193192
return ("session_entry", "is_entry_fastest_lap")
194-
return self.UNIQUE_FIELDS
193+
return self.unique_fields
195194

196195
def deserialise(self, data_dict: dict) -> DeserialisationResult:
197196
try:
198-
data = self.JSON_IMPORT_TYPE.model_validate(data_dict)
197+
data = self.json_import_type.model_validate(data_dict)
199198
except ValidationError as ex:
200199
return DeserialisationResult(
201200
success=False, data=data_dict, errors=ex.errors(include_url=False, include_input=False)
@@ -218,7 +217,7 @@ def deserialise(self, data_dict: dict) -> DeserialisationResult:
218217
unique_fields = self.get_unique_fields(data, obj_data)
219218
model_instances[
220219
ModelImport(
221-
self.MODEL,
220+
self.model,
222221
tuple(obj_data.model_fields_set),
223222
unique_fields,
224223
)
@@ -230,49 +229,26 @@ def deserialise(self, data_dict: dict) -> DeserialisationResult:
230229
return DeserialisationResult(success=True, data=data_dict, instances=model_instances)
231230

232231

233-
class RoundEntryDeserialiser(BaseDeserializer[json_models.RoundEntryObject]):
234-
MODEL = f1.RoundEntry
235-
JSON_IMPORT_TYPE = json_models.RoundEntryImport
236-
UNIQUE_FIELDS = ("round", "team_driver", "car_number")
237-
238-
239-
class SessionEntryDeserialiser(BaseDeserializer[json_models.SessionEntryObject]):
240-
MODEL = f1.SessionEntry
241-
JSON_IMPORT_TYPE = json_models.SessionEntryImport
242-
UNIQUE_FIELDS = ("session", "round_entry")
243-
244-
245-
class LapDeserialiser(BaseDeserializer[json_models.LapObject]):
246-
MODEL = f1.Lap
247-
JSON_IMPORT_TYPE = json_models.LapImport
248-
UNIQUE_FIELDS = ("session_entry", "number")
249-
250-
251-
class PitStopDeserialiser(BaseDeserializer[json_models.PitStopObject]):
252-
MODEL = f1.PitStop
253-
JSON_IMPORT_TYPE = json_models.PitStopImport
254-
UNIQUE_FIELDS = ("session_entry", "number")
255-
256-
257232
class DeserialiserFactory:
258-
deserialisers: ClassVar[dict[str, type[BaseDeserializer]]] = {
259-
"SessionEntry": SessionEntryDeserialiser,
260-
"classification": SessionEntryDeserialiser,
261-
"session_entry": SessionEntryDeserialiser,
262-
"RoundEntry": RoundEntryDeserialiser,
263-
"Lap": LapDeserialiser,
264-
"lap": LapDeserialiser,
265-
"PitStop": PitStopDeserialiser,
266-
"pit_stop": PitStopDeserialiser,
233+
deserialisers: ClassVar[dict[str, tuple[type[models.Model], type[json_models.F1Import], tuple[str, ...]]]] = {
234+
"SessionEntry": (f1.SessionEntry, json_models.SessionEntryImport, ("session", "round_entry")),
235+
"classification": (f1.SessionEntry, json_models.SessionEntryImport, ("session", "round_entry")),
236+
"session_entry": (f1.SessionEntry, json_models.SessionEntryImport, ("session", "round_entry")),
237+
"RoundEntry": (f1.RoundEntry, json_models.RoundEntryImport, ("round", "team_driver", "car_number")),
238+
"Lap": (f1.Lap, json_models.LapImport, ("session_entry", "number")),
239+
"lap": (f1.Lap, json_models.LapImport, ("session_entry", "number")),
240+
"PitStop": (f1.PitStop, json_models.PitStopImport, ("session_entry", "number")),
241+
"pit_stop": (f1.PitStop, json_models.PitStopImport, ("session_entry", "number")),
267242
}
268243

269244
def __init__(self, cache: ModelLookupCache[models.Model] | None = None, legacy_import: bool = False):
270245
self.cache = cache if cache is not None else ModelLookupCache()
271246
self.legacy_import = legacy_import
272247

273-
def get_deserialiser(self, object_type: str) -> BaseDeserializer:
274-
deserialiser_class = self.deserialisers.get(object_type)
275-
if deserialiser_class is None:
248+
def get_deserialiser(self, object_type: str) -> Deserialiser:
249+
args = self.deserialisers.get(object_type, None)
250+
if not args:
276251
raise ValueError(f"Deserializer not found for object type: {object_type}")
252+
model, json_import_type, unique_fields = args
277253

278-
return deserialiser_class(cache=self.cache, legacy_import=self.legacy_import)
254+
return Deserialiser(model, json_import_type, unique_fields, cache=self.cache, legacy_import=self.legacy_import)

jolpica/formula_one/importer/json_models.py

+32-16
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,33 @@
22
from datetime import timedelta
33
from typing import Annotated, Any, Literal
44

5-
from pydantic import BaseModel, BeforeValidator, ConfigDict
5+
from pydantic import (
6+
BaseModel,
7+
BeforeValidator,
8+
ConfigDict,
9+
NonNegativeFloat,
10+
NonNegativeInt,
11+
PositiveFloat,
12+
PositiveInt,
13+
)
14+
15+
16+
class TimedeltaDict(BaseModel):
17+
model_config = ConfigDict(extra="forbid")
18+
milliseconds: NonNegativeInt = 0
19+
seconds: NonNegativeInt = 0
20+
minutes: NonNegativeInt = 0
21+
hours: NonNegativeInt = 0
22+
days: NonNegativeInt = 0
23+
24+
def to_timedelta(self) -> timedelta:
25+
return timedelta(**self.model_dump())
626

727

8-
# TODO: Create more validators (e.g. points > 26 likely is an error)
928
def mutate_timedelta_from_dict(value: Any) -> Any:
1029
if isinstance(value, dict) and value.get("_type") == "timedelta":
1130
del value["_type"]
12-
for val in value.keys():
13-
if val not in {"milliseconds", "seconds", "minutes", "hours", "days"}:
14-
raise ValueError(f"{val} is not a valid field for timedelta")
15-
return timedelta(**value)
31+
return TimedeltaDict(**value).to_timedelta()
1632
return value
1733

1834

@@ -47,7 +63,7 @@ class RoundEntryForeignKeys(F1ForeignKeys):
4763

4864

4965
class RoundEntryObject(F1Object):
50-
car_number: int | None = None
66+
car_number: PositiveInt | None = None
5167

5268

5369
class RoundEntryImport(F1Import):
@@ -64,16 +80,16 @@ class SessionEntryForeignKeys(F1ForeignKeys):
6480

6581

6682
class SessionEntryObject(F1Object):
67-
position: int | None = None
83+
position: PositiveInt | None = None
6884
is_classified: bool | None = None
6985
status: int | None = None
7086
detail: str | None = None
71-
points: int | None = None
87+
points: NonNegativeFloat | None = None
7288
is_eligible_for_points: bool | None = None
73-
grid: int | None = None
89+
grid: PositiveInt | None = None
7490
time: Annotated[timedelta | None, BeforeValidator(mutate_timedelta_from_dict)] = None
75-
fastest_lap_rank: int | None = None
76-
laps_completed: int | None = None
91+
fastest_lap_rank: PositiveInt | None = None
92+
laps_completed: NonNegativeInt | None = None
7793

7894

7995
class SessionEntryImport(F1Import):
@@ -90,10 +106,10 @@ class LapForeignKeys(F1ForeignKeys):
90106

91107

92108
class LapObject(F1Object):
93-
number: int | None = None
94-
position: int | None = None
109+
number: PositiveInt | None = None
110+
position: PositiveInt | None = None
95111
time: Annotated[timedelta | None, BeforeValidator(mutate_timedelta_from_dict)] = None
96-
average_speed: float | None = None
112+
average_speed: PositiveFloat | None = None
97113
is_entry_fastest_lap: bool | None = None
98114
is_deleted: bool | None = None
99115

@@ -113,7 +129,7 @@ class PitStopForeignKeys(F1ForeignKeys):
113129

114130

115131
class PitStopObject(F1Object):
116-
number: int | None = None
132+
number: PositiveInt | None = None
117133
duration: Annotated[timedelta | None, BeforeValidator(mutate_timedelta_from_dict)] = None
118134
local_timestamp: str | None = None
119135

0 commit comments

Comments
 (0)