Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion go/ai/prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package ai

import (
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
Expand Down Expand Up @@ -244,7 +245,16 @@ func buildVariables(variables any) (map[string]any, error) {

v := reflect.Indirect(reflect.ValueOf(variables))
if v.Kind() == reflect.Map {
return variables.(map[string]any), nil
// ensure JSON tags are taken in consideration (allowing snake case fields)
jsonData, err := json.Marshal(variables)
if err != nil {
return nil, fmt.Errorf("unable to marshal prompt field values: %w", err)
}
var resultVariables map[string]any
if err := json.Unmarshal(jsonData, &resultVariables); err != nil {
return nil, fmt.Errorf("unable to unmarshal prompt field values: %w", err)
}
return resultVariables, nil
}
if v.Kind() != reflect.Struct {
return nil, errors.New("prompt.buildVariables: fields not a struct or pointer to a struct or a map")
Expand Down
79 changes: 72 additions & 7 deletions go/ai/prompt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -904,7 +904,7 @@ output:
---
Hello, {{name}}!
`
err := os.WriteFile(mockPromptFile, []byte(mockPromptContent), 0644)
err := os.WriteFile(mockPromptFile, []byte(mockPromptContent), 0o644)
if err != nil {
t.Fatalf("Failed to create mock prompt file: %v", err)
}
Expand Down Expand Up @@ -941,6 +941,71 @@ Hello, {{name}}!
}
}

func TestLoadPromptSnakeCase(t *testing.T) {
tempDir := t.TempDir()
mockPromptFile := filepath.Join(tempDir, "snake.prompt")
mockPromptContent := `---
model: googleai/gemini-2.5-flash
input:
schema:
items(array):
teamColor: string
team_name: string
---
{{#each items as |it|}}
{{ it.teamColor }},{{ it.team_name }}
{{/each}}
`
err := os.WriteFile(mockPromptFile, []byte(mockPromptContent), 0o644)
if err != nil {
t.Fatalf("Failed to create mock prompt file: %v", err)
}

reg := registry.New()
LoadPrompt(reg, tempDir, "snake.prompt", "snake-namespace")

prompt := LookupPrompt(reg, "snake-namespace/snake")
if prompt == nil {
t.Fatalf("prompt was not registered")
}

type SnakeInput struct {
TeamColor string `json:"teamColor"` // intentionally leaving camel case to test snake + camel support
TeamName string `json:"team_name"`
}

input := map[string]any{"items": []SnakeInput{
{TeamColor: "RED", TeamName: "Firebase"},
{TeamColor: "BLUE", TeamName: "Gophers"},
{TeamColor: "GREEN", TeamName: "Google"},
}}

actionOpts, err := prompt.Render(context.Background(), input)
if err != nil {
t.Fatalf("error rendering prompt: %v", err)
}
if actionOpts.Messages == nil {
t.Fatal("expecting messages to be rendered")
}
renderedPrompt := actionOpts.Messages[0].Text()
for line := range strings.SplitSeq(renderedPrompt, "\n") {
trimmedLine := strings.TrimSpace(line)
if strings.HasPrefix(trimmedLine, "RED") {
if !strings.Contains(trimmedLine, "Firebase") {
t.Fatalf("wrong template render, want: RED,Firebase, got: %s", trimmedLine)
}
} else if strings.HasPrefix(trimmedLine, "BLUE") {
if !strings.Contains(trimmedLine, "Gophers") {
t.Fatalf("wrong template render, want: BLUE,Gophers, got: %s", trimmedLine)
}
} else if strings.HasPrefix(trimmedLine, "GREEN") {
if !strings.Contains(trimmedLine, "Google") {
t.Fatalf("wrong template render, want: GREEN,Google, got: %s", trimmedLine)
}
}
}
}

func TestLoadPrompt_FileNotFound(t *testing.T) {
// Initialize a mock registry
reg := registry.New()
Expand All @@ -962,7 +1027,7 @@ func TestLoadPrompt_InvalidPromptFile(t *testing.T) {
// Create an invalid .prompt file
invalidPromptFile := filepath.Join(tempDir, "invalid.prompt")
invalidPromptContent := `invalid json content`
err := os.WriteFile(invalidPromptFile, []byte(invalidPromptContent), 0644)
err := os.WriteFile(invalidPromptFile, []byte(invalidPromptContent), 0o644)
if err != nil {
t.Fatalf("Failed to create invalid prompt file: %v", err)
}
Expand Down Expand Up @@ -993,7 +1058,7 @@ description: A test prompt

Hello, {{name}}!
`
err := os.WriteFile(mockPromptFile, []byte(mockPromptContent), 0644)
err := os.WriteFile(mockPromptFile, []byte(mockPromptContent), 0o644)
if err != nil {
t.Fatalf("Failed to create mock prompt file: %v", err)
}
Expand All @@ -1018,7 +1083,7 @@ func TestLoadPromptFolder(t *testing.T) {
// Create mock prompt and partial files
mockPromptFile := filepath.Join(tempDir, "example.prompt")
mockSubDir := filepath.Join(tempDir, "subdir")
err := os.Mkdir(mockSubDir, 0755)
err := os.Mkdir(mockSubDir, 0o755)
if err != nil {
t.Fatalf("Failed to create subdirectory: %v", err)
}
Expand All @@ -1041,14 +1106,14 @@ output:
Hello, {{name}}!
`

err = os.WriteFile(mockPromptFile, []byte(mockPromptContent), 0644)
err = os.WriteFile(mockPromptFile, []byte(mockPromptContent), 0o644)
if err != nil {
t.Fatalf("Failed to create mock prompt file: %v", err)
}

// Create a mock prompt file in the subdirectory
mockSubPromptFile := filepath.Join(mockSubDir, "sub_example.prompt")
err = os.WriteFile(mockSubPromptFile, []byte(mockPromptContent), 0644)
err = os.WriteFile(mockSubPromptFile, []byte(mockPromptContent), 0o644)
if err != nil {
t.Fatalf("Failed to create mock prompt file in subdirectory: %v", err)
}
Expand Down Expand Up @@ -1131,7 +1196,7 @@ You are a pirate!
{{ role "user" }}
Hello!
`
err := os.WriteFile(mockPromptFile, []byte(mockPromptContent), 0644)
err := os.WriteFile(mockPromptFile, []byte(mockPromptContent), 0o644)
if err != nil {
t.Fatalf("Failed to create mock prompt file: %v", err)
}
Expand Down
Loading