Skip to content

Commit c28aa36

Browse files
Add edit_form_query method (#745)
Co-authored-by: Amin Alaee <[email protected]>
1 parent 9ed5414 commit c28aa36

File tree

7 files changed

+81
-11
lines changed

7 files changed

+81
-11
lines changed

docs/api_reference/model_view.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
- count_query
5050
- search_query
5151
- sort_query
52+
- edit_form_query
5253
- on_model_change
5354
- after_model_change
5455
- on_model_delete

docs/configurations.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ The forms are based on `WTForms` package and include the following options:
200200
* `form_include_pk`: Control if primary key column should be included in create/edit forms. Default is `False`.
201201
* `form_ajax_refs`: Use Ajax with Select2 for loading relationship models async. This is use ful when the related model has a lot of records.
202202
* `form_converter`: Allow adding custom converters to support additional column types.
203+
* `edit_form_query`: A method with the signature of `(request) -> stmt` which can customize the edit form data.
203204

204205
!!! example
205206

docs/cookbook/optimize_relationship_loading.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,22 @@ which should be available in the form.
6060
class ParentAdmin(ModelView, model=Parent):
6161
form_excluded_columns = [Parent.children]
6262
```
63+
64+
### Using `edit_form_query` to customize the edit form data
65+
66+
If you would like to fully customize the query to populate the edit object form, you may override
67+
the `edit_form_query` function with your own SQLAlchemy query. In the following example, overriding
68+
the default query will allow you to filter relationships to show only related children of the parent.
69+
70+
```py
71+
class ParentAdmin(ModelView, model=Parent):
72+
def edit_form_query(self, request: Request) -> Select:
73+
parent_id = request.path_params["pk"]
74+
return (
75+
super()
76+
.edit_form_query(request)
77+
.join(Child)
78+
.options(contains_eager(Parent.children))
79+
.filter(Child.parent_id == parent_id)
80+
)
81+
```

sqladmin/application.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -542,7 +542,7 @@ async def edit(self, request: Request) -> Response:
542542
identity = request.path_params["identity"]
543543
model_view = self._find_model_view(identity)
544544

545-
model = await model_view.get_object_for_edit(request.path_params["pk"])
545+
model = await model_view.get_object_for_edit(request)
546546
if not model:
547547
raise HTTPException(status_code=404)
548548

sqladmin/models.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -807,12 +807,8 @@ async def get_object_for_details(self, value: Any) -> Any:
807807

808808
return await self._get_object_by_pk(stmt)
809809

810-
async def get_object_for_edit(self, value: Any) -> Any:
811-
stmt = self._stmt_by_identifier(value)
812-
813-
for relation in self._form_relations:
814-
stmt = stmt.options(joinedload(relation))
815-
810+
async def get_object_for_edit(self, request: Request) -> Any:
811+
stmt = self.edit_form_query(request)
816812
return await self._get_object_by_pk(stmt)
817813

818814
async def get_object_for_delete(self, value: Any) -> Any:
@@ -1045,6 +1041,18 @@ def list_query(self, request: Request) -> Select:
10451041

10461042
return select(self.model)
10471043

1044+
def edit_form_query(self, request: Request) -> Select:
1045+
"""
1046+
The SQLAlchemy select expression used for the edit form page which can be
1047+
customized. By default it will select the object by primary key(s) without any
1048+
additional filters.
1049+
"""
1050+
1051+
stmt = self._stmt_by_identifier(request.path_params["pk"])
1052+
for relation in self._form_relations:
1053+
stmt = stmt.options(joinedload(relation))
1054+
return stmt
1055+
10481056
def count_query(self, request: Request) -> Select:
10491057
"""
10501058
The SQLAlchemy select expression used for the count query

tests/test_models.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from markupsafe import Markup
66
from sqlalchemy import Boolean, Column, Enum, ForeignKey, Integer, String, select
77
from sqlalchemy.dialects.postgresql import UUID
8-
from sqlalchemy.orm import declarative_base, relationship, sessionmaker
8+
from sqlalchemy.orm import contains_eager, declarative_base, relationship, sessionmaker
99
from sqlalchemy.sql.expression import Select
1010
from starlette.applications import Starlette
1111
from starlette.requests import Request
@@ -52,6 +52,7 @@ class Address(Base):
5252
__tablename__ = "addresses"
5353

5454
id = Column(Integer, primary_key=True)
55+
name = Column(String)
5556
user_id = Column(Integer, ForeignKey("users.id"))
5657

5758
user = relationship("User", back_populates="addresses")
@@ -381,13 +382,47 @@ def list_query(self, request: Request) -> Select:
381382
assert len(await view.get_model_objects(request)) == 1
382383

383384

385+
async def test_edit_form_query() -> None:
386+
session = session_maker()
387+
batman = User(id=123, name="batman")
388+
batcave = Address(user=batman, name="bat cave")
389+
wayne_manor = Address(user=batman, name="wayne manor")
390+
session.add(batman)
391+
session.add(batcave)
392+
session.add(wayne_manor)
393+
session.commit()
394+
395+
class UserAdmin(ModelView, model=User):
396+
async_engine = False
397+
session_maker = session_maker
398+
399+
def edit_form_query(self, request: Request) -> Select:
400+
return (
401+
select(self.model)
402+
.join(Address)
403+
.options(contains_eager(User.addresses))
404+
.filter(Address.name == "bat cave")
405+
)
406+
407+
view = UserAdmin()
408+
409+
class RequestObject(object):
410+
pass
411+
412+
request_object = RequestObject()
413+
request_object.path_params = {"pk": 123}
414+
user_obj = await view.get_object_for_edit(request_object)
415+
416+
assert len(user_obj.addresses) == 1
417+
418+
384419
def test_model_columns_all_keyword() -> None:
385420
class AddressAdmin(ModelView, model=Address):
386421
column_list = "__all__"
387422
column_details_list = "__all__"
388423

389-
assert AddressAdmin().get_list_columns() == ["user", "id", "user_id"]
390-
assert AddressAdmin().get_details_columns() == ["user", "id", "user_id"]
424+
assert AddressAdmin().get_list_columns() == ["user", "id", "name", "user_id"]
425+
assert AddressAdmin().get_details_columns() == ["user", "id", "name", "user_id"]
391426

392427

393428
async def test_get_prop_value() -> None:

tests/test_models_action.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,14 @@ async def _action_stub(self, request: Request) -> Response:
3535
pks = request.query_params.get("pks", "")
3636

3737
obj_strs: List[str] = []
38+
39+
class RequestObject(object):
40+
pass
41+
3842
for pk in pks.split(","):
39-
obj = await self.get_object_for_edit(pk)
43+
request_object = RequestObject()
44+
request_object.path_params = {"pk": pk}
45+
obj = await self.get_object_for_edit(request_object)
4046

4147
obj_strs.append(repr(obj))
4248

0 commit comments

Comments
 (0)