Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix db migration tools #4

Merged
merged 6 commits into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 50 additions & 4 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,33 @@ jobs:
- name: Lint
run: npx nx run-many --target=lint

test:
runs-on: ubuntu-latest
test-py:
runs-on: ${{ matrix.os }}
strategy:
matrix:
python-version: [ 3.11 ]
os: [ ubuntu-latest ]
postgres-version: [ 15-bookworm ]
services:
postgres:
image: postgres:${{ matrix.postgres-version }}
env:
POSTGRES_PASSWORD: "password"
POSTGRES_USER: "sftkit"
POSTGRES_DB: "sftkit_test"
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
ports:
- 5432:5432
env:
SFTKIT_TEST_DB_USER: "sftkit"
SFTKIT_TEST_DB_HOST: "localhost"
SFTKIT_TEST_DB_PORT: "5432"
SFTKIT_TEST_DB_DBNAME: "sftkit_test"
SFTKIT_TEST_DB_PASSWORD: "password"
steps:
- uses: actions/checkout@v4

Expand All @@ -49,13 +74,34 @@ jobs:
- name: Set up Python with PDM
uses: pdm-project/setup-pdm@v3
with:
python-version: "3.11"
python-version: ${{ matrix.python-version }}

- name: Install Python dependencies
run: pdm sync -d

- name: Test
run: npx nx run-many --target=test
run: npx nx run-many --target=test --projects=tag:lang:python

test-js:
runs-on: ubuntu-latest
strategy:
matrix:
node-version: [ 20 ]
steps:
- uses: actions/checkout@v4

- name: Set up Nodejs
uses: actions/setup-node@v4
with:
node-version: ${{ matrix.node-version}}
cache: "npm"
cache-dependency-path: package-lock.json

- name: Install node dependencies
run: npm ci

- name: Test
run: npx nx run-many --target=test --projects=tag:lang:javascript

build:
runs-on: ubuntu-latest
Expand Down
977 changes: 179 additions & 798 deletions pdm.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions sftkit/project.json
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
{
"name": "sftkit",
"$schema": "../../node_modules/nx/schemas/project-schema.json",
"$schema": "../node_modules/nx/schemas/project-schema.json",
"sourceRoot": "sftkit/sftkit",
"projectType": "library",
"tags": [],
"tags": ["lang:python"],
"targets": {
"typecheck": {
"executor": "nx:run-commands",
Expand Down
12 changes: 6 additions & 6 deletions sftkit/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

[build-system]
requires = ["pdm-backend"]
build-backend = "pdm.backend"
Expand All @@ -17,11 +16,11 @@ classifiers = [
]
requires-python = ">=3.11"
dependencies = [
"fastapi>=0.111.0",
"typer>=0.12.3",
"uvicorn[standard]>=0.22.0",
"asyncpg>=0.29.0",
"pydantic[email]==2.7.4",
"fastapi>=0.115.6",
"typer>=0.15.1",
"uvicorn>=0.34.0",
"asyncpg>=0.30.0",
"pydantic[email]==2.10.4",
]

[project.urls]
Expand All @@ -35,6 +34,7 @@ source = ["sftkit"]

[tool.pytest.ini_options]
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "session"
minversion = "6.0"
testpaths = ["tests"]

Expand Down
88 changes: 34 additions & 54 deletions sftkit/sftkit/database/_migrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@

import asyncpg

from sftkit.database import Connection
from sftkit.database.introspection import list_constraints, list_functions, list_triggers, list_views

logger = logging.getLogger(__name__)

MIGRATION_VERSION_RE = re.compile(r"^-- migration: (?P<version>\w+)$")
MIGRATION_REQURES_RE = re.compile(r"^-- requires: (?P<version>\w+)$")
MIGRATION_TABLE = "schema_revision"


async def _run_postgres_code(conn: asyncpg.Connection, code: str, file_name: Path):
async def _run_postgres_code(conn: Connection, code: str, file_name: Path):
if all(line.startswith("--") for line in code.splitlines()):
return
try:
Expand All @@ -32,33 +35,23 @@ async def _run_postgres_code(conn: asyncpg.Connection, code: str, file_name: Pat
raise ValueError(f"Syntax or Access error when executing SQL code ({file_name!s}): {message!r}") from exc


async def _drop_all_views(conn: asyncpg.Connection, schema: str):
async def _drop_all_views(conn: Connection, schema: str):
# TODO: we might have to find out the dependency order of the views if drop cascade does not work
result = await conn.fetch(
"select table_name from information_schema.views where table_schema = $1 and table_name !~ '^pg_';",
schema,
)
views = [row["table_name"] for row in result]
views = await list_views(conn, schema)
if len(views) == 0:
return

# we use drop if exists here as the cascade dropping might lead the view to being already dropped
# due to being a dependency of another view
drop_statements = "\n".join([f"drop view if exists {view} cascade;" for view in views])
drop_statements = "\n".join([f"drop view if exists {view.table_name} cascade;" for view in views])
await conn.execute(drop_statements)


async def _drop_all_triggers(conn: asyncpg.Connection, schema: str):
result = await conn.fetch(
"select distinct on (trigger_name, event_object_table) trigger_name, event_object_table "
"from information_schema.triggers where trigger_schema = $1",
schema,
)
async def _drop_all_triggers(conn: Connection, schema: str):
triggers = await list_triggers(conn, schema)
statements = []
for row in result:
trigger_name = row["trigger_name"]
table = row["event_object_table"]
statements.append(f"drop trigger {trigger_name} on {table};")
for trigger in triggers:
statements.append(f'drop trigger "{trigger.trigger_name}" on "{trigger.event_object_table}";')

if len(statements) == 0:
return
Expand All @@ -67,27 +60,20 @@ async def _drop_all_triggers(conn: asyncpg.Connection, schema: str):
await conn.execute(drop_statements)


async def _drop_all_functions(conn: asyncpg.Connection, schema: str):
result = await conn.fetch(
"select proname, pg_get_function_identity_arguments(oid) as signature, prokind from pg_proc "
"where pronamespace = $1::regnamespace;",
schema,
)
async def _drop_all_functions(conn: Connection, schema: str):
funcs = await list_functions(conn, schema)
drop_statements = []
for row in result:
kind = row["prokind"].decode("utf-8")
name = row["proname"]
signature = row["signature"]
if kind in ("f", "w"):
for func in funcs:
if func.prokind in ("f", "w"):
drop_type = "function"
elif kind == "a":
elif func.prokind == "a":
drop_type = "aggregate"
elif kind == "p":
elif func.prokind == "p":
drop_type = "procedure"
else:
raise RuntimeError(f'Unknown postgres function type "{kind}"')
raise RuntimeError(f'Unknown postgres function type "{func.prokind}"')

drop_statements.append(f"drop {drop_type} {name}({signature}) cascade;")
drop_statements.append(f'drop {drop_type} "{func.proname}"({func.signature}) cascade;')

if len(drop_statements) == 0:
return
Expand All @@ -96,37 +82,31 @@ async def _drop_all_functions(conn: asyncpg.Connection, schema: str):
await conn.execute(drop_code)


async def _drop_all_constraints(conn: asyncpg.Connection, schema: str):
async def _drop_all_constraints(conn: Connection, schema: str):
"""drop all constraints in the given schema which are not unique, primary or foreign key constraints"""
result = await conn.fetch(
"select con.conname as constraint_name, rel.relname as table_name, con.contype as constraint_type "
"from pg_catalog.pg_constraint con "
" join pg_catalog.pg_namespace nsp on nsp.oid = con.connamespace "
" left join pg_catalog.pg_class rel on rel.oid = con.conrelid "
"where nsp.nspname = $1 and con.conname !~ '^pg_' "
" and con.contype != 'p' and con.contype != 'f' and con.contype != 'u';",
schema,
)
constraints = []
for row in result:
constraint_name = row["constraint_name"]
constraint_type = row["constraint_type"].decode("utf-8")
table_name = row["table_name"]
constraints = await list_constraints(conn, schema)
drop_statements = []
for constraint in constraints:
constraint_name = constraint.conname
constraint_type = constraint.contype
table_name = constraint.relname
if constraint_type in ("p", "f", "u"):
continue
if constraint_type == "c":
constraints.append(f"alter table {table_name} drop constraint {constraint_name};")
drop_statements.append(f'alter table "{table_name}" drop constraint "{constraint_name}";')
elif constraint_type == "t":
constraints.append(f"drop constraint trigger {constraint_name};")
drop_statements.append(f"drop constraint trigger {constraint_name};")
else:
raise RuntimeError(f'Unknown constraint type "{constraint_type}" for constraint "{constraint_name}"')

if len(constraints) == 0:
if len(drop_statements) == 0:
return

drop_statements = "\n".join(constraints)
await conn.execute(drop_statements)
drop_cmd = "\n".join(drop_statements)
await conn.execute(drop_cmd)


async def _drop_db_code(conn: asyncpg.Connection, schema: str):
async def _drop_db_code(conn: Connection, schema: str):
await _drop_all_triggers(conn, schema=schema)
await _drop_all_functions(conn, schema=schema)
await _drop_all_views(conn, schema=schema)
Expand Down
89 changes: 89 additions & 0 deletions sftkit/sftkit/database/introspection/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from pydantic import BaseModel

from sftkit.database import Connection


class PgFunctionDef(BaseModel):
proname: str
pronamespace: int # oid
proowner: int # oid
prolang: int # oid
procost: float
prorows: int
provariadic: int # oid
prosupport: str
prokind: str
prosecdef: bool
proleakproof: bool
proisstrict: bool
proretset: bool
provolatile: str
proparallel: str
pronargs: int
pronargdefaults: int
prorettype: int # oid
proargtypes: list[int] # oid
proallargtypes: list[int] | None # oid
proargmodes: list[str] | None
proargnames: list[str] | None
# proargdefaults: pg_node_tree | None
protrftypes: list[str] | None
prosrc: str
probin: str | None
# prosqlbody: pg_node_tree | None
proconfig: list[str] | None
proacl: list[str] | None
signature: str


async def list_functions(conn: Connection, schema: str) -> list[PgFunctionDef]:
return await conn.fetch_many(
PgFunctionDef,
"select pg_proc.*, pg_get_function_identity_arguments(oid) as signature from pg_proc "
"where pronamespace = $1::regnamespace and pg_proc.proname !~ '^pg_';",
schema,
)


class PgViewDef(BaseModel):
table_name: str


async def list_views(conn: Connection, schema: str) -> list[PgViewDef]:
return await conn.fetch_many(
PgViewDef,
"select table_name from information_schema.views where table_schema = $1 and table_name !~ '^pg_';",
schema,
)


class PgTriggerDef(BaseModel):
trigger_name: str
event_object_table: str


async def list_triggers(conn: Connection, schema: str) -> list[PgTriggerDef]:
return await conn.fetch_many(
PgTriggerDef,
"select distinct on (trigger_name, event_object_table) trigger_name, event_object_table "
"from information_schema.triggers where trigger_schema = $1",
schema,
)


class PgConstraintDef(BaseModel):
conname: str
relname: str
contype: str


async def list_constraints(conn: Connection, schema: str) -> list[PgConstraintDef]:
return await conn.fetch_many(
PgConstraintDef,
"select con.conname, rel.relname, con.contype "
"from pg_catalog.pg_constraint con "
" join pg_catalog.pg_namespace nsp on nsp.oid = con.connamespace "
" left join pg_catalog.pg_class rel on rel.oid = con.conrelid "
"where nsp.nspname = $1 and con.conname !~ '^pg_';",
schema,
)
1 change: 1 addition & 0 deletions sftkit/tests/assets/minimal_db/code/constraints.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
alter table "user" add constraint username_allowlist check (name != 'exclusion');
13 changes: 13 additions & 0 deletions sftkit/tests/assets/minimal_db/code/functions.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
create or replace function test_func(
arg1 bigint,
arg2 text
) returns boolean as
$$
<<locals>> declare
tmp_var double precision;
begin
tmp_var = arg1 > 0 and arg2 != 'bla';
return tmp_var;
end;
$$ language plpgsql
set search_path = "$user", public;
14 changes: 14 additions & 0 deletions sftkit/tests/assets/minimal_db/code/triggers.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
create or replace function user_trigger() returns trigger as
$$
begin
return NEW;
end
$$ language plpgsql
stable
set search_path = "$user", public;

create trigger create_user_trigger
before insert
on "user"
for each row
execute function user_trigger();
6 changes: 6 additions & 0 deletions sftkit/tests/assets/minimal_db/code/views.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
create view user_with_post_count as
select u.*, author_counts.count
from "user" as u
join (
select p.author_id, count(*) as count from post as p group by p.author_id
) as author_counts on u.id = author_counts.author_id;
Loading
Loading