Skip to content

Commit

Permalink
Merge pull request #330 from phenobarbital/dev
Browse files Browse the repository at this point in the history
manage enum on serialization of ModelViews
  • Loading branch information
phenobarbital authored Dec 19, 2024
2 parents 763dc35 + 5fc78cd commit c5ac732
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 41 deletions.
34 changes: 32 additions & 2 deletions examples/test_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Union
import asyncio
from enum import Enum
from datetime import datetime
from aiohttp import web
from navconfig.logging import logging
Expand All @@ -9,14 +10,33 @@
from navigator import Application
from navigator.responses import HTMLResponse
from navigator.views import ModelView
from navigator.conf import PG_USER, PG_PWD, PG_HOST, PG_PORT

# Example DSN:
dsn = f'postgresql://{PG_USER}:{PG_PWD}@{PG_HOST}:{PG_PORT}/pruebas'

class AirportType(Enum):
"""
Enum for Airport Types.
"""
CITY = 1
INTERNATIONAL = 2
DOMESTIC = 3


class Country(Model):
country_code: str = Column(primary_key=True)
country: str

class Airport(Model):
iata: str = Column(primary_key=True, required=True, label='IATA Code')
airport: str = Column(required=True, label="Airport Name")
airport_type: AirportType = Column(
required=True,
label='Airport Type',
choices=AirportType,
default=AirportType.CITY
)
city: str
country: str
created_by: int
Expand All @@ -40,10 +60,19 @@ async def hola(request: web.Request) -> web.Response:
class AirportHandler(ModelView):
model: Model = Airport
pk: Union[str, list] = ['iata']
dsn: str = dsn

async def _get_created_by(self, value, column, **kwargs):
async def _set_created_by(self, value, column, **kwargs):
return await self.get_userid(session=self._session)

async def _put_callback(self, response: web.Response, result, *args, **kwargs):
print('RESULT > ', result)
print('RESPONSE > ', response)
print('PUT CALLBACK')
return response

_post_callback = _put_callback

async def on_startup(self, *args, **kwargs):
print(args, kwargs)
print('THIS CODE RUN ON STARTUP')
Expand All @@ -67,6 +96,7 @@ async def start_example(db):
iata character varying(3),
airport character varying(60),
city character varying(20),
airport_type integer,
country character varying(30),
created_by integer,
created_at timestamp with time zone NOT NULL DEFAULT now(),
Expand Down Expand Up @@ -122,7 +152,7 @@ async def end_example(db):
"password": "12345678",
"host": "127.0.0.1",
"port": "5432",
"database": "navigator",
"database": "pruebas",
"DEBUG": True,
}
kwargs = {
Expand Down
8 changes: 8 additions & 0 deletions navigator/libs/json.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ from psycopg2 import Binary # Import Binary from psycopg2
from typing import Any, Union
from pathlib import PosixPath, PurePath, Path
from decimal import Decimal
from enum import Enum, EnumType
from ..exceptions.exceptions cimport ValidationError
import orjson

Expand Down Expand Up @@ -52,6 +53,13 @@ cdef class JSONContent:
return [obj.lower, up]
elif hasattr(obj, 'tolist'): # numpy array
return obj.tolist()
elif isinstance(obj, Enum): # Handle Enum serialization
if obj is None:
return None
return obj.value if hasattr(obj, 'value') else obj.name
elif isinstance(obj, type) and issubclass(obj, Enum):
return [{'value': e.value, 'name': e.name} for e in obj]
# return [e.name for e in obj] # Serialize the names of the Enum class members
elif isinstance(obj, _MISSING_TYPE):
return None
elif obj == MISSING:
Expand Down
2 changes: 1 addition & 1 deletion navigator/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
__description__ = (
"Navigator Web Framework based on aiohttp, " "with batteries included."
)
__version__ = "2.12.7"
__version__ = "2.12.8"
__author__ = "Jesus Lara"
__author_email__ = "[email protected]"
__license__ = "BSD"
70 changes: 45 additions & 25 deletions navigator/views/abstract.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from typing import Optional, Union, Any, TypeVar
from collections.abc import Callable
from typing import Optional, Union, Any, TypeVar, Type
from collections.abc import Awaitable, Callable
import asyncio
import copy
from aiohttp import web, hdrs
import traceback
from functools import wraps
from aiohttp import web, hdrs
try:
import babel
BABEL_INSTALLED = True
Expand Down Expand Up @@ -82,13 +81,15 @@ async def __aenter__(self):
async def default_connection(self, request: web.Request):
if self._dbname in request.app:
return request.app[self._dbname]
kwargs = {
"server_settings": {
'client_min_messages': 'notice',
'max_parallel_workers': '24',
'tcp_keepalives_idle': '30'
kwargs = {}
if self.driver == 'pg':
kwargs = {
"server_settings": {
'client_min_messages': 'notice',
'max_parallel_workers': '24',
'tcp_keepalives_idle': '30'
}
}
}
pool = AsyncPool(
self.driver,
dsn=default_dsn,
Expand Down Expand Up @@ -148,37 +149,37 @@ class AbstractModel(BaseView):
in: Model
type: BaseModel
required: true
description: DataModel to be used.
- name: get_model
in: Model
type: BaseModel
required: false
description: DataModel to be used.
"""
model: BaseModel = None
get_model: BaseModel = None
model: Type[BaseModel] = None
get_model: Type[BaseModel] = None
# Signal for startup method for this ModelView
on_startup: Optional[Callable] = None
on_shutdown: Optional[Callable] = None
model_kwargs: dict = {}
name: str = "Model"
# Connection parameters
driver: str = 'pg'
dsn: str = None
credentials: dict = None
dbname: str = 'nav.model'
handler: ConnectionHandler

def __init__(self, request, *args, **kwargs):
self.__name__ = self.model.__name__
self._session = None
driver = kwargs.pop('driver', 'pg')
dsn = kwargs.pop('dsn', None)
credentials = kwargs.pop('credentials', {})
dbname = kwargs.pop('dbname', 'nav.model')
## getting get Model:
if not self.get_model:
self.get_model = self.model
super().__init__(request, *args, **kwargs)
# Database Connection Handler
self.handler = ConnectionHandler(
driver,
dsn=dsn,
dbname=dbname,
credentials=credentials,
model_kwargs=self.model_kwargs
)

@classmethod
def configure(cls, app: WebApp, path: str = None) -> WebApp:
def configure(cls, app: WebApp, path: str = None, **kwargs) -> WebApp:
"""configure.
Expand Down Expand Up @@ -221,6 +222,25 @@ def configure(cls, app: WebApp, path: str = None) -> WebApp:
app.router.add_view(
r"{url}{{meta:(:.*)?}}".format(url=url), cls
)
# Use kwargs to reconfigure the connection handler if needed
if 'driver' in kwargs:
cls.driver = kwargs['driver']
if 'dsn' in kwargs:
cls.dsn = kwargs['dsn']
if 'credentials' in kwargs:
cls.credentials = kwargs['credentials']
if 'dbname' in kwargs:
cls.dbname = kwargs['dbname']
if 'model_kwargs' in kwargs:
cls.model_kwargs = kwargs['model_kwargs']
# Database Connection Handler
cls.handler = ConnectionHandler(
cls.driver,
dsn=cls.dsn,
dbname=cls.dbname,
credentials=cls.credentials,
model_kwargs=cls.model_kwargs
)

async def validate_payload(self, data: Optional[Union[dict, list]] = None):
"""Get information for usage in Form."""
Expand Down
34 changes: 21 additions & 13 deletions navigator/views/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Awaitable
from typing import Optional, Union, Any
from collections.abc import Iterable
from typing import Optional, Union, Any, Awaitable, Callable
import importlib
import asyncio
from aiohttp import web
Expand Down Expand Up @@ -51,6 +51,9 @@ async def load_model(tablename: str, schema: str, connection: Any) -> Model:
)


CallbackType = Optional[Callable[[web.Response, BaseModel], Awaitable[None]]]


class ModelView(AbstractModel):
"""ModelView.
Expand All @@ -70,16 +73,16 @@ class ModelView(AbstractModel):
get_model: BaseModel = None
model_name: str = None # Override the current model with other.
path: str = None
pk: Union[str, list] = None
pk: Optional[Iterable] = None
_required: list = []
_primaries: list = []
_hidden: list = []
# New Callables to be used on response:
_get_callback: Optional[Awaitable] = None
_put_callback: Optional[Awaitable] = None
_post_callback: Optional[Awaitable] = None
_patch_callback: Optional[Awaitable] = None
_delete_callback: Optional[Awaitable] = None
_get_callback: CallbackType = None
_put_callback: CallbackType = None
_post_callback: CallbackType = None
_patch_callback: CallbackType = None
_delete_callback: CallbackType = None

def __init__(self, request, *args, **kwargs):
if self.model_name is not None:
Expand Down Expand Up @@ -371,7 +374,12 @@ async def _get_filters():
if len(res) == 1:
return res[0]
return res
args = {self.pk: _primary}
elif isinstance(self.pk, str):
args = {self.pk: _primary} # pylint: disable=E1143
else:
raise ValueError(
f"Invalid PK definition for {self.__name__}: {self.pk}"
)
args = {**_filter, **args}
return await self.get_model.get(**args)
elif len(qp) > 0:
Expand Down Expand Up @@ -1271,7 +1279,7 @@ def _del_primary(self, args: dict = None) -> Any:
try:
_args = {}
paramlist = [
item.strip() for item in args["id"].split("/") if item.strip()
item.strip() for item in args.get('id', '').split("/") if item.strip()
]
if not paramlist:
return None
Expand Down Expand Up @@ -1306,8 +1314,8 @@ def _del_primary(self, args: dict = None) -> Any:
# TODO: use validation from datamodel
# evaluate the corrected type for fields:
val = paramlist.pop(0)
args[key] = val
return args
_args[key] = val
return _args
except KeyError:
pass
else:
Expand Down Expand Up @@ -1352,7 +1360,7 @@ async def delete(self):
if isinstance(objid, list):
data = []
for entry in objid:
args = {self.pk: entry}
args = {self.pk: entry} # noqa
obj = await self.model.get(**args)
data.append(await obj.delete())
else:
Expand Down

0 comments on commit c5ac732

Please sign in to comment.