@@ -117,31 +117,30 @@ def get_model_instance(self, model_class: type[M], **unique_fields) -> M:
117
117
return cache [cache_key ]
118
118
119
119
120
- class BaseDeserializer [O : json_models .F1Object ]:
120
+ class Deserialiser [O : json_models .F1Object ]:
121
121
"""Base class for all deserializers."""
122
122
123
- MODEL : ClassVar [type [models .Model ]]
124
- JSON_IMPORT_TYPE : ClassVar [type [json_models .F1Import ]]
125
- UNIQUE_FIELDS : ClassVar [tuple [str , ...]]
126
-
127
123
_cache : ModelLookupCache
128
124
legacy_import : bool
129
125
130
- def __init__ (self , cache : ModelLookupCache | None = None , legacy_import : bool = False ):
131
- if not hasattr (self , "MODEL" ) or self .MODEL is None :
132
- raise NotImplementedError (f"{ self .__class__ .__name__ } must define MODEL" )
133
- if not hasattr (self , "JSON_IMPORT_TYPE" ) or self .JSON_IMPORT_TYPE is None :
134
- raise NotImplementedError (f"{ self .__class__ .__name__ } must define JSON_IMPORT_TYPE" )
135
- if not hasattr (self , "UNIQUE_FIELDS" ) or self .UNIQUE_FIELDS is None :
136
- raise NotImplementedError (f"{ self .__class__ .__name__ } must define UNIQUE_FIELDS" )
137
-
126
+ def __init__ (
127
+ self ,
128
+ model : type [models .Model ],
129
+ json_import_type : type [json_models .F1Import [O ]],
130
+ unique_fields : tuple [str , ...],
131
+ cache : ModelLookupCache | None = None ,
132
+ legacy_import : bool = False ,
133
+ ):
134
+ self .model = model
135
+ self .json_import_type = json_import_type
136
+ self .unique_fields = unique_fields
138
137
self ._cache = cache if cache is not None else ModelLookupCache ()
139
138
self .legacy_import = legacy_import
140
139
141
140
def _get_common_foreign_keys (self , foreign_keys : json_models .F1ForeignKeys ) -> ForeignKeyDict :
142
141
"""Get the foreign keys that are required to get or create the unique model instance."""
143
142
values = {}
144
- if self .MODEL == f1 .RoundEntry :
143
+ if self .model == f1 .RoundEntry :
145
144
values ["round" ] = self ._cache .get_model_instance (
146
145
f1 .Round , season__year = foreign_keys .year , number = foreign_keys .round
147
146
)
@@ -151,7 +150,7 @@ def _get_common_foreign_keys(self, foreign_keys: json_models.F1ForeignKeys) -> F
151
150
driver__reference = foreign_keys .driver_reference ,
152
151
team__reference = foreign_keys .team_reference ,
153
152
)
154
- elif self .MODEL == f1 .SessionEntry :
153
+ elif self .model == f1 .SessionEntry :
155
154
values ["session" ] = self ._cache .get_model_instance (
156
155
f1 .Session ,
157
156
round__season__year = foreign_keys .year ,
@@ -164,38 +163,38 @@ def _get_common_foreign_keys(self, foreign_keys: json_models.F1ForeignKeys) -> F
164
163
round__number = foreign_keys .round ,
165
164
car_number = foreign_keys .car_number ,
166
165
)
167
- elif self .MODEL in {f1 .Lap , f1 .PitStop }:
166
+ elif self .model in {f1 .Lap , f1 .PitStop }:
168
167
values ["session_entry" ] = self ._cache .get_model_instance (
169
168
f1 .SessionEntry ,
170
169
session__round__season__year = foreign_keys .year ,
171
170
session__round__number = foreign_keys .round ,
172
171
session__type = foreign_keys .session ,
173
172
round_entry__car_number = foreign_keys .car_number ,
174
173
)
175
- if self .MODEL == f1 .PitStop :
174
+ if self .model == f1 .PitStop :
176
175
values ["lap" ] = self ._cache .get_model_instance (
177
176
f1 .Lap , session_entry_id = values ["session_entry" ].id , number = foreign_keys .lap
178
177
)
179
178
return values
180
179
181
180
def create_model_instance (self , foreign_key_fields : ForeignKeyDict , field_values : O ) -> models .Model :
182
- return self .MODEL (** foreign_key_fields , ** field_values .model_dump (exclude_unset = True ))
181
+ return self .model (** foreign_key_fields , ** field_values .model_dump (exclude_unset = True ))
183
182
184
183
def get_unique_fields (self , data : json_models .F1Import [O ], object_data : O ) -> tuple [str , ...]:
185
184
if (
186
- self .MODEL == f1 .Lap
185
+ self .model == f1 .Lap
187
186
and isinstance (object_data , json_models .LapObject )
188
187
and self .legacy_import
189
188
and data .foreign_keys .session != "R"
190
189
and object_data .is_entry_fastest_lap
191
190
):
192
191
logger .warning (f"Legacy import for { data .object_type } overriding unique fields" )
193
192
return ("session_entry" , "is_entry_fastest_lap" )
194
- return self .UNIQUE_FIELDS
193
+ return self .unique_fields
195
194
196
195
def deserialise (self , data_dict : dict ) -> DeserialisationResult :
197
196
try :
198
- data = self .JSON_IMPORT_TYPE .model_validate (data_dict )
197
+ data = self .json_import_type .model_validate (data_dict )
199
198
except ValidationError as ex :
200
199
return DeserialisationResult (
201
200
success = False , data = data_dict , errors = ex .errors (include_url = False , include_input = False )
@@ -218,7 +217,7 @@ def deserialise(self, data_dict: dict) -> DeserialisationResult:
218
217
unique_fields = self .get_unique_fields (data , obj_data )
219
218
model_instances [
220
219
ModelImport (
221
- self .MODEL ,
220
+ self .model ,
222
221
tuple (obj_data .model_fields_set ),
223
222
unique_fields ,
224
223
)
@@ -230,49 +229,26 @@ def deserialise(self, data_dict: dict) -> DeserialisationResult:
230
229
return DeserialisationResult (success = True , data = data_dict , instances = model_instances )
231
230
232
231
233
- class RoundEntryDeserialiser (BaseDeserializer [json_models .RoundEntryObject ]):
234
- MODEL = f1 .RoundEntry
235
- JSON_IMPORT_TYPE = json_models .RoundEntryImport
236
- UNIQUE_FIELDS = ("round" , "team_driver" , "car_number" )
237
-
238
-
239
- class SessionEntryDeserialiser (BaseDeserializer [json_models .SessionEntryObject ]):
240
- MODEL = f1 .SessionEntry
241
- JSON_IMPORT_TYPE = json_models .SessionEntryImport
242
- UNIQUE_FIELDS = ("session" , "round_entry" )
243
-
244
-
245
- class LapDeserialiser (BaseDeserializer [json_models .LapObject ]):
246
- MODEL = f1 .Lap
247
- JSON_IMPORT_TYPE = json_models .LapImport
248
- UNIQUE_FIELDS = ("session_entry" , "number" )
249
-
250
-
251
- class PitStopDeserialiser (BaseDeserializer [json_models .PitStopObject ]):
252
- MODEL = f1 .PitStop
253
- JSON_IMPORT_TYPE = json_models .PitStopImport
254
- UNIQUE_FIELDS = ("session_entry" , "number" )
255
-
256
-
257
232
class DeserialiserFactory :
258
- deserialisers : ClassVar [dict [str , type [BaseDeserializer ]]] = {
259
- "SessionEntry" : SessionEntryDeserialiser ,
260
- "classification" : SessionEntryDeserialiser ,
261
- "session_entry" : SessionEntryDeserialiser ,
262
- "RoundEntry" : RoundEntryDeserialiser ,
263
- "Lap" : LapDeserialiser ,
264
- "lap" : LapDeserialiser ,
265
- "PitStop" : PitStopDeserialiser ,
266
- "pit_stop" : PitStopDeserialiser ,
233
+ deserialisers : ClassVar [dict [str , tuple [ type [models . Model ], type [ json_models . F1Import ], tuple [ str , ...] ]]] = {
234
+ "SessionEntry" : ( f1 . SessionEntry , json_models . SessionEntryImport , ( "session" , "round_entry" )) ,
235
+ "classification" : ( f1 . SessionEntry , json_models . SessionEntryImport , ( "session" , "round_entry" )) ,
236
+ "session_entry" : ( f1 . SessionEntry , json_models . SessionEntryImport , ( "session" , "round_entry" )) ,
237
+ "RoundEntry" : ( f1 . RoundEntry , json_models . RoundEntryImport , ( "round" , "team_driver" , "car_number" )) ,
238
+ "Lap" : ( f1 . Lap , json_models . LapImport , ( "session_entry" , "number" )) ,
239
+ "lap" : ( f1 . Lap , json_models . LapImport , ( "session_entry" , "number" )) ,
240
+ "PitStop" : ( f1 . PitStop , json_models . PitStopImport , ( "session_entry" , "number" )) ,
241
+ "pit_stop" : ( f1 . PitStop , json_models . PitStopImport , ( "session_entry" , "number" )) ,
267
242
}
268
243
269
244
def __init__ (self , cache : ModelLookupCache [models .Model ] | None = None , legacy_import : bool = False ):
270
245
self .cache = cache if cache is not None else ModelLookupCache ()
271
246
self .legacy_import = legacy_import
272
247
273
- def get_deserialiser (self , object_type : str ) -> BaseDeserializer :
274
- deserialiser_class = self .deserialisers .get (object_type )
275
- if deserialiser_class is None :
248
+ def get_deserialiser (self , object_type : str ) -> Deserialiser :
249
+ args = self .deserialisers .get (object_type , None )
250
+ if not args :
276
251
raise ValueError (f"Deserializer not found for object type: { object_type } " )
252
+ model , json_import_type , unique_fields = args
277
253
278
- return deserialiser_class ( cache = self .cache , legacy_import = self .legacy_import )
254
+ return Deserialiser ( model , json_import_type , unique_fields , cache = self .cache , legacy_import = self .legacy_import )
0 commit comments