1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
| from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType, utility
class MilvusDB: """ Milvus向量数据库操作封装 """ def __init__(self, host="localhost", port="19530"): connections.connect(host=host, port=port) self.collection = None def create_collection(self, name: str, dim: int, metric_type="IP"): """ 创建集合 """ if utility.has_collection(name): utility.drop_collection(name) fields = [ FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True), FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=dim), FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535), FieldSchema(name="metadata", dtype=DataType.VARCHAR, max_length=65535) ] schema = CollectionSchema(fields=fields, description=f"Collection: {name}") self.collection = Collection(name=name, schema=schema) index_params = { "metric_type": metric_type, "index_type": "IVF_FLAT", "params": {"nlist": 128} } self.collection.create_index( field_name="embedding", index_params=index_params ) self.collection.load() def insert(self, embeddings: np.ndarray, texts: List[str], metadata: List[dict]): """ 插入数据 """ import json entities = [ embeddings.tolist(), texts, [json.dumps(m) for m in metadata] ] self.collection.insert(entities) self.collection.flush() def search(self, query_embedding: np.ndarray, top_k: int = 5) -> List[dict]: """ 相似性搜索 """ import json search_params = {"metric_type": "IP", "params": {"nprobe": 10}} results = self.collection.search( data=[query_embedding.tolist()], anns_field="embedding", param=search_params, limit=top_k, output_fields=["text", "metadata"] ) return [ { "id": hit.id, "distance": hit.distance, "text": hit.entity.get("text"), "metadata": json.loads(hit.entity.get("metadata")) } for hit in results[0] ] def delete_by_ids(self, ids: List[int]): """删除数据""" self.collection.delete(f"id in {ids}") def drop(self): """删除集合""" if self.collection: self.collection.drop()
|