Skip to content

Commit

Permalink
Merge pull request #8 from sangyuxiaowu/chat
Browse files Browse the repository at this point in the history
feat: 函数调用
  • Loading branch information
sangyuxiaowu authored Jul 25, 2024
2 parents 36e72fe + 82732e7 commit b3b3998
Show file tree
Hide file tree
Showing 16 changed files with 886 additions and 258 deletions.
22 changes: 20 additions & 2 deletions LLamaWorker.OpenAIModels/ChatCompletionModels.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
namespace LLamaWorker.OpenAIModels
using System.Text.Json.Serialization;

namespace LLamaWorker.OpenAIModels
{
/// <summary>
/// 对话完成请求
Expand Down Expand Up @@ -40,14 +42,30 @@ public class ChatCompletionMessage
{
/// <summary>
/// 角色
/// system, user, assistant, tool
/// </summary>
/// <example>user</example>
public string? role { get; set; } = string.Empty;
/// <summary>
/// 对话内容
/// </summary>
/// <example>你好</example>
public string content { get; set; } = string.Empty;
public string? content { get; set; }

/// <summary>
/// 工具调用信息
/// </summary>
/// <example>null</example>
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public ToolMeaasge[]? tool_calls { get; set; }

/// <summary>
/// 调用工具的 ID
/// role 为 tool 时必填
/// </summary>
/// <example>null</example>
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? tool_call_id { get; set; }
}

/// <summary>
Expand Down
58 changes: 57 additions & 1 deletion LLamaWorker.OpenAIModels/ToolModels.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
namespace LLamaWorker.OpenAIModels
using System.Text.Json.Serialization;

namespace LLamaWorker.OpenAIModels
{
/// <summary>
/// 推理完成令牌信息
Expand Down Expand Up @@ -29,11 +31,13 @@ public class FunctionInfo
/// <summary>
/// 函数作用的描述,由模型用来选择何时以及如何调用函数。
/// </summary>
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? description { get; set; }

/// <summary>
/// 函数接受的参数,描述为JSON模式对象。
/// </summary>
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public Parameters? parameters { get; set; }
}

Expand All @@ -56,6 +60,7 @@ public class Parameters
/// <summary>
/// 必需的参数名称列表。
/// </summary>
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string[]? required { get; set; }
}

Expand All @@ -72,12 +77,63 @@ public class ParameterInfo
/// <summary>
/// 参数的描述
/// </summary>
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string? description { get; set; }

/// <summary>
/// 参数的可选值
/// </summary>
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string[]? @enum { get; set; }
}

/// <summary>
/// 工具消息块,流式处理
/// </summary>
public class ToolMeaasgeChunk : ToolMeaasge
{
/// <summary>
/// 工具调用的索引
/// </summary>
public int index { get; set; }
}

/// <summary>
/// 工具消息块
/// </summary>
public class ToolMeaasge
{
/// <summary>
/// 工具调用的 ID
/// </summary>
public string id { get; set; }

/// <summary>
/// 工具类型,当前固定 function
/// </summary>
public string type { get; set; } = "function";

/// <summary>
/// 调用的函数信息
/// </summary>
public ToolMeaasgeFuntion function { get; set; }
}

/// <summary>
/// 调用工具的响应选择
/// </summary>
public class ToolMeaasgeFuntion
{
/// <summary>
/// 函数名称
/// </summary>
public string name { get; set; }

/// <summary>
/// 调用函数的参数,由JSON格式的模型生成。
/// 请注意,该模型并不总是生成有效的JSON,并且可能会产生函数模式未定义的参数的幻觉。
/// 在调用函数之前,验证代码中的参数。
/// </summary>
public string? arguments { get; set; }
}
}
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ English | [中文](README_CN.md)
- **Embedding Support**: Provides text embedding functionality with support for various embedding models.
- **chat templates**: Provides some common chat templates.
- **Auto-Release**: Supports automatic release of loaded models.
- **Function Call**: Supports function calls.
- **API Key Authentication**: Supports API Key authentication.
- **Gradio UI Demo**: Provides a UI demo based on Gradio.NET.

Expand Down
1 change: 1 addition & 0 deletions README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ LLamaWorker 是一个基于 [LLamaSharp](https://github.com/SciSharp/LLamaSharp?
- **嵌入支持**: 提供文本嵌入功能,支持多种嵌入模型。
- **对话模版**: 提供了一些常见的对话模版。
- **自动释放**: 支持自动释放已加载模型。
- **函数调用**: 支持函数调用。
- **API Key 认证**: 支持 API Key 认证。
- **Gradio UI Demo**: 提供了一个基于 Gradio.NET 的 UI 演示。

Expand Down
114 changes: 102 additions & 12 deletions src/FunctionCall/ToolPromptGenerator.cs
Original file line number Diff line number Diff line change
@@ -1,31 +1,123 @@
using LLamaWorker.OpenAIModels;
using LLamaWorker.OpenAIModels;
using Microsoft.Extensions.Options;
using System.Collections.Generic;
using System.Text.Encodings.Web;
using System.Text.Json;
using System.Text.RegularExpressions;
using System.Text.Unicode;

namespace LLamaWorker.FunctionCall
{
/// <summary>
/// 基础工具提示生成器
/// 基础工具提示生成器
/// </summary>
public class ToolPromptGenerator
{
private readonly List<ToolPromptConfig> _config;

private readonly string[] _nullWords = new string[] { "null", "{}", "[]" };
/// <summary>
/// 基础工具提示生成器
/// </summary>
/// <param name="config">工具配置信息</param>
public ToolPromptGenerator(IOptions<List<ToolPromptConfig>> config)
{
_config = config.Value;
}

/// <summary>
/// 生成工具提示词
/// 获取工具停用词
/// </summary>
/// <param name="tpl">模版序号</param>
/// <returns></returns>
public string[] GetToolStopWords(int tpl = 0)
{
return _config[tpl].FN_STOP_WORDS;
}

/// <summary>
/// 获取工具提示配置
/// </summary>
/// <param name="tpl">模版序号</param>
/// <returns></returns>
public ToolPromptConfig GetToolPromptConfig(int tpl = 0)
{
return _config[tpl];
}

/// <summary>
/// 生成工具调用
/// </summary>
/// <param name="tool">工具调用消息</param>
/// <param name="tpl">模版序号</param>
/// <returns></returns>
public string GenerateToolCall(ToolMeaasge tool, int tpl = 0)
{
return $"{_config[tpl].FN_NAME}: {tool.function.name}\n{_config[tpl].FN_ARGS}: {tool.function.arguments}";
}

/// <summary>
/// 生成工具返回结果
/// </summary>
/// <param name="res">工具调用结果</param>
/// <param name="tpl">模版序号</param>
/// <returns></returns>
public string GenerateToolCallResult(string? res, int tpl = 0)
{
return $"{_config[tpl].FN_RESULT}: {res}";
}

/// <summary>
/// 生成工具推理结果
/// </summary>
/// <param name="res">工具推理结果</param>
/// <param name="tpl">模版序号</param>
/// <returns></returns>
public string GenerateToolCallReturn(string? res, int tpl = 0)
{
return $"{_config[tpl].FN_EXIT}: {res}";
}

/// <summary>
/// 检查并生成工具调用
/// </summary>
/// <param name="req">原始对话生成请求</param>
/// <param name="tpl">模版序号</param>
/// <param name="lang">语言</param>
/// <param name="input">推理输出</param>
/// <param name="tpl">模版序号</param>
/// <returns></returns>
public List<ToolMeaasgeFuntion> GenerateToolCall(string input, int tpl = 0)
{
string pattern = @$"{_config[tpl].FN_NAME}:? (.*?)\s*({_config[tpl].FN_ARGS}:? (.*?)\s*)(?={_config[tpl].FN_NAME}|$|\n)";
Regex regex = new Regex(pattern, RegexOptions.Singleline);
MatchCollection matches = regex.Matches(input);
List<ToolMeaasgeFuntion> results = new();
foreach (Match match in matches)
{
string functionName = match.Groups[1].Value;
string arguments = match.Groups[3].Success ? match.Groups[3].Value : "";
if (string.IsNullOrWhiteSpace(arguments) || _nullWords.Contains(arguments))
{
arguments = null;
}
results.Add(new ToolMeaasgeFuntion
{
name = functionName,
arguments = arguments,
});
}
return results;
}


/// <summary>
/// 生成工具提示词
/// </summary>
/// <param name="req">原始对话生成请求</param>
/// <param name="tpl">模版序号</param>
/// <param name="lang">语言</param>
/// <returns></returns>
public string GenerateToolPrompt(ChatCompletionRequest req, int tpl = 0, string lang = "zh")
{
// 如果没有工具或者工具选择为 none,则返回空字符串
// 如果没有工具或者工具选择为 none,则返回空字符串
if (req.tools == null || req.tools.Length == 0 || (req.tool_choice != null && req.tool_choice.ToString() == "none"))
{
return string.Empty;
Expand All @@ -42,18 +134,16 @@ public string GenerateToolPrompt(ChatCompletionRequest req, int tpl = 0, string

var parallelFunctionCalls = req.tool_choice?.ToString() == "parallel";
var toolTemplate = parallelFunctionCalls ? config.FN_CALL_TEMPLATE_FMT_PARA[lang] : config.FN_CALL_TEMPLATE_FMT[lang];
var toolPrompt = string.Format(toolTemplate, config.FN_NAME, config.FN_ARGS, config.FN_RESULT, config.FN_EXIT)
.Replace("{tool_names}", toolNames);

return $"{toolSystem}\n\n{toolPrompt}";
var toolPrompt = string.Format(toolTemplate, config.FN_NAME, config.FN_ARGS, config.FN_RESULT, config.FN_EXIT, toolNames);
return $"\n\n{toolSystem}\n\n{toolPrompt}";
}

private string GetFunctionDescription(FunctionInfo function, string toolDescTemplate)
{
var nameForHuman = function.name;
var nameForModel = function.name;
var descriptionForModel = function.description ?? string.Empty;
var parameters = JsonSerializer.Serialize(function.parameters, new JsonSerializerOptions { WriteIndented = true });
var parameters = JsonSerializer.Serialize(function.parameters, new JsonSerializerOptions { Encoder = JavaScriptEncoder.Create(UnicodeRanges.All) });

return string.Format(toolDescTemplate, nameForHuman, nameForModel, descriptionForModel, parameters).Trim();
}
Expand Down
Loading

0 comments on commit b3b3998

Please sign in to comment.