Administrator
yesterday 49aa4adcf769bb96abcb102e0216a4f32ab3fe92
src/main/java/cc/mrbird/febs/ai/service/impl/AiAgentServiceImpl.java
@@ -1,15 +1,278 @@
package cc.mrbird.febs.ai.service.impl;
import cc.mrbird.febs.ai.entity.AiAgent;
import cc.mrbird.febs.ai.mapper.AiAgentMapper;
import cc.mrbird.febs.ai.entity.*;
import cc.mrbird.febs.ai.enumerates.AiCommonEnum;
import cc.mrbird.febs.ai.enumerates.ProductCategoryLevelEnum;
import cc.mrbird.febs.ai.mapper.*;
import cc.mrbird.febs.ai.req.agent.*;
import cc.mrbird.febs.ai.res.agent.AiAgentInitVo;
import cc.mrbird.febs.ai.res.agent.ApiAgentCategoryVo;
import cc.mrbird.febs.ai.res.agent.ApiAgentVo;
import cc.mrbird.febs.ai.res.product.ApiProductVo;
import cc.mrbird.febs.ai.res.productCategory.ApiProductCategoryVo;
import cc.mrbird.febs.ai.service.AiAgentService;
import cc.mrbird.febs.ai.strategy.LlmStrategyFactory;
import cc.mrbird.febs.ai.strategy.enumerates.LlmStrategyEnum;
import cc.mrbird.febs.ai.utils.UUID;
import cc.mrbird.febs.common.entity.FebsResponse;
import cc.mrbird.febs.common.exception.FebsException;
import cc.mrbird.febs.common.utils.LoginUserUtil;
import cn.hutool.core.bean.BeanUtil;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.StrUtil;
import com.alibaba.dashscope.common.Message;
import com.alibaba.dashscope.common.Role;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.baomidou.mybatisplus.core.toolkit.Wrappers;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import reactor.core.publisher.Flux;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
@Slf4j
@Service
@RequiredArgsConstructor
public class AiAgentServiceImpl extends ServiceImpl<AiAgentMapper, AiAgent> implements AiAgentService {
    private final AiAgentMapper aiAgentMapper;
    private final AiAgentCategoryMapper aiAgentCategoryMapper;
    private final AiAgentStartQuestionMapper aiAgentStartQuestionMapper;
    private final AiTalkMapper aiTalkMapper;
    private final AiTalkItemMapper aiTalkItemMapper;
    private final AiCompanyMapper aiCompanyMapper;
    private final AiAgentKnowledgeMapper aiAgentKnowledgeMapper;
    private final AiKnowledgeFileMapper aiKnowledgeFileMapper;
    private final LlmStrategyFactory llmStrategyFactory;
    @Override
    public FebsResponse allCategoryList(ApiAgentCategoryAllDto dto) {
        List<ApiAgentCategoryVo> list = new ArrayList<>();
        LambdaQueryWrapper<AiAgentCategory> query = Wrappers.lambdaQuery(AiAgentCategory.class);
        if (StrUtil.isEmpty(dto.getCompanyId())){
            dto.setCompanyId(AiCommonEnum.COMPANY_ID.getPrompt());
        }
        query.eq(AiAgentCategory::getCompanyId, dto.getCompanyId());
        query.eq(AiAgentCategory::getState, 1);
        query.orderByAsc(AiAgentCategory::getSort);
        List<AiAgentCategory> listByQuery = aiAgentCategoryMapper.selectList(query);
        if (CollUtil.isNotEmpty(listByQuery)){
            for (AiAgentCategory entity : listByQuery){
                ApiAgentCategoryVo vo = new ApiAgentCategoryVo();
                vo.setId(entity.getId());
                vo.setName(entity.getName());
                list.add(vo);
            }
        }
        return new FebsResponse().success().data(list);
    }
    @Override
    public FebsResponse agentList(ApiAgentPageDto dto) {
        // 创建分页对象,传入当前页和每页大小
        if (StrUtil.isEmpty(dto.getCompanyId())){
            dto.setCompanyId(AiCommonEnum.COMPANY_ID.getPrompt());
        }
        Page<ApiAgentVo> page = new Page<>(dto.getPageNow(), dto.getPageSize());
        Page<ApiAgentVo> pageListByQuery = aiAgentMapper.getPageListByQuery(page, dto);
        return new FebsResponse().success().data(pageListByQuery);
    }
    @Override
    public FebsResponse initAgent(AiAgentInitDto dto) {
        String id = dto.getId();
        AiAgent aiAgent = aiAgentMapper.selectById(id);
        if (aiAgent == null) {
            throw new FebsException("智能体异常");
        }
        AiAgentInitVo vo = new AiAgentInitVo();
        //将chatWebPlugin复制给apiInitPluginVo
        BeanUtil.copyProperties(aiAgent, vo);
        List<AiAgentStartQuestion> aiAgentStartQuestions = aiAgentStartQuestionMapper.selectList(
                Wrappers.lambdaQuery(AiAgentStartQuestion.class)
                        .select(AiAgentStartQuestion::getTitle)
                        .eq(AiAgentStartQuestion::getAgentId, id)
        );
        if (CollUtil.isNotEmpty(aiAgentStartQuestions)){
            List<String> items = new ArrayList<>();
            for (AiAgentStartQuestion aiAgentStartQuestion : aiAgentStartQuestions) {
                items.add(aiAgentStartQuestion.getTitle());
            }
            vo.setItems( items);
        }
        return new FebsResponse().success().data(vo);
    }
    @Override
    public FebsResponse initSend(AgentInitDto dto) {
        String memberUuid = LoginUserUtil.getLoginUser().getMemberUuid();
        AgentSendInitVo vo = new AgentSendInitVo();
        String agentId = dto.getId();
        //获取智能体信息
        AiAgent aiAgent = aiAgentMapper.selectById(agentId);
        if (aiAgent == null) {
            throw new FebsException("智能体不存在");
        }
        if (aiAgent.getState() != 1){
            throw new FebsException("智能体未启用");
        }
        String companyId = aiAgent.getCompanyId();
        /**
         * 新增一个会话记录
         */
        AiTalk entity = new AiTalk();
        entity.setId(UUID.getSimpleUUIDString());
        entity.setCompanyId(companyId);
        entity.setMemberId(memberUuid);
        entity.setAgentId(agentId);
        entity.setCreatedTime(new Date());
        aiTalkMapper.insert(entity);
        vo.setTalkId(entity.getId());
        return new FebsResponse().success().data(vo);
    }
    @Override
    public FebsResponse saveContext(AgentSaveContextDto dto) {
        String talkId = dto.getTalkId();
        String type = dto.getType();
        String content = dto.getContent();
        AiTalk aiTalk = aiTalkMapper.selectById(talkId);
        if (aiTalk == null) {
            throw new FebsException("会话不存在");
        }
        //保存会话记录
        AiTalkItem aiTalkItem = new AiTalkItem();
        aiTalkItem.setId(UUID.getSimpleUUIDString());
        aiTalkItem.setCompanyId(aiTalk.getCompanyId());
        aiTalkItem.setTalkId(aiTalk.getId());
        aiTalkItem.setType(type);
        aiTalkItem.setContext(content);
        aiTalkItem.setCreatedTime(new Date());
        aiTalkItemMapper.insert(aiTalkItem);
        return new FebsResponse().success();
    }
    @Override
    public Flux<FebsResponse> aiAnswer(AitalkItemStreamDto dto) {
        String talkId = dto.getTalkId();
        String reqContext = dto.getReqContext();
        AiTalk aiTalk = aiTalkMapper.selectById(talkId);
        if (aiTalk == null) {
            throw new FebsException("会话不存在");
        }
        String agentId = aiTalk.getAgentId();
        AiAgent aiAgent = aiAgentMapper.selectById(agentId);
        //判断字符是否足够
        String companyId = aiTalk.getCompanyId();
        AiCompany aiCompany = aiCompanyMapper.selectById(companyId);
        if (aiCompany == null) {
            throw new FebsException("知识库异常");
        }
        //获取智能体绑定的知识库
        List<String> knowledgeIds = new ArrayList<>();
        String knowledgeId = aiCompany.getKnowledgeId();
        knowledgeIds.add(knowledgeId);
        //获取智能体绑定的查询文件
        List<String> fileIds = new ArrayList<>();
        List<AiAgentKnowledge> aiAgentKnowledges = aiAgentKnowledgeMapper.selectList(
                Wrappers.lambdaQuery(AiAgentKnowledge.class)
                        .select(AiAgentKnowledge::getKnowledgeId)
                        .eq(AiAgentKnowledge::getAgentId, agentId)
                        .eq(AiAgentKnowledge::getCompanyId, companyId)
        );
        if (CollUtil.isNotEmpty(aiAgentKnowledges)){
            List<String> aiKnowledgeIds = new ArrayList<>();
            for (AiAgentKnowledge aiAgentKnowledge : aiAgentKnowledges){
                aiKnowledgeIds.add(aiAgentKnowledge.getKnowledgeId());
            }
            if (CollUtil.isNotEmpty(aiKnowledgeIds)){
                List<AiKnowledgeFile> aiKnowledges = aiKnowledgeFileMapper.selectList(
                        Wrappers.lambdaQuery(AiKnowledgeFile.class)
                                .select(AiKnowledgeFile::getFileId)
                                .in(AiKnowledgeFile::getId, aiKnowledgeIds)
                );
                if (CollUtil.isNotEmpty(aiKnowledges)){
                    for (AiKnowledgeFile aiKnowledge : aiKnowledges){
                        fileIds.add(aiKnowledge.getFileId());
                    }
                }
            }
        }
        AiRequestDto aiRequestDto = new AiRequestDto();
        aiRequestDto.setTalkId(talkId);
        String prompt = aiAgent.getPrompt();
        aiRequestDto.setRolePrompt(prompt);
//        List<Message> messages = new ArrayList<>();
//        messages.add(Message.builder().role(Role.SYSTEM.getValue()).content(prompt).build());
        //获取对话记录
        List<Message> messages = new ArrayList<>();
        List<AiTalkItem> aiTalkItemList = aiTalkItemMapper.selectList(
                Wrappers.lambdaQuery(AiTalkItem.class)
                        .eq(AiTalkItem::getTalkId, talkId)
                        .orderByAsc(AiTalkItem::getCreatedTime)
        );
        if (CollUtil.isNotEmpty(aiTalkItemList)){
            messages = getMessages(messages,aiTalkItemList);
        }
        for (
                Message message : messages
        ){
            log.info("上下文内容:{},{}", message.getRole(),message.getContent());
        }
        aiRequestDto.setMessages(messages);
        aiRequestDto.setKnowledgeIds(knowledgeIds);
        aiRequestDto.setFileIds(fileIds);
        aiRequestDto.setPrompt(reqContext);
        aiRequestDto.setCompanyId(companyId);
        //日志输出详细的请求参数的每一个属性
        log.info("请求参数:{}", aiRequestDto.getPrompt());
        log.info("请求参数:{}", aiRequestDto);
        String modelName = LlmStrategyEnum.getName(2);
        return llmStrategyFactory.getCalculationStrategyMap().get(modelName).llmInvokeStreamingNoThink(aiRequestDto);
    }
    private List<Message> getMessages(List<Message> messages, List<AiTalkItem> aiTalkItemList) {
        for (AiTalkItem item : aiTalkItemList){
            if (StrUtil.equals(item.getType(), Role.USER.getValue())){
                messages.add(Message.builder()
                        .role(Role.USER.getValue())
                        .content(item.getContext())
                        .build());
            }
            if (StrUtil.equals(item.getType(),Role.ASSISTANT.getValue())){
                messages.add(Message.builder()
                        .role(Role.ASSISTANT.getValue())
                        .content(item.getContext())
                        .build());
            }
        }
        return messages;
    }
}