Skip to content

Commit

Permalink
feat(db): Adding DB_SQLA_URI_VALIDATOR (#27847)
Browse files Browse the repository at this point in the history
(cherry picked from commit 8bdf457)
  • Loading branch information
craig-rueda authored and sadpandajoe committed Apr 11, 2024
1 parent 1a04eb2 commit dd6c4fd
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 1 deletion.
12 changes: 12 additions & 0 deletions superset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from flask_caching.backends.base import BaseCache
from pandas import Series
from pandas._libs.parsers import STR_NA_VALUES # pylint: disable=no-name-in-module
from sqlalchemy.engine.url import URL
from sqlalchemy.orm.query import Query

from superset.advanced_data_type.plugins.internet_address import internet_address
Expand Down Expand Up @@ -1204,6 +1205,17 @@ def CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC( # pylint: disable=invalid-name
DB_CONNECTION_MUTATOR = None


# A callable that is invoked for every invocation of DB Engine Specs
# which allows for custom validation of the engine URI.
# See: superset.db_engine_specs.base.BaseEngineSpec.validate_database_uri
# Example:
# def DB_ENGINE_URI_VALIDATOR(sqlalchemy_uri: URL):
# if not <some condition>:
# raise Exception("URI invalid")
#
DB_SQLA_URI_VALIDATOR: Callable[[URL], None] | None = None


# A function that intercepts the SQL to be executed and can alter it.
# The use case is can be around adding some sort of comment header
# with information such as the username and worker node information
Expand Down
3 changes: 3 additions & 0 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1956,6 +1956,9 @@ def validate_database_uri(cls, sqlalchemy_uri: URL) -> None:
:param sqlalchemy_uri:
"""
if db_engine_uri_validator := current_app.config["DB_SQLA_URI_VALIDATOR"]:
db_engine_uri_validator(sqlalchemy_uri)

if existing_disallowed := cls.disallow_uri_query_params.get(
sqlalchemy_uri.get_driver_name(), set()
).intersection(sqlalchemy_uri.query):
Expand Down
20 changes: 20 additions & 0 deletions tests/unit_tests/db_engine_specs/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pytest_mock import MockFixture
from sqlalchemy import types
from sqlalchemy.dialects import sqlite
from sqlalchemy.engine.url import URL
from sqlalchemy.sql import sqltypes

from superset.superset_typing import ResultSetColumnType, SQLAColumnType
Expand Down Expand Up @@ -69,6 +70,25 @@ def test_parse_sql_multi_statement() -> None:
]


def test_validate_db_uri(mocker: MockFixture) -> None:
"""
Ensures that the `validate_database_uri` method invokes the validator correctly
"""

def mock_validate(sqlalchemy_uri: URL) -> None:
raise ValueError("Invalid URI")

mocker.patch(
"superset.db_engine_specs.base.current_app.config",
{"DB_SQLA_URI_VALIDATOR": mock_validate},
)

from superset.db_engine_specs.base import BaseEngineSpec

with pytest.raises(ValueError):
BaseEngineSpec.validate_database_uri(URL.create("sqlite"))


@pytest.mark.parametrize(
"original,expected",
[
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/extensions/test_sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def test_superset_limit(mocker: MockFixture, app_context: None, table1: None) ->
"""
mocker.patch(
"superset.extensions.metadb.current_app.config",
{"SUPERSET_META_DB_LIMIT": 1},
{"DB_SQLA_URI_VALIDATOR": None, "SUPERSET_META_DB_LIMIT": 1},
)
mocker.patch("superset.extensions.metadb.security_manager")

Expand Down

0 comments on commit dd6c4fd

Please sign in to comment.