RAG检索增强生成:让大模型拥有最新知识

🎙️ 语音朗读 当前: 晓晓 (温柔女声)

引言

检索增强生成(Retrieval-Augmented Generation, RAG)是2023年大模型应用领域的核心技术之一。RAG通过结合检索系统和生成模型,有效解决了大模型知识过时、幻觉问题等痛点。本文将深入剖析RAG的技术原理、架构设计和工程实现。

RAG技术原理

1. 为什么需要RAG

大语言模型存在以下固有问题:

  1. 知识截止日期:模型的训练数据有截止日期,无法获取最新信息
  2. 幻觉问题:模型可能生成看似合理但错误的内容
  3. 长尾知识缺失:训练数据不均衡导致长尾知识表现差
  4. 领域专业知识:垂直领域知识不足

RAG通过检索最新、最相关的信息来增强模型的回答质量。

2. RAG核心架构

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from dataclasses import dataclass
from typing import List, Optional
import numpy as np

@dataclass
class Document:
"""文档数据结构"""
page_content: str
metadata: dict

class RAGSystem:
"""
RAG系统核心组件
"""
def __init__(self, embedding_model, llm, vector_store):
self.embedding_model = embedding_model # 嵌入模型
self.llm = llm # 语言模型
self.vector_store = vector_store # 向量数据库

def retrieve(self, query: str, top_k: int = 5) -> List[Document]:
"""
检索相关文档
"""
# 将查询向量化
query_embedding = self.embedding_model.embed(query)

# 在向量数据库中搜索
results = self.vector_store.similarity_search(
query_embedding,
k=top_k
)

return results

def generate(self, query: str, context_docs: List[Document]) -> str:
"""
基于检索结果生成回答
"""
# 构建prompt
context = "\n\n".join([
f"[文档{i+1}] {doc.page_content}"
for i, doc in enumerate(context_docs)
])

prompt = f"""基于以下参考文档回答问题。如果文档中没有相关信息,请说明无法回答。

参考文档:
{context}

问题:{query}

回答:"""

# 调用语言模型生成
response = self.llm.generate(prompt)
return response

def answer(self, query: str, top_k: int = 5) -> str:
"""
RAG完整流程:检索 + 生成
"""
# 1. 检索相关文档
docs = self.retrieve(query, top_k)

# 2. 基于文档生成回答
response = self.generate(query, docs)

return response

向量数据库详解

1. Embedding模型选择

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from sentence_transformers import SentenceTransformer
import torch

class EmbeddingModel:
"""
Embedding模型封装
"""
def __init__(self, model_name="moka-ai/m3e-base"):
self.model = SentenceTransformer(model_name)
self.model.eval()

# GPU加速
if torch.cuda.is_available():
self.model = self.model.to('cuda')

def embed(self, text: str) -> np.ndarray:
"""单文本嵌入"""
return self.model.encode(text, normalize_embeddings=True)

def embed_batch(self, texts: List[str]) -> np.ndarray:
"""批量嵌入"""
return self.model.encode(texts, normalize_embeddings=True, batch_size=32)

def get_embedding_dim(self) -> int:
return self.model.get_sentence_embedding_dimension()

2. 向量索引实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import faiss
import numpy as np

class VectorIndex:
"""
FAISS向量索引
"""
def __init__(self, dim: int, index_type="IVF"):
self.dim = dim
self.index_type = index_type
self.index = self._build_index()
self.documents = [] # 存储原始文档

def _build_index(self):
"""
构建索引
支持多种索引类型
"""
if self.index_type == "Flat":
# 精确搜索,精度最高
return faiss.IndexFlatIP(self.dim) # 内积相似度

elif self.index_type == "IVF":
# 倒排索引,近似搜索
quantizer = faiss.IndexFlatIP(self.dim)
return faiss.IndexIVFFlat(quantizer, self.dim, 100)

elif self.index_type == "HNSW":
# 分层可导航小世界图
return faiss.IndexHNSWFlat(self.dim, 32)

elif self.index_type == "PQ":
# 产品量化,压缩存储
return faiss.IndexPQ(self.dim, 16, 8)

raise ValueError(f"Unknown index type: {self.index_type}")

def add(self, embeddings: np.ndarray, documents: List[Document]):
"""
添加向量到索引
"""
if not self.index.is_trained:
# 训练索引
self.index.train(embeddings)

# L2归一化用于余弦相似度
embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)

self.index.add(embeddings)
self.documents.extend(documents)

def search(self, query_embedding: np.ndarray, k: int = 5) -> List[Document]:
"""
相似性搜索
"""
# 归一化
query = query_embedding / np.linalg.norm(query_embedding)

# 搜索
distances, indices = self.index.search(query.reshape(1, -1), k)

# 返回文档
results = []
for idx, dist in zip(indices[0], distances[0]):
if idx >= 0 and idx < len(self.documents):
doc = self.documents[idx]
doc.metadata['score'] = float(dist)
results.append(doc)

return results

3. Milvus向量数据库集成

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType

class MilvusVectorStore:
"""
Milvus向量数据库封装
"""
def __init__(self, host="localhost", port="19530", collection_name="documents"):
connections.connect(host=host, port=port)
self.collection_name = collection_name
self._create_collection()

def _create_collection(self):
"""创建集合"""
fields = [
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=1536),
FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=65535),
FieldSchema(name="metadata", dtype=DataType.VARCHAR, max_length=65535)
]
schema = CollectionSchema(fields=fields, description="Document collection")

self.collection = Collection(name=self.collection_name, schema=schema)

# 创建索引
index_params = {
"metric_type": "IP", # 内积
"index_type": "IVF_FLAT",
"params": {"nlist": 128}
}
self.collection.create_index(field_name="embedding", index_params=index_params)

def insert(self, embeddings: List[np.ndarray], documents: List[Document]):
"""插入数据"""
import json

entities = [
embeddings.tolist(),
[doc.page_content for doc in documents],
[json.dumps(doc.metadata) for doc in documents]
]

self.collection.insert(entities)
self.collection.flush()

def search(self, query_embedding: np.ndarray, top_k: int = 5) -> List[Document]:
"""搜索"""
import json

search_params = {"metric_type": "IP", "params": {"nprobe": 10}}

results = self.collection.search(
data=[query_embedding.tolist()],
anns_field="embedding",
param=search_params,
limit=top_k,
output_fields=["content", "metadata"]
)

docs = []
for hits in results:
for hit in hits:
doc = Document(
page_content=hit.entity.get("content"),
metadata=json.loads(hit.entity.get("metadata"))
)
doc.metadata['score'] = hit.score
docs.append(doc)

return docs

LangChain RAG实战

1. 文档加载与处理

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from langchain.document_loaders import PyPDFLoader, UnstructuredHTMLLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings

class DocumentProcessor:
"""
文档处理流水线
"""
def __init__(self, chunk_size=500, chunk_overlap=50):
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
separators=["\n\n", "\n", "。", "!", "?", " "]
)
self.embeddings = HuggingFaceEmbeddings(
model_name="moka-ai/m3e-base",
model_kwargs={'device': 'cuda'}
)

def load_pdf(self, file_path: str) -> List[Document]:
"""加载PDF"""
loader = PyPDFLoader(file_path)
pages = loader.load_and_split()

# 分割
docs = self.text_splitter.split_documents(pages)
return docs

def load_html(self, file_path: str) -> List[Document]:
"""加载HTML"""
loader = UnstructuredHTMLLoader(file_path)
docs = loader.load()
docs = self.text_splitter.split_documents(docs)
return docs

def load_webpage(self, url: str) -> List[Document]:
"""加载网页"""
from langchain.document_loaders import WebBaseLoader

loader = WebBaseLoader(url)
docs = loader.load()
docs = self.text_splitter.split_documents(docs)
return docs

def create_vectorstore(self, documents: List[Document], persist_dir: str = None):
"""创建向量数据库"""
from langchain.vectorstores import Chroma

# 使用Chroma存储
vectorstore = Chroma.from_documents(
documents=documents,
embedding=self.embeddings,
persist_directory=persist_dir
)

return vectorstore

2. RAG Chain构建

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain.chat_models import ChatOpenAI

class RAGChainBuilder:
"""
构建RAG Chain
"""
def __init__(self, vectorstore, llm_model="gpt-3.5-turbo"):
self.vectorstore = vectorstore
self.llm = ChatOpenAI(model=llm_model, temperature=0)

def build_qa_chain(self):
"""
构建问答Chain
"""
prompt_template = """使用以下背景信息回答问题。如果信息不足以回答问题,请说明不知道,不要编造答案。

背景信息:
{context}

问题:{question}

回答:"""

PROMPT = PromptTemplate(
template=prompt_template,
input_variables=["context", "question"]
)

chain = RetrievalQA.from_chain_type(
llm=self.llm,
chain_type="stuff", # 将所有文档拼接到prompt
retriever=self.vectorstore.as_retriever(
search_kwargs={"k": 5}
),
chain_type_kwargs={
"prompt": PROMPT,
"document_variable_name": "context"
}
)

return chain

def build_conversational_chain(self):
"""
构建对话Chain(支持多轮对话)
"""
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory

memory = ConversationBufferMemory(
memory_key="chat_history",
return_messages=True
)

chain = ConversationalRetrievalChain.from_llm(
llm=self.llm,
retriever=self.vectorstore.as_retriever(),
memory=memory,
condense_question_prompt=self._get_condense_prompt()
)

return chain

def _get_condense_prompt(self):
"""
获取问题改写prompt
将对话历史+新问题改写为独立问题
"""
template = """给定以下对话历史和一个后续问题,将后续问题重新表述为一个独立的问题,使其包含足够的上下文信息。

对话历史:
{chat_history}

后续问题:{question}

独立问题:"""

return PromptTemplate.from_template(template)

3. HyDE:假设性文档嵌入

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class HyDERetrieval:
"""
HyDE (Hypothetical Document Embeddings)
通过生成假设性文档来改进检索
"""
def __init__(self, embedding_model, llm):
self.embedding_model = embedding_model
self.llm = llm

def generate_hypothetical_document(self, query: str) -> str:
"""
生成假设性答案文档
"""
prompt = f"""针对以下问题,生成一个假设性的答案文档。这个文档应该是一个可能的正确答案,然后根据这个假设文档进行检索。

问题:{query}

假设性答案文档:"""

hyp_doc = self.llm.generate(prompt)
return hyp_doc

def retrieve(self, query: str, vectorstore, top_k: int = 5):
"""
HyDE检索流程
"""
# 1. 生成假设性文档
hyp_doc = self.generate_hypothetical_document(query)

# 2. 向量化假设文档
hyp_embedding = self.embedding_model.embed(hyp_doc)

# 3. 使用假设文档检索
results = vectorstore.similarity_search(hyp_embedding, k=top_k)

return results

RAG高级技巧

1. 重排序(Reranking)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from sentence_transformers import CrossEncoder

class Reranker:
"""
使用Cross-Encoder进行重排序
"""
def __init__(self, model_name="BAAI/bge-reranker-base"):
self.model = CrossEncoder(model_name)

def rerank(self, query: str, documents: List[Document], top_k: int = 3):
"""
重排序检索结果
"""
doc_texts = [doc.page_content for doc in documents]

# 构建查询-文档对
pairs = [(query, doc) for doc in doc_texts]

# 计算相关性分数
scores = self.model.predict(pairs)

# 按分数排序
scored_docs = list(zip(documents, scores))
scored_docs.sort(key=lambda x: x[1], reverse=True)

return [doc for doc, score in scored_docs[:top_k]]

2. 混合检索

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class HybridRetrieval:
"""
混合检索:向量检索 + BM25关键词检索
"""
def __init__(self, vectorstore, bm25_index):
self.vectorstore = vectorstore
self.bm25_index = bm25_index

def retrieve(self, query: str, top_k: int = 5, vector_weight: float = 0.7):
"""
混合检索
"""
# 向量检索
vector_results = self.vectorstore.similarity_search(query, k=top_k * 2)

# BM25检索
bm25_results = self.bm25_index.search(query, k=top_k * 2)

# 分数归一化和融合
fused_scores = {}

for doc, score in vector_results:
doc_id = doc.metadata.get('id', id(doc))
fused_scores[doc_id] = fused_scores.get(doc_id, 0) + score * vector_weight

for doc, score in bm25_results:
doc_id = doc.metadata.get('id', id(doc))
fused_scores[doc_id] = fused_scores.get(doc_id, 0) + score * (1 - vector_weight)

# 排序返回
sorted_docs = sorted(
fused_scores.items(),
key=lambda x: x[1],
reverse=True
)[:top_k]

return [doc for doc, _ in sorted_docs]

总结

RAG是解决大模型知识局限性的一种有效方案。通过结合向量检索和语言模型生成,RAG系统能够提供最新、最准确的回答。本文介绍了RAG的核心原理、向量数据库选择、LangChain实现以及多项高级优化技巧。在实际应用中,需要根据具体场景选择合适的检索策略和模型配置,以达到最佳的问答效果。

参考资源

© 2019-2026 ovo$^{mc^2}$ All Rights Reserved. | 站点总访问 28969 次 | 访客 19045
Theme by hiero