src/main/java/cc/mrbird/febs/ai/controller/enumerates/AiTypeEnum.java | ●●●●● patch | view | raw | blame | history | |
src/main/java/cc/mrbird/febs/ai/req/ai/AiMessage.java | ●●●●● patch | view | raw | blame | history | |
src/main/java/cc/mrbird/febs/ai/req/ai/AiRequest.java | ●●●●● patch | view | raw | blame | history | |
src/main/java/cc/mrbird/febs/ai/service/AiMemberTalkItemService.java | ●●●●● patch | view | raw | blame | history | |
src/main/java/cc/mrbird/febs/ai/service/AiService.java | ●●●●● patch | view | raw | blame | history | |
src/main/java/cc/mrbird/febs/ai/service/impl/AiMemberTalkItemServiceImpl.java | ●●●●● patch | view | raw | blame | history | |
src/main/java/cc/mrbird/febs/ai/service/impl/AiMemberTalkServiceImpl.java | ●●●●● patch | view | raw | blame | history | |
src/main/java/cc/mrbird/febs/ai/service/impl/AiServiceImpl.java | ●●●●● patch | view | raw | blame | history |
src/main/java/cc/mrbird/febs/ai/controller/enumerates/AiTypeEnum.java
New file @@ -0,0 +1,33 @@ package cc.mrbird.febs.ai.controller.enumerates; import lombok.Getter; /** * @author Administrator */ @Getter public enum AiTypeEnum { /** * 1:AI提问 * 2:用户回答 * 3:生成答案解析 */ QUESTION_ANSWER(1,"AI提问"), USER_ANSWER(2,"用户回答"), ANSWER_ANALYSIS(3,"生成答案解析"), QUESTION(1,"<strong>\"生成题目\"</strong>"), ANSWER(2,"生成答案解析"); private final int code; private final String name; AiTypeEnum(int code,String name) { this.code = code; this.name = name; } } src/main/java/cc/mrbird/febs/ai/req/ai/AiMessage.java
New file @@ -0,0 +1,16 @@ package cc.mrbird.febs.ai.req.ai; import com.volcengine.ark.runtime.model.completion.chat.ChatMessageRole; import io.swagger.annotations.ApiModel; import lombok.Data; /** * @author Administrator */ @Data @ApiModel(value = "AiMessage", description = "参数") public class AiMessage { private ChatMessageRole role; private String content; } src/main/java/cc/mrbird/febs/ai/req/ai/AiRequest.java
@@ -3,6 +3,8 @@ import io.swagger.annotations.ApiModel; import lombok.Data; import java.util.List; /** * @author Administrator */ @@ -14,4 +16,5 @@ private String jsonTemplate; private String linkId; private String content; private List<AiMessage> aiMessageDtoList; } src/main/java/cc/mrbird/febs/ai/service/AiMemberTalkItemService.java
@@ -1,6 +1,7 @@ package cc.mrbird.febs.ai.service; import cc.mrbird.febs.ai.entity.AiMemberTalkItem; import cc.mrbird.febs.ai.req.ai.AiMessage; import cc.mrbird.febs.ai.req.memberTalk.ApiMemberTalkItemPageDto; import cc.mrbird.febs.common.entity.FebsResponse; import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; @@ -23,4 +24,7 @@ AiMemberTalkItem getByQuery(LambdaQueryWrapper<AiMemberTalkItem> memberTalkItemQuery); FebsResponse historyPage(ApiMemberTalkItemPageDto dto); List<AiMessage> getQuestionUpDownContext(String memberTalkId, int code); } src/main/java/cc/mrbird/febs/ai/service/AiService.java
@@ -1,9 +1,11 @@ package cc.mrbird.febs.ai.service; import cc.mrbird.febs.ai.req.ai.AiMessage; import cc.mrbird.febs.ai.req.ai.AiRequest; import cc.mrbird.febs.ai.res.ai.AiResponse; import cc.mrbird.febs.ai.res.ai.Report; import java.util.List; import java.util.function.Consumer; /** @@ -12,7 +14,7 @@ public interface AiService { AiResponse start(String productRoleId, String content); AiResponse start(List<AiMessage> aiMessageDtoList,Integer type, String productRoleId, String answer, String question); AiResponse question(AiRequest aiRequest); src/main/java/cc/mrbird/febs/ai/service/impl/AiMemberTalkItemServiceImpl.java
@@ -1,21 +1,26 @@ package cc.mrbird.febs.ai.service.impl; import cc.mrbird.febs.ai.controller.enumerates.AiTypeEnum; import cc.mrbird.febs.ai.entity.AiMemberTalk; import cc.mrbird.febs.ai.entity.AiMemberTalkItem; import cc.mrbird.febs.ai.mapper.AiMemberTalkItemMapper; import cc.mrbird.febs.ai.req.ai.AiMessage; import cc.mrbird.febs.ai.req.memberTalk.ApiMemberTalkItemPageDto; import cc.mrbird.febs.ai.res.memberTalk.ApiMemberTalkItemVo; import cc.mrbird.febs.ai.res.product.ApiProductVo; import cc.mrbird.febs.ai.service.AiMemberTalkItemService; import cc.mrbird.febs.ai.utils.UUID; import cc.mrbird.febs.common.entity.FebsResponse; import cn.hutool.core.collection.CollUtil; import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; import com.baomidou.mybatisplus.core.toolkit.Wrappers; import com.baomidou.mybatisplus.extension.plugins.pagination.Page; import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; import com.volcengine.ark.runtime.model.completion.chat.ChatMessageRole; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; import java.util.ArrayList; import java.util.Date; import java.util.List; @@ -31,7 +36,6 @@ public class AiMemberTalkItemServiceImpl extends ServiceImpl<AiMemberTalkItemMapper, AiMemberTalkItem> implements AiMemberTalkItemService { private final AiMemberTalkItemMapper aiMemberTalkItemMapper; @Override public void add(String memberUuid, String id, int type, String resContext,Date createdTime) { @@ -57,4 +61,39 @@ Page<ApiMemberTalkItemVo> pageListByQuery = aiMemberTalkItemMapper.getPageListByQuery(page, dto); return new FebsResponse().success().data(pageListByQuery); } @Override public List<AiMessage> getQuestionUpDownContext(String memberTalkId, int code) { List<AiMessage> aiMessages = new ArrayList<>(); LambdaQueryWrapper<AiMemberTalkItem> query = Wrappers.lambdaQuery(AiMemberTalkItem.class); if (AiTypeEnum.QUESTION.getCode() == code){ query.eq(AiMemberTalkItem::getType,AiTypeEnum.QUESTION_ANSWER.getCode()); query.eq(AiMemberTalkItem::getMemberTalkId,memberTalkId); query.orderByDesc(AiMemberTalkItem::getCreatedTime); query.last("limit 1"); List<AiMemberTalkItem> aiMemberTalkItems = aiMemberTalkItemMapper.selectList(query); if (CollUtil.isNotEmpty(aiMemberTalkItems)){ AiMessage assistantMessage = new AiMessage(); assistantMessage.setRole(ChatMessageRole.ASSISTANT); assistantMessage.setContent(aiMemberTalkItems.get(0).getContext()); aiMessages.add(assistantMessage); } } if (AiTypeEnum.ANSWER.getCode() == code){ query.eq(AiMemberTalkItem::getType,AiTypeEnum.ANSWER_ANALYSIS.getCode()); query.eq(AiMemberTalkItem::getMemberTalkId,memberTalkId); query.orderByDesc(AiMemberTalkItem::getCreatedTime); query.last("limit 1"); List<AiMemberTalkItem> aiMemberTalkItems = aiMemberTalkItemMapper.selectList(query); if (CollUtil.isNotEmpty(aiMemberTalkItems)){ AiMessage assistantMessage = new AiMessage(); assistantMessage.setRole(ChatMessageRole.ASSISTANT); assistantMessage.setContent(aiMemberTalkItems.get(0).getContext()); aiMessages.add(assistantMessage); } } return aiMessages; } } src/main/java/cc/mrbird/febs/ai/service/impl/AiMemberTalkServiceImpl.java
@@ -1,17 +1,17 @@ package cc.mrbird.febs.ai.service.impl; import cc.mrbird.febs.ai.controller.enumerates.AiTypeEnum; import cc.mrbird.febs.ai.entity.AiMemberTalk; import cc.mrbird.febs.ai.entity.AiMemberTalkItem; import cc.mrbird.febs.ai.entity.AiProductRoleLink; import cc.mrbird.febs.ai.mapper.AiMemberTalkMapper; import cc.mrbird.febs.ai.req.ai.AiMessage; import cc.mrbird.febs.ai.req.ai.AiRequest; import cc.mrbird.febs.ai.req.memberTalk.ApiMemberTalkAnswerDto; import cc.mrbird.febs.ai.req.memberTalk.ApiMemberTalkDto; import cc.mrbird.febs.ai.req.memberTalk.ApiMemberTalkItemPageDto; import cc.mrbird.febs.ai.res.ai.AiResponse; import cc.mrbird.febs.ai.res.memberTalk.ApiMemberTalkItemVo; import cc.mrbird.febs.ai.res.memberTalk.ApiMemberTalkVo; import cc.mrbird.febs.ai.res.product.ApiProductVo; import cc.mrbird.febs.ai.service.AiMemberTalkItemService; import cc.mrbird.febs.ai.service.AiMemberTalkService; import cc.mrbird.febs.ai.service.AiProductRoleLinkService; @@ -25,13 +25,14 @@ import cn.hutool.json.JSONUtil; import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; import com.baomidou.mybatisplus.core.toolkit.Wrappers; import com.baomidou.mybatisplus.extension.plugins.pagination.Page; import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; import java.util.Date; import java.util.List; import java.util.function.Consumer; /** @@ -57,6 +58,7 @@ } @Override @Transactional public FebsResponse start(ApiMemberTalkDto dto) { String memberUuid = LoginUserUtil.getLoginUser().getMemberUuid(); @@ -80,7 +82,9 @@ aiMemberTalk = this.add(memberUuid,productId,nowTime); } AiResponse aiResponse = aiService.start(aiProductRoleLink.getProductRoleId(),"<strong>\"生成题目\"</strong>"); List<AiMessage> aiMessageDtoList = aiMemberTalkItemService.getQuestionUpDownContext(aiMemberTalk.getId(),AiTypeEnum.QUESTION.getCode()); AiResponse aiResponse = aiService.start(aiMessageDtoList,AiTypeEnum.QUESTION.getCode(),aiProductRoleLink.getProductRoleId(),AiTypeEnum.QUESTION.getName(), null); if(aiResponse.getCode().equals("200")){ aiMemberTalkItemService.add(memberUuid,aiMemberTalk.getId(),1,aiResponse.getResContext(),nowTime); this.updateTimeUpdate(nowTime,aiMemberTalk.getId()); @@ -179,8 +183,9 @@ } public static final String ANSWER_FORMAT = "{}/n[回答]{}/n"; public static final String ANSWER_FORMAT = "###题目:{}###用户回答:{}"; @Override @Transactional public FebsResponse answer(ApiMemberTalkAnswerDto dto) { String memberUuid = LoginUserUtil.getLoginUser().getMemberUuid(); String memberTalkId = dto.getId(); @@ -211,11 +216,17 @@ String format = StrUtil.format(ANSWER_FORMAT, aiMemberTalkItem.getContext(), reqContext); log.info("format:{}",format); // AiResponse aiResponse = aiService.start(aiProductRoleLink.getProductRoleId(), format); AiResponse aiResponse = aiService.start(aiProductRoleLink.getProductRoleId(), reqContext); List<AiMessage> aiMessageDtoList = aiMemberTalkItemService.getQuestionUpDownContext(aiMemberTalk.getId(),AiTypeEnum.ANSWER.getCode()); AiResponse aiResponse = aiService.start(aiMessageDtoList,AiTypeEnum.ANSWER.getCode(),aiProductRoleLink.getProductRoleId(), reqContext,aiMemberTalkItem.getContext()); String context = null; if(aiResponse.getCode().equals("200")){ Date nowTime = new Date(); context = String.valueOf(JSONUtil.parse(aiResponse.getReport())); if ("null".equals( context)){ context = aiResponse.getResContext(); } aiMemberTalkItemService.add(memberUuid,aiMemberTalk.getId(),3, context,nowTime); this.updateTimeUpdate(nowTime,aiMemberTalk.getId()); }else{ src/main/java/cc/mrbird/febs/ai/service/impl/AiServiceImpl.java
@@ -1,12 +1,15 @@ package cc.mrbird.febs.ai.service.impl; import cc.mrbird.febs.ai.controller.enumerates.AiTypeEnum; import cc.mrbird.febs.ai.entity.AiProductRole; import cc.mrbird.febs.ai.req.ai.AiMessage; import cc.mrbird.febs.ai.req.ai.AiRequest; import cc.mrbird.febs.ai.res.ai.AiResponse; import cc.mrbird.febs.ai.res.ai.RadarDataItem; import cc.mrbird.febs.ai.res.ai.Report; import cc.mrbird.febs.ai.service.AiProductRoleService; import cc.mrbird.febs.ai.service.AiService; import cn.hutool.core.collection.CollUtil; import cn.hutool.json.JSONUtil; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.JsonNode; @@ -100,7 +103,7 @@ } @Override public AiResponse start(String productRoleId, String content) { public AiResponse start(List<AiMessage> aiMessageDtoList,Integer type,String productRoleId, String content, String question) { if (!StringUtils.hasText(productRoleId)) { log.warn("productRoleId 不能为空"); return buildErrorResponse(CODE_NOT_FOUND, "AI陪练不存在"); @@ -112,7 +115,15 @@ return buildErrorResponse(CODE_NOT_FOUND, "AI陪练不存在"); } String promptTemplate = aiProductRole.getPromptTemplate(); String promptTemplate = "作为一个智能助手,请回答我提出的问题。"; if (AiTypeEnum.QUESTION.getCode() == type){ promptTemplate = aiProductRole.getPromptHead(); } if (AiTypeEnum.ANSWER.getCode() == type){ promptTemplate = aiProductRole.getPromptTemplate()+question; } log.info("promptTemplate: {}", promptTemplate); String linkId = aiProductRole.getLinkId(); String jsonTemplate = aiProductRole.getJsonTemplate(); @@ -130,7 +141,9 @@ aiRequest.setJsonTemplate(jsonTemplate); aiRequest.setLinkId(linkId); aiRequest.setContent(content); if (CollUtil.isNotEmpty(aiMessageDtoList)){ aiRequest.setAiMessageDtoList(aiMessageDtoList); } return this.question(aiRequest); } @@ -150,10 +163,19 @@ return buildErrorResponse(CODE_ERROR, "请求参数不完整"); } final List<ChatMessage> messages = new ArrayList<>(); final ChatMessage systemMessage = ChatMessage.builder().role(ChatMessageRole.SYSTEM).content(promptTemplate).build(); final ChatMessage userMessage = ChatMessage.builder().role(ChatMessageRole.USER).content(content).build(); List<ChatMessage> messages = new ArrayList<>(); ChatMessage systemMessage = ChatMessage.builder().role(ChatMessageRole.SYSTEM).content(promptTemplate).build(); ChatMessage userMessage = ChatMessage.builder().role(ChatMessageRole.USER).content(content).build(); messages.add(systemMessage); if (CollUtil.isNotEmpty(aiRequest.getAiMessageDtoList())){ aiRequest.getAiMessageDtoList().forEach(aiMessageDto -> { ChatMessage message = ChatMessage.builder() .role(aiMessageDto.getRole()) .content(aiMessageDto.getContent()) .build(); messages.add(message); }); } messages.add(userMessage); try { @@ -184,7 +206,6 @@ .filter(contentObj -> contentObj != null) .map(Object::toString) .collect(Collectors.joining()); Report report = this.extractReportData(result); return buildSuccessResponse(report, result); } catch (JsonProcessingException e) { @@ -313,7 +334,7 @@ public Report extractReportData(String modelOutput) { Matcher matcher = JSON_PATTERN.matcher(modelOutput); if (!matcher.find()) { log.warn("未匹配到FunctionCall内容,原始输出长度: {}", modelOutput.length()); log.warn("未匹配到FunctionCall内容,原始输出长度: {}", modelOutput); return null; }