diff --git a/README.md b/README.md index 1d81a3a7..9d646fbe 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,51 @@ pip install shell-gpt ``` By default, ShellGPT uses OpenAI's API and GPT-4 model. You'll need an API key, you can generate one [here](https://beta.openai.com/account/api-keys). You will be prompted for your key which will then be stored in `~/.config/shell_gpt/.sgptrc`. OpenAI API is not free of charge, please refer to the [OpenAI pricing](https://openai.com/pricing) for more information. +### Azure OpenAI Provider +ShellGPT also supports Azure OpenAI provider. To use Azure OpenAI, you need to configure several Azure-specific parameters: + +#### 1. Set the Provider +```shell +export OPENAI_PROVIDER=azure-openai +``` + +#### 2. Configure Azure Resource Endpoint +```shell +export AZURE_RESOURCE_ENDPOINT=https://your-resource.cognitiveservices.azure.com +``` + +#### 3. Configure Deployment Name +```shell +export AZURE_DEPLOYMENT_NAME=your-deployment-name +``` + +#### 4. Set API Version +```shell +export API_VERSION=2025-01-01-preview +``` + +#### 5. Set API Key +```shell +export OPENAI_API_KEY=your_azure_openai_api_key +``` + +#### Configuration File +You can also set these in your configuration file `~/.config/shell_gpt/.sgptrc`: +```text +OPENAI_PROVIDER=azure-openai +AZURE_RESOURCE_ENDPOINT=https://your-resource.cognitiveservices.azure.com +AZURE_DEPLOYMENT_NAME=your-deployment-name +API_VERSION=2025-01-01-preview +OPENAI_API_KEY=your_azure_openai_api_key +``` + +#### URL Structure +Azure OpenAI uses a different URL structure than standard OpenAI: +- **Standard OpenAI**: `https://api.openai.com/v1/chat/completions` +- **Azure OpenAI**: Uses the `AzureOpenAI` client which automatically constructs the correct URL format + +The Azure OpenAI provider uses the official `AzureOpenAI` client from the OpenAI library, which handles the endpoint, deployment name, and API version automatically. + > [!TIP] > Alternatively, you can use locally hosted open source models which are available for free. To use local models, you will need to run your own LLM backend server such as [Ollama](https://github.com/ollama/ollama). To set up ShellGPT with Ollama, please follow this comprehensive [guide](https://github.com/TheR1D/shell_gpt/wiki/Ollama). > @@ -508,4 +553,4 @@ ENTRYPOINT ["sgpt"] ## Additional documentation * [Azure integration](https://github.com/TheR1D/shell_gpt/wiki/Azure) -* [Ollama integration](https://github.com/TheR1D/shell_gpt/wiki/Ollama) +* [Ollama integration](https://github.com/TheR1D/shell_gpt/wiki/Ollama) \ No newline at end of file diff --git a/examples/azure_openai_example.py b/examples/azure_openai_example.py new file mode 100644 index 00000000..e869569c --- /dev/null +++ b/examples/azure_openai_example.py @@ -0,0 +1,199 @@ +#!/usr/bin/env python3 +""" +Example script demonstrating ShellGPT with Azure OpenAI provider. + +This script shows how to configure and use ShellGPT with Azure OpenAI. +""" + +import os +import subprocess +import sys +from pathlib import Path + +def setup_azure_openai_config(): + """Set up Azure OpenAI configuration.""" + config_dir = Path.home() / ".config" / "shell_gpt" + config_file = config_dir / ".sgptrc" + + # Create config directory if it doesn't exist + config_dir.mkdir(parents=True, exist_ok=True) + + # Read existing config or create new one + config_lines = [] + if config_file.exists(): + with open(config_file, 'r') as f: + config_lines = f.readlines() + + # Check if OPENAI_PROVIDER is already set + provider_set = any(line.startswith("OPENAI_PROVIDER=") for line in config_lines) + + if not provider_set: + config_lines.append("OPENAI_PROVIDER=azure-openai\n") + + with open(config_file, 'w') as f: + f.writelines(config_lines) + print(f"āœ… Added OPENAI_PROVIDER=azure-openai to {config_file}") + else: + print(f"āœ… OPENAI_PROVIDER already configured in {config_file}") + + +def run_sgpt_example(): + """Run a simple ShellGPT example with Azure OpenAI.""" + print("\nšŸš€ Running ShellGPT example with Azure OpenAI...") + print("Note: Shell commands in interactive mode will prompt for Execute/Describe/Abort") + + # Set environment variable for this session + os.environ["OPENAI_PROVIDER"] = "azure-openai" + + try: + # Example 1: Simple question + print("\nšŸ“ Example 1: Simple question") + result = subprocess.run( + ["sgpt", "What is the capital of France?"], + capture_output=True, + text=True, + timeout=30 + ) + if result.returncode == 0: + print(f"Response: {result.stdout.strip()}") + else: + print(f"Error: {result.stderr.strip()}") + + # Example 2: Shell command generation (non-interactive) + print("\nšŸ’» Example 2: Shell command generation (non-interactive)") + print("Using --no-interaction flag to avoid interactive prompts") + result = subprocess.run( + ["sgpt", "--shell", "--no-interaction", "list all files in current directory"], + capture_output=True, + text=True, + timeout=30 + ) + if result.returncode == 0: + print(f"Generated command: {result.stdout.strip()}") + else: + print(f"Error: {result.stderr.strip()}") + + + # Example 3: Code generation + print("\nšŸ Example 3: Code generation") + result = subprocess.run( + ["sgpt", "--code", "hello world in python"], + capture_output=True, + text=True, + timeout=30 + ) + if result.returncode == 0: + print(f"Generated code:\n{result.stdout.strip()}") + else: + print(f"Error: {result.stderr.strip()}") + + # Example 4: Chat mode + print("\nšŸ’¬ Example 4: Chat mode") + result = subprocess.run( + ["sgpt", "--chat", "test_session", "Remember my name is Alice"], + capture_output=True, + text=True, + timeout=30 + ) + if result.returncode == 0: + print(f"Chat response: {result.stdout.strip()}") + else: + print(f"Error: {result.stderr.strip()}") + + except subprocess.TimeoutExpired: + print("āŒ Request timed out. Please check your Azure OpenAI configuration.") + except FileNotFoundError: + print("āŒ ShellGPT not found. Please install it first: pip install shell-gpt") + except Exception as e: + print(f"āŒ Error: {e}") + +def check_configuration(): + """Check if Azure OpenAI is properly configured.""" + print("šŸ” Checking Azure OpenAI Configuration") + print("=" * 50) + + # Check if provider is set to azure-openai + if os.getenv("OPENAI_PROVIDER") != "azure-openai": + print("\nāš ļø Warning: OPENAI_PROVIDER not set to azure-openai.") + print(" Please set: export OPENAI_PROVIDER=azure-openai") + print("\n Or add it to your config file:") + print(" echo 'OPENAI_PROVIDER=azure-openai' >> ~/.config/shell_gpt/.sgptrc") + + # Check if Azure-specific configuration is set + if not os.getenv("AZURE_RESOURCE_ENDPOINT"): + print("\nāš ļø Warning: AZURE_RESOURCE_ENDPOINT not configured.") + print(" For Azure OpenAI, please set your resource endpoint:") + print(" export AZURE_RESOURCE_ENDPOINT=https://your-resource.cognitiveservices.azure.com") + print("\n Or add it to your config file:") + print(" echo 'AZURE_RESOURCE_ENDPOINT=https://your-resource.cognitiveservices.azure.com' >> ~/.config/shell_gpt/.sgptrc") + + if not os.getenv("AZURE_DEPLOYMENT_NAME"): + print("\nāš ļø Warning: AZURE_DEPLOYMENT_NAME not configured.") + print(" For Azure OpenAI, please set your deployment name:") + print(" export AZURE_DEPLOYMENT_NAME=your-deployment-name") + print("\n Or add it to your config file:") + print(" echo 'AZURE_DEPLOYMENT_NAME=your-deployment-name' >> ~/.config/shell_gpt/.sgptrc") + + if not os.getenv("API_VERSION"): + print("\nāš ļø Warning: API_VERSION not configured.") + print(" For Azure OpenAI, please set your API version:") + print(" export API_VERSION=2025-01-01-preview") + print("\n Or add it to your config file:") + print(" echo 'API_VERSION=2025-01-01-preview' >> ~/.config/shell_gpt/.sgptrc") + + if not os.getenv("OPENAI_API_KEY"): + print("\nāš ļø Warning: OPENAI_API_KEY not configured.") + print(" Please set your Azure OpenAI API key:") + print(" export OPENAI_API_KEY=your_azure_openai_api_key") + print("\n Or add it to your config file:") + print(" echo 'OPENAI_API_KEY=your_azure_openai_api_key' >> ~/.config/shell_gpt/.sgptrc") + + # Show example configuration + print("\nšŸ“‹ Example Azure OpenAI Configuration:") + print(" export OPENAI_PROVIDER=azure-openai") + print(" export AZURE_RESOURCE_ENDPOINT=https://er-biz-svcs-us.cognitiveservices.azure.com") + print(" export AZURE_DEPLOYMENT_NAME=erbizgpt4o") + print(" export API_VERSION=2025-01-01-preview") + print(" export OPENAI_API_KEY=your_azure_openai_api_key") + + print("\nšŸ”— This will use the AzureOpenAI client with:") + print(" - Endpoint: https://er-biz-svcs-us.cognitiveservices.azure.com") + print(" - Deployment: erbizgpt4o") + print(" - API Version: 2025-01-01-preview") + print(" - Model parameter: erbizgpt4o (deployment name)") + +def main(): + """Main function.""" + print("šŸ”§ ShellGPT Azure OpenAI Provider Example") + print("=" * 50) + + # Check if ShellGPT is installed + try: + subprocess.run(["sgpt", "--version"], capture_output=True, check=True) + except (subprocess.CalledProcessError, FileNotFoundError): + print("āŒ ShellGPT not found. Please install it first:") + print(" pip install shell-gpt") + sys.exit(1) + + # Set up configuration + setup_azure_openai_config() + + # Check configuration + check_configuration() + + # Run examples + run_sgpt_example() + + + print("\nāœ… Example completed!") + print("\nšŸ“š For more information, see:") + print(" - README.md for detailed usage") + print(" - Azure OpenAI documentation: https://docs.microsoft.com/en-us/azure/cognitive-services/openai/") + print("\nšŸ’” Tips for using ShellGPT with Azure OpenAI:") + print(" - Use --no-interaction for non-interactive shell commands") + print(" - Use --shell for interactive command generation") + print(" - Use --code for pure code generation") + print(" - Use --chat for conversation mode") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/sgpt/config.py b/sgpt/config.py index bc083733..8b9ee084 100644 --- a/sgpt/config.py +++ b/sgpt/config.py @@ -37,6 +37,10 @@ "SHELL_INTERACTION": os.getenv("SHELL_INTERACTION ", "true"), "OS_NAME": os.getenv("OS_NAME", "auto"), "SHELL_NAME": os.getenv("SHELL_NAME", "auto"), + "OPENAI_PROVIDER": os.getenv("OPENAI_PROVIDER", "openai"), + "API_VERSION": os.getenv("API_VERSION", "2024-02-15-preview"), + "AZURE_DEPLOYMENT_NAME": os.getenv("AZURE_DEPLOYMENT_NAME", ""), + "AZURE_RESOURCE_ENDPOINT": os.getenv("AZURE_RESOURCE_ENDPOINT", ""), # New features might add their own config variables here. } diff --git a/sgpt/handlers/handler.py b/sgpt/handlers/handler.py index a17d8024..fb2926e5 100644 --- a/sgpt/handlers/handler.py +++ b/sgpt/handlers/handler.py @@ -10,8 +10,10 @@ completion: Callable[..., Any] = lambda *args, **kwargs: Generator[Any, None, None] +# Initialize basic kwargs without provider-specific parameters base_url = cfg.get("API_BASE_URL") use_litellm = cfg.get("USE_LITELLM") == "true" + additional_kwargs = { "timeout": int(cfg.get("REQUEST_TIMEOUT")), "api_key": cfg.get("OPENAI_API_KEY"), @@ -32,6 +34,49 @@ additional_kwargs = {} +def get_provider() -> str: + """Safely get the provider with fallback.""" + try: + return cfg.get("OPENAI_PROVIDER") + except: + return "openai" + + +def get_azure_client(): + """Get Azure OpenAI client with proper configuration.""" + try: + provider = get_provider() + if provider == "azure-openai": + # Get Azure-specific configuration + azure_endpoint = "" + api_version = "2024-02-15-preview" + + try: + azure_endpoint = cfg.get("AZURE_RESOURCE_ENDPOINT") + api_version = cfg.get("API_VERSION") + except: + pass + + # Validate Azure configuration + if not azure_endpoint: + raise Exception("Azure OpenAI requires AZURE_RESOURCE_ENDPOINT configuration") + + # Create Azure OpenAI client with proper parameters + from openai import AzureOpenAI + + azure_kwargs = { + "api_version": api_version, + "azure_endpoint": azure_endpoint, + "api_key": cfg.get("OPENAI_API_KEY"), + "timeout": int(cfg.get("REQUEST_TIMEOUT")), + } + + return AzureOpenAI(**azure_kwargs) + except Exception as e: + print(f"Warning: Failed to create Azure OpenAI client: {e}") + return None + + class Handler: cache = Cache(int(cfg.get("CACHE_LENGTH")), Path(cfg.get("CACHE_PATH"))) @@ -44,6 +89,10 @@ def __init__(self, role: SystemRole, markdown: bool) -> None: self.markdown = "APPLY MARKDOWN" in self.role.role and markdown self.code_theme, self.color = cfg.get("CODE_THEME"), cfg.get("DEFAULT_COLOR") + + # Get provider and client + self.provider = get_provider() + self.azure_client = get_azure_client() if self.provider == "azure-openai" else None @property def printer(self) -> Printer: @@ -56,6 +105,25 @@ def printer(self) -> Printer: def make_messages(self, prompt: str) -> List[Dict[str, str]]: raise NotImplementedError + def _get_model_name(self, model: str) -> str: + """Get the appropriate model name based on provider.""" + if self.provider == "azure-openai": + # For Azure OpenAI, we need to use the deployment name from config + # The deployment name is passed as the model parameter + try: + deployment_name = cfg.get("AZURE_DEPLOYMENT_NAME") + if deployment_name: + return deployment_name + else: + # Fallback to model name if deployment name not configured + return model + except: + # If deployment name not found in config, use model name + return model + else: + # Standard OpenAI model names + return model + def handle_function_call( self, messages: List[dict[str, Any]], @@ -92,6 +160,8 @@ def get_completion( functions: Optional[List[Dict[str, str]]], ) -> Generator[str, None, None]: name = arguments = "" + local_kwargs = {} # Initialize local kwargs + is_shell_role = self.role.name == DefaultRoles.SHELL.value is_code_role = self.role.name == DefaultRoles.CODE.value is_dsc_shell_role = self.role.name == DefaultRoles.DESCRIBE_SHELL.value @@ -99,48 +169,157 @@ def get_completion( functions = None if functions: - additional_kwargs["tool_choice"] = "auto" - additional_kwargs["tools"] = functions - additional_kwargs["parallel_tool_calls"] = False + local_kwargs["tool_choice"] = "auto" + local_kwargs["tools"] = functions + local_kwargs["parallel_tool_calls"] = False - response = completion( - model=model, - temperature=temperature, - top_p=top_p, - messages=messages, - stream=True, - **additional_kwargs, - ) + # Prepare API call parameters + disable_stream = cfg.get("DISABLE_STREAMING") == "true" + api_call_kwargs = { + "temperature": temperature, + "top_p": top_p, + "messages": messages, + "stream": not disable_stream, # Respect DISABLE_STREAMING configuration + "model": self._get_model_name(model), # Always pass model parameter + **local_kwargs, + } + + # Use appropriate client based on provider + if self.provider == "azure-openai" and self.azure_client: + response = self.azure_client.chat.completions.create(**api_call_kwargs) + else: + response = completion(**api_call_kwargs) try: for chunk in response: - delta = chunk.choices[0].delta - - # LiteLLM uses dict instead of Pydantic object like OpenAI does. - tool_calls = ( - delta.get("tool_calls") if use_litellm else delta.tool_calls - ) - if tool_calls: - for tool_call in tool_calls: - if tool_call.function.name: - name = tool_call.function.name - if tool_call.function.arguments: - arguments += tool_call.function.arguments - if chunk.choices[0].finish_reason == "tool_calls": - yield from self.handle_function_call(messages, name, arguments) - yield from self.get_completion( - model=model, - temperature=temperature, - top_p=top_p, - messages=messages, - functions=functions, - caching=False, - ) - return - - yield delta.content or "" + # Safety check for chunk structure + if not hasattr(chunk, 'choices') or not chunk.choices: + continue + + # Safety check for choices array + if len(chunk.choices) == 0: + continue + + choice = chunk.choices[0] + + # Handle both streaming and non-streaming responses + if hasattr(choice, 'delta') and choice.delta is not None: + # Streaming response structure + delta = choice.delta + finish_reason = choice.finish_reason + + # Safety check for delta + if delta is None: + continue + + # Handle tool calls (function calls) for streaming + tool_calls = None + if self.provider == "azure-openai": + # Azure OpenAI tool calls - handle as object properties + if hasattr(delta, 'tool_calls') and delta.tool_calls: + tool_calls = delta.tool_calls + else: + # Standard OpenAI format + tool_calls = ( + delta.get("tool_calls") if use_litellm else delta.tool_calls + ) + + if tool_calls: + for tool_call in tool_calls: + if hasattr(tool_call, 'function'): + func = tool_call.function + else: + func = tool_call.get('function', {}) + + if hasattr(func, 'name') and func.name: + name = func.name + elif isinstance(func, dict) and func.get('name'): + name = func['name'] + + if hasattr(func, 'arguments') and func.arguments: + arguments += func.arguments + elif isinstance(func, dict) and func.get('arguments'): + arguments += func['arguments'] + + if finish_reason == "tool_calls": + yield from self.handle_function_call(messages, name, arguments) + yield from self.get_completion( + model=model, + temperature=temperature, + top_p=top_p, + messages=messages, + functions=functions, + caching=False, + ) + return + + # Extract content from delta + content = delta.content or "" + yield content + + elif hasattr(choice, 'message') and choice.message is not None: + # Non-streaming response structure (like the one you showed) + message = choice.message + finish_reason = choice.finish_reason + + # Handle tool calls for non-streaming + tool_calls = None + if hasattr(message, 'tool_calls') and message.tool_calls: + tool_calls = message.tool_calls + + if tool_calls: + for tool_call in tool_calls: + if hasattr(tool_call, 'function'): + func = tool_call.function + else: + func = tool_call.get('function', {}) + + if hasattr(func, 'name') and func.name: + name = func.name + elif isinstance(func, dict) and func.get('name'): + name = func['name'] + + if hasattr(func, 'arguments') and func.arguments: + arguments += func.arguments + elif isinstance(func, dict) and func.get('arguments'): + arguments += func['arguments'] + + if finish_reason == "tool_calls": + yield from self.handle_function_call(messages, name, arguments) + yield from self.get_completion( + model=model, + temperature=temperature, + top_p=top_p, + messages=messages, + functions=functions, + caching=False, + ) + return + + # Extract content from message + content = message.content or "" + yield content + + else: + # Unknown response structure, skip + continue + except KeyboardInterrupt: response.close() + except Exception as e: + # Handle Azure OpenAI specific errors + if self.provider == "azure-openai": + error_msg = str(e) + if "deployment" in error_msg.lower(): + raise Exception(f"Azure OpenAI deployment error: {error_msg}. Please check your deployment name and API configuration.") + elif "api_version" in error_msg.lower(): + raise Exception(f"Azure OpenAI API version error: {error_msg}. Please check your API version configuration.") + elif "list index out of range" in error_msg.lower(): + raise Exception(f"Azure OpenAI response parsing error: {error_msg}. This might be due to unexpected response format or empty choices.") + else: + raise Exception(f"Azure OpenAI error: {error_msg}") + else: + raise e def handle( self, diff --git a/tests/test_azure_openai.py b/tests/test_azure_openai.py new file mode 100644 index 00000000..0b887b6a --- /dev/null +++ b/tests/test_azure_openai.py @@ -0,0 +1,135 @@ +import os +from unittest.mock import patch + +import typer +from typer.testing import CliRunner + +from sgpt import config, main +from sgpt.role import DefaultRoles, SystemRole + +from .utils import app, cmd_args, comp_args, mock_azure_comp, runner + +role = SystemRole.get(DefaultRoles.DEFAULT.value) +cfg = config.cfg + + +@patch("sgpt.handlers.handler.completion") +def test_azure_openai_provider(completion): + """Test that Azure OpenAI provider works correctly.""" + completion.return_value = mock_azure_comp("Azure OpenAI response") + + # Set environment variable for Azure OpenAI provider + os.environ["OPENAI_PROVIDER"] = "azure-openai" + + args = {"prompt": "test prompt"} + result = runner.invoke(app, cmd_args(**args)) + + completion.assert_called_once_with(**comp_args(role, **args)) + assert result.exit_code == 0 + assert "Azure OpenAI response" in result.stdout + + # Clean up environment variable + os.environ.pop("OPENAI_PROVIDER", None) + + +@patch("sgpt.handlers.handler.completion") +def test_azure_openai_shell_command(completion): + """Test Azure OpenAI provider with shell command generation.""" + completion.return_value = mock_azure_comp("ls -la") + + # Set environment variable for Azure OpenAI provider + os.environ["OPENAI_PROVIDER"] = "azure-openai" + + args = {"prompt": "list files", "--shell": True} + result = runner.invoke(app, cmd_args(**args)) + + completion.assert_called_once_with(**comp_args(role, **args)) + assert result.exit_code == 0 + assert "ls -la" in result.stdout + + # Clean up environment variable + os.environ.pop("OPENAI_PROVIDER", None) + + +@patch("sgpt.handlers.handler.completion") +def test_azure_openai_code_generation(completion): + """Test Azure OpenAI provider with code generation.""" + completion.return_value = mock_azure_comp("print('Hello World')") + + # Set environment variable for Azure OpenAI provider + os.environ["OPENAI_PROVIDER"] = "azure-openai" + + args = {"prompt": "hello world in python", "--code": True} + result = runner.invoke(app, cmd_args(**args)) + + completion.assert_called_once_with(**comp_args(role, **args)) + assert result.exit_code == 0 + assert "print('Hello World')" in result.stdout + + # Clean up environment variable + os.environ.pop("OPENAI_PROVIDER", None) + + +@patch("sgpt.handlers.handler.completion") +def test_azure_openai_chat_mode(completion): + """Test Azure OpenAI provider with chat mode.""" + completion.side_effect = [mock_azure_comp("ok"), mock_azure_comp("4")] + + # Set environment variable for Azure OpenAI provider + os.environ["OPENAI_PROVIDER"] = "azure-openai" + + chat_name = "_test_azure" + chat_path = cfg.get("CHAT_CACHE_PATH") / chat_name + chat_path.unlink(missing_ok=True) + + args = {"prompt": "my number is 2", "--chat": chat_name} + result = runner.invoke(app, cmd_args(**args)) + assert result.exit_code == 0 + assert "ok" in result.stdout + assert chat_path.exists() + + args["prompt"] = "my number + 2?" + result = runner.invoke(app, cmd_args(**args)) + assert result.exit_code == 0 + assert "4" in result.stdout + + expected_messages = [ + {"role": "system", "content": role.role}, + {"role": "user", "content": "my number is 2"}, + {"role": "assistant", "content": "ok"}, + {"role": "user", "content": "my number + 2?"}, + {"role": "assistant", "content": "4"}, + ] + expected_args = comp_args(role, "", messages=expected_messages) + completion.assert_called_with(**expected_args) + assert completion.call_count == 2 + + chat_path.unlink() + + # Clean up environment variable + os.environ.pop("OPENAI_PROVIDER", None) + + +@patch("sgpt.handlers.handler.completion") +def test_provider_configuration(completion): + """Test that provider configuration is properly read.""" + completion.return_value = mock_azure_comp("test response") + + # Test with different provider values + test_cases = [ + ("azure-openai", "Azure response"), + ("openai", "OpenAI response"), + ] + + for provider, expected_content in test_cases: + os.environ["OPENAI_PROVIDER"] = provider + completion.return_value = mock_azure_comp(expected_content) + + args = {"prompt": "test"} + result = runner.invoke(app, cmd_args(**args)) + + assert result.exit_code == 0 + assert expected_content in result.stdout + + # Clean up environment variable + os.environ.pop("OPENAI_PROVIDER", None) \ No newline at end of file