diff --git a/piccolo_api/crud/serializers.py b/piccolo_api/crud/serializers.py index 9421da9a..35681703 100644 --- a/piccolo_api/crud/serializers.py +++ b/piccolo_api/crud/serializers.py @@ -50,6 +50,10 @@ def create_pydantic_model( """ Create a Pydantic model representing a table. + Make sure that you are using the ``required`` attribute on each ``Field`` + of your model because it is used here to find out whether each created + Pydantic field should have a default value or not. + :param table: The Piccolo ``Table`` you want to create a Pydantic serialiser model for. @@ -131,9 +135,10 @@ def create_pydantic_model( ####################################################################### params: t.Dict[str, t.Any] = { - "default": None if is_optional else ..., "nullable": column._meta.null, } + if not column._meta.required: + params.update({"default_factory": column.get_default_value}) extra = { "help_text": column._meta.help_text, diff --git a/tests/crud/test_crud_endpoints.py b/tests/crud/test_crud_endpoints.py index 52a7534e..073f8744 100644 --- a/tests/crud/test_crud_endpoints.py +++ b/tests/crud/test_crud_endpoints.py @@ -870,8 +870,14 @@ def test_new(self): response = client.get("/new/") self.assertEqual(response.status_code, 200) + response_json = response.json() + self.assertEqual( + response_json["name"], + "", + ) self.assertEqual( - response.json(), {"id": None, "name": "", "rating": 0} + response_json["rating"], + 0, ) diff --git a/tests/crud/test_serializers.py b/tests/crud/test_serializers.py index 60f8dfd6..64e6b4d1 100644 --- a/tests/crud/test_serializers.py +++ b/tests/crud/test_serializers.py @@ -2,7 +2,7 @@ from unittest import TestCase import pydantic -from piccolo.columns import Array, Numeric, Text, Varchar +from piccolo.columns import Array, Integer, Numeric, Text, Varchar from piccolo.columns.column_types import JSON, JSONB, Secret from piccolo.table import Table from pydantic import ValidationError @@ -256,3 +256,50 @@ class Computer2(Table): with self.assertRaises(ValueError): create_pydantic_model(Computer, exclude_columns=(Computer2.CPU,)) + + +class TestDefaultColumn(TestCase): + def test_default(self): + class Monitor(Table): + refresh_rate = Integer(default=144) + resolution = Varchar(required=True) + + pydantic_model = create_pydantic_model(Monitor) + + assert pydantic_model.schema()["required"] == ["resolution"] + + pydantic_instance = pydantic_model(resolution="1440*2560") + + assert pydantic_instance.refresh_rate == 144 + assert pydantic_instance.resolution == "1440*2560" + + def test_default_factory(self): + class Monitor(Table): + refresh_rate = Integer(required=True) + resolution = Varchar(default=lambda: "1920*1080") + + pydantic_model = create_pydantic_model(Monitor) + + assert pydantic_model.schema()["required"] == ["refresh_rate"] + + pydantic_instance = pydantic_model(refresh_rate=60) + + assert pydantic_instance.refresh_rate == 60 + assert pydantic_instance.resolution == "1920*1080" + + def test_override_default(self): + class Monitor(Table): + refresh_rate = Integer(default=240) + resolution = Varchar(default=lambda: "1440*2560") + + pydantic_model = create_pydantic_model(Monitor) + + assert not pydantic_model.schema().get("required") + + pydantic_instance = pydantic_model( + refresh_rate=60, + resolution="1080*1920", + ) + + assert pydantic_instance.refresh_rate == 60 + assert pydantic_instance.resolution == "1080*1920" diff --git a/tests/fastapi/test_fastapi_endpoints.py b/tests/fastapi/test_fastapi_endpoints.py index 0ec2803b..a187960b 100644 --- a/tests/fastapi/test_fastapi_endpoints.py +++ b/tests/fastapi/test_fastapi_endpoints.py @@ -111,9 +111,14 @@ def test_get_responses(self): response = client.get("/movies/new/") self.assertEqual(response.status_code, 200) + response_json = response.json() self.assertEqual( - response.json(), - {"id": None, "name": "", "rating": 0}, + response_json["name"], + "", + ) + self.assertEqual( + response_json["rating"], + 0, ) response = client.get("/movies/references/")