From 5560758ef9f18bf088fe38e6955d88f3e32db445 Mon Sep 17 00:00:00 2001 From: Benjamin Alan Weaver Date: Fri, 24 Jan 2025 11:27:55 -0700 Subject: [PATCH] add overwrite option --- dlairflow/postgresql.py | 6 ++++-- dlairflow/test/test_postgresql.py | 7 ++++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/dlairflow/postgresql.py b/dlairflow/postgresql.py index 80cca5c..d5a7dd8 100644 --- a/dlairflow/postgresql.py +++ b/dlairflow/postgresql.py @@ -96,7 +96,7 @@ def pg_restore_schema(connection, schema, dump_dir=None): append_env=True) -def q3c_index(connection, schema, table, ra='ra', dec='dec'): +def q3c_index(connection, schema, table, ra='ra', dec='dec', overwrite=False): """Create a q3c index on `schema`.`table`. Parameters @@ -111,6 +111,8 @@ def q3c_index(connection, schema, table, ra='ra', dec='dec'): Name of the column containing Right Ascension, default 'ra'. dec : :class:`str`, optional Name of the column containing Declination, default 'dec'. + overwrite : :class:`bool`, optional + If ``True`` replace any existing SQL template file. Returns ------- @@ -119,7 +121,7 @@ def q3c_index(connection, schema, table, ra='ra', dec='dec'): """ sql_dir = ensure_sql() sql_file = os.path.join(sql_dir, "dlairflow.postgresql.q3c_index.sql") - if not os.path.exists(sql_file): + if overwrite or not os.path.exists(sql_file): sql_data = """-- -- Created by dlairflow.postgresql.q3c_index(). -- diff --git a/dlairflow/test/test_postgresql.py b/dlairflow/test/test_postgresql.py index 469b080..119e79a 100644 --- a/dlairflow/test/test_postgresql.py +++ b/dlairflow/test/test_postgresql.py @@ -70,8 +70,8 @@ def test_pg_dump_schema(monkeypatch, temporary_airflow_home, task_function, dump else: assert test_operator.params['dump_dir'] == 'dump_dir' - -def test_q3c_index(monkeypatch, temporary_airflow_home): +@pytest.mark.parametrize('overwrite', [(False, ), (True, )]) +def test_q3c_index(monkeypatch, temporary_airflow_home, overwrite): """Test the q3c_index function. """ # @@ -85,7 +85,8 @@ def test_q3c_index(monkeypatch, temporary_airflow_home): p = import_module('..postgresql', package='dlairflow.test') tf = p.__dict__['q3c_index'] - test_operator = tf("login,password,host,schema", 'q3c_schema', 'q3c_table') + test_operator = tf("login,password,host,schema", 'q3c_schema', 'q3c_table', + overwrite=overwrite) assert isinstance(test_operator, PostgresOperator) assert os.path.exists(str(temporary_airflow_home / 'dags' / 'sql' / 'dlairflow.postgresql.q3c_index.sql')) assert test_operator.task_id == 'q3c_index'