From a46c4d2db30c2f534400a6179cd82f7beb07a29d Mon Sep 17 00:00:00 2001
From: Administrator <15274802129@163.com>
Date: Mon, 18 Aug 2025 09:54:23 +0800
Subject: [PATCH] feat(ai): 优化 AI 陪练问答流程

---
 src/main/java/cc/mrbird/febs/ai/service/impl/AiMemberTalkItemServiceImpl.java |   45 ++++++++++++++-
 src/main/java/cc/mrbird/febs/ai/service/impl/AiServiceImpl.java               |   37 +++++++++--
 src/main/java/cc/mrbird/febs/ai/service/AiMemberTalkItemService.java          |    4 +
 src/main/java/cc/mrbird/febs/ai/req/ai/AiRequest.java                         |    3 +
 src/main/java/cc/mrbird/febs/ai/service/impl/AiMemberTalkServiceImpl.java     |   23 +++++--
 src/main/java/cc/mrbird/febs/ai/req/ai/AiMessage.java                         |   16 +++++
 src/main/java/cc/mrbird/febs/ai/service/AiService.java                        |    4 +
 src/main/java/cc/mrbird/febs/ai/controller/enumerates/AiTypeEnum.java         |   33 +++++++++++
 8 files changed, 147 insertions(+), 18 deletions(-)

diff --git a/src/main/java/cc/mrbird/febs/ai/controller/enumerates/AiTypeEnum.java b/src/main/java/cc/mrbird/febs/ai/controller/enumerates/AiTypeEnum.java
new file mode 100644
index 0000000..4dc7440
--- /dev/null
+++ b/src/main/java/cc/mrbird/febs/ai/controller/enumerates/AiTypeEnum.java
@@ -0,0 +1,33 @@
+package cc.mrbird.febs.ai.controller.enumerates;
+
+import lombok.Getter;
+
+/**
+ * @author Administrator
+ */
+
+@Getter
+public enum AiTypeEnum {
+
+    /**
+     * 1:AI提问
+     * 2:用户回答
+     * 3:生成答案解析
+     */
+    QUESTION_ANSWER(1,"AI提问"),
+    USER_ANSWER(2,"用户回答"),
+    ANSWER_ANALYSIS(3,"生成答案解析"),
+
+
+    QUESTION(1,"<strong>\"生成题目\"</strong>"),
+    ANSWER(2,"生成答案解析");
+
+    private final int code;
+    private final String name;
+
+    AiTypeEnum(int code,String name) {
+
+        this.code = code;
+        this.name = name;
+    }
+}
diff --git a/src/main/java/cc/mrbird/febs/ai/req/ai/AiMessage.java b/src/main/java/cc/mrbird/febs/ai/req/ai/AiMessage.java
new file mode 100644
index 0000000..a741700
--- /dev/null
+++ b/src/main/java/cc/mrbird/febs/ai/req/ai/AiMessage.java
@@ -0,0 +1,16 @@
+package cc.mrbird.febs.ai.req.ai;
+
+import com.volcengine.ark.runtime.model.completion.chat.ChatMessageRole;
+import io.swagger.annotations.ApiModel;
+import lombok.Data;
+
+/**
+ * @author Administrator
+ */
+@Data
+@ApiModel(value = "AiMessage", description = "参数")
+public class AiMessage {
+
+    private ChatMessageRole role;
+    private String content;
+}
diff --git a/src/main/java/cc/mrbird/febs/ai/req/ai/AiRequest.java b/src/main/java/cc/mrbird/febs/ai/req/ai/AiRequest.java
index b97f61a..6bf0e42 100644
--- a/src/main/java/cc/mrbird/febs/ai/req/ai/AiRequest.java
+++ b/src/main/java/cc/mrbird/febs/ai/req/ai/AiRequest.java
@@ -3,6 +3,8 @@
 import io.swagger.annotations.ApiModel;
 import lombok.Data;
 
+import java.util.List;
+
 /**
  * @author Administrator
  */
@@ -14,4 +16,5 @@
     private String jsonTemplate;
     private String linkId;
     private String content;
+    private List<AiMessage> aiMessageDtoList;
 }
diff --git a/src/main/java/cc/mrbird/febs/ai/service/AiMemberTalkItemService.java b/src/main/java/cc/mrbird/febs/ai/service/AiMemberTalkItemService.java
index 4da6a08..9b61d68 100644
--- a/src/main/java/cc/mrbird/febs/ai/service/AiMemberTalkItemService.java
+++ b/src/main/java/cc/mrbird/febs/ai/service/AiMemberTalkItemService.java
@@ -1,6 +1,7 @@
 package cc.mrbird.febs.ai.service;
 
 import cc.mrbird.febs.ai.entity.AiMemberTalkItem;
+import cc.mrbird.febs.ai.req.ai.AiMessage;
 import cc.mrbird.febs.ai.req.memberTalk.ApiMemberTalkItemPageDto;
 import cc.mrbird.febs.common.entity.FebsResponse;
 import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
@@ -23,4 +24,7 @@
     AiMemberTalkItem getByQuery(LambdaQueryWrapper<AiMemberTalkItem> memberTalkItemQuery);
 
     FebsResponse historyPage(ApiMemberTalkItemPageDto dto);
+
+    List<AiMessage> getQuestionUpDownContext(String memberTalkId, int code);
+
 }
diff --git a/src/main/java/cc/mrbird/febs/ai/service/AiService.java b/src/main/java/cc/mrbird/febs/ai/service/AiService.java
index f94fea8..f27c3f2 100644
--- a/src/main/java/cc/mrbird/febs/ai/service/AiService.java
+++ b/src/main/java/cc/mrbird/febs/ai/service/AiService.java
@@ -1,9 +1,11 @@
 package cc.mrbird.febs.ai.service;
 
+import cc.mrbird.febs.ai.req.ai.AiMessage;
 import cc.mrbird.febs.ai.req.ai.AiRequest;
 import cc.mrbird.febs.ai.res.ai.AiResponse;
 import cc.mrbird.febs.ai.res.ai.Report;
 
+import java.util.List;
 import java.util.function.Consumer;
 
 /**
@@ -12,7 +14,7 @@
 public interface AiService {
 
 
-    AiResponse start(String productRoleId, String content);
+    AiResponse start(List<AiMessage> aiMessageDtoList,Integer type, String productRoleId, String answer, String question);
 
     AiResponse question(AiRequest aiRequest);
 
diff --git a/src/main/java/cc/mrbird/febs/ai/service/impl/AiMemberTalkItemServiceImpl.java b/src/main/java/cc/mrbird/febs/ai/service/impl/AiMemberTalkItemServiceImpl.java
index 0b4a3a0..ac1a089 100644
--- a/src/main/java/cc/mrbird/febs/ai/service/impl/AiMemberTalkItemServiceImpl.java
+++ b/src/main/java/cc/mrbird/febs/ai/service/impl/AiMemberTalkItemServiceImpl.java
@@ -1,21 +1,26 @@
 package cc.mrbird.febs.ai.service.impl;
 
+import cc.mrbird.febs.ai.controller.enumerates.AiTypeEnum;
+import cc.mrbird.febs.ai.entity.AiMemberTalk;
 import cc.mrbird.febs.ai.entity.AiMemberTalkItem;
 import cc.mrbird.febs.ai.mapper.AiMemberTalkItemMapper;
+import cc.mrbird.febs.ai.req.ai.AiMessage;
 import cc.mrbird.febs.ai.req.memberTalk.ApiMemberTalkItemPageDto;
 import cc.mrbird.febs.ai.res.memberTalk.ApiMemberTalkItemVo;
-import cc.mrbird.febs.ai.res.product.ApiProductVo;
 import cc.mrbird.febs.ai.service.AiMemberTalkItemService;
 import cc.mrbird.febs.ai.utils.UUID;
 import cc.mrbird.febs.common.entity.FebsResponse;
+import cn.hutool.core.collection.CollUtil;
 import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
+import com.baomidou.mybatisplus.core.toolkit.Wrappers;
 import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
 import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
+import com.volcengine.ark.runtime.model.completion.chat.ChatMessageRole;
 import lombok.RequiredArgsConstructor;
 import lombok.extern.slf4j.Slf4j;
 import org.springframework.stereotype.Service;
-import org.springframework.transaction.annotation.Transactional;
 
+import java.util.ArrayList;
 import java.util.Date;
 import java.util.List;
 
@@ -31,7 +36,6 @@
 public class AiMemberTalkItemServiceImpl extends ServiceImpl<AiMemberTalkItemMapper, AiMemberTalkItem> implements AiMemberTalkItemService {
 
     private final AiMemberTalkItemMapper aiMemberTalkItemMapper;
-
 
     @Override
     public void add(String memberUuid, String id, int type, String resContext,Date createdTime) {
@@ -57,4 +61,39 @@
         Page<ApiMemberTalkItemVo> pageListByQuery = aiMemberTalkItemMapper.getPageListByQuery(page, dto);
         return new FebsResponse().success().data(pageListByQuery);
     }
+
+    @Override
+    public List<AiMessage> getQuestionUpDownContext(String memberTalkId, int code) {
+        List<AiMessage> aiMessages = new ArrayList<>();
+
+        LambdaQueryWrapper<AiMemberTalkItem> query = Wrappers.lambdaQuery(AiMemberTalkItem.class);
+            if (AiTypeEnum.QUESTION.getCode() == code){
+                query.eq(AiMemberTalkItem::getType,AiTypeEnum.QUESTION_ANSWER.getCode());
+                query.eq(AiMemberTalkItem::getMemberTalkId,memberTalkId);
+                query.orderByDesc(AiMemberTalkItem::getCreatedTime);
+                query.last("limit 1");
+                List<AiMemberTalkItem> aiMemberTalkItems = aiMemberTalkItemMapper.selectList(query);
+                if (CollUtil.isNotEmpty(aiMemberTalkItems)){
+                    AiMessage assistantMessage = new AiMessage();
+                    assistantMessage.setRole(ChatMessageRole.ASSISTANT);
+                    assistantMessage.setContent(aiMemberTalkItems.get(0).getContext());
+                    aiMessages.add(assistantMessage);
+                }
+            }
+            if (AiTypeEnum.ANSWER.getCode() == code){
+                query.eq(AiMemberTalkItem::getType,AiTypeEnum.ANSWER_ANALYSIS.getCode());
+                query.eq(AiMemberTalkItem::getMemberTalkId,memberTalkId);
+                query.orderByDesc(AiMemberTalkItem::getCreatedTime);
+                query.last("limit 1");
+                List<AiMemberTalkItem> aiMemberTalkItems = aiMemberTalkItemMapper.selectList(query);
+                if (CollUtil.isNotEmpty(aiMemberTalkItems)){
+                    AiMessage assistantMessage = new AiMessage();
+                    assistantMessage.setRole(ChatMessageRole.ASSISTANT);
+                    assistantMessage.setContent(aiMemberTalkItems.get(0).getContext());
+                    aiMessages.add(assistantMessage);
+                }
+            }
+        return aiMessages;
+    }
+
 }
diff --git a/src/main/java/cc/mrbird/febs/ai/service/impl/AiMemberTalkServiceImpl.java b/src/main/java/cc/mrbird/febs/ai/service/impl/AiMemberTalkServiceImpl.java
index 96ceed5..690d80c 100644
--- a/src/main/java/cc/mrbird/febs/ai/service/impl/AiMemberTalkServiceImpl.java
+++ b/src/main/java/cc/mrbird/febs/ai/service/impl/AiMemberTalkServiceImpl.java
@@ -1,17 +1,17 @@
 package cc.mrbird.febs.ai.service.impl;
 
+import cc.mrbird.febs.ai.controller.enumerates.AiTypeEnum;
 import cc.mrbird.febs.ai.entity.AiMemberTalk;
 import cc.mrbird.febs.ai.entity.AiMemberTalkItem;
 import cc.mrbird.febs.ai.entity.AiProductRoleLink;
 import cc.mrbird.febs.ai.mapper.AiMemberTalkMapper;
+import cc.mrbird.febs.ai.req.ai.AiMessage;
 import cc.mrbird.febs.ai.req.ai.AiRequest;
 import cc.mrbird.febs.ai.req.memberTalk.ApiMemberTalkAnswerDto;
 import cc.mrbird.febs.ai.req.memberTalk.ApiMemberTalkDto;
 import cc.mrbird.febs.ai.req.memberTalk.ApiMemberTalkItemPageDto;
 import cc.mrbird.febs.ai.res.ai.AiResponse;
-import cc.mrbird.febs.ai.res.memberTalk.ApiMemberTalkItemVo;
 import cc.mrbird.febs.ai.res.memberTalk.ApiMemberTalkVo;
-import cc.mrbird.febs.ai.res.product.ApiProductVo;
 import cc.mrbird.febs.ai.service.AiMemberTalkItemService;
 import cc.mrbird.febs.ai.service.AiMemberTalkService;
 import cc.mrbird.febs.ai.service.AiProductRoleLinkService;
@@ -25,13 +25,14 @@
 import cn.hutool.json.JSONUtil;
 import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
 import com.baomidou.mybatisplus.core.toolkit.Wrappers;
-import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
 import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
 import lombok.RequiredArgsConstructor;
 import lombok.extern.slf4j.Slf4j;
 import org.springframework.stereotype.Service;
+import org.springframework.transaction.annotation.Transactional;
 
 import java.util.Date;
+import java.util.List;
 import java.util.function.Consumer;
 
 /**
@@ -57,6 +58,7 @@
     }
 
     @Override
+    @Transactional
     public FebsResponse start(ApiMemberTalkDto dto) {
 
         String memberUuid = LoginUserUtil.getLoginUser().getMemberUuid();
@@ -80,7 +82,9 @@
             aiMemberTalk = this.add(memberUuid,productId,nowTime);
         }
 
-        AiResponse aiResponse = aiService.start(aiProductRoleLink.getProductRoleId(),"<strong>\"生成题目\"</strong>");
+        List<AiMessage> aiMessageDtoList = aiMemberTalkItemService.getQuestionUpDownContext(aiMemberTalk.getId(),AiTypeEnum.QUESTION.getCode());
+
+        AiResponse aiResponse = aiService.start(aiMessageDtoList,AiTypeEnum.QUESTION.getCode(),aiProductRoleLink.getProductRoleId(),AiTypeEnum.QUESTION.getName(), null);
         if(aiResponse.getCode().equals("200")){
             aiMemberTalkItemService.add(memberUuid,aiMemberTalk.getId(),1,aiResponse.getResContext(),nowTime);
             this.updateTimeUpdate(nowTime,aiMemberTalk.getId());
@@ -179,8 +183,9 @@
     }
 
 
-    public static final String ANSWER_FORMAT = "{}/n[回答]{}/n";
+    public static final String ANSWER_FORMAT = "###题目:{}###用户回答:{}";
     @Override
+    @Transactional
     public FebsResponse answer(ApiMemberTalkAnswerDto dto) {
         String memberUuid = LoginUserUtil.getLoginUser().getMemberUuid();
         String memberTalkId = dto.getId();
@@ -211,11 +216,17 @@
         String format = StrUtil.format(ANSWER_FORMAT, aiMemberTalkItem.getContext(), reqContext);
         log.info("format:{}",format);
 //        AiResponse aiResponse = aiService.start(aiProductRoleLink.getProductRoleId(), format);
-        AiResponse aiResponse = aiService.start(aiProductRoleLink.getProductRoleId(), reqContext);
+
+
+        List<AiMessage> aiMessageDtoList = aiMemberTalkItemService.getQuestionUpDownContext(aiMemberTalk.getId(),AiTypeEnum.ANSWER.getCode());
+        AiResponse aiResponse = aiService.start(aiMessageDtoList,AiTypeEnum.ANSWER.getCode(),aiProductRoleLink.getProductRoleId(), reqContext,aiMemberTalkItem.getContext());
         String context = null;
         if(aiResponse.getCode().equals("200")){
             Date nowTime = new Date();
             context = String.valueOf(JSONUtil.parse(aiResponse.getReport()));
+            if ("null".equals( context)){
+                context = aiResponse.getResContext();
+            }
             aiMemberTalkItemService.add(memberUuid,aiMemberTalk.getId(),3, context,nowTime);
             this.updateTimeUpdate(nowTime,aiMemberTalk.getId());
         }else{
diff --git a/src/main/java/cc/mrbird/febs/ai/service/impl/AiServiceImpl.java b/src/main/java/cc/mrbird/febs/ai/service/impl/AiServiceImpl.java
index 9639b0b..fb2b2e7 100644
--- a/src/main/java/cc/mrbird/febs/ai/service/impl/AiServiceImpl.java
+++ b/src/main/java/cc/mrbird/febs/ai/service/impl/AiServiceImpl.java
@@ -1,12 +1,15 @@
 package cc.mrbird.febs.ai.service.impl;
 
+import cc.mrbird.febs.ai.controller.enumerates.AiTypeEnum;
 import cc.mrbird.febs.ai.entity.AiProductRole;
+import cc.mrbird.febs.ai.req.ai.AiMessage;
 import cc.mrbird.febs.ai.req.ai.AiRequest;
 import cc.mrbird.febs.ai.res.ai.AiResponse;
 import cc.mrbird.febs.ai.res.ai.RadarDataItem;
 import cc.mrbird.febs.ai.res.ai.Report;
 import cc.mrbird.febs.ai.service.AiProductRoleService;
 import cc.mrbird.febs.ai.service.AiService;
+import cn.hutool.core.collection.CollUtil;
 import cn.hutool.json.JSONUtil;
 import com.fasterxml.jackson.core.JsonProcessingException;
 import com.fasterxml.jackson.databind.JsonNode;
@@ -100,7 +103,7 @@
     }
 
     @Override
-    public AiResponse start(String productRoleId, String content) {
+    public AiResponse start(List<AiMessage> aiMessageDtoList,Integer type,String productRoleId, String content, String question) {
         if (!StringUtils.hasText(productRoleId)) {
             log.warn("productRoleId 不能为空");
             return buildErrorResponse(CODE_NOT_FOUND, "AI陪练不存在");
@@ -112,7 +115,15 @@
             return buildErrorResponse(CODE_NOT_FOUND, "AI陪练不存在");
         }
 
-        String promptTemplate = aiProductRole.getPromptTemplate();
+
+        String promptTemplate = "作为一个智能助手,请回答我提出的问题。";
+        if (AiTypeEnum.QUESTION.getCode() ==  type){
+            promptTemplate = aiProductRole.getPromptHead();
+        }
+        if (AiTypeEnum.ANSWER.getCode() ==  type){
+            promptTemplate = aiProductRole.getPromptTemplate()+question;
+        }
+        log.info("promptTemplate: {}", promptTemplate);
         String linkId = aiProductRole.getLinkId();
         String jsonTemplate = aiProductRole.getJsonTemplate();
 
@@ -130,7 +141,9 @@
         aiRequest.setJsonTemplate(jsonTemplate);
         aiRequest.setLinkId(linkId);
         aiRequest.setContent(content);
-
+        if (CollUtil.isNotEmpty(aiMessageDtoList)){
+            aiRequest.setAiMessageDtoList(aiMessageDtoList);
+        }
         return this.question(aiRequest);
     }
 
@@ -150,10 +163,19 @@
             return buildErrorResponse(CODE_ERROR, "请求参数不完整");
         }
 
-        final List<ChatMessage> messages = new ArrayList<>();
-        final ChatMessage systemMessage = ChatMessage.builder().role(ChatMessageRole.SYSTEM).content(promptTemplate).build();
-        final ChatMessage userMessage = ChatMessage.builder().role(ChatMessageRole.USER).content(content).build();
+        List<ChatMessage> messages = new ArrayList<>();
+        ChatMessage systemMessage = ChatMessage.builder().role(ChatMessageRole.SYSTEM).content(promptTemplate).build();
+        ChatMessage userMessage = ChatMessage.builder().role(ChatMessageRole.USER).content(content).build();
         messages.add(systemMessage);
+        if (CollUtil.isNotEmpty(aiRequest.getAiMessageDtoList())){
+            aiRequest.getAiMessageDtoList().forEach(aiMessageDto -> {
+                ChatMessage message = ChatMessage.builder()
+                        .role(aiMessageDto.getRole())
+                        .content(aiMessageDto.getContent())
+                        .build();
+                messages.add(message);
+            });
+        }
         messages.add(userMessage);
 
         try {
@@ -184,7 +206,6 @@
                     .filter(contentObj -> contentObj != null)
                     .map(Object::toString)
                     .collect(Collectors.joining());
-
             Report report = this.extractReportData(result);
             return buildSuccessResponse(report, result);
         } catch (JsonProcessingException e) {
@@ -313,7 +334,7 @@
     public Report extractReportData(String modelOutput) {
         Matcher matcher = JSON_PATTERN.matcher(modelOutput);
         if (!matcher.find()) {
-            log.warn("未匹配到FunctionCall内容,原始输出长度: {}", modelOutput.length());
+            log.warn("未匹配到FunctionCall内容,原始输出长度: {}", modelOutput);
             return null;
         }
 

--
Gitblit v1.9.1