Skip to content

Commit

Permalink
fix conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
waketzheng committed Dec 21, 2024
2 parents 99e7258 + 7d22518 commit 19adfe8
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 55 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
### [0.8.1](Unreleased)

#### Fixed
- fix: add o2o field does not create constraint when migrating. (#396)
- Migration with duplicate renaming of columns in some cases. (#395)
- fix: intermediate table for m2m relation not created. (#394)
- Migrate add m2m field with custom through generate duplicated table. (#393)
Expand All @@ -16,6 +17,7 @@
- Fix configuration file reading error when containing Chinese characters. (#286)
- sqlite: failed to create/drop index. (#302)
- PostgreSQL: Cannot drop constraint after deleting or rename FK on a model. (#378)
- Fix create/drop indexes in every migration. (#377)
- Sort m2m fields before comparing them with diff. (#271)

#### Changed
Expand Down
157 changes: 103 additions & 54 deletions aerich/migrate.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import hashlib
from __future__ import annotations

import importlib
import os
from datetime import datetime
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Set, Tuple, Type, Union, cast

import asyncclick as click
import tortoise
from dictdiffer import diff
from tortoise import BaseDBAsyncClient, Model, Tortoise
from tortoise.exceptions import OperationalError
Expand Down Expand Up @@ -201,21 +203,25 @@ def _add_operator(cls, operator: str, upgrade=True, fk_m2m_index=False) -> None:

@classmethod
def _handle_indexes(cls, model: Type[Model], indexes: List[Union[Tuple[str], Index]]) -> list:
ret: list = []

def index_hash(self) -> str:
h = hashlib.new("MD5", usedforsecurity=False) # type:ignore[call-arg]
h.update(
self.index_name(cls.ddl.schema_generator, model).encode()
+ self.__class__.__name__.encode()
)
return h.hexdigest()

for index in indexes:
if isinstance(index, Index):
index.__hash__ = index_hash # type:ignore[method-assign,assignment]
ret.append(index)
return ret
if tortoise.__version__ > "0.22.2":
# The min version of tortoise is '0.11.0', so we can compare it by a `>`,
# tortoise>0.22.2 have __eq__/__hash__ with Index class since 313ee76.
return indexes
if index_classes := set(index.__class__ for index in indexes if isinstance(index, Index)):
# Leave magic patch here to compare with older version of tortoise-orm
# TODO: limit tortoise>0.22.2 in pyproject.toml and remove this function when v0.9.0 released
for index_cls in index_classes:
if index_cls(fields=("id",)) != index_cls(fields=("id",)):

def _hash(self) -> int:
return hash((tuple(sorted(self.fields)), self.name, self.expressions))

def _eq(self, other) -> bool:
return type(self) is type(other) and self.__dict__ == other.__dict__

setattr(index_cls, "__hash__", _hash)
setattr(index_cls, "__eq__", _eq)
return indexes

@classmethod
def _get_indexes(cls, model, model_describe: dict) -> Set[Union[Index, Tuple[str, ...]]]:
Expand Down Expand Up @@ -280,6 +286,68 @@ def _handle_m2m_fields(
if add:
cls._add_operator(cls.drop_m2m(table), upgrade, True)

@classmethod
def _handle_relational(
cls,
key: str,
old_model_describe: Dict,
new_model_describe: Dict,
model: Type[Model],
old_models: Dict,
new_models: Dict,
upgrade=True,
) -> None:
old_fk_fields = cast(List[dict], old_model_describe.get(key))
new_fk_fields = cast(List[dict], new_model_describe.get(key))

old_fk_fields_name: List[str] = [i.get("name", "") for i in old_fk_fields]
new_fk_fields_name: List[str] = [i.get("name", "") for i in new_fk_fields]

# add
for new_fk_field_name in set(new_fk_fields_name).difference(set(old_fk_fields_name)):
fk_field = cls.get_field_by_name(new_fk_field_name, new_fk_fields)
if fk_field.get("db_constraint"):
ref_describe = cast(dict, new_models[fk_field["python_type"]])
sql = cls._add_fk(model, fk_field, ref_describe)
cls._add_operator(sql, upgrade, fk_m2m_index=True)
# drop
for old_fk_field_name in set(old_fk_fields_name).difference(set(new_fk_fields_name)):
old_fk_field = cls.get_field_by_name(old_fk_field_name, cast(List[dict], old_fk_fields))
if old_fk_field.get("db_constraint"):
ref_describe = cast(dict, old_models[old_fk_field["python_type"]])
sql = cls._drop_fk(model, old_fk_field, ref_describe)
cls._add_operator(sql, upgrade, fk_m2m_index=True)

@classmethod
def _handle_fk_fields(
cls,
old_model_describe: Dict,
new_model_describe: Dict,
model: Type[Model],
old_models: Dict,
new_models: Dict,
upgrade=True,
) -> None:
key = "fk_fields"
cls._handle_relational(
key, old_model_describe, new_model_describe, model, old_models, new_models, upgrade
)

@classmethod
def _handle_o2o_fields(
cls,
old_model_describe: Dict,
new_model_describe: Dict,
model: Type[Model],
old_models: Dict,
new_models: Dict,
upgrade=True,
) -> None:
key = "o2o_fields"
cls._handle_relational(
key, old_model_describe, new_model_describe, model, old_models, new_models, upgrade
)

@classmethod
def diff_models(
cls, old_models: Dict[str, dict], new_models: Dict[str, dict], upgrade=True
Expand Down Expand Up @@ -334,6 +402,13 @@ def diff_models(
# current only support rename pk
if action == "change" and option == "name":
cls._add_operator(cls._rename_field(model, *change), upgrade)
# fk fields
args = (old_model_describe, new_model_describe, model, old_models, new_models)
cls._handle_fk_fields(*args, upgrade=upgrade)
# o2o fields
cls._handle_o2o_fields(*args, upgrade=upgrade)
old_o2o_columns = [i["raw_field"] for i in old_model_describe.get("o2o_fields", [])]
new_o2o_columns = [i["raw_field"] for i in new_model_describe.get("o2o_fields", [])]
# m2m fields
cls._handle_m2m_fields(
old_model_describe, new_model_describe, model, new_models, upgrade
Expand Down Expand Up @@ -447,7 +522,10 @@ def diff_models(
)
if not is_rename:
cls._add_operator(cls._add_field(model, new_data_field), upgrade)
if new_data_field["indexed"]:
if (
new_data_field["indexed"]
and new_data_field["db_column"] not in new_o2o_columns
):
cls._add_operator(
cls._add_index(
model, (new_data_field["db_column"],), new_data_field["unique"]
Expand All @@ -456,14 +534,14 @@ def diff_models(
True,
)
# remove fields
model_rename_fields = cls._rename_fields.get(new_model_str)
rename_fields = cls._rename_fields.get(new_model_str)
for old_data_field_name in set(old_data_fields_name).difference(
set(new_data_fields_name)
):
# don't remove field if is renamed
if model_rename_fields and (
(upgrade and old_data_field_name in model_rename_fields)
or (not upgrade and old_data_field_name in model_rename_fields.values())
if rename_fields and (
(upgrade and old_data_field_name in rename_fields)
or (not upgrade and old_data_field_name in rename_fields.values())
):
continue
old_data_field = cls.get_field_by_name(old_data_field_name, old_data_fields)
Expand All @@ -472,46 +550,17 @@ def diff_models(
cls._remove_field(model, db_column),
upgrade,
)
if old_data_field["indexed"]:
if (
old_data_field["indexed"]
and old_data_field["db_column"] not in old_o2o_columns
):
is_unique_field = old_data_field.get("unique")
cls._add_operator(
cls._drop_index(model, {db_column}, is_unique_field),
upgrade,
True,
)

old_fk_fields = cast(List[dict], old_model_describe.get("fk_fields"))
new_fk_fields = cast(List[dict], new_model_describe.get("fk_fields"))

old_fk_fields_name: List[str] = [i.get("name", "") for i in old_fk_fields]
new_fk_fields_name: List[str] = [i.get("name", "") for i in new_fk_fields]

# add fk
for new_fk_field_name in set(new_fk_fields_name).difference(
set(old_fk_fields_name)
):
fk_field = cls.get_field_by_name(new_fk_field_name, new_fk_fields)
if fk_field.get("db_constraint"):
ref_describe = cast(dict, new_models[fk_field["python_type"]])
cls._add_operator(
cls._add_fk(model, fk_field, ref_describe),
upgrade,
fk_m2m_index=True,
)
# drop fk
for old_fk_field_name in set(old_fk_fields_name).difference(
set(new_fk_fields_name)
):
old_fk_field = cls.get_field_by_name(
old_fk_field_name, cast(List[dict], old_fk_fields)
)
if old_fk_field.get("db_constraint"):
ref_describe = cast(dict, old_models[old_fk_field["python_type"]])
cls._add_operator(
cls._drop_fk(model, old_fk_field, ref_describe),
upgrade,
fk_m2m_index=True,
)
# change fields
for field_name in set(new_data_fields_name).intersection(set(old_data_fields_name)):
old_data_field = cls.get_field_by_name(field_name, old_data_fields)
Expand Down
7 changes: 7 additions & 0 deletions tests/indexes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from tortoise.indexes import Index


class CustomIndex(Index):
def __init__(self, *args, **kw) -> None:
super().__init__(*args, **kw)
self._foo = ""
10 changes: 10 additions & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from enum import IntEnum

from tortoise import Model, fields
from tortoise.indexes import Index

from tests.indexes import CustomIndex


class ProductType(IntEnum):
Expand Down Expand Up @@ -33,13 +36,18 @@ class User(Model):

products: fields.ManyToManyRelation["Product"]

class Meta:
# reverse indexes elements
indexes = [CustomIndex(fields=("is_superuser",)), Index(fields=("username", "is_active"))]


class Email(Model):
email_id = fields.IntField(primary_key=True)
email = fields.CharField(max_length=200, db_index=True)
is_primary = fields.BooleanField(default=False)
address = fields.CharField(max_length=200)
users: fields.ManyToManyRelation[User] = fields.ManyToManyField("models.User")
config: fields.OneToOneRelation["Config"] = fields.OneToOneField("models.Config")


def default_name():
Expand Down Expand Up @@ -92,6 +100,8 @@ class Config(Model):
"models.User", description="User"
)

email: fields.OneToOneRelation["Email"]


class NewModel(Model):
name = fields.CharField(max_length=50)
6 changes: 6 additions & 0 deletions tests/old_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from enum import IntEnum

from tortoise import Model, fields
from tortoise.indexes import Index

from tests.indexes import CustomIndex


class ProductType(IntEnum):
Expand Down Expand Up @@ -31,6 +34,9 @@ class User(Model):
intro = fields.TextField(default="")
longitude = fields.DecimalField(max_digits=12, decimal_places=9)

class Meta:
indexes = [Index(fields=("username", "is_active")), CustomIndex(fields=("is_superuser",))]


class Email(Model):
email = fields.CharField(max_length=200)
Expand Down
13 changes: 12 additions & 1 deletion tests/test_migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
import pytest
import tortoise
from pytest_mock import MockerFixture
from tortoise.indexes import Index

from aerich.ddl.mysql import MysqlDDL
from aerich.ddl.postgres import PostgresDDL
from aerich.ddl.sqlite import SqliteDDL
from aerich.exceptions import NotSupportError
from aerich.migrate import MIGRATE_TEMPLATE, Migrate
from aerich.utils import get_models_describe
from tests.indexes import CustomIndex

# tortoise-orm>=0.21 changes IntField constraints
# from {"ge": 1, "le": 2147483647} to {"ge": -2147483648, "le": 2147483647}
Expand Down Expand Up @@ -638,7 +640,7 @@
"description": None,
"docstring": None,
"unique_together": [],
"indexes": [],
"indexes": [Index(fields=("username", "is_active")), CustomIndex(fields=("is_superuser",))],
"pk_field": {
"name": "id",
"field_type": "IntField",
Expand Down Expand Up @@ -915,6 +917,7 @@ def test_migrate(mocker: MockerFixture):
- drop field: User.avatar
- add index: Email.email
- add many to many: Email.users
- add one to one: Email.config
- remove unique: Category.title
- add unique: User.username
- change column: length User.password
Expand Down Expand Up @@ -957,6 +960,8 @@ def test_migrate(mocker: MockerFixture):
"ALTER TABLE `config` ALTER COLUMN `status` DROP DEFAULT",
"ALTER TABLE `config` MODIFY COLUMN `value` JSON NOT NULL",
"ALTER TABLE `email` ADD `address` VARCHAR(200) NOT NULL",
"ALTER TABLE `email` ADD CONSTRAINT `fk_email_config_76a9dc71` FOREIGN KEY (`config_id`) REFERENCES `config` (`id`) ON DELETE CASCADE",
"ALTER TABLE `email` ADD `config_id` INT NOT NULL UNIQUE",
"ALTER TABLE `configs` RENAME TO `config`",
"ALTER TABLE `product` DROP COLUMN `uuid`",
"ALTER TABLE `product` DROP INDEX `uuid`",
Expand Down Expand Up @@ -996,6 +1001,8 @@ def test_migrate(mocker: MockerFixture):
"ALTER TABLE `email` ADD `user_id` INT NOT NULL",
"ALTER TABLE `config` DROP COLUMN `user_id`",
"ALTER TABLE `email` DROP COLUMN `address`",
"ALTER TABLE `email` DROP COLUMN `config_id`",
"ALTER TABLE `email` DROP FOREIGN KEY `fk_email_config_76a9dc71`",
"ALTER TABLE `config` RENAME TO `configs`",
"ALTER TABLE `product` RENAME COLUMN `pic` TO `image`",
"ALTER TABLE `email` RENAME COLUMN `email_id` TO `id`",
Expand Down Expand Up @@ -1047,6 +1054,8 @@ def test_migrate(mocker: MockerFixture):
'ALTER TABLE "email" ADD "address" VARCHAR(200) NOT NULL',
'ALTER TABLE "email" RENAME COLUMN "id" TO "email_id"',
'ALTER TABLE "email" DROP COLUMN "user_id"',
'ALTER TABLE "email" ADD CONSTRAINT "fk_email_config_76a9dc71" FOREIGN KEY ("config_id") REFERENCES "config" ("id") ON DELETE CASCADE',
'ALTER TABLE "email" ADD "config_id" INT NOT NULL UNIQUE',
'DROP INDEX IF EXISTS "uid_product_uuid_d33c18"',
'ALTER TABLE "product" DROP COLUMN "uuid"',
'ALTER TABLE "product" ALTER COLUMN "view_num" SET DEFAULT 0',
Expand Down Expand Up @@ -1087,6 +1096,8 @@ def test_migrate(mocker: MockerFixture):
'ALTER TABLE "email" ADD "user_id" INT NOT NULL',
'ALTER TABLE "email" DROP COLUMN "address"',
'ALTER TABLE "email" RENAME COLUMN "email_id" TO "id"',
'ALTER TABLE "email" DROP COLUMN "config_id"',
'ALTER TABLE "email" DROP CONSTRAINT IF EXISTS "fk_email_config_76a9dc71"',
'ALTER TABLE "product" ADD "uuid" INT NOT NULL UNIQUE',
'CREATE UNIQUE INDEX "uid_product_uuid_d33c18" ON "product" ("uuid")',
'ALTER TABLE "product" ALTER COLUMN "view_num" DROP DEFAULT',
Expand Down

0 comments on commit 19adfe8

Please sign in to comment.