diff --git a/SakuraTranslator/SakuraTranslatorEndpoint.cs b/SakuraTranslator/SakuraTranslatorEndpoint.cs index 8d47da9..435d119 100644 --- a/SakuraTranslator/SakuraTranslatorEndpoint.cs +++ b/SakuraTranslator/SakuraTranslatorEndpoint.cs @@ -1,10 +1,18 @@ -using System; +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; +using System; using System.Collections; +using System.Collections.Generic; using System.IO; +using System.Linq; using System.Net; +using System.Reflection; using System.Text; using XUnity.AutoTranslator.Plugin.Core.Endpoints; +[assembly: AssemblyVersion("0.2.3")] +[assembly: AssemblyFileVersion("0.2.3")] + namespace SakuraTranslator { public class SakuraTranslatorEndpoint : ITranslateEndpoint @@ -20,11 +28,49 @@ public class SakuraTranslatorEndpoint : ITranslateEndpoint // params private string _endpoint; private string _apiType; + private bool _useDict; + private string _dictMode; + private Dictionary _dict; + + // local var + private string _fullDictStr; public void Initialize(IInitializationContext context) { _endpoint = context.GetOrCreateSetting("Sakura", "Endpoint", "http://127.0.0.1:8080/completion"); _apiType = context.GetOrCreateSetting("Sakura", "ApiType", string.Empty); + if (!bool.TryParse(context.GetOrCreateSetting("Sakura", "UseDict", string.Empty), out _useDict)) + { + _useDict = false; + } + _dictMode = context.GetOrCreateSetting("Sakura", "DictMode", "Full"); + var dictStr = context.GetOrCreateSetting("Sakura", "Dict", string.Empty); + if (!string.IsNullOrEmpty(dictStr)) + { + try + { + _dict = new Dictionary(); + JObject dictJObj = JsonConvert.DeserializeObject(dictStr) as JObject; + foreach (var item in dictJObj) + { + _dict.Add(item.Key, item.Value.ToString()); + } + if (_dict.Count == 0) + { + _useDict = false; + _fullDictStr = string.Empty; + } + else + { + _fullDictStr = string.Join("\n", _dict.Select(x => $"{x.Key}->{x.Value}").ToArray()); + } + } + catch + { + _useDict = false; + _fullDictStr = string.Empty; + } + } } public IEnumerator Translate(ITranslationContext context) @@ -118,9 +164,67 @@ private string MakeRequestJson(string line) } else if (_apiType == "OpenAI") { - json = $"{{" + - $"\"model\": \"sukinishiro\"," + - $"\"messages\": [" + + json = MakeOpenAIPrompt(line); + } + else + { + json = $"{{\"frequency_penalty\": 0.2, \"n_predict\": 1000, \"prompt\": \"将下面的日文文本翻译成中文:{line}\", \"repeat_penalty\": 1, \"temperature\": 0.1, \"top_k\": 40, \"top_p\": 0.3}}"; + } + + return json; + } + + private string MakeOpenAIPrompt(string line) + { + string messagesStr = string.Empty; + if (_useDict) + { + var messages = new List + { + new PromptMessage + { + Role = "system", + Content = "你是一个轻小说翻译模型,可以流畅通顺地以日本轻小说的风格将日文翻译成简体中文,并联系上下文正确使用人称代词,注意不要擅自添加原文中没有的代词,也不要擅自增加或减少换行。" + } + }; + string dictStr; + if (_dictMode == "Full") + { + dictStr = _fullDictStr; + } + else + { + var usedDict = _dict.Where(x => line.Contains(x.Key)); + if (usedDict.Count() > 0) + { + dictStr = string.Join("\n", usedDict.Select(x => $"{x.Key}->{x.Value}").ToArray()); + } + else + { + dictStr = string.Empty; + } + } + if (string.IsNullOrEmpty(dictStr)) + { + messages.Add(new PromptMessage + { + Role = "user", + Content = $"将下面的日文文本翻译成中文:{line}" + }); + } + else + { + messages.Add(new PromptMessage + { + Role = "user", + Content = $"根据以下术语表:\n{dictStr}\n将下面的日文文本根据上述术语表的对应关系和注释翻译成中文:{line}" + }); + } + messagesStr = SerializePromptMessages(messages); + } + else + { + messagesStr = "[" + $"{{" + $"\"role\": \"system\"," + $"\"content\": \"你是一个轻小说翻译模型,可以流畅通顺地以日本轻小说的风格将日文翻译成简体中文,并联系上下文正确使用人称代词,不擅自添加原文中没有的代词。\"" + @@ -129,7 +233,13 @@ private string MakeRequestJson(string line) $"\"role\": \"user\"," + $"\"content\": \"将下面的日文文本翻译成中文:{line}\"" + $"}}" + - $"]," + + $"]"; + } + return $"{{" + + $"\"model\": \"sukinishiro\"," + + $"\"messages\": " + + messagesStr + + $"," + $"\"temperature\": 0.1," + $"\"top_p\": 0.3," + $"\"max_tokens\": 1000," + @@ -139,13 +249,35 @@ private string MakeRequestJson(string line) $"\"um_beams\": 1," + $"\"repetition_penalty\": 1.0" + $"}}"; - } - else - { - json = $"{{\"frequency_penalty\": 0.2, \"n_predict\": 1000, \"prompt\": \"将下面的日文文本翻译成中文:{line}\", \"repeat_penalty\": 1, \"temperature\": 0.1, \"top_k\": 40, \"top_p\": 0.3}}"; - } + } - return json; + private string SerializePromptMessages(List messages) + { + string result = "["; + result += string.Join(",", messages.Select(x => $"{{\"role\":\"{x.Role}\"," + + $"\"content\":\"{EscapeJsonString(x.Content)}\"}}").ToArray()); + result += "]"; + return result; + } + + private string EscapeJsonString(string str) + { + return str + .Replace("\\", "\\\\") + .Replace("/", "\\/") + .Replace("\b", "\\b") + .Replace("\f", "\\f") + .Replace("\n", "\\n") + .Replace("\r", "\\r") + .Replace("\t", "\\t") + .Replace("\v", "\\v") + .Replace("\"", "\\\""); + } + + class PromptMessage + { + public string Role { get; set; } + public string Content { get; set; } } } }