Skip to content

Commit

Permalink
feat(GAQ): Add Redis Sentinel Support for Global Async Queries (#29912)
Browse files Browse the repository at this point in the history
Co-authored-by: Sivarajan Narayanan <[email protected]>
  • Loading branch information
nsivarajan and Sivarajan Narayanan committed Aug 30, 2024
1 parent cd6b8b2 commit 103cd3d
Show file tree
Hide file tree
Showing 6 changed files with 450 additions and 45 deletions.
80 changes: 65 additions & 15 deletions superset/async_events/async_query_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,31 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import logging
import uuid
from typing import Any, Literal, Optional
from typing import Any, Literal, Optional, Union

import jwt
import redis
from flask import Flask, Request, request, Response, session
from flask_caching.backends.base import BaseCache

from superset.async_events.cache_backend import (
RedisCacheBackend,
RedisSentinelCacheBackend,
)
from superset.utils import json
from superset.utils.core import get_user_id

logger = logging.getLogger(__name__)


class CacheBackendNotInitialized(Exception):
pass


class AsyncQueryTokenException(Exception):
pass

Expand Down Expand Up @@ -55,13 +66,32 @@ def parse_event(event_data: tuple[str, dict[str, Any]]) -> dict[str, Any]:
return {"id": event_id, **json.loads(event_payload)}


def increment_id(redis_id: str) -> str:
def increment_id(entry_id: str) -> str:
# redis stream IDs are in this format: '1607477697866-0'
try:
prefix, last = redis_id[:-1], int(redis_id[-1])
prefix, last = entry_id[:-1], int(entry_id[-1])
return prefix + str(last + 1)
except Exception: # pylint: disable=broad-except
return redis_id
return entry_id


def get_cache_backend(
config: dict[str, Any],
) -> Union[RedisCacheBackend, RedisSentinelCacheBackend, redis.Redis]: # type: ignore
cache_config = config.get("GLOBAL_ASYNC_QUERIES_CACHE_BACKEND", {})
cache_type = cache_config.get("CACHE_TYPE")

if cache_type == "RedisCache":
return RedisCacheBackend.from_config(cache_config)

if cache_type == "RedisSentinelCache":
return RedisSentinelCacheBackend.from_config(cache_config)

# TODO: Deprecate hardcoded plain Redis code and expand cache backend options.
# Maintain backward compatibility with 'GLOBAL_ASYNC_QUERIES_REDIS_CONFIG' until it is deprecated.
return redis.Redis(
**config["GLOBAL_ASYNC_QUERIES_REDIS_CONFIG"], decode_responses=True
)


class AsyncQueryManager:
Expand All @@ -73,7 +103,7 @@ class AsyncQueryManager:

def __init__(self) -> None:
super().__init__()
self._redis: redis.Redis # type: ignore
self._cache: Optional[BaseCache] = None
self._stream_prefix: str = ""
self._stream_limit: Optional[int]
self._stream_limit_firehose: Optional[int]
Expand All @@ -88,25 +118,24 @@ def __init__(self) -> None:

def init_app(self, app: Flask) -> None:
config = app.config
if (
config["CACHE_CONFIG"]["CACHE_TYPE"] == "null"
or config["DATA_CACHE_CONFIG"]["CACHE_TYPE"] == "null"
):
cache_type = config.get("CACHE_CONFIG", {}).get("CACHE_TYPE")
data_cache_type = config.get("DATA_CACHE_CONFIG", {}).get("CACHE_TYPE")
if cache_type in [None, "null"] or data_cache_type in [None, "null"]:
raise Exception( # pylint: disable=broad-exception-raised
"""
Cache backends (CACHE_CONFIG, DATA_CACHE_CONFIG) must be configured
and non-null in order to enable async queries
"""
)

self._cache = get_cache_backend(config)
logger.debug("Using GAQ Cache backend as %s", type(self._cache).__name__)

if len(config["GLOBAL_ASYNC_QUERIES_JWT_SECRET"]) < 32:
raise AsyncQueryTokenException(
"Please provide a JWT secret at least 32 bytes long"
)

self._redis = redis.Redis(
**config["GLOBAL_ASYNC_QUERIES_REDIS_CONFIG"], decode_responses=True
)
self._stream_prefix = config["GLOBAL_ASYNC_QUERIES_REDIS_STREAM_PREFIX"]
self._stream_limit = config["GLOBAL_ASYNC_QUERIES_REDIS_STREAM_LIMIT"]
self._stream_limit_firehose = config[
Expand Down Expand Up @@ -230,14 +259,35 @@ def submit_chart_data_job(
def read_events(
self, channel: str, last_id: Optional[str]
) -> list[Optional[dict[str, Any]]]:
if not self._cache:
raise CacheBackendNotInitialized("Cache backend not initialized")

stream_name = f"{self._stream_prefix}{channel}"
start_id = increment_id(last_id) if last_id else "-"
results = self._redis.xrange(stream_name, start_id, "+", self.MAX_EVENT_COUNT)
results = self._cache.xrange(stream_name, start_id, "+", self.MAX_EVENT_COUNT)
# Decode bytes to strings, decode_responses is not supported at RedisCache and RedisSentinelCache
if isinstance(self._cache, (RedisSentinelCacheBackend, RedisCacheBackend)):
decoded_results = [
(
event_id.decode("utf-8"),
{
key.decode("utf-8"): value.decode("utf-8")
for key, value in event_data.items()
},
)
for event_id, event_data in results
]
return (
[] if not decoded_results else list(map(parse_event, decoded_results))
)
return [] if not results else list(map(parse_event, results))

def update_job(
self, job_metadata: dict[str, Any], status: str, **kwargs: Any
) -> None:
if not self._cache:
raise CacheBackendNotInitialized("Cache backend not initialized")

if "channel_id" not in job_metadata:
raise AsyncQueryJobException("No channel ID specified")

Expand All @@ -253,5 +303,5 @@ def update_job(
logger.debug("********** logging event data to stream %s", scoped_stream_name)
logger.debug(event_data)

self._redis.xadd(scoped_stream_name, event_data, "*", self._stream_limit)
self._redis.xadd(full_stream_name, event_data, "*", self._stream_limit_firehose)
self._cache.xadd(scoped_stream_name, event_data, "*", self._stream_limit)
self._cache.xadd(full_stream_name, event_data, "*", self._stream_limit_firehose)
209 changes: 209 additions & 0 deletions superset/async_events/cache_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Dict, List, Optional, Tuple

import redis
from flask_caching.backends.rediscache import RedisCache, RedisSentinelCache
from redis.sentinel import Sentinel


class RedisCacheBackend(RedisCache):
MAX_EVENT_COUNT = 100

def __init__( # pylint: disable=too-many-arguments
self,
host: str,
port: int,
password: Optional[str] = None,
db: int = 0,
default_timeout: int = 300,
key_prefix: Optional[str] = None,
ssl: bool = False,
ssl_certfile: Optional[str] = None,
ssl_keyfile: Optional[str] = None,
ssl_cert_reqs: str = "required",
ssl_ca_certs: Optional[str] = None,
**kwargs: Any,
) -> None:
super().__init__(
host=host,
port=port,
password=password,
db=db,
default_timeout=default_timeout,
key_prefix=key_prefix,
**kwargs,
)
self._cache = redis.Redis(
host=host,
port=port,
password=password,
db=db,
ssl=ssl,
ssl_certfile=ssl_certfile,
ssl_keyfile=ssl_keyfile,
ssl_cert_reqs=ssl_cert_reqs,
ssl_ca_certs=ssl_ca_certs,
**kwargs,
)

def xadd(
self,
stream_name: str,
event_data: Dict[str, Any],
event_id: str = "*",
maxlen: Optional[int] = None,
) -> str:
return self._cache.xadd(stream_name, event_data, event_id, maxlen)

def xrange(
self,
stream_name: str,
start: str = "-",
end: str = "+",
count: Optional[int] = None,
) -> List[Any]:
count = count or self.MAX_EVENT_COUNT
return self._cache.xrange(stream_name, start, end, count)

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "RedisCacheBackend":
kwargs = {
"host": config.get("CACHE_REDIS_HOST", "localhost"),
"port": config.get("CACHE_REDIS_PORT", 6379),
"db": config.get("CACHE_REDIS_DB", 0),
"password": config.get("CACHE_REDIS_PASSWORD", None),
"key_prefix": config.get("CACHE_KEY_PREFIX", None),
"default_timeout": config.get("CACHE_DEFAULT_TIMEOUT", 300),
"ssl": config.get("CACHE_REDIS_SSL", False),
"ssl_certfile": config.get("CACHE_REDIS_SSL_CERTFILE", None),
"ssl_keyfile": config.get("CACHE_REDIS_SSL_KEYFILE", None),
"ssl_cert_reqs": config.get("CACHE_REDIS_SSL_CERT_REQS", "required"),
"ssl_ca_certs": config.get("CACHE_REDIS_SSL_CA_CERTS", None),
}
return cls(**kwargs)


class RedisSentinelCacheBackend(RedisSentinelCache):
MAX_EVENT_COUNT = 100

def __init__( # pylint: disable=too-many-arguments
self,
sentinels: List[Tuple[str, int]],
master: str,
password: Optional[str] = None,
sentinel_password: Optional[str] = None,
db: int = 0,
default_timeout: int = 300,
key_prefix: str = "",
ssl: bool = False,
ssl_certfile: Optional[str] = None,
ssl_keyfile: Optional[str] = None,
ssl_cert_reqs: str = "required",
ssl_ca_certs: Optional[str] = None,
**kwargs: Any,
) -> None:
# Sentinel dont directly support SSL
# Initialize Sentinel without SSL parameters
self._sentinel = Sentinel(
sentinels,
sentinel_kwargs={
"password": sentinel_password,
},
**{
k: v
for k, v in kwargs.items()
if k
not in [
"ssl",
"ssl_certfile",
"ssl_keyfile",
"ssl_cert_reqs",
"ssl_ca_certs",
]
},
)

# Prepare SSL-related arguments for master_for method
master_kwargs = {
"password": password,
"ssl": ssl,
"ssl_certfile": ssl_certfile if ssl else None,
"ssl_keyfile": ssl_keyfile if ssl else None,
"ssl_cert_reqs": ssl_cert_reqs if ssl else None,
"ssl_ca_certs": ssl_ca_certs if ssl else None,
}

# If SSL is False, remove all SSL-related keys
# SSL_* are expected only if SSL is True
if not ssl:
master_kwargs = {
k: v for k, v in master_kwargs.items() if not k.startswith("ssl")
}

# Filter out None values from master_kwargs
master_kwargs = {k: v for k, v in master_kwargs.items() if v is not None}

# Initialize Redis master connection
self._cache = self._sentinel.master_for(master, **master_kwargs)

# Call the parent class constructor
super().__init__(
host=None,
port=None,
password=password,
db=db,
default_timeout=default_timeout,
key_prefix=key_prefix,
**kwargs,
)

def xadd(
self,
stream_name: str,
event_data: Dict[str, Any],
event_id: str = "*",
maxlen: Optional[int] = None,
) -> str:
return self._cache.xadd(stream_name, event_data, event_id, maxlen)

def xrange(
self,
stream_name: str,
start: str = "-",
end: str = "+",
count: Optional[int] = None,
) -> List[Any]:
count = count or self.MAX_EVENT_COUNT
return self._cache.xrange(stream_name, start, end, count)

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "RedisSentinelCacheBackend":
kwargs = {
"sentinels": config.get("CACHE_REDIS_SENTINELS", [("127.0.0.1", 26379)]),
"master": config.get("CACHE_REDIS_SENTINEL_MASTER", "mymaster"),
"password": config.get("CACHE_REDIS_PASSWORD", None),
"sentinel_password": config.get("CACHE_REDIS_SENTINEL_PASSWORD", None),
"key_prefix": config.get("CACHE_KEY_PREFIX", ""),
"db": config.get("CACHE_REDIS_DB", 0),
"ssl": config.get("CACHE_REDIS_SSL", False),
"ssl_certfile": config.get("CACHE_REDIS_SSL_CERTFILE", None),
"ssl_keyfile": config.get("CACHE_REDIS_SSL_KEYFILE", None),
"ssl_cert_reqs": config.get("CACHE_REDIS_SSL_CERT_REQS", "required"),
"ssl_ca_certs": config.get("CACHE_REDIS_SSL_CA_CERTS", None),
}
return cls(**kwargs)
22 changes: 22 additions & 0 deletions superset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1690,6 +1690,28 @@ def EMAIL_HEADER_MUTATOR( # pylint: disable=invalid-name,unused-argument
)
GLOBAL_ASYNC_QUERIES_WEBSOCKET_URL = "ws://127.0.0.1:8080/"

# Global async queries cache backend configuration options:
# - Set 'CACHE_TYPE' to 'RedisCache' for RedisCacheBackend.
# - Set 'CACHE_TYPE' to 'RedisSentinelCache' for RedisSentinelCacheBackend.
# - Set 'CACHE_TYPE' to 'None' to fall back on 'GLOBAL_ASYNC_QUERIES_REDIS_CONFIG'.
GLOBAL_ASYNC_QUERIES_CACHE_BACKEND = {
"CACHE_TYPE": "RedisCache",
"CACHE_REDIS_HOST": "localhost",
"CACHE_REDIS_PORT": 6379,
"CACHE_REDIS_USER": "",
"CACHE_REDIS_PASSWORD": "",
"CACHE_REDIS_DB": 0,
"CACHE_DEFAULT_TIMEOUT": 300,
"CACHE_REDIS_SENTINELS": [("localhost", 26379)],
"CACHE_REDIS_SENTINEL_MASTER": "mymaster",
"CACHE_REDIS_SENTINEL_PASSWORD": None,
"CACHE_REDIS_SSL": False, # True or False
"CACHE_REDIS_SSL_CERTFILE": None,
"CACHE_REDIS_SSL_KEYFILE": None,
"CACHE_REDIS_SSL_CERT_REQS": "required",
"CACHE_REDIS_SSL_CA_CERTS": None,
}

# Embedded config options
GUEST_ROLE_NAME = "Public"
GUEST_TOKEN_JWT_SECRET = "test-guest-secret-change-me"
Expand Down
Loading

0 comments on commit 103cd3d

Please sign in to comment.