|
| 1 | +package main |
| 2 | + |
| 3 | +import ( |
| 4 | + "bytes" |
| 5 | + "encoding/json" |
| 6 | + "fmt" |
| 7 | + "io" |
| 8 | + "net/http" |
| 9 | + "os" |
| 10 | + "strings" |
| 11 | + "time" |
| 12 | +) |
| 13 | + |
| 14 | +// ChatMessage |
| 15 | +type ChatMessage struct { |
| 16 | + Role string `json:"role"` |
| 17 | + Content string `json:"content"` |
| 18 | + // 省略 ToolCalls 等仅用于工具调用的字段 |
| 19 | +} |
| 20 | + |
| 21 | +// ChatCompletionRequest |
| 22 | +type ChatCompletionRequest struct { |
| 23 | + Model string `json:"model"` |
| 24 | + Messages []ChatMessage `json:"messages"` // 包含完整的对话历史 |
| 25 | + Temperature float64 `json:"temperature,omitempty"` |
| 26 | + // 使用明确的 max_completion_tokens (如果 API 支持) 或 max_tokens |
| 27 | + MaxTokens int `json:"max_tokens,omitempty"` // 兼容旧模型或 Ollama |
| 28 | + Stream bool `json:"stream,omitempty"` |
| 29 | +} |
| 30 | + |
| 31 | +// ResponseMessage |
| 32 | +type ResponseMessage struct { |
| 33 | + Role string `json:"role"` |
| 34 | + Content *string `json:"content"` // 使用指针处理 null |
| 35 | + // 省略 ToolCalls |
| 36 | +} |
| 37 | + |
| 38 | +// Choice |
| 39 | +type Choice struct { |
| 40 | + Index int `json:"index"` |
| 41 | + Message ResponseMessage `json:"message"` |
| 42 | + FinishReason string `json:"finish_reason"` |
| 43 | + // 省略 logprobs |
| 44 | +} |
| 45 | + |
| 46 | +// UsageInfo |
| 47 | +type UsageInfo struct { |
| 48 | + PromptTokens int `json:"prompt_tokens"` |
| 49 | + CompletionTokens int `json:"completion_tokens"` |
| 50 | + TotalTokens int `json:"total_tokens"` |
| 51 | +} |
| 52 | + |
| 53 | +// ChatCompletionResponse |
| 54 | +type ChatCompletionResponse struct { |
| 55 | + ID string `json:"id"` |
| 56 | + Object string `json:"object"` |
| 57 | + Created int64 `json:"created"` |
| 58 | + Model string `json:"model"` |
| 59 | + Choices []Choice `json:"choices"` |
| 60 | + Usage UsageInfo `json:"usage"` |
| 61 | + // 省略 system_fingerprint |
| 62 | +} |
| 63 | + |
| 64 | +// --- 主函数 --- |
| 65 | + |
| 66 | +const maxTurns = 5 // 控制对话轮数 |
| 67 | + |
| 68 | +func main() { |
| 69 | + apiKey := os.Getenv("OPENAI_API_KEY") |
| 70 | + // 对本地 Ollama 等不需要 key 的服务,apiKey 可以为空 |
| 71 | + // if apiKey == "" { |
| 72 | + // fmt.Println("Warning: OPENAI_API_KEY environment variable not set. Assuming local service.") |
| 73 | + // } |
| 74 | + |
| 75 | + // --- 配置 API 端点和模型 --- |
| 76 | + // apiURL := "https://api.openai.com/v1/chat/completions" // OpenAI |
| 77 | + // modelID := "gpt-3.5-turbo" |
| 78 | + apiURL := "https://api.deepseek.com/chat/completions" |
| 79 | + modelID := "deepseek-chat" // DeepSeek V3 |
| 80 | + |
| 81 | + // --- 初始化对话历史 --- |
| 82 | + // 开发者消息设定角色和目标 |
| 83 | + conversationHistory := []ChatMessage{ |
| 84 | + { |
| 85 | + Role: "system", |
| 86 | + Content: "You are a Go language assistant. Answer questions concisely about Go features.", |
| 87 | + }, |
| 88 | + } |
| 89 | + |
| 90 | + fmt.Printf("Starting a %d-turn conversation with model %s via %s...\n", maxTurns, modelID, apiURL) |
| 91 | + fmt.Println("--------------------------------------------------") |
| 92 | + |
| 93 | + // --- 创建 HTTP 客户端 --- |
| 94 | + client := &http.Client{Timeout: 60 * time.Second} |
| 95 | + |
| 96 | + // --- 自动进行多轮对话 --- |
| 97 | + for turn := 1; turn <= maxTurns; turn++ { |
| 98 | + fmt.Printf("--- Turn %d ---\n", turn) |
| 99 | + |
| 100 | + // 1. 模拟生成用户本轮输入 (基于上一轮的简单逻辑) |
| 101 | + var userPrompt string |
| 102 | + if turn == 1 { |
| 103 | + userPrompt = "What are Go channels used for?" |
| 104 | + } else { |
| 105 | + // 简单追问,实际应用会更复杂 |
| 106 | + // 获取上一轮助手的回答来构造问题 |
| 107 | + lastAssistantMessage := "" |
| 108 | + if len(conversationHistory) > 0 && conversationHistory[len(conversationHistory)-1].Role == "assistant" { |
| 109 | + lastAssistantMessage = conversationHistory[len(conversationHistory)-1].Content |
| 110 | + } |
| 111 | + if strings.Contains(lastAssistantMessage, "communication") { |
| 112 | + userPrompt = "Can you give a simple code example of channel communication?" |
| 113 | + } else if strings.Contains(lastAssistantMessage, "synchronization") { |
| 114 | + userPrompt = "How does channel synchronization compare to using sync.Mutex?" |
| 115 | + } else { |
| 116 | + // 如果上轮没捕捉到关键词,就问个相关问题 |
| 117 | + userPrompt = "What about goroutines? How do they relate to channels?" |
| 118 | + } |
| 119 | + } |
| 120 | + fmt.Printf("User: %s\n", userPrompt) |
| 121 | + |
| 122 | + // 2. 将用户消息添加到历史记录 (发送前) |
| 123 | + userMessage := ChatMessage{Role: "user", Content: userPrompt} |
| 124 | + conversationHistory = append(conversationHistory, userMessage) |
| 125 | + |
| 126 | + // 3. 准备 API 请求体 (包含完整历史) |
| 127 | + requestPayload := ChatCompletionRequest{ |
| 128 | + Model: modelID, |
| 129 | + Messages: conversationHistory, // **关键: 传递了更新后的完整历史** |
| 130 | + Temperature: 0.6, |
| 131 | + MaxTokens: 100, // **限制每轮回复长度** |
| 132 | + Stream: false, |
| 133 | + } |
| 134 | + requestBodyBytes, err := json.Marshal(requestPayload) |
| 135 | + if err != nil { |
| 136 | + fmt.Printf("[Error] Marshalling request for turn %d: %v\n", turn, err) |
| 137 | + // 出错时移除刚添加的用户消息,避免影响下一轮(如果继续) |
| 138 | + conversationHistory = conversationHistory[:len(conversationHistory)-1] |
| 139 | + continue // 或 break |
| 140 | + } |
| 141 | + |
| 142 | + // 4. 创建并发送 HTTP 请求 |
| 143 | + req, err := http.NewRequest("POST", apiURL, bytes.NewBuffer(requestBodyBytes)) |
| 144 | + if err != nil { |
| 145 | + fmt.Printf("[Error] Creating request for turn %d: %v\n", turn, err) |
| 146 | + conversationHistory = conversationHistory[:len(conversationHistory)-1] |
| 147 | + continue |
| 148 | + } |
| 149 | + req.Header.Set("Content-Type", "application/json") |
| 150 | + // 仅在 apiKey 存在时添加 Authorization 头 |
| 151 | + if apiKey != "" { |
| 152 | + req.Header.Set("Authorization", "Bearer "+apiKey) |
| 153 | + } |
| 154 | + |
| 155 | + resp, err := client.Do(req) |
| 156 | + if err != nil { |
| 157 | + fmt.Printf("[Error] Sending request for turn %d: %v\n", turn, err) |
| 158 | + conversationHistory = conversationHistory[:len(conversationHistory)-1] |
| 159 | + continue |
| 160 | + } |
| 161 | + |
| 162 | + // 5. 处理响应 |
| 163 | + func() { // 使用匿名函数方便 defer resp.Body.Close() |
| 164 | + defer resp.Body.Close() |
| 165 | + responseBodyBytes, readErr := io.ReadAll(resp.Body) |
| 166 | + if readErr != nil { |
| 167 | + fmt.Printf("[Error] Reading response body for turn %d: %v\n", turn, readErr) |
| 168 | + conversationHistory = conversationHistory[:len(conversationHistory)-1] |
| 169 | + return // 从匿名函数返回 |
| 170 | + } |
| 171 | + |
| 172 | + if resp.StatusCode != http.StatusOK { |
| 173 | + fmt.Printf("[Error] Non-OK status code for turn %d: %d\nResponse: %s\n", turn, resp.StatusCode, string(responseBodyBytes)) |
| 174 | + conversationHistory = conversationHistory[:len(conversationHistory)-1] |
| 175 | + return |
| 176 | + } |
| 177 | + |
| 178 | + var chatResponse ChatCompletionResponse |
| 179 | + unmarshalErr := json.Unmarshal(responseBodyBytes, &chatResponse) |
| 180 | + if unmarshalErr != nil { |
| 181 | + fmt.Printf("[Error] Unmarshalling response for turn %d: %v\n", turn, unmarshalErr) |
| 182 | + fmt.Printf("Raw Response: %s\n", string(responseBodyBytes)) |
| 183 | + conversationHistory = conversationHistory[:len(conversationHistory)-1] |
| 184 | + return |
| 185 | + } |
| 186 | + |
| 187 | + // 6. 提取、打印并存储助手响应 |
| 188 | + if len(chatResponse.Choices) > 0 { |
| 189 | + choice := chatResponse.Choices[0] |
| 190 | + assistantContent := "" |
| 191 | + if choice.Message.Content != nil { // 检查 content 是否为 null |
| 192 | + assistantContent = *choice.Message.Content |
| 193 | + } |
| 194 | + |
| 195 | + fmt.Printf("Assistant: %s\n", assistantContent) |
| 196 | + fmt.Printf("(Finish Reason: %s, Tokens: %d prompt + %d completion = %d total)\n", |
| 197 | + choice.FinishReason, |
| 198 | + chatResponse.Usage.PromptTokens, |
| 199 | + chatResponse.Usage.CompletionTokens, |
| 200 | + chatResponse.Usage.TotalTokens) |
| 201 | + |
| 202 | + // 将有效的助手响应添加到历史记录 (为下一轮准备) |
| 203 | + if assistantContent != "" || choice.FinishReason == "tool_calls" { // 即使无文本但有工具调用也应记录 |
| 204 | + assistantMessage := ChatMessage{Role: "assistant", Content: assistantContent} |
| 205 | + // 如果是工具调用,还需要处理 tool_calls 字段,这里简化处理 |
| 206 | + conversationHistory = append(conversationHistory, assistantMessage) |
| 207 | + } else { |
| 208 | + fmt.Printf("[Warning] Assistant response content was empty for turn %d.\n", turn) |
| 209 | + // 如果响应无效,可以选择不添加到历史,或添加一个标记 |
| 210 | + // 这里同样移除用户消息,避免空响应影响后续 |
| 211 | + conversationHistory = conversationHistory[:len(conversationHistory)-1] |
| 212 | + } |
| 213 | + } else { |
| 214 | + fmt.Printf("[Error] No choices received in response for turn %d.\n", turn) |
| 215 | + conversationHistory = conversationHistory[:len(conversationHistory)-1] |
| 216 | + } |
| 217 | + }() // 立即执行匿名函数 |
| 218 | + |
| 219 | + fmt.Println("--------------------------------------------------") |
| 220 | + // 可以加个短暂休眠,模拟思考时间或避免触发速率限制 |
| 221 | + // time.Sleep(1 * time.Second) |
| 222 | + } |
| 223 | + |
| 224 | + fmt.Println("Conversation finished after 5 turns.") |
| 225 | +} |
0 commit comments