1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
| from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType
class MilvusVectorStore: """ Milvus向量数据库封装 """ def __init__(self, host="localhost", port="19530", collection_name="documents"): connections.connect(host=host, port=port) self.collection_name = collection_name self._create_collection() def _create_collection(self): """创建集合""" fields = [ FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True), FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=1536), FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=65535), FieldSchema(name="metadata", dtype=DataType.VARCHAR, max_length=65535) ] schema = CollectionSchema(fields=fields, description="Document collection") self.collection = Collection(name=self.collection_name, schema=schema) index_params = { "metric_type": "IP", "index_type": "IVF_FLAT", "params": {"nlist": 128} } self.collection.create_index(field_name="embedding", index_params=index_params) def insert(self, embeddings: List[np.ndarray], documents: List[Document]): """插入数据""" import json entities = [ embeddings.tolist(), [doc.page_content for doc in documents], [json.dumps(doc.metadata) for doc in documents] ] self.collection.insert(entities) self.collection.flush() def search(self, query_embedding: np.ndarray, top_k: int = 5) -> List[Document]: """搜索""" import json search_params = {"metric_type": "IP", "params": {"nprobe": 10}} results = self.collection.search( data=[query_embedding.tolist()], anns_field="embedding", param=search_params, limit=top_k, output_fields=["content", "metadata"] ) docs = [] for hits in results: for hit in hits: doc = Document( page_content=hit.entity.get("content"), metadata=json.loads(hit.entity.get("metadata")) ) doc.metadata['score'] = hit.score docs.append(doc) return docs
|