diff --git a/superset/commands/dashboard/permalink/create.py b/superset/commands/dashboard/permalink/create.py index 7d08f78e9a9b..20bc5118f576 100644 --- a/superset/commands/dashboard/permalink/create.py +++ b/superset/commands/dashboard/permalink/create.py @@ -19,9 +19,10 @@ from sqlalchemy.exc import SQLAlchemyError +from superset import db from superset.commands.dashboard.permalink.base import BaseDashboardPermalinkCommand -from superset.commands.key_value.upsert import UpsertKeyValueCommand from superset.daos.dashboard import DashboardDAO +from superset.daos.key_value import KeyValueDAO from superset.dashboards.permalink.exceptions import DashboardPermalinkCreateFailedError from superset.dashboards.permalink.types import DashboardPermalinkState from superset.key_value.exceptions import ( @@ -70,14 +71,15 @@ def run(self) -> str: "state": self.state, } user_id = get_user_id() - key = UpsertKeyValueCommand( + entry = KeyValueDAO.upsert_entry( resource=self.resource, key=get_deterministic_uuid(self.salt, (user_id, value)), value=value, codec=self.codec, - ).run() - assert key.id # for type checks - return encode_permalink_key(key=key.id, salt=self.salt) + ) + db.session.flush() + assert entry.id # for type checks + return encode_permalink_key(key=entry.id, salt=self.salt) def validate(self) -> None: pass diff --git a/superset/commands/dashboard/permalink/get.py b/superset/commands/dashboard/permalink/get.py index 32efa688813c..e87711a5bfeb 100644 --- a/superset/commands/dashboard/permalink/get.py +++ b/superset/commands/dashboard/permalink/get.py @@ -21,8 +21,8 @@ from superset.commands.dashboard.exceptions import DashboardNotFoundError from superset.commands.dashboard.permalink.base import BaseDashboardPermalinkCommand -from superset.commands.key_value.get import GetKeyValueCommand from superset.daos.dashboard import DashboardDAO +from superset.daos.key_value import KeyValueDAO from superset.dashboards.permalink.exceptions import DashboardPermalinkGetFailedError from superset.dashboards.permalink.types import DashboardPermalinkValue from superset.key_value.exceptions import ( @@ -43,12 +43,7 @@ def run(self) -> Optional[DashboardPermalinkValue]: self.validate() try: key = decode_permalink_id(self.key, salt=self.salt) - command = GetKeyValueCommand( - resource=self.resource, - key=key, - codec=self.codec, - ) - value: Optional[DashboardPermalinkValue] = command.run() + value = KeyValueDAO.get_value(self.resource, key, self.codec) if value: DashboardDAO.get_by_id_or_slug(value["dashboardId"]) return value diff --git a/superset/commands/distributed_lock/__init__.py b/superset/commands/distributed_lock/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/integration_tests/key_value/__init__.py b/superset/commands/distributed_lock/base.py similarity index 54% rename from tests/integration_tests/key_value/__init__.py rename to superset/commands/distributed_lock/base.py index 13a83393a912..322063f54e89 100644 --- a/tests/integration_tests/key_value/__init__.py +++ b/superset/commands/distributed_lock/base.py @@ -14,3 +14,28 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +import logging +import uuid +from typing import Any + +from flask import current_app + +from superset.commands.base import BaseCommand +from superset.distributed_lock.utils import get_key +from superset.key_value.types import JsonKeyValueCodec, KeyValueResource + +logger = logging.getLogger(__name__) +stats_logger = current_app.config["STATS_LOGGER"] + + +class BaseDistributedLockCommand(BaseCommand): + key: uuid.UUID + codec = JsonKeyValueCodec() + resource = KeyValueResource.LOCK + + def __init__(self, namespace: str, params: dict[str, Any] | None = None): + self.key = get_key(namespace, **(params or {})) + + def validate(self) -> None: + pass diff --git a/superset/commands/distributed_lock/create.py b/superset/commands/distributed_lock/create.py new file mode 100644 index 000000000000..c654089336af --- /dev/null +++ b/superset/commands/distributed_lock/create.py @@ -0,0 +1,64 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import logging +from datetime import datetime, timedelta +from functools import partial + +from flask import current_app +from sqlalchemy.exc import SQLAlchemyError + +from superset.commands.distributed_lock.base import BaseDistributedLockCommand +from superset.daos.key_value import KeyValueDAO +from superset.exceptions import CreateKeyValueDistributedLockFailedException +from superset.key_value.exceptions import ( + KeyValueCodecEncodeException, + KeyValueUpsertFailedError, +) +from superset.key_value.types import KeyValueResource +from superset.utils.decorators import on_error, transaction + +logger = logging.getLogger(__name__) +stats_logger = current_app.config["STATS_LOGGER"] + + +class CreateDistributedLock(BaseDistributedLockCommand): + lock_expiration = timedelta(seconds=30) + + def validate(self) -> None: + pass + + @transaction( + on_error=partial( + on_error, + catches=( + KeyValueCodecEncodeException, + KeyValueUpsertFailedError, + SQLAlchemyError, + ), + reraise=CreateKeyValueDistributedLockFailedException, + ), + ) + def run(self) -> None: + KeyValueDAO.delete_expired_entries(self.resource) + KeyValueDAO.create_entry( + resource=KeyValueResource.LOCK, + value={"value": True}, + codec=self.codec, + key=self.key, + expires_on=datetime.now() + self.lock_expiration, + ) diff --git a/superset/commands/key_value/delete_expired.py b/superset/commands/distributed_lock/delete.py similarity index 51% rename from superset/commands/key_value/delete_expired.py rename to superset/commands/distributed_lock/delete.py index 54991c7531d2..cd279dbe2409 100644 --- a/superset/commands/key_value/delete_expired.py +++ b/superset/commands/distributed_lock/delete.py @@ -14,49 +14,36 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + import logging -from datetime import datetime from functools import partial -from sqlalchemy import and_ +from flask import current_app +from sqlalchemy.exc import SQLAlchemyError -from superset import db -from superset.commands.base import BaseCommand +from superset.commands.distributed_lock.base import BaseDistributedLockCommand +from superset.daos.key_value import KeyValueDAO +from superset.exceptions import DeleteKeyValueDistributedLockFailedException from superset.key_value.exceptions import KeyValueDeleteFailedError -from superset.key_value.models import KeyValueEntry -from superset.key_value.types import KeyValueResource from superset.utils.decorators import on_error, transaction logger = logging.getLogger(__name__) +stats_logger = current_app.config["STATS_LOGGER"] -class DeleteExpiredKeyValueCommand(BaseCommand): - resource: KeyValueResource - - def __init__(self, resource: KeyValueResource): - """ - Delete all expired key-value pairs - - :param resource: the resource (dashboard, chart etc) - :return: was the entry deleted or not - """ - self.resource = resource - - @transaction(on_error=partial(on_error, reraise=KeyValueDeleteFailedError)) - def run(self) -> None: - self.delete_expired() - +class DeleteDistributedLock(BaseDistributedLockCommand): def validate(self) -> None: pass - def delete_expired(self) -> None: - ( - db.session.query(KeyValueEntry) - .filter( - and_( - KeyValueEntry.resource == self.resource.value, - KeyValueEntry.expires_on <= datetime.now(), - ) - ) - .delete() - ) + @transaction( + on_error=partial( + on_error, + catches=( + KeyValueDeleteFailedError, + SQLAlchemyError, + ), + reraise=DeleteKeyValueDistributedLockFailedException, + ), + ) + def run(self) -> None: + KeyValueDAO.delete_entry(self.resource, self.key) diff --git a/superset/commands/distributed_lock/get.py b/superset/commands/distributed_lock/get.py new file mode 100644 index 000000000000..562456410935 --- /dev/null +++ b/superset/commands/distributed_lock/get.py @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import logging +from typing import cast + +from flask import current_app + +from superset.commands.distributed_lock.base import BaseDistributedLockCommand +from superset.daos.key_value import KeyValueDAO +from superset.distributed_lock.types import LockValue + +logger = logging.getLogger(__name__) +stats_logger = current_app.config["STATS_LOGGER"] + + +class GetDistributedLock(BaseDistributedLockCommand): + def validate(self) -> None: + pass + + def run(self) -> LockValue | None: + entry = KeyValueDAO.get_entry( + resource=self.resource, + key=self.key, + ) + if not entry or entry.is_expired(): + return None + + return cast(LockValue, self.codec.decode(entry.value)) diff --git a/superset/commands/explore/permalink/create.py b/superset/commands/explore/permalink/create.py index 2128fa4b8c40..926b9ba919f4 100644 --- a/superset/commands/explore/permalink/create.py +++ b/superset/commands/explore/permalink/create.py @@ -20,8 +20,9 @@ from sqlalchemy.exc import SQLAlchemyError +from superset import db from superset.commands.explore.permalink.base import BaseExplorePermalinkCommand -from superset.commands.key_value.create import CreateKeyValueCommand +from superset.daos.key_value import KeyValueDAO from superset.explore.permalink.exceptions import ExplorePermalinkCreateFailedError from superset.explore.utils import check_access as check_chart_access from superset.key_value.exceptions import ( @@ -65,15 +66,12 @@ def run(self) -> str: "datasource": self.datasource, "state": self.state, } - command = CreateKeyValueCommand( - resource=self.resource, - value=value, - codec=self.codec, - ) - key = command.run() - if key.id is None: + entry = KeyValueDAO.create_entry(self.resource, value, self.codec) + db.session.flush() + key = entry.id + if key is None: raise ExplorePermalinkCreateFailedError("Unexpected missing key id") - return encode_permalink_key(key=key.id, salt=self.salt) + return encode_permalink_key(key=key, salt=self.salt) def validate(self) -> None: pass diff --git a/superset/commands/explore/permalink/get.py b/superset/commands/explore/permalink/get.py index 4c01db1ccab4..7dc1db40df24 100644 --- a/superset/commands/explore/permalink/get.py +++ b/superset/commands/explore/permalink/get.py @@ -21,7 +21,7 @@ from superset.commands.dataset.exceptions import DatasetNotFoundError from superset.commands.explore.permalink.base import BaseExplorePermalinkCommand -from superset.commands.key_value.get import GetKeyValueCommand +from superset.daos.key_value import KeyValueDAO from superset.explore.permalink.exceptions import ExplorePermalinkGetFailedError from superset.explore.permalink.types import ExplorePermalinkValue from superset.explore.utils import check_access as check_chart_access @@ -44,11 +44,7 @@ def run(self) -> Optional[ExplorePermalinkValue]: self.validate() try: key = decode_permalink_id(self.key, salt=self.salt) - value: Optional[ExplorePermalinkValue] = GetKeyValueCommand( - resource=self.resource, - key=key, - codec=self.codec, - ).run() + value = KeyValueDAO.get_value(self.resource, key, self.codec) if value: chart_id: Optional[int] = value.get("chartId") # keep this backward compatible for old permalinks diff --git a/superset/commands/key_value/create.py b/superset/commands/key_value/create.py deleted file mode 100644 index 81b7c4c3d4a9..000000000000 --- a/superset/commands/key_value/create.py +++ /dev/null @@ -1,102 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import logging -from datetime import datetime -from functools import partial -from typing import Any, Optional, Union -from uuid import UUID - -from superset import db -from superset.commands.base import BaseCommand -from superset.key_value.exceptions import KeyValueCreateFailedError -from superset.key_value.models import KeyValueEntry -from superset.key_value.types import Key, KeyValueCodec, KeyValueResource -from superset.utils.core import get_user_id -from superset.utils.decorators import on_error, transaction - -logger = logging.getLogger(__name__) - - -class CreateKeyValueCommand(BaseCommand): - resource: KeyValueResource - value: Any - codec: KeyValueCodec - key: Optional[Union[int, UUID]] - expires_on: Optional[datetime] - - def __init__( # pylint: disable=too-many-arguments - self, - resource: KeyValueResource, - value: Any, - codec: KeyValueCodec, - key: Optional[Union[int, UUID]] = None, - expires_on: Optional[datetime] = None, - ): - """ - Create a new key-value pair - - :param resource: the resource (dashboard, chart etc) - :param value: the value to persist in the key-value store - :param codec: codec used to encode the value - :param key: id of entry (autogenerated if undefined) - :param expires_on: entry expiration time - : - """ - self.resource = resource - self.value = value - self.codec = codec - self.key = key - self.expires_on = expires_on - - @transaction(on_error=partial(on_error, reraise=KeyValueCreateFailedError)) - def run(self) -> Key: - """ - Persist the value - - :return: the key associated with the persisted value - - """ - - return self.create() - - def validate(self) -> None: - pass - - def create(self) -> Key: - try: - value = self.codec.encode(self.value) - except Exception as ex: - raise KeyValueCreateFailedError("Unable to encode value") from ex - entry = KeyValueEntry( - resource=self.resource.value, - value=value, - created_on=datetime.now(), - created_by_fk=get_user_id(), - expires_on=self.expires_on, - ) - if self.key is not None: - try: - if isinstance(self.key, UUID): - entry.uuid = self.key - else: - entry.id = self.key - except ValueError as ex: - raise KeyValueCreateFailedError() from ex - - db.session.add(entry) - db.session.flush() - return Key(id=entry.id, uuid=entry.uuid) diff --git a/superset/commands/key_value/delete.py b/superset/commands/key_value/delete.py deleted file mode 100644 index a3fdf079c73c..000000000000 --- a/superset/commands/key_value/delete.py +++ /dev/null @@ -1,63 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import logging -from functools import partial -from typing import Union -from uuid import UUID - -from superset import db -from superset.commands.base import BaseCommand -from superset.key_value.exceptions import KeyValueDeleteFailedError -from superset.key_value.models import KeyValueEntry -from superset.key_value.types import KeyValueResource -from superset.key_value.utils import get_filter -from superset.utils.decorators import on_error, transaction - -logger = logging.getLogger(__name__) - - -class DeleteKeyValueCommand(BaseCommand): - key: Union[int, UUID] - resource: KeyValueResource - - def __init__(self, resource: KeyValueResource, key: Union[int, UUID]): - """ - Delete a key-value pair - - :param resource: the resource (dashboard, chart etc) - :param key: the key to delete - :return: was the entry deleted or not - """ - self.resource = resource - self.key = key - - @transaction(on_error=partial(on_error, reraise=KeyValueDeleteFailedError)) - def run(self) -> bool: - return self.delete() - - def validate(self) -> None: - pass - - def delete(self) -> bool: - if ( - entry := db.session.query(KeyValueEntry) - .filter_by(**get_filter(self.resource, self.key)) - .first() - ): - db.session.delete(entry) - return True - return False diff --git a/superset/commands/key_value/get.py b/superset/commands/key_value/get.py deleted file mode 100644 index 93550ee840c3..000000000000 --- a/superset/commands/key_value/get.py +++ /dev/null @@ -1,71 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import logging -from typing import Any, Optional, Union -from uuid import UUID - -from sqlalchemy.exc import SQLAlchemyError - -from superset import db -from superset.commands.base import BaseCommand -from superset.key_value.exceptions import KeyValueGetFailedError -from superset.key_value.models import KeyValueEntry -from superset.key_value.types import KeyValueCodec, KeyValueResource -from superset.key_value.utils import get_filter - -logger = logging.getLogger(__name__) - - -class GetKeyValueCommand(BaseCommand): - resource: KeyValueResource - key: Union[int, UUID] - codec: KeyValueCodec - - def __init__( - self, - resource: KeyValueResource, - key: Union[int, UUID], - codec: KeyValueCodec, - ): - """ - Retrieve a key value entry - - :param resource: the resource (dashboard, chart etc) - :param key: the key to retrieve - :param codec: codec used to decode the value - :return: the value associated with the key if present - """ - self.resource = resource - self.key = key - self.codec = codec - - def run(self) -> Any: - try: - return self.get() - except SQLAlchemyError as ex: - raise KeyValueGetFailedError() from ex - - def validate(self) -> None: - pass - - def get(self) -> Optional[Any]: - filter_ = get_filter(self.resource, self.key) - entry = db.session.query(KeyValueEntry).filter_by(**filter_).first() - if entry and not entry.is_expired(): - return self.codec.decode(entry.value) - return None diff --git a/superset/commands/key_value/update.py b/superset/commands/key_value/update.py deleted file mode 100644 index b6ffc22174f6..000000000000 --- a/superset/commands/key_value/update.py +++ /dev/null @@ -1,87 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import logging -from datetime import datetime -from functools import partial -from typing import Any, Optional, Union -from uuid import UUID - -from superset import db -from superset.commands.base import BaseCommand -from superset.key_value.exceptions import KeyValueUpdateFailedError -from superset.key_value.models import KeyValueEntry -from superset.key_value.types import Key, KeyValueCodec, KeyValueResource -from superset.key_value.utils import get_filter -from superset.utils.core import get_user_id -from superset.utils.decorators import on_error, transaction - -logger = logging.getLogger(__name__) - - -class UpdateKeyValueCommand(BaseCommand): - resource: KeyValueResource - value: Any - codec: KeyValueCodec - key: Union[int, UUID] - expires_on: Optional[datetime] - - def __init__( # pylint: disable=too-many-arguments - self, - resource: KeyValueResource, - key: Union[int, UUID], - value: Any, - codec: KeyValueCodec, - expires_on: Optional[datetime] = None, - ): - """ - Update a key value entry - - :param resource: the resource (dashboard, chart etc) - :param key: the key to update - :param value: the value to persist in the key-value store - :param codec: codec used to encode the value - :param expires_on: entry expiration time - :return: the key associated with the updated value - """ - self.resource = resource - self.key = key - self.value = value - self.codec = codec - self.expires_on = expires_on - - @transaction(on_error=partial(on_error, reraise=KeyValueUpdateFailedError)) - def run(self) -> Optional[Key]: - return self.update() - - def validate(self) -> None: - pass - - def update(self) -> Optional[Key]: - filter_ = get_filter(self.resource, self.key) - entry: KeyValueEntry = ( - db.session.query(KeyValueEntry).filter_by(**filter_).first() - ) - if entry: - entry.value = self.codec.encode(self.value) - entry.expires_on = self.expires_on - entry.changed_on = datetime.now() - entry.changed_by_fk = get_user_id() - db.session.flush() - return Key(id=entry.id, uuid=entry.uuid) - - return None diff --git a/superset/commands/key_value/upsert.py b/superset/commands/key_value/upsert.py deleted file mode 100644 index 32918d9b1439..000000000000 --- a/superset/commands/key_value/upsert.py +++ /dev/null @@ -1,104 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import logging -from datetime import datetime -from functools import partial -from typing import Any, Optional, Union -from uuid import UUID - -from sqlalchemy.exc import SQLAlchemyError - -from superset import db -from superset.commands.base import BaseCommand -from superset.commands.key_value.create import CreateKeyValueCommand -from superset.key_value.exceptions import ( - KeyValueCreateFailedError, - KeyValueUpsertFailedError, -) -from superset.key_value.models import KeyValueEntry -from superset.key_value.types import Key, KeyValueCodec, KeyValueResource -from superset.key_value.utils import get_filter -from superset.utils.core import get_user_id -from superset.utils.decorators import on_error, transaction - -logger = logging.getLogger(__name__) - - -class UpsertKeyValueCommand(BaseCommand): - resource: KeyValueResource - value: Any - key: Union[int, UUID] - codec: KeyValueCodec - expires_on: Optional[datetime] - - def __init__( # pylint: disable=too-many-arguments - self, - resource: KeyValueResource, - key: Union[int, UUID], - value: Any, - codec: KeyValueCodec, - expires_on: Optional[datetime] = None, - ): - """ - Upsert a key value entry - - :param resource: the resource (dashboard, chart etc) - :param key: the key to update - :param value: the value to persist in the key-value store - :param codec: codec used to encode the value - :param expires_on: entry expiration time - :return: the key associated with the updated value - """ - self.resource = resource - self.key = key - self.value = value - self.codec = codec - self.expires_on = expires_on - - @transaction( - on_error=partial( - on_error, - catches=(KeyValueCreateFailedError, SQLAlchemyError), - reraise=KeyValueUpsertFailedError, - ), - ) - def run(self) -> Key: - return self.upsert() - - def validate(self) -> None: - pass - - def upsert(self) -> Key: - if ( - entry := db.session.query(KeyValueEntry) - .filter_by(**get_filter(self.resource, self.key)) - .first() - ): - entry.value = self.codec.encode(self.value) - entry.expires_on = self.expires_on - entry.changed_on = datetime.now() - entry.changed_by_fk = get_user_id() - return Key(entry.id, entry.uuid) - - return CreateKeyValueCommand( - resource=self.resource, - value=self.value, - codec=self.codec, - key=self.key, - expires_on=self.expires_on, - ).run() diff --git a/superset/daos/key_value.py b/superset/daos/key_value.py new file mode 100644 index 000000000000..f15293abcab8 --- /dev/null +++ b/superset/daos/key_value.py @@ -0,0 +1,145 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import logging +from datetime import datetime +from typing import Any +from uuid import UUID + +from sqlalchemy import and_ + +from superset import db +from superset.daos.base import BaseDAO +from superset.key_value.exceptions import ( + KeyValueCreateFailedError, + KeyValueUpdateFailedError, +) +from superset.key_value.models import KeyValueEntry +from superset.key_value.types import Key, KeyValueCodec, KeyValueResource +from superset.key_value.utils import get_filter +from superset.utils.core import get_user_id + +logger = logging.getLogger(__name__) + + +class KeyValueDAO(BaseDAO[KeyValueEntry]): + @staticmethod + def get_entry( + resource: KeyValueResource, + key: Key, + ) -> KeyValueEntry | None: + filter_ = get_filter(resource, key) + return db.session.query(KeyValueEntry).filter_by(**filter_).first() + + @classmethod + def get_value( + cls, + resource: KeyValueResource, + key: Key, + codec: KeyValueCodec, + ) -> Any: + entry = cls.get_entry(resource, key) + if not entry or entry.is_expired(): + return None + + return codec.decode(entry.value) + + @staticmethod + def delete_entry(resource: KeyValueResource, key: Key) -> bool: + if entry := KeyValueDAO.get_entry(resource, key): + db.session.delete(entry) + return True + + return False + + @staticmethod + def delete_expired_entries(resource: KeyValueResource) -> None: + ( + db.session.query(KeyValueEntry) + .filter( + and_( + KeyValueEntry.resource == resource.value, + KeyValueEntry.expires_on <= datetime.now(), + ) + ) + .delete() + ) + + @staticmethod + def create_entry( + resource: KeyValueResource, + value: Any, + codec: KeyValueCodec, + key: Key | None = None, + expires_on: datetime | None = None, + ) -> KeyValueEntry: + try: + encoded_value = codec.encode(value) + except Exception as ex: + raise KeyValueCreateFailedError("Unable to encode value") from ex + entry = KeyValueEntry( + resource=resource.value, + value=encoded_value, + created_on=datetime.now(), + created_by_fk=get_user_id(), + expires_on=expires_on, + ) + if key is not None: + try: + if isinstance(key, UUID): + entry.uuid = key + else: + entry.id = key + except ValueError as ex: + raise KeyValueCreateFailedError() from ex + db.session.add(entry) + return entry + + @staticmethod + def upsert_entry( + resource: KeyValueResource, + value: Any, + codec: KeyValueCodec, + key: Key, + expires_on: datetime | None = None, + ) -> KeyValueEntry: + if entry := KeyValueDAO.get_entry(resource, key): + entry.value = codec.encode(value) + entry.expires_on = expires_on + entry.changed_on = datetime.now() + entry.changed_by_fk = get_user_id() + return entry + + return KeyValueDAO.create_entry(resource, value, codec, key, expires_on) + + @staticmethod + def update_entry( + resource: KeyValueResource, + value: Any, + codec: KeyValueCodec, + key: Key, + expires_on: datetime | None = None, + ) -> KeyValueEntry: + if entry := KeyValueDAO.get_entry(resource, key): + entry.value = codec.encode(value) + entry.expires_on = expires_on + entry.changed_on = datetime.now() + entry.changed_by_fk = get_user_id() + return entry + + raise KeyValueUpdateFailedError() diff --git a/superset/utils/lock.py b/superset/distributed_lock/__init__.py similarity index 55% rename from superset/utils/lock.py rename to superset/distributed_lock/__init__.py index 4723b57fa1b0..c4af73ac0f09 100644 --- a/superset/utils/lock.py +++ b/superset/distributed_lock/__init__.py @@ -21,40 +21,18 @@ import uuid from collections.abc import Iterator from contextlib import contextmanager -from datetime import datetime, timedelta -from typing import Any, cast, TypeVar, Union +from datetime import timedelta +from typing import Any +from superset.distributed_lock.utils import get_key from superset.exceptions import CreateKeyValueDistributedLockFailedException -from superset.key_value.exceptions import KeyValueCreateFailedError from superset.key_value.types import JsonKeyValueCodec, KeyValueResource -from superset.utils import json -LOCK_EXPIRATION = timedelta(seconds=30) logger = logging.getLogger(__name__) - -def serialize(params: dict[str, Any]) -> str: - """ - Serialize parameters into a string. - """ - - T = TypeVar( - "T", - bound=Union[dict[str, Any], list[Any], int, float, str, bool, None], - ) - - def sort(obj: T) -> T: - if isinstance(obj, dict): - return cast(T, {k: sort(v) for k, v in sorted(obj.items())}) - if isinstance(obj, list): - return cast(T, [sort(x) for x in obj]) - return obj - - return json.dumps(params) - - -def get_key(namespace: str, **kwargs: Any) -> uuid.UUID: - return uuid.uuid5(uuid.uuid5(uuid.NAMESPACE_DNS, namespace), serialize(kwargs)) +CODEC = JsonKeyValueCodec() +LOCK_EXPIRATION = timedelta(seconds=30) +RESOURCE = KeyValueResource.LOCK @contextmanager @@ -75,28 +53,25 @@ def KeyValueDistributedLock( # pylint: disable=invalid-name :yields: A unique identifier (UUID) for the acquired lock (the KV key). :raises CreateKeyValueDistributedLockFailedException: If the lock is taken. """ + # pylint: disable=import-outside-toplevel - from superset.commands.key_value.create import CreateKeyValueCommand - from superset.commands.key_value.delete import DeleteKeyValueCommand - from superset.commands.key_value.delete_expired import DeleteExpiredKeyValueCommand + from superset.commands.distributed_lock.create import CreateDistributedLock + from superset.commands.distributed_lock.delete import DeleteDistributedLock + from superset.commands.distributed_lock.get import GetDistributedLock key = get_key(namespace, **kwargs) + value = GetDistributedLock(namespace=namespace, params=kwargs).run() + if value: + logger.debug("Lock on namespace %s for key %s already taken", namespace, key) + raise CreateKeyValueDistributedLockFailedException("Lock already taken") + logger.debug("Acquiring lock on namespace %s for key %s", namespace, key) try: - DeleteExpiredKeyValueCommand(resource=KeyValueResource.LOCK).run() - CreateKeyValueCommand( - resource=KeyValueResource.LOCK, - codec=JsonKeyValueCodec(), - key=key, - value=True, - expires_on=datetime.now() + LOCK_EXPIRATION, - ).run() - - yield key - - DeleteKeyValueCommand(resource=KeyValueResource.LOCK, key=key).run() - logger.debug("Removed lock on namespace %s for key %s", namespace, key) - except KeyValueCreateFailedError as ex: - raise CreateKeyValueDistributedLockFailedException( - "Error acquiring lock" - ) from ex + CreateDistributedLock(namespace=namespace, params=kwargs).run() + except CreateKeyValueDistributedLockFailedException as ex: + logger.debug("Lock on namespace %s for key %s already taken", namespace, key) + raise CreateKeyValueDistributedLockFailedException("Lock already taken") from ex + + yield key + DeleteDistributedLock(namespace=namespace, params=kwargs).run() + logger.debug("Removed lock on namespace %s for key %s", namespace, key) diff --git a/superset/commands/key_value/__init__.py b/superset/distributed_lock/types.py similarity index 91% rename from superset/commands/key_value/__init__.py rename to superset/distributed_lock/types.py index 13a83393a912..b714913e8e8b 100644 --- a/superset/commands/key_value/__init__.py +++ b/superset/distributed_lock/types.py @@ -14,3 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from typing import TypedDict + + +class LockValue(TypedDict): + value: bool diff --git a/tests/integration_tests/key_value/commands/__init__.py b/superset/distributed_lock/utils.py similarity index 52% rename from tests/integration_tests/key_value/commands/__init__.py rename to superset/distributed_lock/utils.py index 13a83393a912..09ed12d704d9 100644 --- a/tests/integration_tests/key_value/commands/__init__.py +++ b/superset/distributed_lock/utils.py @@ -14,3 +14,32 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +import uuid +from typing import Any, cast, TypeVar, Union + +from superset.utils import json + + +def serialize(params: dict[str, Any]) -> str: + """ + Serialize parameters into a string. + """ + + T = TypeVar( + "T", + bound=Union[dict[str, Any], list[Any], int, float, str, bool, None], + ) + + def sort(obj: T) -> T: + if isinstance(obj, dict): + return cast(T, {k: sort(v) for k, v in sorted(obj.items())}) + if isinstance(obj, list): + return cast(T, [sort(x) for x in obj]) + return obj + + return json.dumps(params) + + +def get_key(namespace: str, **kwargs: Any) -> uuid.UUID: + return uuid.uuid5(uuid.uuid5(uuid.NAMESPACE_DNS, namespace), serialize(kwargs)) diff --git a/superset/exceptions.py b/superset/exceptions.py index 47cd511f8f20..dd669f5b72ae 100644 --- a/superset/exceptions.py +++ b/superset/exceptions.py @@ -379,6 +379,12 @@ class CreateKeyValueDistributedLockFailedException(Exception): """ +class DeleteKeyValueDistributedLockFailedException(Exception): + """ + Exception to signalize failure to delete lock. + """ + + class DatabaseNotFoundException(SupersetErrorException): status = 404 diff --git a/superset/extensions/metastore_cache.py b/superset/extensions/metastore_cache.py index 1c89e8459774..1195bd8edf65 100644 --- a/superset/extensions/metastore_cache.py +++ b/superset/extensions/metastore_cache.py @@ -21,7 +21,10 @@ from flask import current_app, Flask, has_app_context from flask_caching import BaseCache +from sqlalchemy.exc import SQLAlchemyError +from superset import db +from superset.daos.key_value import KeyValueDAO from superset.key_value.exceptions import KeyValueCreateFailedError from superset.key_value.types import ( KeyValueCodec, @@ -29,6 +32,7 @@ PickleKeyValueCodec, ) from superset.key_value.utils import get_uuid_namespace +from superset.utils.decorators import transaction RESOURCE = KeyValueResource.METASTORE_CACHE @@ -68,15 +72,6 @@ def factory( def get_key(self, key: str) -> UUID: return uuid3(self.namespace, key) - @staticmethod - def _prune() -> None: - # pylint: disable=import-outside-toplevel - from superset.commands.key_value.delete_expired import ( - DeleteExpiredKeyValueCommand, - ) - - DeleteExpiredKeyValueCommand(resource=RESOURCE).run() - def _get_expiry(self, timeout: Optional[int]) -> Optional[datetime]: timeout = self._normalize_timeout(timeout) if timeout is not None and timeout > 0: @@ -84,44 +79,34 @@ def _get_expiry(self, timeout: Optional[int]) -> Optional[datetime]: return None def set(self, key: str, value: Any, timeout: Optional[int] = None) -> bool: - # pylint: disable=import-outside-toplevel - from superset.commands.key_value.upsert import UpsertKeyValueCommand - - UpsertKeyValueCommand( + KeyValueDAO.upsert_entry( resource=RESOURCE, key=self.get_key(key), value=value, codec=self.codec, expires_on=self._get_expiry(timeout), - ).run() + ) + db.session.commit() # pylint: disable=consider-using-transaction return True def add(self, key: str, value: Any, timeout: Optional[int] = None) -> bool: - # pylint: disable=import-outside-toplevel - from superset.commands.key_value.create import CreateKeyValueCommand - try: - self._prune() - CreateKeyValueCommand( + KeyValueDAO.delete_expired_entries(RESOURCE) + KeyValueDAO.create_entry( resource=RESOURCE, value=value, codec=self.codec, key=self.get_key(key), expires_on=self._get_expiry(timeout), - ).run() + ) + db.session.commit() # pylint: disable=consider-using-transaction return True - except KeyValueCreateFailedError: + except (SQLAlchemyError, KeyValueCreateFailedError): + db.session.rollback() # pylint: disable=consider-using-transaction return False def get(self, key: str) -> Any: - # pylint: disable=import-outside-toplevel - from superset.commands.key_value.get import GetKeyValueCommand - - return GetKeyValueCommand( - resource=RESOURCE, - key=self.get_key(key), - codec=self.codec, - ).run() + return KeyValueDAO.get_value(RESOURCE, self.get_key(key), self.codec) def has(self, key: str) -> bool: entry = self.get(key) @@ -129,8 +114,6 @@ def has(self, key: str) -> bool: return True return False + @transaction() def delete(self, key: str) -> Any: - # pylint: disable=import-outside-toplevel - from superset.commands.key_value.delete import DeleteKeyValueCommand - - return DeleteKeyValueCommand(resource=RESOURCE, key=self.get_key(key)).run() + return KeyValueDAO.delete_entry(RESOURCE, self.get_key(key)) diff --git a/superset/key_value/shared_entries.py b/superset/key_value/shared_entries.py index 130313157a53..c2acafa80752 100644 --- a/superset/key_value/shared_entries.py +++ b/superset/key_value/shared_entries.py @@ -18,8 +18,10 @@ from typing import Any, Optional from uuid import uuid3 +from superset.daos.key_value import KeyValueDAO from superset.key_value.types import JsonKeyValueCodec, KeyValueResource, SharedKey from superset.key_value.utils import get_uuid_namespace, random_key +from superset.utils.decorators import transaction RESOURCE = KeyValueResource.APP NAMESPACE = get_uuid_namespace("") @@ -27,24 +29,14 @@ def get_shared_value(key: SharedKey) -> Optional[Any]: - # pylint: disable=import-outside-toplevel - from superset.commands.key_value.get import GetKeyValueCommand - uuid_key = uuid3(NAMESPACE, key) - return GetKeyValueCommand(RESOURCE, key=uuid_key, codec=CODEC).run() + return KeyValueDAO.get_value(RESOURCE, uuid_key, CODEC) +@transaction() def set_shared_value(key: SharedKey, value: Any) -> None: - # pylint: disable=import-outside-toplevel - from superset.commands.key_value.create import CreateKeyValueCommand - uuid_key = uuid3(NAMESPACE, key) - CreateKeyValueCommand( - resource=RESOURCE, - value=value, - key=uuid_key, - codec=CODEC, - ).run() + KeyValueDAO.create_entry(RESOURCE, value, CODEC, uuid_key) def get_permalink_salt(key: SharedKey) -> str: diff --git a/superset/key_value/types.py b/superset/key_value/types.py index 7b0130c0e6ce..f6459c330283 100644 --- a/superset/key_value/types.py +++ b/superset/key_value/types.py @@ -19,8 +19,7 @@ import json import pickle from abc import ABC, abstractmethod -from dataclasses import dataclass -from typing import Any, TypedDict +from typing import Any, TypedDict, Union from uuid import UUID from marshmallow import Schema, ValidationError @@ -31,11 +30,7 @@ ) from superset.utils.backports import StrEnum - -@dataclass -class Key: - id: int | None - uuid: UUID | None +Key = Union[int, UUID] class KeyValueFilter(TypedDict, total=False): diff --git a/superset/key_value/utils.py b/superset/key_value/utils.py index 1a22cfaa747b..0a4e63778aa6 100644 --- a/superset/key_value/utils.py +++ b/superset/key_value/utils.py @@ -25,7 +25,7 @@ from flask_babel import gettext as _ from superset.key_value.exceptions import KeyValueParseKeyError -from superset.key_value.types import KeyValueFilter, KeyValueResource +from superset.key_value.types import Key, KeyValueFilter, KeyValueResource from superset.utils.json import json_dumps_w_dates HASHIDS_MIN_LENGTH = 11 @@ -35,7 +35,7 @@ def random_key() -> str: return token_urlsafe(48) -def get_filter(resource: KeyValueResource, key: int | UUID) -> KeyValueFilter: +def get_filter(resource: KeyValueResource, key: Key) -> KeyValueFilter: try: filter_: KeyValueFilter = {"resource": resource.value} if isinstance(key, UUID): diff --git a/superset/utils/oauth2.py b/superset/utils/oauth2.py index bc4805fd8192..b889ef83c5e7 100644 --- a/superset/utils/oauth2.py +++ b/superset/utils/oauth2.py @@ -26,9 +26,9 @@ from marshmallow import EXCLUDE, fields, post_load, Schema from superset import db +from superset.distributed_lock import KeyValueDistributedLock from superset.exceptions import CreateKeyValueDistributedLockFailedException from superset.superset_typing import OAuth2ClientConfig, OAuth2State -from superset.utils.lock import KeyValueDistributedLock if TYPE_CHECKING: from superset.db_engine_specs.base import BaseEngineSpec diff --git a/tests/integration_tests/explore/permalink/commands_tests.py b/tests/integration_tests/explore/permalink/commands_tests.py index 4993e33f1895..c17d8bafdb00 100644 --- a/tests/integration_tests/explore/permalink/commands_tests.py +++ b/tests/integration_tests/explore/permalink/commands_tests.py @@ -133,11 +133,11 @@ def test_get_permalink_command(self, mock_g): assert cache_data.get("datasource") == datasource @patch("superset.security.manager.g") - @patch("superset.commands.key_value.get.GetKeyValueCommand.run") + @patch("superset.daos.key_value.KeyValueDAO.get_value") @patch("superset.commands.explore.permalink.get.decode_permalink_id") @pytest.mark.usefixtures("create_dataset", "create_slice") def test_get_permalink_command_with_old_dataset_key( - self, decode_id_mock, get_kv_command_mock, mock_g + self, decode_id_mock, kv_get_value_mock, mock_g ): mock_g.user = security_manager.find_user("admin") app.config["EXPLORE_FORM_DATA_CACHE_CONFIG"] = { @@ -149,13 +149,14 @@ def test_get_permalink_command_with_old_dataset_key( ) slice = db.session.query(Slice).filter_by(slice_name="slice_name").first() - datasource_string = f"{dataset.id}__{DatasourceType.TABLE}" + datasource_string = f"{dataset.id}__{DatasourceType.TABLE.value}" decode_id_mock.return_value = "123456" - get_kv_command_mock.return_value = { + kv_get_value_mock.return_value = { "chartId": slice.id, "datasetId": dataset.id, "datasource": datasource_string, + "datasourceType": DatasourceType.TABLE.value, "state": { "formData": {"datasource": datasource_string, "slice_id": slice.id} }, diff --git a/tests/integration_tests/extensions/metastore_cache_test.py b/tests/integration_tests/extensions/metastore_cache_test.py index c69340a7a2a3..238e8fd46a50 100644 --- a/tests/integration_tests/extensions/metastore_cache_test.py +++ b/tests/integration_tests/extensions/metastore_cache_test.py @@ -60,6 +60,7 @@ def test_caching_flow(app_context: AppContext, cache: SupersetMetastoreCache) -> assert cache.has(FIRST_KEY) is False assert cache.add(FIRST_KEY, FIRST_KEY_INITIAL_VALUE) is True assert cache.has(FIRST_KEY) is True + assert cache.get(FIRST_KEY) == FIRST_KEY_INITIAL_VALUE cache.set(SECOND_KEY, SECOND_VALUE) assert cache.get(FIRST_KEY) == FIRST_KEY_INITIAL_VALUE assert cache.get(SECOND_KEY) == SECOND_VALUE diff --git a/tests/integration_tests/key_value/commands/create_test.py b/tests/integration_tests/key_value/commands/create_test.py deleted file mode 100644 index b18b9886d6ff..000000000000 --- a/tests/integration_tests/key_value/commands/create_test.py +++ /dev/null @@ -1,96 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import pickle - -import pytest -from flask.ctx import AppContext -from flask_appbuilder.security.sqla.models import User - -from superset.extensions import db -from superset.key_value.exceptions import KeyValueCreateFailedError -from superset.utils import json -from superset.utils.core import override_user -from tests.integration_tests.key_value.commands.fixtures import ( - admin, # noqa: F401 - JSON_CODEC, - JSON_VALUE, - PICKLE_CODEC, - PICKLE_VALUE, - RESOURCE, -) - - -def test_create_id_entry(app_context: AppContext, admin: User) -> None: # noqa: F811 - from superset.commands.key_value.create import CreateKeyValueCommand - from superset.key_value.models import KeyValueEntry - - with override_user(admin): - key = CreateKeyValueCommand( - resource=RESOURCE, - value=JSON_VALUE, - codec=JSON_CODEC, - ).run() - entry = db.session.query(KeyValueEntry).filter_by(id=key.id).one() - assert json.loads(entry.value) == JSON_VALUE - assert entry.created_by_fk == admin.id - db.session.delete(entry) - db.session.commit() - - -def test_create_uuid_entry(app_context: AppContext, admin: User) -> None: # noqa: F811 - from superset.commands.key_value.create import CreateKeyValueCommand - from superset.key_value.models import KeyValueEntry - - with override_user(admin): - key = CreateKeyValueCommand( - resource=RESOURCE, value=JSON_VALUE, codec=JSON_CODEC - ).run() - entry = db.session.query(KeyValueEntry).filter_by(uuid=key.uuid).one() - assert json.loads(entry.value) == JSON_VALUE - assert entry.created_by_fk == admin.id - db.session.delete(entry) - db.session.commit() - - -def test_create_fail_json_entry(app_context: AppContext, admin: User) -> None: # noqa: F811 - from superset.commands.key_value.create import CreateKeyValueCommand - - with pytest.raises(KeyValueCreateFailedError): - CreateKeyValueCommand( - resource=RESOURCE, - value=PICKLE_VALUE, - codec=JSON_CODEC, - ).run() - - -def test_create_pickle_entry(app_context: AppContext, admin: User) -> None: # noqa: F811 - from superset.commands.key_value.create import CreateKeyValueCommand - from superset.key_value.models import KeyValueEntry - - with override_user(admin): - key = CreateKeyValueCommand( - resource=RESOURCE, - value=PICKLE_VALUE, - codec=PICKLE_CODEC, - ).run() - entry = db.session.query(KeyValueEntry).filter_by(id=key.id).one() - assert type(pickle.loads(entry.value)) == type(PICKLE_VALUE) - assert entry.created_by_fk == admin.id - db.session.delete(entry) - db.session.commit() diff --git a/tests/integration_tests/key_value/commands/delete_test.py b/tests/integration_tests/key_value/commands/delete_test.py deleted file mode 100644 index b45a5d075d21..000000000000 --- a/tests/integration_tests/key_value/commands/delete_test.py +++ /dev/null @@ -1,84 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -from typing import TYPE_CHECKING -from uuid import UUID - -import pytest -from flask.ctx import AppContext -from flask_appbuilder.security.sqla.models import User - -from superset.extensions import db -from superset.utils import json -from tests.integration_tests.key_value.commands.fixtures import ( - admin, # noqa: F401 - JSON_VALUE, - RESOURCE, -) - -if TYPE_CHECKING: - from superset.key_value.models import KeyValueEntry - -ID_KEY = 234 -UUID_KEY = UUID("5aae143c-44f1-478e-9153-ae6154df333a") - - -@pytest.fixture -def key_value_entry() -> KeyValueEntry: - from superset.key_value.models import KeyValueEntry - - entry = KeyValueEntry( - id=ID_KEY, - uuid=UUID_KEY, - resource=RESOURCE, - value=bytes(json.dumps(JSON_VALUE), encoding="utf-8"), - ) - db.session.add(entry) - db.session.flush() - return entry - - -def test_delete_id_entry( - app_context: AppContext, - admin: User, # noqa: F811 - key_value_entry: KeyValueEntry, -) -> None: - from superset.commands.key_value.delete import DeleteKeyValueCommand - - assert DeleteKeyValueCommand(resource=RESOURCE, key=ID_KEY).run() is True - db.session.commit() - - -def test_delete_uuid_entry( - app_context: AppContext, - admin: User, # noqa: F811 - key_value_entry: KeyValueEntry, -) -> None: - from superset.commands.key_value.delete import DeleteKeyValueCommand - - assert DeleteKeyValueCommand(resource=RESOURCE, key=UUID_KEY).run() is True - db.session.commit() - - -def test_delete_entry_missing( - app_context: AppContext, - admin: User, # noqa: F811 -) -> None: - from superset.commands.key_value.delete import DeleteKeyValueCommand - - assert DeleteKeyValueCommand(resource=RESOURCE, key=456).run() is False diff --git a/tests/integration_tests/key_value/commands/fixtures.py b/tests/integration_tests/key_value/commands/fixtures.py deleted file mode 100644 index 74bf809301c1..000000000000 --- a/tests/integration_tests/key_value/commands/fixtures.py +++ /dev/null @@ -1,69 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from __future__ import annotations - -from collections.abc import Generator -from typing import TYPE_CHECKING -from uuid import UUID - -import pytest -from flask_appbuilder.security.sqla.models import User - -from superset.extensions import db -from superset.key_value.types import ( - JsonKeyValueCodec, - KeyValueResource, - PickleKeyValueCodec, -) -from superset.utils import json -from tests.integration_tests.test_app import app - -if TYPE_CHECKING: - from superset.key_value.models import KeyValueEntry - -ID_KEY = 123 -UUID_KEY = UUID("3e7a2ab8-bcaf-49b0-a5df-dfb432f291cc") -RESOURCE = KeyValueResource.APP -JSON_VALUE = {"foo": "bar"} -PICKLE_VALUE = object() -JSON_CODEC = JsonKeyValueCodec() -PICKLE_CODEC = PickleKeyValueCodec() - - -@pytest.fixture -def key_value_entry() -> Generator[KeyValueEntry, None, None]: - from superset.key_value.models import KeyValueEntry - - entry = KeyValueEntry( - id=ID_KEY, - uuid=UUID_KEY, - resource=RESOURCE, - value=bytes(json.dumps(JSON_VALUE), encoding="utf-8"), - ) - db.session.add(entry) - db.session.flush() - yield entry - db.session.delete(entry) - db.session.commit() - - -@pytest.fixture -def admin() -> User: - with app.app_context(): # noqa: F841 - admin = db.session.query(User).filter_by(username="admin").one() - return admin diff --git a/tests/integration_tests/key_value/commands/get_test.py b/tests/integration_tests/key_value/commands/get_test.py deleted file mode 100644 index 131b615b7c2e..000000000000 --- a/tests/integration_tests/key_value/commands/get_test.py +++ /dev/null @@ -1,103 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import uuid -from datetime import datetime, timedelta -from typing import TYPE_CHECKING - -from flask.ctx import AppContext - -from superset.extensions import db -from superset.utils import json -from tests.integration_tests.key_value.commands.fixtures import ( - ID_KEY, - JSON_CODEC, - JSON_VALUE, - key_value_entry, # noqa: F401 - RESOURCE, - UUID_KEY, -) - -if TYPE_CHECKING: - from superset.key_value.models import KeyValueEntry - - -def test_get_id_entry(app_context: AppContext, key_value_entry: KeyValueEntry) -> None: # noqa: F811 - from superset.commands.key_value.get import GetKeyValueCommand - - value = GetKeyValueCommand(resource=RESOURCE, key=ID_KEY, codec=JSON_CODEC).run() - assert value == JSON_VALUE - - -def test_get_uuid_entry( - app_context: AppContext, - key_value_entry: KeyValueEntry, # noqa: F811 -) -> None: - from superset.commands.key_value.get import GetKeyValueCommand - - value = GetKeyValueCommand(resource=RESOURCE, key=UUID_KEY, codec=JSON_CODEC).run() - assert value == JSON_VALUE - - -def test_get_id_entry_missing( - app_context: AppContext, - key_value_entry: KeyValueEntry, # noqa: F811 -) -> None: - from superset.commands.key_value.get import GetKeyValueCommand - - value = GetKeyValueCommand(resource=RESOURCE, key=456, codec=JSON_CODEC).run() - assert value is None - - -def test_get_expired_entry(app_context: AppContext) -> None: - from superset.commands.key_value.get import GetKeyValueCommand - from superset.key_value.models import KeyValueEntry - - entry = KeyValueEntry( - id=678, - uuid=uuid.uuid4(), - resource=RESOURCE, - value=bytes(json.dumps(JSON_VALUE), encoding="utf-8"), - expires_on=datetime.now() - timedelta(days=1), - ) - db.session.add(entry) - db.session.flush() - value = GetKeyValueCommand(resource=RESOURCE, key=ID_KEY, codec=JSON_CODEC).run() - assert value is None - db.session.delete(entry) - db.session.commit() - - -def test_get_future_expiring_entry(app_context: AppContext) -> None: - from superset.commands.key_value.get import GetKeyValueCommand - from superset.key_value.models import KeyValueEntry - - id_ = 789 - entry = KeyValueEntry( - id=id_, - uuid=uuid.uuid4(), - resource=RESOURCE, - value=bytes(json.dumps(JSON_VALUE), encoding="utf-8"), - expires_on=datetime.now() + timedelta(days=1), - ) - db.session.add(entry) - db.session.flush() - value = GetKeyValueCommand(resource=RESOURCE, key=id_, codec=JSON_CODEC).run() - assert value == JSON_VALUE - db.session.delete(entry) - db.session.commit() diff --git a/tests/integration_tests/key_value/commands/update_test.py b/tests/integration_tests/key_value/commands/update_test.py deleted file mode 100644 index bb434ec3b98b..000000000000 --- a/tests/integration_tests/key_value/commands/update_test.py +++ /dev/null @@ -1,97 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -from typing import TYPE_CHECKING - -from flask.ctx import AppContext -from flask_appbuilder.security.sqla.models import User - -from superset.extensions import db -from superset.utils import json -from superset.utils.core import override_user -from tests.integration_tests.key_value.commands.fixtures import ( - admin, # noqa: F401 - ID_KEY, - JSON_CODEC, - key_value_entry, # noqa: F401 - RESOURCE, - UUID_KEY, -) - -if TYPE_CHECKING: - from superset.key_value.models import KeyValueEntry - - -NEW_VALUE = "new value" - - -def test_update_id_entry( - app_context: AppContext, - admin: User, # noqa: F811 - key_value_entry: KeyValueEntry, # noqa: F811 -) -> None: - from superset.commands.key_value.update import UpdateKeyValueCommand - from superset.key_value.models import KeyValueEntry - - with override_user(admin): - key = UpdateKeyValueCommand( - resource=RESOURCE, - key=ID_KEY, - value=NEW_VALUE, - codec=JSON_CODEC, - ).run() - assert key is not None - assert key.id == ID_KEY - entry = db.session.query(KeyValueEntry).filter_by(id=ID_KEY).one() - assert json.loads(entry.value) == NEW_VALUE - assert entry.changed_by_fk == admin.id - - -def test_update_uuid_entry( - app_context: AppContext, - admin: User, # noqa: F811 - key_value_entry: KeyValueEntry, # noqa: F811 -) -> None: - from superset.commands.key_value.update import UpdateKeyValueCommand - from superset.key_value.models import KeyValueEntry - - with override_user(admin): - key = UpdateKeyValueCommand( - resource=RESOURCE, - key=UUID_KEY, - value=NEW_VALUE, - codec=JSON_CODEC, - ).run() - assert key is not None - assert key.uuid == UUID_KEY - entry = db.session.query(KeyValueEntry).filter_by(uuid=UUID_KEY).one() - assert json.loads(entry.value) == NEW_VALUE - assert entry.changed_by_fk == admin.id - - -def test_update_missing_entry(app_context: AppContext, admin: User) -> None: # noqa: F811 - from superset.commands.key_value.update import UpdateKeyValueCommand - - with override_user(admin): - key = UpdateKeyValueCommand( - resource=RESOURCE, - key=456, - value=NEW_VALUE, - codec=JSON_CODEC, - ).run() - assert key is None diff --git a/tests/integration_tests/key_value/commands/upsert_test.py b/tests/integration_tests/key_value/commands/upsert_test.py deleted file mode 100644 index 6ff61423f1a7..000000000000 --- a/tests/integration_tests/key_value/commands/upsert_test.py +++ /dev/null @@ -1,101 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -from typing import TYPE_CHECKING - -from flask.ctx import AppContext -from flask_appbuilder.security.sqla.models import User - -from superset.extensions import db -from superset.utils import json -from superset.utils.core import override_user -from tests.integration_tests.key_value.commands.fixtures import ( - admin, # noqa: F401 - ID_KEY, - JSON_CODEC, - key_value_entry, # noqa: F401 - RESOURCE, - UUID_KEY, -) - -if TYPE_CHECKING: - from superset.key_value.models import KeyValueEntry - - -NEW_VALUE = "new value" - - -def test_upsert_id_entry( - app_context: AppContext, - admin: User, # noqa: F811 - key_value_entry: KeyValueEntry, # noqa: F811 -) -> None: - from superset.commands.key_value.upsert import UpsertKeyValueCommand - from superset.key_value.models import KeyValueEntry - - with override_user(admin): - key = UpsertKeyValueCommand( - resource=RESOURCE, - key=ID_KEY, - value=NEW_VALUE, - codec=JSON_CODEC, - ).run() - assert key is not None - assert key.id == ID_KEY - entry = db.session.query(KeyValueEntry).filter_by(id=int(ID_KEY)).one() - assert json.loads(entry.value) == NEW_VALUE - assert entry.changed_by_fk == admin.id - - -def test_upsert_uuid_entry( - app_context: AppContext, - admin: User, # noqa: F811 - key_value_entry: KeyValueEntry, # noqa: F811 -) -> None: - from superset.commands.key_value.upsert import UpsertKeyValueCommand - from superset.key_value.models import KeyValueEntry - - with override_user(admin): - key = UpsertKeyValueCommand( - resource=RESOURCE, - key=UUID_KEY, - value=NEW_VALUE, - codec=JSON_CODEC, - ).run() - assert key is not None - assert key.uuid == UUID_KEY - entry = db.session.query(KeyValueEntry).filter_by(uuid=UUID_KEY).one() - assert json.loads(entry.value) == NEW_VALUE - assert entry.changed_by_fk == admin.id - - -def test_upsert_missing_entry(app_context: AppContext, admin: User) -> None: # noqa: F811 - from superset.commands.key_value.upsert import UpsertKeyValueCommand - from superset.key_value.models import KeyValueEntry - - with override_user(admin): - key = UpsertKeyValueCommand( - resource=RESOURCE, - key=456, - value=NEW_VALUE, - codec=JSON_CODEC, - ).run() - assert key is not None - assert key.id == 456 - db.session.query(KeyValueEntry).filter_by(id=456).delete() - db.session.commit() diff --git a/tests/unit_tests/dao/key_value_test.py b/tests/unit_tests/dao/key_value_test.py new file mode 100644 index 000000000000..18c0dfb25f94 --- /dev/null +++ b/tests/unit_tests/dao/key_value_test.py @@ -0,0 +1,395 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-argument, import-outside-toplevel, unused-import +from __future__ import annotations + +import pickle +from datetime import datetime, timedelta +from typing import Generator, TYPE_CHECKING +from uuid import UUID + +import pytest +from flask.ctx import AppContext +from flask_appbuilder.security.sqla.models import User + +from superset.extensions import db +from superset.key_value.exceptions import ( + KeyValueCreateFailedError, + KeyValueUpdateFailedError, +) +from superset.key_value.types import ( + JsonKeyValueCodec, + KeyValueResource, + PickleKeyValueCodec, +) +from superset.utils import json +from superset.utils.core import override_user +from tests.unit_tests.fixtures.common import admin_user, after_each # noqa: F401 + +if TYPE_CHECKING: + from superset.key_value.models import KeyValueEntry + +ID_KEY = 123 +UUID_KEY = UUID("3e7a2ab8-bcaf-49b0-a5df-dfb432f291cc") +RESOURCE = KeyValueResource.APP +JSON_VALUE = {"foo": "bar"} +PICKLE_VALUE = object() +JSON_CODEC = JsonKeyValueCodec() +PICKLE_CODEC = PickleKeyValueCodec() +NEW_VALUE = {"foo": "baz"} + + +@pytest.fixture +def key_value_entry() -> Generator[KeyValueEntry, None, None]: + from superset.key_value.models import KeyValueEntry + + entry = KeyValueEntry( + id=ID_KEY, + uuid=UUID_KEY, + resource=RESOURCE, + value=JSON_CODEC.encode(JSON_VALUE), + ) + db.session.add(entry) + db.session.flush() + yield entry + + +def test_create_id_entry( + app_context: AppContext, + admin_user: User, # noqa: F811 + after_each: None, # noqa: F811 +) -> None: + from superset.daos.key_value import KeyValueDAO + from superset.key_value.models import KeyValueEntry + + with override_user(admin_user): + created_entry = KeyValueDAO.create_entry( + resource=RESOURCE, + value=JSON_VALUE, + codec=JSON_CODEC, + ) + db.session.flush() + found_entry = ( + db.session.query(KeyValueEntry).filter_by(id=created_entry.id).one() + ) + assert json.loads(found_entry.value) == JSON_VALUE + assert found_entry.created_by_fk == admin_user.id + + +def test_create_uuid_entry( + app_context: AppContext, + admin_user: User, # noqa: F811 + after_each: None, # noqa: F811 +) -> None: + from superset.daos.key_value import KeyValueDAO + from superset.key_value.models import KeyValueEntry + + with override_user(admin_user): + created_entry = KeyValueDAO.create_entry( + resource=RESOURCE, value=JSON_VALUE, codec=JSON_CODEC + ) + db.session.flush() + + found_entry = ( + db.session.query(KeyValueEntry).filter_by(uuid=created_entry.uuid).one() + ) + assert json.loads(found_entry.value) == JSON_VALUE + assert found_entry.created_by_fk == admin_user.id + + +def test_create_fail_json_entry( + app_context: AppContext, + after_each: None, # noqa: F811 +) -> None: + from superset.daos.key_value import KeyValueDAO + + with pytest.raises(KeyValueCreateFailedError): + KeyValueDAO.create_entry( + resource=RESOURCE, + value=PICKLE_VALUE, + codec=JSON_CODEC, + ) + + +def test_create_pickle_entry( + app_context: AppContext, + admin_user: User, # noqa: F811 + after_each: None, # noqa: F811 +) -> None: + from superset.daos.key_value import KeyValueDAO + from superset.key_value.models import KeyValueEntry + + with override_user(admin_user): + created_entry = KeyValueDAO.create_entry( + resource=RESOURCE, + value=PICKLE_VALUE, + codec=PICKLE_CODEC, + ) + db.session.flush() + found_entry = ( + db.session.query(KeyValueEntry).filter_by(id=created_entry.id).one() + ) + assert type(pickle.loads(found_entry.value)) == type(PICKLE_VALUE) + assert found_entry.created_by_fk == admin_user.id + + +def test_get_value( + app_context: AppContext, + key_value_entry: KeyValueEntry, + after_each: None, # noqa: F811 +) -> None: + from superset.daos.key_value import KeyValueDAO + + value = KeyValueDAO.get_value( + resource=RESOURCE, + key=key_value_entry.id, + codec=JSON_CODEC, + ) + assert value == JSON_VALUE + + +def test_get_id_entry( + app_context: AppContext, + key_value_entry: KeyValueEntry, + after_each: None, # noqa: F811 +) -> None: + from superset.daos.key_value import KeyValueDAO + + found_entry = KeyValueDAO.get_entry(resource=RESOURCE, key=key_value_entry.id) + assert found_entry is not None + assert found_entry.id == key_value_entry.id + + +def test_get_uuid_entry( + app_context: AppContext, + key_value_entry: KeyValueEntry, # noqa: F811 + after_each: None, # noqa: F811 +) -> None: + from superset.daos.key_value import KeyValueDAO + + found_entry = KeyValueDAO.get_entry(resource=RESOURCE, key=key_value_entry.uuid) + assert found_entry is not None + assert JSON_CODEC.decode(found_entry.value) == JSON_VALUE + + +def test_get_id_entry_missing( + app_context: AppContext, + after_each: None, # noqa: F811 +) -> None: + from superset.daos.key_value import KeyValueDAO + + entry = KeyValueDAO.get_entry(resource=RESOURCE, key=456) + assert entry is None + + +def test_get_expired_entry( + app_context: AppContext, + after_each: None, # noqa: F811 +) -> None: + from superset.daos.key_value import KeyValueDAO + + created_entry = KeyValueDAO.create_entry( + resource=RESOURCE, + value=JSON_VALUE, + codec=JSON_CODEC, + key=ID_KEY, + expires_on=datetime.now() - timedelta(days=1), + ) + found_entry = KeyValueDAO.get_entry(resource=RESOURCE, key=created_entry.id) + assert found_entry is not None + assert found_entry.is_expired() is True + + +def test_get_future_expiring_entry( + app_context: AppContext, + after_each: None, # noqa: F811 +) -> None: + from superset.daos.key_value import KeyValueDAO + + created_entry = KeyValueDAO.create_entry( + resource=RESOURCE, + value=JSON_VALUE, + codec=JSON_CODEC, + key=ID_KEY, + expires_on=datetime.now() + timedelta(days=1), + ) + found_entry = KeyValueDAO.get_entry(resource=RESOURCE, key=created_entry.id) + assert found_entry is not None + assert found_entry.is_expired() is False + + +def test_update_id_entry( + app_context: AppContext, + key_value_entry: KeyValueEntry, # noqa: F811 + admin_user: User, # noqa: F811 + after_each: None, # noqa: F811 +) -> None: + from superset.daos.key_value import KeyValueDAO + + with override_user(admin_user): + updated_entry = KeyValueDAO.update_entry( + resource=RESOURCE, + key=ID_KEY, + value=NEW_VALUE, + codec=JSON_CODEC, + ) + db.session.flush() + assert updated_entry is not None + assert JSON_CODEC.decode(updated_entry.value) == NEW_VALUE + assert updated_entry.id == ID_KEY + assert updated_entry.uuid == UUID_KEY + found_entry = KeyValueDAO.get_entry(resource=RESOURCE, key=ID_KEY) + assert found_entry is not None + assert JSON_CODEC.decode(found_entry.value) == NEW_VALUE + assert found_entry.changed_by_fk == admin_user.id + + +def test_update_uuid_entry( + app_context: AppContext, + key_value_entry: KeyValueEntry, # noqa: F811 + admin_user: User, # noqa: F811 + after_each: None, # noqa: F811 +) -> None: + from superset.daos.key_value import KeyValueDAO + + with override_user(admin_user): + updated_entry = KeyValueDAO.update_entry( + resource=RESOURCE, + key=UUID_KEY, + value=NEW_VALUE, + codec=JSON_CODEC, + ) + db.session.flush() + assert updated_entry is not None + assert JSON_CODEC.decode(updated_entry.value) == NEW_VALUE + assert updated_entry.id == ID_KEY + assert updated_entry.uuid == UUID_KEY + found_entry = KeyValueDAO.get_entry(resource=RESOURCE, key=UUID_KEY) + assert found_entry is not None + assert JSON_CODEC.decode(found_entry.value) == NEW_VALUE + assert found_entry.changed_by_fk == admin_user.id + + +def test_update_missing_entry( + app_context: AppContext, + admin_user: User, # noqa: F811 + after_each: None, # noqa: F811 +) -> None: + from superset.daos.key_value import KeyValueDAO + + with override_user(admin_user): + with pytest.raises(KeyValueUpdateFailedError): + KeyValueDAO.update_entry( + resource=RESOURCE, + key=456, + value=NEW_VALUE, + codec=JSON_CODEC, + ) + + +def test_upsert_id_entry( + app_context: AppContext, + key_value_entry: KeyValueEntry, # noqa: F811 + admin_user: User, # noqa: F811 + after_each: None, # noqa: F811 +) -> None: + from superset.daos.key_value import KeyValueDAO + + with override_user(admin_user): + entry = KeyValueDAO.upsert_entry( + resource=RESOURCE, + key=ID_KEY, + value=NEW_VALUE, + codec=JSON_CODEC, + ) + found_entry = KeyValueDAO.get_entry(resource=RESOURCE, key=ID_KEY) + assert found_entry is not None + assert JSON_CODEC.decode(found_entry.value) == NEW_VALUE + assert entry.changed_by_fk == admin_user.id + + +def test_upsert_uuid_entry( + app_context: AppContext, + key_value_entry: KeyValueEntry, # noqa: F811 + admin_user: User, # noqa: F811 + after_each: None, # noqa: F811 +) -> None: + from superset.daos.key_value import KeyValueDAO + + with override_user(admin_user): + entry = KeyValueDAO.upsert_entry( + resource=RESOURCE, + key=UUID_KEY, + value=NEW_VALUE, + codec=JSON_CODEC, + ) + db.session.flush() + assert entry is not None + assert entry.id == ID_KEY + assert entry.uuid == UUID_KEY + found_entry = KeyValueDAO.get_entry(resource=RESOURCE, key=UUID_KEY) + assert found_entry is not None + assert JSON_CODEC.decode(found_entry.value) == NEW_VALUE + assert entry.changed_by_fk == admin_user.id + + +def test_upsert_missing_entry( + app_context: AppContext, + after_each: None, # noqa: F811 +) -> None: + from superset.daos.key_value import KeyValueDAO + + entry = KeyValueDAO.get_entry(resource=RESOURCE, key=ID_KEY) + assert entry is None + KeyValueDAO.upsert_entry( + resource=RESOURCE, + key=ID_KEY, + value=NEW_VALUE, + codec=JSON_CODEC, + ) + entry = KeyValueDAO.get_entry(resource=RESOURCE, key=ID_KEY) + assert entry is not None + assert JSON_CODEC.decode(entry.value) == NEW_VALUE + + +def test_delete_id_entry( + app_context: AppContext, + key_value_entry: KeyValueEntry, + after_each: None, # noqa: F811 +) -> None: + from superset.daos.key_value import KeyValueDAO + + assert KeyValueDAO.delete_entry(resource=RESOURCE, key=ID_KEY) is True + + +def test_delete_uuid_entry( + app_context: AppContext, + key_value_entry: KeyValueEntry, + after_each: None, # noqa: F811 +) -> None: + from superset.daos.key_value import KeyValueDAO + + assert KeyValueDAO.delete_entry(resource=RESOURCE, key=UUID_KEY) is True + + +def test_delete_entry_missing( + app_context: AppContext, + after_each: None, # noqa: F811 +) -> None: + from superset.daos.key_value import KeyValueDAO + + assert KeyValueDAO.delete_entry(resource=RESOURCE, key=12345678) is False diff --git a/tests/unit_tests/distributed_lock/__init__.py b/tests/unit_tests/distributed_lock/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/unit_tests/utils/lock_tests.py b/tests/unit_tests/distributed_lock/distributed_lock_tests.py similarity index 51% rename from tests/unit_tests/utils/lock_tests.py rename to tests/unit_tests/distributed_lock/distributed_lock_tests.py index 4c9121fe3874..6fe363f0978d 100644 --- a/tests/unit_tests/utils/lock_tests.py +++ b/tests/unit_tests/distributed_lock/distributed_lock_tests.py @@ -22,17 +22,21 @@ import pytest from freezegun import freeze_time +from sqlalchemy.orm import Session, sessionmaker from superset import db +from superset.distributed_lock import KeyValueDistributedLock +from superset.distributed_lock.types import LockValue +from superset.distributed_lock.utils import get_key from superset.exceptions import CreateKeyValueDistributedLockFailedException from superset.key_value.types import JsonKeyValueCodec -from superset.utils.lock import get_key, KeyValueDistributedLock +LOCK_VALUE: LockValue = {"value": True} MAIN_KEY = get_key("ns", a=1, b=2) OTHER_KEY = get_key("ns2", a=1, b=2) -def _get_lock(key: UUID) -> Any: +def _get_lock(key: UUID, session: Session) -> Any: from superset.key_value.models import KeyValueEntry entry = db.session.query(KeyValueEntry).filter_by(uuid=key).first() @@ -42,41 +46,56 @@ def _get_lock(key: UUID) -> Any: return JsonKeyValueCodec().decode(entry.value) +def _get_other_session() -> Session: + # This session is used to simulate what another worker will find in the metastore + # during the locking process. + from superset import db + + bind = db.session.get_bind() + SessionMaker = sessionmaker(bind=bind) + return SessionMaker() + + def test_key_value_distributed_lock_happy_path() -> None: """ Test successfully acquiring and returning the distributed lock. - Note we use a nested transaction to ensure that the cleanup from the outer context - manager is correctly invoked, otherwise a partial rollback would occur leaving the - database in a fractured state. + Note, we're using another session for asserting the lock state in the Metastore + to simulate what another worker will observe. Otherwise, there's the risk that + the assertions would only be using the non-committed state from the main session. """ + session = _get_other_session() with freeze_time("2021-01-01"): - assert _get_lock(MAIN_KEY) is None + assert _get_lock(MAIN_KEY, session) is None with KeyValueDistributedLock("ns", a=1, b=2) as key: assert key == MAIN_KEY - assert _get_lock(key) is True - assert _get_lock(OTHER_KEY) is None + assert _get_lock(key, session) == LOCK_VALUE + assert _get_lock(OTHER_KEY, session) is None - with db.session.begin_nested(): - with pytest.raises(CreateKeyValueDistributedLockFailedException): - with KeyValueDistributedLock("ns", a=1, b=2): - pass + with pytest.raises(CreateKeyValueDistributedLockFailedException): + with KeyValueDistributedLock("ns", a=1, b=2): + pass - assert _get_lock(MAIN_KEY) is None + assert _get_lock(MAIN_KEY, session) is None def test_key_value_distributed_lock_expired() -> None: """ Test expiration of the distributed lock + + Note, we're using another session for asserting the lock state in the Metastore + to simulate what another worker will observe. Otherwise, there's the risk that + the assertions would only be using the non-committed state from the main session. """ + session = _get_other_session() with freeze_time("2021-01-01"): - assert _get_lock(MAIN_KEY) is None + assert _get_lock(MAIN_KEY, session) is None with KeyValueDistributedLock("ns", a=1, b=2): - assert _get_lock(MAIN_KEY) is True + assert _get_lock(MAIN_KEY, session) == LOCK_VALUE with freeze_time("2022-01-01"): - assert _get_lock(MAIN_KEY) is None + assert _get_lock(MAIN_KEY, session) is None - assert _get_lock(MAIN_KEY) is None + assert _get_lock(MAIN_KEY, session) is None diff --git a/tests/unit_tests/fixtures/common.py b/tests/unit_tests/fixtures/common.py index 5aea8472c04a..4ee1d9d0ee34 100644 --- a/tests/unit_tests/fixtures/common.py +++ b/tests/unit_tests/fixtures/common.py @@ -20,12 +20,15 @@ import csv from datetime import datetime from io import BytesIO, StringIO -from typing import Any +from typing import Any, Generator import pandas as pd import pytest +from flask_appbuilder.security.sqla.models import Role, User from werkzeug.datastructures import FileStorage +from superset import db + @pytest.fixture def dttm() -> datetime: @@ -73,3 +76,24 @@ def create_columnar_file( df.to_parquet(buffer, index=False) buffer.seek(0) return FileStorage(stream=buffer, filename=filename) + + +@pytest.fixture +def admin_user() -> Generator[User, None, None]: + role = db.session.query(Role).filter_by(name="Admin").one() + user = User( + first_name="Alice", + last_name="Admin", + email="alice_admin@example.org", + username="alice_admin", + roles=[role], + ) + db.session.add(user) + db.session.flush() + yield user + + +@pytest.fixture +def after_each() -> Generator[None, None, None]: + yield + db.session.rollback()