Skip to content

Commit

Permalink
[feature][headless-chat]Introduce ChatApp to support more flexible …
Browse files Browse the repository at this point in the history
…chat model config.#1739
  • Loading branch information
jerryjzhang committed Oct 12, 2024
1 parent 0cce0a7 commit 49d0638
Show file tree
Hide file tree
Showing 11 changed files with 17 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import com.google.common.collect.Sets;
import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.RecordInfo;
import com.tencent.supersonic.common.pojo.enums.ChatModelType;
import lombok.Data;
import org.springframework.util.CollectionUtils;

Expand All @@ -22,12 +21,8 @@ public class Agent extends RecordInfo {
private Integer status;
private List<String> examples;
private Integer enableSearch;
private Integer enableMemoryReview;
private String toolConfig;
private Map<ChatModelType, Integer> chatModelConfig = Collections.EMPTY_MAP;
private Map<String, ChatApp> chatAppConfig = Collections.EMPTY_MAP;
private PromptConfig promptConfig;
private MultiTurnConfig multiTurnConfig;
private VisualConfig visualConfig;

public List<String> getTools(AgentToolType type) {
Expand All @@ -49,7 +44,7 @@ public boolean enableSearch() {
}

public boolean enableMemoryReview() {
return enableMemoryReview != null && enableMemoryReview == 1;
return false;
}

public static boolean containsAllModel(Set<Long> detectViewIds) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,19 @@

import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.chat.api.pojo.response.ChatModelTypeResp;
import com.tencent.supersonic.chat.server.config.ChatModelParameters;
import com.tencent.supersonic.chat.server.pojo.ChatModel;
import com.tencent.supersonic.chat.server.service.ChatModelService;
import com.tencent.supersonic.chat.server.util.ModelConfigHelper;
import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.Parameter;
import com.tencent.supersonic.common.pojo.enums.ChatModelType;
import com.tencent.supersonic.common.util.ChatAppManager;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.*;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;

@RestController
@RequestMapping({"/api/chat/model", "/openapi/chat/model"})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryFilter;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
import com.tencent.supersonic.chat.server.agent.PromptConfig;
import com.tencent.supersonic.chat.server.agent.VisualConfig;
import com.tencent.supersonic.chat.server.persistence.dataobject.AgentDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
Expand All @@ -24,8 +22,6 @@
import org.springframework.util.CollectionUtils;

import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.stream.Collectors;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
package com.tencent.supersonic.common.util;

import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.tencent.supersonic.common.pojo.ChatApp;

import java.util.List;
import java.util.Map;

public class ChatAppManager {
private static final Map<String, ChatApp> chatApps = Maps.newConcurrentMap();

public static void register(ChatApp chatApp) {
if (chatApps.containsKey(chatApp.getKey())) {
throw new RuntimeException("Duplicate chat app key is disallowed.");
}
chatApps.put(chatApp.getKey(), chatApp);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo seman
return;
}

ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(chatApp.getChatModelConfig());
ChatLanguageModel chatLanguageModel =
ModelProvider.getChatModel(chatApp.getChatModelConfig());
SemanticSqlExtractor extractor =
AiServices.create(SemanticSqlExtractor.class, chatLanguageModel);
Prompt prompt = generatePrompt(chatQueryContext.getQueryText(), semanticParseInfo,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ public PageInfo<DictValueResp> queryDictValue(String fileName, DictValueReq dict
}

private PageInfo<DictValueResp> getDictValueRespPagWithKey(String fileName,
DictValueReq dictValueReq) {
DictValueReq dictValueReq) {
PageInfo<DictValueResp> dictValueRespPageInfo = new PageInfo<>();
dictValueRespPageInfo.setPageSize(dictValueReq.getPageSize());
dictValueRespPageInfo.setPageNum(dictValueReq.getCurrent());
Expand Down Expand Up @@ -118,11 +118,12 @@ private PageInfo<DictValueResp> getDictValueRespPagWithKey(String fileName,
}

private PageInfo<DictValueResp> getDictValueRespPagWithoutKey(String fileName,
DictValueReq dictValueReq) {
DictValueReq dictValueReq) {
PageInfo<DictValueResp> dictValueRespPageInfo = new PageInfo<>();
String filePath = localFileConfig.getDictDirectoryLatest() + FILE_SPILT + fileName;
Long fileLineNum = getFileLineNum(filePath);
Integer startLine = Math.max(1, (dictValueReq.getCurrent() - 1) * dictValueReq.getPageSize() + 1);
Integer startLine =
Math.max(1, (dictValueReq.getCurrent() - 1) * dictValueReq.getPageSize() + 1);
Integer endLine = Integer.valueOf(
Math.min(dictValueReq.getCurrent() * dictValueReq.getPageSize(), fileLineNum) + "");
List<DictValueResp> dictValueRespList = getFileData(filePath, startLine, endLine);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ public class DictTaskServiceImpl implements DictTaskService {
private final DimensionService dimensionService;

public DictTaskServiceImpl(DictRepository dictRepository, DictUtils dictConverter,
DictUtils dictUtils, FileHandler fileHandler, DictWordService dictWordService,
DimensionService dimensionService) {
DictUtils dictUtils, FileHandler fileHandler, DictWordService dictWordService,
DimensionService dimensionService) {
this.dictRepository = dictRepository;
this.dictConverter = dictConverter;
this.dictUtils = dictUtils;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ public static DimensionDO convert2DimensionDO(DimensionReq dimensionReq) {
}

public static DimensionResp convert2DimensionResp(DimensionDO dimensionDO,
Map<Long, ModelResp> modelRespMap) {
Map<Long, ModelResp> modelRespMap) {
DimensionResp dimensionResp = new DimensionResp();
BeanUtils.copyProperties(dimensionDO, dimensionResp);
dimensionResp.setModelName(
Expand Down Expand Up @@ -123,11 +123,11 @@ private static DimensionType getType(String type) {
}

public static List<DimensionResp> filterByDataSet(List<DimensionResp> dimensionResps,
DataSetResp dataSetResp) {
DataSetResp dataSetResp) {
return dimensionResps.stream()
.filter(dimensionResp -> dataSetResp.dimensionIds().contains(dimensionResp.getId())
|| dataSetResp.getAllIncludeAllModels()
.contains(dimensionResp.getModelId()))
.contains(dimensionResp.getModelId()))
.collect(Collectors.toList());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.AgentToolType;
import com.tencent.supersonic.chat.server.agent.DatasetTool;
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
import com.tencent.supersonic.chat.server.agent.ToolConfig;
import com.tencent.supersonic.chat.server.plugin.ChatPlugin;
import com.tencent.supersonic.chat.server.plugin.PluginParseConfig;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ com.tencent.supersonic.headless.chat.parser.SemanticParser=\
com.tencent.supersonic.headless.chat.parser.QueryTypeParser

com.tencent.supersonic.headless.chat.corrector.SemanticCorrector=\
com.tencent.supersonic.headless.chat.corrector.RuleSqlCorrector
com.tencent.supersonic.headless.chat.corrector.RuleSqlCorrector,\
com.tencent.supersonic.headless.chat.corrector.LLMSqlCorrector

com.tencent.supersonic.headless.chat.knowledge.file.FileHandler=\
com.tencent.supersonic.headless.chat.knowledge.file.FileHandlerImpl
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public class Text2SQLEval extends BaseTest {

@BeforeAll
public void init() {
Agent agent = agentService.createAgent(getLLMAgent(false), DataUtils.getUser());
Agent agent = agentService.createAgent(getLLMAgent(), DataUtils.getUser());
agentId = agent.getId();
}

Expand Down Expand Up @@ -133,7 +133,7 @@ public void test_second_calculation() throws Exception {
assert result.getTextResult().contains("3");
}

public Agent getLLMAgent(boolean enableMultiturn) {
public Agent getLLMAgent() {
Agent agent = new Agent();
agent.setName("Agent for Test");
ToolConfig toolConfig = new ToolConfig();
Expand All @@ -142,6 +142,7 @@ public Agent getLLMAgent(boolean enableMultiturn) {
ChatModel chatModel = new ChatModel();
chatModel.setName("Text2SQL LLM");
chatModel.setConfig(LLMConfigUtils.getLLMConfig(LLMConfigUtils.LLMType.OLLAMA_LLAMA3));
chatModel.setConfig(LocalLLMConfigUtils.getLLMConfig(LLMConfigUtils.LLMType.OPENAI_GPT));
chatModel = chatModelService.createChatModel(chatModel, User.getDefaultUser());
Integer chatModelId = chatModel.getId();
// configure chat apps
Expand Down

0 comments on commit 49d0638

Please sign in to comment.