diff --git a/superset/connectors/connector_registry.py b/superset/connectors/connector_registry.py index 7de19b3484df5..86081cd1b539b 100644 --- a/superset/connectors/connector_registry.py +++ b/superset/connectors/connector_registry.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from collections import defaultdict from typing import Dict, List, Optional, Set, Type, TYPE_CHECKING from flask_babel import _ @@ -100,41 +99,6 @@ def get_datasource_by_id( # pylint: disable=too-many-arguments pass raise NoResultFound(_("Datasource id not found: %(id)s", id=datasource_id)) - @classmethod - def get_user_datasources(cls, session: Session) -> List["BaseDatasource"]: - from superset import security_manager - - # collect datasources which the user has explicit permissions to - user_perms = security_manager.user_view_menu_names("datasource_access") - schema_perms = security_manager.user_view_menu_names("schema_access") - user_datasources = set() - for datasource_class in ConnectorRegistry.sources.values(): - user_datasources.update( - session.query(datasource_class) - .filter( - or_( - datasource_class.perm.in_(user_perms), - datasource_class.schema_perm.in_(schema_perms), - ) - ) - .all() - ) - - # group all datasources by database - all_datasources = cls.get_all_datasources(session) - datasources_by_database: Dict["Database", Set["BaseDatasource"]] = defaultdict( - set - ) - for datasource in all_datasources: - datasources_by_database[datasource.database].add(datasource) - - # add datasources with implicit permission (eg, database access) - for database, datasources in datasources_by_database.items(): - if security_manager.can_access_database(database): - user_datasources.update(datasources) - - return list(user_datasources) - @classmethod def get_datasource_by_name( # pylint: disable=too-many-arguments cls, diff --git a/superset/security/manager.py b/superset/security/manager.py index a7a84648fb978..9dc4bf9efc81b 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -18,7 +18,19 @@ """A set of constants and methods to manage permissions and security""" import logging import re -from typing import Any, Callable, cast, List, Optional, Set, Tuple, TYPE_CHECKING, Union +from collections import defaultdict +from typing import ( + Any, + Callable, + cast, + Dict, + List, + Optional, + Set, + Tuple, + TYPE_CHECKING, + Union, +) from flask import current_app, g from flask_appbuilder import Model @@ -419,6 +431,43 @@ def get_table_access_link( # pylint: disable=unused-argument,no-self-use return conf.get("PERMISSION_INSTRUCTIONS_LINK") + def get_user_datasources(self) -> List["BaseDatasource"]: + """ + Collect datasources which the user has explicit permissions to. + + :returns: The list of datasources + """ + + user_perms = self.user_view_menu_names("datasource_access") + schema_perms = self.user_view_menu_names("schema_access") + user_datasources = set() + for datasource_class in ConnectorRegistry.sources.values(): + user_datasources.update( + self.get_session.query(datasource_class) + .filter( + or_( + datasource_class.perm.in_(user_perms), + datasource_class.schema_perm.in_(schema_perms), + ) + ) + .all() + ) + + # group all datasources by database + all_datasources = ConnectorRegistry.get_all_datasources(self.get_session) + datasources_by_database: Dict["Database", Set["BaseDatasource"]] = defaultdict( + set + ) + for datasource in all_datasources: + datasources_by_database[datasource.database].add(datasource) + + # add datasources with implicit permission (eg, database access) + for database, datasources in datasources_by_database.items(): + if self.can_access_database(database): + user_datasources.update(datasources) + + return list(user_datasources) + def can_access_table(self, database: "Database", table: "Table") -> bool: """ Return True if the user can access the SQL table, False otherwise. diff --git a/superset/views/chart/views.py b/superset/views/chart/views.py index 45dbd2308caa1..68c19cc2cb514 100644 --- a/superset/views/chart/views.py +++ b/superset/views/chart/views.py @@ -21,8 +21,7 @@ from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_babel import lazy_gettext as _ -from superset import db, is_feature_enabled -from superset.connectors.connector_registry import ConnectorRegistry +from superset import is_feature_enabled, security_manager from superset.constants import MODEL_VIEW_RW_METHOD_PERMISSION_MAP, RouteMethod from superset.models.slice import Slice from superset.typing import FlaskResponse @@ -65,7 +64,7 @@ def pre_delete(self, item: "SliceModelView") -> None: def add(self) -> FlaskResponse: datasources = [ {"value": str(d.id) + "__" + d.type, "label": repr(d)} - for d in ConnectorRegistry.get_user_datasources(db.session) + for d in security_manager.get_user_datasources() ] payload = { "datasources": sorted( diff --git a/superset/views/core.py b/superset/views/core.py index b800728dc5d72..67ab3b13f1d67 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -186,7 +186,7 @@ def datasources(self) -> FlaskResponse: sorted( [ datasource.short_data - for datasource in ConnectorRegistry.get_user_datasources(db.session) + for datasource in security_manager.get_user_datasources() if datasource.short_data.get("name") ], key=lambda datasource: datasource["name"], diff --git a/tests/access_tests.py b/tests/access_tests.py index 795ca98c30cfd..d3cc55a1c9df5 100644 --- a/tests/access_tests.py +++ b/tests/access_tests.py @@ -18,7 +18,6 @@ """Unit tests for Superset""" import json import unittest -from collections import namedtuple from unittest import mock from tests.fixtures.birth_names_dashboard import load_birth_names_dashboard_with_slices @@ -627,83 +626,5 @@ def test_request_access(self): session.commit() -class TestDatasources(SupersetTestCase): - def test_get_user_datasources_admin(self): - Datasource = namedtuple("Datasource", ["database", "schema", "name"]) - - mock_session = mock.MagicMock() - mock_session.query.return_value.filter.return_value.all.return_value = [] - - with mock.patch("superset.security_manager") as mock_security_manager: - mock_security_manager.can_access_database.return_value = True - - with mock.patch.object( - ConnectorRegistry, "get_all_datasources" - ) as mock_get_all_datasources: - mock_get_all_datasources.return_value = [ - Datasource("database1", "schema1", "table1"), - Datasource("database1", "schema1", "table2"), - Datasource("database2", None, "table1"), - ] - - datasources = ConnectorRegistry.get_user_datasources(mock_session) - - assert sorted(datasources) == [ - Datasource("database1", "schema1", "table1"), - Datasource("database1", "schema1", "table2"), - Datasource("database2", None, "table1"), - ] - - def test_get_user_datasources_gamma(self): - Datasource = namedtuple("Datasource", ["database", "schema", "name"]) - - mock_session = mock.MagicMock() - mock_session.query.return_value.filter.return_value.all.return_value = [] - - with mock.patch("superset.security_manager") as mock_security_manager: - mock_security_manager.can_access_database.return_value = False - - with mock.patch.object( - ConnectorRegistry, "get_all_datasources" - ) as mock_get_all_datasources: - mock_get_all_datasources.return_value = [ - Datasource("database1", "schema1", "table1"), - Datasource("database1", "schema1", "table2"), - Datasource("database2", None, "table1"), - ] - - datasources = ConnectorRegistry.get_user_datasources(mock_session) - - assert datasources == [] - - def test_get_user_datasources_gamma_with_schema(self): - Datasource = namedtuple("Datasource", ["database", "schema", "name"]) - - mock_session = mock.MagicMock() - mock_session.query.return_value.filter.return_value.all.return_value = [ - Datasource("database1", "schema1", "table1"), - Datasource("database1", "schema1", "table2"), - ] - - with mock.patch("superset.security_manager") as mock_security_manager: - mock_security_manager.can_access_database.return_value = False - - with mock.patch.object( - ConnectorRegistry, "get_all_datasources" - ) as mock_get_all_datasources: - mock_get_all_datasources.return_value = [ - Datasource("database1", "schema1", "table1"), - Datasource("database1", "schema1", "table2"), - Datasource("database2", None, "table1"), - ] - - datasources = ConnectorRegistry.get_user_datasources(mock_session) - - assert sorted(datasources) == [ - Datasource("database1", "schema1", "table1"), - Datasource("database1", "schema1", "table2"), - ] - - if __name__ == "__main__": unittest.main() diff --git a/tests/security_tests.py b/tests/security_tests.py index ae636a75ac15e..a59118183a462 100644 --- a/tests/security_tests.py +++ b/tests/security_tests.py @@ -18,7 +18,8 @@ import inspect import re import unittest - +from collections import namedtuple +from unittest import mock from unittest.mock import Mock, patch from typing import Any, Dict @@ -1220,3 +1221,88 @@ def test_access_request_enabled(self): uri = "/accessrequestsmodelview/list/" rv = self.client.get(uri) self.assertLess(rv.status_code, 400) + + +class TestDatasources(SupersetTestCase): + @patch("superset.security.manager.g") + @patch("superset.security.SupersetSecurityManager.can_access_database") + @patch("superset.security.SupersetSecurityManager.get_session") + def test_get_user_datasources_admin( + self, mock_get_session, mock_can_access_database, mock_g + ): + Datasource = namedtuple("Datasource", ["database", "schema", "name"]) + mock_g.user = security_manager.find_user("admin") + mock_can_access_database.return_value = True + mock_get_session.query.return_value.filter.return_value.all.return_value = [] + + with mock.patch.object( + ConnectorRegistry, "get_all_datasources" + ) as mock_get_all_datasources: + mock_get_all_datasources.return_value = [ + Datasource("database1", "schema1", "table1"), + Datasource("database1", "schema1", "table2"), + Datasource("database2", None, "table1"), + ] + + datasources = security_manager.get_user_datasources() + + assert sorted(datasources) == [ + Datasource("database1", "schema1", "table1"), + Datasource("database1", "schema1", "table2"), + Datasource("database2", None, "table1"), + ] + + @patch("superset.security.manager.g") + @patch("superset.security.SupersetSecurityManager.can_access_database") + @patch("superset.security.SupersetSecurityManager.get_session") + def test_get_user_datasources_gamma( + self, mock_get_session, mock_can_access_database, mock_g + ): + Datasource = namedtuple("Datasource", ["database", "schema", "name"]) + mock_g.user = security_manager.find_user("gamma") + mock_can_access_database.return_value = False + mock_get_session.query.return_value.filter.return_value.all.return_value = [] + + with mock.patch.object( + ConnectorRegistry, "get_all_datasources" + ) as mock_get_all_datasources: + mock_get_all_datasources.return_value = [ + Datasource("database1", "schema1", "table1"), + Datasource("database1", "schema1", "table2"), + Datasource("database2", None, "table1"), + ] + + datasources = security_manager.get_user_datasources() + + assert datasources == [] + + @patch("superset.security.manager.g") + @patch("superset.security.SupersetSecurityManager.can_access_database") + @patch("superset.security.SupersetSecurityManager.get_session") + def test_get_user_datasources_gamma_with_schema( + self, mock_get_session, mock_can_access_database, mock_g + ): + Datasource = namedtuple("Datasource", ["database", "schema", "name"]) + mock_g.user = security_manager.find_user("gamma") + mock_can_access_database.return_value = False + + mock_get_session.query.return_value.filter.return_value.all.return_value = [ + Datasource("database1", "schema1", "table1"), + Datasource("database1", "schema1", "table2"), + ] + + with mock.patch.object( + ConnectorRegistry, "get_all_datasources" + ) as mock_get_all_datasources: + mock_get_all_datasources.return_value = [ + Datasource("database1", "schema1", "table1"), + Datasource("database1", "schema1", "table2"), + Datasource("database2", None, "table1"), + ] + + datasources = security_manager.get_user_datasources() + + assert sorted(datasources) == [ + Datasource("database1", "schema1", "table1"), + Datasource("database1", "schema1", "table2"), + ]