Skip to content

Commit

Permalink
refactor: Moving get_user_datasources to security manager (apache#15467)
Browse files Browse the repository at this point in the history
Co-authored-by: John Bodley <[email protected]>
  • Loading branch information
john-bodley and John Bodley committed Jun 30, 2021
1 parent cad5ba8 commit ffa5175
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 121 deletions.
36 changes: 0 additions & 36 deletions superset/connectors/connector_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from collections import defaultdict
from typing import Dict, List, Optional, Set, Type, TYPE_CHECKING

from flask_babel import _
Expand Down Expand Up @@ -100,41 +99,6 @@ def get_datasource_by_id( # pylint: disable=too-many-arguments
pass
raise NoResultFound(_("Datasource id not found: %(id)s", id=datasource_id))

@classmethod
def get_user_datasources(cls, session: Session) -> List["BaseDatasource"]:
from superset import security_manager

# collect datasources which the user has explicit permissions to
user_perms = security_manager.user_view_menu_names("datasource_access")
schema_perms = security_manager.user_view_menu_names("schema_access")
user_datasources = set()
for datasource_class in ConnectorRegistry.sources.values():
user_datasources.update(
session.query(datasource_class)
.filter(
or_(
datasource_class.perm.in_(user_perms),
datasource_class.schema_perm.in_(schema_perms),
)
)
.all()
)

# group all datasources by database
all_datasources = cls.get_all_datasources(session)
datasources_by_database: Dict["Database", Set["BaseDatasource"]] = defaultdict(
set
)
for datasource in all_datasources:
datasources_by_database[datasource.database].add(datasource)

# add datasources with implicit permission (eg, database access)
for database, datasources in datasources_by_database.items():
if security_manager.can_access_database(database):
user_datasources.update(datasources)

return list(user_datasources)

@classmethod
def get_datasource_by_name( # pylint: disable=too-many-arguments
cls,
Expand Down
51 changes: 50 additions & 1 deletion superset/security/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,19 @@
"""A set of constants and methods to manage permissions and security"""
import logging
import re
from typing import Any, Callable, cast, List, Optional, Set, Tuple, TYPE_CHECKING, Union
from collections import defaultdict
from typing import (
Any,
Callable,
cast,
Dict,
List,
Optional,
Set,
Tuple,
TYPE_CHECKING,
Union,
)

from flask import current_app, g
from flask_appbuilder import Model
Expand Down Expand Up @@ -419,6 +431,43 @@ def get_table_access_link( # pylint: disable=unused-argument,no-self-use

return conf.get("PERMISSION_INSTRUCTIONS_LINK")

def get_user_datasources(self) -> List["BaseDatasource"]:
"""
Collect datasources which the user has explicit permissions to.
:returns: The list of datasources
"""

user_perms = self.user_view_menu_names("datasource_access")
schema_perms = self.user_view_menu_names("schema_access")
user_datasources = set()
for datasource_class in ConnectorRegistry.sources.values():
user_datasources.update(
self.get_session.query(datasource_class)
.filter(
or_(
datasource_class.perm.in_(user_perms),
datasource_class.schema_perm.in_(schema_perms),
)
)
.all()
)

# group all datasources by database
all_datasources = ConnectorRegistry.get_all_datasources(self.get_session)
datasources_by_database: Dict["Database", Set["BaseDatasource"]] = defaultdict(
set
)
for datasource in all_datasources:
datasources_by_database[datasource.database].add(datasource)

# add datasources with implicit permission (eg, database access)
for database, datasources in datasources_by_database.items():
if self.can_access_database(database):
user_datasources.update(datasources)

return list(user_datasources)

def can_access_table(self, database: "Database", table: "Table") -> bool:
"""
Return True if the user can access the SQL table, False otherwise.
Expand Down
5 changes: 2 additions & 3 deletions superset/views/chart/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@
from flask_appbuilder.models.sqla.interface import SQLAInterface
from flask_babel import lazy_gettext as _

from superset import db, is_feature_enabled
from superset.connectors.connector_registry import ConnectorRegistry
from superset import is_feature_enabled, security_manager
from superset.constants import MODEL_VIEW_RW_METHOD_PERMISSION_MAP, RouteMethod
from superset.models.slice import Slice
from superset.typing import FlaskResponse
Expand Down Expand Up @@ -65,7 +64,7 @@ def pre_delete(self, item: "SliceModelView") -> None:
def add(self) -> FlaskResponse:
datasources = [
{"value": str(d.id) + "__" + d.type, "label": repr(d)}
for d in ConnectorRegistry.get_user_datasources(db.session)
for d in security_manager.get_user_datasources()
]
payload = {
"datasources": sorted(
Expand Down
2 changes: 1 addition & 1 deletion superset/views/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def datasources(self) -> FlaskResponse:
sorted(
[
datasource.short_data
for datasource in ConnectorRegistry.get_user_datasources(db.session)
for datasource in security_manager.get_user_datasources()
if datasource.short_data.get("name")
],
key=lambda datasource: datasource["name"],
Expand Down
79 changes: 0 additions & 79 deletions tests/access_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
"""Unit tests for Superset"""
import json
import unittest
from collections import namedtuple
from unittest import mock
from tests.fixtures.birth_names_dashboard import load_birth_names_dashboard_with_slices

Expand Down Expand Up @@ -627,83 +626,5 @@ def test_request_access(self):
session.commit()


class TestDatasources(SupersetTestCase):
def test_get_user_datasources_admin(self):
Datasource = namedtuple("Datasource", ["database", "schema", "name"])

mock_session = mock.MagicMock()
mock_session.query.return_value.filter.return_value.all.return_value = []

with mock.patch("superset.security_manager") as mock_security_manager:
mock_security_manager.can_access_database.return_value = True

with mock.patch.object(
ConnectorRegistry, "get_all_datasources"
) as mock_get_all_datasources:
mock_get_all_datasources.return_value = [
Datasource("database1", "schema1", "table1"),
Datasource("database1", "schema1", "table2"),
Datasource("database2", None, "table1"),
]

datasources = ConnectorRegistry.get_user_datasources(mock_session)

assert sorted(datasources) == [
Datasource("database1", "schema1", "table1"),
Datasource("database1", "schema1", "table2"),
Datasource("database2", None, "table1"),
]

def test_get_user_datasources_gamma(self):
Datasource = namedtuple("Datasource", ["database", "schema", "name"])

mock_session = mock.MagicMock()
mock_session.query.return_value.filter.return_value.all.return_value = []

with mock.patch("superset.security_manager") as mock_security_manager:
mock_security_manager.can_access_database.return_value = False

with mock.patch.object(
ConnectorRegistry, "get_all_datasources"
) as mock_get_all_datasources:
mock_get_all_datasources.return_value = [
Datasource("database1", "schema1", "table1"),
Datasource("database1", "schema1", "table2"),
Datasource("database2", None, "table1"),
]

datasources = ConnectorRegistry.get_user_datasources(mock_session)

assert datasources == []

def test_get_user_datasources_gamma_with_schema(self):
Datasource = namedtuple("Datasource", ["database", "schema", "name"])

mock_session = mock.MagicMock()
mock_session.query.return_value.filter.return_value.all.return_value = [
Datasource("database1", "schema1", "table1"),
Datasource("database1", "schema1", "table2"),
]

with mock.patch("superset.security_manager") as mock_security_manager:
mock_security_manager.can_access_database.return_value = False

with mock.patch.object(
ConnectorRegistry, "get_all_datasources"
) as mock_get_all_datasources:
mock_get_all_datasources.return_value = [
Datasource("database1", "schema1", "table1"),
Datasource("database1", "schema1", "table2"),
Datasource("database2", None, "table1"),
]

datasources = ConnectorRegistry.get_user_datasources(mock_session)

assert sorted(datasources) == [
Datasource("database1", "schema1", "table1"),
Datasource("database1", "schema1", "table2"),
]


if __name__ == "__main__":
unittest.main()
88 changes: 87 additions & 1 deletion tests/security_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
import inspect
import re
import unittest

from collections import namedtuple
from unittest import mock
from unittest.mock import Mock, patch
from typing import Any, Dict

Expand Down Expand Up @@ -1220,3 +1221,88 @@ def test_access_request_enabled(self):
uri = "/accessrequestsmodelview/list/"
rv = self.client.get(uri)
self.assertLess(rv.status_code, 400)


class TestDatasources(SupersetTestCase):
@patch("superset.security.manager.g")
@patch("superset.security.SupersetSecurityManager.can_access_database")
@patch("superset.security.SupersetSecurityManager.get_session")
def test_get_user_datasources_admin(
self, mock_get_session, mock_can_access_database, mock_g
):
Datasource = namedtuple("Datasource", ["database", "schema", "name"])
mock_g.user = security_manager.find_user("admin")
mock_can_access_database.return_value = True
mock_get_session.query.return_value.filter.return_value.all.return_value = []

with mock.patch.object(
ConnectorRegistry, "get_all_datasources"
) as mock_get_all_datasources:
mock_get_all_datasources.return_value = [
Datasource("database1", "schema1", "table1"),
Datasource("database1", "schema1", "table2"),
Datasource("database2", None, "table1"),
]

datasources = security_manager.get_user_datasources()

assert sorted(datasources) == [
Datasource("database1", "schema1", "table1"),
Datasource("database1", "schema1", "table2"),
Datasource("database2", None, "table1"),
]

@patch("superset.security.manager.g")
@patch("superset.security.SupersetSecurityManager.can_access_database")
@patch("superset.security.SupersetSecurityManager.get_session")
def test_get_user_datasources_gamma(
self, mock_get_session, mock_can_access_database, mock_g
):
Datasource = namedtuple("Datasource", ["database", "schema", "name"])
mock_g.user = security_manager.find_user("gamma")
mock_can_access_database.return_value = False
mock_get_session.query.return_value.filter.return_value.all.return_value = []

with mock.patch.object(
ConnectorRegistry, "get_all_datasources"
) as mock_get_all_datasources:
mock_get_all_datasources.return_value = [
Datasource("database1", "schema1", "table1"),
Datasource("database1", "schema1", "table2"),
Datasource("database2", None, "table1"),
]

datasources = security_manager.get_user_datasources()

assert datasources == []

@patch("superset.security.manager.g")
@patch("superset.security.SupersetSecurityManager.can_access_database")
@patch("superset.security.SupersetSecurityManager.get_session")
def test_get_user_datasources_gamma_with_schema(
self, mock_get_session, mock_can_access_database, mock_g
):
Datasource = namedtuple("Datasource", ["database", "schema", "name"])
mock_g.user = security_manager.find_user("gamma")
mock_can_access_database.return_value = False

mock_get_session.query.return_value.filter.return_value.all.return_value = [
Datasource("database1", "schema1", "table1"),
Datasource("database1", "schema1", "table2"),
]

with mock.patch.object(
ConnectorRegistry, "get_all_datasources"
) as mock_get_all_datasources:
mock_get_all_datasources.return_value = [
Datasource("database1", "schema1", "table1"),
Datasource("database1", "schema1", "table2"),
Datasource("database2", None, "table1"),
]

datasources = security_manager.get_user_datasources()

assert sorted(datasources) == [
Datasource("database1", "schema1", "table1"),
Datasource("database1", "schema1", "table2"),
]

0 comments on commit ffa5175

Please sign in to comment.