Skip to content

Commit

Permalink
Add tests for _url_for methods (#456)
Browse files Browse the repository at this point in the history
  • Loading branch information
aminalaee authored Mar 25, 2023
1 parent bad606c commit c838595
Showing 1 changed file with 32 additions and 9 deletions.
41 changes: 32 additions & 9 deletions tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Generator
from unittest.mock import Mock, call, patch

import pytest
from markupsafe import Markup
Expand All @@ -7,6 +8,7 @@
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship, sessionmaker
from starlette.applications import Starlette
from starlette.requests import Request

from sqladmin import Admin, ModelView
from sqladmin.exceptions import InvalidColumnError, InvalidModelError
Expand Down Expand Up @@ -36,7 +38,7 @@ class User(Base):
class Address(Base):
__tablename__ = "addresses"

id = Column(Integer, primary_key=True)
pk = Column(Integer, primary_key=True)
user_id = Column(Integer, ForeignKey("users.id"))

user = relationship("User", back_populates="addresses")
Expand Down Expand Up @@ -122,10 +124,10 @@ class UserAdmin(ModelView, model=User):

def test_column_list_by_str_name() -> None:
class AddressAdmin(ModelView, model=Address):
column_list = ["id", "user_id"]
column_list = ["pk", "user_id"]

assert AddressAdmin().get_list_columns() == [
("id", Address.id),
("pk", Address.pk),
("user_id", Address.user_id),
]

Expand Down Expand Up @@ -320,10 +322,10 @@ class UserAdmin(ModelView, model=User):

def test_form_columns_by_str_name() -> None:
class AddressAdmin(ModelView, model=Address):
form_columns = ["id", "user_id"]
form_columns = ["pk", "user_id"]

assert AddressAdmin().get_form_columns() == [
("id", Address.id),
("pk", Address.pk),
("user_id", Address.user_id),
]

Expand Down Expand Up @@ -393,10 +395,10 @@ class UserAdmin(ModelView, model=User):

def test_export_columns_by_str_name() -> None:
class AddressAdmin(ModelView, model=Address):
column_export_list = ["id", "user_id"]
column_export_list = ["pk", "user_id"]

assert AddressAdmin().get_export_columns() == [
("id", Address.id),
("pk", Address.pk),
("user_id", Address.user_id),
]

Expand Down Expand Up @@ -475,14 +477,35 @@ async def test_get_model_objects_uses_list_query() -> None:
session.refresh(batman)
session.close()

class HerosAdmin(ModelView, model=User):
class UserAdmin(ModelView, model=User):
async_engine = False
sessionmaker = LocalSession

view = HerosAdmin()
view = UserAdmin()

view.list_query = select(User).filter(User.name.endswith("man"))
assert len(await view.get_model_objects()) == 1

view.list_query = select(User).filter(User.name.endswith("man").is_(False))
assert len(await view.get_model_objects()) == 0


def test_url_for() -> None:
class UserAdmin(ModelView, model=User):
...

view = UserAdmin()
request = Request({"type": "http"})
user = User(id=1)
address = Address(pk=2, user=user)

with patch("starlette.requests.Request.url_for", Mock()) as mock:
view._url_for_details(request, user)
view._url_for_edit(request, address)
view._url_for_delete(request, address)

assert mock.call_args_list == [
call("admin:details", identity="user", pk=1),
call("admin:edit", identity="address", pk=2),
call("admin:delete", identity="address"),
]

0 comments on commit c838595

Please sign in to comment.