#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ FAISS多模态检索系统简单测试 - 带代理设置 """ import sys import os import logging # 设置代理 os.environ['HTTP_PROXY'] = 'http://127.0.0.1:7890' # 根据实际情况修改 os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:7890' # 根据实际情况修改 # 设置日志 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # 设置离线模式,避免下载模型 os.environ['TRANSFORMERS_OFFLINE'] = '1' # 添加当前目录到路径 sys.path.append(os.path.dirname(os.path.abspath(__file__))) # 使用简单的向量模型替代大型多模态模型 from sentence_transformers import SentenceTransformer import faiss import numpy as np class SimpleFaissRetrieval: """简化版FAISS检索系统,使用sentence-transformers""" def __init__(self, model_name="paraphrase-multilingual-MiniLM-L12-v2", index_path="simple_faiss_index"): """ 初始化简化版检索系统 Args: model_name: 模型名称,使用轻量级模型 index_path: 索引文件路径 """ self.model_name = model_name self.index_path = index_path logger.info(f"加载模型: {model_name}") try: # 尝试加载模型 self.model = SentenceTransformer(model_name) self.dimension = self.model.get_sentence_embedding_dimension() logger.info(f"模型加载成功,向量维度: {self.dimension}") except Exception as e: logger.error(f"模型加载失败: {str(e)}") logger.info("使用随机向量模拟...") self.model = None self.dimension = 384 # 默认维度 # 初始化索引 self.index = faiss.IndexFlatL2(self.dimension) self.metadata = {} logger.info("检索系统初始化完成") def encode_text(self, text): """编码文本为向量""" if self.model is None: # 如果模型加载失败,使用随机向量 if isinstance(text, list): vectors = np.random.rand(len(text), self.dimension).astype('float32') return vectors else: return np.random.rand(self.dimension).astype('float32') else: # 使用模型编码 return self.model.encode(text, convert_to_numpy=True) def add_texts(self, texts, metadatas=None): """添加文本到索引""" if not texts: return [] if metadatas is None: metadatas = [{} for _ in range(len(texts))] # 编码文本 vectors = self.encode_text(texts) # 添加到索引 start_id = len(self.metadata) ids = list(range(start_id, start_id + len(texts))) self.index.add(np.array(vectors).astype('float32')) # 保存元数据 for i, id in enumerate(ids): self.metadata[str(id)] = { "text": texts[i], "type": "text", **metadatas[i] } logger.info(f"添加了{len(ids)}条文本,当前索引大小: {self.index.ntotal}") return [str(id) for id in ids] def search(self, query, k=5): """搜索相似文本""" # 编码查询 query_vector = self.encode_text(query) if len(query_vector.shape) == 1: query_vector = query_vector.reshape(1, -1) # 搜索 distances, indices = self.index.search(query_vector.astype('float32'), k) # 处理结果 results = [] for i in range(len(indices[0])): idx = indices[0][i] if idx < 0: continue vector_id = str(idx) if vector_id in self.metadata: result = self.metadata[vector_id].copy() result['score'] = float(1.0 / (1.0 + distances[0][i])) results.append(result) return results def test_simple_retrieval(): """测试简化版检索系统""" print("=== 测试简化版FAISS检索系统 ===") # 初始化检索系统 print("初始化检索系统...") retrieval = SimpleFaissRetrieval() # 测试文本 texts = [ "一只可爱的橘色猫咪在沙发上睡觉", "城市夜景中的高楼大厦和车流", "阳光明媚的海滩上,人们在冲浪和晒太阳", "美味的意大利面配红酒和沙拉", "雪山上滑雪的运动员" ] # 添加文本 print("\n添加文本到检索系统...") text_ids = retrieval.add_texts(texts) print(f"添加了{len(text_ids)}条文本") # 测试文本搜索 print("\n测试文本搜索...") queries = ["一只猫在睡觉", "都市风光", "海边的景色"] for query in queries: print(f"\n查询: {query}") results = retrieval.search(query, k=2) for i, result in enumerate(results): print(f" 结果 {i+1}: {result.get('text', 'N/A')} (分数: {result.get('score', 0):.4f})") print("\n测试完成!") if __name__ == "__main__": test_simple_retrieval()