165 lines
5.3 KiB
Python
165 lines
5.3 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
FAISS多模态检索系统简单测试 - 带代理设置
|
||
"""
|
||
|
||
import sys
|
||
import os
|
||
import logging
|
||
|
||
# 设置代理
|
||
os.environ['HTTP_PROXY'] = 'http://127.0.0.1:7890' # 根据实际情况修改
|
||
os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:7890' # 根据实际情况修改
|
||
|
||
# 设置日志
|
||
logging.basicConfig(level=logging.INFO,
|
||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 设置离线模式,避免下载模型
|
||
os.environ['TRANSFORMERS_OFFLINE'] = '1'
|
||
|
||
# 添加当前目录到路径
|
||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||
|
||
# 使用简单的向量模型替代大型多模态模型
|
||
from sentence_transformers import SentenceTransformer
|
||
import faiss
|
||
import numpy as np
|
||
|
||
class SimpleFaissRetrieval:
|
||
"""简化版FAISS检索系统,使用sentence-transformers"""
|
||
|
||
def __init__(self, model_name="paraphrase-multilingual-MiniLM-L12-v2", index_path="simple_faiss_index"):
|
||
"""
|
||
初始化简化版检索系统
|
||
|
||
Args:
|
||
model_name: 模型名称,使用轻量级模型
|
||
index_path: 索引文件路径
|
||
"""
|
||
self.model_name = model_name
|
||
self.index_path = index_path
|
||
|
||
logger.info(f"加载模型: {model_name}")
|
||
try:
|
||
# 尝试加载模型
|
||
self.model = SentenceTransformer(model_name)
|
||
self.dimension = self.model.get_sentence_embedding_dimension()
|
||
logger.info(f"模型加载成功,向量维度: {self.dimension}")
|
||
except Exception as e:
|
||
logger.error(f"模型加载失败: {str(e)}")
|
||
logger.info("使用随机向量模拟...")
|
||
self.model = None
|
||
self.dimension = 384 # 默认维度
|
||
|
||
# 初始化索引
|
||
self.index = faiss.IndexFlatL2(self.dimension)
|
||
self.metadata = {}
|
||
|
||
logger.info("检索系统初始化完成")
|
||
|
||
def encode_text(self, text):
|
||
"""编码文本为向量"""
|
||
if self.model is None:
|
||
# 如果模型加载失败,使用随机向量
|
||
if isinstance(text, list):
|
||
vectors = np.random.rand(len(text), self.dimension).astype('float32')
|
||
return vectors
|
||
else:
|
||
return np.random.rand(self.dimension).astype('float32')
|
||
else:
|
||
# 使用模型编码
|
||
return self.model.encode(text, convert_to_numpy=True)
|
||
|
||
def add_texts(self, texts, metadatas=None):
|
||
"""添加文本到索引"""
|
||
if not texts:
|
||
return []
|
||
|
||
if metadatas is None:
|
||
metadatas = [{} for _ in range(len(texts))]
|
||
|
||
# 编码文本
|
||
vectors = self.encode_text(texts)
|
||
|
||
# 添加到索引
|
||
start_id = len(self.metadata)
|
||
ids = list(range(start_id, start_id + len(texts)))
|
||
|
||
self.index.add(np.array(vectors).astype('float32'))
|
||
|
||
# 保存元数据
|
||
for i, id in enumerate(ids):
|
||
self.metadata[str(id)] = {
|
||
"text": texts[i],
|
||
"type": "text",
|
||
**metadatas[i]
|
||
}
|
||
|
||
logger.info(f"添加了{len(ids)}条文本,当前索引大小: {self.index.ntotal}")
|
||
return [str(id) for id in ids]
|
||
|
||
def search(self, query, k=5):
|
||
"""搜索相似文本"""
|
||
# 编码查询
|
||
query_vector = self.encode_text(query)
|
||
if len(query_vector.shape) == 1:
|
||
query_vector = query_vector.reshape(1, -1)
|
||
|
||
# 搜索
|
||
distances, indices = self.index.search(query_vector.astype('float32'), k)
|
||
|
||
# 处理结果
|
||
results = []
|
||
for i in range(len(indices[0])):
|
||
idx = indices[0][i]
|
||
if idx < 0:
|
||
continue
|
||
|
||
vector_id = str(idx)
|
||
if vector_id in self.metadata:
|
||
result = self.metadata[vector_id].copy()
|
||
result['score'] = float(1.0 / (1.0 + distances[0][i]))
|
||
results.append(result)
|
||
|
||
return results
|
||
|
||
def test_simple_retrieval():
|
||
"""测试简化版检索系统"""
|
||
print("=== 测试简化版FAISS检索系统 ===")
|
||
|
||
# 初始化检索系统
|
||
print("初始化检索系统...")
|
||
retrieval = SimpleFaissRetrieval()
|
||
|
||
# 测试文本
|
||
texts = [
|
||
"一只可爱的橘色猫咪在沙发上睡觉",
|
||
"城市夜景中的高楼大厦和车流",
|
||
"阳光明媚的海滩上,人们在冲浪和晒太阳",
|
||
"美味的意大利面配红酒和沙拉",
|
||
"雪山上滑雪的运动员"
|
||
]
|
||
|
||
# 添加文本
|
||
print("\n添加文本到检索系统...")
|
||
text_ids = retrieval.add_texts(texts)
|
||
print(f"添加了{len(text_ids)}条文本")
|
||
|
||
# 测试文本搜索
|
||
print("\n测试文本搜索...")
|
||
queries = ["一只猫在睡觉", "都市风光", "海边的景色"]
|
||
|
||
for query in queries:
|
||
print(f"\n查询: {query}")
|
||
results = retrieval.search(query, k=2)
|
||
for i, result in enumerate(results):
|
||
print(f" 结果 {i+1}: {result.get('text', 'N/A')} (分数: {result.get('score', 0):.4f})")
|
||
|
||
print("\n测试完成!")
|
||
|
||
if __name__ == "__main__":
|
||
test_simple_retrieval()
|