Skip to content

Commit

Permalink
Merge pull request #32 from psbleep/close-tortoise-connections
Browse files Browse the repository at this point in the history
Close database connections
  • Loading branch information
long2ice authored Jul 26, 2020
2 parents d74e7b5 + dfe13ea commit 18cb75f
Showing 1 changed file with 18 additions and 1 deletion.
19 changes: 18 additions & 1 deletion aerich/cli.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import json
import os
import sys
Expand Down Expand Up @@ -27,6 +28,16 @@ class Color(str, Enum):
parser = ConfigParser()


def close_db(func):
@functools.wraps(func)
async def close_db_inner(*args, **kwargs):
result = await func(*args, **kwargs)
await Tortoise.close_connections()
return result

return close_db_inner


@click.group(context_settings={"help_option_names": ["-h", "--help"]})
@click.version_option(__version__, "-V", "--version")
@click.option(
Expand Down Expand Up @@ -70,6 +81,7 @@ async def cli(ctx: Context, config, app, name):
@cli.command(help="Generate migrate changes file.")
@click.option("--name", default="update", show_default=True, help="Migrate name.")
@click.pass_context
@close_db
async def migrate(ctx: Context, name):
config = ctx.obj["config"]
location = ctx.obj["location"]
Expand All @@ -84,6 +96,7 @@ async def migrate(ctx: Context, name):

@cli.command(help="Upgrade to latest version.")
@click.pass_context
@close_db
async def upgrade(ctx: Context):
config = ctx.obj["config"]
app = ctx.obj["app"]
Expand All @@ -110,6 +123,7 @@ async def upgrade(ctx: Context):

@cli.command(help="Downgrade to previous version.")
@click.pass_context
@close_db
async def downgrade(ctx: Context):
app = ctx.obj["app"]
config = ctx.obj["config"]
Expand All @@ -132,6 +146,7 @@ async def downgrade(ctx: Context):

@cli.command(help="Show current available heads in migrate location.")
@click.pass_context
@close_db
async def heads(ctx: Context):
app = ctx.obj["app"]
versions = Migrate.get_all_version_files()
Expand All @@ -146,7 +161,8 @@ async def heads(ctx: Context):

@cli.command(help="List all migrate items.")
@click.pass_context
def history(ctx):
@close_db
async def history(ctx: Context):
versions = Migrate.get_all_version_files()
for version in versions:
click.secho(version, fg=Color.green)
Expand Down Expand Up @@ -196,6 +212,7 @@ async def init(
show_default=True,
)
@click.pass_context
@close_db
async def init_db(ctx: Context, safe):
config = ctx.obj["config"]
location = ctx.obj["location"]
Expand Down

0 comments on commit 18cb75f

Please sign in to comment.