RAG技术原理:大模型检索增强生成详解

🎙️ 语音朗读 当前: 晓晓 (温柔女声)

RAG技术原理:大模型检索增强生成详解

引言

检索增强生成(Retrieval-Augmented Generation, RAG)是一种将大规模语言模型与外部知识检索相结合的技术。RAG解决了大模型的两大核心问题:知识过时和幻觉问题。本文将深入解析RAG的技术原理、实现方法和最佳实践。

为什么需要RAG

大语言模型的局限性

问题 说明 影响
知识截止 训练数据有时间限制 无法回答最新问题
幻觉 可能生成看似合理但错误的答案 可靠性降低
长尾知识 对稀有领域知识覆盖不足 专业领域应用受限
领域知识 缺乏特定行业的专业知识 企业应用受限
实时信息 无法访问实时数据和新闻 应用场景受限

RAG的核心思想

RAG通过”检索-增强-生成”的流程来解决上述问题:

1
用户问题 → 检索相关文档 → 将文档加入Prompt → LLM生成答案

RAG系统架构

整体流程

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
class RAGSystem:
"""
RAG系统完整架构
"""

def __init__(self, config):
# 1. 文档处理组件
self.document_loader = DocumentLoader(config['loader_type'])
self.text_splitter = TextSplitter(
chunk_size=config['chunk_size'],
chunk_overlap=config['chunk_overlap']
)

# 2. 向量化组件
self.embedding_model = EmbeddingModel(config['embedding_model'])

# 3. 向量存储
self.vector_store = VectorStore(
config['vector_store_type'],
config['vector_store_path']
)

# 4. 检索器
self.retriever = Retriever(
top_k=config['top_k'],
similarity_threshold=config['similarity_threshold']
)

# 5. 生成器
self.llm = LLMModel(config['llm_type'], config['llm_config'])

def index_documents(self, documents):
"""构建索引"""
# 文档加载
docs = self.document_loader.load(documents)

# 文本分块
chunks = self.text_splitter.split(docs)

# 向量化
embeddings = self.embedding_model.encode(chunks)

# 存储
self.vector_store.add(chunks, embeddings)

def retrieve(self, query, top_k=5):
"""检索相关文档"""
# 查询向量化
query_embedding = self.embedding_model.encode([query])

# 相似度搜索
results = self.vector_store.search(query_embedding, top_k)

return results

def generate(self, query, context_docs):
"""增强生成"""
# 构建Prompt
prompt = self.build_prompt(query, context_docs)

# LLM生成
response = self.llm.generate(prompt)

return response

def build_prompt(self, query, context_docs):
"""构建RAG Prompt"""
context = "\n\n".join([
f"文档 {i+1}:\n{doc.content}"
for i, doc in enumerate(context_docs)
])

prompt = f"""基于以下参考文档回答问题。如果文档中没有相关信息,请说明不知道。

参考文档:
{context}

问题:{query}

回答:"""

return prompt

def query(self, question):
"""完整RAG流程"""
# 1. 检索
retrieved_docs = self.retrieve(question, top_k=self.config['top_k'])

# 2. 生成
response = self.generate(question, retrieved_docs)

return {
'answer': response,
'sources': retrieved_docs
}

文档处理与分块策略

文档加载器

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class DocumentLoader:
"""多格式文档加载器"""

def __init__(self, loader_type='pdf'):
self.loader_type = loader_type

def load(self, file_path):
if self.loader_type == 'pdf':
return self._load_pdf(file_path)
elif self.loader_type == 'docx':
return self._load_docx(file_path)
elif self.loader_type == 'html':
return self._load_html(file_path)
elif self.loader_type == 'markdown':
return self._load_markdown(file_path)
else:
raise ValueError(f"Unsupported format: {self.loader_type}")

def _load_pdf(self, file_path):
"""PDF加载"""
from pypdf import PdfReader

reader = PdfReader(file_path)
documents = []

for page_num, page in enumerate(reader.pages):
text = page.extract_text()
documents.append({
'content': text,
'metadata': {
'source': file_path,
'page': page_num + 1
}
})

return documents

分块策略

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
class TextSplitter:
"""文本分块器"""

def __init__(self, chunk_size=500, chunk_overlap=50):
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap

def split(self, documents):
"""分割文档为小块"""
chunks = []

for doc in documents:
text = doc['content']
metadata = doc['metadata']

# 使用滑动窗口分割
start = 0
while start < len(text):
end = start + self.chunk_size
chunk_text = text[start:end]

chunks.append({
'content': chunk_text,
'metadata': metadata,
'start_char': start,
'end_char': end
})

start += (self.chunk_size - self.chunk_overlap)

return chunks

def split_by_semantic(self, text, model):
"""
语义分块:根据语义边界分割
"""
# 句子级别分割
sentences = self._split_sentences(text)

# 构建句子嵌入
embeddings = model.encode(sentences)

# 基于嵌入相似度决定分块点
chunks = []
current_chunk = [sentences[0]]

for i in range(1, len(sentences)):
# 检查与前一个句子的相似度
similarity = cosine_similarity(
embeddings[i-1], embeddings[i]
)

if similarity > 0.7: # 语义相近,合并
current_chunk.append(sentences[i])
else: # 语义跳跃,开始新块
chunks.append(' '.join(current_chunk))
current_chunk = [sentences[i]]

if current_chunk:
chunks.append(' '.join(current_chunk))

return chunks

向量化与 Embedding

Embedding模型选择

模型 维度 特点 适用场景
OpenAI text-embedding-ada-002 1536 通用强 通用场景
BGE 1024 中英双语 多语言
M3E 768 中文优化 中文场景
Instructor 768 指令驱动 领域特定
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class EmbeddingModel:
"""Embedding模型封装"""

def __init__(self, model_name='bge-large-zh-v1.5'):
self.model_name = model_name
self.model = self._load_model(model_name)
self.tokenizer = self._load_tokenizer(model_name)

def encode(self, texts, batch_size=32):
"""批量编码"""
embeddings = []

for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]

inputs = self.tokenizer(
batch,
padding=True,
truncation=True,
max_length=512,
return_tensors='pt'
)

with torch.no_grad():
outputs = self.model(**inputs)
# 使用[CLS] token或均值池化
batch_embeddings = outputs.last_hidden_state.mean(dim=1)

embeddings.append(batch_embeddings.numpy())

return np.vstack(embeddings)

def encode_query(self, query):
"""编码查询(可能需要特殊处理)"""
return self.encode([query])[0]

向量数据库

主流向量数据库对比

数据库 特点 适用规模 部署方式
Milvus 高性能,开源 十亿级 云/本地
Qdrant Rust实现,高效 亿级 云/本地
Chroma 轻量级,易用 百万级 嵌入式
Pinecone 云原生 任意规模 仅云
Weaviate 混合搜索 千万级 云/本地
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
class VectorStore:
"""向量存储封装"""

def __init__(self, store_type='milvus', connection_params=None):
self.store_type = store_type

if store_type == 'milvus':
from pymilvus import connections, Collection
connections.connect(**connection_params)
self.collection = Collection(connection_params['collection_name'])

elif store_type == 'qdrant':
from qdrant_client import QdrantClient
self.client = QdrantClient(**connection_params)

elif store_type == 'chroma':
import chromadb
self.client = chromadb.Client()

def add(self, chunks, embeddings, ids=None):
"""添加向量到数据库"""
if ids is None:
ids = [str(i) for i in range(len(chunks))]

if self.store_type == 'milvus':
self.collection.insert([
{
'id': id_,
'vector': embedding,
'content': chunk['content'],
'metadata': chunk['metadata']
}
for id_, embedding, chunk in zip(ids, embeddings, chunks)
])

elif self.store_type == 'chroma':
self.client.add(
ids=ids,
embeddings=embeddings,
documents=[c['content'] for c in chunks],
metadatas=[c['metadata'] for c in chunks]
)

def search(self, query_embedding, top_k=5,
similarity_threshold=0.7):
"""相似度搜索"""
if self.store_type == 'milvus':
results = self.collection.search(
data=[query_embedding],
anns_field='vector',
top_k=top_k,
param={'metric_type': 'IP'} # 内积相似度
)

return [r for r in results[0]
if r.score >= similarity_threshold]

elif self.store_type == 'chroma':
results = self.client.query(
query_embeddings=[query_embedding],
n_results=top_k
)

return results

检索策略

基础检索

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class NaiveRetriever:
"""简单向量检索"""

def __init__(self, vector_store, embedding_model, top_k=5):
self.vector_store = vector_store
self.embedding_model = embedding_model
self.top_k = top_k

def retrieve(self, query):
# 向量化查询
query_embedding = self.embedding_model.encode_query(query)

# 搜索
results = self.vector_store.search(query_embedding, self.top_k)

return results

混合检索

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
class HybridRetriever:
"""混合检索:向量+关键词"""

def __init__(self, vector_store, bm25_index,
embedding_model, top_k=10, alpha=0.5):
self.vector_store = vector_store
self.bm25_index = bm25_index
self.embedding_model = embedding_model
self.top_k = top_k
self.alpha = alpha # 向量权重

def retrieve(self, query):
# 1. 向量检索
query_embedding = self.embedding_model.encode_query(query)
vector_results = self.vector_store.search(
query_embedding, self.top_k * 2
)

# 2. BM25检索
bm25_results = self.bm25_index.search(query, self.top_k * 2)

# 3. RRF融合
fused_scores = self._reciprocal_rank_fusion(
vector_results, bm25_results
)

# 4. 返回top_k
return sorted(fused_scores, key=lambda x: x['score'],
reverse=True)[:self.top_k]

def _reciprocal_rank_fusion(self, results_list, k=60):
"""RRF融合算法"""
scores = {}

for results in results_list:
for rank, result in enumerate(results):
doc_id = result['id']
rrf_score = 1 / (k + rank + 1)

if doc_id not in scores:
scores[doc_id] = {
'doc': result,
'score': 0
}

scores[doc_id]['score'] += rrf_score

return list(scores.values())

重排序(Reranker)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class Reranker:
"""重排序模型"""

def __init__(self, model_name='cross-encoder'):
self.model = self._load_model(model_name)

def rerank(self, query, documents, top_k=5):
"""
使用Cross-Encoder重排序
"""
# 构建句子对
pairs = [(query, doc['content']) for doc in documents]

# 计算相关性分数
scores = self.model.predict(pairs)

# 排序
doc_scores = list(zip(documents, scores))
doc_scores.sort(key=lambda x: x[1], reverse=True)

# 返回top_k
reranked = []
for doc, score in doc_scores[:top_k]:
doc['rerank_score'] = score
reranked.append(doc)

return reranked

高级RAG技术

1. 子问题分解

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class SubQuestionRetriever:
"""子问题分解检索"""

def __init__(self, llm, retriever):
self.llm = llm
self.retriever = retriever

def decompose(self, query):
"""将复杂问题分解为子问题"""
prompt = f"""将以下问题分解为多个简单的子问题:

问题:{query}

子问题:"""

response = self.llm.generate(prompt)
sub_questions = self._parse_sub_questions(response)

return sub_questions

def retrieve_with_decomposition(self, query):
"""分解+检索"""
# 分解问题
sub_questions = self.decompose(query)

# 分别检索
all_results = []
for sq in sub_questions:
results = self.retriever.retrieve(sq)
all_results.extend(results)

# 去重和排序
unique_results = self._deduplicate(all_results)

return unique_results

2. 自适应检索

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class AdaptiveRAG:
"""自适应检索策略"""

def __init__(self, llm, retrievers):
self.llm = llm
self.retrievers = retrievers # 不同策略的检索器

def decide_retrieval_strategy(self, query):
"""决定检索策略"""
prompt = f"""分析以下查询,决定最佳检索策略:

查询:{query}

策略选项:
1. web_search - 需要最新信息
2. vector_search - 需要从知识库检索
3. hybrid - 需要结合两者
4. direct_answer - 可以直接回答

决策:"""

decision = self.llm.generate(prompt)
return self._parse_decision(decision)

def query(self, question):
"""自适应查询"""
strategy = self.decide_retrieval_strategy(question)

if strategy == 'vector_search':
return self.retrievers['vector'].retrieve(question)
elif strategy == 'web_search':
return self.retrievers['web'].search(question)
elif strategy == 'hybrid':
return self.retrievers['hybrid'].retrieve(question)
else:
return []

评估指标

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class RAGEvaluator:
"""RAG系统评估"""

def evaluate(self, test_set, rag_system):
"""
评估RAG系统

测试指标:
- 检索召回率
- 生成质量 (BLEU, ROUGE)
- 答案正确性
- 引用准确率
"""
results = []

for item in test_set:
question = item['question']
expected_answer = item['answer']

# RAG查询
response = rag_system.query(question)

# 评估
metrics = {
'faithfulness': self.compute_faithfulness(
response['answer'], response['sources']
),
'answer_relevancy': self.compute_answer_relevancy(
question, response['answer']
),
'context_recall': self.compute_context_recall(
expected_answer, response['sources']
),
'context_precision': self.compute_context_precision(
expected_answer, response['sources']
)
}

results.append(metrics)

# 汇总
return self._aggregate_metrics(results)

总结

RAG是当前大模型应用的主流架构,通过将外部知识检索与大模型生成相结合,有效解决了知识时效性和幻觉问题。随着向量数据库、Embedding模型和检索策略的不断优化,RAG系统的性能正在持续提升。


推荐阅读:

  • 《Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks》
  • LangChain官方文档
  • 向量数据库对比分析
© 2019-2026 ovo$^{mc^2}$ All Rights Reserved. | 站点总访问 28969 次 | 访客 19045
Theme by hiero