From 9aaffea20c5bbea4fcddb123b04a3e9ac076e110 Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Thu, 6 Feb 2025 21:05:27 -0800 Subject: [PATCH] feat: enhance AzureAIChatCompletionClient validation and add unit tests --- .../models/azure/_azure_ai_client.py | 13 ++-- .../models/test_azure_ai_model_client.py | 60 +++++++++++++++++++ 2 files changed, 69 insertions(+), 4 deletions(-) diff --git a/python/packages/autogen-ext/src/autogen_ext/models/azure/_azure_ai_client.py b/python/packages/autogen-ext/src/autogen_ext/models/azure/_azure_ai_client.py index 4d8d5eb50630..317bf1db8bb1 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/azure/_azure_ai_client.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/azure/_azure_ai_client.py @@ -246,6 +246,10 @@ def _validate_config(config: Dict[str, Any]) -> AzureAIChatCompletionClientConfi raise ValueError("credential is required for AzureAIChatCompletionClient") if "model_info" not in config: raise ValueError("model_info is required for AzureAIChatCompletionClient") + if "family" not in config["model_info"]: + raise ValueError( + "family is required for model_info in AzureAIChatCompletionClient. See autogen_core.models.ModelFamily for options." + ) if _is_github_model(config["endpoint"]) and "model" not in config: raise ValueError("model is required for when using a Github model with AzureAIChatCompletionClient") return cast(AzureAIChatCompletionClientConfig, config) @@ -512,7 +516,8 @@ def capabilities(self) -> ModelInfo: def __del__(self) -> None: # TODO: This is a hack to close the open client - try: - asyncio.get_running_loop().create_task(self._client.close()) - except RuntimeError: - asyncio.run(self._client.close()) + if hasattr(self, "_client"): + try: + asyncio.get_running_loop().create_task(self._client.close()) + except RuntimeError: + asyncio.run(self._client.close()) diff --git a/python/packages/autogen-ext/tests/models/test_azure_ai_model_client.py b/python/packages/autogen-ext/tests/models/test_azure_ai_model_client.py index d2662a0a270b..d21c249b9571 100644 --- a/python/packages/autogen-ext/tests/models/test_azure_ai_model_client.py +++ b/python/packages/autogen-ext/tests/models/test_azure_ai_model_client.py @@ -7,6 +7,7 @@ from autogen_core import CancellationToken, FunctionCall, Image from autogen_core.models import CreateResult, ModelFamily, UserMessage from autogen_ext.models.azure import AzureAIChatCompletionClient +from autogen_ext.models.azure.config import GITHUB_MODELS_ENDPOINT from azure.ai.inference.aio import ( ChatCompletionsClient, ) @@ -104,6 +105,65 @@ def azure_client(monkeypatch: pytest.MonkeyPatch) -> AzureAIChatCompletionClient ) +@pytest.mark.asyncio +async def test_azure_ai_chat_completion_client_validation() -> None: + with pytest.raises(ValueError, match="endpoint is required"): + AzureAIChatCompletionClient( + model="model", + credential=AzureKeyCredential("api_key"), + model_info={ + "json_output": False, + "function_calling": False, + "vision": False, + "family": "unknown", + }, + ) + + with pytest.raises(ValueError, match="credential is required"): + AzureAIChatCompletionClient( + model="model", + endpoint="endpoint", + model_info={ + "json_output": False, + "function_calling": False, + "vision": False, + "family": "unknown", + }, + ) + + with pytest.raises(ValueError, match="model is required"): + AzureAIChatCompletionClient( + endpoint=GITHUB_MODELS_ENDPOINT, + credential=AzureKeyCredential("api_key"), + model_info={ + "json_output": False, + "function_calling": False, + "vision": False, + "family": "unknown", + }, + ) + + with pytest.raises(ValueError, match="model_info is required"): + AzureAIChatCompletionClient( + model="model", + endpoint="endpoint", + credential=AzureKeyCredential("api_key"), + ) + + with pytest.raises(ValueError, match="family is required"): + AzureAIChatCompletionClient( + model="model", + endpoint="endpoint", + credential=AzureKeyCredential("api_key"), + model_info={ + "json_output": False, + "function_calling": False, + "vision": False, + # Missing family. + }, # type: ignore + ) + + @pytest.mark.asyncio async def test_azure_ai_chat_completion_client(azure_client: AzureAIChatCompletionClient) -> None: assert azure_client