Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions postgres/changelog.d/21503.fixed
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed support for refreshing IAM authentication and Azure Managed Identity tokens
5 changes: 3 additions & 2 deletions postgres/datadog_checks/postgres/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@
# All rights reserved
# Licensed under a 3-clause BSD style license (see LICENSE)

from azure.core.credentials import AccessToken
from azure.identity import ManagedIdentityCredential

DEFAULT_PERMISSION_SCOPE = "https://ossrdbms-aad.database.windows.net/.default"


# Use the azure identity API to generate a token that will be used
# authenticate with either a system or user assigned managed identity
def generate_managed_identity_token(client_id: str, identity_scope: str = None):
def generate_managed_identity_token(client_id: str, identity_scope: str = None) -> AccessToken:
credential = ManagedIdentityCredential(client_id=client_id)
if not identity_scope:
identity_scope = DEFAULT_PERMISSION_SCOPE
return credential.get_token(identity_scope).token
return credential.get_token(identity_scope)
108 changes: 107 additions & 1 deletion postgres/datadog_checks/postgres/connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import threading
import time
from abc import ABC, abstractmethod
from collections import OrderedDict
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union
Expand All @@ -14,6 +15,100 @@
from .cursor import CommenterCursor, SQLASCIITextLoader


class TokenProvider(ABC):
"""
Interface for providing a token for managed authentication.
"""

def __init__(self, *, skew_seconds: int = 60):
self._skew = skew_seconds
self._lock = threading.Lock()
self._token: str | None = None
self._expires_at: float = 0.0

def get_token(self) -> str:
"""
Get a token for managed authentication.
"""
now = time.time()
with self._lock:
if self._token is None or now >= self._expires_at - self._skew:
token, expires_at = self._fetch_token()
self._token = token
self._expires_at = float(expires_at)
return self._token # type: ignore[return-value]

@abstractmethod
def _fetch_token(self) -> Tuple[str, float]:
"""
Return (token, expires_at_epoch_seconds).
Implementations should return the absolute expiry; if the provider
has a fixed TTL, compute expires_at = time.time() + ttl_seconds.
"""


class AWSTokenProvider(TokenProvider):
"""
Token provider for AWS IAM authentication.
"""

TOKEN_TTL_SECONDS = 900 # 15 minutes

def __init__(
self, host: str, port: int, username: str, region: str, *, role_arn: str = None, skew_seconds: int = 60
):
super().__init__(skew_seconds=skew_seconds)
self.host = host
self.port = port
self.username = username
self.region = region
self.role_arn = role_arn

def _fetch_token(self) -> Tuple[str, float]:
# Import aws only when this method is called
from .aws import generate_rds_iam_token

token = generate_rds_iam_token(
host=self.host, port=self.port, username=self.username, region=self.region, role_arn=self.role_arn
)
return token, time.time() + self.TOKEN_TTL_SECONDS


class AzureTokenProvider(TokenProvider):
"""
Token provider for Azure Managed Identity.
"""

def __init__(self, client_id: str, identity_scope: str = None, skew_seconds: int = 60):
super().__init__(skew_seconds=skew_seconds)
self.client_id = client_id
self.identity_scope = identity_scope

def _fetch_token(self) -> Tuple[str, float]:
# Import azure only when this method is called
from .azure import generate_managed_identity_token

token = generate_managed_identity_token(client_id=self.client_id, identity_scope=self.identity_scope)
return token.token, float(token.expires_at)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

High

Azure's AccessToken object uses expires_on attribute, not expires_at. This will cause an AttributeError at runtime when the AzureTokenProvider attempts to access token.expires_at. The correct attribute is expires_on which contains the expiration time as seconds since epoch.

Agent: 🤖 Integrations reviewer • Fix in Cursor • Fix in Claude

Prompt for Agent
Task: Address review feedback left on GitHub.
Repository: mesa-dot-dev/integrations-core#5
File: postgres/datadog_checks/postgres/connection_pool.py#L92
Action: Open this file location in your editor, inspect the highlighted code, and resolve the issue described below.

Feedback:
Azure's AccessToken object uses `expires_on` attribute, not `expires_at`. This will cause an AttributeError at runtime when the AzureTokenProvider attempts to access `token.expires_at`. The correct attribute is `expires_on` which contains the expiration time as seconds since epoch.



class TokenAwareConnection(Connection):
"""
Connection that can be used for managed authentication.
"""

token_provider: Optional[TokenProvider] = None

@classmethod
def connect(cls, *args, **kwargs):
"""
Override the connection method to pass a refreshable token as the connection password.
"""
if cls.token_provider:
kwargs["password"] = cls.token_provider.get_token()
return super().connect(*args, **kwargs)


@dataclass(frozen=True)
class PostgresConnectionArgs:
"""
Expand All @@ -25,6 +120,7 @@ class PostgresConnectionArgs:
host: Optional[str] = None
port: Optional[int] = None
password: Optional[str] = None
token_provider: Optional[TokenProvider] = None
ssl_mode: Optional[str] = "allow"
ssl_cert: Optional[str] = None
ssl_root_cert: Optional[str] = None
Expand Down Expand Up @@ -86,6 +182,7 @@ def __init__(
pool_config: Optional[Dict[str, Any]] = None,
statement_timeout: Optional[int] = None, # milliseconds
sqlascii_encodings: Optional[list[str]] = None,
token_provider: Optional[TokenProvider] = None,
) -> None:
"""
Initialize the pool manager.
Expand All @@ -101,13 +198,17 @@ def __init__(
self.base_conn_args = base_conn_args
self.statement_timeout = statement_timeout
self.sqlascii_encodings = sqlascii_encodings
self.token_provider = token_provider

self.pool_config = {
**(pool_config or {}),
"min_size": 0,
"max_size": 2,
"open": True,
}

TokenAwareConnection.token_provider = self.token_provider
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

High

This stores the token provider as a class-level attribute, which creates a global state shared across all PostgreSQL check instances in the Agent process. If multiple PostgreSQL checks are configured (e.g., different hosts, or some with managed auth and others without), the last initialized check will overwrite this attribute, causing incorrect authentication behavior for other checks. The token provider should be scoped to the specific connection pool instance rather than the class.

Agent: 🤖 Integrations reviewer • Fix in Cursor • Fix in Claude

Prompt for Agent
Task: Address review feedback left on GitHub.
Repository: mesa-dot-dev/integrations-core#5
File: postgres/datadog_checks/postgres/connection_pool.py#L210
Action: Open this file location in your editor, inspect the highlighted code, and resolve the issue described below.

Feedback:
This stores the token provider as a class-level attribute, which creates a global state shared across all PostgreSQL check instances in the Agent process. If multiple PostgreSQL checks are configured (e.g., different hosts, or some with managed auth and others without), the last initialized check will overwrite this attribute, causing incorrect authentication behavior for other checks. The token provider should be scoped to the specific connection pool instance rather than the class.


self.lock = threading.Lock()
self.pools: OrderedDict[str, Tuple[ConnectionPool, float, bool]] = OrderedDict()
self._closed = False
Expand Down Expand Up @@ -139,7 +240,12 @@ def _create_pool(self, dbname: str) -> ConnectionPool:
"""
kwargs = self.base_conn_args.as_kwargs(dbname=dbname)

return ConnectionPool(kwargs=kwargs, configure=self._configure_connection, **self.pool_config)
return ConnectionPool(
kwargs=kwargs,
configure=self._configure_connection,
connection_class=TokenAwareConnection,
**self.pool_config,
)

def get_connection(self, dbname: str, persistent: bool = False):
"""
Expand Down
51 changes: 28 additions & 23 deletions postgres/datadog_checks/postgres/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,14 @@
)
from datadog_checks.base.utils.db.utils import resolve_db_host as agent_host_resolver
from datadog_checks.base.utils.serialization import json
from datadog_checks.postgres import aws, azure
from datadog_checks.postgres.connection_pool import LRUConnectionPoolManager, PostgresConnectionArgs
from datadog_checks.postgres.connection_pool import (
AWSTokenProvider,
AzureTokenProvider,
LRUConnectionPoolManager,
PostgresConnectionArgs,
TokenAwareConnection,
TokenProvider,
)
from datadog_checks.postgres.discovery import PostgresAutodiscovery
from datadog_checks.postgres.health import PostgresHealth
from datadog_checks.postgres.metadata import PostgresMetadata
Expand Down Expand Up @@ -165,6 +171,7 @@ def __init__(self, name, init_config, instances):
base_conn_args=self.build_connection_args(),
statement_timeout=self._config.query_timeout,
sqlascii_encodings=self._config.query_encodings,
token_provider=self.build_token_provider(),
)
self.metrics_cache = PostgresMetricsCache(self._config)
self.statement_metrics = PostgresStatementMetrics(self, self._config)
Expand Down Expand Up @@ -913,38 +920,36 @@ def _collect_stats(self, instance_tags):
for dynamic_query in self.dynamic_queries:
dynamic_query.execute()

def build_token_provider(self) -> TokenProvider:
if self._config.aws.managed_authentication.enabled:
return AWSTokenProvider(
host=self._config.host,
port=self._config.port,
username=self._config.username,
region=self._config.aws.region,
role_arn=self._config.aws.managed_authentication.role_arn,
)
elif self._config.azure.managed_authentication.enabled:
return AzureTokenProvider(
client_id=self._config.azure.managed_authentication.client_id,
identity_scope=self._config.azure.managed_authentication.identity_scope,
)
else:
return None

def build_connection_args(self) -> PostgresConnectionArgs:
if self._config.host == 'localhost' and self._config.password == '':
return PostgresConnectionArgs(
application_name=self._config.application_name,
username=self._config.username,
)
else:
password = self._config.password
if self._config.aws.managed_authentication.enabled:
password = aws.generate_rds_iam_token(
host=self._config.host,
username=self._config.username,
port=self._config.port,
region=self._config.aws.region,
role_arn=self._config.aws.managed_authentication.role_arn,
)
elif self._config.azure.managed_authentication.enabled:
client_id = self._config.azure.managed_authentication.client_id
identity_scope = self._config.azure.managed_authentication.identity_scope
password = azure.generate_managed_identity_token(client_id=client_id, identity_scope=identity_scope)

self.log.debug(
"Try to connect to %s with %s",
self._config.host,
"password" if password == self._config.password else "token",
)
return PostgresConnectionArgs(
application_name=self._config.application_name,
username=self._config.username,
host=self._config.host,
port=self._config.port,
password=password,
password=self._config.password,
ssl_mode=self._config.ssl,
ssl_cert=self._config.ssl_cert,
ssl_root_cert=self._config.ssl_root_cert,
Expand All @@ -956,7 +961,7 @@ def _new_connection(self, dbname):
# TODO: Keeping this main connection outside of the pool for now to keep existing behavior.
# We should move this to the pool in the future.
conn_args = self.build_connection_args()
conn = psycopg.connect(**conn_args.as_kwargs(dbname=dbname))
conn = TokenAwareConnection.connect(**conn_args.as_kwargs(dbname=dbname))
self.db_pool._configure_connection(conn)
return conn

Expand Down
Loading