diff --git a/postgres/changelog.d/21503.fixed b/postgres/changelog.d/21503.fixed new file mode 100644 index 0000000000000..191fbaf7361ca --- /dev/null +++ b/postgres/changelog.d/21503.fixed @@ -0,0 +1 @@ +Fixed support for refreshing IAM authentication and Azure Managed Identity tokens diff --git a/postgres/datadog_checks/postgres/azure.py b/postgres/datadog_checks/postgres/azure.py index 409efef47e859..33a6ec66d5d7b 100644 --- a/postgres/datadog_checks/postgres/azure.py +++ b/postgres/datadog_checks/postgres/azure.py @@ -2,6 +2,7 @@ # All rights reserved # Licensed under a 3-clause BSD style license (see LICENSE) +from azure.core.credentials import AccessToken from azure.identity import ManagedIdentityCredential DEFAULT_PERMISSION_SCOPE = "https://ossrdbms-aad.database.windows.net/.default" @@ -9,8 +10,8 @@ # Use the azure identity API to generate a token that will be used # authenticate with either a system or user assigned managed identity -def generate_managed_identity_token(client_id: str, identity_scope: str = None): +def generate_managed_identity_token(client_id: str, identity_scope: str = None) -> AccessToken: credential = ManagedIdentityCredential(client_id=client_id) if not identity_scope: identity_scope = DEFAULT_PERMISSION_SCOPE - return credential.get_token(identity_scope).token + return credential.get_token(identity_scope) diff --git a/postgres/datadog_checks/postgres/connection_pool.py b/postgres/datadog_checks/postgres/connection_pool.py index 3ed99ff96f51b..3009ba86a0192 100644 --- a/postgres/datadog_checks/postgres/connection_pool.py +++ b/postgres/datadog_checks/postgres/connection_pool.py @@ -4,6 +4,7 @@ import threading import time +from abc import ABC, abstractmethod from collections import OrderedDict from dataclasses import dataclass from typing import Any, Dict, Optional, Tuple, Union @@ -14,6 +15,100 @@ from .cursor import CommenterCursor, SQLASCIITextLoader +class TokenProvider(ABC): + """ + Interface for providing a token for managed authentication. + """ + + def __init__(self, *, skew_seconds: int = 60): + self._skew = skew_seconds + self._lock = threading.Lock() + self._token: str | None = None + self._expires_at: float = 0.0 + + def get_token(self) -> str: + """ + Get a token for managed authentication. + """ + now = time.time() + with self._lock: + if self._token is None or now >= self._expires_at - self._skew: + token, expires_at = self._fetch_token() + self._token = token + self._expires_at = float(expires_at) + return self._token # type: ignore[return-value] + + @abstractmethod + def _fetch_token(self) -> Tuple[str, float]: + """ + Return (token, expires_at_epoch_seconds). + Implementations should return the absolute expiry; if the provider + has a fixed TTL, compute expires_at = time.time() + ttl_seconds. + """ + + +class AWSTokenProvider(TokenProvider): + """ + Token provider for AWS IAM authentication. + """ + + TOKEN_TTL_SECONDS = 900 # 15 minutes + + def __init__( + self, host: str, port: int, username: str, region: str, *, role_arn: str = None, skew_seconds: int = 60 + ): + super().__init__(skew_seconds=skew_seconds) + self.host = host + self.port = port + self.username = username + self.region = region + self.role_arn = role_arn + + def _fetch_token(self) -> Tuple[str, float]: + # Import aws only when this method is called + from .aws import generate_rds_iam_token + + token = generate_rds_iam_token( + host=self.host, port=self.port, username=self.username, region=self.region, role_arn=self.role_arn + ) + return token, time.time() + self.TOKEN_TTL_SECONDS + + +class AzureTokenProvider(TokenProvider): + """ + Token provider for Azure Managed Identity. + """ + + def __init__(self, client_id: str, identity_scope: str = None, skew_seconds: int = 60): + super().__init__(skew_seconds=skew_seconds) + self.client_id = client_id + self.identity_scope = identity_scope + + def _fetch_token(self) -> Tuple[str, float]: + # Import azure only when this method is called + from .azure import generate_managed_identity_token + + token = generate_managed_identity_token(client_id=self.client_id, identity_scope=self.identity_scope) + return token.token, float(token.expires_at) + + +class TokenAwareConnection(Connection): + """ + Connection that can be used for managed authentication. + """ + + token_provider: Optional[TokenProvider] = None + + @classmethod + def connect(cls, *args, **kwargs): + """ + Override the connection method to pass a refreshable token as the connection password. + """ + if cls.token_provider: + kwargs["password"] = cls.token_provider.get_token() + return super().connect(*args, **kwargs) + + @dataclass(frozen=True) class PostgresConnectionArgs: """ @@ -25,6 +120,7 @@ class PostgresConnectionArgs: host: Optional[str] = None port: Optional[int] = None password: Optional[str] = None + token_provider: Optional[TokenProvider] = None ssl_mode: Optional[str] = "allow" ssl_cert: Optional[str] = None ssl_root_cert: Optional[str] = None @@ -86,6 +182,7 @@ def __init__( pool_config: Optional[Dict[str, Any]] = None, statement_timeout: Optional[int] = None, # milliseconds sqlascii_encodings: Optional[list[str]] = None, + token_provider: Optional[TokenProvider] = None, ) -> None: """ Initialize the pool manager. @@ -101,6 +198,7 @@ def __init__( self.base_conn_args = base_conn_args self.statement_timeout = statement_timeout self.sqlascii_encodings = sqlascii_encodings + self.token_provider = token_provider self.pool_config = { **(pool_config or {}), @@ -108,6 +206,9 @@ def __init__( "max_size": 2, "open": True, } + + TokenAwareConnection.token_provider = self.token_provider + self.lock = threading.Lock() self.pools: OrderedDict[str, Tuple[ConnectionPool, float, bool]] = OrderedDict() self._closed = False @@ -139,7 +240,12 @@ def _create_pool(self, dbname: str) -> ConnectionPool: """ kwargs = self.base_conn_args.as_kwargs(dbname=dbname) - return ConnectionPool(kwargs=kwargs, configure=self._configure_connection, **self.pool_config) + return ConnectionPool( + kwargs=kwargs, + configure=self._configure_connection, + connection_class=TokenAwareConnection, + **self.pool_config, + ) def get_connection(self, dbname: str, persistent: bool = False): """ diff --git a/postgres/datadog_checks/postgres/postgres.py b/postgres/datadog_checks/postgres/postgres.py index 8775936cb3248..1c6087be57f0f 100644 --- a/postgres/datadog_checks/postgres/postgres.py +++ b/postgres/datadog_checks/postgres/postgres.py @@ -21,8 +21,14 @@ ) from datadog_checks.base.utils.db.utils import resolve_db_host as agent_host_resolver from datadog_checks.base.utils.serialization import json -from datadog_checks.postgres import aws, azure -from datadog_checks.postgres.connection_pool import LRUConnectionPoolManager, PostgresConnectionArgs +from datadog_checks.postgres.connection_pool import ( + AWSTokenProvider, + AzureTokenProvider, + LRUConnectionPoolManager, + PostgresConnectionArgs, + TokenAwareConnection, + TokenProvider, +) from datadog_checks.postgres.discovery import PostgresAutodiscovery from datadog_checks.postgres.health import PostgresHealth from datadog_checks.postgres.metadata import PostgresMetadata @@ -165,6 +171,7 @@ def __init__(self, name, init_config, instances): base_conn_args=self.build_connection_args(), statement_timeout=self._config.query_timeout, sqlascii_encodings=self._config.query_encodings, + token_provider=self.build_token_provider(), ) self.metrics_cache = PostgresMetricsCache(self._config) self.statement_metrics = PostgresStatementMetrics(self, self._config) @@ -913,6 +920,23 @@ def _collect_stats(self, instance_tags): for dynamic_query in self.dynamic_queries: dynamic_query.execute() + def build_token_provider(self) -> TokenProvider: + if self._config.aws.managed_authentication.enabled: + return AWSTokenProvider( + host=self._config.host, + port=self._config.port, + username=self._config.username, + region=self._config.aws.region, + role_arn=self._config.aws.managed_authentication.role_arn, + ) + elif self._config.azure.managed_authentication.enabled: + return AzureTokenProvider( + client_id=self._config.azure.managed_authentication.client_id, + identity_scope=self._config.azure.managed_authentication.identity_scope, + ) + else: + return None + def build_connection_args(self) -> PostgresConnectionArgs: if self._config.host == 'localhost' and self._config.password == '': return PostgresConnectionArgs( @@ -920,31 +944,12 @@ def build_connection_args(self) -> PostgresConnectionArgs: username=self._config.username, ) else: - password = self._config.password - if self._config.aws.managed_authentication.enabled: - password = aws.generate_rds_iam_token( - host=self._config.host, - username=self._config.username, - port=self._config.port, - region=self._config.aws.region, - role_arn=self._config.aws.managed_authentication.role_arn, - ) - elif self._config.azure.managed_authentication.enabled: - client_id = self._config.azure.managed_authentication.client_id - identity_scope = self._config.azure.managed_authentication.identity_scope - password = azure.generate_managed_identity_token(client_id=client_id, identity_scope=identity_scope) - - self.log.debug( - "Try to connect to %s with %s", - self._config.host, - "password" if password == self._config.password else "token", - ) return PostgresConnectionArgs( application_name=self._config.application_name, username=self._config.username, host=self._config.host, port=self._config.port, - password=password, + password=self._config.password, ssl_mode=self._config.ssl, ssl_cert=self._config.ssl_cert, ssl_root_cert=self._config.ssl_root_cert, @@ -956,7 +961,7 @@ def _new_connection(self, dbname): # TODO: Keeping this main connection outside of the pool for now to keep existing behavior. # We should move this to the pool in the future. conn_args = self.build_connection_args() - conn = psycopg.connect(**conn_args.as_kwargs(dbname=dbname)) + conn = TokenAwareConnection.connect(**conn_args.as_kwargs(dbname=dbname)) self.db_pool._configure_connection(conn) return conn diff --git a/postgres/tests/test_token_provider.py b/postgres/tests/test_token_provider.py new file mode 100644 index 0000000000000..5fc416773605d --- /dev/null +++ b/postgres/tests/test_token_provider.py @@ -0,0 +1,263 @@ +# (C) Datadog, Inc. 2025-present +# All rights reserved +# Licensed under a 3-clause BSD style license (see LICENSE) + +import time +from unittest.mock import Mock, patch + +from datadog_checks.postgres.connection_pool import AWSTokenProvider, AzureTokenProvider, TokenProvider + + +def test_get_token_first_call(): + """Test that get_token() calls _fetch_token() on first call.""" + provider = MockTokenProvider() + provider._fetch_token = Mock(return_value=("test_token", time.time() + 3600)) + + token = provider.get_token() + + assert token == "test_token" + assert provider._fetch_token.call_count == 1 + assert provider._token == "test_token" + + +def test_get_token_cached_token_valid(): + """Test that get_token() returns cached token when still valid.""" + provider = MockTokenProvider() + provider._token = "cached_token" + provider._expires_at = time.time() + 3600 # Valid for 1 hour + provider._fetch_token = Mock() + + token = provider.get_token() + + assert token == "cached_token" + assert provider._fetch_token.call_count == 0 + + +def test_get_token_cached_token_expired(): + """Test that get_token() fetches new token when cached token is expired.""" + provider = MockTokenProvider() + provider._token = "old_token" + provider._expires_at = time.time() - 1 # Expired + provider._fetch_token = Mock(return_value=("new_token", time.time() + 3600)) + + token = provider.get_token() + + assert token == "new_token" + assert provider._fetch_token.call_count == 1 + assert provider._token == "new_token" + + +def test_get_token_skew_handling(): + """Test that get_token() respects skew_seconds for token refresh.""" + provider = MockTokenProvider(skew_seconds=60) + provider._token = "cached_token" + # Token expires in 30 seconds, but skew is 60 seconds, so should refresh + provider._expires_at = time.time() + 30 + provider._fetch_token = Mock(return_value=("new_token", time.time() + 3600)) + + token = provider.get_token() + + assert token == "new_token" + assert provider._fetch_token.call_count == 1 + + +def test_thread_safety(): + """Test that TokenProvider is thread-safe.""" + import threading + + provider = MockTokenProvider() + provider._fetch_token = Mock(return_value=("test_token", time.time() + 3600)) + + results = [] + + def get_token(): + results.append(provider.get_token()) + + # Start multiple threads + threads = [threading.Thread(target=get_token) for _ in range(10)] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + # All threads should get the same token + assert all(token == "test_token" for token in results) + # _fetch_token should only be called + assert provider._fetch_token.call_count == 1 + + +def test_aws_token_provider_initialization(): + """Test AWSTokenProvider initialization.""" + provider = AWSTokenProvider( + host="test-host", + port=5432, + username="testuser", + region="us-east-1", + role_arn="arn:aws:iam::123456789012:role/test-role", + ) + + assert provider.host == "test-host" + assert provider.port == 5432 + assert provider.username == "testuser" + assert provider.region == "us-east-1" + assert provider.role_arn == "arn:aws:iam::123456789012:role/test-role" + assert provider.TOKEN_TTL_SECONDS == 900 + + +def test_aws_token_provider_initialization_without_role_arn(): + """Test AWSTokenProvider initialization without role_arn.""" + provider = AWSTokenProvider(host="test-host", port=5432, username="testuser", region="us-east-1") + + assert provider.role_arn is None + + +@patch('datadog_checks.postgres.connection_pool.time.time') +@patch('datadog_checks.postgres.aws.generate_rds_iam_token') +def test_aws_fetch_token_with_role_arn(mock_generate_token, mock_time): + """Test AWS token fetching with role_arn.""" + mock_time.return_value = 1000.0 + mock_generate_token.return_value = "aws_token_123" + + provider = AWSTokenProvider( + host="test-host", + port=5432, + username="testuser", + region="us-east-1", + role_arn="arn:aws:iam::123456789012:role/test-role", + ) + + token, expires_at = provider._fetch_token() + + assert token == "aws_token_123" + assert expires_at == 1900.0 # 1000.0 + 900 (TOKEN_TTL_SECONDS) + mock_generate_token.assert_called_once_with( + host="test-host", + port=5432, + username="testuser", + region="us-east-1", + role_arn="arn:aws:iam::123456789012:role/test-role", + ) + + +@patch('datadog_checks.postgres.connection_pool.time.time') +@patch('datadog_checks.postgres.aws.generate_rds_iam_token') +def test_aws_fetch_token_without_role_arn(mock_generate_token, mock_time): + """Test AWS token fetching without role_arn.""" + mock_time.return_value = 1000.0 + mock_generate_token.return_value = "aws_token_456" + + provider = AWSTokenProvider(host="test-host", port=5432, username="testuser", region="us-east-1") + + token, expires_at = provider._fetch_token() + + assert token == "aws_token_456" + assert expires_at == 1900.0 + mock_generate_token.assert_called_once_with( + host="test-host", port=5432, username="testuser", region="us-east-1", role_arn=None + ) + + +def test_aws_token_provider_integration(): + """Test AWSTokenProvider integration with get_token().""" + with patch('datadog_checks.postgres.aws.generate_rds_iam_token') as mock_generate: + mock_generate.return_value = "integration_token" + + provider = AWSTokenProvider(host="test-host", port=5432, username="testuser", region="us-east-1") + + # First call should fetch token + token1 = provider.get_token() + assert token1 == "integration_token" + assert mock_generate.call_count == 1 + + # Second call should use cached token + token2 = provider.get_token() + assert token2 == "integration_token" + assert mock_generate.call_count == 1 + + +def test_azure_token_provider_initialization(): + """Test AzureTokenProvider initialization.""" + provider = AzureTokenProvider(client_id="test-client-id", identity_scope="https://test.scope/.default") + + assert provider.client_id == "test-client-id" + assert provider.identity_scope == "https://test.scope/.default" + + +def test_azure_token_provider_initialization_without_scope(): + """Test AzureTokenProvider initialization without identity_scope.""" + provider = AzureTokenProvider(client_id="test-client-id") + + assert provider.identity_scope is None + + +@patch('datadog_checks.postgres.azure.ManagedIdentityCredential') +def test_azure_fetch_token_with_scope(mock_credential_class): + """Test Azure token fetching with custom scope.""" + mock_token = Mock() + mock_token.token = "azure_token_123" + mock_token.expires_at = 1900.0 + + mock_credential = Mock() + mock_credential.get_token.return_value = mock_token + mock_credential_class.return_value = mock_credential + + provider = AzureTokenProvider(client_id="test-client-id", identity_scope="https://custom.scope/.default") + + token, expires_at = provider._fetch_token() + + assert token == "azure_token_123" + assert expires_at == 1900.0 + mock_credential_class.assert_called_once_with(client_id="test-client-id") + mock_credential.get_token.assert_called_once_with("https://custom.scope/.default") + + +@patch('datadog_checks.postgres.azure.ManagedIdentityCredential') +def test_azure_fetch_token_without_scope(mock_credential_class): + """Test Azure token fetching without custom scope (uses default).""" + mock_token = Mock() + mock_token.token = "azure_token_456" + mock_token.expires_at = 2000.0 + + mock_credential = Mock() + mock_credential.get_token.return_value = mock_token + mock_credential_class.return_value = mock_credential + + provider = AzureTokenProvider(client_id="test-client-id") + + token, expires_at = provider._fetch_token() + + assert token == "azure_token_456" + assert expires_at == 2000.0 + mock_credential_class.assert_called_once_with(client_id="test-client-id") + mock_credential.get_token.assert_called_once_with("https://ossrdbms-aad.database.windows.net/.default") + + +def test_azure_token_provider_integration(): + """Test AzureTokenProvider integration with get_token().""" + with patch('datadog_checks.postgres.azure.ManagedIdentityCredential') as mock_credential_class: + mock_token = Mock() + mock_token.token = "integration_azure_token" + mock_token.expires_at = time.time() + 3600 + + mock_credential = Mock() + mock_credential.get_token.return_value = mock_token + mock_credential_class.return_value = mock_credential + + provider = AzureTokenProvider(client_id="test-client-id") + + # First call should fetch token + token1 = provider.get_token() + assert token1 == "integration_azure_token" + assert mock_credential.get_token.call_count == 1 + + # Second call should use cached token + token2 = provider.get_token() + assert token2 == "integration_azure_token" + assert mock_credential.get_token.call_count == 1 + + +class MockTokenProvider(TokenProvider): + """Mock implementation of TokenProvider for testing.""" + + def _fetch_token(self): + return "mock_token", time.time() + 3600