Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: adjust lambda_stmt statement tracking #128

Merged
merged 8 commits into from
Feb 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 24 additions & 9 deletions advanced_alchemy/repository/_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,9 +337,19 @@ async def exists(
def _get_base_stmt(
self,
statement: Select[tuple[ModelT]] | StatementLambdaElement | None = None,
global_track_bound_values: bool = True,
track_closure_variables: bool = True,
enable_tracking: bool = True,
track_bound_values: bool = True,
) -> StatementLambdaElement:
if isinstance(statement, Select):
return lambda_stmt(lambda: statement)
return lambda_stmt(
lambda: statement,
track_bound_values=track_bound_values,
global_track_bound_values=global_track_bound_values,
track_closure_variables=track_closure_variables,
enable_tracking=enable_tracking,
)
return self.statement if statement is None else statement

def _get_delete_many_statement(
Expand Down Expand Up @@ -650,10 +660,13 @@ async def count(
Count of records returned by query, ignoring pagination.
"""
with wrap_sqlalchemy_exception():
statement = self._get_base_stmt(statement)
statement = self._get_base_stmt(statement, enable_tracking=False)
fragment = self.get_id_attribute_value(self.model_type)
statement += lambda s: s.with_only_columns(sql_func.count(fragment), maintain_column_froms=True)
statement += lambda s: s.order_by(None)
statement = statement.add_criteria(
lambda s: s.with_only_columns(sql_func.count(fragment), maintain_column_froms=True),
enable_tracking=False,
)
statement = statement.add_criteria(lambda s: s.order_by(None))
statement = self._filter_select_by_kwargs(statement, kwargs)
statement = self._apply_filters(*filters, apply_pagination=False, statement=statement)
results = await self._execute(statement)
Expand Down Expand Up @@ -843,7 +856,7 @@ async def _list_and_count_window(
"""
statement = self._get_base_stmt(statement)
field = self.get_id_attribute_value(self.model_type)
statement += lambda s: s.add_columns(over(sql_func.count(field)))
statement = statement.add_criteria(lambda s: s.add_columns(over(sql_func.count(field))), enable_tracking=False)
statement = self._apply_filters(*filters, statement=statement)
statement = self._filter_select_by_kwargs(statement, kwargs)
with wrap_sqlalchemy_exception():
Expand Down Expand Up @@ -893,9 +906,11 @@ async def _list_and_count_basic(

def _get_count_stmt(self, statement: StatementLambdaElement) -> StatementLambdaElement:
fragment = self.get_id_attribute_value(self.model_type)
statement += lambda s: s.with_only_columns(sql_func.count(fragment), maintain_column_froms=True)
statement += lambda s: s.order_by(None)
return statement
statement = statement.add_criteria(
lambda s: s.with_only_columns(sql_func.count(fragment), maintain_column_froms=True),
enable_tracking=False,
)
return statement.add_criteria(lambda s: s.order_by(None))

async def upsert(
self,
Expand Down Expand Up @@ -1373,7 +1388,7 @@ def _filter_by_expression(
statement: StatementLambdaElement,
expression: ColumnElement[bool],
) -> StatementLambdaElement:
statement += lambda s: s.filter(expression)
statement += lambda s: s.where(expression)
return statement

def _filter_by_where(
Expand Down
33 changes: 24 additions & 9 deletions advanced_alchemy/repository/_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,9 +338,19 @@ def exists(
def _get_base_stmt(
self,
statement: Select[tuple[ModelT]] | StatementLambdaElement | None = None,
global_track_bound_values: bool = True,
track_closure_variables: bool = True,
enable_tracking: bool = True,
track_bound_values: bool = True,
) -> StatementLambdaElement:
if isinstance(statement, Select):
return lambda_stmt(lambda: statement)
return lambda_stmt(
lambda: statement,
track_bound_values=track_bound_values,
global_track_bound_values=global_track_bound_values,
track_closure_variables=track_closure_variables,
enable_tracking=enable_tracking,
)
return self.statement if statement is None else statement

def _get_delete_many_statement(
Expand Down Expand Up @@ -651,10 +661,13 @@ def count(
Count of records returned by query, ignoring pagination.
"""
with wrap_sqlalchemy_exception():
statement = self._get_base_stmt(statement)
statement = self._get_base_stmt(statement, enable_tracking=False)
fragment = self.get_id_attribute_value(self.model_type)
statement += lambda s: s.with_only_columns(sql_func.count(fragment), maintain_column_froms=True)
statement += lambda s: s.order_by(None)
statement = statement.add_criteria(
lambda s: s.with_only_columns(sql_func.count(fragment), maintain_column_froms=True),
enable_tracking=False,
)
statement = statement.add_criteria(lambda s: s.order_by(None))
statement = self._filter_select_by_kwargs(statement, kwargs)
statement = self._apply_filters(*filters, apply_pagination=False, statement=statement)
results = self._execute(statement)
Expand Down Expand Up @@ -844,7 +857,7 @@ def _list_and_count_window(
"""
statement = self._get_base_stmt(statement)
field = self.get_id_attribute_value(self.model_type)
statement += lambda s: s.add_columns(over(sql_func.count(field)))
statement = statement.add_criteria(lambda s: s.add_columns(over(sql_func.count(field))), enable_tracking=False)
statement = self._apply_filters(*filters, statement=statement)
statement = self._filter_select_by_kwargs(statement, kwargs)
with wrap_sqlalchemy_exception():
Expand Down Expand Up @@ -894,9 +907,11 @@ def _list_and_count_basic(

def _get_count_stmt(self, statement: StatementLambdaElement) -> StatementLambdaElement:
fragment = self.get_id_attribute_value(self.model_type)
statement += lambda s: s.with_only_columns(sql_func.count(fragment), maintain_column_froms=True)
statement += lambda s: s.order_by(None)
return statement
statement = statement.add_criteria(
lambda s: s.with_only_columns(sql_func.count(fragment), maintain_column_froms=True),
enable_tracking=False,
)
return statement.add_criteria(lambda s: s.order_by(None))

def upsert(
self,
Expand Down Expand Up @@ -1374,7 +1389,7 @@ def _filter_by_expression(
statement: StatementLambdaElement,
expression: ColumnElement[bool],
) -> StatementLambdaElement:
statement += lambda s: s.filter(expression)
statement += lambda s: s.where(expression)
return statement

def _filter_by_where(
Expand Down
87 changes: 87 additions & 0 deletions tests/integration/test_lambda_stmt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from sqlalchemy import ForeignKey, create_engine, func, select
from sqlalchemy.orm import Mapped, Session, mapped_column, relationship, sessionmaker

from advanced_alchemy.base import UUIDBase
from advanced_alchemy.repository import SQLAlchemySyncRepository


def test_lambda_statement_quirks() -> None:

class Country(UUIDBase):
name: Mapped[str]

class State(UUIDBase):
name: Mapped[str]
country_id: Mapped[str] = mapped_column(ForeignKey(Country.id))

country = relationship(Country)

class USStateRepository(SQLAlchemySyncRepository[State]):
model_type = State

engine = create_engine("sqlite:///:memory:", future=True, echo=True)
session_factory: sessionmaker[Session] = sessionmaker(engine, expire_on_commit=False)

with engine.begin() as conn:
State.metadata.create_all(conn)

with session_factory() as db_session:
usa = Country(name="United States of America")
france = Country(name="France")
db_session.add(usa)
db_session.add(france)

california = State(name="California", country=usa)
oregon = State(name="Oregon", country=usa)
ile_de_france = State(name="Île-de-France", country=france)

repo = USStateRepository(session=db_session)
repo.add(california)
repo.add(oregon)
repo.add(ile_de_france)
db_session.commit()

# Using only the ORM, this works fine:

stmt = select(State).where(State.country_id == usa.id).with_only_columns(func.count())
count = db_session.execute(stmt).scalar_one()
assert count == 2, f"Expected 2, got {count}"
count = db_session.execute(stmt).scalar_one()
assert count == 2, f"Expected 2, got {count}"

stmt = select(State).where(State.country == usa).with_only_columns(func.count())
count = db_session.execute(stmt).scalar_one()
assert count == 2, f"Expected 2, got {count}"
count = db_session.execute(stmt).scalar_one()
assert count == 2, f"Expected 2, got {count}"

# Using the repository, this works:
stmt1 = select(State).where(State.country_id == usa.id)

count = repo.count(statement=stmt1)
assert count == 2, f"Expected 2, got {count}"

count = repo.count(statement=stmt1)
assert count == 2, f"Expected 2, got {count}"

# But this would fail (only after the second query) (lambda caching test):
stmt2 = select(State).where(State.country == usa)

count = repo.count(statement=stmt2)
assert count == 2, f"Expected 2, got {count}"

count = repo.count(State.country == usa)
assert count == 2, f"Expected 2, got {count}"

count = repo.count(statement=stmt2)
assert count == 2, f"Expected 2, got {count}"

# It also failed with
states = repo.list(statement=stmt2)
count = len(states)
assert count == 2, f"Expected 2, got {count}"

_states, count = repo.list_and_count(statement=stmt2)
assert count == 2, f"Expected 2, got {count}"
_states, count = repo.list_and_count(statement=stmt2, force_basic_query_mode=True)
assert count == 2, f"Expected 2, got {count}"
Loading