From a46c4d2db30c2f534400a6179cd82f7beb07a29d Mon Sep 17 00:00:00 2001
From: Administrator <15274802129@163.com>
Date: Mon, 18 Aug 2025 09:54:23 +0800
Subject: [PATCH] feat(ai): 优化 AI 陪练问答流程

---
 src/main/java/cc/mrbird/febs/ai/service/impl/AiServiceImpl.java |   37 +++++++++++++++++++++++++++++--------
 1 files changed, 29 insertions(+), 8 deletions(-)

diff --git a/src/main/java/cc/mrbird/febs/ai/service/impl/AiServiceImpl.java b/src/main/java/cc/mrbird/febs/ai/service/impl/AiServiceImpl.java
index 9639b0b..fb2b2e7 100644
--- a/src/main/java/cc/mrbird/febs/ai/service/impl/AiServiceImpl.java
+++ b/src/main/java/cc/mrbird/febs/ai/service/impl/AiServiceImpl.java
@@ -1,12 +1,15 @@
 package cc.mrbird.febs.ai.service.impl;
 
+import cc.mrbird.febs.ai.controller.enumerates.AiTypeEnum;
 import cc.mrbird.febs.ai.entity.AiProductRole;
+import cc.mrbird.febs.ai.req.ai.AiMessage;
 import cc.mrbird.febs.ai.req.ai.AiRequest;
 import cc.mrbird.febs.ai.res.ai.AiResponse;
 import cc.mrbird.febs.ai.res.ai.RadarDataItem;
 import cc.mrbird.febs.ai.res.ai.Report;
 import cc.mrbird.febs.ai.service.AiProductRoleService;
 import cc.mrbird.febs.ai.service.AiService;
+import cn.hutool.core.collection.CollUtil;
 import cn.hutool.json.JSONUtil;
 import com.fasterxml.jackson.core.JsonProcessingException;
 import com.fasterxml.jackson.databind.JsonNode;
@@ -100,7 +103,7 @@
     }
 
     @Override
-    public AiResponse start(String productRoleId, String content) {
+    public AiResponse start(List<AiMessage> aiMessageDtoList,Integer type,String productRoleId, String content, String question) {
         if (!StringUtils.hasText(productRoleId)) {
             log.warn("productRoleId 不能为空");
             return buildErrorResponse(CODE_NOT_FOUND, "AI陪练不存在");
@@ -112,7 +115,15 @@
             return buildErrorResponse(CODE_NOT_FOUND, "AI陪练不存在");
         }
 
-        String promptTemplate = aiProductRole.getPromptTemplate();
+
+        String promptTemplate = "作为一个智能助手,请回答我提出的问题。";
+        if (AiTypeEnum.QUESTION.getCode() ==  type){
+            promptTemplate = aiProductRole.getPromptHead();
+        }
+        if (AiTypeEnum.ANSWER.getCode() ==  type){
+            promptTemplate = aiProductRole.getPromptTemplate()+question;
+        }
+        log.info("promptTemplate: {}", promptTemplate);
         String linkId = aiProductRole.getLinkId();
         String jsonTemplate = aiProductRole.getJsonTemplate();
 
@@ -130,7 +141,9 @@
         aiRequest.setJsonTemplate(jsonTemplate);
         aiRequest.setLinkId(linkId);
         aiRequest.setContent(content);
-
+        if (CollUtil.isNotEmpty(aiMessageDtoList)){
+            aiRequest.setAiMessageDtoList(aiMessageDtoList);
+        }
         return this.question(aiRequest);
     }
 
@@ -150,10 +163,19 @@
             return buildErrorResponse(CODE_ERROR, "请求参数不完整");
         }
 
-        final List<ChatMessage> messages = new ArrayList<>();
-        final ChatMessage systemMessage = ChatMessage.builder().role(ChatMessageRole.SYSTEM).content(promptTemplate).build();
-        final ChatMessage userMessage = ChatMessage.builder().role(ChatMessageRole.USER).content(content).build();
+        List<ChatMessage> messages = new ArrayList<>();
+        ChatMessage systemMessage = ChatMessage.builder().role(ChatMessageRole.SYSTEM).content(promptTemplate).build();
+        ChatMessage userMessage = ChatMessage.builder().role(ChatMessageRole.USER).content(content).build();
         messages.add(systemMessage);
+        if (CollUtil.isNotEmpty(aiRequest.getAiMessageDtoList())){
+            aiRequest.getAiMessageDtoList().forEach(aiMessageDto -> {
+                ChatMessage message = ChatMessage.builder()
+                        .role(aiMessageDto.getRole())
+                        .content(aiMessageDto.getContent())
+                        .build();
+                messages.add(message);
+            });
+        }
         messages.add(userMessage);
 
         try {
@@ -184,7 +206,6 @@
                     .filter(contentObj -> contentObj != null)
                     .map(Object::toString)
                     .collect(Collectors.joining());
-
             Report report = this.extractReportData(result);
             return buildSuccessResponse(report, result);
         } catch (JsonProcessingException e) {
@@ -313,7 +334,7 @@
     public Report extractReportData(String modelOutput) {
         Matcher matcher = JSON_PATTERN.matcher(modelOutput);
         if (!matcher.find()) {
-            log.warn("未匹配到FunctionCall内容,原始输出长度: {}", modelOutput.length());
+            log.warn("未匹配到FunctionCall内容,原始输出长度: {}", modelOutput);
             return null;
         }
 

--
Gitblit v1.9.1