From 0be80a07e38f08dd40e20175d788a3bfad8e68ef Mon Sep 17 00:00:00 2001
From: Administrator <15274802129@163.com>
Date: Mon, 11 Aug 2025 11:41:24 +0800
Subject: [PATCH] feat(ai): 增加 AI 陪练报告数据解析功能 - 新增 Report、RadarData 和 Evaluation 类用于解析报告数据 - 在 AiService 接口中添加 extractReportData 方法 - 在 AiServiceImpl 中实现报告数据的提取和解析 - 更新 ApiMemberTalkVo,增加 report 字段用于存储解析后的报告数据 - 修改前端相关的回答格式和类型
---
src/main/java/cc/mrbird/febs/ai/service/impl/AiServiceImpl.java | 159 ++++++++++++++++++++++++-------
src/main/java/cc/mrbird/febs/ai/service/AiMemberTalkService.java | 4
src/main/java/cc/mrbird/febs/ai/service/impl/AiMemberTalkServiceImpl.java | 72 ++++++++++++++
src/main/java/cc/mrbird/febs/ai/controller/memberTalk/ApiMemberTalkController.java | 36 ++++++
src/main/java/cc/mrbird/febs/ai/req/ai/AiRequestParam.java | 16 +++
src/main/java/cc/mrbird/febs/ai/service/AiService.java | 4
6 files changed, 250 insertions(+), 41 deletions(-)
diff --git a/src/main/java/cc/mrbird/febs/ai/controller/memberTalk/ApiMemberTalkController.java b/src/main/java/cc/mrbird/febs/ai/controller/memberTalk/ApiMemberTalkController.java
index 66212da..ca256ba 100644
--- a/src/main/java/cc/mrbird/febs/ai/controller/memberTalk/ApiMemberTalkController.java
+++ b/src/main/java/cc/mrbird/febs/ai/controller/memberTalk/ApiMemberTalkController.java
@@ -5,10 +5,7 @@
import cc.mrbird.febs.ai.res.memberTalk.ApiMemberTalkVo;
import cc.mrbird.febs.ai.service.AiMemberTalkService;
import cc.mrbird.febs.common.entity.FebsResponse;
-import io.swagger.annotations.Api;
-import io.swagger.annotations.ApiOperation;
-import io.swagger.annotations.ApiResponse;
-import io.swagger.annotations.ApiResponses;
+import io.swagger.annotations.*;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.validation.annotation.Validated;
@@ -16,6 +13,9 @@
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
+import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
+
+import java.io.IOException;
/**
* @author Administrator
@@ -51,4 +51,32 @@
return aiMemberTalkService.answer(dto);
}
+
+ @PostMapping("/start-stream")
+ @ApiOperation("开始AI对话(流式)")
+ @ApiResponses({
+ @ApiResponse(code = 200, message = "流式响应", response = ApiMemberTalkVo.class),
+ @ApiResponse(code = 500, message = "系统错误")
+ })
+ public SseEmitter startStream(
+ @ApiParam(value = "对话请求参数", required = true)
+ @RequestBody ApiMemberTalkDto dto) {
+
+ SseEmitter emitter = new SseEmitter(0L); // 0表示永不超时
+
+ aiMemberTalkService.startStream(dto, response -> {
+ try {
+ emitter.send(SseEmitter.event().data(response));
+ // 如果包含report,说明是最终结果,关闭连接
+ if (response.getCode() != null &&
+ "200".equals(response.getCode())) {
+ emitter.complete();
+ }
+ } catch (IOException e) {
+ emitter.completeWithError(e);
+ }
+ });
+
+ return emitter;
+ }
}
diff --git a/src/main/java/cc/mrbird/febs/ai/req/ai/AiRequestParam.java b/src/main/java/cc/mrbird/febs/ai/req/ai/AiRequestParam.java
new file mode 100644
index 0000000..efe9d13
--- /dev/null
+++ b/src/main/java/cc/mrbird/febs/ai/req/ai/AiRequestParam.java
@@ -0,0 +1,16 @@
+package cc.mrbird.febs.ai.req.ai;
+
+import io.swagger.annotations.ApiModel;
+import lombok.Data;
+
+/**
+ * @author Administrator
+ */
+@Data
+@ApiModel(value = "AiRequestParam", description = "参数")
+public class AiRequestParam {
+
+ private String productRoleId;
+ private String content;
+ private String assistantRole;
+}
diff --git a/src/main/java/cc/mrbird/febs/ai/service/AiMemberTalkService.java b/src/main/java/cc/mrbird/febs/ai/service/AiMemberTalkService.java
index 1101850..f533640 100644
--- a/src/main/java/cc/mrbird/febs/ai/service/AiMemberTalkService.java
+++ b/src/main/java/cc/mrbird/febs/ai/service/AiMemberTalkService.java
@@ -3,11 +3,13 @@
import cc.mrbird.febs.ai.entity.AiMemberTalk;
import cc.mrbird.febs.ai.req.memberTalk.ApiMemberTalkAnswerDto;
import cc.mrbird.febs.ai.req.memberTalk.ApiMemberTalkDto;
+import cc.mrbird.febs.ai.res.ai.AiResponse;
import cc.mrbird.febs.common.entity.FebsResponse;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.baomidou.mybatisplus.extension.service.IService;
import java.util.Date;
+import java.util.function.Consumer;
/**
* AI用户对话训练记录 Service接口
@@ -22,6 +24,8 @@
FebsResponse start(ApiMemberTalkDto dto);
+ void startStream(ApiMemberTalkDto dto, Consumer<AiResponse> callback);
+
AiMemberTalk getByQuery(LambdaQueryWrapper<AiMemberTalk> query);
void updateTimeUpdate(Date nowTime, String id);
diff --git a/src/main/java/cc/mrbird/febs/ai/service/AiService.java b/src/main/java/cc/mrbird/febs/ai/service/AiService.java
index 04a87e9..f94fea8 100644
--- a/src/main/java/cc/mrbird/febs/ai/service/AiService.java
+++ b/src/main/java/cc/mrbird/febs/ai/service/AiService.java
@@ -4,6 +4,8 @@
import cc.mrbird.febs.ai.res.ai.AiResponse;
import cc.mrbird.febs.ai.res.ai.Report;
+import java.util.function.Consumer;
+
/**
* @author Administrator
*/
@@ -14,6 +16,8 @@
AiResponse question(AiRequest aiRequest);
+ void streamQuestion(AiRequest aiRequest, Consumer<AiResponse> callback);
+
/**
* 从模型输出中提取并解析报告数据
* @param modelOutput 模型原始输出
diff --git a/src/main/java/cc/mrbird/febs/ai/service/impl/AiMemberTalkServiceImpl.java b/src/main/java/cc/mrbird/febs/ai/service/impl/AiMemberTalkServiceImpl.java
index 6c88c98..db3d4ba 100644
--- a/src/main/java/cc/mrbird/febs/ai/service/impl/AiMemberTalkServiceImpl.java
+++ b/src/main/java/cc/mrbird/febs/ai/service/impl/AiMemberTalkServiceImpl.java
@@ -4,6 +4,7 @@
import cc.mrbird.febs.ai.entity.AiMemberTalkItem;
import cc.mrbird.febs.ai.entity.AiProductRoleLink;
import cc.mrbird.febs.ai.mapper.AiMemberTalkMapper;
+import cc.mrbird.febs.ai.req.ai.AiRequest;
import cc.mrbird.febs.ai.req.memberTalk.ApiMemberTalkAnswerDto;
import cc.mrbird.febs.ai.req.memberTalk.ApiMemberTalkDto;
import cc.mrbird.febs.ai.res.ai.AiResponse;
@@ -26,6 +27,7 @@
import org.springframework.stereotype.Service;
import java.util.Date;
+import java.util.function.Consumer;
/**
* AI用户对话训练记录 Service实现类
@@ -90,6 +92,76 @@
}
@Override
+ public void startStream(ApiMemberTalkDto dto, Consumer<AiResponse> callback) {
+ try {
+ String memberUuid = LoginUserUtil.getLoginUser().getMemberUuid();
+ String productId = dto.getId();
+
+ LambdaQueryWrapper<AiProductRoleLink> productLinkQuery = Wrappers.lambdaQuery(AiProductRoleLink.class);
+ productLinkQuery.eq(AiProductRoleLink::getProductId, productId);
+ productLinkQuery.last("limit 1");
+ AiProductRoleLink aiProductRoleLink = aiProductRoleLinkService.getByQuery(productLinkQuery);
+ if (ObjectUtil.isNull(aiProductRoleLink)) {
+ AiResponse aiResponse = new AiResponse();
+ aiResponse.setCode("500");
+ aiResponse.setDescription("产品AI陪练不存在");
+ callback.accept(aiResponse);
+ return;
+ }
+
+ // 构造AI请求
+ AiRequest aiRequest = new AiRequest();
+ aiRequest.setLinkId(aiProductRoleLink.getProductRoleId());
+ aiRequest.setPromptTemplate(aiProductRoleLink.getProductRoleId());
+ aiRequest.setContent("<strong>\"生成题目\"</strong>");
+
+ // 定义AI服务回调处理
+ Consumer<AiResponse> aiCallback = aiResponse -> {
+ Date nowTime = new Date();
+ LambdaQueryWrapper<AiMemberTalk> query = Wrappers.lambdaQuery(AiMemberTalk.class);
+ query.eq(AiMemberTalk::getMemberId, memberUuid);
+ query.eq(AiMemberTalk::getProductId, productId);
+ query.last("limit 1");
+ AiMemberTalk aiMemberTalk = this.getByQuery(query);
+ if (ObjectUtil.isNull(aiMemberTalk)) {
+ aiMemberTalk = this.add(memberUuid, productId, nowTime);
+ }
+ try {
+ if (aiResponse.getCode().equals("200")) {
+ // 如果是最终结果(包含报告)
+ if (aiResponse.getReport() != null) {
+ // 保存完整响应到数据库
+ aiMemberTalkItemService.add(memberUuid, aiMemberTalk.getId(), 1, aiResponse.getResContext(), nowTime);
+ this.updateTimeUpdate(nowTime, aiMemberTalk.getId());
+
+ callback.accept(aiResponse);
+ } else {
+ // 流式响应片段
+ callback.accept(aiResponse);
+ }
+ } else {
+ callback.accept(aiResponse);
+ }
+ } catch (Exception e) {
+ log.error("处理AI响应异常", e);
+ callback.accept(aiResponse);
+ }
+ };
+
+ // 调用AI服务的流式接口
+ aiService.streamQuestion(aiRequest, aiCallback);
+
+ } catch (Exception e) {
+ log.error("流式调用start方法异常", e);
+
+ AiResponse aiResponse = new AiResponse();
+ aiResponse.setCode("500");
+ aiResponse.setDescription("流式调用start方法异常");
+ callback.accept(aiResponse);
+ }
+ }
+
+ @Override
public AiMemberTalk getByQuery(LambdaQueryWrapper<AiMemberTalk> query) {
return aiMemberTalkMapper.selectOne( query);
}
diff --git a/src/main/java/cc/mrbird/febs/ai/service/impl/AiServiceImpl.java b/src/main/java/cc/mrbird/febs/ai/service/impl/AiServiceImpl.java
index 3ab9f2c..fca39c3 100644
--- a/src/main/java/cc/mrbird/febs/ai/service/impl/AiServiceImpl.java
+++ b/src/main/java/cc/mrbird/febs/ai/service/impl/AiServiceImpl.java
@@ -24,6 +24,7 @@
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.TimeUnit;
+import java.util.function.Consumer;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
@@ -37,8 +38,25 @@
public class AiServiceImpl implements AiService {
private static final String CODE_SUCCESS = "200";
+ private static final String CODE_GOING_ON = "199";
private static final String CODE_NOT_FOUND = "201";
private static final String CODE_ERROR = "500";
+
+ private static final String SCHEMA_JSON = "{\n" +
+ " \"radar_data\": {\n" +
+ " \"problem_understanding\": \"object\",\n" +
+ " \"fluency\": \"object\",\n" +
+ " \"principle_adherence\": \"object\",\n" +
+ " \"logicality\": \"object\",\n" +
+ " \"knowledge_mastery\": \"object\"\n" +
+ " },\n" +
+ " \"evaluation\": {\n" +
+ " \"highlight\": \"object\",\n" +
+ " \"suggestion\": \"object\",\n" +
+ " \"reference_answer\": \"object\",\n" +
+ " \"key_knowledge\": \"object\"\n" +
+ " }\n" +
+ " }";
private final AiProductRoleService aiProductRoleService;
private final ObjectMapper objectMapper;
@@ -56,8 +74,13 @@
@PostConstruct
public void init() {
- ConnectionPool connectionPool = new ConnectionPool(10, 30, TimeUnit.SECONDS);
+ // 增加连接池大小和存活时间
+ ConnectionPool connectionPool = new ConnectionPool(32, 60, TimeUnit.SECONDS);
Dispatcher dispatcher = new Dispatcher();
+ // 增加并发请求数量
+ dispatcher.setMaxRequests(128);
+ dispatcher.setMaxRequestsPerHost(32);
+
this.service = ArkService.builder()
.dispatcher(dispatcher)
.connectionPool(connectionPool)
@@ -100,7 +123,7 @@
aiRequest.setLinkId(linkId);
aiRequest.setContent(content);
- return question(aiRequest);
+ return this.question(aiRequest);
}
@Override
@@ -119,25 +142,8 @@
messages.add(systemMessage);
messages.add(userMessage);
- // 生成 JSON Schema
- String schemaJson = "{\n" +
- " \"radar_data\": {\n" +
- " \"problem_understanding\": \"object\",\n" +
- " \"fluency\": \"object\",\n" +
- " \"principle_adherence\": \"object\",\n" +
- " \"logicality\": \"object\",\n" +
- " \"knowledge_mastery\": \"object\"\n" +
- " },\n" +
- " \"evaluation\": {\n" +
- " \"highlight\": \"object\",\n" +
- " \"suggestion\": \"object\",\n" +
- " \"reference_answer\": \"object\",\n" +
- " \"key_knowledge\": \"object\"\n" +
- " }\n" +
- " }";
try {
- JsonNode schemaNode = objectMapper.readTree(schemaJson);
- // 配置响应格式
+ JsonNode schemaNode = objectMapper.readTree(SCHEMA_JSON);
ChatCompletionRequest.ChatCompletionRequestResponseFormat responseFormat = new ChatCompletionRequest.ChatCompletionRequestResponseFormat(
"json_schema",
new ResponseFormatJSONSchemaJSONSchemaParam(
@@ -152,11 +158,12 @@
.messages(messages)
.stream(false)
.responseFormat(responseFormat)
- .temperature(1.0)
- .topP(0.7)
- .maxTokens(4096)
+ .temperature(0.7) // 降低温度参数,提高确定性,可能提升速度
+ .topP(0.9) // 调整topP参数
+ .maxTokens(2048) // 减少最大token数
.frequencyPenalty(0.0)
.build();
+
List<ChatCompletionChoice> choices = service.createChatCompletion(chatCompletionRequest).getChoices();
String result = choices.stream()
.map(choice -> choice.getMessage().getContent())
@@ -170,10 +177,97 @@
log.error("初始化AI服务失败,JSON格式化输出初始化失败", e);
return buildErrorResponse(CODE_ERROR, "AI服务调用失败");
} catch (Exception e) {
- log.error("调用AI服务失败,modelId: {}, content: {}", linkId, content, e);
+ log.error("调用AI服务失败,modelId: {}", linkId, e);
return buildErrorResponse(CODE_ERROR, "AI服务调用失败");
}
}
+
+ @Override
+ public void streamQuestion(AiRequest aiRequest, Consumer<AiResponse> callback) {
+
+ String promptTemplate = aiRequest.getPromptTemplate();
+ String linkId = aiRequest.getLinkId();
+ String content = aiRequest.getContent();
+ if (!StringUtils.hasText(promptTemplate) || !StringUtils.hasText(linkId) || !StringUtils.hasText(content)) {
+ log.warn("请求参数不完整,promptTemplate: {}, linkId: {}, content: {}", promptTemplate, linkId, content);
+ }
+
+ final List<ChatMessage> messages = new ArrayList<>();
+ final ChatMessage systemMessage = ChatMessage.builder().role(ChatMessageRole.SYSTEM).content(promptTemplate).build();
+ final ChatMessage userMessage = ChatMessage.builder().role(ChatMessageRole.USER).content(content).build();
+ messages.add(systemMessage);
+ messages.add(userMessage);
+
+ try {
+ JsonNode schemaNode = objectMapper.readTree(SCHEMA_JSON);
+ ChatCompletionRequest.ChatCompletionRequestResponseFormat responseFormat = new ChatCompletionRequest.ChatCompletionRequestResponseFormat(
+ "json_schema",
+ new ResponseFormatJSONSchemaJSONSchemaParam(
+ "ai_response",
+ "json数据响应",
+ schemaNode,
+ true
+ )
+ );
+ ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder()
+ .model(linkId)
+ .messages(messages)
+ .stream(true) // 启用流式响应
+ .responseFormat(responseFormat)
+ .temperature(0.7)
+ .topP(0.9)
+ .maxTokens(2048)
+ .build();
+
+ service.streamChatCompletion(chatCompletionRequest)
+ .doOnError(Throwable::printStackTrace) // 处理错误
+ .blockingForEach(response -> {
+ AiResponse partialResponse = new AiResponse();
+ if (response.getChoices() != null && !response.getChoices().isEmpty()) {
+ String responseStr = String.valueOf(response.getChoices().get(0).getMessage().getContent());
+ if (responseStr != null) {
+ // 构造部分响应并回调
+ partialResponse = buildGOINGONResponse(responseStr);
+ }
+ }else{
+ partialResponse = buildPartialResponse("成功");
+ }
+ callback.accept(partialResponse);
+ });
+// service.streamChatCompletion(chatCompletionRequest)
+// .doOnError(throwable -> {
+// log.error("流式调用AI服务失败", throwable);
+// callback.accept(buildErrorResponse(CODE_ERROR, "AI服务调用失败"));
+// })
+// .subscribe(chatCompletionChunk -> {
+// // 处理每个数据块
+// Object chunkContent = chatCompletionChunk.getChoices().get(0).getMessage().getContent();
+// // 构造部分响应并回调
+// AiResponse partialResponse = buildGOINGONResponse(chunkContent);
+// callback.accept(partialResponse);
+// });
+ } catch (Exception e) {
+ log.error("调用AI服务失败", e);
+ callback.accept(buildErrorResponse(CODE_ERROR, "AI服务调用失败"));
+ }
+ }
+
+ private AiResponse buildGOINGONResponse(Object chunkContent) {
+ AiResponse response = new AiResponse();
+ response.setCode(CODE_GOING_ON);
+ response.setDescription("成功");
+ response.setResContext(chunkContent.toString());
+ return response;
+ }
+
+ private AiResponse buildPartialResponse(Object chunkContent) {
+ AiResponse response = new AiResponse();
+ response.setCode(CODE_SUCCESS);
+ response.setDescription("成功");
+ response.setResContext(chunkContent.toString());
+ return response;
+ }
+
private static final Pattern JSON_PATTERN = Pattern.compile(
"<\\|FunctionCallBegin\\|>(.*?)<\\|FunctionCallEnd\\|>",
@@ -182,22 +276,19 @@
@Override
public Report extractReportData(String modelOutput) {
- // 提取JSON部分
Matcher matcher = JSON_PATTERN.matcher(modelOutput);
if (!matcher.find()) {
- log.warn("未匹配到FunctionCall内容,原始输出: {}", modelOutput);
+ log.warn("未匹配到FunctionCall内容,原始输出长度: {}", modelOutput.length());
return null;
}
String jsonContent = matcher.group(1);
- log.debug("提取到的JSON内容: {}", jsonContent);
+ log.debug("提取到的JSON内容长度: {}", jsonContent.length());
- // 解析JSON到Report对象
try {
return objectMapper.readValue(jsonContent, Report.class);
} catch (JsonProcessingException e) {
- log.error("JSON解析失败,原始内容: {}", jsonContent, e);
- // 尝试修复截断的JSON(可选)
+ log.error("JSON解析失败,原始内容长度: {}", jsonContent.length(), e);
Report repairedReport = tryRepairTruncatedJson(jsonContent);
if (repairedReport != null) {
log.info("成功修复截断的JSON");
@@ -207,13 +298,7 @@
}
}
- /**
- * 尝试修复截断的JSON字符串
- * @param truncatedJson 可能被截断的JSON字符串
- * @return 修复后的Report对象,如果无法修复则返回null
- */
private Report tryRepairTruncatedJson(String truncatedJson) {
- // 简单的修复策略:尝试添加缺失的结束括号
String[] repairAttempts = {
truncatedJson + "\"}}}",
truncatedJson + "}}}",
@@ -229,7 +314,7 @@
}
}
- log.warn("无法修复截断的JSON: {}", truncatedJson);
+ log.warn("无法修复截断的JSON,原始内容长度: {}", truncatedJson.length());
return null;
}
--
Gitblit v1.9.1