Skip to content

Commit

Permalink
🔖 chore: support claude function
Browse files Browse the repository at this point in the history
  • Loading branch information
MartialBE committed Jul 2, 2024
1 parent 50562aa commit 3c45b5c
Show file tree
Hide file tree
Showing 4 changed files with 312 additions and 67 deletions.
10 changes: 6 additions & 4 deletions providers/claude/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ func errorHandle(claudeError *ClaudeError) *types.OpenAIError {
return nil
}
return &types.OpenAIError{
Message: claudeError.Message,
Type: claudeError.Type,
Message: claudeError.Error.Message,
Type: claudeError.Error.Type,
Code: claudeError.Type,
}
}
Expand Down Expand Up @@ -84,18 +84,20 @@ func (p *ClaudeProvider) GetFullRequestURL(requestURL string) string {

func stopReasonClaude2OpenAI(reason string) string {
switch reason {
case "end_turn":
case "end_turn", "stop_sequence":
return types.FinishReasonStop
case "max_tokens":
return types.FinishReasonLength
case "tool_use":
return types.FinishReasonToolCalls
default:
return reason
}
}

func convertRole(role string) string {
switch role {
case "user":
case types.ChatMessageRoleUser, types.ChatMessageRoleTool, types.ChatMessageRoleFunction:
return types.ChatMessageRoleUser
default:
return types.ChatMessageRoleAssistant
Expand Down
239 changes: 196 additions & 43 deletions providers/claude/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,17 @@ import (
"github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream"
)

const (
StreamTollsNone = 0
StreamTollsUse = 1
StreamTollsArg = 2
)

type ClaudeStreamHandler struct {
Usage *types.Usage
Request *types.ChatCompletionRequest
Prefix string
Usage *types.Usage
Request *types.ChatCompletionRequest
StreamTolls int
Prefix string
}

func (p *ClaudeProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) {
Expand Down Expand Up @@ -99,57 +106,142 @@ func ConvertFromChatOpenai(request *types.ChatCompletionRequest) (*ClaudeRequest
request.ClearEmptyMessages()
claudeRequest := ClaudeRequest{
Model: request.Model,
Messages: []Message{},
Messages: make([]Message, 0),
System: "",
MaxTokens: request.MaxTokens,
MaxTokens: defaultMaxTokens(request.MaxTokens),
StopSequences: nil,
Temperature: request.Temperature,
TopP: request.TopP,
Stream: request.Stream,
}
if claudeRequest.MaxTokens == 0 {
claudeRequest.MaxTokens = 4096
}

for _, message := range request.Messages {
if message.Role == "system" {
claudeRequest.System = message.Content.(string)
var prevUserMessage bool

for _, msg := range request.Messages {
if msg.Role == "system" && claudeRequest.System == "" {
claudeRequest.System = msg.StringContent()
continue
}
content := Message{
Role: convertRole(message.Role),
Content: []MessageContent{},
messageContent, err := convertMessageContent(&msg)
if err != nil {
return nil, common.ErrorWrapper(err, "conversion_error", http.StatusBadRequest)
}
if messageContent != nil {
if messageContent.Role == "user" && prevUserMessage {
assistantMessage := Message{
Role: "assistant",
Content: []MessageContent{
{
Type: "text",
Text: "ok",
},
},
}
claudeRequest.Messages = append(claudeRequest.Messages, assistantMessage)
prevUserMessage = false
} else {
prevUserMessage = messageContent.Role == "user"
}
claudeRequest.Messages = append(claudeRequest.Messages, *messageContent)
}
}

for _, tool := range request.Tools {
tool := Tools{
Name: tool.Function.Name,
Description: tool.Function.Description,
InputSchema: tool.Function.Parameters,
}
claudeRequest.Tools = append(claudeRequest.Tools, tool)
}

if request.ToolChoice != nil {
toolType, toolFunc := request.ParseToolChoice()
claudeRequest.ToolChoice = ConvertToolChoice(toolType, toolFunc)
}

openaiContent := message.ParseContent()
for _, part := range openaiContent {
if part.Type == types.ContentTypeText {
content.Content = append(content.Content, MessageContent{
Type: "text",
Text: part.Text,
})
continue
return &claudeRequest, nil
}

func ConvertToolChoice(toolType, toolFunc string) *ToolChoice {
choice := &ToolChoice{Type: "auto"}

switch toolType {
case types.ToolChoiceTypeFunction:
choice.Type = "tool"
choice.Name = toolFunc
case types.ToolChoiceTypeRequired:
choice.Type = "any"
}

return choice
}

func defaultMaxTokens(maxTokens int) int {
if maxTokens == 0 {
return 4096
}
return maxTokens
}

func convertMessageContent(msg *types.ChatCompletionMessage) (*Message, error) {
content := Message{
Role: convertRole(msg.Role),
Content: make([]MessageContent, 0),
}

if msg.ToolCalls != nil {
for _, toolCall := range msg.ToolCalls {
inputParam := make(map[string]any)
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &inputParam); err != nil {
return nil, err
}
content.Content = append(content.Content, MessageContent{
Type: ContentTypeToolUes,
Id: toolCall.Id,
Name: toolCall.Function.Name,
Input: inputParam,
})
}
return &content, nil
}

if part.Type == types.ContentTypeImageURL {
mimeType, data, err := image.GetImageFromUrl(part.ImageURL.URL)
if err != nil {
return nil, common.ErrorWrapper(err, "image_url_invalid", http.StatusBadRequest)
}
content.Content = append(content.Content, MessageContent{
Type: "image",
Source: &ContentSource{
Type: "base64",
MediaType: mimeType,
Data: data,
},
})
if msg.Role == types.ChatMessageRoleTool {
content.Content = append(content.Content, MessageContent{
Type: ContentTypeToolResult,
Content: msg.StringContent(),
ToolUseId: msg.ToolCallID,
})

return &content, nil
}

openaiContent := msg.ParseContent()
for _, part := range openaiContent {
if part.Type == types.ContentTypeText {
content.Content = append(content.Content, MessageContent{
Type: "text",
Text: part.Text,
})
continue
}
if part.Type == types.ContentTypeImageURL {
mimeType, data, err := image.GetImageFromUrl(part.ImageURL.URL)
if err != nil {
return nil, common.ErrorWrapper(err, "image_url_invalid", http.StatusBadRequest)
}
content.Content = append(content.Content, MessageContent{
Type: "image",
Source: &ContentSource{
Type: "base64",
MediaType: mimeType,
Data: data,
},
})
}
claudeRequest.Messages = append(claudeRequest.Messages, content)
}

return &claudeRequest, nil
return &content, nil
}

func ConvertToChatOpenai(provider base.ProviderInterface, response *ClaudeResponse, request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
Expand All @@ -162,15 +254,29 @@ func ConvertToChatOpenai(provider base.ProviderInterface, response *ClaudeRespon
return
}

responseText := ""
if len(response.Content) > 0 {
responseText = response.Content[0].Text
}

choice := types.ChatCompletionChoice{
Index: 0,
Message: types.ChatCompletionMessage{
Role: response.Role,
Content: strings.TrimPrefix(response.Content[0].Text, " "),
Name: nil,
Content: responseText,
},
FinishReason: stopReasonClaude2OpenAI(response.StopReason),
}

if response.StopReason == FinishReasonToolUse {
for _, content := range response.Content {
if content.Type == FinishReasonToolUse {
choice.Message.ToolCalls = []*types.ChatCompletionToolCalls{content.ToOpenAITool()}
}
}
choice.FinishReason = types.FinishReasonToolCalls
}

openaiResponse = &types.ChatCompletionResponse{
ID: response.Id,
Object: "chat.completion",
Expand Down Expand Up @@ -245,6 +351,9 @@ func (h *ClaudeStreamHandler) HandlerStream(rawLine *[]byte, dataChan chan strin
h.Usage.CompletionTokens += common.CountTokenText(claudeResponse.Delta.Text, h.Request.Model)
h.Usage.TotalTokens = h.Usage.PromptTokens + h.Usage.CompletionTokens

case "content_block_start":
h.convertToOpenaiStream(&claudeResponse, dataChan)

default:
return
}
Expand All @@ -253,14 +362,58 @@ func (h *ClaudeStreamHandler) HandlerStream(rawLine *[]byte, dataChan chan strin
func (h *ClaudeStreamHandler) convertToOpenaiStream(claudeResponse *ClaudeStreamResponse, dataChan chan string) {
choice := types.ChatCompletionStreamChoice{
Index: claudeResponse.Index,
Delta: types.ChatCompletionStreamChoiceDelta{
Role: claudeResponse.Message.Role,
Content: claudeResponse.Delta.Text,
},
}

if claudeResponse.ContentBlock.Text != "" {
choice.Delta.Content = claudeResponse.ContentBlock.Text
}

if claudeResponse.Message.Role != "" {
choice.Delta.Role = claudeResponse.Message.Role
var toolCalls []*types.ChatCompletionToolCalls

if claudeResponse.ContentBlock.Type == ContentTypeToolUes {
toolCalls = append(toolCalls, &types.ChatCompletionToolCalls{
Id: claudeResponse.ContentBlock.Id,
Type: types.ChatMessageRoleFunction,
Function: &types.ChatCompletionToolCallsFunction{
Name: claudeResponse.ContentBlock.Name,
Arguments: "",
},
})
h.StreamTolls = StreamTollsUse
}

if claudeResponse.Delta.Type == "input_json_delta" {
if claudeResponse.Delta.PartialJson == "" {
return
}
toolCalls = append(toolCalls, &types.ChatCompletionToolCalls{
Type: types.ChatMessageRoleFunction,
Function: &types.ChatCompletionToolCallsFunction{
Arguments: claudeResponse.Delta.PartialJson,
},
})
h.StreamTolls = StreamTollsArg
}

if claudeResponse.ContentBlock.Type != ContentTypeToolUes && claudeResponse.Delta.Type != "input_json_delta" && h.StreamTolls != StreamTollsNone {
if h.StreamTolls == StreamTollsUse {
toolCalls = append(toolCalls, &types.ChatCompletionToolCalls{
Type: types.ChatMessageRoleFunction,
Function: &types.ChatCompletionToolCallsFunction{
Arguments: "{}",
},
})
}

h.StreamTolls = StreamTollsNone
}

if claudeResponse.Delta.Text != "" {
choice.Delta.Content = claudeResponse.Delta.Text
if toolCalls != nil {
choice.Delta.ToolCalls = toolCalls
}

finishReason := stopReasonClaude2OpenAI(claudeResponse.Delta.StopReason)
Expand Down
Loading

0 comments on commit 3c45b5c

Please sign in to comment.