Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 39 additions & 18 deletions sqlmodel/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
ClassVar,
Literal,
TypeAlias,
TypedDict,
TypeVar,
Union,
cast,
get_origin,
overload,
)

from pydantic import BaseModel, EmailStr
from pydantic import BaseModel, EmailStr, create_model
from pydantic.fields import FieldInfo as PydanticFieldInfo
from sqlalchemy import (
Boolean,
Expand All @@ -49,7 +50,7 @@
from sqlalchemy.orm.instrumentation import is_instrumented
from sqlalchemy.sql.schema import MetaData
from sqlalchemy.sql.sqltypes import LargeBinary, Time, Uuid
from typing_extensions import deprecated
from typing_extensions import Unpack, deprecated

from ._compat import (
PYDANTIC_MINOR_VERSION,
Expand Down Expand Up @@ -801,6 +802,22 @@ def get_column_from_field(field: Any) -> Column:
_TSQLModel = TypeVar("_TSQLModel", bound="SQLModel")


class _ModelDumpKwargs(TypedDict):
mode: Literal["json", "python"] | str
include: IncEx | None
exclude: IncEx | None
context: Any | None # v2.7
by_alias: bool | None
exclude_unset: bool
exclude_defaults: bool
exclude_none: bool
exclude_computed_fields: bool # v2.12
round_trip: bool
warnings: bool | Literal["none", "warn", "error"]
fallback: Callable[[Any], Any] | None # v2.11
serialize_as_any: bool # v2.7


class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry):
# SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values
__slots__ = ("__weakref__",)
Expand Down Expand Up @@ -984,25 +1001,29 @@ def sqlmodel_update(
obj: builtins.dict[str, Any] | BaseModel,
*,
update: builtins.dict[str, Any] | None = None,
**model_dump_kwargs: Unpack[_ModelDumpKwargs],
) -> _TSQLModel:
use_update = (update or {}).copy()
if isinstance(obj, dict):
for key, value in {**obj, **use_update}.items():
if key in get_model_fields(self):
setattr(self, key, value)
elif isinstance(obj, BaseModel):
for key in get_model_fields(obj):
if key in use_update:
value = use_update.pop(key)
else:
value = getattr(obj, key)
setattr(self, key, value)
for remaining_key, value in use_update.items():
if remaining_key in get_model_fields(self):
setattr(self, remaining_key, value)
else:
if not (isinstance(obj, dict) or isinstance(obj, BaseModel)):
raise ValueError(
"Can't use sqlmodel_update() with something that "
f"is not a dict or SQLModel or Pydantic model: {obj}"
)
if isinstance(obj, BaseModel):
# Create a temp UpdateModel schema (removes extra serialization settings)
ObjClass = obj.__class__
fields_def = {
fname: finfo.annotation
for fname, finfo in ObjClass.model_fields.items()
}
UpdateModel = create_model(f"_{ObjClass.__name__}Update_", **fields_def)
# rebuild obj instance with model_construct
obj = UpdateModel.model_construct(
_fields_set=obj.model_fields_set, **obj.__dict__
)
# Now `obj.model_dump` works with **model_dump_kwargs
obj = obj.model_dump(**model_dump_kwargs)
Comment on lines +1011 to +1024
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will break use cases with models that have fields with exclude=True:

from sqlmodel import Field, SQLModel


class Item(SQLModel):
    id: str
    param: str = Field(exclude=True)


a = Item.model_validate({"id": "1", "param": "1"})
b = Item.model_validate({"id": "1", "param": "2"})


a.sqlmodel_update(b, exclude={"id"})
# a.sqlmodel_update(b)


assert a.param == "2"

.. and probably some other cases when model has settings that change the default serialization schema.

use_update = (update or {}).copy()
for key, value in {**obj, **use_update}.items():
if key in get_model_fields(self):
setattr(self, key, value)
return self
14 changes: 13 additions & 1 deletion tests/test_update.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from pytest import raises
from sqlmodel import Field, SQLModel


def test_sqlmodel_update():
class Organization(SQLModel, table=True):
id: int = Field(default=None, primary_key=True)
name: str
city: str
headquarters: str

class OrganizationUpdate(SQLModel):
name: str
name: str = Field(exclude=True)
city: str | None = None

org = Organization(name="Example Org", city="New York", headquarters="NYC HQ")
org_in = OrganizationUpdate(name="Updated org")
Expand All @@ -17,4 +20,13 @@ class OrganizationUpdate(SQLModel):
update={
"headquarters": "-", # This field is in Organization, but not in OrganizationUpdate
},
exclude_unset=True,
)
# fields that should stay the same
assert org.city == "New York"
# fields that should be updated
assert org.name == "Updated org"
assert org.headquarters == "-"
# test raise value error when passing in updates other than dict or BaseModel
with raises(ValueError):
org.sqlmodel_update(["Boston"])
Loading