Skip to content

Commit

Permalink
fix: add o2o field does not create constraint when migrating
Browse files Browse the repository at this point in the history
  • Loading branch information
waketzheng committed Dec 18, 2024
1 parent 69ce0ca commit fe58032
Showing 1 changed file with 48 additions and 34 deletions.
82 changes: 48 additions & 34 deletions aerich/migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,39 @@ def _handle_m2m_fields(
if add:
cls._add_operator(cls.drop_m2m(table), upgrade, True)

@classmethod
def _handle_fk_fields(
cls,
old_model_describe: Dict,
new_model_describe: Dict,
model,
old_models,
new_models,
upgrade=True,
is_o2o=False,
) -> None:
key = "o2o_fields" if is_o2o else "fk_fields"
old_fk_fields = cast(List[dict], old_model_describe.get(key))
new_fk_fields = cast(List[dict], new_model_describe.get(key))

old_fk_fields_name: List[str] = [i.get("name", "") for i in old_fk_fields]
new_fk_fields_name: List[str] = [i.get("name", "") for i in new_fk_fields]

# add
for new_fk_field_name in set(new_fk_fields_name).difference(set(old_fk_fields_name)):
fk_field = cls.get_field_by_name(new_fk_field_name, new_fk_fields)
if fk_field.get("db_constraint"):
ref_describe = cast(dict, new_models[fk_field["python_type"]])
sql = cls._add_fk(model, fk_field, ref_describe)
cls._add_operator(sql, upgrade, fk_m2m_index=True)
# drop
for old_fk_field_name in set(old_fk_fields_name).difference(set(new_fk_fields_name)):
old_fk_field = cls.get_field_by_name(old_fk_field_name, cast(List[dict], old_fk_fields))
if old_fk_field.get("db_constraint"):
ref_describe = cast(dict, old_models[old_fk_field["python_type"]])
sql = cls._drop_fk(model, old_fk_field, ref_describe)
cls._add_operator(sql, upgrade, fk_m2m_index=True)

@classmethod
def diff_models(
cls, old_models: Dict[str, dict], new_models: Dict[str, dict], upgrade=True
Expand Down Expand Up @@ -334,6 +367,13 @@ def diff_models(
# current only support rename pk
if action == "change" and option == "name":
cls._add_operator(cls._rename_field(model, *change), upgrade)
# fk fields
args = (old_model_describe, new_model_describe, model, old_models, new_models)
cls._handle_fk_fields(*args, upgrade=upgrade)
# o2o fields
cls._handle_fk_fields(*args, upgrade=upgrade, is_o2o=True)
old_o2o_columns = [i["raw_field"] for i in old_model_describe.get("o2o_fields", [])]
new_o2o_columns = [i["raw_field"] for i in new_model_describe.get("o2o_fields", [])]
# m2m fields
cls._handle_m2m_fields(
old_model_describe, new_model_describe, model, new_models, upgrade
Expand Down Expand Up @@ -424,7 +464,10 @@ def diff_models(
),
upgrade,
)
if new_data_field["indexed"]:
if (
new_data_field["indexed"]
and new_data_field["db_column"] not in new_o2o_columns
):
cls._add_operator(
cls._add_index(
model, (new_data_field["db_column"],), new_data_field["unique"]
Expand All @@ -447,46 +490,17 @@ def diff_models(
cls._remove_field(model, db_column),
upgrade,
)
if old_data_field["indexed"]:
if (
old_data_field["indexed"]
and old_data_field["db_column"] not in old_o2o_columns
):
is_unique_field = old_data_field.get("unique")
cls._add_operator(
cls._drop_index(model, {db_column}, is_unique_field),
upgrade,
True,
)

old_fk_fields = cast(List[dict], old_model_describe.get("fk_fields"))
new_fk_fields = cast(List[dict], new_model_describe.get("fk_fields"))

old_fk_fields_name: List[str] = [i.get("name", "") for i in old_fk_fields]
new_fk_fields_name: List[str] = [i.get("name", "") for i in new_fk_fields]

# add fk
for new_fk_field_name in set(new_fk_fields_name).difference(
set(old_fk_fields_name)
):
fk_field = cls.get_field_by_name(new_fk_field_name, new_fk_fields)
if fk_field.get("db_constraint"):
ref_describe = cast(dict, new_models[fk_field["python_type"]])
cls._add_operator(
cls._add_fk(model, fk_field, ref_describe),
upgrade,
fk_m2m_index=True,
)
# drop fk
for old_fk_field_name in set(old_fk_fields_name).difference(
set(new_fk_fields_name)
):
old_fk_field = cls.get_field_by_name(
old_fk_field_name, cast(List[dict], old_fk_fields)
)
if old_fk_field.get("db_constraint"):
ref_describe = cast(dict, old_models[old_fk_field["python_type"]])
cls._add_operator(
cls._drop_fk(model, old_fk_field, ref_describe),
upgrade,
fk_m2m_index=True,
)
# change fields
for field_name in set(new_data_fields_name).intersection(set(old_data_fields_name)):
old_data_field = cls.get_field_by_name(field_name, old_data_fields)
Expand Down

0 comments on commit fe58032

Please sign in to comment.