Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: enhance AzureAIChatCompletionClient validation and add unit tests #5417

Merged
merged 2 commits into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,10 @@ def _validate_config(config: Dict[str, Any]) -> AzureAIChatCompletionClientConfi
raise ValueError("credential is required for AzureAIChatCompletionClient")
if "model_info" not in config:
raise ValueError("model_info is required for AzureAIChatCompletionClient")
if "family" not in config["model_info"]:
raise ValueError(
"family is required for model_info in AzureAIChatCompletionClient. See autogen_core.models.ModelFamily for options."
)
if _is_github_model(config["endpoint"]) and "model" not in config:
raise ValueError("model is required for when using a Github model with AzureAIChatCompletionClient")
return cast(AzureAIChatCompletionClientConfig, config)
Expand Down Expand Up @@ -512,7 +516,8 @@ def capabilities(self) -> ModelInfo:

def __del__(self) -> None:
# TODO: This is a hack to close the open client
try:
asyncio.get_running_loop().create_task(self._client.close())
except RuntimeError:
asyncio.run(self._client.close())
if hasattr(self, "_client"):
try:
asyncio.get_running_loop().create_task(self._client.close())
except RuntimeError:
asyncio.run(self._client.close())
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from autogen_core import CancellationToken, FunctionCall, Image
from autogen_core.models import CreateResult, ModelFamily, UserMessage
from autogen_ext.models.azure import AzureAIChatCompletionClient
from autogen_ext.models.azure.config import GITHUB_MODELS_ENDPOINT
from azure.ai.inference.aio import (
ChatCompletionsClient,
)
Expand Down Expand Up @@ -104,6 +105,65 @@ def azure_client(monkeypatch: pytest.MonkeyPatch) -> AzureAIChatCompletionClient
)


@pytest.mark.asyncio
async def test_azure_ai_chat_completion_client_validation() -> None:
with pytest.raises(ValueError, match="endpoint is required"):
AzureAIChatCompletionClient(
model="model",
credential=AzureKeyCredential("api_key"),
model_info={
"json_output": False,
"function_calling": False,
"vision": False,
"family": "unknown",
},
)

with pytest.raises(ValueError, match="credential is required"):
AzureAIChatCompletionClient(
model="model",
endpoint="endpoint",
model_info={
"json_output": False,
"function_calling": False,
"vision": False,
"family": "unknown",
},
)

with pytest.raises(ValueError, match="model is required"):
AzureAIChatCompletionClient(
endpoint=GITHUB_MODELS_ENDPOINT,
credential=AzureKeyCredential("api_key"),
model_info={
"json_output": False,
"function_calling": False,
"vision": False,
"family": "unknown",
},
)

with pytest.raises(ValueError, match="model_info is required"):
AzureAIChatCompletionClient(
model="model",
endpoint="endpoint",
credential=AzureKeyCredential("api_key"),
)

with pytest.raises(ValueError, match="family is required"):
AzureAIChatCompletionClient(
model="model",
endpoint="endpoint",
credential=AzureKeyCredential("api_key"),
model_info={
"json_output": False,
"function_calling": False,
"vision": False,
# Missing family.
}, # type: ignore
)


@pytest.mark.asyncio
async def test_azure_ai_chat_completion_client(azure_client: AzureAIChatCompletionClient) -> None:
assert azure_client
Expand Down
Loading