mmeb/faiss_vector_store.py
2025-09-22 10:13:11 +00:00

148 lines
4.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import json
import numpy as np
import faiss
from typing import List, Dict, Any, Optional, Tuple
import logging
class FaissVectorStore:
def __init__(self, index_path: str = "faiss_index", dimension: int = 3584):
"""
初始化FAISS向量存储
参数:
index_path: 索引文件路径
dimension: 向量维度
"""
self.index_path = index_path
self.dimension = dimension
self.index = None
self.metadata = {}
self.metadata_path = f"{index_path}_metadata.json"
# 加载现有索引或创建新索引
self._load_or_create_index()
def _load_or_create_index(self):
"""加载现有索引或创建新索引"""
if os.path.exists(f"{self.index_path}.index"):
logging.info(f"加载现有索引: {self.index_path}")
self.index = faiss.read_index(f"{self.index_path}.index")
self._load_metadata()
else:
logging.info(f"创建新索引,维度: {self.dimension}")
self.index = faiss.IndexFlatL2(self.dimension) # 使用L2距离
def _load_metadata(self):
"""加载元数据"""
if os.path.exists(self.metadata_path):
with open(self.metadata_path, 'r', encoding='utf-8') as f:
self.metadata = json.load(f)
def _save_metadata(self):
"""保存元数据到文件"""
with open(self.metadata_path, 'w', encoding='utf-8') as f:
json.dump(self.metadata, f, ensure_ascii=False, indent=2)
def save_index(self):
"""保存索引和元数据"""
if self.index is not None:
faiss.write_index(self.index, f"{self.index_path}.index")
self._save_metadata()
logging.info(f"索引已保存到 {self.index_path}.index")
def add_vectors(
self,
vectors: np.ndarray,
metadatas: List[Dict[str, Any]]
) -> List[str]:
"""
添加向量和元数据
参数:
vectors: 向量数组
metadatas: 对应的元数据列表
返回:
添加的向量ID列表
"""
if len(vectors) != len(metadatas):
raise ValueError("vectors和metadatas长度必须相同")
start_id = len(self.metadata)
ids = list(range(start_id, start_id + len(vectors)))
# 添加向量到索引
self.index.add(vectors.astype('float32'))
# 保存元数据
for idx, vector_id in enumerate(ids):
self.metadata[str(vector_id)] = metadatas[idx]
# 保存索引和元数据
self.save_index()
return [str(id) for id in ids]
def search(
self,
query_vector: np.ndarray,
k: int = 5
) -> Tuple[List[Dict[str, Any]], List[float]]:
"""
相似性搜索
参数:
query_vector: 查询向量
k: 返回结果数量
返回:
(结果列表, 距离列表)
"""
if self.index is None:
return [], []
# 确保输入是2D数组
if len(query_vector.shape) == 1:
query_vector = query_vector.reshape(1, -1)
# 执行搜索
distances, indices = self.index.search(query_vector.astype('float32'), k)
# 处理结果
results = []
for i in range(len(indices[0])):
idx = indices[0][i]
if idx < 0: # FAISS可能返回-1表示无效索引
continue
vector_id = str(idx)
if vector_id in self.metadata:
result = self.metadata[vector_id].copy()
result['distance'] = float(distances[0][i])
results.append(result)
return results, distances[0].tolist()
def get_vector_count(self) -> int:
"""获取向量数量"""
return self.index.ntotal if self.index is not None else 0
def delete_vectors(self, vector_ids: List[str]) -> bool:
"""
删除指定ID的向量
注意: FAISS不支持直接删除向量这里实现为逻辑删除
"""
deleted_count = 0
for vector_id in vector_ids:
if vector_id in self.metadata:
del self.metadata[vector_id]
deleted_count += 1
if deleted_count > 0:
self._save_metadata()
logging.warning("FAISS不支持直接删除向量已从元数据中移除但索引中仍保留")
return deleted_count > 0