Skip to content

Commit 9c0006a

Browse files
added find_one (#21)
1 parent 0610c90 commit 9c0006a

File tree

3 files changed

+59
-16
lines changed

3 files changed

+59
-16
lines changed

sqlmodel_repository/base_repository.py

Lines changed: 45 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Generic, List, Optional, Type, TypeVar, get_args
44

55
from sqlalchemy.orm import Session
6+
from sqlalchemy.sql.elements import ColumnClause
67
from sqlmodel import col
78
from structlog import WriteLogger
89

@@ -15,6 +16,7 @@
1516

1617
class BaseRepository(Generic[GenericEntity], ABC):
1718
"""Abstract base class for all repositories"""
19+
1820
_default_excluded_keys = ["_sa_instance_state"]
1921

2022
def __init__(self, logger: Optional[WriteLogger] = None, sensitive_attribute_keys: Optional[list[str]] = None):
@@ -45,20 +47,31 @@ def find(self, **kwargs) -> List[GenericEntity]:
4547
4648
Returns:
4749
List[GenericEntity]: The entities that were found in the repository for the given filters
48-
50+
4951
Notes:
5052
- Success log is covered by get_batch
5153
"""
5254
filters = []
5355
self._emit_operation_begin_log("Finding", **kwargs)
54-
55-
for key, value in kwargs.items():
56-
try:
57-
filters.append(col(getattr(self.entity, key)) == value)
58-
except AttributeError as attribute_error:
59-
raise EntityDoesNotPossessAttributeException(f"Entity {self.entity} does not have the attribute {key}") from attribute_error
56+
filters = self._create_filters(**kwargs)
6057
return self.get_batch(filters=filters)
6158

59+
def find_one(self, **kwargs) -> GenericEntity:
60+
"""Get a single entity with one query by filters
61+
62+
Args:
63+
**kwargs: The filters to apply
64+
65+
Returns:
66+
GenericEntity: The entity that was found in the repository for the given filters
67+
"""
68+
session = self.get_session()
69+
self._emit_operation_begin_log("Finding one", **kwargs)
70+
filters = self._create_filters(**kwargs)
71+
result = session.query(self.entity).filter(*filters).one()
72+
self._emit_operation_success_log("Finding one", entities=[result])
73+
return result
74+
6275
def update(self, entity: GenericEntity, **kwargs) -> GenericEntity:
6376
"""Updates an entity with the given attributes (keyword arguments) if they are not None
6477
@@ -89,7 +102,7 @@ def update(self, entity: GenericEntity, **kwargs) -> GenericEntity:
89102

90103
session.commit()
91104
session.refresh(entity)
92-
105+
93106
self._emit_operation_success_log("Updating", entities=[entity])
94107
return entity
95108

@@ -141,7 +154,7 @@ def get(self, entity_id: int) -> GenericEntity:
141154
result = session.query(self.entity).filter(self.entity.id == entity_id).one_or_none()
142155
if result is None:
143156
raise EntityNotFoundException(f"Entity {GenericEntity.__name__} with ID {entity_id} not found")
144-
157+
145158
self._emit_operation_success_log("Getting", entities=[result])
146159
return result
147160

@@ -162,7 +175,7 @@ def get_batch(self, filters: Optional[list] = None) -> list[GenericEntity]:
162175
self._emit_operation_begin_log("Batch get")
163176

164177
result = session.query(self.entity).filter(*filters).all()
165-
178+
166179
self._emit_operation_success_log("Batch get", entities=result)
167180
return result
168181

@@ -185,8 +198,8 @@ def create(self, entity: GenericEntity) -> GenericEntity:
185198
session.add(entity)
186199
session.commit()
187200
session.refresh(entity)
188-
189-
self._emit_operation_success_log("Creating", entities=[entity])
201+
202+
self._emit_operation_success_log("Creating", entities=[entity])
190203
return entity
191204
except Exception as exception:
192205
session.rollback()
@@ -215,7 +228,7 @@ def create_batch(self, entities: list[GenericEntity]) -> list[GenericEntity]:
215228

216229
for entity in entities:
217230
session.refresh(entity)
218-
231+
219232
self._emit_operation_success_log("Batch creating", entities=entities)
220233
return entities
221234

@@ -237,7 +250,7 @@ def delete(self, entity: GenericEntity) -> GenericEntity:
237250
try:
238251
session.delete(entity)
239252
session.commit()
240-
253+
241254
self._emit_operation_success_log("Deleting", entities=[entity])
242255
return entity
243256
except Exception as exception:
@@ -266,6 +279,23 @@ def delete_batch(self, entities: list[GenericEntity]) -> None:
266279
session.rollback()
267280
raise CouldNotDeleteEntityException from exception
268281

282+
def _create_filters(self, **kwargs) -> list[ColumnClause]:
283+
"""Creates a list of filters for a query
284+
285+
Args:
286+
**kwargs: The filters to build
287+
288+
Returns:
289+
list: The filters to apply to a query
290+
"""
291+
filters = []
292+
for key, value in kwargs.items():
293+
try:
294+
filters.append(col(getattr(self.entity, key)) == value)
295+
except AttributeError as attribute_error:
296+
raise EntityDoesNotPossessAttributeException(f"Entity {self.entity} does not have the attribute {key}") from attribute_error
297+
return filters
298+
269299
def _safe_kwargs(self, prefix: str = "", **kwargs) -> dict[str, str]:
270300
"""Filters out sensitive attributes from the log kwargs
271301
@@ -290,7 +320,7 @@ def _emit_operation_success_log(self, operation: str, entities: Optional[list[Ge
290320
entity_ids = [entity.id for entity in entities]
291321
entity_log: dict = {"entity_ids": entity_ids}
292322
self.logger.debug(f"{operation} {self.entity.__name__} succeeded", **entity_log)
293-
except Exception as exception: # pylint: disable=broad-except:
323+
except Exception as exception: # pylint: disable=broad-except:
294324
# We want to catch all exceptions here. Logs must be written by all means. It's no silent passing and thereby acceptable.
295325
self.logger.exception(f"Could not emit log for concluding {operation} {self.entity.__name__}", exception=exception) # type: ignore TODO: fix this
296326

tests/integration/test_base_repository_with_database.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,19 @@ def test_raises_entity_does_not_possess_attribute(self, pet_base_repository: Pet
106106
with pytest.raises(EntityDoesNotPossessAttributeException):
107107
pet_base_repository.find(legs=12)
108108

109+
class TestFindOne:
110+
"""Tests for the find_one method."""
111+
112+
def test_find_one_by_attribute(self, pet_base_repository: PetBaseRepository, dog: Pet):
113+
"""Test to find an entity"""
114+
assert pet_base_repository.find_one(name=dog.name) == dog
115+
assert pet_base_repository.find_one(type=PetType.DOG) == dog
116+
117+
def test_raises_entity_does_not_possess_attribute(self, pet_base_repository: PetBaseRepository, dog: Pet):
118+
"""Test to find an entity"""
119+
with pytest.raises(EntityDoesNotPossessAttributeException):
120+
pet_base_repository.find_one(legs=12)
121+
109122
class TestCreateBatch:
110123
"""Tests for the _create_batch method"""
111124

tests/unit/test_base_repository.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,4 @@ class TestRepository(BaseRepository[invalid_entity_class]): # type: ignore
2121
pass
2222

2323
with pytest.raises(TypeError):
24-
TestRepository._entity_class()
24+
TestRepository._entity_class()

0 commit comments

Comments
 (0)