From 0fd788a6a44a1f093ba6de4a13d55b185e429b8c Mon Sep 17 00:00:00 2001 From: Michelangelo Mori <328978+blkt@users.noreply.github.com> Date: Tue, 29 Apr 2025 19:40:31 +0200 Subject: [PATCH] Fix `codegate version` and similar commands. While refactoring I removed three lines of coding managing short-circuited requests. We short-circuit requests to implement `codegate version`, `codegate workspace`, and similar commands. Given we now have a provider-native representation of the messages, it is necessary to produce the right message for the given request and provider, so some code must be added in provider-specific modules to handle that. Fixes #1362 --- src/codegate/providers/anthropic/provider.py | 2 + src/codegate/providers/base.py | 5 ++ src/codegate/providers/openai/provider.py | 4 ++ src/codegate/types/openai/__init__.py | 2 + src/codegate/types/openai/_generators.py | 51 ++++++++++++++++++++ src/codegate/updates/client.py | 2 +- 6 files changed, 65 insertions(+), 1 deletion(-) diff --git a/src/codegate/providers/anthropic/provider.py b/src/codegate/providers/anthropic/provider.py index 13741b85..368a1356 100644 --- a/src/codegate/providers/anthropic/provider.py +++ b/src/codegate/providers/anthropic/provider.py @@ -77,6 +77,7 @@ async def process_request( client_type: ClientType, completion_handler: Callable | None = None, stream_generator: Callable | None = None, + short_circuiter: Callable | None = None, ): try: stream = await self.complete( @@ -86,6 +87,7 @@ async def process_request( is_fim_request, client_type, completion_handler=completion_handler, + short_circuiter=short_circuiter, ) except Exception as e: # check if we have an status code there diff --git a/src/codegate/providers/base.py b/src/codegate/providers/base.py index 9dca5ed9..a12f705f 100644 --- a/src/codegate/providers/base.py +++ b/src/codegate/providers/base.py @@ -259,6 +259,7 @@ async def complete( is_fim_request: bool, client_type: ClientType, completion_handler: Callable | None = None, + short_circuiter: Callable | None = None, ) -> Union[Any, AsyncIterator[Any]]: """ Main completion flow with pipeline integration @@ -287,6 +288,10 @@ async def complete( is_fim_request, ) + if input_pipeline_result.response and input_pipeline_result.context: + if short_circuiter: # this if should be removed eventually + return short_circuiter(input_pipeline_result) + provider_request = normalized_request # default value if input_pipeline_result.request: provider_request = self._input_normalizer.denormalize(input_pipeline_result.request) diff --git a/src/codegate/providers/openai/provider.py b/src/codegate/providers/openai/provider.py index ef8c4b5b..dc418192 100644 --- a/src/codegate/providers/openai/provider.py +++ b/src/codegate/providers/openai/provider.py @@ -14,6 +14,7 @@ from codegate.types.openai import ( ChatCompletionRequest, completions_streaming, + short_circuiter, stream_generator, ) @@ -72,6 +73,7 @@ async def process_request( client_type: ClientType, completion_handler: Callable | None = None, stream_generator: Callable | None = None, + short_circuiter: Callable | None = None, ): try: stream = await self.complete( @@ -81,6 +83,7 @@ async def process_request( is_fim_request=is_fim_request, client_type=client_type, completion_handler=completion_handler, + short_circuiter=short_circuiter, ) except Exception as e: # Check if we have an status code there @@ -130,4 +133,5 @@ async def create_completion( self.base_url, is_fim_request, request.state.detected_client, + short_circuiter=short_circuiter, ) diff --git a/src/codegate/types/openai/__init__.py b/src/codegate/types/openai/__init__.py index ca97e268..73cf8702 100644 --- a/src/codegate/types/openai/__init__.py +++ b/src/codegate/types/openai/__init__.py @@ -2,6 +2,7 @@ from ._generators import ( completions_streaming, message_wrapper, + short_circuiter, single_response_generator, stream_generator, streaming, @@ -74,6 +75,7 @@ "completions_streaming", "message_wrapper", "single_response_generator", + "short_circuiter", "stream_generator", "streaming", "LegacyCompletion", diff --git a/src/codegate/types/openai/_generators.py b/src/codegate/types/openai/_generators.py index 1d0f215c..829b0844 100644 --- a/src/codegate/types/openai/_generators.py +++ b/src/codegate/types/openai/_generators.py @@ -1,5 +1,7 @@ import os +import time from typing import ( + Any, AsyncIterator, ) @@ -9,9 +11,16 @@ from ._legacy_models import ( LegacyCompletionRequest, ) +from ._request_models import ( + ChatCompletionRequest, +) from ._response_models import ( ChatCompletion, + Choice, + ChoiceDelta, ErrorDetails, + Message, + MessageDelta, MessageError, StreamingChatCompletion, VllmMessageError, @@ -20,6 +29,48 @@ logger = structlog.get_logger("codegate") +async def short_circuiter(pipeline_result) -> AsyncIterator[Any]: + # NOTE: This routine MUST be called only when we short-circuit the + # request. + assert pipeline_result.context.shortcut_response # nosec + + match pipeline_result.context.input_request.request: + case ChatCompletionRequest(stream=True): + yield StreamingChatCompletion( + id="codegate", + model=pipeline_result.response.model, + created=int(time.time()), + choices=[ + ChoiceDelta( + finish_reason="stop", + index=0, + delta=MessageDelta( + content=pipeline_result.response.content, + ), + ), + ], + ) + case ChatCompletionRequest(stream=False): + yield ChatCompletion( + id="codegate", + model=pipeline_result.response.model, + created=int(time.time()), + choices=[ + Choice( + finish_reason="stop", + index=0, + message=Message( + content=pipeline_result.response.content, + ), + ), + ], + ) + case _: + raise ValueError( + f"invalid input request: {pipeline_result.context.input_request.request}" + ) + + async def stream_generator(stream: AsyncIterator[StreamingChatCompletion]) -> AsyncIterator[str]: """OpenAI-style SSE format""" try: diff --git a/src/codegate/updates/client.py b/src/codegate/updates/client.py index 7c958d8c..fc1aa778 100644 --- a/src/codegate/updates/client.py +++ b/src/codegate/updates/client.py @@ -1,8 +1,8 @@ +import os from enum import Enum import requests import structlog -import os logger = structlog.get_logger("codegate")