diff --git a/CHANGELOG.md b/CHANGELOG.md index b134a53..1923429 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ ### [0.8.1](Unreleased) #### Fixed +- fix: add o2o field does not create constraint when migrating. (#396) - Migration with duplicate renaming of columns in some cases. (#395) - fix: intermediate table for m2m relation not created. (#394) - Migrate add m2m field with custom through generate duplicated table. (#393) @@ -16,6 +17,7 @@ - Fix configuration file reading error when containing Chinese characters. (#286) - sqlite: failed to create/drop index. (#302) - PostgreSQL: Cannot drop constraint after deleting or rename FK on a model. (#378) +- Fix create/drop indexes in every migration. (#377) - Sort m2m fields before comparing them with diff. (#271) #### Changed diff --git a/aerich/migrate.py b/aerich/migrate.py index 7abb4a9..856018f 100644 --- a/aerich/migrate.py +++ b/aerich/migrate.py @@ -1,4 +1,5 @@ -import hashlib +from __future__ import annotations + import importlib import os from datetime import datetime @@ -6,6 +7,7 @@ from typing import Dict, Iterable, List, Optional, Set, Tuple, Type, Union, cast import asyncclick as click +import tortoise from dictdiffer import diff from tortoise import BaseDBAsyncClient, ConfigurationError, Model, Tortoise from tortoise.exceptions import OperationalError @@ -201,21 +203,25 @@ def _add_operator(cls, operator: str, upgrade=True, fk_m2m_index=False) -> None: @classmethod def _handle_indexes(cls, model: Type[Model], indexes: List[Union[Tuple[str], Index]]) -> list: - ret: list = [] - - def index_hash(self) -> str: - h = hashlib.new("MD5", usedforsecurity=False) # type:ignore[call-arg] - h.update( - self.index_name(cls.ddl.schema_generator, model).encode() - + self.__class__.__name__.encode() - ) - return h.hexdigest() - - for index in indexes: - if isinstance(index, Index): - index.__hash__ = index_hash # type:ignore[method-assign,assignment] - ret.append(index) - return ret + if tortoise.__version__ > "0.22.2": + # The min version of tortoise is '0.11.0', so we can compare it by a `>`, + # tortoise>0.22.2 have __eq__/__hash__ with Index class since 313ee76. + return indexes + if index_classes := set(index.__class__ for index in indexes if isinstance(index, Index)): + # Leave magic patch here to compare with older version of tortoise-orm + # TODO: limit tortoise>0.22.2 in pyproject.toml and remove this function when v0.9.0 released + for index_cls in index_classes: + if index_cls(fields=("id",)) != index_cls(fields=("id",)): + + def _hash(self) -> int: + return hash((tuple(sorted(self.fields)), self.name, self.expressions)) + + def _eq(self, other) -> bool: + return type(self) is type(other) and self.__dict__ == other.__dict__ + + setattr(index_cls, "__hash__", _hash) + setattr(index_cls, "__eq__", _eq) + return indexes @classmethod def _get_indexes(cls, model, model_describe: dict) -> Set[Union[Index, Tuple[str, ...]]]: @@ -282,6 +288,68 @@ def _handle_m2m_fields( if add: cls._add_operator(cls.drop_m2m(table), upgrade, True) + @classmethod + def _handle_relational( + cls, + key: str, + old_model_describe: Dict, + new_model_describe: Dict, + model: Type[Model], + old_models: Dict, + new_models: Dict, + upgrade=True, + ) -> None: + old_fk_fields = cast(List[dict], old_model_describe.get(key)) + new_fk_fields = cast(List[dict], new_model_describe.get(key)) + + old_fk_fields_name: List[str] = [i.get("name", "") for i in old_fk_fields] + new_fk_fields_name: List[str] = [i.get("name", "") for i in new_fk_fields] + + # add + for new_fk_field_name in set(new_fk_fields_name).difference(set(old_fk_fields_name)): + fk_field = cls.get_field_by_name(new_fk_field_name, new_fk_fields) + if fk_field.get("db_constraint"): + ref_describe = cast(dict, new_models[fk_field["python_type"]]) + sql = cls._add_fk(model, fk_field, ref_describe) + cls._add_operator(sql, upgrade, fk_m2m_index=True) + # drop + for old_fk_field_name in set(old_fk_fields_name).difference(set(new_fk_fields_name)): + old_fk_field = cls.get_field_by_name(old_fk_field_name, cast(List[dict], old_fk_fields)) + if old_fk_field.get("db_constraint"): + ref_describe = cast(dict, old_models[old_fk_field["python_type"]]) + sql = cls._drop_fk(model, old_fk_field, ref_describe) + cls._add_operator(sql, upgrade, fk_m2m_index=True) + + @classmethod + def _handle_fk_fields( + cls, + old_model_describe: Dict, + new_model_describe: Dict, + model: Type[Model], + old_models: Dict, + new_models: Dict, + upgrade=True, + ) -> None: + key = "fk_fields" + cls._handle_relational( + key, old_model_describe, new_model_describe, model, old_models, new_models, upgrade + ) + + @classmethod + def _handle_o2o_fields( + cls, + old_model_describe: Dict, + new_model_describe: Dict, + model: Type[Model], + old_models: Dict, + new_models: Dict, + upgrade=True, + ) -> None: + key = "o2o_fields" + cls._handle_relational( + key, old_model_describe, new_model_describe, model, old_models, new_models, upgrade + ) + @classmethod def diff_models( cls, old_models: Dict[str, dict], new_models: Dict[str, dict], upgrade=True @@ -296,7 +364,7 @@ def diff_models( _aerich = f"{cls.app}.{cls._aerich}" old_models.pop(_aerich, None) new_models.pop(_aerich, None) - models_with_rename_field: Set[str] = set() + models_with_rename_field: Set[str] = set() # models that trigger the click.prompt for new_model_str, new_model_describe in new_models.items(): model = cls._get_model(new_model_describe["name"].split(".")[1]) @@ -336,6 +404,13 @@ def diff_models( # current only support rename pk if action == "change" and option == "name": cls._add_operator(cls._rename_field(model, *change), upgrade) + # fk fields + args = (old_model_describe, new_model_describe, model, old_models, new_models) + cls._handle_fk_fields(*args, upgrade=upgrade) + # o2o fields + cls._handle_o2o_fields(*args, upgrade=upgrade) + old_o2o_columns = [i["raw_field"] for i in old_model_describe.get("o2o_fields", [])] + new_o2o_columns = [i["raw_field"] for i in new_model_describe.get("o2o_fields", [])] # m2m fields cls._handle_m2m_fields( old_model_describe, new_model_describe, model, new_models, upgrade @@ -369,12 +444,10 @@ def diff_models( new_data_fields_name = cast(List[str], [i.get("name") for i in new_data_fields]) # add fields or rename fields - rename_fields: Dict[str, str] = {} for new_data_field_name in set(new_data_fields_name).difference( set(old_data_fields_name) ): new_data_field = cls.get_field_by_name(new_data_field_name, new_data_fields) - model_rename_fields = cls._rename_fields.get(new_model_str) is_rename = False field_type = new_data_field.get("field_type") db_column = new_data_field.get("db_column") @@ -399,8 +472,11 @@ def diff_models( and old_data_field_name not in new_data_fields_name ): if upgrade: - if old_data_field_name in rename_fields or ( - new_data_field_name in rename_fields.values() + if ( + rename_fields := cls._rename_fields.get(new_model_str) + ) and ( + old_data_field_name in rename_fields + or new_data_field_name in rename_fields.values() ): continue prefix = f"({new_model_str}) " @@ -417,22 +493,18 @@ def diff_models( show_choices=True, ) if is_rename: + if rename_fields is None: + rename_fields = cls._rename_fields[new_model_str] = {} rename_fields[old_data_field_name] = new_data_field_name else: is_rename = False - if model_rename_fields and ( - rename_to := model_rename_fields.get(new_data_field_name) + if rename_to := cls._rename_fields.get(new_model_str, {}).get( + new_data_field_name ): is_rename = True if rename_to != old_data_field_name: continue if is_rename: - if upgrade: - if new_model_str not in cls._rename_fields: - cls._rename_fields[new_model_str] = {} - cls._rename_fields[new_model_str][ - old_data_field_name - ] = new_data_field_name # only MySQL8+ has rename syntax if ( cls.dialect == "mysql" @@ -452,7 +524,10 @@ def diff_models( ) if not is_rename: cls._add_operator(cls._add_field(model, new_data_field), upgrade) - if new_data_field["indexed"]: + if ( + new_data_field["indexed"] + and new_data_field["db_column"] not in new_o2o_columns + ): cls._add_operator( cls._add_index( model, (new_data_field["db_column"],), new_data_field["unique"] @@ -461,14 +536,14 @@ def diff_models( True, ) # remove fields - model_rename_fields = cls._rename_fields.get(new_model_str) + rename_fields = cls._rename_fields.get(new_model_str) for old_data_field_name in set(old_data_fields_name).difference( set(new_data_fields_name) ): # don't remove field if is renamed - if model_rename_fields and ( - (upgrade and old_data_field_name in model_rename_fields) - or (not upgrade and old_data_field_name in model_rename_fields.values()) + if rename_fields and ( + (upgrade and old_data_field_name in rename_fields) + or (not upgrade and old_data_field_name in rename_fields.values()) ): continue old_data_field = cls.get_field_by_name(old_data_field_name, old_data_fields) @@ -477,7 +552,10 @@ def diff_models( cls._remove_field(model, db_column), upgrade, ) - if old_data_field["indexed"]: + if ( + old_data_field["indexed"] + and old_data_field["db_column"] not in old_o2o_columns + ): is_unique_field = old_data_field.get("unique") cls._add_operator( cls._drop_index(model, {db_column}, is_unique_field), @@ -485,38 +563,6 @@ def diff_models( True, ) - old_fk_fields = cast(List[dict], old_model_describe.get("fk_fields")) - new_fk_fields = cast(List[dict], new_model_describe.get("fk_fields")) - - old_fk_fields_name: List[str] = [i.get("name", "") for i in old_fk_fields] - new_fk_fields_name: List[str] = [i.get("name", "") for i in new_fk_fields] - - # add fk - for new_fk_field_name in set(new_fk_fields_name).difference( - set(old_fk_fields_name) - ): - fk_field = cls.get_field_by_name(new_fk_field_name, new_fk_fields) - if fk_field.get("db_constraint"): - ref_describe = cast(dict, new_models[fk_field["python_type"]]) - cls._add_operator( - cls._add_fk(model, fk_field, ref_describe), - upgrade, - fk_m2m_index=True, - ) - # drop fk - for old_fk_field_name in set(old_fk_fields_name).difference( - set(new_fk_fields_name) - ): - old_fk_field = cls.get_field_by_name( - old_fk_field_name, cast(List[dict], old_fk_fields) - ) - if old_fk_field.get("db_constraint"): - ref_describe = cast(dict, old_models[old_fk_field["python_type"]]) - cls._add_operator( - cls._drop_fk(model, old_fk_field, ref_describe), - upgrade, - fk_m2m_index=True, - ) # change fields for field_name in set(new_data_fields_name).intersection(set(old_data_fields_name)): old_data_field = cls.get_field_by_name(field_name, old_data_fields) diff --git a/tests/indexes.py b/tests/indexes.py new file mode 100644 index 0000000..83959f8 --- /dev/null +++ b/tests/indexes.py @@ -0,0 +1,7 @@ +from tortoise.indexes import Index + + +class CustomIndex(Index): + def __init__(self, *args, **kw) -> None: + super().__init__(*args, **kw) + self._foo = "" diff --git a/tests/models.py b/tests/models.py index 959ba8f..527af12 100644 --- a/tests/models.py +++ b/tests/models.py @@ -3,6 +3,9 @@ from enum import IntEnum from tortoise import Model, fields +from tortoise.indexes import Index + +from tests.indexes import CustomIndex class ProductType(IntEnum): @@ -33,6 +36,10 @@ class User(Model): products: fields.ManyToManyRelation["Product"] + class Meta: + # reverse indexes elements + indexes = [CustomIndex(fields=("is_superuser",)), Index(fields=("username", "is_active"))] + class Email(Model): email_id = fields.IntField(primary_key=True) @@ -40,6 +47,7 @@ class Email(Model): is_primary = fields.BooleanField(default=False) address = fields.CharField(max_length=200) users: fields.ManyToManyRelation[User] = fields.ManyToManyField("models.User") + config: fields.OneToOneRelation["Config"] = fields.OneToOneField("models.Config") def default_name(): @@ -92,6 +100,8 @@ class Config(Model): "models.User", description="User" ) + email: fields.OneToOneRelation["Email"] + class NewModel(Model): name = fields.CharField(max_length=50) diff --git a/tests/old_models.py b/tests/old_models.py index faffd41..92eb7f8 100644 --- a/tests/old_models.py +++ b/tests/old_models.py @@ -2,6 +2,9 @@ from enum import IntEnum from tortoise import Model, fields +from tortoise.indexes import Index + +from tests.indexes import CustomIndex class ProductType(IntEnum): @@ -31,6 +34,9 @@ class User(Model): intro = fields.TextField(default="") longitude = fields.DecimalField(max_digits=12, decimal_places=9) + class Meta: + indexes = [Index(fields=("username", "is_active")), CustomIndex(fields=("is_superuser",))] + class Email(Model): email = fields.CharField(max_length=200) diff --git a/tests/test_migrate.py b/tests/test_migrate.py index 30d9d11..9e85a2e 100644 --- a/tests/test_migrate.py +++ b/tests/test_migrate.py @@ -3,6 +3,7 @@ import pytest import tortoise from pytest_mock import MockerFixture +from tortoise.indexes import Index from aerich.ddl.mysql import MysqlDDL from aerich.ddl.postgres import PostgresDDL @@ -10,6 +11,7 @@ from aerich.exceptions import NotSupportError from aerich.migrate import MIGRATE_TEMPLATE, Migrate from aerich.utils import get_models_describe +from tests.indexes import CustomIndex # tortoise-orm>=0.21 changes IntField constraints # from {"ge": 1, "le": 2147483647} to {"ge": -2147483648, "le": 2147483647} @@ -638,7 +640,7 @@ "description": None, "docstring": None, "unique_together": [], - "indexes": [], + "indexes": [Index(fields=("username", "is_active")), CustomIndex(fields=("is_superuser",))], "pk_field": { "name": "id", "field_type": "IntField", @@ -915,6 +917,7 @@ def test_migrate(mocker: MockerFixture): - drop field: User.avatar - add index: Email.email - add many to many: Email.users + - add one to one: Email.config - remove unique: Category.title - add unique: User.username - change column: length User.password @@ -957,6 +960,8 @@ def test_migrate(mocker: MockerFixture): "ALTER TABLE `config` ALTER COLUMN `status` DROP DEFAULT", "ALTER TABLE `config` MODIFY COLUMN `value` JSON NOT NULL", "ALTER TABLE `email` ADD `address` VARCHAR(200) NOT NULL", + "ALTER TABLE `email` ADD CONSTRAINT `fk_email_config_76a9dc71` FOREIGN KEY (`config_id`) REFERENCES `config` (`id`) ON DELETE CASCADE", + "ALTER TABLE `email` ADD `config_id` INT NOT NULL UNIQUE", "ALTER TABLE `configs` RENAME TO `config`", "ALTER TABLE `product` DROP COLUMN `uuid`", "ALTER TABLE `product` DROP INDEX `uuid`", @@ -996,6 +1001,8 @@ def test_migrate(mocker: MockerFixture): "ALTER TABLE `email` ADD `user_id` INT NOT NULL", "ALTER TABLE `config` DROP COLUMN `user_id`", "ALTER TABLE `email` DROP COLUMN `address`", + "ALTER TABLE `email` DROP COLUMN `config_id`", + "ALTER TABLE `email` DROP FOREIGN KEY `fk_email_config_76a9dc71`", "ALTER TABLE `config` RENAME TO `configs`", "ALTER TABLE `product` RENAME COLUMN `pic` TO `image`", "ALTER TABLE `email` RENAME COLUMN `email_id` TO `id`", @@ -1047,6 +1054,8 @@ def test_migrate(mocker: MockerFixture): 'ALTER TABLE "email" ADD "address" VARCHAR(200) NOT NULL', 'ALTER TABLE "email" RENAME COLUMN "id" TO "email_id"', 'ALTER TABLE "email" DROP COLUMN "user_id"', + 'ALTER TABLE "email" ADD CONSTRAINT "fk_email_config_76a9dc71" FOREIGN KEY ("config_id") REFERENCES "config" ("id") ON DELETE CASCADE', + 'ALTER TABLE "email" ADD "config_id" INT NOT NULL UNIQUE', 'DROP INDEX IF EXISTS "uid_product_uuid_d33c18"', 'ALTER TABLE "product" DROP COLUMN "uuid"', 'ALTER TABLE "product" ALTER COLUMN "view_num" SET DEFAULT 0', @@ -1087,6 +1096,8 @@ def test_migrate(mocker: MockerFixture): 'ALTER TABLE "email" ADD "user_id" INT NOT NULL', 'ALTER TABLE "email" DROP COLUMN "address"', 'ALTER TABLE "email" RENAME COLUMN "email_id" TO "id"', + 'ALTER TABLE "email" DROP COLUMN "config_id"', + 'ALTER TABLE "email" DROP CONSTRAINT IF EXISTS "fk_email_config_76a9dc71"', 'ALTER TABLE "product" ADD "uuid" INT NOT NULL UNIQUE', 'CREATE UNIQUE INDEX "uid_product_uuid_d33c18" ON "product" ("uuid")', 'ALTER TABLE "product" ALTER COLUMN "view_num" DROP DEFAULT',