Administrator
6 days ago 485c6557ae50afe6703c0b64169ce8eb634b1924
src/main/java/cc/mrbird/febs/yinhe/service/impl/YhAiServiceImpl.java
@@ -1,10 +1,15 @@
package cc.mrbird.febs.yinhe.service.impl;
import cc.mrbird.febs.ai.strategy.LlmStrategyFactory;
import cc.mrbird.febs.ai.strategy.enumerates.LlmStrategyEnum;
import cc.mrbird.febs.ai.utils.UUID;
import cc.mrbird.febs.common.entity.FebsResponse;
import cc.mrbird.febs.common.exception.FebsException;
import cc.mrbird.febs.common.utils.AppContants;
import cc.mrbird.febs.common.utils.LoginUserUtil;
import cc.mrbird.febs.common.utils.RedisUtils;
import cc.mrbird.febs.common.utils.YHLoginUserUtil;
import cc.mrbird.febs.rabbit.producter.AgentProducer;
import cc.mrbird.febs.yinhe.entity.*;
import cc.mrbird.febs.yinhe.mapper.*;
import cc.mrbird.febs.yinhe.req.*;
@@ -13,14 +18,18 @@
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.ObjectUtil;
import cn.hutool.core.util.StrUtil;
import com.alibaba.dashscope.common.Message;
import com.alibaba.dashscope.common.Role;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.baomidou.mybatisplus.core.toolkit.Wrappers;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import reactor.core.publisher.Flux;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
@Slf4j
@@ -34,12 +43,21 @@
    private final YHAiAgentCategoryMapper yhAiAgentCategoryMapper;
    private final YHAiAgentMapper yhAiAgentMapper;
    private final YHAiAgentStartQuestionMapper yhAiAgentStartQuestionMapper;
    private final YHAiAgentKnowledgeMapper yhAiAgentKnowledgeMapper;
    private final YHAiKnowledgeMapper yhAiKnowledgeMapper;
    private final YhSysAgentCategoryMapper yhSysAgentCategoryMapper;
    private final YHSysAgentStartQuestionMapper yhSysAgentStartQuestionMapper;
    private final YHSysAgentMapper yhSysAgentMapper;
    private final YHSysCompanyLevelMapper yhSysCompanyLevelMapper;
    private final YHAiTalkMapper yhAiTalkMapper;
    private final YHAiTalkItemMapper yhAiTalkItemMapper;
    private final RedisUtils redisUtils;
    private final AgentProducer agentProducer;
    private final LlmStrategyFactory llmStrategyFactory;
    @Override
    public FebsResponse memberInfo() {
@@ -246,4 +264,181 @@
        redisUtils.del(AppContants.XCX_LOGIN_PREFIX + memberUuid);
        return new FebsResponse().success().message("退出登录");
    }
    @Override
    public FebsResponse initSend(YHSendInitDto dto) {
        String memberUuid = LoginUserUtil.getLoginUser().getMemberUuid();
        YHSendInitVo vo = new YHSendInitVo();
        String agentId = dto.getId();
        //获取智能体信息
        YHAiAgent yhAiAgent = yhAiAgentMapper.selectById(agentId);
        if (yhAiAgent == null) {
            throw new FebsException("智能体不存在");
        }
        if (yhAiAgent.getState() != 1){
            throw new FebsException("智能体未启用");
        }
        String companyId = yhAiAgent.getCompanyId();
        /**
         * 新增一个会话记录
         */
        YHAiTalk entity = new YHAiTalk();
        entity.setId(UUID.getSimpleUUIDString());
        entity.setCompanyId(companyId);
        entity.setCustomerId(memberUuid);
        entity.setAgentId(yhAiAgent.getId());
        entity.setType(2);
        entity.setCreateTime(new Date());
        yhAiTalkMapper.insert(entity);
        vo.setTalkId(entity.getId());
        return new FebsResponse().success().data(vo);
    }
    @Override
    public FebsResponse saveContext(YHSaveContextDto dto) {
        String talkId = dto.getTalkId();
        String type = dto.getType();
        String content = dto.getContent();
        YHAiTalk aiTalk = yhAiTalkMapper.selectById(talkId);
        if (aiTalk == null) {
            throw new FebsException("会话不存在");
        }
        //保存会话记录
        YHAiTalkItem aiTalkItem = new YHAiTalkItem();
        aiTalkItem.setId(UUID.getSimpleUUIDString());
        aiTalkItem.setCompanyId(aiTalk.getCompanyId());
        aiTalkItem.setTalkId(aiTalk.getId());
        aiTalkItem.setType(type);
        aiTalkItem.setContext(content);
        aiTalkItem.setTokenNum(content.length());
        aiTalkItem.setCreateTime(new Date());
        yhAiTalkItemMapper.insert(aiTalkItem);
        agentProducer.sendAddCompanyToken(aiTalkItem.getId());
        return new FebsResponse().success();
    }
    @Override
    public Flux<FebsResponse> aiAnswer(YHAitalkItemStreamDto dto) {
        String talkId = dto.getTalkId();
        String reqContext = dto.getReqContext();
        YHAiTalk aiTalk = yhAiTalkMapper.selectById(talkId);
        if (aiTalk == null) {
            throw new FebsException("会话不存在");
        }
        //获取智能体的信息
        String agentId = aiTalk.getAgentId();
        YHAiAgent aiAgent = yhAiAgentMapper.selectById(agentId);
        if (aiAgent == null) {
            throw new FebsException("智能体异常");
        }
        //判断字符是否足够
        String companyId = aiTalk.getCompanyId();
        YHAiCompany aiCompany = yhAiCompanyMapper.selectById(companyId);
        if (aiCompany == null) {
            throw new FebsException("知识库异常");
        }
        Integer useToken = aiCompany.getUseToken();
        YHSysCompanyLevel sysCompanyLevel = yhSysCompanyLevelMapper.selectOne(
                Wrappers.lambdaQuery(YHSysCompanyLevel.class)
                        .select(YHSysCompanyLevel::getToken)
                        .eq(YHSysCompanyLevel::getCode, aiCompany.getLevelCode())
        );
        if (useToken > sysCompanyLevel.getToken()) {
            throw new FebsException("字符已消耗完");
        }
        //获取智能体绑定的知识库
        List<String> knowledgeIds = new ArrayList<>();
        String knowledgeId = aiCompany.getKnowledgeId();
        knowledgeIds.add(knowledgeId);
        //获取智能体绑定的查询文件
        List<String> fileIds = new ArrayList<>();
        List<YHAiAgentKnowledge> aiAgentKnowledges = yhAiAgentKnowledgeMapper.selectList(
                Wrappers.lambdaQuery(YHAiAgentKnowledge.class)
                        .select(YHAiAgentKnowledge::getKnowledgeId)
                        .eq(YHAiAgentKnowledge::getAgentId, agentId)
                        .eq(YHAiAgentKnowledge::getCompanyId, companyId)
        );
        if (CollUtil.isNotEmpty(aiAgentKnowledges)){
            List<String> aiKnowledgeIds = new ArrayList<>();
            for (YHAiAgentKnowledge aiAgentKnowledge : aiAgentKnowledges){
                aiKnowledgeIds.add(aiAgentKnowledge.getKnowledgeId());
            }
            if (CollUtil.isNotEmpty(aiKnowledgeIds)){
                List<YHAiKnowledge> aiKnowledges = yhAiKnowledgeMapper.selectList(
                        Wrappers.lambdaQuery(YHAiKnowledge.class)
                                .select(YHAiKnowledge::getFileId)
                                .in(YHAiKnowledge::getId, aiKnowledgeIds)
                );
                if (CollUtil.isNotEmpty(aiKnowledges)){
                    for (YHAiKnowledge aiKnowledge : aiKnowledges){
                        fileIds.add(aiKnowledge.getFileId());
                    }
                }
            }
        }
        AiRequestDto aiRequestDto = new AiRequestDto();
        aiRequestDto.setTalkId(talkId);
        String prompt = aiAgent.getPrompt();
        aiRequestDto.setRolePrompt(prompt);
//        List<Message> messages = new ArrayList<>();
//        messages.add(Message.builder().role(Role.SYSTEM.getValue()).content(prompt).build());
        //获取对话记录
        List<Message> messages = new ArrayList<>();
        List<YHAiTalkItem> aiTalkItemList = yhAiTalkItemMapper.selectList(
                Wrappers.lambdaQuery(YHAiTalkItem.class)
                        .eq(YHAiTalkItem::getTalkId, talkId)
                        .orderByAsc(YHAiTalkItem::getCreateTime)
        );
        if (CollUtil.isNotEmpty(aiTalkItemList)){
            messages = getMessages(messages,aiTalkItemList);
        }
        for (
                Message message : messages
        ){
            log.info("上下文内容:{},{}", message.getRole(),message.getContent());
        }
        aiRequestDto.setMessages(messages);
        aiRequestDto.setKnowledgeIds(knowledgeIds);
        aiRequestDto.setFileIds(fileIds);
        aiRequestDto.setPrompt(reqContext);
        aiRequestDto.setCompanyId(companyId);
        //日志输出详细的请求参数的每一个属性
        log.info("请求参数:{}", aiRequestDto.getPrompt());
        log.info("请求参数:{}", aiRequestDto);
        String modelName = LlmStrategyEnum.getName(2);
        return llmStrategyFactory.getCalculationStrategyMap().get(modelName).llmInvokeStreamingNoThink(aiRequestDto);
    }
    private List<Message> getMessages(List<Message> messages,List<YHAiTalkItem> aiTalkItemList) {
        for (YHAiTalkItem item : aiTalkItemList){
            if (StrUtil.equals(item.getType(), Role.USER.getValue())){
                messages.add(Message.builder()
                        .role(Role.USER.getValue())
                        .content(item.getContext())
                        .build());
            }
            if (StrUtil.equals(item.getType(),Role.ASSISTANT.getValue())){
                messages.add(Message.builder()
                        .role(Role.ASSISTANT.getValue())
                        .content(item.getContext())
                        .build());
            }
        }
        return messages;
    }
}