Skip to content

Commit

Permalink
add overwrite option
Browse files Browse the repository at this point in the history
  • Loading branch information
weaverba137 committed Jan 24, 2025
1 parent 44621bb commit 5560758
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
6 changes: 4 additions & 2 deletions dlairflow/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def pg_restore_schema(connection, schema, dump_dir=None):
append_env=True)


def q3c_index(connection, schema, table, ra='ra', dec='dec'):
def q3c_index(connection, schema, table, ra='ra', dec='dec', overwrite=False):
"""Create a q3c index on `schema`.`table`.
Parameters
Expand All @@ -111,6 +111,8 @@ def q3c_index(connection, schema, table, ra='ra', dec='dec'):
Name of the column containing Right Ascension, default 'ra'.
dec : :class:`str`, optional
Name of the column containing Declination, default 'dec'.
overwrite : :class:`bool`, optional
If ``True`` replace any existing SQL template file.
Returns
-------
Expand All @@ -119,7 +121,7 @@ def q3c_index(connection, schema, table, ra='ra', dec='dec'):
"""
sql_dir = ensure_sql()
sql_file = os.path.join(sql_dir, "dlairflow.postgresql.q3c_index.sql")
if not os.path.exists(sql_file):
if overwrite or not os.path.exists(sql_file):
sql_data = """--
-- Created by dlairflow.postgresql.q3c_index().
--
Expand Down
7 changes: 4 additions & 3 deletions dlairflow/test/test_postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def test_pg_dump_schema(monkeypatch, temporary_airflow_home, task_function, dump
else:
assert test_operator.params['dump_dir'] == 'dump_dir'


def test_q3c_index(monkeypatch, temporary_airflow_home):
@pytest.mark.parametrize('overwrite', [(False, ), (True, )])
def test_q3c_index(monkeypatch, temporary_airflow_home, overwrite):
"""Test the q3c_index function.
"""
#
Expand All @@ -85,7 +85,8 @@ def test_q3c_index(monkeypatch, temporary_airflow_home):
p = import_module('..postgresql', package='dlairflow.test')

tf = p.__dict__['q3c_index']
test_operator = tf("login,password,host,schema", 'q3c_schema', 'q3c_table')
test_operator = tf("login,password,host,schema", 'q3c_schema', 'q3c_table',
overwrite=overwrite)
assert isinstance(test_operator, PostgresOperator)
assert os.path.exists(str(temporary_airflow_home / 'dags' / 'sql' / 'dlairflow.postgresql.q3c_index.sql'))
assert test_operator.task_id == 'q3c_index'
Expand Down

0 comments on commit 5560758

Please sign in to comment.