Skip to content

Commit

Permalink
[Issue #1978] further optimizations for load chunking (#1984)
Browse files Browse the repository at this point in the history
## Summary
Fixes #1978

## Changes proposed
- Optimize chunked load further by moving chunking logic from database
to PostgreSQL.

## Context for reviewers
Instead of using `LIMIT` to carry out chunking in PostgreSQL, read the
full set of ids as a first step, then issue a series of INSERT / UPDATE
queries.

This is expected to be faster. With the previous method, the PostgreSQL
optimizer did not do an ideal plan, and did a full read of all rows and
columns from the Oracle database. By splitting the query, we can do a
read of only the id columns for the new or updated rows. Then additional
queries select only the rows that have changed.

## Additional information
N/A
  • Loading branch information
jamesbursa committed May 10, 2024
1 parent 425a91f commit fa38c32
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 95 deletions.
54 changes: 36 additions & 18 deletions api/src/data_migration/load/load_oracle_data_task.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#
# Load data from legacy (Oracle) tables to staging tables.
#

import itertools
import logging
import time

Expand Down Expand Up @@ -94,23 +94,23 @@ def load_data_for_table(self, table_name: str) -> None:
def do_insert(self, foreign_table: sqlalchemy.Table, staging_table: sqlalchemy.Table) -> int:
"""Determine new rows by primary key, and copy them into the staging table."""

insert_from_select_sql, select_sql = sql.build_insert_select_sql(
foreign_table, staging_table, self.insert_chunk_size
)
select_sql = sql.build_select_new_rows_sql(foreign_table, staging_table)
new_ids = self.db_session.execute(select_sql).all()

t0 = time.monotonic()
insert_chunk_count = []
while True:
insert_count = self.db_session.query(select_sql.subquery()).count()
if insert_count == 0:
break
for batch_of_new_ids in itertools.batched(new_ids, self.insert_chunk_size):
insert_from_select_sql = sql.build_insert_select_sql(
foreign_table, staging_table, batch_of_new_ids
)

# Execute the INSERT.
self.db_session.execute(insert_from_select_sql)

insert_chunk_count.append(insert_count)
insert_chunk_count.append(len(batch_of_new_ids))
logger.info(
"insert chunk done", extra={"count": insert_count, "total": sum(insert_chunk_count)}
"insert chunk done",
extra={"count": sum(insert_chunk_count), "total": len(new_ids)},
)

t1 = time.monotonic()
Expand All @@ -129,18 +129,36 @@ def do_insert(self, foreign_table: sqlalchemy.Table, staging_table: sqlalchemy.T
def do_update(self, foreign_table: sqlalchemy.Table, staging_table: sqlalchemy.Table) -> int:
"""Find updated rows using last_upd_date, copy them, and reset transformed_at to NULL."""

update_sql = sql.build_update_sql(foreign_table, staging_table).values(transformed_at=None)
select_sql = sql.build_select_updated_rows_sql(foreign_table, staging_table)
update_ids = self.db_session.execute(select_sql).all()

t0 = time.monotonic()
result = self.db_session.execute(update_sql)
t1 = time.monotonic()
update_count = result.rowcount
update_chunk_count = []
for batch_of_update_ids in itertools.batched(update_ids, self.insert_chunk_size):
update_sql = sql.build_update_sql(
foreign_table, staging_table, batch_of_update_ids
).values(transformed_at=None)

self.db_session.execute(update_sql)

update_chunk_count.append(len(batch_of_update_ids))
logger.info(
"update chunk done",
extra={"count": sum(update_chunk_count), "total": len(update_ids)},
)

self.increment("count.update.total", update_count)
self.set_metrics({f"count.update.{staging_table.name}": update_count})
self.set_metrics({f"time.update.{staging_table.name}": round(t1 - t0, 3)})
t1 = time.monotonic()
total_update_count = sum(update_chunk_count)
self.increment("count.update.total", total_update_count)
self.increment(f"count.update.{staging_table.name}", total_update_count)
self.set_metrics(
{
f"count.update.chunk.{staging_table.name}": ",".join(map(str, update_chunk_count)),
f"time.update.{staging_table.name}": round(t1 - t0, 3),
}
)

return update_count
return total_update_count

def do_mark_deleted(
self, foreign_table: sqlalchemy.Table, staging_table: sqlalchemy.Table
Expand Down
78 changes: 37 additions & 41 deletions api/src/data_migration/load/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,18 @@
# SQL building for data load process.
#

import sqlalchemy

from typing import Iterable

def build_insert_select_sql(
source_table: sqlalchemy.Table, destination_table: sqlalchemy.Table, limit: int = 1000
) -> tuple[sqlalchemy.Insert, sqlalchemy.Select]:
"""Build an `INSERT INTO ... SELECT ... FROM ...` query for new rows."""
import sqlalchemy

all_columns = tuple(c.name for c in source_table.columns)

# Optimization: use a Common Table Expression (`WITH`) marked as MATERIALIZED. This directs the PostgreSQL
# optimizer to run it first (prevents folding it into the parent query), so it only fetches the primary keys and
# last_upd_date columns from Oracle to perform the date comparison. Without this materialized CTE, it fetches all
# columns and all rows from Oracle before applying the WHERE, which is very slow for large tables.
#
# See https://www.postgresql.org/docs/current/queries-with.html#QUERIES-WITH-CTE-MATERIALIZATION
def build_select_new_rows_sql(
source_table: sqlalchemy.Table, destination_table: sqlalchemy.Table
) -> sqlalchemy.Select:
"""Build a `SELECT id1, id2, ... FROM <source_table>` query that finds new rows in source_table."""

# `WITH insert_pks AS MATERIALIZED (`
cte = (
# `SELECT id1, id2, id3, ... FROM <source_table>` (id1, id2, ... is the multipart primary key)
# `SELECT id1, id2, id3, ... FROM <source_table>` (id1, id2, ... is the multipart primary key)
return (
sqlalchemy.select(*source_table.primary_key.columns)
.where(
# `WHERE (id1, id2, id3, ...) NOT IN`
Expand All @@ -30,43 +22,42 @@ def build_insert_select_sql(
sqlalchemy.select(*destination_table.primary_key.columns)
)
)
.limit(limit)
.cte("insert_pks")
.prefix_with("MATERIALIZED")
.order_by(*source_table.primary_key.columns)
)


def build_insert_select_sql(
source_table: sqlalchemy.Table,
destination_table: sqlalchemy.Table,
ids: Iterable[tuple | sqlalchemy.Row],
) -> sqlalchemy.Insert:
"""Build an `INSERT INTO ... SELECT ... FROM ...` query for new rows."""

all_columns = tuple(c.name for c in source_table.columns)

# `SELECT col1, col2, ..., FALSE AS is_deleted FROM <source_table>`
select_sql = sqlalchemy.select(
source_table, sqlalchemy.literal_column("FALSE").label("is_deleted")
).where(
# `WHERE (id1, id2, ...)
# IN (SELECT insert_pks.id1, insert_pks.id2
# FROM insert_pks)`
sqlalchemy.tuple_(*source_table.primary_key.columns).in_(sqlalchemy.select(*cte.columns)),
# IN ((a1, a2), (b1, b2), ...)`
sqlalchemy.tuple_(*source_table.primary_key.columns).in_(ids),
)
# `INSERT INTO <destination_table> (col1, col2, ..., is_deleted) SELECT ...`
insert_from_select_sql = sqlalchemy.insert(destination_table).from_select(
all_columns + (destination_table.c.is_deleted,), select_sql
)

return insert_from_select_sql, select_sql
return insert_from_select_sql


def build_update_sql(
def build_select_updated_rows_sql(
source_table: sqlalchemy.Table, destination_table: sqlalchemy.Table
) -> sqlalchemy.Update:
"""Build an `UPDATE ... SET ... WHERE ...` statement for updated rows."""

# Optimization: use a Common Table Expression (`WITH`) marked as MATERIALIZED. This directs the PostgreSQL
# optimizer to run it first (prevents folding it into the parent query), so it only fetches the primary keys and
# last_upd_date columns from Oracle to perform the date comparison. Without this materialized CTE, it fetches all
# columns and all rows from Oracle before applying the WHERE, which is very slow for large tables.
#
# See https://www.postgresql.org/docs/current/queries-with.html#QUERIES-WITH-CTE-MATERIALIZATION
) -> sqlalchemy.Select:
"""Build a `SELECT id1, id2, ... FROM <source_table>` query that finds updated rows in source_table."""

# `WITH update_pks AS MATERIALIZED (`
cte = (
# `SELECT id1, id2, id3, ... FROM <destination_table>`
# `SELECT id1, id2, id3, ... FROM <destination_table>`
return (
sqlalchemy.select(*destination_table.primary_key.columns)
.join(
# `JOIN <source_table>
Expand All @@ -77,10 +68,17 @@ def build_update_sql(
)
# `WHERE ...`
.where(destination_table.c.last_upd_date < source_table.c.last_upd_date)
.cte("update_pks")
.prefix_with("MATERIALIZED")
.order_by(*source_table.primary_key.columns)
)


def build_update_sql(
source_table: sqlalchemy.Table,
destination_table: sqlalchemy.Table,
ids: Iterable[tuple | sqlalchemy.Row],
) -> sqlalchemy.Update:
"""Build an `UPDATE ... SET ... WHERE ...` statement for updated rows."""

return (
# `UPDATE <destination_table>`
sqlalchemy.update(destination_table)
Expand All @@ -90,9 +88,7 @@ def build_update_sql(
.where(
sqlalchemy.tuple_(*destination_table.primary_key.columns)
== sqlalchemy.tuple_(*source_table.primary_key.columns),
sqlalchemy.tuple_(*destination_table.primary_key.columns).in_(
sqlalchemy.select(*cte.columns)
),
sqlalchemy.tuple_(*destination_table.primary_key.columns).in_(ids),
)
)

Expand Down
65 changes: 29 additions & 36 deletions api/tests/src/data_migration/load/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,60 +38,53 @@ def destination_table(sqlalchemy_metadata):
)


def test_build_insert_select_sql(source_table, destination_table):
insert, select = sql.build_insert_select_sql(source_table, destination_table)
assert str(insert) == (
"WITH insert_pks AS MATERIALIZED \n"
"(SELECT test_source_table.id1 AS id1, test_source_table.id2 AS id2 \n"
def test_build_select_new_rows_sql(source_table, destination_table):
select = sql.build_select_new_rows_sql(source_table, destination_table)
assert str(select) == (
"SELECT test_source_table.id1, test_source_table.id2 \n"
"FROM test_source_table \n"
"WHERE ((test_source_table.id1, test_source_table.id2) "
"NOT IN (SELECT test_destination_table.id1, test_destination_table.id2 \n"
"FROM test_destination_table))\n "
"LIMIT :param_1)\n "
"INSERT INTO test_destination_table (id1, id2, x, last_upd_date, is_deleted) "
"SELECT test_source_table.id1, test_source_table.id2, test_source_table.x, "
"test_source_table.last_upd_date, FALSE AS is_deleted \n"
"FROM test_source_table \n"
"WHERE (test_source_table.id1, test_source_table.id2) IN "
"(SELECT insert_pks.id1, insert_pks.id2 \n"
"FROM insert_pks)"
"NOT IN ("
"SELECT test_destination_table.id1, test_destination_table.id2 \n"
"FROM test_destination_table)) "
"ORDER BY test_source_table.id1, test_source_table.id2"
)


def test_build_select_updated_rows_sql(source_table, destination_table):
select = sql.build_select_updated_rows_sql(source_table, destination_table)
assert str(select) == (
"WITH insert_pks AS MATERIALIZED \n"
"(SELECT test_source_table.id1 AS id1, test_source_table.id2 AS id2 \n"
"FROM test_source_table \n"
"WHERE ((test_source_table.id1, test_source_table.id2) "
"NOT IN (SELECT test_destination_table.id1, test_destination_table.id2 \n"
"FROM test_destination_table))\n "
"LIMIT :param_1)\n "
"SELECT test_destination_table.id1, test_destination_table.id2 \n"
"FROM test_destination_table "
"JOIN test_source_table ON "
"(test_destination_table.id1, test_destination_table.id2) = "
"(test_source_table.id1, test_source_table.id2) \n"
"WHERE test_destination_table.last_upd_date < test_source_table.last_upd_date "
"ORDER BY test_source_table.id1, test_source_table.id2"
)


def test_build_insert_select_sql(source_table, destination_table):
insert = sql.build_insert_select_sql(source_table, destination_table, [(1, 2), (3, 4), (5, 6)])
assert str(insert) == (
"INSERT INTO test_destination_table (id1, id2, x, last_upd_date, is_deleted) "
"SELECT test_source_table.id1, test_source_table.id2, test_source_table.x, "
"test_source_table.last_upd_date, FALSE AS is_deleted \n"
"FROM test_source_table \n"
"WHERE (test_source_table.id1, test_source_table.id2) IN "
"(SELECT insert_pks.id1, insert_pks.id2 \n"
"FROM insert_pks)"
"WHERE (test_source_table.id1, test_source_table.id2) IN (__[POSTCOMPILE_param_1])"
)


def test_build_update_sql(source_table, destination_table):
update = sql.build_update_sql(source_table, destination_table)
update = sql.build_update_sql(source_table, destination_table, [(1, 2), (3, 4), (5, 6)])
assert str(update) == (
"WITH update_pks AS MATERIALIZED \n"
"(SELECT test_destination_table.id1 AS id1, test_destination_table.id2 AS id2 \n"
"FROM test_destination_table "
"JOIN test_source_table "
"ON (test_destination_table.id1, test_destination_table.id2) = "
"(test_source_table.id1, test_source_table.id2) \n"
"WHERE test_destination_table.last_upd_date < "
"test_source_table.last_upd_date)\n "
"UPDATE test_destination_table "
"SET id1=test_source_table.id1, id2=test_source_table.id2, x=test_source_table.x, "
"last_upd_date=test_source_table.last_upd_date FROM test_source_table "
"WHERE (test_destination_table.id1, test_destination_table.id2) = "
"(test_source_table.id1, test_source_table.id2) AND "
"(test_destination_table.id1, test_destination_table.id2) "
"IN (SELECT update_pks.id1, update_pks.id2 \n"
"FROM update_pks)"
"IN (__[POSTCOMPILE_param_1])"
)


Expand Down

1 comment on commit fa38c32

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coverage report for ./frontend

St.
Category Percentage Covered / Total
🟢 Statements 84.11% 868/1032
🟡 Branches 65.01% 223/343
🟡 Functions 75.58% 164/217
🟢 Lines 84.15% 807/959

Test suite run success

164 tests passing in 56 suites.

Report generated by 🧪jest coverage report action from fa38c32

Please sign in to comment.