Administrator
2025-08-18 a46c4d2db30c2f534400a6179cd82f7beb07a29d
feat(ai): 优化 AI 陪练问答流程

- 新增 AiMessage 类用于封装聊天消息
- 实现 getQuestionUpDownContext 方法获取上下文消息
- 修改 start 方法支持上下文消息传递
-优化 answer 方法处理用户回答
- 调整 promptTemplate 生成逻辑
6 files modified
2 files added
165 ■■■■ changed files
src/main/java/cc/mrbird/febs/ai/controller/enumerates/AiTypeEnum.java 33 ●●●●● patch | view | raw | blame | history
src/main/java/cc/mrbird/febs/ai/req/ai/AiMessage.java 16 ●●●●● patch | view | raw | blame | history
src/main/java/cc/mrbird/febs/ai/req/ai/AiRequest.java 3 ●●●●● patch | view | raw | blame | history
src/main/java/cc/mrbird/febs/ai/service/AiMemberTalkItemService.java 4 ●●●● patch | view | raw | blame | history
src/main/java/cc/mrbird/febs/ai/service/AiService.java 4 ●●● patch | view | raw | blame | history
src/main/java/cc/mrbird/febs/ai/service/impl/AiMemberTalkItemServiceImpl.java 45 ●●●●● patch | view | raw | blame | history
src/main/java/cc/mrbird/febs/ai/service/impl/AiMemberTalkServiceImpl.java 23 ●●●● patch | view | raw | blame | history
src/main/java/cc/mrbird/febs/ai/service/impl/AiServiceImpl.java 37 ●●●● patch | view | raw | blame | history
src/main/java/cc/mrbird/febs/ai/controller/enumerates/AiTypeEnum.java
New file
@@ -0,0 +1,33 @@
package cc.mrbird.febs.ai.controller.enumerates;
import lombok.Getter;
/**
 * @author Administrator
 */
@Getter
public enum AiTypeEnum {
    /**
     * 1:AI提问
     * 2:用户回答
     * 3:生成答案解析
     */
    QUESTION_ANSWER(1,"AI提问"),
    USER_ANSWER(2,"用户回答"),
    ANSWER_ANALYSIS(3,"生成答案解析"),
    QUESTION(1,"<strong>\"生成题目\"</strong>"),
    ANSWER(2,"生成答案解析");
    private final int code;
    private final String name;
    AiTypeEnum(int code,String name) {
        this.code = code;
        this.name = name;
    }
}
src/main/java/cc/mrbird/febs/ai/req/ai/AiMessage.java
New file
@@ -0,0 +1,16 @@
package cc.mrbird.febs.ai.req.ai;
import com.volcengine.ark.runtime.model.completion.chat.ChatMessageRole;
import io.swagger.annotations.ApiModel;
import lombok.Data;
/**
 * @author Administrator
 */
@Data
@ApiModel(value = "AiMessage", description = "参数")
public class AiMessage {
    private ChatMessageRole role;
    private String content;
}
src/main/java/cc/mrbird/febs/ai/req/ai/AiRequest.java
@@ -3,6 +3,8 @@
import io.swagger.annotations.ApiModel;
import lombok.Data;
import java.util.List;
/**
 * @author Administrator
 */
@@ -14,4 +16,5 @@
    private String jsonTemplate;
    private String linkId;
    private String content;
    private List<AiMessage> aiMessageDtoList;
}
src/main/java/cc/mrbird/febs/ai/service/AiMemberTalkItemService.java
@@ -1,6 +1,7 @@
package cc.mrbird.febs.ai.service;
import cc.mrbird.febs.ai.entity.AiMemberTalkItem;
import cc.mrbird.febs.ai.req.ai.AiMessage;
import cc.mrbird.febs.ai.req.memberTalk.ApiMemberTalkItemPageDto;
import cc.mrbird.febs.common.entity.FebsResponse;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
@@ -23,4 +24,7 @@
    AiMemberTalkItem getByQuery(LambdaQueryWrapper<AiMemberTalkItem> memberTalkItemQuery);
    FebsResponse historyPage(ApiMemberTalkItemPageDto dto);
    List<AiMessage> getQuestionUpDownContext(String memberTalkId, int code);
}
src/main/java/cc/mrbird/febs/ai/service/AiService.java
@@ -1,9 +1,11 @@
package cc.mrbird.febs.ai.service;
import cc.mrbird.febs.ai.req.ai.AiMessage;
import cc.mrbird.febs.ai.req.ai.AiRequest;
import cc.mrbird.febs.ai.res.ai.AiResponse;
import cc.mrbird.febs.ai.res.ai.Report;
import java.util.List;
import java.util.function.Consumer;
/**
@@ -12,7 +14,7 @@
public interface AiService {
    AiResponse start(String productRoleId, String content);
    AiResponse start(List<AiMessage> aiMessageDtoList,Integer type, String productRoleId, String answer, String question);
    AiResponse question(AiRequest aiRequest);
src/main/java/cc/mrbird/febs/ai/service/impl/AiMemberTalkItemServiceImpl.java
@@ -1,21 +1,26 @@
package cc.mrbird.febs.ai.service.impl;
import cc.mrbird.febs.ai.controller.enumerates.AiTypeEnum;
import cc.mrbird.febs.ai.entity.AiMemberTalk;
import cc.mrbird.febs.ai.entity.AiMemberTalkItem;
import cc.mrbird.febs.ai.mapper.AiMemberTalkItemMapper;
import cc.mrbird.febs.ai.req.ai.AiMessage;
import cc.mrbird.febs.ai.req.memberTalk.ApiMemberTalkItemPageDto;
import cc.mrbird.febs.ai.res.memberTalk.ApiMemberTalkItemVo;
import cc.mrbird.febs.ai.res.product.ApiProductVo;
import cc.mrbird.febs.ai.service.AiMemberTalkItemService;
import cc.mrbird.febs.ai.utils.UUID;
import cc.mrbird.febs.common.entity.FebsResponse;
import cn.hutool.core.collection.CollUtil;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.baomidou.mybatisplus.core.toolkit.Wrappers;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.volcengine.ark.runtime.model.completion.chat.ChatMessageRole;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
@@ -31,7 +36,6 @@
public class AiMemberTalkItemServiceImpl extends ServiceImpl<AiMemberTalkItemMapper, AiMemberTalkItem> implements AiMemberTalkItemService {
    private final AiMemberTalkItemMapper aiMemberTalkItemMapper;
    @Override
    public void add(String memberUuid, String id, int type, String resContext,Date createdTime) {
@@ -57,4 +61,39 @@
        Page<ApiMemberTalkItemVo> pageListByQuery = aiMemberTalkItemMapper.getPageListByQuery(page, dto);
        return new FebsResponse().success().data(pageListByQuery);
    }
    @Override
    public List<AiMessage> getQuestionUpDownContext(String memberTalkId, int code) {
        List<AiMessage> aiMessages = new ArrayList<>();
        LambdaQueryWrapper<AiMemberTalkItem> query = Wrappers.lambdaQuery(AiMemberTalkItem.class);
            if (AiTypeEnum.QUESTION.getCode() == code){
                query.eq(AiMemberTalkItem::getType,AiTypeEnum.QUESTION_ANSWER.getCode());
                query.eq(AiMemberTalkItem::getMemberTalkId,memberTalkId);
                query.orderByDesc(AiMemberTalkItem::getCreatedTime);
                query.last("limit 1");
                List<AiMemberTalkItem> aiMemberTalkItems = aiMemberTalkItemMapper.selectList(query);
                if (CollUtil.isNotEmpty(aiMemberTalkItems)){
                    AiMessage assistantMessage = new AiMessage();
                    assistantMessage.setRole(ChatMessageRole.ASSISTANT);
                    assistantMessage.setContent(aiMemberTalkItems.get(0).getContext());
                    aiMessages.add(assistantMessage);
                }
            }
            if (AiTypeEnum.ANSWER.getCode() == code){
                query.eq(AiMemberTalkItem::getType,AiTypeEnum.ANSWER_ANALYSIS.getCode());
                query.eq(AiMemberTalkItem::getMemberTalkId,memberTalkId);
                query.orderByDesc(AiMemberTalkItem::getCreatedTime);
                query.last("limit 1");
                List<AiMemberTalkItem> aiMemberTalkItems = aiMemberTalkItemMapper.selectList(query);
                if (CollUtil.isNotEmpty(aiMemberTalkItems)){
                    AiMessage assistantMessage = new AiMessage();
                    assistantMessage.setRole(ChatMessageRole.ASSISTANT);
                    assistantMessage.setContent(aiMemberTalkItems.get(0).getContext());
                    aiMessages.add(assistantMessage);
                }
            }
        return aiMessages;
    }
}
src/main/java/cc/mrbird/febs/ai/service/impl/AiMemberTalkServiceImpl.java
@@ -1,17 +1,17 @@
package cc.mrbird.febs.ai.service.impl;
import cc.mrbird.febs.ai.controller.enumerates.AiTypeEnum;
import cc.mrbird.febs.ai.entity.AiMemberTalk;
import cc.mrbird.febs.ai.entity.AiMemberTalkItem;
import cc.mrbird.febs.ai.entity.AiProductRoleLink;
import cc.mrbird.febs.ai.mapper.AiMemberTalkMapper;
import cc.mrbird.febs.ai.req.ai.AiMessage;
import cc.mrbird.febs.ai.req.ai.AiRequest;
import cc.mrbird.febs.ai.req.memberTalk.ApiMemberTalkAnswerDto;
import cc.mrbird.febs.ai.req.memberTalk.ApiMemberTalkDto;
import cc.mrbird.febs.ai.req.memberTalk.ApiMemberTalkItemPageDto;
import cc.mrbird.febs.ai.res.ai.AiResponse;
import cc.mrbird.febs.ai.res.memberTalk.ApiMemberTalkItemVo;
import cc.mrbird.febs.ai.res.memberTalk.ApiMemberTalkVo;
import cc.mrbird.febs.ai.res.product.ApiProductVo;
import cc.mrbird.febs.ai.service.AiMemberTalkItemService;
import cc.mrbird.febs.ai.service.AiMemberTalkService;
import cc.mrbird.febs.ai.service.AiProductRoleLinkService;
@@ -25,13 +25,14 @@
import cn.hutool.json.JSONUtil;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.baomidou.mybatisplus.core.toolkit.Wrappers;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import java.util.Date;
import java.util.List;
import java.util.function.Consumer;
/**
@@ -57,6 +58,7 @@
    }
    @Override
    @Transactional
    public FebsResponse start(ApiMemberTalkDto dto) {
        String memberUuid = LoginUserUtil.getLoginUser().getMemberUuid();
@@ -80,7 +82,9 @@
            aiMemberTalk = this.add(memberUuid,productId,nowTime);
        }
        AiResponse aiResponse = aiService.start(aiProductRoleLink.getProductRoleId(),"<strong>\"生成题目\"</strong>");
        List<AiMessage> aiMessageDtoList = aiMemberTalkItemService.getQuestionUpDownContext(aiMemberTalk.getId(),AiTypeEnum.QUESTION.getCode());
        AiResponse aiResponse = aiService.start(aiMessageDtoList,AiTypeEnum.QUESTION.getCode(),aiProductRoleLink.getProductRoleId(),AiTypeEnum.QUESTION.getName(), null);
        if(aiResponse.getCode().equals("200")){
            aiMemberTalkItemService.add(memberUuid,aiMemberTalk.getId(),1,aiResponse.getResContext(),nowTime);
            this.updateTimeUpdate(nowTime,aiMemberTalk.getId());
@@ -179,8 +183,9 @@
    }
    public static final String ANSWER_FORMAT = "{}/n[回答]{}/n";
    public static final String ANSWER_FORMAT = "###题目:{}###用户回答:{}";
    @Override
    @Transactional
    public FebsResponse answer(ApiMemberTalkAnswerDto dto) {
        String memberUuid = LoginUserUtil.getLoginUser().getMemberUuid();
        String memberTalkId = dto.getId();
@@ -211,11 +216,17 @@
        String format = StrUtil.format(ANSWER_FORMAT, aiMemberTalkItem.getContext(), reqContext);
        log.info("format:{}",format);
//        AiResponse aiResponse = aiService.start(aiProductRoleLink.getProductRoleId(), format);
        AiResponse aiResponse = aiService.start(aiProductRoleLink.getProductRoleId(), reqContext);
        List<AiMessage> aiMessageDtoList = aiMemberTalkItemService.getQuestionUpDownContext(aiMemberTalk.getId(),AiTypeEnum.ANSWER.getCode());
        AiResponse aiResponse = aiService.start(aiMessageDtoList,AiTypeEnum.ANSWER.getCode(),aiProductRoleLink.getProductRoleId(), reqContext,aiMemberTalkItem.getContext());
        String context = null;
        if(aiResponse.getCode().equals("200")){
            Date nowTime = new Date();
            context = String.valueOf(JSONUtil.parse(aiResponse.getReport()));
            if ("null".equals( context)){
                context = aiResponse.getResContext();
            }
            aiMemberTalkItemService.add(memberUuid,aiMemberTalk.getId(),3, context,nowTime);
            this.updateTimeUpdate(nowTime,aiMemberTalk.getId());
        }else{
src/main/java/cc/mrbird/febs/ai/service/impl/AiServiceImpl.java
@@ -1,12 +1,15 @@
package cc.mrbird.febs.ai.service.impl;
import cc.mrbird.febs.ai.controller.enumerates.AiTypeEnum;
import cc.mrbird.febs.ai.entity.AiProductRole;
import cc.mrbird.febs.ai.req.ai.AiMessage;
import cc.mrbird.febs.ai.req.ai.AiRequest;
import cc.mrbird.febs.ai.res.ai.AiResponse;
import cc.mrbird.febs.ai.res.ai.RadarDataItem;
import cc.mrbird.febs.ai.res.ai.Report;
import cc.mrbird.febs.ai.service.AiProductRoleService;
import cc.mrbird.febs.ai.service.AiService;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.json.JSONUtil;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
@@ -100,7 +103,7 @@
    }
    @Override
    public AiResponse start(String productRoleId, String content) {
    public AiResponse start(List<AiMessage> aiMessageDtoList,Integer type,String productRoleId, String content, String question) {
        if (!StringUtils.hasText(productRoleId)) {
            log.warn("productRoleId 不能为空");
            return buildErrorResponse(CODE_NOT_FOUND, "AI陪练不存在");
@@ -112,7 +115,15 @@
            return buildErrorResponse(CODE_NOT_FOUND, "AI陪练不存在");
        }
        String promptTemplate = aiProductRole.getPromptTemplate();
        String promptTemplate = "作为一个智能助手,请回答我提出的问题。";
        if (AiTypeEnum.QUESTION.getCode() ==  type){
            promptTemplate = aiProductRole.getPromptHead();
        }
        if (AiTypeEnum.ANSWER.getCode() ==  type){
            promptTemplate = aiProductRole.getPromptTemplate()+question;
        }
        log.info("promptTemplate: {}", promptTemplate);
        String linkId = aiProductRole.getLinkId();
        String jsonTemplate = aiProductRole.getJsonTemplate();
@@ -130,7 +141,9 @@
        aiRequest.setJsonTemplate(jsonTemplate);
        aiRequest.setLinkId(linkId);
        aiRequest.setContent(content);
        if (CollUtil.isNotEmpty(aiMessageDtoList)){
            aiRequest.setAiMessageDtoList(aiMessageDtoList);
        }
        return this.question(aiRequest);
    }
@@ -150,10 +163,19 @@
            return buildErrorResponse(CODE_ERROR, "请求参数不完整");
        }
        final List<ChatMessage> messages = new ArrayList<>();
        final ChatMessage systemMessage = ChatMessage.builder().role(ChatMessageRole.SYSTEM).content(promptTemplate).build();
        final ChatMessage userMessage = ChatMessage.builder().role(ChatMessageRole.USER).content(content).build();
        List<ChatMessage> messages = new ArrayList<>();
        ChatMessage systemMessage = ChatMessage.builder().role(ChatMessageRole.SYSTEM).content(promptTemplate).build();
        ChatMessage userMessage = ChatMessage.builder().role(ChatMessageRole.USER).content(content).build();
        messages.add(systemMessage);
        if (CollUtil.isNotEmpty(aiRequest.getAiMessageDtoList())){
            aiRequest.getAiMessageDtoList().forEach(aiMessageDto -> {
                ChatMessage message = ChatMessage.builder()
                        .role(aiMessageDto.getRole())
                        .content(aiMessageDto.getContent())
                        .build();
                messages.add(message);
            });
        }
        messages.add(userMessage);
        try {
@@ -184,7 +206,6 @@
                    .filter(contentObj -> contentObj != null)
                    .map(Object::toString)
                    .collect(Collectors.joining());
            Report report = this.extractReportData(result);
            return buildSuccessResponse(report, result);
        } catch (JsonProcessingException e) {
@@ -313,7 +334,7 @@
    public Report extractReportData(String modelOutput) {
        Matcher matcher = JSON_PATTERN.matcher(modelOutput);
        if (!matcher.find()) {
            log.warn("未匹配到FunctionCall内容,原始输出长度: {}", modelOutput.length());
            log.warn("未匹配到FunctionCall内容,原始输出长度: {}", modelOutput);
            return null;
        }