概述
RAG(Retrieval-Augmented Generation,检索增强生成)是LangChat Pro的核心功能之一,通过检索相关知识库内容来增强大语言模型的生成能力,提供更准确、更专业的回答。架构设计
RAG整体流程
Copy
用户查询
│
▼
┌──────────────────────┐
│ RagRetrieverBuilder │
│ 构建RAG组件 │
└──────────┬───────────┘
│
├─────────────────────┬──────────────────┐
│ │ │
▼ ▼ ▼
┌──────────────────┐ ┌──────────────┐ ┌─────────────────┐
│KnowledgeRetriever│ │SqlRetriever │ │ContentAggregator│
│ 知识库检索 │ │ SQL检索 │ │ 内容聚合器 │
└────────┬─────────┘ └──────┬───────┘ └────────┬────────┘
│ │ │
└────────────────────┼───────────────────┘
│
▼
┌──────────────────────┐
│ DefaultQueryRouter │
│ 查询路由器 │
└──────────────────────┘
│
▼
┌──────────────────────┐
│RetrievalAugmentor │
│ 检索增强器 │
└──────────────────────┘
│
▼
┌──────────────────────┐
│ 注入到Prompt │
└──────────────────────┘
│
▼
┌──────────────────────┐
│ LLM │
│ 生成回答 │
└──────────────────────┘
核心组件
RagRetrieverBuilder
路径:langchat-core/src/main/java/cn/langchat/core/builder/RagRetrieverBuilder.java
职责: 构建RAG检索组件
Copy
@Slf4j
@Component
public class RagRetrieverBuilder {
@Resource
private KnowledgeRetrieverBuilder knowledgeRetrieverBuilder;
@Resource
private SqlRetrieverBuilder sqlRetrieverBuilder;
@Resource
private RerankModelFactory rerankModelFactory;
}
build() 方法
Copy
/**
* 构建RAG组件
*
* @param req 聊天请求对象
* @return RAG组件
*/
public RagComponents build(LcChatReq req) {
log.debug("开始构建RAG组件: userId={}, conversationId={}",
req.getUserId(), req.getConversationId());
var contentRetrievers = new ArrayList<ContentRetriever>();
var rerankModelIds = new LinkedHashSet<String>();
// 1. 构建知识库检索器
var knowledgeResult = knowledgeRetrieverBuilder.build(req);
contentRetrievers.addAll(knowledgeResult.retrievers());
rerankModelIds.addAll(knowledgeResult.rerankModelIds());
log.debug("知识库检索器构建结果: retrieverCount={}, rerankModelCount={}",
knowledgeResult.retrievers().size(),
knowledgeResult.rerankModelIds().size());
// 2. 构建SQL检索器
var sqlRetrievers = sqlRetrieverBuilder.build(req);
contentRetrievers.addAll(sqlRetrievers);
log.debug("SQL检索器构建结果: retrieverCount={}",
sqlRetrievers.size());
// 3. 构建ContentAggregator
var aggregator = buildContentAggregator(rerankModelIds);
log.info("RAG组件构建完成: userId={}, knowledgeRetrieverCount={}, sqlRetrieverCount={}, totalRetrieverCount={}, rerankEnabled={}",
req.getUserId(),
knowledgeResult.retrievers().size(),
sqlRetrievers.size(),
contentRetrievers.size(),
aggregator != null);
return new RagComponents(contentRetrievers, aggregator);
}
/**
* 构建内容聚合器
*/
private ContentAggregator buildContentAggregator(Set<String> rerankModelIds) {
if (CollUtil.isEmpty(rerankModelIds)) {
log.debug("未启用Rerank模型");
return null;
}
// 选择第一个Rerank模型
var chosenId = rerankModelIds.iterator().next();
if (rerankModelIds.size() > 1) {
log.warn("检测到多个Rerank模型: {}, 将使用第一个: {}",
rerankModelIds, chosenId);
}
try {
ScoringModel scoringModel = rerankModelFactory.getRerankModel(chosenId);
var aggregator = ReRankingContentAggregator.builder()
.scoringModel(scoringModel)
.build();
log.info("Rerank聚合器构建成功: modelId={}", chosenId);
return aggregator;
} catch (Exception e) {
log.error("构建Rerank聚合器失败: modelId={}", chosenId, e);
return null;
}
}
/**
* RAG组件
*
* @param retrievers 检索器列表
* @param aggregator 内容聚合器
*/
public record RagComponents(
List<ContentRetriever> retrievers,
ContentAggregator aggregator
) {
}
KnowledgeRetriever(知识库检索)
KnowledgeRetrieverBuilder
职责: 构建知识库检索器 构建流程:Copy
public KnowledgeRetrieverResult build(LcChatReq req) {
var retrievers = new ArrayList<ContentRetriever>();
var rerankModelIds = new LinkedHashSet<String>();
// 获取关联的知识库
var knowledgeIds = req.getKnowledgeIds();
if (CollUtil.isEmpty(knowledgeIds)) {
return new KnowledgeRetrieverResult(retrievers, rerankModelIds);
}
// 遍历每个知识库
for (var knowledgeId : knowledgeIds) {
var knowledge = knowledgeService.getById(knowledgeId);
if (knowledge == null) {
log.warn("知识库不存在: {}", knowledgeId);
continue;
}
// 获取向量库配置
var vectorStore = vectorStoreService.getById(knowledge.getVectorStoreId());
if (vectorStore == null) {
log.warn("向量库不存在: {}", knowledge.getVectorStoreId());
continue;
}
// 创建向量检索器
var retriever = createVectorRetriever(knowledge, vectorStore);
retrievers.add(retriever);
// 收集Rerank模型
if (StrUtil.isNotBlank(knowledge.getRerankModelId())) {
rerankModelIds.add(knowledge.getRerankModelId());
}
}
return new KnowledgeRetrieverResult(retrievers, rerankModelIds);
}
向量检索流程
Copy
1. 用户查询
↓
2. 向量化查询
- 使用Embedding模型将查询文本转换为向量
↓
3. 向量相似度检索
- 在向量数据库中搜索最相似的文档片段
- 支持多种向量库(PGVector, Milvus等)
↓
4. 返回相关文档
- 返回Top-K最相关的文档片段
- 包含文档内容和元数据
支持的向量库
| 向量库 | 说明 | 适用场景 |
|---|---|---|
| PGVector | PostgreSQL向量扩展 | 中小规模,已有PostgreSQL |
| Milvus | 专业向量数据库 | 大规模,高性能 |
| Elasticsearch | 全文搜索+向量 | 需要混合检索 |
| Redis Vector | 高速缓存 | 低延迟场景 |
| Neo4j | 图数据库 | 知识图谱 |
SqlRetriever(SQL检索)
SqlRetrieverBuilder
职责: 构建SQL检索器 构建流程:Copy
public List<ContentRetriever> build(LcChatReq req) {
var retrievers = new ArrayList<ContentRetriever>();
// 获取关联的数据源
var datasourceIds = req.getDatasources();
if (CollUtil.isEmpty(datasourceIds)) {
return retrievers;
}
// 遍历每个数据源
for (var datasourceId : datasourceIds) {
var datasource = datasourceService.getById(datasourceId);
if (datasource == null) {
log.warn("数据源不存在: {}", datasourceId);
continue;
}
// 创建SQL检索器
var retriever = createSqlRetriever(datasource);
retrievers.add(retriever);
}
return retrievers;
}
SQL检索流程
Copy
1. 用户查询
↓
2. Text2SQL生成
- 使用LLM将自然语言转换为SQL
↓
3. SQL执行
- 执行SQL查询数据库
↓
4. 结果处理
- 将查询结果转换为文本
↓
5. 返回检索结果
ContentAggregator(内容聚合)
ReRankingContentAggregator
职责: 对检索结果进行重排序和聚合 工作流程:Copy
1. 接收多个检索器的结果
↓
2. 合并结果
- 去重
- 保留元数据
↓
3. Rerank重排序
- 使用Rerank模型对结果重新打分
- 根据相似度重新排序
↓
4. 选择Top-N结果
- 选择最相关的N个结果
↓
5. 返回最终结果
Copy
var aggregator = ReRankingContentAggregator.builder()
.scoringModel(scoringModel)
.build();
DefaultQueryRouter(查询路由)
职责
- 将查询分发到多个检索器
- 合并所有检索器的结果
- 调用ContentAggregator进行聚合
Copy
var queryRouter = new DefaultQueryRouter(contentRetrievers);
var augmentor = DefaultRetrievalAugmentor.builder()
.contentAggregator(aggregator)
.queryRouter(queryRouter)
.build();
Rerank模型
RerankModelFactory
职责: 创建Rerank模型实例 支持的Rerank模型:- BGE-Reranker
- Cohere Rerank
- 其他兼容的Rerank模型
Copy
public ScoringModel getRerankModel(String modelId) {
// 获取模型配置
var model = aigcModelService.getById(modelId);
// 根据供应商创建对应的Rerank模型
return switch (model.getProvider().toLowerCase()) {
case ProviderConst.bge -> createBgeRerankModel(model);
case ProviderConst.cohere -> createCohereRerankModel(model);
default -> createDefaultRerankModel(model);
};
}
完整RAG流程示例
请求构建
Copy
// 1. 构建RAG组件
var ragComponents = ragRetrieverBuilder.build(req);
// 2. 配置到AI Service
if (CollUtil.isNotEmpty(ragComponents.retrievers())) {
var augmentor = DefaultRetrievalAugmentor.builder()
.contentAggregator(ragComponents.aggregator())
.queryRouter(new DefaultQueryRouter(ragComponents.retrievers()))
.build();
builder.retrievalAugmentor(augmentor);
}
执行流程
Copy
用户提问:"LangChat是什么?"
↓
[1] 向量化查询
- Embedding模型: "text-embedding-ada-002"
- 查询向量: [0.1, 0.2, 0.3, ...]
↓
[2] 向量检索
- 向量库: PGVector
- 检索Top-5文档
↓
[3] 检索结果
- Doc1: "LangChat是一个AI应用平台..."
- Doc2: "LangChat支持多种LLM模型..."
- Doc3: "LangChat采用LangChain4j..."
- Doc4: "LangChat支持知识库检索..."
- Doc5: "LangChat提供工作流编排..."
↓
[4] Rerank重排序
- Rerank模型: bge-reranker-large
- 重新打分和排序
↓
[5] 选择Top-3
- Doc1: score 0.95
- Doc2: score 0.92
- Doc4: score 0.88
↓
[6] 注入Prompt
- 系统提示词: "你是一个AI助手..."
- 检索内容: Doc1 + Doc2 + Doc4
- 用户问题: "LangChat是什么?"
↓
[7] LLM生成
- 根据上下文生成准确回答
↓
[8] 返回答案
"LangChat是一个基于Spring Boot和LangChain4j的AI应用平台..."
配置说明
知识库配置
Copy
public class AigcKnowledge {
private String id;
private String name;
private String vectorStoreId; // 向量库ID
private String rerankModelId; // Rerank模型ID
private Integer topK; // 检索数量
private Double similarityThreshold; // 相似度阈值
}
向量库配置
Copy
public class AigcVectorStore {
private String id;
private String name;
private String type; // 类型: pgvector, milvus, es等
private String host;
private Integer port;
private String database;
private String collection; // 集合/表名
private String apiKey;
}
application.yml
Copy
langchat:
rag:
default-top-k: 5 # 默认检索数量
default-similarity: 0.7 # 默认相似度阈值
enable-rerank: true # 是否启用Rerank
扩展点
1. 添加新的向量库
步骤1: 实现向量检索器Copy
public class CustomVectorRetriever implements ContentRetriever {
private final VectorStore vectorStore;
private final EmbeddingModel embeddingModel;
public CustomVectorRetriever(VectorStore vectorStore, EmbeddingModel embeddingModel) {
this.vectorStore = vectorStore;
this.embeddingModel = embeddingModel;
}
@Override
public List<Content> retrieve(Query query) {
// 1. 向量化查询
TextSegment textSegment = TextSegment.from(query.text());
Embedding embedding = embeddingModel.embed(textSegment).content();
// 2. 向量检索
List<EmbeddingMatch<TextSegment>> matches = vectorStore.findRelevant(embedding, 5);
// 3. 转换为Content
return matches.stream()
.map(match -> Content.from(match.embedded().text()))
.collect(Collectors.toList());
}
}
Copy
private ContentRetriever createVectorRetriever(
AigcKnowledge knowledge,
AigcVectorStore vectorStore
) {
// 根据向量库类型创建检索器
return switch (vectorStore.getType().toLowerCase()) {
case "custom" -> new CustomVectorRetriever(vectorStore, embeddingModel);
default -> throw new IllegalArgumentException("不支持的向量库类型: " + vectorStore.getType());
};
}
2. 自定义ContentAggregator
Copy
@Component
public class CustomContentAggregator implements ContentAggregator {
@Resource
private ScoringModel rerankModel;
@Override
public List<Content> aggregate(Query query, List<Retrieval> retrievals) {
// 1. 合并所有检索结果
List<Content> allContents = retrievals.stream()
.flatMap(retrieval -> retrieval.contents().stream())
.collect(Collectors.toList());
// 2. 去重
Set<String> seen = new HashSet<>();
List<Content> uniqueContents = allContents.stream()
.filter(content -> seen.add(content.textSegment().text()))
.collect(Collectors.toList());
// 3. Rerank
List<ScoredText> scoredTexts = rerankModel.scoreAll(query.text(), uniqueContents);
// 4. 排序并返回Top-N
return scoredTexts.stream()
.sorted((a, b) -> Double.compare(b.score(), a.score()))
.limit(10)
.map(scored -> Content.from(scored.text()))
.collect(Collectors.toList());
}
}
3. 自定义SqlRetriever
Copy
public class CustomSqlRetriever implements ContentRetriever {
private final DataSource dataSource;
private final ChatModel text2SqlModel;
@Override
public List<Content> retrieve(Query query) {
// 1. 生成SQL
String sql = generateSql(query.text());
// 2. 执行SQL
List<Map<String, Object>> results = executeSql(sql);
// 3. 转换为Content
return results.stream()
.map(result -> Content.from(result.toString()))
.collect(Collectors.toList());
}
private String generateSql(String question) {
// 使用LLM生成SQL
String prompt = String.format("""
请将以下自然语言问题转换为SQL查询:
问题:%s
数据库schema:%s
只返回SQL语句,不要其他内容。
""", question, getDatabaseSchema());
return text2SqlModel.chat(prompt);
}
private List<Map<String, Object>> executeSql(String sql) {
// 执行SQL查询
try (Connection conn = dataSource.getConnection();
Statement stmt = conn.createStatement();
ResultSet rs = stmt.executeQuery(sql)) {
ResultSetMetaData metaData = rs.getMetaData();
int columnCount = metaData.getColumnCount();
List<Map<String, Object>> results = new ArrayList<>();
while (rs.next()) {
Map<String, Object> row = new HashMap<>();
for (int i = 1; i <= columnCount; i++) {
row.put(metaData.getColumnName(i), rs.getObject(i));
}
results.add(row);
}
return results;
} catch (SQLException e) {
throw new RuntimeException("SQL执行失败", e);
}
}
}
最佳实践
1. 向量库选择
| 场景 | 推荐向量库 | 原因 |
|---|---|---|
| 中小规模 | PGVector | 简单,已有PostgreSQL |
| 大规模 | Milvus | 高性能,专业向量数据库 |
| 混合检索 | Elasticsearch | 支持全文+向量 |
| 低延迟 | Redis Vector | 内存缓存,极快 |
2. 检索参数调优
- Top-K: 通常5-10,过多影响性能
- 相似度阈值: 0.7-0.8,根据实际情况调整
- 分块大小: 500-1000字符,过大影响精度,过小丢失上下文
3. Rerank使用
- 提高检索准确性
- 增加延迟,根据场景选择
- 适用于高准确性要求的场景
4. 混合检索
- 结合向量检索和关键词检索
- 提高召回率
- Elasticsearch天然支持
性能优化
1. 向量检索优化
- 建立向量索引
- 使用近似检索(IVF, HNSW等)
- 缓存热门查询结果
2. 并行检索
Copy
// 并行执行多个检索器
List<Retrieval> retrievals = contentRetrievers.parallelStream()
.map(retriever -> Retrieval.from(retriever.retrieve(query)))
.collect(Collectors.toList());
3. 缓存策略
- 缓存Embedding结果
- 缓存Rerank结果
- 使用Redis缓存热门查询
4. 批量处理
- 批量向量化
- 批量检索
- 减少网络往返

