From 9377227e06f818c6e4fcf1715d38e27978ec64de Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Fri, 5 Apr 2024 12:17:39 -0400 Subject: [PATCH] chore(OAuth2): refactor for custom OAuth2 clients (#27880) --- superset/common/query_object.py | 2 +- superset/config.py | 20 +- superset/connectors/sqla/utils.py | 2 +- superset/databases/api.py | 6 +- superset/db_engine_specs/README.md | 62 +++--- superset/db_engine_specs/base.py | 176 +++++++++++++----- superset/db_engine_specs/gsheets.py | 91 +-------- superset/db_engine_specs/hive.py | 3 +- superset/db_engine_specs/impala.py | 14 +- superset/db_engine_specs/presto.py | 9 +- superset/db_engine_specs/trino.py | 4 +- superset/models/core.py | 39 +++- superset/sql_validators/presto_db.py | 2 +- superset/superset_typing.py | 49 +++++ superset/utils/oauth2.py | 12 +- tests/unit_tests/databases/api_test.py | 12 +- .../db_engine_specs/test_clickhouse.py | 3 +- .../db_engine_specs/test_databend.py | 3 +- .../db_engine_specs/test_elasticsearch.py | 3 +- .../db_engine_specs/test_gsheets.py | 150 +++++++-------- .../unit_tests/extensions/test_sqlalchemy.py | 6 +- tests/unit_tests/utils/oauth2_tests.py | 8 +- 22 files changed, 382 insertions(+), 294 deletions(-) diff --git a/superset/common/query_object.py b/superset/common/query_object.py index e4a305316e7b..a16166134408 100644 --- a/superset/common/query_object.py +++ b/superset/common/query_object.py @@ -108,7 +108,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes time_range: str | None to_dttm: datetime | None - def __init__( # pylint: disable=too-many-arguments,too-many-locals + def __init__( # pylint: disable=too-many-locals, too-many-arguments self, *, annotation_layers: list[dict[str, Any]] | None = None, diff --git a/superset/config.py b/superset/config.py index 19bb7224ae1c..1b06f96db8e7 100644 --- a/superset/config.py +++ b/superset/config.py @@ -1409,12 +1409,20 @@ def EMAIL_HEADER_MUTATOR( # pylint: disable=invalid-name,unused-argument # Details needed for databases that allows user to authenticate using personal # OAuth2 tokens. See https://github.com/apache/superset/issues/20300 for more -# information -DATABASE_OAUTH2_CREDENTIALS: dict[str, dict[str, Any]] = { +# information. The scope and URIs are optional. +DATABASE_OAUTH2_CLIENTS: dict[str, dict[str, Any]] = { # "Google Sheets": { - # "CLIENT_ID": "XXX.apps.googleusercontent.com", - # "CLIENT_SECRET": "GOCSPX-YYY", - # "BASEURL": "https://accounts.google.com/o/oauth2/v2/auth", + # "id": "XXX.apps.googleusercontent.com", + # "secret": "GOCSPX-YYY", + # "scope": " ".join( + # [ + # "https://www.googleapis.com/auth/drive.readonly", + # "https://www.googleapis.com/auth/spreadsheets", + # "https://spreadsheets.google.com/feeds", + # ] + # ), + # "authorization_request_uri": "https://accounts.google.com/o/oauth2/v2/auth", + # "token_request_uri": "https://oauth2.googleapis.com/token", # }, } # OAuth2 state is encoded in a JWT using the alogorithm below. @@ -1425,6 +1433,8 @@ def EMAIL_HEADER_MUTATOR( # pylint: disable=invalid-name,unused-argument # applications. In that case, the proxy can forward the request to the correct instance # by looking at the `default_redirect_uri` attribute in the OAuth2 state object. # DATABASE_OAUTH2_REDIRECT_URI = "http://localhost:8088/api/v1/database/oauth2/" +# Timeout when fetching access and refresh tokens. +DATABASE_OAUTH2_TIMEOUT = timedelta(seconds=30) # Enable/disable CSP warning CONTENT_SECURITY_POLICY_WARNING = True diff --git a/superset/connectors/sqla/utils.py b/superset/connectors/sqla/utils.py index d0922e40f3c4..4bc11aee42d8 100644 --- a/superset/connectors/sqla/utils.py +++ b/superset/connectors/sqla/utils.py @@ -145,7 +145,7 @@ def get_columns_description( cursor = conn.cursor() query = database.apply_limit_to_sql(query, limit=1) cursor.execute(query) - db_engine_spec.execute(cursor, query, database.id) + db_engine_spec.execute(cursor, query, database) result = db_engine_spec.fetch_data(cursor, limit=1) result_set = SupersetResultSet(result, cursor.description, db_engine_spec) return result_set.columns diff --git a/superset/databases/api.py b/superset/databases/api.py index ce83f0de59b0..5eb9de90d44a 100644 --- a/superset/databases/api.py +++ b/superset/databases/api.py @@ -1115,9 +1115,13 @@ def oauth2(self) -> FlaskResponse: if database is None: return self.response_404() + oauth2_config = database.get_oauth2_config() + if oauth2_config is None: + raise OAuth2Error("No configuration found for OAuth2") + token_response = database.db_engine_spec.get_oauth2_token( + oauth2_config, parameters["code"], - state, ) # delete old tokens diff --git a/superset/db_engine_specs/README.md b/superset/db_engine_specs/README.md index 0be1f2914667..f158a3a41bb4 100644 --- a/superset/db_engine_specs/README.md +++ b/superset/db_engine_specs/README.md @@ -547,65 +547,53 @@ Alternatively, it's also possible to impersonate users by implementing the `upda Support for authenticating to a database using personal OAuth2 access tokens was introduced in [SIP-85](https://github.com/apache/superset/issues/20300). The Google Sheets DB engine spec is the reference implementation. -To add support for OAuth2 to a DB engine spec, the following attribute and methods are needed: +Note that this API is still experimental and evolving quickly, subject to breaking changes. Currently, to add support for OAuth2 to a DB engine spec, the following attributes are needed: ```python class BaseEngineSpec: + supports_oauth2 = True oauth2_exception = OAuth2RedirectError - @staticmethod - def is_oauth2_enabled() -> bool: - return False - - @staticmethod - def get_oauth2_authorization_uri(state: OAuth2State) -> str: - raise NotImplementedError() - - @staticmethod - def get_oauth2_token(code: str, state: OAuth2State) -> OAuth2TokenResponse: - raise NotImplementedError() - - @staticmethod - def get_oauth2_fresh_token(refresh_token: str) -> OAuth2TokenResponse: - raise NotImplementedError() + oauth2_scope = " ".join([ + "https://example.org/scope1", + "https://example.org/scope2", + ]) + oauth2_authorization_request_uri = "https://example.org/authorize" + oauth2_token_request_uri = "https://example.org/token" ``` The `oauth2_exception` is an exception that is raised by `cursor.execute` when OAuth2 is needed. This will start the OAuth2 dance when `BaseEngineSpec.execute` is called, by returning the custom error `OAUTH2_REDIRECT` to the frontend. If the database driver doesn't have a specific exception, it might be necessary to overload the `execute` method in the DB engine spec, so that the `BaseEngineSpec.start_oauth2_dance` method gets called whenever OAuth2 is needed. -The first method, `is_oauth2_enabled`, is used to inform if the database supports OAuth2. This can be dynamic; for example, the Google Sheets DB engine spec checks if the Superset configuration has the necessary section: - -```python -from flask import current_app - +The DB engine should implement logic in either `get_url_for_impersonation` or `update_impersonation_config` to update the connection with the personal access token. See the Google Sheets DB engine spec for a reference implementation. -class GSheetsEngineSpec(ShillelaghEngineSpec): - @staticmethod - def is_oauth2_enabled() -> bool: - return "Google Sheets" in current_app.config["DATABASE_OAUTH2_CREDENTIALS"] -``` - -Where the configuration for OAuth2 would look like this: +Currently OAuth2 needs to be configured at the DB engine spec level, ie, with one client for each DB engien spec. The configuration lives in `superset_config.py`: ```python # superset_config.py -DATABASE_OAUTH2_CREDENTIALS = { +DATABASE_OAUTH2_CLIENTS = { "Google Sheets": { - "CLIENT_ID": "XXX.apps.googleusercontent.com", - "CLIENT_SECRET": "GOCSPX-YYY", + "id": "XXX.apps.googleusercontent.com", + "secret": "GOCSPX-YYY", + "scope": " ".join( + [ + "https://www.googleapis.com/auth/drive.readonly", + "https://www.googleapis.com/auth/spreadsheets", + "https://spreadsheets.google.com/feeds", + ], + ), + "authorization_request_uri": "https://accounts.google.com/o/oauth2/v2/auth", + "token_request_uri": "https://oauth2.googleapis.com/token", }, } DATABASE_OAUTH2_JWT_ALGORITHM = "HS256" DATABASE_OAUTH2_REDIRECT_URI = "http://localhost:8088/api/v1/database/oauth2/" +DATABASE_OAUTH2_TIMEOUT = timedelta(seconds=30) ``` -The second method, `get_oauth2_authorization_uri`, is responsible for building the URL where the user is sent to initiate OAuth2. This method receives a `state`. The state is an encoded JWT that is passed to the OAuth2 provider, and is received unmodified when the user is redirected back to Superset. The default state contains the user ID and the database ID, so that Superset can know where to store the received OAuth2 tokens. - -Additionally, the state also contains a `tab_id`, which is a random UUID4 used as a shared secret for communication between browser tabs. When OAuth2 starts, Superset will open a new browser tab, where the user will grant permissions to Superset. When authentication is complete and successful this opened tab will send a message to the original tab, so that the original query can be re-run. The `tab_id` is sent by the opened tab and verified by the original tab to prevent malicious messages from other sites. As an additional security measure the origin of the message should match the OAuth2 redirect URL. - -State also contains a `defaul_redirect_uri`, which is the enpoint in Supeset that receives the tokens from the OAuth2 provider (`/api/v1/database/oauth2/`). The redirect URL can be overwritten in the config file via the `DATABASE_OAUTH2_REDIRECT_URI` parameter. This might be useful where you have multiple Superset instances. Since the OAuth2 provider requires the redirect URL to be registered a priori, it might be easier (or needed) to register a single URL for a proxy service; the proxy service can then inspect the JWT and redirect the request to `defaul_redirect_uri`. +When configuring a client only the ID and secret are required; the DB engine spec should have default values for the scope and endpoints. The `DATABASE_OAUTH2_REDIRECT_URI` attribute is optional, and defaults to `/api/v1/databases/oauth2/` in Superset. -Finally, `get_oauth2_token` and `get_oauth2_fresh_token` are used to actually retrieve a token and refresh an expired token, respectively. +In the future we plan to support adding custom clients via the Superset UI, and being able to manually assign clients to specific databases. ### File upload diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 12797fc6a31e..bcb4035c9c67 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -33,9 +33,11 @@ TypedDict, Union, ) +from urllib.parse import urlencode, urljoin from uuid import uuid4 import pandas as pd +import requests import sqlparse from apispec import APISpec from apispec.ext.marshmallow import MarshmallowPlugin @@ -62,11 +64,18 @@ from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import OAuth2Error, OAuth2RedirectError from superset.sql_parse import ParsedQuery, SQLScript, Table -from superset.superset_typing import ResultSetColumnType, SQLAColumnType +from superset.superset_typing import ( + OAuth2ClientConfig, + OAuth2State, + OAuth2TokenResponse, + ResultSetColumnType, + SQLAColumnType, +) from superset.utils import core as utils from superset.utils.core import ColumnSpec, GenericDataType from superset.utils.hashing import md5_sha_from_str from superset.utils.network import is_hostname_valid, is_port_open +from superset.utils.oauth2 import encode_oauth2_state if TYPE_CHECKING: from superset.connectors.sqla.models import TableColumn @@ -173,31 +182,6 @@ class MetricType(TypedDict, total=False): extra: str | None -class OAuth2TokenResponse(TypedDict, total=False): - """ - Type for an OAuth2 response when exchanging or refreshing tokens. - """ - - access_token: str - expires_in: int - scope: str - token_type: str - - # only present when exchanging code for refresh/access tokens - refresh_token: str - - -class OAuth2State(TypedDict): - """ - Type for the state passed during OAuth2. - """ - - database_id: int - user_id: int - default_redirect_uri: str - tab_id: str - - class BaseEngineSpec: # pylint: disable=too-many-public-methods """Abstract class for database engine specific configurations @@ -425,15 +409,25 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods # Can the catalog be changed on a per-query basis? supports_dynamic_catalog = False + # Does the engine supports OAuth 2.0? This requires logic to be added to one of the + # the user impersonation methods to handle personal tokens. + supports_oauth2 = False + oauth2_scope = "" + oauth2_authorization_request_uri = "" # pylint: disable=invalid-name + oauth2_token_request_uri = "" + # Driver-specific exception that should be mapped to OAuth2RedirectError oauth2_exception = OAuth2RedirectError - @staticmethod - def is_oauth2_enabled() -> bool: - return False + @classmethod + def is_oauth2_enabled(cls) -> bool: + return ( + cls.supports_oauth2 + and cls.engine_name in current_app.config["DATABASE_OAUTH2_CLIENTS"] + ) @classmethod - def start_oauth2_dance(cls, database_id: int) -> None: + def start_oauth2_dance(cls, database: Database) -> None: """ Start the OAuth2 dance. @@ -446,10 +440,6 @@ def start_oauth2_dance(cls, database_id: int) -> None: """ tab_id = str(uuid4()) default_redirect_uri = url_for("DatabaseRestApi.oauth2", _external=True) - redirect_uri = current_app.config.get( - "DATABASE_OAUTH2_REDIRECT_URI", - default_redirect_uri, - ) # The state is passed to the OAuth2 provider, and sent back to Superset after # the user authorizes the access. The redirect endpoint in Superset can then @@ -457,7 +447,7 @@ def start_oauth2_dance(cls, database_id: int) -> None: # belongs to. state: OAuth2State = { # Database ID and user ID are the primary key associated with the token. - "database_id": database_id, + "database_id": database.id, "user_id": g.user.id, # In multi-instance deployments there might be a single proxy handling # redirects, with a custom `DATABASE_OAUTH2_REDIRECT_URI`. Since the OAuth2 @@ -473,30 +463,114 @@ def start_oauth2_dance(cls, database_id: int) -> None: # message. "tab_id": tab_id, } - oauth_url = cls.get_oauth2_authorization_uri(state) + oauth2_config = database.get_oauth2_config() + if oauth2_config is None: + raise OAuth2Error("No configuration found for OAuth2") - raise OAuth2RedirectError(oauth_url, tab_id, redirect_uri) + oauth_url = cls.get_oauth2_authorization_uri(oauth2_config, state) - @staticmethod - def get_oauth2_authorization_uri(state: OAuth2State) -> str: + raise OAuth2RedirectError(oauth_url, tab_id, default_redirect_uri) + + @classmethod + def get_oauth2_config(cls) -> OAuth2ClientConfig | None: + """ + Build the DB engine spec level OAuth2 client config. + """ + oauth2_config = current_app.config["DATABASE_OAUTH2_CLIENTS"] + if cls.engine_name not in oauth2_config: + return None + + db_engine_spec_config = oauth2_config[cls.engine_name] + redirect_uri = current_app.config.get( + "DATABASE_OAUTH2_REDIRECT_URI", + url_for("DatabaseRestApi.oauth2", _external=True), + ) + + config: OAuth2ClientConfig = { + "id": db_engine_spec_config["id"], + "secret": db_engine_spec_config["secret"], + "scope": db_engine_spec_config.get("scope") or cls.oauth2_scope, + "redirect_uri": redirect_uri, + "authorization_request_uri": db_engine_spec_config.get( + "authorization_request_uri", + cls.oauth2_authorization_request_uri, + ), + "token_request_uri": db_engine_spec_config.get( + "token_request_uri", + cls.oauth2_token_request_uri, + ), + } + + return config + + @classmethod + def get_oauth2_authorization_uri( + cls, + config: OAuth2ClientConfig, + state: OAuth2State, + ) -> str: """ Return URI for initial OAuth2 request. """ - raise OAuth2Error("Subclasses must implement `get_oauth2_authorization_uri`") + uri = config["authorization_request_uri"] + params = { + "scope": config["scope"], + "access_type": "offline", + "include_granted_scopes": "false", + "response_type": "code", + "state": encode_oauth2_state(state), + "redirect_uri": config["redirect_uri"], + "client_id": config["id"], + "prompt": "consent", + } + return urljoin(uri, "?" + urlencode(params)) - @staticmethod - def get_oauth2_token(code: str, state: OAuth2State) -> OAuth2TokenResponse: + @classmethod + def get_oauth2_token( + cls, + config: OAuth2ClientConfig, + code: str, + ) -> OAuth2TokenResponse: """ Exchange authorization code for refresh/access tokens. """ - raise OAuth2Error("Subclasses must implement `get_oauth2_token`") + timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds() + uri = config["token_request_uri"] + response = requests.post( + uri, + json={ + "code": code, + "client_id": config["id"], + "client_secret": config["secret"], + "redirect_uri": config["redirect_uri"], + "grant_type": "authorization_code", + }, + timeout=timeout, + ) + return response.json() - @staticmethod - def get_oauth2_fresh_token(refresh_token: str) -> OAuth2TokenResponse: + @classmethod + def get_oauth2_fresh_token( + cls, + config: OAuth2ClientConfig, + refresh_token: str, + ) -> OAuth2TokenResponse: """ Refresh an access token that has expired. """ - raise OAuth2Error("Subclasses must implement `get_oauth2_fresh_token`") + timeout = current_app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds() + uri = config["token_request_uri"] + response = requests.post( + uri, + json={ + "client_id": config["id"], + "client_secret": config["secret"], + "refresh_token": refresh_token, + "grant_type": "refresh_token", + }, + timeout=timeout, + ) + return response.json() @classmethod def get_allows_alias_in_select( @@ -1196,7 +1270,7 @@ def execute_with_cursor( in a timely manner and facilitate operations such as query stop """ logger.debug("Query %d: Running query: %s", query.id, sql) - cls.execute(cursor, sql, query.database.id, async_=True) + cls.execute(cursor, sql, query.database, async_=True) logger.debug("Query %d: Handling cursor", query.id) cls.handle_cursor(cursor, query) @@ -1667,6 +1741,7 @@ def update_impersonation_config( connect_args: dict[str, Any], uri: str, username: str | None, + access_token: str | None, ) -> None: """ Update a configuration dictionary @@ -1675,6 +1750,7 @@ def update_impersonation_config( :param connect_args: config to be updated :param uri: URI :param username: Effective username + :param access_token: Personal access token for OAuth2 :return: None """ @@ -1683,7 +1759,7 @@ def execute( # pylint: disable=unused-argument cls, cursor: Any, query: str, - database_id: int, + database: Database, **kwargs: Any, ) -> None: """ @@ -1703,8 +1779,8 @@ def execute( # pylint: disable=unused-argument try: cursor.execute(query) except cls.oauth2_exception as ex: - if cls.is_oauth2_enabled() and g.user: - cls.start_oauth2_dance(database_id) + if database.is_oauth2_enabled() and g and g.user: + cls.start_oauth2_dance(database) raise cls.get_dbapi_mapped_exception(ex) from ex except Exception as ex: raise cls.get_dbapi_mapped_exception(ex) from ex diff --git a/superset/db_engine_specs/gsheets.py b/superset/db_engine_specs/gsheets.py index 28e95811d48a..3e33aa8e32dd 100644 --- a/superset/db_engine_specs/gsheets.py +++ b/superset/db_engine_specs/gsheets.py @@ -23,13 +23,11 @@ import re from re import Pattern from typing import Any, TYPE_CHECKING, TypedDict -from urllib.parse import urlencode, urljoin import pandas as pd -import urllib3 from apispec import APISpec from apispec.ext.marshmallow import MarshmallowPlugin -from flask import current_app, g +from flask import g from flask_babel import gettext as __ from marshmallow import fields, Schema from marshmallow.exceptions import ValidationError @@ -42,11 +40,9 @@ from superset import db, security_manager from superset.constants import PASSWORD_MASK from superset.databases.schemas import encrypted_field_properties, EncryptedString -from superset.db_engine_specs.base import OAuth2State, OAuth2TokenResponse from superset.db_engine_specs.shillelagh import ShillelaghEngineSpec from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import SupersetException -from superset.utils.oauth2 import encode_oauth2_state if TYPE_CHECKING: from superset.models.core import Database @@ -62,7 +58,6 @@ SYNTAX_ERROR_REGEX = re.compile('SQLError: near "(?P.*?)": syntax error') ma_plugin = MarshmallowPlugin() -http = urllib3.PoolManager() class GSheetsParametersSchema(Schema): @@ -111,7 +106,13 @@ class GSheetsEngineSpec(ShillelaghEngineSpec): supports_file_upload = True - # exception raised by shillelagh that should trigger OAuth2 + # OAuth 2.0 + supports_oauth2 = True + oauth2_scope = " ".join(SCOPES) + oauth2_authorization_request_uri = ( # pylint: disable=invalid-name + "https://accounts.google.com/o/oauth2/v2/auth" + ) + oauth2_token_request_uri = "https://oauth2.googleapis.com/token" oauth2_exception = UnauthenticatedError @classmethod @@ -153,82 +154,6 @@ def extra_table_metadata( return {"metadata": metadata["extra"]} - @staticmethod - def is_oauth2_enabled() -> bool: - """ - Return if OAuth2 is enabled for GSheets. - """ - return "Google Sheets" in current_app.config["DATABASE_OAUTH2_CREDENTIALS"] - - @classmethod - def get_oauth2_authorization_uri(cls, state: OAuth2State) -> str: - """ - Return URI for initial OAuth2 request. - - https://developers.google.com/identity/protocols/oauth2/web-server#creatingclient - """ - config = current_app.config["DATABASE_OAUTH2_CREDENTIALS"]["Google Sheets"] - baseurl = config.get("BASEURL", "https://accounts.google.com/o/oauth2/v2/auth") - redirect_uri = current_app.config.get( - "DATABASE_OAUTH2_REDIRECT_URI", - state["default_redirect_uri"], - ) - - params = { - "scope": " ".join(SCOPES), - "access_type": "offline", - "include_granted_scopes": "false", - "response_type": "code", - "state": encode_oauth2_state(state), - "redirect_uri": redirect_uri, - "client_id": config["CLIENT_ID"], - "prompt": "consent", - } - return urljoin(baseurl, "?" + urlencode(params)) - - @staticmethod - def get_oauth2_token(code: str, state: OAuth2State) -> OAuth2TokenResponse: - """ - Exchange authorization code for refresh/access tokens. - """ - config = current_app.config["DATABASE_OAUTH2_CREDENTIALS"]["Google Sheets"] - redirect_uri = current_app.config.get( - "DATABASE_OAUTH2_REDIRECT_URI", - state["default_redirect_uri"], - ) - - response = http.request( - "POST", - "https://oauth2.googleapis.com/token", - fields={ - "code": code, - "client_id": config["CLIENT_ID"], - "client_secret": config["CLIENT_SECRET"], - "redirect_uri": redirect_uri, - "grant_type": "authorization_code", - }, - ) - return json.loads(response.data.decode("utf-8")) - - @staticmethod - def get_oauth2_fresh_token(refresh_token: str) -> OAuth2TokenResponse: - """ - Refresh an access token that has expired. - """ - config = current_app.config["DATABASE_OAUTH2_CREDENTIALS"]["Google Sheets"] - - response = http.request( - "POST", - "https://oauth2.googleapis.com/token", - fields={ - "client_id": config["CLIENT_ID"], - "client_secret": config["CLIENT_SECRET"], - "refresh_token": refresh_token, - "grant_type": "refresh_token", - }, - ) - return json.loads(response.data.decode("utf-8")) - @classmethod def build_sqlalchemy_uri( cls, diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py index a97dd88aefdd..2655ed6c9af6 100644 --- a/superset/db_engine_specs/hive.py +++ b/superset/db_engine_specs/hive.py @@ -528,6 +528,7 @@ def update_impersonation_config( connect_args: dict[str, Any], uri: str, username: str | None, + access_token: str | None, ) -> None: """ Update a configuration dictionary @@ -553,7 +554,7 @@ def update_impersonation_config( def execute( # type: ignore cursor, query: str, - database_id: int, + database: Database, async_: bool = False, ): # pylint: disable=arguments-differ kwargs = {"async": async_} diff --git a/superset/db_engine_specs/impala.py b/superset/db_engine_specs/impala.py index 8cda5b586183..1d3ec4e9e5b0 100644 --- a/superset/db_engine_specs/impala.py +++ b/superset/db_engine_specs/impala.py @@ -14,11 +14,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +from __future__ import annotations + import logging import re import time from datetime import datetime -from typing import Any, Optional +from typing import Any, TYPE_CHECKING from flask import current_app from sqlalchemy import types @@ -29,6 +32,9 @@ from superset.db_engine_specs.base import BaseEngineSpec from superset.models.sql_lab import Query +if TYPE_CHECKING: + from superset.models.core import Database + logger = logging.getLogger(__name__) # Query 5543ffdf692b7d02:f78a944000000000: 3% Complete (17 out of 547) QUERY_PROGRESS_REGEX = re.compile(r"Query.*: (?P[0-9]+)%") @@ -57,8 +63,8 @@ def epoch_to_dttm(cls) -> str: @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None - ) -> Optional[str]: + cls, target_type: str, dttm: datetime, db_extra: dict[str, Any] | None = None + ) -> str | None: sqla_type = cls.get_sqla_column_type(target_type) if isinstance(sqla_type, types.Date): @@ -93,7 +99,7 @@ def execute( cls, cursor: Any, query: str, - database_id: int, + database: Database, **kwargs: Any, ) -> None: try: diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index 3a409e7189c7..d749d2cb185b 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -717,6 +717,7 @@ def update_impersonation_config( connect_args: dict[str, Any], uri: str, username: str | None, + access_token: str | None, ) -> None: """ Update a configuration dictionary @@ -724,6 +725,7 @@ def update_impersonation_config( :param connect_args: config to be updated :param uri: URI string :param username: Effective username + :param access_token: Personal access token for OAuth2 :return: None """ url = make_url_safe(uri) @@ -1271,7 +1273,7 @@ def get_create_view( cursor = conn.cursor() sql = f"SHOW CREATE VIEW {schema}.{table}" try: - cls.execute(cursor, sql, database.id) + cls.execute(cursor, sql, database) rows = cls.fetch_data(cursor, 1) return rows[0][0] @@ -1329,10 +1331,7 @@ def handle_cursor(cls, cursor: Cursor, query: Query) -> None: completed_splits, total_splits, ) - if ( # pylint: disable=consider-using-min-builtin - progress > query.progress - ): - query.progress = progress + query.progress = max(query.progress, progress) db.session.commit() time.sleep(poll_interval) logger.info("Query %i: Polling the cursor for progress", query_id) diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py index 4513d63c606b..2185de8c867a 100644 --- a/superset/db_engine_specs/trino.py +++ b/superset/db_engine_specs/trino.py @@ -112,6 +112,7 @@ def update_impersonation_config( connect_args: dict[str, Any], uri: str, username: str | None, + access_token: str | None, ) -> None: """ Update a configuration dictionary @@ -119,6 +120,7 @@ def update_impersonation_config( :param connect_args: config to be updated :param uri: URI string :param username: Effective username + :param access_token: Personal access token for OAuth2 :return: None """ url = make_url_safe(uri) @@ -219,7 +221,7 @@ def _execute(results: dict[str, Any], event: threading.Event) -> None: logger.debug("Query %d: Running query: %s", query_id, sql) try: - cls.execute(cursor, sql, query.database.id) + cls.execute(cursor, sql, query.database) except Exception as ex: # pylint: disable=broad-except results["error"] = ex finally: diff --git a/superset/models/core.py b/superset/models/core.py index cf84b90ac29e..92f6946f1e0e 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -71,7 +71,7 @@ ) from superset.models.helpers import AuditMixinNullable, ImportExportMixin from superset.result_set import SupersetResultSet -from superset.superset_typing import ResultSetColumnType +from superset.superset_typing import OAuth2ClientConfig, ResultSetColumnType from superset.utils import cache as cache_util, core as utils from superset.utils.backports import StrEnum from superset.utils.core import get_username @@ -466,9 +466,15 @@ def _get_sqla_engine( ) effective_username = self.get_effective_user(sqlalchemy_url) + oauth2_config = self.get_oauth2_config() access_token = ( - get_oauth2_access_token(self.id, g.user.id, self.db_engine_spec) - if hasattr(g, "user") and hasattr(g.user, "id") + get_oauth2_access_token( + oauth2_config, + self.id, + g.user.id, + self.db_engine_spec, + ) + if hasattr(g, "user") and hasattr(g.user, "id") and oauth2_config else None ) # If using MySQL or Presto for example, will set url.username @@ -489,6 +495,7 @@ def _get_sqla_engine( connect_args, str(sqlalchemy_url), effective_username, + access_token, ) if connect_args: @@ -599,7 +606,7 @@ def _log_query(sql: str) -> None: database=None, ) _log_query(sql_) - self.db_engine_spec.execute(cursor, sql_, self.id) + self.db_engine_spec.execute(cursor, sql_, self) cursor.fetchall() if mutate_after_split: @@ -609,10 +616,10 @@ def _log_query(sql: str) -> None: database=None, ) _log_query(last_sql) - self.db_engine_spec.execute(cursor, last_sql, self.id) + self.db_engine_spec.execute(cursor, last_sql, self) else: _log_query(sqls[-1]) - self.db_engine_spec.execute(cursor, sqls[-1], self.id) + self.db_engine_spec.execute(cursor, sqls[-1], self) data = self.db_engine_spec.fetch_data(cursor) result_set = SupersetResultSet( @@ -983,6 +990,26 @@ def make_sqla_column_compatible( sqla_col.key = label_expected return sqla_col + def is_oauth2_enabled(self) -> bool: + """ + Is OAuth2 enabled in the database for authentication? + + Currently this looks for a global config at the DB engine spec level, but in the + future we want to be allow admins to create custom OAuth2 clients from the + Superset UI, and assign them to specific databases. + """ + return self.db_engine_spec.is_oauth2_enabled() + + def get_oauth2_config(self) -> OAuth2ClientConfig | None: + """ + Return OAuth2 client configuration. + + This includes client ID, client secret, scope, redirect URI, endpointsm etc. + Currently this reads the global DB engine spec config, but in the future it + should first check if there's a custom client assigned to the database. + """ + return self.db_engine_spec.get_oauth2_config() + sqla.event.listen(Database, "after_insert", security_manager.database_after_insert) sqla.event.listen(Database, "after_update", security_manager.database_after_update) diff --git a/superset/sql_validators/presto_db.py b/superset/sql_validators/presto_db.py index 8c815ad63ed3..4852f70ee46b 100644 --- a/superset/sql_validators/presto_db.py +++ b/superset/sql_validators/presto_db.py @@ -73,7 +73,7 @@ def validate_statement( from pyhive.exc import DatabaseError try: - db_engine_spec.execute(cursor, sql, database.id) + db_engine_spec.execute(cursor, sql, database) polled = cursor.poll() while polled: logger.info("polling presto for validation progress") diff --git a/superset/superset_typing.py b/superset/superset_typing.py index c71dcea3f1a2..ba623f581905 100644 --- a/superset/superset_typing.py +++ b/superset/superset_typing.py @@ -121,3 +121,52 @@ class ResultSetColumnType(TypedDict): tuple[Base, Status, Headers], tuple[Response, Status], ] + + +class OAuth2ClientConfig(TypedDict): + """ + Configuration for an OAuth2 client. + """ + + # The client ID and secret. + id: str + secret: str + + # The scopes requested; this is usually a space separated list of URLs. + scope: str + + # The URI where the user is redirected to after authorizing the client; by default + # this points to `/api/v1/databases/oauth2/`, but it can be overridden by the admin. + redirect_uri: str + + # The URI used to getting a code. + authorization_request_uri: str + + # The URI used when exchaing the code for an access token, or when refreshing an + # expired access token. + token_request_uri: str + + +class OAuth2TokenResponse(TypedDict, total=False): + """ + Type for an OAuth2 response when exchanging or refreshing tokens. + """ + + access_token: str + expires_in: int + scope: str + token_type: str + + # only present when exchanging code for refresh/access tokens + refresh_token: str + + +class OAuth2State(TypedDict): + """ + Type for the state passed during OAuth2. + """ + + database_id: int + user_id: int + default_redirect_uri: str + tab_id: str diff --git a/superset/utils/oauth2.py b/superset/utils/oauth2.py index 7e80df959915..9cc58a0b7ffc 100644 --- a/superset/utils/oauth2.py +++ b/superset/utils/oauth2.py @@ -26,11 +26,12 @@ from marshmallow import EXCLUDE, fields, post_load, Schema from superset import db -from superset.db_engine_specs.base import BaseEngineSpec, OAuth2State from superset.exceptions import CreateKeyValueDistributedLockFailedException +from superset.superset_typing import OAuth2ClientConfig, OAuth2State from superset.utils.lock import KeyValueDistributedLock if TYPE_CHECKING: + from superset.db_engine_specs.base import BaseEngineSpec from superset.models.core import DatabaseUserOAuth2Tokens JWT_EXPIRATION = timedelta(minutes=5) @@ -44,6 +45,7 @@ max_tries=5, ) def get_oauth2_access_token( + config: OAuth2ClientConfig, database_id: int, user_id: int, db_engine_spec: type[BaseEngineSpec], @@ -73,7 +75,7 @@ def get_oauth2_access_token( return token.access_token if token.refresh_token: - return refresh_oauth2_token(database_id, user_id, db_engine_spec, token) + return refresh_oauth2_token(config, database_id, user_id, db_engine_spec, token) # since the access token is expired and there's no refresh token, delete the entry db.session.delete(token) @@ -82,6 +84,7 @@ def get_oauth2_access_token( def refresh_oauth2_token( + config: OAuth2ClientConfig, database_id: int, user_id: int, db_engine_spec: type[BaseEngineSpec], @@ -92,7 +95,10 @@ def refresh_oauth2_token( user_id=user_id, database_id=database_id, ): - token_response = db_engine_spec.get_oauth2_fresh_token(token.refresh_token) + token_response = db_engine_spec.get_oauth2_fresh_token( + config, + token.refresh_token, + ) # store new access token; note that the refresh token might be revoked, in which # case there would be no access token in the response diff --git a/tests/unit_tests/databases/api_test.py b/tests/unit_tests/databases/api_test.py index 3b99e69d22be..ff6dcb23349d 100644 --- a/tests/unit_tests/databases/api_test.py +++ b/tests/unit_tests/databases/api_test.py @@ -668,6 +668,11 @@ def test_oauth2_happy_path( ) db.session.commit() + mocker.patch.object( + SqliteEngineSpec, + "get_oauth2_config", + return_value={"id": "one", "secret": "two"}, + ) get_oauth2_token = mocker.patch.object(SqliteEngineSpec, "get_oauth2_token") get_oauth2_token.return_value = { "access_token": "YYY", @@ -696,7 +701,7 @@ def test_oauth2_happy_path( assert response.status_code == 200 decode_oauth2_state.assert_called_with("some%2Estate") - get_oauth2_token.assert_called_with("XXX", state) + get_oauth2_token.assert_called_with({"id": "one", "secret": "two"}, "XXX") token = db.session.query(DatabaseUserOAuth2Tokens).one() assert token.user_id == 1 @@ -731,6 +736,11 @@ def test_oauth2_multiple_tokens( ) db.session.commit() + mocker.patch.object( + SqliteEngineSpec, + "get_oauth2_config", + return_value={"id": "one", "secret": "two"}, + ) get_oauth2_token = mocker.patch.object(SqliteEngineSpec, "get_oauth2_token") get_oauth2_token.side_effect = [ { diff --git a/tests/unit_tests/db_engine_specs/test_clickhouse.py b/tests/unit_tests/db_engine_specs/test_clickhouse.py index 94b70ba5264e..65f4d7903cab 100644 --- a/tests/unit_tests/db_engine_specs/test_clickhouse.py +++ b/tests/unit_tests/db_engine_specs/test_clickhouse.py @@ -61,12 +61,13 @@ def test_execute_connection_error() -> None: from superset.db_engine_specs.clickhouse import ClickHouseEngineSpec from superset.db_engine_specs.exceptions import SupersetDBAPIDatabaseError + database = Mock() cursor = Mock() cursor.execute.side_effect = NewConnectionError( HTTPConnection("localhost"), "Exception with sensitive data" ) with pytest.raises(SupersetDBAPIDatabaseError) as excinfo: - ClickHouseEngineSpec.execute(cursor, "SELECT col1 from table1", 1) + ClickHouseEngineSpec.execute(cursor, "SELECT col1 from table1", database) assert str(excinfo.value) == "Connection failed" diff --git a/tests/unit_tests/db_engine_specs/test_databend.py b/tests/unit_tests/db_engine_specs/test_databend.py index 06fab791884e..8e8cfe310997 100644 --- a/tests/unit_tests/db_engine_specs/test_databend.py +++ b/tests/unit_tests/db_engine_specs/test_databend.py @@ -62,12 +62,13 @@ def test_execute_connection_error() -> None: from superset.db_engine_specs.databend import DatabendEngineSpec from superset.db_engine_specs.exceptions import SupersetDBAPIDatabaseError + database = Mock() cursor = Mock() cursor.execute.side_effect = NewConnectionError( HTTPConnection("Dummypool"), "Exception with sensitive data" ) with pytest.raises(SupersetDBAPIDatabaseError) as excinfo: - DatabendEngineSpec.execute(cursor, "SELECT col1 from table1", 1) + DatabendEngineSpec.execute(cursor, "SELECT col1 from table1", database) assert str(excinfo.value) == "Connection failed" diff --git a/tests/unit_tests/db_engine_specs/test_elasticsearch.py b/tests/unit_tests/db_engine_specs/test_elasticsearch.py index ed80454d3c69..1fc3d11ca4e1 100644 --- a/tests/unit_tests/db_engine_specs/test_elasticsearch.py +++ b/tests/unit_tests/db_engine_specs/test_elasticsearch.py @@ -97,12 +97,13 @@ def test_opendistro_strip_comments() -> None: """ from superset.db_engine_specs.elasticsearch import OpenDistroEngineSpec + mock_database = MagicMock() mock_cursor = MagicMock() mock_cursor.execute.return_value = [] OpenDistroEngineSpec.execute( mock_cursor, "-- some comment \nSELECT 1\n --other comment", - 1, + mock_database, ) mock_cursor.execute.assert_called_once_with("SELECT 1\n") diff --git a/tests/unit_tests/db_engine_specs/test_gsheets.py b/tests/unit_tests/db_engine_specs/test_gsheets.py index eb72878a7834..b3d754a24ecf 100644 --- a/tests/unit_tests/db_engine_specs/test_gsheets.py +++ b/tests/unit_tests/db_engine_specs/test_gsheets.py @@ -29,6 +29,7 @@ from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import SupersetException from superset.sql_parse import Table +from superset.superset_typing import OAuth2ClientConfig from superset.utils.oauth2 import decode_oauth2_state if TYPE_CHECKING: @@ -450,8 +451,8 @@ def test_is_oauth2_enabled_no_config(mocker: MockFixture) -> None: from superset.db_engine_specs.gsheets import GSheetsEngineSpec mocker.patch( - "superset.db_engine_specs.gsheets.current_app.config", - new={"DATABASE_OAUTH2_CREDENTIALS": {}}, + "superset.db_engine_specs.base.current_app.config", + new={"DATABASE_OAUTH2_CLIENTS": {}}, ) assert GSheetsEngineSpec.is_oauth2_enabled() is False @@ -464,12 +465,12 @@ def test_is_oauth2_enabled_config(mocker: MockFixture) -> None: from superset.db_engine_specs.gsheets import GSheetsEngineSpec mocker.patch( - "superset.db_engine_specs.gsheets.current_app.config", + "superset.db_engine_specs.base.current_app.config", new={ - "DATABASE_OAUTH2_CREDENTIALS": { + "DATABASE_OAUTH2_CLIENTS": { "Google Sheets": { - "CLIENT_ID": "XXX.apps.googleusercontent.com", - "CLIENT_SECRET": "GOCSPX-YYY", + "id": "XXX.apps.googleusercontent.com", + "secret": "GOCSPX-YYY", }, } }, @@ -478,26 +479,36 @@ def test_is_oauth2_enabled_config(mocker: MockFixture) -> None: assert GSheetsEngineSpec.is_oauth2_enabled() is True -def test_get_oauth2_authorization_uri(mocker: MockFixture) -> None: +@pytest.fixture +def oauth2_config() -> OAuth2ClientConfig: + """ + Config for GSheets OAuth2. + """ + return { + "id": "XXX.apps.googleusercontent.com", + "secret": "GOCSPX-YYY", + "scope": " ".join( + [ + "https://www.googleapis.com/auth/drive.readonly " + "https://www.googleapis.com/auth/spreadsheets " + "https://spreadsheets.google.com/feeds" + ] + ), + "redirect_uri": "http://localhost:8088/api/v1/oauth2/", + "authorization_request_uri": "https://accounts.google.com/o/oauth2/v2/auth", + "token_request_uri": "https://oauth2.googleapis.com/token", + } + + +def test_get_oauth2_authorization_uri( + mocker: MockFixture, + oauth2_config: OAuth2ClientConfig, +) -> None: """ Test `get_oauth2_authorization_uri`. """ from superset.db_engine_specs.gsheets import GSheetsEngineSpec - mocker.patch( - "superset.db_engine_specs.gsheets.current_app.config", - new={ - "DATABASE_OAUTH2_CREDENTIALS": { - "Google Sheets": { - "CLIENT_ID": "XXX.apps.googleusercontent.com", - "CLIENT_SECRET": "GOCSPX-YYY", - }, - }, - "SECRET_KEY": "not-a-secret", - "DATABASE_OAUTH2_JWT_ALGORITHM": "HS256", - }, - ) - state: OAuth2State = { "database_id": 1, "user_id": 1, @@ -505,7 +516,7 @@ def test_get_oauth2_authorization_uri(mocker: MockFixture) -> None: "tab_id": "1234", } - url = GSheetsEngineSpec.get_oauth2_authorization_uri(state) + url = GSheetsEngineSpec.get_oauth2_authorization_uri(oauth2_config, state) parsed = urlparse(url) assert parsed.netloc == "accounts.google.com" assert parsed.path == "/o/oauth2/v2/auth" @@ -520,109 +531,76 @@ def test_get_oauth2_authorization_uri(mocker: MockFixture) -> None: assert decode_oauth2_state(encoded_state) == state -def test_get_oauth2_token(mocker: MockFixture) -> None: +def test_get_oauth2_token( + mocker: MockFixture, + oauth2_config: OAuth2ClientConfig, +) -> None: """ Test `get_oauth2_token`. """ from superset.db_engine_specs.gsheets import GSheetsEngineSpec - http = mocker.patch("superset.db_engine_specs.gsheets.http") - http.request().data.decode.return_value = json.dumps( - { - "access_token": "access-token", - "expires_in": 3600, - "scope": "scope", - "token_type": "Bearer", - "refresh_token": "refresh-token", - } - ) - - mocker.patch( - "superset.db_engine_specs.gsheets.current_app.config", - new={ - "DATABASE_OAUTH2_CREDENTIALS": { - "Google Sheets": { - "CLIENT_ID": "XXX.apps.googleusercontent.com", - "CLIENT_SECRET": "GOCSPX-YYY", - }, - }, - "SECRET_KEY": "not-a-secret", - "DATABASE_OAUTH2_JWT_ALGORITHM": "HS256", - }, - ) - - state: OAuth2State = { - "database_id": 1, - "user_id": 1, - "default_redirect_uri": "http://localhost:8088/api/v1/oauth2/", - "tab_id": "1234", + requests = mocker.patch("superset.db_engine_specs.base.requests") + requests.post().json.return_value = { + "access_token": "access-token", + "expires_in": 3600, + "scope": "scope", + "token_type": "Bearer", + "refresh_token": "refresh-token", } - assert GSheetsEngineSpec.get_oauth2_token("code", state) == { + assert GSheetsEngineSpec.get_oauth2_token(oauth2_config, "code") == { "access_token": "access-token", "expires_in": 3600, "scope": "scope", "token_type": "Bearer", "refresh_token": "refresh-token", } - http.request.assert_called_with( - "POST", + requests.post.assert_called_with( "https://oauth2.googleapis.com/token", - fields={ + json={ "code": "code", "client_id": "XXX.apps.googleusercontent.com", "client_secret": "GOCSPX-YYY", "redirect_uri": "http://localhost:8088/api/v1/oauth2/", "grant_type": "authorization_code", }, + timeout=30.0, ) -def test_get_oauth2_fresh_token(mocker: MockFixture) -> None: +def test_get_oauth2_fresh_token( + mocker: MockFixture, + oauth2_config: OAuth2ClientConfig, +) -> None: """ Test `get_oauth2_token`. """ from superset.db_engine_specs.gsheets import GSheetsEngineSpec - http = mocker.patch("superset.db_engine_specs.gsheets.http") - http.request().data.decode.return_value = json.dumps( - { - "access_token": "access-token", - "expires_in": 3600, - "scope": "scope", - "token_type": "Bearer", - "refresh_token": "refresh-token", - } - ) - - mocker.patch( - "superset.db_engine_specs.gsheets.current_app.config", - new={ - "DATABASE_OAUTH2_CREDENTIALS": { - "Google Sheets": { - "CLIENT_ID": "XXX.apps.googleusercontent.com", - "CLIENT_SECRET": "GOCSPX-YYY", - }, - }, - "SECRET_KEY": "not-a-secret", - "DATABASE_OAUTH2_JWT_ALGORITHM": "HS256", - }, - ) + requests = mocker.patch("superset.db_engine_specs.base.requests") + requests.post().json.return_value = { + "access_token": "access-token", + "expires_in": 3600, + "scope": "scope", + "token_type": "Bearer", + "refresh_token": "refresh-token", + } - assert GSheetsEngineSpec.get_oauth2_fresh_token("refresh-token") == { + assert GSheetsEngineSpec.get_oauth2_fresh_token(oauth2_config, "refresh-token") == { "access_token": "access-token", "expires_in": 3600, "scope": "scope", "token_type": "Bearer", "refresh_token": "refresh-token", } - http.request.assert_called_with( - "POST", + requests.post.assert_called_with( "https://oauth2.googleapis.com/token", - fields={ + json={ "client_id": "XXX.apps.googleusercontent.com", "client_secret": "GOCSPX-YYY", "refresh_token": "refresh-token", "grant_type": "refresh_token", }, + timeout=30.0, ) diff --git a/tests/unit_tests/extensions/test_sqlalchemy.py b/tests/unit_tests/extensions/test_sqlalchemy.py index c0fd49f9eb0e..df36dc44ef60 100644 --- a/tests/unit_tests/extensions/test_sqlalchemy.py +++ b/tests/unit_tests/extensions/test_sqlalchemy.py @@ -124,7 +124,11 @@ def test_superset_limit(mocker: MockFixture, app_context: None, table1: None) -> """ mocker.patch( "superset.extensions.metadb.current_app.config", - {"DB_SQLA_URI_VALIDATOR": None, "SUPERSET_META_DB_LIMIT": 1}, + { + "DB_SQLA_URI_VALIDATOR": None, + "SUPERSET_META_DB_LIMIT": 1, + "DATABASE_OAUTH2_CLIENTS": {}, + }, ) mocker.patch("superset.extensions.metadb.security_manager") diff --git a/tests/unit_tests/utils/oauth2_tests.py b/tests/unit_tests/utils/oauth2_tests.py index 6c859a538f04..19e8ad5aa639 100644 --- a/tests/unit_tests/utils/oauth2_tests.py +++ b/tests/unit_tests/utils/oauth2_tests.py @@ -33,7 +33,7 @@ def test_get_oauth2_access_token_base_no_token(mocker: MockerFixture) -> None: db_engine_spec = mocker.MagicMock() db.session.query().filter_by().one_or_none.return_value = None - assert get_oauth2_access_token(1, 1, db_engine_spec) is None + assert get_oauth2_access_token({}, 1, 1, db_engine_spec) is None def test_get_oauth2_access_token_base_token_valid(mocker: MockerFixture) -> None: @@ -48,7 +48,7 @@ def test_get_oauth2_access_token_base_token_valid(mocker: MockerFixture) -> None db.session.query().filter_by().one_or_none.return_value = token with freeze_time("2024-01-01"): - assert get_oauth2_access_token(1, 1, db_engine_spec) == "access-token" + assert get_oauth2_access_token({}, 1, 1, db_engine_spec) == "access-token" def test_get_oauth2_access_token_base_refresh(mocker: MockerFixture) -> None: @@ -68,7 +68,7 @@ def test_get_oauth2_access_token_base_refresh(mocker: MockerFixture) -> None: db.session.query().filter_by().one_or_none.return_value = token with freeze_time("2024-01-02"): - assert get_oauth2_access_token(1, 1, db_engine_spec) == "new-token" + assert get_oauth2_access_token({}, 1, 1, db_engine_spec) == "new-token" # check that token was updated assert token.access_token == "new-token" @@ -89,7 +89,7 @@ def test_get_oauth2_access_token_base_no_refresh(mocker: MockerFixture) -> None: db.session.query().filter_by().one_or_none.return_value = token with freeze_time("2024-01-02"): - assert get_oauth2_access_token(1, 1, db_engine_spec) is None + assert get_oauth2_access_token({}, 1, 1, db_engine_spec) is None # check that token was deleted db.session.delete.assert_called_with(token)