@@ -256,7 +256,7 @@ def target_mapper(self):
256256
257257 @property
258258 def __isabstractmethod__ (self ):
259- # compatibility between this description __getattr__ usage and abc.ABC
259+ # compatibility between this descriptor __getattr__ usage and abc.ABC
260260 return False
261261
262262 def __get__ (self , obj , owner = None ):
@@ -267,18 +267,29 @@ def __get__(self, obj, owner=None):
267267 return obj .__dict__ [self .attribute ]
268268
269269 def __set__ (self , obj , value ):
270- obj .__dict__ [self .attribute ] = value if self .single else self .list_class (obj , self , value )
270+ if self .single :
271+ obj .__dict__ [self .attribute ] = value
272+ self .update_related_objs (obj , value )
273+ flag_dirty_attr (obj , self .source_attr )
274+ else :
275+ obj .__dict__ [self .attribute ] = self .list_class (obj , self , [])
276+ for item in value :
277+ obj .__dict__ [self .attribute ].append (item ) # ensure target_attr is set
271278
272279 def is_loaded (self , obj ):
273280 return self .attribute in obj .__dict__
274281
275282 def fetch (self , obj ):
276283 """Fetches the list of related objects from the database and loads it in the object"""
277284 r = self .target .query (self .select_from_target (obj ))
278- self .__set__ (obj , r .first () if self .single else r .all ())
285+ obj .__dict__ [self .attribute ] = r .first () if self .single else self .list_class (obj , self , r .all ())
286+
287+ def load (self , obj , values ):
288+ value = values [self .attribute ]
289+ obj .__dict__ [self .attribute ] = value if self .single else self .list_class (obj , self , value )
279290
280291
281- class RelatedObjectsList ( object ) :
292+ class RelatedObjectsList :
282293 def __init__ (self , obj , relationship , items ):
283294 self .obj = obj
284295 self .relationship = relationship
@@ -297,20 +308,12 @@ def __contains__(self, item):
297308 return item in self .items
298309
299310 def append (self , item ):
300- target_attr = self .relationship .target_attr
301- if not target_attr :
302- raise MapperError (
303- f"Missing target_attr on relationship '{ self .relationship .attribute } '"
304- )
305- setattr (item , target_attr , getattr (self .obj , self .relationship .source_attr ))
311+ self .relationship .update_related_objs (self .obj , item )
312+ flag_dirty_attr (item , self .relationship .target_attr )
306313
307314 def remove (self , item ):
308- target_attr = self .relationship .target_attr
309- if not target_attr :
310- raise MapperError (
311- f"Missing target_attr on relationship '{ self .relationship .attribute } '"
312- )
313- setattr (item , target_attr , None )
315+ self .relationship .update_related_objs (None , item )
316+ flag_dirty_attr (item , self .relationship .target_attr )
314317
315318
316319def flag_dirty_attr (obj , attr ):
@@ -367,6 +370,7 @@ def __sql__(self):
367370
368371class Model (BaseModel , abc .ABC ):
369372 """Our standard model class with CRUD methods"""
373+ __resultset_class__ = CompositeResultSet
370374
371375 class Meta :
372376 insert_update_dirty_only : bool = (
@@ -397,12 +401,12 @@ def query(cls, stmt, params=None) -> CompositeResultSet:
397401 with ensure_transaction (cls .__engine__ ) as tx :
398402 rv = _signal_rv (cls .before_query .send (cls , stmt = stmt , params = params ))
399403 if rv is False :
400- return ResultSet (None )
404+ return cls . __resultset_class__ (None )
401405 if isinstance (rv , ResultSet ):
402406 return rv
403407 if isinstance (rv , tuple ):
404408 stmt , params = rv
405- return tx .fetchhydrated (cls , stmt , params )
409+ return tx .fetchhydrated (cls , stmt , params , resultset_class = cls . __resultset_class__ )
406410
407411 @classmethod
408412 def find_all (
@@ -473,8 +477,11 @@ def __init__(self, **values):
473477 setattr (self , k , v )
474478
475479 def __setattr__ (self , name , value ):
476- self .__dict__ [name ] = value
477- flag_dirty_attr (self , name )
480+ if isinstance (getattr (self .__class__ , name , None ), (ModelColumnMixin , Relationship )):
481+ super ().__setattr__ (name , value )
482+ else :
483+ self .__dict__ [name ] = value
484+ flag_dirty_attr (self , name )
478485
479486 def refresh (self , ** select_kwargs ):
480487 stmt = self .__mapper__ .select_by_pk (self .__mapper__ .get_primary_key (self ), ** select_kwargs )
0 commit comments