import os import json import numpy as np import faiss from typing import List, Dict, Any, Optional, Tuple import logging class FaissVectorStore: def __init__(self, index_path: str = "faiss_index", dimension: int = 3584): """ 初始化FAISS向量存储 参数: index_path: 索引文件路径 dimension: 向量维度 """ self.index_path = index_path self.dimension = dimension self.index = None self.metadata = {} self.metadata_path = f"{index_path}_metadata.json" # 加载现有索引或创建新索引 self._load_or_create_index() def _load_or_create_index(self): """加载现有索引或创建新索引""" if os.path.exists(f"{self.index_path}.index"): logging.info(f"加载现有索引: {self.index_path}") self.index = faiss.read_index(f"{self.index_path}.index") self._load_metadata() else: logging.info(f"创建新索引,维度: {self.dimension}") self.index = faiss.IndexFlatL2(self.dimension) # 使用L2距离 def _load_metadata(self): """加载元数据""" if os.path.exists(self.metadata_path): with open(self.metadata_path, 'r', encoding='utf-8') as f: self.metadata = json.load(f) def _save_metadata(self): """保存元数据到文件""" with open(self.metadata_path, 'w', encoding='utf-8') as f: json.dump(self.metadata, f, ensure_ascii=False, indent=2) def save_index(self): """保存索引和元数据""" if self.index is not None: faiss.write_index(self.index, f"{self.index_path}.index") self._save_metadata() logging.info(f"索引已保存到 {self.index_path}.index") def add_vectors( self, vectors: np.ndarray, metadatas: List[Dict[str, Any]] ) -> List[str]: """ 添加向量和元数据 参数: vectors: 向量数组 metadatas: 对应的元数据列表 返回: 添加的向量ID列表 """ if len(vectors) != len(metadatas): raise ValueError("vectors和metadatas长度必须相同") start_id = len(self.metadata) ids = list(range(start_id, start_id + len(vectors))) # 添加向量到索引 self.index.add(vectors.astype('float32')) # 保存元数据 for idx, vector_id in enumerate(ids): self.metadata[str(vector_id)] = metadatas[idx] # 保存索引和元数据 self.save_index() return [str(id) for id in ids] def search( self, query_vector: np.ndarray, k: int = 5 ) -> Tuple[List[Dict[str, Any]], List[float]]: """ 相似性搜索 参数: query_vector: 查询向量 k: 返回结果数量 返回: (结果列表, 距离列表) """ if self.index is None: return [], [] # 确保输入是2D数组 if len(query_vector.shape) == 1: query_vector = query_vector.reshape(1, -1) # 执行搜索 distances, indices = self.index.search(query_vector.astype('float32'), k) # 处理结果 results = [] for i in range(len(indices[0])): idx = indices[0][i] if idx < 0: # FAISS可能返回-1表示无效索引 continue vector_id = str(idx) if vector_id in self.metadata: result = self.metadata[vector_id].copy() result['distance'] = float(distances[0][i]) results.append(result) return results, distances[0].tolist() def get_vector_count(self) -> int: """获取向量数量""" return self.index.ntotal if self.index is not None else 0 def delete_vectors(self, vector_ids: List[str]) -> bool: """ 删除指定ID的向量 注意: FAISS不支持直接删除向量,这里实现为逻辑删除 """ deleted_count = 0 for vector_id in vector_ids: if vector_id in self.metadata: del self.metadata[vector_id] deleted_count += 1 if deleted_count > 0: self._save_metadata() logging.warning("FAISS不支持直接删除向量,已从元数据中移除,但索引中仍保留") return deleted_count > 0