Administrator
2 days ago 89c8099e23ef5260ce6b3e46339064c559e9cc0f
src/main/java/cc/mrbird/febs/ai/service/impl/AiAgentServiceImpl.java
@@ -3,23 +3,25 @@
import cc.mrbird.febs.ai.entity.*;
import cc.mrbird.febs.ai.enumerates.AiCommonEnum;
import cc.mrbird.febs.ai.enumerates.ProductCategoryLevelEnum;
import cc.mrbird.febs.ai.mapper.AiAgentCategoryMapper;
import cc.mrbird.febs.ai.mapper.AiAgentMapper;
import cc.mrbird.febs.ai.mapper.AiAgentStartQuestionMapper;
import cc.mrbird.febs.ai.req.agent.AiAgentInitDto;
import cc.mrbird.febs.ai.req.agent.ApiAgentCategoryAllDto;
import cc.mrbird.febs.ai.req.agent.ApiAgentPageDto;
import cc.mrbird.febs.ai.mapper.*;
import cc.mrbird.febs.ai.req.agent.*;
import cc.mrbird.febs.ai.res.agent.AiAgentInitVo;
import cc.mrbird.febs.ai.res.agent.ApiAgentCategoryVo;
import cc.mrbird.febs.ai.res.agent.ApiAgentVo;
import cc.mrbird.febs.ai.res.product.ApiProductVo;
import cc.mrbird.febs.ai.res.productCategory.ApiProductCategoryVo;
import cc.mrbird.febs.ai.service.AiAgentService;
import cc.mrbird.febs.ai.strategy.LlmStrategyFactory;
import cc.mrbird.febs.ai.strategy.enumerates.LlmStrategyEnum;
import cc.mrbird.febs.ai.utils.UUID;
import cc.mrbird.febs.common.entity.FebsResponse;
import cc.mrbird.febs.common.exception.FebsException;
import cc.mrbird.febs.common.utils.LoginUserUtil;
import cn.hutool.core.bean.BeanUtil;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.StrUtil;
import com.alibaba.dashscope.common.Message;
import com.alibaba.dashscope.common.Role;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.baomidou.mybatisplus.core.toolkit.Wrappers;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
@@ -27,8 +29,10 @@
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import reactor.core.publisher.Flux;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
@Slf4j
@@ -39,6 +43,12 @@
    private final AiAgentMapper aiAgentMapper;
    private final AiAgentCategoryMapper aiAgentCategoryMapper;
    private final AiAgentStartQuestionMapper aiAgentStartQuestionMapper;
    private final AiTalkMapper aiTalkMapper;
    private final AiTalkItemMapper aiTalkItemMapper;
    private final AiCompanyMapper aiCompanyMapper;
    private final AiAgentKnowledgeMapper aiAgentKnowledgeMapper;
    private final AiKnowledgeFileMapper aiKnowledgeFileMapper;
    private final LlmStrategyFactory llmStrategyFactory;
    @Override
    public FebsResponse allCategoryList(ApiAgentCategoryAllDto dto) {
@@ -98,4 +108,167 @@
        return new FebsResponse().success().data(vo);
    }
    @Override
    public FebsResponse initSend(AgentInitDto dto) {
        String memberUuid = LoginUserUtil.getLoginUser().getMemberUuid();
        AgentSendInitVo vo = new AgentSendInitVo();
        String agentId = dto.getId();
        //获取智能体信息
        AiAgent aiAgent = aiAgentMapper.selectById(agentId);
        if (aiAgent == null) {
            throw new FebsException("智能体不存在");
        }
        if (aiAgent.getState() != 1){
            throw new FebsException("智能体未启用");
        }
        String companyId = aiAgent.getCompanyId();
        /**
         * 新增一个会话记录
         */
        AiTalk entity = new AiTalk();
        entity.setId(UUID.getSimpleUUIDString());
        entity.setCompanyId(companyId);
        entity.setMemberId(memberUuid);
        entity.setAgentId(agentId);
        entity.setCreatedTime(new Date());
        aiTalkMapper.insert(entity);
        vo.setTalkId(entity.getId());
        return new FebsResponse().success().data(vo);
    }
    @Override
    public FebsResponse saveContext(AgentSaveContextDto dto) {
        String talkId = dto.getTalkId();
        String type = dto.getType();
        String content = dto.getContent();
        AiTalk aiTalk = aiTalkMapper.selectById(talkId);
        if (aiTalk == null) {
            throw new FebsException("会话不存在");
        }
        //保存会话记录
        AiTalkItem aiTalkItem = new AiTalkItem();
        aiTalkItem.setId(UUID.getSimpleUUIDString());
        aiTalkItem.setCompanyId(aiTalk.getCompanyId());
        aiTalkItem.setTalkId(aiTalk.getId());
        aiTalkItem.setType(type);
        aiTalkItem.setContext(content);
        aiTalkItem.setCreatedTime(new Date());
        aiTalkItemMapper.insert(aiTalkItem);
        return new FebsResponse().success();
    }
    @Override
    public Flux<FebsResponse> aiAnswer(AitalkItemStreamDto dto) {
        String talkId = dto.getTalkId();
        String reqContext = dto.getReqContext();
        AiTalk aiTalk = aiTalkMapper.selectById(talkId);
        if (aiTalk == null) {
            throw new FebsException("会话不存在");
        }
        String agentId = aiTalk.getAgentId();
        AiAgent aiAgent = aiAgentMapper.selectById(agentId);
        //判断字符是否足够
        String companyId = aiTalk.getCompanyId();
        AiCompany aiCompany = aiCompanyMapper.selectById(companyId);
        if (aiCompany == null) {
            throw new FebsException("知识库异常");
        }
        //获取智能体绑定的知识库
        List<String> knowledgeIds = new ArrayList<>();
        String knowledgeId = aiCompany.getKnowledgeId();
        knowledgeIds.add(knowledgeId);
        //获取智能体绑定的查询文件
        List<String> fileIds = new ArrayList<>();
        List<AiAgentKnowledge> aiAgentKnowledges = aiAgentKnowledgeMapper.selectList(
                Wrappers.lambdaQuery(AiAgentKnowledge.class)
                        .select(AiAgentKnowledge::getKnowledgeId)
                        .eq(AiAgentKnowledge::getAgentId, agentId)
                        .eq(AiAgentKnowledge::getCompanyId, companyId)
        );
        if (CollUtil.isNotEmpty(aiAgentKnowledges)){
            List<String> aiKnowledgeIds = new ArrayList<>();
            for (AiAgentKnowledge aiAgentKnowledge : aiAgentKnowledges){
                aiKnowledgeIds.add(aiAgentKnowledge.getKnowledgeId());
            }
            if (CollUtil.isNotEmpty(aiKnowledgeIds)){
                List<AiKnowledgeFile> aiKnowledges = aiKnowledgeFileMapper.selectList(
                        Wrappers.lambdaQuery(AiKnowledgeFile.class)
                                .select(AiKnowledgeFile::getFileId)
                                .in(AiKnowledgeFile::getId, aiKnowledgeIds)
                );
                if (CollUtil.isNotEmpty(aiKnowledges)){
                    for (AiKnowledgeFile aiKnowledge : aiKnowledges){
                        fileIds.add(aiKnowledge.getFileId());
                    }
                }
            }
        }
        AiRequestDto aiRequestDto = new AiRequestDto();
        aiRequestDto.setTalkId(talkId);
        String prompt = aiAgent.getPrompt();
        aiRequestDto.setRolePrompt(prompt);
//        List<Message> messages = new ArrayList<>();
//        messages.add(Message.builder().role(Role.SYSTEM.getValue()).content(prompt).build());
        //获取对话记录
        List<Message> messages = new ArrayList<>();
        List<AiTalkItem> aiTalkItemList = aiTalkItemMapper.selectList(
                Wrappers.lambdaQuery(AiTalkItem.class)
                        .eq(AiTalkItem::getTalkId, talkId)
                        .orderByAsc(AiTalkItem::getCreatedTime)
        );
        if (CollUtil.isNotEmpty(aiTalkItemList)){
            messages = getMessages(messages,aiTalkItemList);
        }
        for (
                Message message : messages
        ){
            log.info("上下文内容:{},{}", message.getRole(),message.getContent());
        }
        aiRequestDto.setMessages(messages);
        aiRequestDto.setKnowledgeIds(knowledgeIds);
        aiRequestDto.setFileIds(fileIds);
        aiRequestDto.setPrompt(reqContext);
        aiRequestDto.setCompanyId(companyId);
        //日志输出详细的请求参数的每一个属性
        log.info("请求参数:{}", aiRequestDto.getPrompt());
        log.info("请求参数:{}", aiRequestDto);
        String modelName = LlmStrategyEnum.getName(2);
        return llmStrategyFactory.getCalculationStrategyMap().get(modelName).llmInvokeStreamingNoThink(aiRequestDto);
    }
    private List<Message> getMessages(List<Message> messages, List<AiTalkItem> aiTalkItemList) {
        for (AiTalkItem item : aiTalkItemList){
            if (StrUtil.equals(item.getType(), Role.USER.getValue())){
                messages.add(Message.builder()
                        .role(Role.USER.getValue())
                        .content(item.getContext())
                        .build());
            }
            if (StrUtil.equals(item.getType(),Role.ASSISTANT.getValue())){
                messages.add(Message.builder()
                        .role(Role.ASSISTANT.getValue())
                        .content(item.getContext())
                        .build());
            }
        }
        return messages;
    }
}