mmeb/test_faiss_with_proxy.py
2025-09-22 10:13:11 +00:00

165 lines
5.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
FAISS多模态检索系统简单测试 - 带代理设置
"""
import sys
import os
import logging
# 设置代理
os.environ['HTTP_PROXY'] = 'http://127.0.0.1:7890' # 根据实际情况修改
os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:7890' # 根据实际情况修改
# 设置日志
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# 设置离线模式,避免下载模型
os.environ['TRANSFORMERS_OFFLINE'] = '1'
# 添加当前目录到路径
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
# 使用简单的向量模型替代大型多模态模型
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
class SimpleFaissRetrieval:
"""简化版FAISS检索系统使用sentence-transformers"""
def __init__(self, model_name="paraphrase-multilingual-MiniLM-L12-v2", index_path="simple_faiss_index"):
"""
初始化简化版检索系统
Args:
model_name: 模型名称,使用轻量级模型
index_path: 索引文件路径
"""
self.model_name = model_name
self.index_path = index_path
logger.info(f"加载模型: {model_name}")
try:
# 尝试加载模型
self.model = SentenceTransformer(model_name)
self.dimension = self.model.get_sentence_embedding_dimension()
logger.info(f"模型加载成功,向量维度: {self.dimension}")
except Exception as e:
logger.error(f"模型加载失败: {str(e)}")
logger.info("使用随机向量模拟...")
self.model = None
self.dimension = 384 # 默认维度
# 初始化索引
self.index = faiss.IndexFlatL2(self.dimension)
self.metadata = {}
logger.info("检索系统初始化完成")
def encode_text(self, text):
"""编码文本为向量"""
if self.model is None:
# 如果模型加载失败,使用随机向量
if isinstance(text, list):
vectors = np.random.rand(len(text), self.dimension).astype('float32')
return vectors
else:
return np.random.rand(self.dimension).astype('float32')
else:
# 使用模型编码
return self.model.encode(text, convert_to_numpy=True)
def add_texts(self, texts, metadatas=None):
"""添加文本到索引"""
if not texts:
return []
if metadatas is None:
metadatas = [{} for _ in range(len(texts))]
# 编码文本
vectors = self.encode_text(texts)
# 添加到索引
start_id = len(self.metadata)
ids = list(range(start_id, start_id + len(texts)))
self.index.add(np.array(vectors).astype('float32'))
# 保存元数据
for i, id in enumerate(ids):
self.metadata[str(id)] = {
"text": texts[i],
"type": "text",
**metadatas[i]
}
logger.info(f"添加了{len(ids)}条文本,当前索引大小: {self.index.ntotal}")
return [str(id) for id in ids]
def search(self, query, k=5):
"""搜索相似文本"""
# 编码查询
query_vector = self.encode_text(query)
if len(query_vector.shape) == 1:
query_vector = query_vector.reshape(1, -1)
# 搜索
distances, indices = self.index.search(query_vector.astype('float32'), k)
# 处理结果
results = []
for i in range(len(indices[0])):
idx = indices[0][i]
if idx < 0:
continue
vector_id = str(idx)
if vector_id in self.metadata:
result = self.metadata[vector_id].copy()
result['score'] = float(1.0 / (1.0 + distances[0][i]))
results.append(result)
return results
def test_simple_retrieval():
"""测试简化版检索系统"""
print("=== 测试简化版FAISS检索系统 ===")
# 初始化检索系统
print("初始化检索系统...")
retrieval = SimpleFaissRetrieval()
# 测试文本
texts = [
"一只可爱的橘色猫咪在沙发上睡觉",
"城市夜景中的高楼大厦和车流",
"阳光明媚的海滩上,人们在冲浪和晒太阳",
"美味的意大利面配红酒和沙拉",
"雪山上滑雪的运动员"
]
# 添加文本
print("\n添加文本到检索系统...")
text_ids = retrieval.add_texts(texts)
print(f"添加了{len(text_ids)}条文本")
# 测试文本搜索
print("\n测试文本搜索...")
queries = ["一只猫在睡觉", "都市风光", "海边的景色"]
for query in queries:
print(f"\n查询: {query}")
results = retrieval.search(query, k=2)
for i, result in enumerate(results):
print(f" 结果 {i+1}: {result.get('text', 'N/A')} (分数: {result.get('score', 0):.4f})")
print("\n测试完成!")
if __name__ == "__main__":
test_simple_retrieval()