Skip to content

Combined column metadata sqlmodel with dynamic schema #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 119 additions & 7 deletions src/sqlacodegen/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,14 @@ def __init__(
options: Sequence[str],
*,
indentation: str = " ",
dynamic_schema_import_path: str | None = None,
dynamic_schema_value: str | None = None,
):
super().__init__(metadata, bind, options)
self.indentation: str = indentation
# TODO add check if there is a "." in the value if set?
self.dynamic_schema_import_path: str | None = dynamic_schema_import_path
self.dynamic_schema_value: str | None = dynamic_schema_value
self.imports: dict[str, set[str]] = defaultdict(set)
self.module_imports: set[str] = set()

Expand Down Expand Up @@ -197,6 +202,8 @@ def collect_imports(self, models: Iterable[Model]) -> None:

for model in models:
self.collect_imports_for_model(model)
if self.dynamic_schema_import_path:
self.add_literal_import(*self.dynamic_schema_import_path.rsplit(".", 1))

def collect_imports_for_model(self, model: Model) -> None:
if model.__class__ is Model:
Expand Down Expand Up @@ -374,7 +381,9 @@ def render_table(self, table: Table) -> str:
if len(index.columns) > 1 or not uses_default_name(index):
args.append(self.render_index(index))

if table.schema:
if self.dynamic_schema_value:
kwargs["schema"] = self.dynamic_schema_value
elif table.schema:
kwargs["schema"] = repr(table.schema)

table_comment = getattr(table, "comment", None)
Expand Down Expand Up @@ -722,9 +731,18 @@ def __init__(
options: Sequence[str],
*,
indentation: str = " ",
dynamic_schema_import_path: str | None = None,
dynamic_schema_value: str | None = None,
base_class_name: str = "Base",
):
super().__init__(metadata, bind, options, indentation=indentation)
super().__init__(
metadata,
bind,
options,
indentation=indentation,
dynamic_schema_import_path=dynamic_schema_import_path,
dynamic_schema_value=dynamic_schema_value,
)
self.base_class_name: str = base_class_name
self.inflect_engine = inflect.engine()

Expand Down Expand Up @@ -1159,14 +1177,23 @@ def render_table_args(self, table: Table) -> str:
if len(index.columns) > 1 or not uses_default_name(index):
args.append(self.render_index(index))

if table.schema:
if self.dynamic_schema_value:
kwargs["schema"] = self.dynamic_schema_value
elif table.schema:
kwargs["schema"] = table.schema

if table.comment:
kwargs["comment"] = table.comment

if kwargs:
formatted_kwargs = pformat(kwargs)
# NB: using pformat on the dict turns schema value (python code) to a string
formatted_kwargs = f",\n{self.indentation}".join(
f"'{k}': {pformat(v)}"
if v != self.dynamic_schema_value
else f"'{k}': {v}"
for k, v in kwargs.items()
)
formatted_kwargs = f"{{{formatted_kwargs}}}"
if not args:
return formatted_kwargs
else:
Expand Down Expand Up @@ -1309,6 +1336,8 @@ def __init__(
options: Sequence[str],
*,
indentation: str = " ",
dynamic_schema_import_path: str | None = None,
dynamic_schema_value: str | None = None,
base_class_name: str = "Base",
quote_annotations: bool = False,
metadata_key: str = "sa",
Expand All @@ -1318,6 +1347,8 @@ def __init__(
bind,
options,
indentation=indentation,
dynamic_schema_import_path=dynamic_schema_import_path,
dynamic_schema_value=dynamic_schema_value,
base_class_name=base_class_name,
)
self.metadata_key: str = metadata_key
Expand Down Expand Up @@ -1348,16 +1379,98 @@ def __init__(
options: Sequence[str],
*,
indentation: str = " ",
dynamic_schema_import_path: str | None = None,
dynamic_schema_value: str | None = None,
base_class_name: str = "SQLModel",
):
super().__init__(
metadata,
bind,
options,
indentation=indentation,
dynamic_schema_import_path=dynamic_schema_import_path,
dynamic_schema_value=dynamic_schema_value,
base_class_name=base_class_name,
)

def generate_models(self) -> list[Model]:
models_by_table_name: dict[str, Model] = {}

# Pick association tables from the metadata into their own set, don't process
# them normally
links: defaultdict[str, list[Model]] = defaultdict(lambda: [])
for table in self.metadata.sorted_tables:
qualified_name = qualified_table_name(table)

# Link tables have exactly two foreign key constraints and all columns are
# involved in them
fk_constraints = sorted(
table.foreign_key_constraints, key=get_constraint_sort_key
)
if len(fk_constraints) == 2 and all(
col.foreign_keys for col in table.columns
):
model = models_by_table_name[qualified_name] = Model(table)
tablename = fk_constraints[0].elements[0].column.table.name
links[tablename].append(model)
continue

# Only form model classes for tables that have a primary key and are not
# association tables
if not table.primary_key:
models_by_table_name[qualified_name] = Model(table)
else:
model = ModelClass(table)
models_by_table_name[qualified_name] = model

# Fill in the columns
for column in table.c:
column_attr = ColumnAttribute(model, column)
model.columns.append(column_attr)

# Add relationships
for model in models_by_table_name.values():
if isinstance(model, ModelClass):
self.generate_relationships(
model, models_by_table_name, links[model.table.name]
)

# Nest inherited classes in their superclasses to ensure proper ordering
if "nojoined" not in self.options:
for model in list(models_by_table_name.values()):
if not isinstance(model, ModelClass):
continue

pk_column_names = {col.name for col in model.table.primary_key.columns}
for constraint in model.table.foreign_key_constraints:
if set(get_column_names(constraint)) == pk_column_names:
target = models_by_table_name[
qualified_table_name(constraint.elements[0].column.table)
]
if isinstance(target, ModelClass):
model.parent_class = target
target.children.append(model)

# Change base if we have both tables and model classes
if any(
not isinstance(model, ModelClass) for model in models_by_table_name.values()
):
TablesGenerator.generate_base(self)

# Collect the imports
self.collect_imports(models_by_table_name.values())

# Rename models and their attributes that conflict with imports or other
# attributes
global_names = {
name for namespace in self.imports.values() for name in namespace
}
for model in models_by_table_name.values():
self.generate_model_name(model, global_names)
global_names.add(model.name)

return list(models_by_table_name.values())

def generate_base(self) -> None:
self.base = Base(
literal_imports=[],
Expand All @@ -1368,7 +1481,6 @@ def generate_base(self) -> None:
def collect_imports(self, models: Iterable[Model]) -> None:
super(DeclarativeGenerator, self).collect_imports(models)
if any(isinstance(model, ModelClass) for model in models):
self.remove_literal_import("sqlalchemy", "MetaData")
self.add_literal_import("sqlmodel", "SQLModel")
self.add_literal_import("sqlmodel", "Field")

Expand Down Expand Up @@ -1400,7 +1512,7 @@ def collect_imports_for_column(self, column: Column[Any]) -> None:
self.add_import(python_type)

def render_module_variables(self, models: list[Model]) -> str:
declarations: list[str] = []
declarations: list[str] = self.base.declarations
if any(not isinstance(model, ModelClass) for model in models):
if self.base.table_metadata_declaration is not None:
declarations.append(self.base.table_metadata_declaration)
Expand Down Expand Up @@ -1446,7 +1558,7 @@ def render_column_attribute(self, column_attr: ColumnAttribute) -> str:
kwargs["default"] = None
python_type_name = f"Optional[{python_type_name}]"

rendered_column = self.render_column(column, True)
rendered_column = self.render_column(column, True, is_table=True)
kwargs["sa_column"] = f"{rendered_column}"
rendered_field = render_callable("Field", kwargs=kwargs)
return f"{column_attr.name}: {python_type_name} = {rendered_field}"
Expand Down
10 changes: 10 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from dataclasses import dataclass
from textwrap import dedent

import pytest
Expand Down Expand Up @@ -31,3 +32,12 @@ def validate_code(generated_code: str, expected_code: str) -> None:
configure_mappers()
finally:
clear_mappers()


@dataclass
class SchemaObject:
name: str


# NB: not a fixture on purpose
schema_obj = SchemaObject(name="best_schema")
50 changes: 50 additions & 0 deletions tests/test_generator_declarative.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from textwrap import dedent

import pytest
from _pytest.fixtures import FixtureRequest
from sqlalchemy import PrimaryKeyConstraint
Expand Down Expand Up @@ -30,6 +32,20 @@ def generator(
return DeclarativeGenerator(metadata, engine, options)


@pytest.fixture
def generator_dynamic_schema(
request: FixtureRequest, metadata: MetaData, engine: Engine
) -> CodeGenerator:
schema_import_path, schema_value = getattr(request, "param", (None, None))
return DeclarativeGenerator(
metadata,
engine,
[],
dynamic_schema_import_path=schema_import_path,
dynamic_schema_value=schema_value,
)


def test_indexes(generator: CodeGenerator) -> None:
simple_items = Table(
"simple_items",
Expand Down Expand Up @@ -1509,3 +1525,37 @@ class Simple(Base):
server_default=text("'test'"))
""",
)


@pytest.mark.parametrize(
"generator_dynamic_schema",
[[".conftest.schema_obj", "schema_obj.name"]],
indirect=True,
)
def test_use_dynamic_schema(generator_dynamic_schema: CodeGenerator) -> None:
Table(
"simple_items",
generator_dynamic_schema.metadata,
Column("id", INTEGER, primary_key=True),
)

expected_code = """\
from .conftest import schema_obj
from sqlalchemy import Integer
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column

class Base(DeclarativeBase):
pass


class SimpleItems(Base):
__tablename__ = 'simple_items'
__table_args__ = {'schema': schema_obj.name}

id: Mapped[int] = mapped_column(Integer, primary_key=True)
"""
generated_code = generator_dynamic_schema.generate()
expected_code = dedent(expected_code)
assert generated_code == expected_code
# TODO: code execution fails with KeyError: "'__name__' not in globals", any idea?
# validate_code(generated_code, expected_code)