Skip to content

Commit

Permalink
Merge branch 'fix-380-rename-multi-fields' of github.com:waketzheng/a…
Browse files Browse the repository at this point in the history
…erich into fix-380-rename-multi-fields
  • Loading branch information
waketzheng committed Dec 23, 2024
2 parents 5d460be + 19adfe8 commit 6f8b524
Show file tree
Hide file tree
Showing 6 changed files with 150 additions and 68 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
### [0.8.1](Unreleased)

#### Fixed
- fix: add o2o field does not create constraint when migrating. (#396)
- Migration with duplicate renaming of columns in some cases. (#395)
- fix: intermediate table for m2m relation not created. (#394)
- Migrate add m2m field with custom through generate duplicated table. (#393)
Expand All @@ -16,6 +17,7 @@
- Fix configuration file reading error when containing Chinese characters. (#286)
- sqlite: failed to create/drop index. (#302)
- PostgreSQL: Cannot drop constraint after deleting or rename FK on a model. (#378)
- Fix create/drop indexes in every migration. (#377)
- Sort m2m fields before comparing them with diff. (#271)

#### Changed
Expand Down
180 changes: 113 additions & 67 deletions aerich/migrate.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import hashlib
from __future__ import annotations

import importlib
import os
from datetime import datetime
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Set, Tuple, Type, Union, cast

import asyncclick as click
import tortoise
from dictdiffer import diff
from tortoise import BaseDBAsyncClient, ConfigurationError, Model, Tortoise
from tortoise.exceptions import OperationalError
Expand Down Expand Up @@ -201,21 +203,25 @@ def _add_operator(cls, operator: str, upgrade=True, fk_m2m_index=False) -> None:

@classmethod
def _handle_indexes(cls, model: Type[Model], indexes: List[Union[Tuple[str], Index]]) -> list:
ret: list = []

def index_hash(self) -> str:
h = hashlib.new("MD5", usedforsecurity=False) # type:ignore[call-arg]
h.update(
self.index_name(cls.ddl.schema_generator, model).encode()
+ self.__class__.__name__.encode()
)
return h.hexdigest()

for index in indexes:
if isinstance(index, Index):
index.__hash__ = index_hash # type:ignore[method-assign,assignment]
ret.append(index)
return ret
if tortoise.__version__ > "0.22.2":
# The min version of tortoise is '0.11.0', so we can compare it by a `>`,
# tortoise>0.22.2 have __eq__/__hash__ with Index class since 313ee76.
return indexes
if index_classes := set(index.__class__ for index in indexes if isinstance(index, Index)):
# Leave magic patch here to compare with older version of tortoise-orm
# TODO: limit tortoise>0.22.2 in pyproject.toml and remove this function when v0.9.0 released
for index_cls in index_classes:
if index_cls(fields=("id",)) != index_cls(fields=("id",)):

def _hash(self) -> int:
return hash((tuple(sorted(self.fields)), self.name, self.expressions))

def _eq(self, other) -> bool:
return type(self) is type(other) and self.__dict__ == other.__dict__

setattr(index_cls, "__hash__", _hash)
setattr(index_cls, "__eq__", _eq)
return indexes

@classmethod
def _get_indexes(cls, model, model_describe: dict) -> Set[Union[Index, Tuple[str, ...]]]:
Expand Down Expand Up @@ -282,6 +288,68 @@ def _handle_m2m_fields(
if add:
cls._add_operator(cls.drop_m2m(table), upgrade, True)

@classmethod
def _handle_relational(
cls,
key: str,
old_model_describe: Dict,
new_model_describe: Dict,
model: Type[Model],
old_models: Dict,
new_models: Dict,
upgrade=True,
) -> None:
old_fk_fields = cast(List[dict], old_model_describe.get(key))
new_fk_fields = cast(List[dict], new_model_describe.get(key))

old_fk_fields_name: List[str] = [i.get("name", "") for i in old_fk_fields]
new_fk_fields_name: List[str] = [i.get("name", "") for i in new_fk_fields]

# add
for new_fk_field_name in set(new_fk_fields_name).difference(set(old_fk_fields_name)):
fk_field = cls.get_field_by_name(new_fk_field_name, new_fk_fields)
if fk_field.get("db_constraint"):
ref_describe = cast(dict, new_models[fk_field["python_type"]])
sql = cls._add_fk(model, fk_field, ref_describe)
cls._add_operator(sql, upgrade, fk_m2m_index=True)
# drop
for old_fk_field_name in set(old_fk_fields_name).difference(set(new_fk_fields_name)):
old_fk_field = cls.get_field_by_name(old_fk_field_name, cast(List[dict], old_fk_fields))
if old_fk_field.get("db_constraint"):
ref_describe = cast(dict, old_models[old_fk_field["python_type"]])
sql = cls._drop_fk(model, old_fk_field, ref_describe)
cls._add_operator(sql, upgrade, fk_m2m_index=True)

@classmethod
def _handle_fk_fields(
cls,
old_model_describe: Dict,
new_model_describe: Dict,
model: Type[Model],
old_models: Dict,
new_models: Dict,
upgrade=True,
) -> None:
key = "fk_fields"
cls._handle_relational(
key, old_model_describe, new_model_describe, model, old_models, new_models, upgrade
)

@classmethod
def _handle_o2o_fields(
cls,
old_model_describe: Dict,
new_model_describe: Dict,
model: Type[Model],
old_models: Dict,
new_models: Dict,
upgrade=True,
) -> None:
key = "o2o_fields"
cls._handle_relational(
key, old_model_describe, new_model_describe, model, old_models, new_models, upgrade
)

@classmethod
def diff_models(
cls, old_models: Dict[str, dict], new_models: Dict[str, dict], upgrade=True
Expand All @@ -296,7 +364,7 @@ def diff_models(
_aerich = f"{cls.app}.{cls._aerich}"
old_models.pop(_aerich, None)
new_models.pop(_aerich, None)
models_with_rename_field: Set[str] = set()
models_with_rename_field: Set[str] = set() # models that trigger the click.prompt

for new_model_str, new_model_describe in new_models.items():
model = cls._get_model(new_model_describe["name"].split(".")[1])
Expand Down Expand Up @@ -336,6 +404,13 @@ def diff_models(
# current only support rename pk
if action == "change" and option == "name":
cls._add_operator(cls._rename_field(model, *change), upgrade)
# fk fields
args = (old_model_describe, new_model_describe, model, old_models, new_models)
cls._handle_fk_fields(*args, upgrade=upgrade)
# o2o fields
cls._handle_o2o_fields(*args, upgrade=upgrade)
old_o2o_columns = [i["raw_field"] for i in old_model_describe.get("o2o_fields", [])]
new_o2o_columns = [i["raw_field"] for i in new_model_describe.get("o2o_fields", [])]
# m2m fields
cls._handle_m2m_fields(
old_model_describe, new_model_describe, model, new_models, upgrade
Expand Down Expand Up @@ -369,12 +444,10 @@ def diff_models(
new_data_fields_name = cast(List[str], [i.get("name") for i in new_data_fields])

# add fields or rename fields
rename_fields: Dict[str, str] = {}
for new_data_field_name in set(new_data_fields_name).difference(
set(old_data_fields_name)
):
new_data_field = cls.get_field_by_name(new_data_field_name, new_data_fields)
model_rename_fields = cls._rename_fields.get(new_model_str)
is_rename = False
field_type = new_data_field.get("field_type")
db_column = new_data_field.get("db_column")
Expand All @@ -399,8 +472,11 @@ def diff_models(
and old_data_field_name not in new_data_fields_name
):
if upgrade:
if old_data_field_name in rename_fields or (
new_data_field_name in rename_fields.values()
if (
rename_fields := cls._rename_fields.get(new_model_str)
) and (
old_data_field_name in rename_fields
or new_data_field_name in rename_fields.values()
):
continue
prefix = f"({new_model_str}) "
Expand All @@ -417,22 +493,18 @@ def diff_models(
show_choices=True,
)
if is_rename:
if rename_fields is None:
rename_fields = cls._rename_fields[new_model_str] = {}
rename_fields[old_data_field_name] = new_data_field_name
else:
is_rename = False
if model_rename_fields and (
rename_to := model_rename_fields.get(new_data_field_name)
if rename_to := cls._rename_fields.get(new_model_str, {}).get(
new_data_field_name
):
is_rename = True
if rename_to != old_data_field_name:
continue
if is_rename:
if upgrade:
if new_model_str not in cls._rename_fields:
cls._rename_fields[new_model_str] = {}
cls._rename_fields[new_model_str][
old_data_field_name
] = new_data_field_name
# only MySQL8+ has rename syntax
if (
cls.dialect == "mysql"
Expand All @@ -452,7 +524,10 @@ def diff_models(
)
if not is_rename:
cls._add_operator(cls._add_field(model, new_data_field), upgrade)
if new_data_field["indexed"]:
if (
new_data_field["indexed"]
and new_data_field["db_column"] not in new_o2o_columns
):
cls._add_operator(
cls._add_index(
model, (new_data_field["db_column"],), new_data_field["unique"]
Expand All @@ -461,14 +536,14 @@ def diff_models(
True,
)
# remove fields
model_rename_fields = cls._rename_fields.get(new_model_str)
rename_fields = cls._rename_fields.get(new_model_str)
for old_data_field_name in set(old_data_fields_name).difference(
set(new_data_fields_name)
):
# don't remove field if is renamed
if model_rename_fields and (
(upgrade and old_data_field_name in model_rename_fields)
or (not upgrade and old_data_field_name in model_rename_fields.values())
if rename_fields and (
(upgrade and old_data_field_name in rename_fields)
or (not upgrade and old_data_field_name in rename_fields.values())
):
continue
old_data_field = cls.get_field_by_name(old_data_field_name, old_data_fields)
Expand All @@ -477,46 +552,17 @@ def diff_models(
cls._remove_field(model, db_column),
upgrade,
)
if old_data_field["indexed"]:
if (
old_data_field["indexed"]
and old_data_field["db_column"] not in old_o2o_columns
):
is_unique_field = old_data_field.get("unique")
cls._add_operator(
cls._drop_index(model, {db_column}, is_unique_field),
upgrade,
True,
)

old_fk_fields = cast(List[dict], old_model_describe.get("fk_fields"))
new_fk_fields = cast(List[dict], new_model_describe.get("fk_fields"))

old_fk_fields_name: List[str] = [i.get("name", "") for i in old_fk_fields]
new_fk_fields_name: List[str] = [i.get("name", "") for i in new_fk_fields]

# add fk
for new_fk_field_name in set(new_fk_fields_name).difference(
set(old_fk_fields_name)
):
fk_field = cls.get_field_by_name(new_fk_field_name, new_fk_fields)
if fk_field.get("db_constraint"):
ref_describe = cast(dict, new_models[fk_field["python_type"]])
cls._add_operator(
cls._add_fk(model, fk_field, ref_describe),
upgrade,
fk_m2m_index=True,
)
# drop fk
for old_fk_field_name in set(old_fk_fields_name).difference(
set(new_fk_fields_name)
):
old_fk_field = cls.get_field_by_name(
old_fk_field_name, cast(List[dict], old_fk_fields)
)
if old_fk_field.get("db_constraint"):
ref_describe = cast(dict, old_models[old_fk_field["python_type"]])
cls._add_operator(
cls._drop_fk(model, old_fk_field, ref_describe),
upgrade,
fk_m2m_index=True,
)
# change fields
for field_name in set(new_data_fields_name).intersection(set(old_data_fields_name)):
old_data_field = cls.get_field_by_name(field_name, old_data_fields)
Expand Down
7 changes: 7 additions & 0 deletions tests/indexes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from tortoise.indexes import Index


class CustomIndex(Index):
def __init__(self, *args, **kw) -> None:
super().__init__(*args, **kw)
self._foo = ""
10 changes: 10 additions & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from enum import IntEnum

from tortoise import Model, fields
from tortoise.indexes import Index

from tests.indexes import CustomIndex


class ProductType(IntEnum):
Expand Down Expand Up @@ -33,13 +36,18 @@ class User(Model):

products: fields.ManyToManyRelation["Product"]

class Meta:
# reverse indexes elements
indexes = [CustomIndex(fields=("is_superuser",)), Index(fields=("username", "is_active"))]


class Email(Model):
email_id = fields.IntField(primary_key=True)
email = fields.CharField(max_length=200, db_index=True)
is_primary = fields.BooleanField(default=False)
address = fields.CharField(max_length=200)
users: fields.ManyToManyRelation[User] = fields.ManyToManyField("models.User")
config: fields.OneToOneRelation["Config"] = fields.OneToOneField("models.Config")


def default_name():
Expand Down Expand Up @@ -92,6 +100,8 @@ class Config(Model):
"models.User", description="User"
)

email: fields.OneToOneRelation["Email"]


class NewModel(Model):
name = fields.CharField(max_length=50)
6 changes: 6 additions & 0 deletions tests/old_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from enum import IntEnum

from tortoise import Model, fields
from tortoise.indexes import Index

from tests.indexes import CustomIndex


class ProductType(IntEnum):
Expand Down Expand Up @@ -31,6 +34,9 @@ class User(Model):
intro = fields.TextField(default="")
longitude = fields.DecimalField(max_digits=12, decimal_places=9)

class Meta:
indexes = [Index(fields=("username", "is_active")), CustomIndex(fields=("is_superuser",))]


class Email(Model):
email = fields.CharField(max_length=200)
Expand Down
Loading

0 comments on commit 6f8b524

Please sign in to comment.