mmeb/quick_test.py
2025-09-01 11:24:01 +00:00

236 lines
7.0 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
快速测试脚本 - 验证多模态检索系统功能
"""
import os
import sys
import logging
import traceback
from pathlib import Path
# 设置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def test_imports():
"""测试关键模块导入"""
logger.info("🔍 测试模块导入...")
try:
import torch
logger.info(f"✅ PyTorch {torch.__version__}")
import transformers
logger.info(f"✅ Transformers {transformers.__version__}")
import numpy as np
logger.info(f"✅ NumPy {np.__version__}")
from PIL import Image
logger.info("✅ Pillow")
import flask
logger.info(f"✅ Flask {flask.__version__}")
try:
import pymochow
logger.info("✅ PyMochow (百度VDB SDK)")
except ImportError:
logger.warning("⚠️ PyMochow 未安装,需要运行: pip install pymochow")
try:
import pymongo
logger.info("✅ PyMongo")
except ImportError:
logger.warning("⚠️ PyMongo 未安装,需要运行: pip install pymongo")
return True
except Exception as e:
logger.error(f"❌ 模块导入失败: {str(e)}")
return False
def test_gpu_availability():
"""测试GPU可用性"""
logger.info("🖥️ 检查GPU环境...")
try:
import torch
if torch.cuda.is_available():
gpu_count = torch.cuda.device_count()
logger.info(f"✅ 检测到 {gpu_count} 个GPU")
for i in range(gpu_count):
gpu_name = torch.cuda.get_device_name(i)
gpu_memory = torch.cuda.get_device_properties(i).total_memory / 1024**3
logger.info(f" GPU {i}: {gpu_name} ({gpu_memory:.1f}GB)")
return True
else:
logger.info(" 未检测到GPU将使用CPU")
return False
except Exception as e:
logger.error(f"❌ GPU检查失败: {str(e)}")
return False
def test_baidu_vdb_connection():
"""测试百度VDB连接"""
logger.info("🔗 测试百度VDB连接...")
try:
import pymochow
from pymochow.configuration import Configuration
from pymochow.auth.bce_credentials import BceCredentials
# 连接配置
account = "root"
api_key = "vdb$yjr9ln3n0td"
endpoint = "http://180.76.96.191:5287"
config = Configuration(
credentials=BceCredentials(account, api_key),
endpoint=endpoint
)
client = pymochow.MochowClient(config)
# 测试连接 - 列出数据库
databases = client.list_databases()
logger.info(f"✅ VDB连接成功发现 {len(databases)} 个数据库")
client.close()
return True
except ImportError:
logger.error("❌ PyMochow 未安装无法测试VDB连接")
return False
except Exception as e:
logger.error(f"❌ VDB连接失败: {str(e)}")
return False
def test_model_loading():
"""测试模型加载"""
logger.info("🤖 测试模型加载...")
try:
from ops_mm_embedding_v1 import OpsMMEmbeddingV1
logger.info("正在初始化模型...")
model = OpsMMEmbeddingV1()
# 测试文本编码
test_texts = ["测试文本"]
embeddings = model.embed(texts=test_texts)
logger.info(f"✅ 模型加载成功,向量维度: {embeddings.shape}")
return True
except Exception as e:
logger.error(f"❌ 模型加载失败: {str(e)}")
logger.error(traceback.format_exc())
return False
def test_web_app_import():
"""测试Web应用导入"""
logger.info("🌐 测试Web应用模块...")
try:
# 测试导入主要模块
from multimodal_retrieval_vdb_only import MultimodalRetrievalVDBOnly
logger.info("✅ 多模态检索系统模块")
from baidu_vdb_production import BaiduVDBProduction
logger.info("✅ 百度VDB后端模块")
# 测试Web应用文件存在
web_app_file = Path("web_app_vdb_production.py")
if web_app_file.exists():
logger.info("✅ Web应用文件存在")
else:
logger.error("❌ Web应用文件不存在")
return False
return True
except Exception as e:
logger.error(f"❌ Web应用模块测试失败: {str(e)}")
return False
def create_test_directories():
"""创建必要的测试目录"""
logger.info("📁 创建测试目录...")
directories = ["uploads", "sample_images", "text_data"]
for dir_name in directories:
dir_path = Path(dir_name)
dir_path.mkdir(exist_ok=True)
logger.info(f"✅ 目录已创建: {dir_name}")
def main():
"""主测试函数"""
logger.info("🚀 开始快速测试...")
logger.info("=" * 50)
test_results = {}
# 1. 测试模块导入
test_results["imports"] = test_imports()
# 2. 测试GPU环境
test_results["gpu"] = test_gpu_availability()
# 3. 测试VDB连接
test_results["vdb"] = test_baidu_vdb_connection()
# 4. 测试Web应用模块
test_results["web_modules"] = test_web_app_import()
# 5. 创建测试目录
create_test_directories()
# 6. 尝试测试模型加载(可选)
if test_results["imports"]:
logger.info("\n⚠️ 模型加载测试需要较长时间,是否跳过?")
logger.info("如需测试模型,请单独运行模型测试")
# test_results["model"] = test_model_loading()
# 输出测试结果
logger.info("\n" + "=" * 50)
logger.info("📊 测试结果汇总:")
logger.info("=" * 50)
for test_name, result in test_results.items():
status = "✅ 通过" if result else "❌ 失败"
test_display = {
"imports": "模块导入",
"gpu": "GPU环境",
"vdb": "VDB连接",
"web_modules": "Web模块",
"model": "模型加载"
}.get(test_name, test_name)
logger.info(f"{test_display}: {status}")
# 计算成功率
success_count = sum(test_results.values())
total_count = len(test_results)
success_rate = (success_count / total_count) * 100
logger.info(f"\n总体成功率: {success_count}/{total_count} ({success_rate:.1f}%)")
if success_rate >= 75:
logger.info("🎉 系统基本就绪可以启动Web应用进行完整测试")
logger.info("运行命令: python web_app_vdb_production.py")
else:
logger.warning("⚠️ 系统存在问题,请检查失败的测试项")
return test_results
if __name__ == "__main__":
main()