#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 使用本地模型的FAISS多模态检索系统测试 """ import os import sys import logging from pathlib import Path import numpy as np import faiss from typing import List, Dict, Any, Optional, Union import json # 设置日志 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # 设置离线模式 os.environ['TRANSFORMERS_OFFLINE'] = '1' def test_local_model(): """测试本地模型加载""" from transformers import AutoModel, AutoTokenizer, AutoProcessor import torch from PIL import Image # 这里替换为您实际下载的模型路径 local_model_path = "/root/models/Ops-MM-embedding-v1-7B" if not os.path.exists(local_model_path): logger.error(f"模型路径不存在: {local_model_path}") logger.info("请先下载模型到指定路径") return logger.info(f"加载本地模型: {local_model_path}") try: # 加载tokenizer logger.info("加载tokenizer...") tokenizer = AutoTokenizer.from_pretrained(local_model_path) # 加载processor logger.info("加载processor...") processor = AutoProcessor.from_pretrained(local_model_path) # 加载模型 logger.info("加载模型...") model = AutoModel.from_pretrained( local_model_path, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto" if torch.cuda.device_count() > 0 else None ) logger.info("模型加载成功!") # 测试文本编码 logger.info("测试文本编码...") text = "这是一个测试文本" inputs = tokenizer(text, return_tensors="pt") if torch.cuda.is_available(): inputs = {k: v.to("cuda") for k, v in inputs.items()} with torch.no_grad(): outputs = model(**inputs) text_embedding = outputs.last_hidden_state[:, 0, :].cpu().numpy() logger.info(f"文本编码维度: {text_embedding.shape}") # 如果有图像处理功能,测试图像编码 try: logger.info("测试图像编码...") # 创建一个简单的测试图像 image = Image.new('RGB', (224, 224), color='red') image_inputs = processor(images=image, return_tensors="pt") if torch.cuda.is_available(): image_inputs = {k: v.to("cuda") for k, v in image_inputs.items()} with torch.no_grad(): image_outputs = model.vision_model(**image_inputs) image_embedding = image_outputs.pooler_output.cpu().numpy() logger.info(f"图像编码维度: {image_embedding.shape}") except Exception as e: logger.error(f"图像编码测试失败: {str(e)}") logger.info("本地模型测试完成!") except Exception as e: logger.error(f"模型加载失败: {str(e)}") logger.error("请确保模型文件已正确下载") if __name__ == "__main__": test_local_model()