-
Notifications
You must be signed in to change notification settings - Fork 0
[Test] Add token refresh support for Postgres connectors #5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: test/base-for-token-refresh
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| Fixed support for refreshing IAM authentication and Azure Managed Identity tokens |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,7 @@ | |
|
|
||
| import threading | ||
| import time | ||
| from abc import ABC, abstractmethod | ||
| from collections import OrderedDict | ||
| from dataclasses import dataclass | ||
| from typing import Any, Dict, Optional, Tuple, Union | ||
|
|
@@ -14,6 +15,100 @@ | |
| from .cursor import CommenterCursor, SQLASCIITextLoader | ||
|
|
||
|
|
||
| class TokenProvider(ABC): | ||
| """ | ||
| Interface for providing a token for managed authentication. | ||
| """ | ||
|
|
||
| def __init__(self, *, skew_seconds: int = 60): | ||
| self._skew = skew_seconds | ||
| self._lock = threading.Lock() | ||
| self._token: str | None = None | ||
| self._expires_at: float = 0.0 | ||
|
|
||
| def get_token(self) -> str: | ||
| """ | ||
| Get a token for managed authentication. | ||
| """ | ||
| now = time.time() | ||
| with self._lock: | ||
| if self._token is None or now >= self._expires_at - self._skew: | ||
| token, expires_at = self._fetch_token() | ||
| self._token = token | ||
| self._expires_at = float(expires_at) | ||
| return self._token # type: ignore[return-value] | ||
|
|
||
| @abstractmethod | ||
| def _fetch_token(self) -> Tuple[str, float]: | ||
| """ | ||
| Return (token, expires_at_epoch_seconds). | ||
| Implementations should return the absolute expiry; if the provider | ||
| has a fixed TTL, compute expires_at = time.time() + ttl_seconds. | ||
| """ | ||
|
|
||
|
|
||
| class AWSTokenProvider(TokenProvider): | ||
| """ | ||
| Token provider for AWS IAM authentication. | ||
| """ | ||
|
|
||
| TOKEN_TTL_SECONDS = 900 # 15 minutes | ||
|
|
||
| def __init__( | ||
| self, host: str, port: int, username: str, region: str, *, role_arn: str = None, skew_seconds: int = 60 | ||
| ): | ||
| super().__init__(skew_seconds=skew_seconds) | ||
| self.host = host | ||
| self.port = port | ||
| self.username = username | ||
| self.region = region | ||
| self.role_arn = role_arn | ||
|
|
||
| def _fetch_token(self) -> Tuple[str, float]: | ||
| # Import aws only when this method is called | ||
| from .aws import generate_rds_iam_token | ||
|
|
||
| token = generate_rds_iam_token( | ||
| host=self.host, port=self.port, username=self.username, region=self.region, role_arn=self.role_arn | ||
| ) | ||
| return token, time.time() + self.TOKEN_TTL_SECONDS | ||
|
|
||
|
|
||
| class AzureTokenProvider(TokenProvider): | ||
| """ | ||
| Token provider for Azure Managed Identity. | ||
| """ | ||
|
|
||
| def __init__(self, client_id: str, identity_scope: str = None, skew_seconds: int = 60): | ||
| super().__init__(skew_seconds=skew_seconds) | ||
| self.client_id = client_id | ||
| self.identity_scope = identity_scope | ||
|
|
||
| def _fetch_token(self) -> Tuple[str, float]: | ||
| # Import azure only when this method is called | ||
| from .azure import generate_managed_identity_token | ||
|
|
||
| token = generate_managed_identity_token(client_id=self.client_id, identity_scope=self.identity_scope) | ||
| return token.token, float(token.expires_at) | ||
|
|
||
|
|
||
| class TokenAwareConnection(Connection): | ||
| """ | ||
| Connection that can be used for managed authentication. | ||
| """ | ||
|
|
||
| token_provider: Optional[TokenProvider] = None | ||
|
|
||
| @classmethod | ||
| def connect(cls, *args, **kwargs): | ||
| """ | ||
| Override the connection method to pass a refreshable token as the connection password. | ||
| """ | ||
| if cls.token_provider: | ||
| kwargs["password"] = cls.token_provider.get_token() | ||
| return super().connect(*args, **kwargs) | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class PostgresConnectionArgs: | ||
| """ | ||
|
|
@@ -25,6 +120,7 @@ class PostgresConnectionArgs: | |
| host: Optional[str] = None | ||
| port: Optional[int] = None | ||
| password: Optional[str] = None | ||
| token_provider: Optional[TokenProvider] = None | ||
| ssl_mode: Optional[str] = "allow" | ||
| ssl_cert: Optional[str] = None | ||
| ssl_root_cert: Optional[str] = None | ||
|
|
@@ -86,6 +182,7 @@ def __init__( | |
| pool_config: Optional[Dict[str, Any]] = None, | ||
| statement_timeout: Optional[int] = None, # milliseconds | ||
| sqlascii_encodings: Optional[list[str]] = None, | ||
| token_provider: Optional[TokenProvider] = None, | ||
| ) -> None: | ||
| """ | ||
| Initialize the pool manager. | ||
|
|
@@ -101,13 +198,17 @@ def __init__( | |
| self.base_conn_args = base_conn_args | ||
| self.statement_timeout = statement_timeout | ||
| self.sqlascii_encodings = sqlascii_encodings | ||
| self.token_provider = token_provider | ||
|
|
||
| self.pool_config = { | ||
| **(pool_config or {}), | ||
| "min_size": 0, | ||
| "max_size": 2, | ||
| "open": True, | ||
| } | ||
|
|
||
| TokenAwareConnection.token_provider = self.token_provider | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This stores the token provider as a class-level attribute, which creates a global state shared across all PostgreSQL check instances in the Agent process. If multiple PostgreSQL checks are configured (e.g., different hosts, or some with managed auth and others without), the last initialized check will overwrite this attribute, causing incorrect authentication behavior for other checks. The token provider should be scoped to the specific connection pool instance rather than the class. Agent: 🤖 Integrations reviewer • Prompt for Agent |
||
|
|
||
| self.lock = threading.Lock() | ||
| self.pools: OrderedDict[str, Tuple[ConnectionPool, float, bool]] = OrderedDict() | ||
| self._closed = False | ||
|
|
@@ -139,7 +240,12 @@ def _create_pool(self, dbname: str) -> ConnectionPool: | |
| """ | ||
| kwargs = self.base_conn_args.as_kwargs(dbname=dbname) | ||
|
|
||
| return ConnectionPool(kwargs=kwargs, configure=self._configure_connection, **self.pool_config) | ||
| return ConnectionPool( | ||
| kwargs=kwargs, | ||
| configure=self._configure_connection, | ||
| connection_class=TokenAwareConnection, | ||
| **self.pool_config, | ||
| ) | ||
|
|
||
| def get_connection(self, dbname: str, persistent: bool = False): | ||
| """ | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Azure's AccessToken object uses
expires_onattribute, notexpires_at. This will cause an AttributeError at runtime when the AzureTokenProvider attempts to accesstoken.expires_at. The correct attribute isexpires_onwhich contains the expiration time as seconds since epoch.Agent: 🤖 Integrations reviewer •
• 
Prompt for Agent