148 lines
4.6 KiB
Python
148 lines
4.6 KiB
Python
import os
|
||
import json
|
||
import numpy as np
|
||
import faiss
|
||
from typing import List, Dict, Any, Optional, Tuple
|
||
import logging
|
||
|
||
class FaissVectorStore:
|
||
def __init__(self, index_path: str = "faiss_index", dimension: int = 3584):
|
||
"""
|
||
初始化FAISS向量存储
|
||
|
||
参数:
|
||
index_path: 索引文件路径
|
||
dimension: 向量维度
|
||
"""
|
||
self.index_path = index_path
|
||
self.dimension = dimension
|
||
self.index = None
|
||
self.metadata = {}
|
||
self.metadata_path = f"{index_path}_metadata.json"
|
||
|
||
# 加载现有索引或创建新索引
|
||
self._load_or_create_index()
|
||
|
||
def _load_or_create_index(self):
|
||
"""加载现有索引或创建新索引"""
|
||
if os.path.exists(f"{self.index_path}.index"):
|
||
logging.info(f"加载现有索引: {self.index_path}")
|
||
self.index = faiss.read_index(f"{self.index_path}.index")
|
||
self._load_metadata()
|
||
else:
|
||
logging.info(f"创建新索引,维度: {self.dimension}")
|
||
self.index = faiss.IndexFlatL2(self.dimension) # 使用L2距离
|
||
|
||
def _load_metadata(self):
|
||
"""加载元数据"""
|
||
if os.path.exists(self.metadata_path):
|
||
with open(self.metadata_path, 'r', encoding='utf-8') as f:
|
||
self.metadata = json.load(f)
|
||
|
||
def _save_metadata(self):
|
||
"""保存元数据到文件"""
|
||
with open(self.metadata_path, 'w', encoding='utf-8') as f:
|
||
json.dump(self.metadata, f, ensure_ascii=False, indent=2)
|
||
|
||
def save_index(self):
|
||
"""保存索引和元数据"""
|
||
if self.index is not None:
|
||
faiss.write_index(self.index, f"{self.index_path}.index")
|
||
self._save_metadata()
|
||
logging.info(f"索引已保存到 {self.index_path}.index")
|
||
|
||
def add_vectors(
|
||
self,
|
||
vectors: np.ndarray,
|
||
metadatas: List[Dict[str, Any]]
|
||
) -> List[str]:
|
||
"""
|
||
添加向量和元数据
|
||
|
||
参数:
|
||
vectors: 向量数组
|
||
metadatas: 对应的元数据列表
|
||
|
||
返回:
|
||
添加的向量ID列表
|
||
"""
|
||
if len(vectors) != len(metadatas):
|
||
raise ValueError("vectors和metadatas长度必须相同")
|
||
|
||
start_id = len(self.metadata)
|
||
ids = list(range(start_id, start_id + len(vectors)))
|
||
|
||
# 添加向量到索引
|
||
self.index.add(vectors.astype('float32'))
|
||
|
||
# 保存元数据
|
||
for idx, vector_id in enumerate(ids):
|
||
self.metadata[str(vector_id)] = metadatas[idx]
|
||
|
||
# 保存索引和元数据
|
||
self.save_index()
|
||
|
||
return [str(id) for id in ids]
|
||
|
||
def search(
|
||
self,
|
||
query_vector: np.ndarray,
|
||
k: int = 5
|
||
) -> Tuple[List[Dict[str, Any]], List[float]]:
|
||
"""
|
||
相似性搜索
|
||
|
||
参数:
|
||
query_vector: 查询向量
|
||
k: 返回结果数量
|
||
|
||
返回:
|
||
(结果列表, 距离列表)
|
||
"""
|
||
if self.index is None:
|
||
return [], []
|
||
|
||
# 确保输入是2D数组
|
||
if len(query_vector.shape) == 1:
|
||
query_vector = query_vector.reshape(1, -1)
|
||
|
||
# 执行搜索
|
||
distances, indices = self.index.search(query_vector.astype('float32'), k)
|
||
|
||
# 处理结果
|
||
results = []
|
||
for i in range(len(indices[0])):
|
||
idx = indices[0][i]
|
||
if idx < 0: # FAISS可能返回-1表示无效索引
|
||
continue
|
||
|
||
vector_id = str(idx)
|
||
if vector_id in self.metadata:
|
||
result = self.metadata[vector_id].copy()
|
||
result['distance'] = float(distances[0][i])
|
||
results.append(result)
|
||
|
||
return results, distances[0].tolist()
|
||
|
||
def get_vector_count(self) -> int:
|
||
"""获取向量数量"""
|
||
return self.index.ntotal if self.index is not None else 0
|
||
|
||
def delete_vectors(self, vector_ids: List[str]) -> bool:
|
||
"""
|
||
删除指定ID的向量
|
||
|
||
注意: FAISS不支持直接删除向量,这里实现为逻辑删除
|
||
"""
|
||
deleted_count = 0
|
||
for vector_id in vector_ids:
|
||
if vector_id in self.metadata:
|
||
del self.metadata[vector_id]
|
||
deleted_count += 1
|
||
|
||
if deleted_count > 0:
|
||
self._save_metadata()
|
||
logging.warning("FAISS不支持直接删除向量,已从元数据中移除,但索引中仍保留")
|
||
|
||
return deleted_count > 0
|