diff --git a/src/sqlacodegen/generators.py b/src/sqlacodegen/generators.py index 21eadb63..7ad902c0 100644 --- a/src/sqlacodegen/generators.py +++ b/src/sqlacodegen/generators.py @@ -128,9 +128,14 @@ def __init__( options: Sequence[str], *, indentation: str = " ", + dynamic_schema_import_path: str | None = None, + dynamic_schema_value: str | None = None, ): super().__init__(metadata, bind, options) self.indentation: str = indentation + # TODO add check if there is a "." in the value if set? + self.dynamic_schema_import_path: str | None = dynamic_schema_import_path + self.dynamic_schema_value: str | None = dynamic_schema_value self.imports: dict[str, set[str]] = defaultdict(set) self.module_imports: set[str] = set() @@ -197,6 +202,8 @@ def collect_imports(self, models: Iterable[Model]) -> None: for model in models: self.collect_imports_for_model(model) + if self.dynamic_schema_import_path: + self.add_literal_import(*self.dynamic_schema_import_path.rsplit(".", 1)) def collect_imports_for_model(self, model: Model) -> None: if model.__class__ is Model: @@ -374,7 +381,9 @@ def render_table(self, table: Table) -> str: if len(index.columns) > 1 or not uses_default_name(index): args.append(self.render_index(index)) - if table.schema: + if self.dynamic_schema_value: + kwargs["schema"] = self.dynamic_schema_value + elif table.schema: kwargs["schema"] = repr(table.schema) table_comment = getattr(table, "comment", None) @@ -722,9 +731,18 @@ def __init__( options: Sequence[str], *, indentation: str = " ", + dynamic_schema_import_path: str | None = None, + dynamic_schema_value: str | None = None, base_class_name: str = "Base", ): - super().__init__(metadata, bind, options, indentation=indentation) + super().__init__( + metadata, + bind, + options, + indentation=indentation, + dynamic_schema_import_path=dynamic_schema_import_path, + dynamic_schema_value=dynamic_schema_value, + ) self.base_class_name: str = base_class_name self.inflect_engine = inflect.engine() @@ -1159,14 +1177,23 @@ def render_table_args(self, table: Table) -> str: if len(index.columns) > 1 or not uses_default_name(index): args.append(self.render_index(index)) - if table.schema: + if self.dynamic_schema_value: + kwargs["schema"] = self.dynamic_schema_value + elif table.schema: kwargs["schema"] = table.schema if table.comment: kwargs["comment"] = table.comment if kwargs: - formatted_kwargs = pformat(kwargs) + # NB: using pformat on the dict turns schema value (python code) to a string + formatted_kwargs = f",\n{self.indentation}".join( + f"'{k}': {pformat(v)}" + if v != self.dynamic_schema_value + else f"'{k}': {v}" + for k, v in kwargs.items() + ) + formatted_kwargs = f"{{{formatted_kwargs}}}" if not args: return formatted_kwargs else: @@ -1309,6 +1336,8 @@ def __init__( options: Sequence[str], *, indentation: str = " ", + dynamic_schema_import_path: str | None = None, + dynamic_schema_value: str | None = None, base_class_name: str = "Base", quote_annotations: bool = False, metadata_key: str = "sa", @@ -1318,6 +1347,8 @@ def __init__( bind, options, indentation=indentation, + dynamic_schema_import_path=dynamic_schema_import_path, + dynamic_schema_value=dynamic_schema_value, base_class_name=base_class_name, ) self.metadata_key: str = metadata_key @@ -1348,6 +1379,8 @@ def __init__( options: Sequence[str], *, indentation: str = " ", + dynamic_schema_import_path: str | None = None, + dynamic_schema_value: str | None = None, base_class_name: str = "SQLModel", ): super().__init__( @@ -1355,6 +1388,8 @@ def __init__( bind, options, indentation=indentation, + dynamic_schema_import_path=dynamic_schema_import_path, + dynamic_schema_value=dynamic_schema_value, base_class_name=base_class_name, ) diff --git a/tests/conftest.py b/tests/conftest.py index 022e786c..ae3822d4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass from textwrap import dedent import pytest @@ -31,3 +32,12 @@ def validate_code(generated_code: str, expected_code: str) -> None: configure_mappers() finally: clear_mappers() + + +@dataclass +class SchemaObject: + name: str + + +# NB: not a fixture on purpose +schema_obj = SchemaObject(name="best_schema") diff --git a/tests/test_generator_declarative.py b/tests/test_generator_declarative.py index d9bf7b53..7ce25c89 100644 --- a/tests/test_generator_declarative.py +++ b/tests/test_generator_declarative.py @@ -1,5 +1,7 @@ from __future__ import annotations +from textwrap import dedent + import pytest from _pytest.fixtures import FixtureRequest from sqlalchemy import PrimaryKeyConstraint @@ -30,6 +32,20 @@ def generator( return DeclarativeGenerator(metadata, engine, options) +@pytest.fixture +def generator_dynamic_schema( + request: FixtureRequest, metadata: MetaData, engine: Engine +) -> CodeGenerator: + schema_import_path, schema_value = getattr(request, "param", (None, None)) + return DeclarativeGenerator( + metadata, + engine, + [], + dynamic_schema_import_path=schema_import_path, + dynamic_schema_value=schema_value, + ) + + def test_indexes(generator: CodeGenerator) -> None: simple_items = Table( "simple_items", @@ -1509,3 +1525,37 @@ class Simple(Base): server_default=text("'test'")) """, ) + + +@pytest.mark.parametrize( + "generator_dynamic_schema", + [[".conftest.schema_obj", "schema_obj.name"]], + indirect=True, +) +def test_use_dynamic_schema(generator_dynamic_schema: CodeGenerator) -> None: + Table( + "simple_items", + generator_dynamic_schema.metadata, + Column("id", INTEGER, primary_key=True), + ) + + expected_code = """\ +from .conftest import schema_obj +from sqlalchemy import Integer +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + +class Base(DeclarativeBase): + pass + + +class SimpleItems(Base): + __tablename__ = 'simple_items' + __table_args__ = {'schema': schema_obj.name} + + id: Mapped[int] = mapped_column(Integer, primary_key=True) +""" + generated_code = generator_dynamic_schema.generate() + expected_code = dedent(expected_code) + assert generated_code == expected_code + # TODO: code execution fails with KeyError: "'__name__' not in globals", any idea? + # validate_code(generated_code, expected_code)