Skip to content

Commit

Permalink
Add Rename support
Browse files Browse the repository at this point in the history
  • Loading branch information
long2ice committed Sep 25, 2020
1 parent af4d4be commit 141d720
Show file tree
Hide file tree
Showing 11 changed files with 238 additions and 192 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

- Raise error with SQLite unsupported features.
- Fix Postgres alter table. (#48)
- Add `Rename` support.

### 0.2.3

Expand Down
6 changes: 3 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,14 @@ test_sqlite:
$(py_warn) TEST_DB=sqlite://:memory: py.test

test_mysql:
$(py_warn) TEST_DB="mysql://root:$(MYSQL_PASS)@$(MYSQL_HOST):$(MYSQL_PORT)/test_\{\}" py.test
$(py_warn) TEST_DB="mysql://root:$(MYSQL_PASS)@$(MYSQL_HOST):$(MYSQL_PORT)/test_\{\}" pytest -v -s

test_postgres:
$(py_warn) TEST_DB="postgres://postgres:$(POSTGRES_PASS)@$(POSTGRES_HOST):$(POSTGRES_PORT)/test_\{\}" py.test
$(py_warn) TEST_DB="postgres://postgres:$(POSTGRES_PASS)@$(POSTGRES_HOST):$(POSTGRES_PORT)/test_\{\}" pytest

testall: deps test_sqlite test_postgres test_mysql

build: deps
@poetry build

ci: check testall
ci: check testall
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ Success migrate 1_202029051520102929_drop_column.json
Format of migrate filename is
`{version_num}_{datetime}_{name|update}.json`

And if `aerich` guess you are renaming a column, it will ask `Rename {old_column} to {new_column} [True]`, you can choice `True` to choice rename column without column drop, or choice `False` to drop column then create.

### Upgrade to latest version

```shell
Expand Down
27 changes: 1 addition & 26 deletions aerich/cli.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import functools
import json
import os
import sys
from configparser import ConfigParser
from enum import Enum

import asyncclick as click
from asyncclick import Context, UsageError
Expand All @@ -16,28 +14,12 @@
from aerich.utils import get_app_connection, get_app_connection_name, get_tortoise_config

from . import __version__
from .enums import Color
from .models import Aerich


class Color(str, Enum):
green = "green"
red = "red"
yellow = "yellow"


parser = ConfigParser()


def close_db(func):
@functools.wraps(func)
async def close_db_inner(*args, **kwargs):
result = await func(*args, **kwargs)
await Tortoise.close_connections()
return result

return close_db_inner


@click.group(context_settings={"help_option_names": ["-h", "--help"]})
@click.version_option(__version__, "-V", "--version")
@click.option(
Expand Down Expand Up @@ -81,12 +63,10 @@ async def cli(ctx: Context, config, app, name):
@cli.command(help="Generate migrate changes file.")
@click.option("--name", default="update", show_default=True, help="Migrate name.")
@click.pass_context
@close_db
async def migrate(ctx: Context, name):
config = ctx.obj["config"]
location = ctx.obj["location"]
app = ctx.obj["app"]

ret = await Migrate.migrate(name)
if not ret:
return click.secho("No changes detected", fg=Color.yellow)
Expand All @@ -96,7 +76,6 @@ async def migrate(ctx: Context, name):

@cli.command(help="Upgrade to latest version.")
@click.pass_context
@close_db
async def upgrade(ctx: Context):
config = ctx.obj["config"]
app = ctx.obj["app"]
Expand All @@ -123,7 +102,6 @@ async def upgrade(ctx: Context):

@cli.command(help="Downgrade to previous version.")
@click.pass_context
@close_db
async def downgrade(ctx: Context):
app = ctx.obj["app"]
config = ctx.obj["config"]
Expand All @@ -146,7 +124,6 @@ async def downgrade(ctx: Context):

@cli.command(help="Show current available heads in migrate location.")
@click.pass_context
@close_db
async def heads(ctx: Context):
app = ctx.obj["app"]
versions = Migrate.get_all_version_files()
Expand All @@ -161,7 +138,6 @@ async def heads(ctx: Context):

@cli.command(help="List all migrate items.")
@click.pass_context
@close_db
async def history(ctx: Context):
versions = Migrate.get_all_version_files()
for version in versions:
Expand Down Expand Up @@ -212,7 +188,6 @@ async def init(
show_default=True,
)
@click.pass_context
@close_db
async def init_db(ctx: Context, safe):
config = ctx.obj["config"]
location = ctx.obj["location"]
Expand Down
10 changes: 10 additions & 0 deletions aerich/ddl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ class BaseDDL:
_DROP_TABLE_TEMPLATE = 'DROP TABLE IF EXISTS "{table_name}"'
_ADD_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" ADD {column}'
_DROP_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" DROP COLUMN "{column_name}"'
_RENAME_COLUMN_TEMPLATE = (
'ALTER TABLE "{table_name}" RENAME COLUMN "{old_column_name}" TO "{new_column_name}"'
)
_ADD_INDEX_TEMPLATE = (
'ALTER TABLE "{table_name}" ADD {unique} INDEX "{index_name}" ({column_names})'
)
Expand Down Expand Up @@ -125,6 +128,13 @@ def modify_column(self, model: "Type[Model]", field_object: Field):
),
)

def rename_column(self, model: "Type[Model]", old_column_name: str, new_column_name: str):
return self._RENAME_COLUMN_TEMPLATE.format(
table_name=model._meta.db_table,
old_column_name=old_column_name,
new_column_name=new_column_name,
)

def add_index(self, model: "Type[Model]", field_names: List[str], unique=False):
return self._ADD_INDEX_TEMPLATE.format(
unique="UNIQUE" if unique else "",
Expand Down
3 changes: 3 additions & 0 deletions aerich/ddl/mysql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ class MysqlDDL(BaseDDL):
_DROP_TABLE_TEMPLATE = "DROP TABLE IF EXISTS `{table_name}`"
_ADD_COLUMN_TEMPLATE = "ALTER TABLE `{table_name}` ADD {column}"
_DROP_COLUMN_TEMPLATE = "ALTER TABLE `{table_name}` DROP COLUMN `{column_name}`"
_RENAME_COLUMN_TEMPLATE = (
"ALTER TABLE `{table_name}` RENAME COLUMN `{old_column_name}` TO `{new_column_name}`"
)
_ADD_INDEX_TEMPLATE = (
"ALTER TABLE `{table_name}` ADD {unique} INDEX `{index_name}` ({column_names})"
)
Expand Down
7 changes: 7 additions & 0 deletions aerich/enums.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from enum import Enum


class Color(str, Enum):
green = "green"
red = "red"
yellow = "yellow"
52 changes: 44 additions & 8 deletions aerich/migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from importlib import import_module
from typing import Dict, List, Tuple, Type

import click
from tortoise import (
BackwardFKRelation,
BackwardOneToOneRelation,
Expand All @@ -28,6 +29,8 @@ class Migrate:
_upgrade_m2m: List[str] = []
_downgrade_m2m: List[str] = []
_aerich = Aerich.__name__
_rename_old = []
_rename_new = []

ddl: BaseDDL
migrate_config: dict
Expand Down Expand Up @@ -264,11 +267,37 @@ def diff_model(cls, old_model: Type[Model], new_model: Type[Model], upgrade=True
if cls._exclude_field(new_field, upgrade):
continue
if new_key not in old_keys:
cls._add_operator(
cls._add_field(new_model, new_field),
upgrade,
isinstance(new_field, (ForeignKeyFieldInstance, ManyToManyFieldInstance)),
)
new_field_dict = new_field.describe(serializable=True)
new_field_dict.pop("name")
new_field_dict.pop("db_column")
for diff_key in old_keys - new_keys:
old_field = old_fields_map.get(diff_key)
old_field_dict = old_field.describe(serializable=True)
old_field_dict.pop("name")
old_field_dict.pop("db_column")
if old_field_dict == new_field_dict:
if upgrade:
is_rename = click.prompt(
f"Rename {diff_key} to {new_key}",
default=True,
type=bool,
show_choices=True,
)
cls._rename_new.append(new_key)
cls._rename_old.append(diff_key)
else:
is_rename = diff_key in cls._rename_new
if is_rename:
cls._add_operator(
cls._rename_field(new_model, old_field, new_field), upgrade,
)
break
else:
cls._add_operator(
cls._add_field(new_model, new_field),
upgrade,
isinstance(new_field, (ForeignKeyFieldInstance, ManyToManyFieldInstance)),
)
else:
old_field = old_fields_map.get(new_key)
new_field_dict = new_field.describe(serializable=True)
Expand Down Expand Up @@ -319,9 +348,12 @@ def diff_model(cls, old_model: Type[Model], new_model: Type[Model], upgrade=True
for old_key in old_keys:
field = old_fields_map.get(old_key)
if old_key not in new_keys and not cls._exclude_field(field, upgrade):
cls._add_operator(
cls._remove_field(old_model, field), upgrade, cls._is_fk_m2m(field),
)
if (upgrade and old_key not in cls._rename_old) or (
not upgrade and old_key not in cls._rename_new
):
cls._add_operator(
cls._remove_field(old_model, field), upgrade, cls._is_fk_m2m(field),
)

for new_index in new_indexes:
if new_index not in old_indexes:
Expand Down Expand Up @@ -413,6 +445,10 @@ def _remove_field(cls, model: Type[Model], field: Field):
return cls.ddl.drop_m2m(field)
return cls.ddl.drop_column(model, field.model_field_name)

@classmethod
def _rename_field(cls, model: Type[Model], old_field: Field, new_field: Field):
return cls.ddl.rename_column(model, old_field.model_field_name, new_field.model_field_name)

@classmethod
def _add_fk(cls, model: Type[Model], field: ForeignKeyFieldInstance):
"""
Expand Down
Loading

0 comments on commit 141d720

Please sign in to comment.