Skip to main content

概述

RAG(Retrieval-Augmented Generation,检索增强生成)是LangChat Pro的核心功能之一,通过检索相关知识库内容来增强大语言模型的生成能力,提供更准确、更专业的回答。

架构设计

RAG整体流程

用户查询


┌──────────────────────┐
│  RagRetrieverBuilder │
│   构建RAG组件       │
└──────────┬───────────┘

           ├─────────────────────┬──────────────────┐
           │                     │                  │
           ▼                     ▼                  ▼
┌──────────────────┐  ┌──────────────┐  ┌─────────────────┐
│KnowledgeRetriever│  │SqlRetriever │  │ContentAggregator│
│  知识库检索      │  │ SQL检索      │  │   内容聚合器     │
└────────┬─────────┘  └──────┬───────┘  └────────┬────────┘
         │                    │                   │
         └────────────────────┼───────────────────┘


                    ┌──────────────────────┐
                    │  DefaultQueryRouter │
                    │    查询路由器       │
                    └──────────────────────┘


                    ┌──────────────────────┐
                    │RetrievalAugmentor  │
                    │   检索增强器       │
                    └──────────────────────┘


                    ┌──────────────────────┐
                    │    注入到Prompt     │
                    └──────────────────────┘


                    ┌──────────────────────┐
                    │       LLM         │
                    │    生成回答         │
                    └──────────────────────┘

核心组件

RagRetrieverBuilder

路径: langchat-core/src/main/java/cn/langchat/core/builder/RagRetrieverBuilder.java 职责: 构建RAG检索组件
@Slf4j
@Component
public class RagRetrieverBuilder {

    @Resource
    private KnowledgeRetrieverBuilder knowledgeRetrieverBuilder;

    @Resource
    private SqlRetrieverBuilder sqlRetrieverBuilder;

    @Resource
    private RerankModelFactory rerankModelFactory;
}

build() 方法

/**
 * 构建RAG组件
 *
 * @param req 聊天请求对象
 * @return RAG组件
 */
public RagComponents build(LcChatReq req) {
    log.debug("开始构建RAG组件: userId={}, conversationId={}", 
            req.getUserId(), req.getConversationId());

    var contentRetrievers = new ArrayList<ContentRetriever>();
    var rerankModelIds = new LinkedHashSet<String>();

    // 1. 构建知识库检索器
    var knowledgeResult = knowledgeRetrieverBuilder.build(req);
    contentRetrievers.addAll(knowledgeResult.retrievers());
    rerankModelIds.addAll(knowledgeResult.rerankModelIds());

    log.debug("知识库检索器构建结果: retrieverCount={}, rerankModelCount={}",
            knowledgeResult.retrievers().size(), 
            knowledgeResult.rerankModelIds().size());

    // 2. 构建SQL检索器
    var sqlRetrievers = sqlRetrieverBuilder.build(req);
    contentRetrievers.addAll(sqlRetrievers);

    log.debug("SQL检索器构建结果: retrieverCount={}", 
            sqlRetrievers.size());

    // 3. 构建ContentAggregator
    var aggregator = buildContentAggregator(rerankModelIds);

    log.info("RAG组件构建完成: userId={}, knowledgeRetrieverCount={}, sqlRetrieverCount={}, totalRetrieverCount={}, rerankEnabled={}",
            req.getUserId(),
            knowledgeResult.retrievers().size(),
            sqlRetrievers.size(),
            contentRetrievers.size(),
            aggregator != null);

    return new RagComponents(contentRetrievers, aggregator);
}

/**
 * 构建内容聚合器
 */
private ContentAggregator buildContentAggregator(Set<String> rerankModelIds) {
    if (CollUtil.isEmpty(rerankModelIds)) {
        log.debug("未启用Rerank模型");
        return null;
    }

    // 选择第一个Rerank模型
    var chosenId = rerankModelIds.iterator().next();

    if (rerankModelIds.size() > 1) {
        log.warn("检测到多个Rerank模型: {}, 将使用第一个: {}", 
                rerankModelIds, chosenId);
    }

    try {
        ScoringModel scoringModel = rerankModelFactory.getRerankModel(chosenId);
        var aggregator = ReRankingContentAggregator.builder()
            .scoringModel(scoringModel)
            .build();

        log.info("Rerank聚合器构建成功: modelId={}", chosenId);
        return aggregator;

    } catch (Exception e) {
        log.error("构建Rerank聚合器失败: modelId={}", chosenId, e);
        return null;
    }
}

/**
 * RAG组件
 *
 * @param retrievers 检索器列表
 * @param aggregator 内容聚合器
 */
public record RagComponents(
    List<ContentRetriever> retrievers,
    ContentAggregator aggregator
) {
}

KnowledgeRetriever(知识库检索)

KnowledgeRetrieverBuilder

职责: 构建知识库检索器 构建流程:
public KnowledgeRetrieverResult build(LcChatReq req) {
    var retrievers = new ArrayList<ContentRetriever>();
    var rerankModelIds = new LinkedHashSet<String>();

    // 获取关联的知识库
    var knowledgeIds = req.getKnowledgeIds();
    
    if (CollUtil.isEmpty(knowledgeIds)) {
        return new KnowledgeRetrieverResult(retrievers, rerankModelIds);
    }

    // 遍历每个知识库
    for (var knowledgeId : knowledgeIds) {
        var knowledge = knowledgeService.getById(knowledgeId);
        
        if (knowledge == null) {
            log.warn("知识库不存在: {}", knowledgeId);
            continue;
        }

        // 获取向量库配置
        var vectorStore = vectorStoreService.getById(knowledge.getVectorStoreId());
        
        if (vectorStore == null) {
            log.warn("向量库不存在: {}", knowledge.getVectorStoreId());
            continue;
        }

        // 创建向量检索器
        var retriever = createVectorRetriever(knowledge, vectorStore);
        retrievers.add(retriever);

        // 收集Rerank模型
        if (StrUtil.isNotBlank(knowledge.getRerankModelId())) {
            rerankModelIds.add(knowledge.getRerankModelId());
        }
    }

    return new KnowledgeRetrieverResult(retrievers, rerankModelIds);
}

向量检索流程

1. 用户查询

2. 向量化查询
   - 使用Embedding模型将查询文本转换为向量

3. 向量相似度检索
   - 在向量数据库中搜索最相似的文档片段
   - 支持多种向量库(PGVector, Milvus等)

4. 返回相关文档
   - 返回Top-K最相关的文档片段
   - 包含文档内容和元数据

支持的向量库

向量库说明适用场景
PGVectorPostgreSQL向量扩展中小规模,已有PostgreSQL
Milvus专业向量数据库大规模,高性能
Elasticsearch全文搜索+向量需要混合检索
Redis Vector高速缓存低延迟场景
Neo4j图数据库知识图谱

SqlRetriever(SQL检索)

SqlRetrieverBuilder

职责: 构建SQL检索器 构建流程:
public List<ContentRetriever> build(LcChatReq req) {
    var retrievers = new ArrayList<ContentRetriever>();

    // 获取关联的数据源
    var datasourceIds = req.getDatasources();
    
    if (CollUtil.isEmpty(datasourceIds)) {
        return retrievers;
    }

    // 遍历每个数据源
    for (var datasourceId : datasourceIds) {
        var datasource = datasourceService.getById(datasourceId);
        
        if (datasource == null) {
            log.warn("数据源不存在: {}", datasourceId);
            continue;
        }

        // 创建SQL检索器
        var retriever = createSqlRetriever(datasource);
        retrievers.add(retriever);
    }

    return retrievers;
}

SQL检索流程

1. 用户查询

2. Text2SQL生成
   - 使用LLM将自然语言转换为SQL

3. SQL执行
   - 执行SQL查询数据库

4. 结果处理
   - 将查询结果转换为文本

5. 返回检索结果

ContentAggregator(内容聚合)

ReRankingContentAggregator

职责: 对检索结果进行重排序和聚合 工作流程:
1. 接收多个检索器的结果

2. 合并结果
   - 去重
   - 保留元数据

3. Rerank重排序
   - 使用Rerank模型对结果重新打分
   - 根据相似度重新排序

4. 选择Top-N结果
   - 选择最相关的N个结果

5. 返回最终结果
代码示例:
var aggregator = ReRankingContentAggregator.builder()
    .scoringModel(scoringModel)
    .build();

DefaultQueryRouter(查询路由)

职责

  • 将查询分发到多个检索器
  • 合并所有检索器的结果
  • 调用ContentAggregator进行聚合
代码示例:
var queryRouter = new DefaultQueryRouter(contentRetrievers);

var augmentor = DefaultRetrievalAugmentor.builder()
    .contentAggregator(aggregator)
    .queryRouter(queryRouter)
    .build();

Rerank模型

RerankModelFactory

职责: 创建Rerank模型实例 支持的Rerank模型:
  • BGE-Reranker
  • Cohere Rerank
  • 其他兼容的Rerank模型
创建流程:
public ScoringModel getRerankModel(String modelId) {
    // 获取模型配置
    var model = aigcModelService.getById(modelId);
    
    // 根据供应商创建对应的Rerank模型
    return switch (model.getProvider().toLowerCase()) {
        case ProviderConst.bge -> createBgeRerankModel(model);
        case ProviderConst.cohere -> createCohereRerankModel(model);
        default -> createDefaultRerankModel(model);
    };
}

完整RAG流程示例

请求构建

// 1. 构建RAG组件
var ragComponents = ragRetrieverBuilder.build(req);

// 2. 配置到AI Service
if (CollUtil.isNotEmpty(ragComponents.retrievers())) {
    var augmentor = DefaultRetrievalAugmentor.builder()
        .contentAggregator(ragComponents.aggregator())
        .queryRouter(new DefaultQueryRouter(ragComponents.retrievers()))
        .build();
    builder.retrievalAugmentor(augmentor);
}

执行流程

用户提问:"LangChat是什么?"

[1] 向量化查询
    - Embedding模型: "text-embedding-ada-002"
    - 查询向量: [0.1, 0.2, 0.3, ...]

[2] 向量检索
    - 向量库: PGVector
    - 检索Top-5文档

[3] 检索结果
    - Doc1: "LangChat是一个AI应用平台..."
    - Doc2: "LangChat支持多种LLM模型..."
    - Doc3: "LangChat采用LangChain4j..."
    - Doc4: "LangChat支持知识库检索..."
    - Doc5: "LangChat提供工作流编排..."

[4] Rerank重排序
    - Rerank模型: bge-reranker-large
    - 重新打分和排序

[5] 选择Top-3
    - Doc1: score 0.95
    - Doc2: score 0.92
    - Doc4: score 0.88

[6] 注入Prompt
    - 系统提示词: "你是一个AI助手..."
    - 检索内容: Doc1 + Doc2 + Doc4
    - 用户问题: "LangChat是什么?"

[7] LLM生成
    - 根据上下文生成准确回答

[8] 返回答案
    "LangChat是一个基于Spring Boot和LangChain4j的AI应用平台..."

配置说明

知识库配置

public class AigcKnowledge {
    private String id;
    private String name;
    private String vectorStoreId;      // 向量库ID
    private String rerankModelId;      // Rerank模型ID
    private Integer topK;              // 检索数量
    private Double similarityThreshold;  // 相似度阈值
}

向量库配置

public class AigcVectorStore {
    private String id;
    private String name;
    private String type;              // 类型: pgvector, milvus, es等
    private String host;
    private Integer port;
    private String database;
    private String collection;          // 集合/表名
    private String apiKey;
}

application.yml

langchat:
  rag:
    default-top-k: 5               # 默认检索数量
    default-similarity: 0.7        # 默认相似度阈值
    enable-rerank: true            # 是否启用Rerank

扩展点

1. 添加新的向量库

步骤1: 实现向量检索器
public class CustomVectorRetriever implements ContentRetriever {
    
    private final VectorStore vectorStore;
    private final EmbeddingModel embeddingModel;
    
    public CustomVectorRetriever(VectorStore vectorStore, EmbeddingModel embeddingModel) {
        this.vectorStore = vectorStore;
        this.embeddingModel = embeddingModel;
    }
    
    @Override
    public List<Content> retrieve(Query query) {
        // 1. 向量化查询
        TextSegment textSegment = TextSegment.from(query.text());
        Embedding embedding = embeddingModel.embed(textSegment).content();
        
        // 2. 向量检索
        List<EmbeddingMatch<TextSegment>> matches = vectorStore.findRelevant(embedding, 5);
        
        // 3. 转换为Content
        return matches.stream()
            .map(match -> Content.from(match.embedded().text()))
            .collect(Collectors.toList());
    }
}
步骤2: 在KnowledgeRetrieverBuilder中注册
private ContentRetriever createVectorRetriever(
    AigcKnowledge knowledge, 
    AigcVectorStore vectorStore
) {
    // 根据向量库类型创建检索器
    return switch (vectorStore.getType().toLowerCase()) {
        case "custom" -> new CustomVectorRetriever(vectorStore, embeddingModel);
        default -> throw new IllegalArgumentException("不支持的向量库类型: " + vectorStore.getType());
    };
}

2. 自定义ContentAggregator

@Component
public class CustomContentAggregator implements ContentAggregator {
    
    @Resource
    private ScoringModel rerankModel;
    
    @Override
    public List<Content> aggregate(Query query, List<Retrieval> retrievals) {
        // 1. 合并所有检索结果
        List<Content> allContents = retrievals.stream()
            .flatMap(retrieval -> retrieval.contents().stream())
            .collect(Collectors.toList());
        
        // 2. 去重
        Set<String> seen = new HashSet<>();
        List<Content> uniqueContents = allContents.stream()
            .filter(content -> seen.add(content.textSegment().text()))
            .collect(Collectors.toList());
        
        // 3. Rerank
        List<ScoredText> scoredTexts = rerankModel.scoreAll(query.text(), uniqueContents);
        
        // 4. 排序并返回Top-N
        return scoredTexts.stream()
            .sorted((a, b) -> Double.compare(b.score(), a.score()))
            .limit(10)
            .map(scored -> Content.from(scored.text()))
            .collect(Collectors.toList());
    }
}

3. 自定义SqlRetriever

public class CustomSqlRetriever implements ContentRetriever {
    
    private final DataSource dataSource;
    private final ChatModel text2SqlModel;
    
    @Override
    public List<Content> retrieve(Query query) {
        // 1. 生成SQL
        String sql = generateSql(query.text());
        
        // 2. 执行SQL
        List<Map<String, Object>> results = executeSql(sql);
        
        // 3. 转换为Content
        return results.stream()
            .map(result -> Content.from(result.toString()))
            .collect(Collectors.toList());
    }
    
    private String generateSql(String question) {
        // 使用LLM生成SQL
        String prompt = String.format("""
            请将以下自然语言问题转换为SQL查询:
            问题:%s
            数据库schema:%s
            只返回SQL语句,不要其他内容。
            """, question, getDatabaseSchema());
        
        return text2SqlModel.chat(prompt);
    }
    
    private List<Map<String, Object>> executeSql(String sql) {
        // 执行SQL查询
        try (Connection conn = dataSource.getConnection();
             Statement stmt = conn.createStatement();
             ResultSet rs = stmt.executeQuery(sql)) {
            
            ResultSetMetaData metaData = rs.getMetaData();
            int columnCount = metaData.getColumnCount();
            List<Map<String, Object>> results = new ArrayList<>();
            
            while (rs.next()) {
                Map<String, Object> row = new HashMap<>();
                for (int i = 1; i <= columnCount; i++) {
                    row.put(metaData.getColumnName(i), rs.getObject(i));
                }
                results.add(row);
            }
            
            return results;
        } catch (SQLException e) {
            throw new RuntimeException("SQL执行失败", e);
        }
    }
}

最佳实践

1. 向量库选择

场景推荐向量库原因
中小规模PGVector简单,已有PostgreSQL
大规模Milvus高性能,专业向量数据库
混合检索Elasticsearch支持全文+向量
低延迟Redis Vector内存缓存,极快

2. 检索参数调优

  • Top-K: 通常5-10,过多影响性能
  • 相似度阈值: 0.7-0.8,根据实际情况调整
  • 分块大小: 500-1000字符,过大影响精度,过小丢失上下文

3. Rerank使用

  • 提高检索准确性
  • 增加延迟,根据场景选择
  • 适用于高准确性要求的场景

4. 混合检索

  • 结合向量检索和关键词检索
  • 提高召回率
  • Elasticsearch天然支持

性能优化

1. 向量检索优化

  • 建立向量索引
  • 使用近似检索(IVF, HNSW等)
  • 缓存热门查询结果

2. 并行检索

// 并行执行多个检索器
List<Retrieval> retrievals = contentRetrievers.parallelStream()
    .map(retriever -> Retrieval.from(retriever.retrieve(query)))
    .collect(Collectors.toList());

3. 缓存策略

  • 缓存Embedding结果
  • 缓存Rerank结果
  • 使用Redis缓存热门查询

4. 批量处理

  • 批量向量化
  • 批量检索
  • 减少网络往返

参考文档