Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion application_sdk/clients/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ class DatabaseConfig(BaseModel):
)
parameters: Optional[List[str]] = Field(
default=None,
description="List of additional connection parameter names that can be dynamically added from credentials",
description="List of additional connection parameter names that can be dynamically added from credentials to the connection string. ex: ['ssl_mode'] will be added to the connection string as ?ssl_mode=require",
)
connect_args: Dict[str, Any] = Field(
default_factory=dict,
description="Additional connection arguments to be passed to SQLAlchemy. ex: {'sslmode': 'require'}",
)

class Config:
Expand Down
17 changes: 8 additions & 9 deletions application_sdk/clients/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,13 @@ class BaseSQLClient(ClientInterface):
Attributes:
connection: Database connection instance.
engine: SQLAlchemy engine instance.
sql_alchemy_connect_args (Dict[str, Any]): Additional connection arguments.
credentials (Dict[str, Any]): Database credentials.
resolved_credentials (Dict[str, Any]): Resolved credentials after reading from secret manager.
use_server_side_cursor (bool): Whether to use server-side cursors.
"""

connection = None
engine = None
sql_alchemy_connect_args: Dict[str, Any] = {}
credentials: Dict[str, Any] = {}
resolved_credentials: Dict[str, Any] = {}
use_server_side_cursor: bool = USE_SERVER_SIDE_CURSOR
Expand All @@ -55,7 +53,6 @@ def __init__(
self,
use_server_side_cursor: bool = USE_SERVER_SIDE_CURSOR,
credentials: Dict[str, Any] = {},
sql_alchemy_connect_args: Dict[str, Any] = {},
):
"""
Initialize the SQL client.
Expand All @@ -64,12 +61,9 @@ def __init__(
use_server_side_cursor (bool, optional): Whether to use server-side cursors.
Defaults to USE_SERVER_SIDE_CURSOR.
credentials (Dict[str, Any], optional): Database credentials. Defaults to {}.
sql_alchemy_connect_args (Dict[str, Any], optional): Additional SQLAlchemy
connection arguments. Defaults to {}.
"""
self.use_server_side_cursor = use_server_side_cursor
self.credentials = credentials
self.sql_alchemy_connect_args = sql_alchemy_connect_args

async def load(self, credentials: Dict[str, Any]) -> None:
"""Load credentials and prepare engine for lazy connections.
Expand All @@ -83,14 +77,17 @@ async def load(self, credentials: Dict[str, Any]) -> None:
Raises:
ClientError: If credentials are invalid or engine creation fails
"""
if not self.DB_CONFIG:
raise ValueError("DB_CONFIG is not configured for this SQL client.")

self.credentials = credentials # Update the instance credentials
try:
from sqlalchemy import create_engine

# Create engine but no persistent connection
self.engine = create_engine(
self.get_sqlalchemy_connection_string(),
connect_args=self.sql_alchemy_connect_args,
connect_args=self.DB_CONFIG.connect_args,
pool_pre_ping=True,
)

Expand Down Expand Up @@ -397,7 +394,6 @@ class AsyncBaseSQLClient(BaseSQLClient):
Attributes:
connection (AsyncConnection): Async database connection instance.
engine (AsyncEngine): Async SQLAlchemy engine instance.
sql_alchemy_connect_args (Dict[str, Any]): Additional connection arguments.
credentials (Dict[str, Any]): Database credentials.
use_server_side_cursor (bool): Whether to use server-side cursors.
"""
Expand All @@ -419,13 +415,16 @@ async def load(self, credentials: Dict[str, Any]) -> None:
ValueError: If credentials are invalid or engine creation fails.
"""
self.credentials = credentials
if not self.DB_CONFIG:
raise ValueError("DB_CONFIG is not configured for this SQL client.")

try:
from sqlalchemy.ext.asyncio import create_async_engine

# Create async engine but no persistent connection
self.engine = create_async_engine(
self.get_sqlalchemy_connection_string(),
connect_args=self.sql_alchemy_connect_args,
connect_args=self.DB_CONFIG.connect_args,
pool_pre_ping=True,
)
if not self.engine:
Expand Down
2 changes: 2 additions & 0 deletions docs/docs/concepts/clients.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ Both SQL client classes are typically **subclassed** for specific database types
* **`required` (list[str]):** Keys that must be present in `credentials`/`credentials.extra`. `{password}` is resolved via `get_auth_token()` depending on `authType`.
* **`parameters` (list[str], optional):** Optional keys appended as URL query parameters when present in `credentials`/`extra`.
* **`defaults` (dict[str, Any], optional):** Default URL parameters always appended unless already in the template.
* **`connect_args` (dict[str, Any], optional):** Additional connection arguments to be passed directly to SQLAlchemy's `create_engine` or `create_async_engine`. Useful for driver-specific connection parameters that are not part of the connection URL. Defaults to `{}`.
* **Credentials Note:** The `credentials` dictionary can include an `extra` field (JSON or dict). Lookups for `required` and `parameters` first check `credentials`, then `extra`.

2. **Loading (`load` method):**
Expand All @@ -61,6 +62,7 @@ class SnowflakeClient(BaseSQLClient):
required=["username", "password", "account_id"],
parameters=["warehouse", "role"],
defaults={"client_session_keep_alive": "true"},
connect_args={"sslmode": "require"}, # Optional: driver-specific connection arguments
)
```

Expand Down
9 changes: 8 additions & 1 deletion tests/unit/clients/test_async_sql_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,19 @@

import pytest

from application_sdk.clients.models import DatabaseConfig
from application_sdk.clients.sql import AsyncBaseSQLClient
from application_sdk.handlers.sql import BaseSQLHandler


@pytest.fixture
def async_sql_client():
client = AsyncBaseSQLClient()
client.DB_CONFIG = DatabaseConfig(
template="test://{username}:{password}@{host}:{port}/{database}",
required=["username", "password", "host", "port", "database"],
connect_args={},
)
client.get_sqlalchemy_connection_string = lambda: "test_connection_string"
return client

Expand Down Expand Up @@ -75,9 +81,10 @@ async def test_load(
await async_sql_client.load(credentials)

# Assertions to verify behavior
assert async_sql_client.DB_CONFIG is not None
create_async_engine.assert_called_once_with(
async_sql_client.get_sqlalchemy_connection_string(),
connect_args=async_sql_client.sql_alchemy_connect_args,
connect_args=async_sql_client.DB_CONFIG.connect_args,
pool_pre_ping=True,
)
assert async_sql_client.engine == mock_engine
Expand Down
16 changes: 12 additions & 4 deletions tests/unit/clients/test_sql_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@
@pytest.fixture
def sql_client():
client = BaseSQLClient()
client.DB_CONFIG = DatabaseConfig(
template="test://{username}:{password}@{host}:{port}/{database}",
required=["username", "password", "host", "port", "database"],
connect_args={},
)
client.get_sqlalchemy_connection_string = lambda: "test_connection_string"
return client

Expand All @@ -51,9 +56,10 @@ def test_load(mock_create_engine: Any, sql_client: BaseSQLClient):
asyncio.run(sql_client.load(credentials))

# Assertions to verify behavior
assert sql_client.DB_CONFIG is not None
mock_create_engine.assert_called_once_with(
sql_client.get_sqlalchemy_connection_string(),
connect_args=sql_client.sql_alchemy_connect_args,
connect_args=sql_client.DB_CONFIG.connect_args,
pool_pre_ping=True,
)
assert sql_client.engine == mock_engine
Expand All @@ -78,8 +84,9 @@ def test_load_property_based(
mock_create_engine.return_value = mock_engine
mock_engine.connect.return_value = mock_connection

# Set the connection arguments
sql_client.sql_alchemy_connect_args = connect_args
# Set the connection arguments in DB_CONFIG
assert sql_client.DB_CONFIG is not None
sql_client.DB_CONFIG.connect_args = connect_args

# Run the load function
asyncio.run(sql_client.load(credentials))
Expand Down Expand Up @@ -386,9 +393,10 @@ def test_connection_string_property_based(
asyncio.run(sql_client.load(credentials))

# Assertions to verify behavior
assert sql_client.DB_CONFIG is not None
mock_create_engine.assert_called_once_with(
connection_string,
connect_args=sql_client.sql_alchemy_connect_args,
connect_args=sql_client.DB_CONFIG.connect_args,
pool_pre_ping=True,
)
assert sql_client.engine == mock_engine
Expand Down