diff --git a/application_sdk/clients/models.py b/application_sdk/clients/models.py index cb103d3a4..4fad3d532 100644 --- a/application_sdk/clients/models.py +++ b/application_sdk/clients/models.py @@ -31,7 +31,11 @@ class DatabaseConfig(BaseModel): ) parameters: Optional[List[str]] = Field( default=None, - description="List of additional connection parameter names that can be dynamically added from credentials", + description="List of additional connection parameter names that can be dynamically added from credentials to the connection string. ex: ['ssl_mode'] will be added to the connection string as ?ssl_mode=require", + ) + connect_args: Dict[str, Any] = Field( + default_factory=dict, + description="Additional connection arguments to be passed to SQLAlchemy. ex: {'sslmode': 'require'}", ) class Config: diff --git a/application_sdk/clients/sql.py b/application_sdk/clients/sql.py index 411dec877..1603e3395 100644 --- a/application_sdk/clients/sql.py +++ b/application_sdk/clients/sql.py @@ -37,7 +37,6 @@ class BaseSQLClient(ClientInterface): Attributes: connection: Database connection instance. engine: SQLAlchemy engine instance. - sql_alchemy_connect_args (Dict[str, Any]): Additional connection arguments. credentials (Dict[str, Any]): Database credentials. resolved_credentials (Dict[str, Any]): Resolved credentials after reading from secret manager. use_server_side_cursor (bool): Whether to use server-side cursors. @@ -45,7 +44,6 @@ class BaseSQLClient(ClientInterface): connection = None engine = None - sql_alchemy_connect_args: Dict[str, Any] = {} credentials: Dict[str, Any] = {} resolved_credentials: Dict[str, Any] = {} use_server_side_cursor: bool = USE_SERVER_SIDE_CURSOR @@ -55,7 +53,6 @@ def __init__( self, use_server_side_cursor: bool = USE_SERVER_SIDE_CURSOR, credentials: Dict[str, Any] = {}, - sql_alchemy_connect_args: Dict[str, Any] = {}, ): """ Initialize the SQL client. @@ -64,12 +61,9 @@ def __init__( use_server_side_cursor (bool, optional): Whether to use server-side cursors. Defaults to USE_SERVER_SIDE_CURSOR. credentials (Dict[str, Any], optional): Database credentials. Defaults to {}. - sql_alchemy_connect_args (Dict[str, Any], optional): Additional SQLAlchemy - connection arguments. Defaults to {}. """ self.use_server_side_cursor = use_server_side_cursor self.credentials = credentials - self.sql_alchemy_connect_args = sql_alchemy_connect_args async def load(self, credentials: Dict[str, Any]) -> None: """Load credentials and prepare engine for lazy connections. @@ -83,6 +77,9 @@ async def load(self, credentials: Dict[str, Any]) -> None: Raises: ClientError: If credentials are invalid or engine creation fails """ + if not self.DB_CONFIG: + raise ValueError("DB_CONFIG is not configured for this SQL client.") + self.credentials = credentials # Update the instance credentials try: from sqlalchemy import create_engine @@ -90,7 +87,7 @@ async def load(self, credentials: Dict[str, Any]) -> None: # Create engine but no persistent connection self.engine = create_engine( self.get_sqlalchemy_connection_string(), - connect_args=self.sql_alchemy_connect_args, + connect_args=self.DB_CONFIG.connect_args, pool_pre_ping=True, ) @@ -397,7 +394,6 @@ class AsyncBaseSQLClient(BaseSQLClient): Attributes: connection (AsyncConnection): Async database connection instance. engine (AsyncEngine): Async SQLAlchemy engine instance. - sql_alchemy_connect_args (Dict[str, Any]): Additional connection arguments. credentials (Dict[str, Any]): Database credentials. use_server_side_cursor (bool): Whether to use server-side cursors. """ @@ -419,13 +415,16 @@ async def load(self, credentials: Dict[str, Any]) -> None: ValueError: If credentials are invalid or engine creation fails. """ self.credentials = credentials + if not self.DB_CONFIG: + raise ValueError("DB_CONFIG is not configured for this SQL client.") + try: from sqlalchemy.ext.asyncio import create_async_engine # Create async engine but no persistent connection self.engine = create_async_engine( self.get_sqlalchemy_connection_string(), - connect_args=self.sql_alchemy_connect_args, + connect_args=self.DB_CONFIG.connect_args, pool_pre_ping=True, ) if not self.engine: diff --git a/docs/docs/concepts/clients.md b/docs/docs/concepts/clients.md index 718a37aea..9b9575a6f 100644 --- a/docs/docs/concepts/clients.md +++ b/docs/docs/concepts/clients.md @@ -36,6 +36,7 @@ Both SQL client classes are typically **subclassed** for specific database types * **`required` (list[str]):** Keys that must be present in `credentials`/`credentials.extra`. `{password}` is resolved via `get_auth_token()` depending on `authType`. * **`parameters` (list[str], optional):** Optional keys appended as URL query parameters when present in `credentials`/`extra`. * **`defaults` (dict[str, Any], optional):** Default URL parameters always appended unless already in the template. + * **`connect_args` (dict[str, Any], optional):** Additional connection arguments to be passed directly to SQLAlchemy's `create_engine` or `create_async_engine`. Useful for driver-specific connection parameters that are not part of the connection URL. Defaults to `{}`. * **Credentials Note:** The `credentials` dictionary can include an `extra` field (JSON or dict). Lookups for `required` and `parameters` first check `credentials`, then `extra`. 2. **Loading (`load` method):** @@ -61,6 +62,7 @@ class SnowflakeClient(BaseSQLClient): required=["username", "password", "account_id"], parameters=["warehouse", "role"], defaults={"client_session_keep_alive": "true"}, + connect_args={"sslmode": "require"}, # Optional: driver-specific connection arguments ) ``` diff --git a/tests/unit/clients/test_async_sql_client.py b/tests/unit/clients/test_async_sql_client.py index 1ecbf5c41..a5d3fe45d 100644 --- a/tests/unit/clients/test_async_sql_client.py +++ b/tests/unit/clients/test_async_sql_client.py @@ -3,6 +3,7 @@ import pytest +from application_sdk.clients.models import DatabaseConfig from application_sdk.clients.sql import AsyncBaseSQLClient from application_sdk.handlers.sql import BaseSQLHandler @@ -10,6 +11,11 @@ @pytest.fixture def async_sql_client(): client = AsyncBaseSQLClient() + client.DB_CONFIG = DatabaseConfig( + template="test://{username}:{password}@{host}:{port}/{database}", + required=["username", "password", "host", "port", "database"], + connect_args={}, + ) client.get_sqlalchemy_connection_string = lambda: "test_connection_string" return client @@ -75,9 +81,10 @@ async def test_load( await async_sql_client.load(credentials) # Assertions to verify behavior + assert async_sql_client.DB_CONFIG is not None create_async_engine.assert_called_once_with( async_sql_client.get_sqlalchemy_connection_string(), - connect_args=async_sql_client.sql_alchemy_connect_args, + connect_args=async_sql_client.DB_CONFIG.connect_args, pool_pre_ping=True, ) assert async_sql_client.engine == mock_engine diff --git a/tests/unit/clients/test_sql_client.py b/tests/unit/clients/test_sql_client.py index 587991160..77a30fa05 100644 --- a/tests/unit/clients/test_sql_client.py +++ b/tests/unit/clients/test_sql_client.py @@ -25,6 +25,11 @@ @pytest.fixture def sql_client(): client = BaseSQLClient() + client.DB_CONFIG = DatabaseConfig( + template="test://{username}:{password}@{host}:{port}/{database}", + required=["username", "password", "host", "port", "database"], + connect_args={}, + ) client.get_sqlalchemy_connection_string = lambda: "test_connection_string" return client @@ -51,9 +56,10 @@ def test_load(mock_create_engine: Any, sql_client: BaseSQLClient): asyncio.run(sql_client.load(credentials)) # Assertions to verify behavior + assert sql_client.DB_CONFIG is not None mock_create_engine.assert_called_once_with( sql_client.get_sqlalchemy_connection_string(), - connect_args=sql_client.sql_alchemy_connect_args, + connect_args=sql_client.DB_CONFIG.connect_args, pool_pre_ping=True, ) assert sql_client.engine == mock_engine @@ -78,8 +84,9 @@ def test_load_property_based( mock_create_engine.return_value = mock_engine mock_engine.connect.return_value = mock_connection - # Set the connection arguments - sql_client.sql_alchemy_connect_args = connect_args + # Set the connection arguments in DB_CONFIG + assert sql_client.DB_CONFIG is not None + sql_client.DB_CONFIG.connect_args = connect_args # Run the load function asyncio.run(sql_client.load(credentials)) @@ -386,9 +393,10 @@ def test_connection_string_property_based( asyncio.run(sql_client.load(credentials)) # Assertions to verify behavior + assert sql_client.DB_CONFIG is not None mock_create_engine.assert_called_once_with( connection_string, - connect_args=sql_client.sql_alchemy_connect_args, + connect_args=sql_client.DB_CONFIG.connect_args, pool_pre_ping=True, ) assert sql_client.engine == mock_engine