Skip to content

Commit

Permalink
Adding DB_ENGINE_URI_VALIDATOR
Browse files Browse the repository at this point in the history
  • Loading branch information
craig-rueda committed Apr 1, 2024
1 parent ca47717 commit f0ac4e1
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 0 deletions.
11 changes: 11 additions & 0 deletions superset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1206,6 +1206,17 @@ def CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC( # pylint: disable=invalid-name
DB_CONNECTION_MUTATOR = None


# A callable that is invoked for every invocation of DB Engine Specs
# which allows for custom validation of the engine URI.
# See: superset.db_engine_specs.base.BaseEngineSpec.validate_database_uri
# Example:
# def DB_ENGINE_URI_VALIDATOR(sqlalchemy_uri: URL):
# if not <some condition>:
# raise Exception("URI invalid")
#
DB_ENGINE_URI_VALIDATOR = None


# A function that intercepts the SQL to be executed and can alter it.
# The use case is can be around adding some sort of comment header
# with information such as the username and worker node information
Expand Down
3 changes: 3 additions & 0 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1956,6 +1956,9 @@ def validate_database_uri(cls, sqlalchemy_uri: URL) -> None:
:param sqlalchemy_uri:
"""
if db_engine_uri_validator := current_app.config["DB_ENGINE_URI_VALIDATOR"]:
db_engine_uri_validator(sqlalchemy_uri)

Check warning on line 1960 in superset/db_engine_specs/base.py

View check run for this annotation

Codecov / codecov/patch

superset/db_engine_specs/base.py#L1960

Added line #L1960 was not covered by tests

if existing_disallowed := cls.disallow_uri_query_params.get(
sqlalchemy_uri.get_driver_name(), set()
).intersection(sqlalchemy_uri.query):
Expand Down
19 changes: 19 additions & 0 deletions tests/unit_tests/db_engine_specs/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@

from textwrap import dedent
from typing import Any, Optional
from unittest.mock import patch

import pytest
from pytest_mock import MockFixture
from sqlalchemy import types
from sqlalchemy.dialects import sqlite
from sqlalchemy.engine.url import URL
from sqlalchemy.sql import sqltypes

from superset.superset_typing import ResultSetColumnType, SQLAColumnType
Expand Down Expand Up @@ -69,6 +71,23 @@ def test_parse_sql_multi_statement() -> None:
]


@patch("superset.db_engine_specs.base.current_app")
def test_validate_db_uri(current_app) -> None:
"""
Ensures that the `validate_database_uri` method invokes the validator correctly
"""

def mock_validate(sqlalchemy_uri: URL) -> None:
raise ValueError("Invalid URI")

current_app.config = {"DB_ENGINE_URI_VALIDATOR": mock_validate}

from superset.db_engine_specs.base import BaseEngineSpec

with pytest.raises(ValueError):
BaseEngineSpec.validate_database_uri(URL.create("sqlite"))


@pytest.mark.parametrize(
"original,expected",
[
Expand Down

0 comments on commit f0ac4e1

Please sign in to comment.