mmeb/test_local_model.py
2025-09-22 10:13:11 +00:00

99 lines
3.2 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
使用本地模型的FAISS多模态检索系统测试
"""
import os
import sys
import logging
from pathlib import Path
import numpy as np
import faiss
from typing import List, Dict, Any, Optional, Union
import json
# 设置日志
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# 设置离线模式
os.environ['TRANSFORMERS_OFFLINE'] = '1'
def test_local_model():
"""测试本地模型加载"""
from transformers import AutoModel, AutoTokenizer, AutoProcessor
import torch
from PIL import Image
# 这里替换为您实际下载的模型路径
local_model_path = "/root/models/Ops-MM-embedding-v1-7B"
if not os.path.exists(local_model_path):
logger.error(f"模型路径不存在: {local_model_path}")
logger.info("请先下载模型到指定路径")
return
logger.info(f"加载本地模型: {local_model_path}")
try:
# 加载tokenizer
logger.info("加载tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(local_model_path)
# 加载processor
logger.info("加载processor...")
processor = AutoProcessor.from_pretrained(local_model_path)
# 加载模型
logger.info("加载模型...")
model = AutoModel.from_pretrained(
local_model_path,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.device_count() > 0 else None
)
logger.info("模型加载成功!")
# 测试文本编码
logger.info("测试文本编码...")
text = "这是一个测试文本"
inputs = tokenizer(text, return_tensors="pt")
if torch.cuda.is_available():
inputs = {k: v.to("cuda") for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
text_embedding = outputs.last_hidden_state[:, 0, :].cpu().numpy()
logger.info(f"文本编码维度: {text_embedding.shape}")
# 如果有图像处理功能,测试图像编码
try:
logger.info("测试图像编码...")
# 创建一个简单的测试图像
image = Image.new('RGB', (224, 224), color='red')
image_inputs = processor(images=image, return_tensors="pt")
if torch.cuda.is_available():
image_inputs = {k: v.to("cuda") for k, v in image_inputs.items()}
with torch.no_grad():
image_outputs = model.vision_model(**image_inputs)
image_embedding = image_outputs.pooler_output.cpu().numpy()
logger.info(f"图像编码维度: {image_embedding.shape}")
except Exception as e:
logger.error(f"图像编码测试失败: {str(e)}")
logger.info("本地模型测试完成!")
except Exception as e:
logger.error(f"模型加载失败: {str(e)}")
logger.error("请确保模型文件已正确下载")
if __name__ == "__main__":
test_local_model()