Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(db_engine): Implement user impersonation support for StarRocks #28110

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion superset/db_engine_specs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ The table below (generated via `python superset/db_engine_specs/lib.py`) summari
| Masks/unmasks encrypted_extra | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | True | True | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False |
| Has column type mappings | False | False | False | False | False | True | False | False | False | False | True | False | True | True | True | True | True | True | False | False | True | False | False | False | False | False | False | False | False | False | False | False | False | False | True | True | True | False | False | False | True | True | False | False | False | False | False | True | False | True | False | True |
| Returns a list of function names | False | False | False | False | False | True | False | False | False | False | True | False | False | False | False | True | True | False | False | False | True | False | False | False | False | False | False | False | False | False | True | False | False | False | False | False | False | False | False | False | True | False | False | False | True | True | False | False | False | True | False | True |
| Supports user impersonation | False | False | False | True | False | True | False | False | False | False | True | False | False | False | False | False | False | False | False | False | True | False | False | False | False | False | False | False | False | False | True | False | False | False | False | False | False | False | False | False | True | False | False | False | False | False | False | False | False | True | False | False |
| Supports user impersonation | False | False | False | True | False | True | False | False | False | False | True | False | False | False | False | False | False | False | False | False | True | False | False | False | False | False | False | False | False | False | True | False | False | False | False | False | False | False | False | False | True | False | False | False | False | False | False | True | False | True | False | False |
| Support file upload | True | True | True | True | True | True | True | True | True | True | True | True | True | True | True | False | False | True | True | True | True | True | True | True | True | True | True | True | True | True | False | True | True | True | True | True | True | True | True | True | True | True | True | True | True | True | True | True | True | True | True | True |
| Returns extra table metadata | False | False | False | False | False | True | False | False | False | False | True | False | False | False | False | False | False | False | False | False | True | False | False | False | False | False | False | False | False | True | True | False | False | False | False | False | False | False | False | False | True | False | False | False | False | False | False | False | False | True | False | False |
| Maps driver exceptions to Superset exceptions | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False | False |
Expand Down
1 change: 1 addition & 0 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1408,6 +1408,7 @@ def adjust_engine_params( # pylint: disable=unused-argument
@classmethod
def get_prequeries(
cls,
database: Database, # pylint: disable=unused-argument
catalog: str | None = None, # pylint: disable=unused-argument
schema: str | None = None, # pylint: disable=unused-argument
) -> list[str]:
Expand Down
1 change: 1 addition & 0 deletions superset/db_engine_specs/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,7 @@ def get_default_catalog(
@classmethod
def get_prequeries(
cls,
database: Database,
catalog: str | None = None,
schema: str | None = None,
) -> list[str]:
Expand Down
2 changes: 2 additions & 0 deletions superset/db_engine_specs/db2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from superset.constants import TimeGrain
from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod
from superset.models.core import Database
from superset.sql_parse import Table

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -93,6 +94,7 @@ def get_table_comment(
@classmethod
def get_prequeries(
cls,
database: Database,
catalog: Union[str, None] = None,
schema: Union[str, None] = None,
) -> list[str]:
Expand Down
1 change: 1 addition & 0 deletions superset/db_engine_specs/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ def get_default_catalog(cls, database: Database) -> str | None:
@classmethod
def get_prequeries(
cls,
database: Database,
catalog: str | None = None,
schema: str | None = None,
) -> list[str]:
Expand Down
50 changes: 49 additions & 1 deletion superset/db_engine_specs/starrocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import logging
import re
from re import Pattern
from typing import Any, Optional
from typing import Any, Optional, Union
from urllib import parse

from flask_babel import gettext as __
Expand All @@ -28,6 +28,7 @@

from superset.db_engine_specs.mysql import MySQLEngineSpec
from superset.errors import SupersetErrorType
from superset.models.core import Database
from superset.utils.core import GenericDataType

# Regular expressions to catch custom errors
Expand Down Expand Up @@ -201,3 +202,50 @@ def get_schema_from_engine_params(
return None

return parse.unquote(database.split(".")[1])

@classmethod
def get_url_for_impersonation(
cls,
url: URL,
impersonate_user: bool,
username: Union[str, None] = None,
access_token: Union[str, None] = None,
) -> URL:
"""
Return a modified URL with the username set.

:param url: SQLAlchemy URL object
:param impersonate_user: Flag indicating if impersonation is enabled
:param username: Effective username
:param access_token: Personal access token
"""
# Leave URL unchanged. We will impersonate with the pre-query below.
return url

@classmethod
def get_prequeries(
cls,
database: Database,
catalog: Union[str, None] = None,
schema: Union[str, None] = None,
) -> list[str]:
"""
Return pre-session queries.

These are currently used as an alternative to ``adjust_engine_params`` for
databases where the selected schema cannot be specified in the SQLAlchemy URI or
connection arguments.

For example, in order to specify a default schema in RDS we need to run a query
at the beginning of the session:

sql> set search_path = my_schema;

"""
if database.impersonate_user:
username = database.get_effective_user(database.url_object)

if username:
return [f'EXECUTE AS "{username}" WITH NO REVERT;']

return []
1 change: 1 addition & 0 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,7 @@ def get_raw_connection(
# pre-session queries are used to set the selected schema and, in the
# future, the selected catalog
for prequery in self.db_engine_spec.get_prequeries(
database=self,
catalog=catalog,
schema=schema,
):
Expand Down
14 changes: 9 additions & 5 deletions tests/unit_tests/db_engine_specs/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,20 +247,24 @@ def test_convert_dttm(
assert_convert_dttm(spec, target_type, expected_result, dttm)


def test_get_prequeries() -> None:
def test_get_prequeries(mocker: MockerFixture) -> None:
"""
Test the ``get_prequeries`` method.
"""
from superset.db_engine_specs.databricks import DatabricksNativeEngineSpec

assert DatabricksNativeEngineSpec.get_prequeries() == []
assert DatabricksNativeEngineSpec.get_prequeries(schema="test") == [
database = mocker.MagicMock()

assert DatabricksNativeEngineSpec.get_prequeries(database) == []
assert DatabricksNativeEngineSpec.get_prequeries(database, schema="test") == [
"USE SCHEMA test",
]
assert DatabricksNativeEngineSpec.get_prequeries(catalog="test") == [
assert DatabricksNativeEngineSpec.get_prequeries(database, catalog="test") == [
"USE CATALOG test",
]
assert DatabricksNativeEngineSpec.get_prequeries(catalog="foo", schema="bar") == [
assert DatabricksNativeEngineSpec.get_prequeries(
database, catalog="foo", schema="bar"
) == [
"USE CATALOG foo",
"USE SCHEMA bar",
]
8 changes: 5 additions & 3 deletions tests/unit_tests/db_engine_specs/test_db2.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,15 @@ def test_get_table_comment_empty(mocker: MockerFixture):
)


def test_get_prequeries() -> None:
def test_get_prequeries(mocker: MockerFixture) -> None:
"""
Test the ``get_prequeries`` method.
"""
from superset.db_engine_specs.db2 import Db2EngineSpec

assert Db2EngineSpec.get_prequeries() == []
assert Db2EngineSpec.get_prequeries(schema="my_schema") == [
database = mocker.MagicMock()

assert Db2EngineSpec.get_prequeries(database) == []
assert Db2EngineSpec.get_prequeries(database, schema="my_schema") == [
'set current_schema "my_schema"'
]
8 changes: 5 additions & 3 deletions tests/unit_tests/db_engine_specs/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,16 @@ def test_get_schema_from_engine_params() -> None:
)


def test_get_prequeries() -> None:
def test_get_prequeries(mocker: MockerFixture) -> None:
"""
Test the ``get_prequeries`` method.
"""
from superset.db_engine_specs.postgres import PostgresEngineSpec

assert PostgresEngineSpec.get_prequeries() == []
assert PostgresEngineSpec.get_prequeries(schema="test") == [
database = mocker.MagicMock()

assert PostgresEngineSpec.get_prequeries(database) == []
assert PostgresEngineSpec.get_prequeries(database, schema="test") == [
'set search_path = "test"'
]

Expand Down
45 changes: 45 additions & 0 deletions tests/unit_tests/db_engine_specs/test_starrocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Any, Optional

import pytest
from pytest_mock import MockerFixture
from sqlalchemy import JSON, types
from sqlalchemy.engine.url import make_url

Expand Down Expand Up @@ -124,3 +125,47 @@ def test_get_schema_from_engine_params() -> None:
)
is None
)


def test_impersonation_username(mocker: MockerFixture) -> None:
"""
Test impersonation and make sure that `get_url_for_impersonation` leaves the URL
unchanged and that `get_prequeries` returns the appropriate impersonation query.
"""
from superset.db_engine_specs.starrocks import StarRocksEngineSpec

database = mocker.MagicMock()
database.impersonate_user = True
database.get_effective_user.return_value = "alice"

assert StarRocksEngineSpec.get_url_for_impersonation(
url=make_url("starrocks://service_user@localhost:9030/hive.default"),
impersonate_user=True,
username="alice",
access_token=None,
) == make_url("starrocks://service_user@localhost:9030/hive.default")

assert StarRocksEngineSpec.get_prequeries(database) == [
'EXECUTE AS "alice" WITH NO REVERT;'
]


def test_impersonation_disabled(mocker: MockerFixture) -> None:
"""
Test that impersonation is not applied when the feature is disabled in
`get_url_for_impersonation` and `get_prequeries`.
"""
from superset.db_engine_specs.starrocks import StarRocksEngineSpec

database = mocker.MagicMock()
database.impersonate_user = False
database.get_effective_user.return_value = "alice"

assert StarRocksEngineSpec.get_url_for_impersonation(
url=make_url("starrocks://service_user@localhost:9030/hive.default"),
impersonate_user=False,
username="alice",
access_token=None,
) == make_url("starrocks://service_user@localhost:9030/hive.default")

assert StarRocksEngineSpec.get_prequeries(database) == []
Loading