355 lines
12 KiB
Python
355 lines
12 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
测试优化后的系统功能
|
||
验证自动清理、内存处理和流式上传
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import time
|
||
import tempfile
|
||
import logging
|
||
import subprocess
|
||
import signal
|
||
from io import BytesIO
|
||
from PIL import Image
|
||
import requests
|
||
import json
|
||
|
||
# 设置日志
|
||
logging.basicConfig(level=logging.INFO)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
def start_web_server():
|
||
"""启动Web服务器"""
|
||
try:
|
||
logger.info("🚀 启动Web服务器...")
|
||
|
||
# 启动Web应用进程
|
||
process = subprocess.Popen([
|
||
sys.executable, 'web_app_vdb_production.py'
|
||
], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
||
|
||
logger.info(f"✅ Web服务器已启动,PID: {process.pid}")
|
||
|
||
# 等待服务器启动
|
||
for i in range(10):
|
||
try:
|
||
response = requests.get("http://127.0.0.1:5000/", timeout=2)
|
||
if response.status_code == 200:
|
||
logger.info("✅ Web服务器就绪")
|
||
return process
|
||
except:
|
||
time.sleep(2)
|
||
|
||
logger.warning("⚠️ Web服务器启动超时")
|
||
return process
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 启动Web服务器失败: {e}")
|
||
return None
|
||
|
||
def stop_web_server(process):
|
||
"""停止Web服务器"""
|
||
if process:
|
||
try:
|
||
process.terminate()
|
||
process.wait(timeout=5)
|
||
logger.info("✅ Web服务器已停止")
|
||
except subprocess.TimeoutExpired:
|
||
process.kill()
|
||
logger.info("🔥 强制停止Web服务器")
|
||
except Exception as e:
|
||
logger.error(f"❌ 停止Web服务器失败: {e}")
|
||
|
||
def test_optimized_file_handler():
|
||
"""测试优化的文件处理器"""
|
||
print("\n" + "="*60)
|
||
print("测试优化的文件处理器")
|
||
print("="*60)
|
||
|
||
try:
|
||
from optimized_file_handler import get_file_handler
|
||
|
||
# 获取文件处理器实例
|
||
file_handler = get_file_handler()
|
||
logger.info("✅ 文件处理器初始化成功")
|
||
|
||
# 创建测试图像
|
||
test_image = Image.new('RGB', (100, 100), color='red')
|
||
img_buffer = BytesIO()
|
||
test_image.save(img_buffer, format='PNG')
|
||
img_buffer.seek(0)
|
||
|
||
# 测试文件大小判断
|
||
file_size = file_handler.get_file_size(img_buffer)
|
||
is_small = file_handler.is_small_file(img_buffer)
|
||
logger.info(f"测试图像大小: {file_size} bytes, 小文件: {is_small}")
|
||
|
||
# 测试临时文件上下文管理器
|
||
with file_handler.temp_file_context(suffix='.png') as temp_path:
|
||
logger.info(f"创建临时文件: {temp_path}")
|
||
assert os.path.exists(temp_path)
|
||
|
||
# 验证临时文件已被清理
|
||
assert not os.path.exists(temp_path)
|
||
logger.info("✅ 临时文件自动清理成功")
|
||
|
||
# 测试清理所有临时文件
|
||
file_handler.cleanup_all_temp_files()
|
||
logger.info("✅ 批量清理临时文件成功")
|
||
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 文件处理器测试失败: {e}")
|
||
return False
|
||
|
||
def test_memory_processing():
|
||
"""测试内存处理功能"""
|
||
print("\n" + "="*60)
|
||
print("测试内存处理功能")
|
||
print("="*60)
|
||
|
||
try:
|
||
from optimized_file_handler import get_file_handler
|
||
|
||
file_handler = get_file_handler()
|
||
|
||
# 测试文本内存处理
|
||
test_texts = [
|
||
"这是一个测试文本",
|
||
"测试内存处理功能",
|
||
"优化的文件处理器"
|
||
]
|
||
|
||
logger.info(f"开始内存处理 {len(test_texts)} 条文本...")
|
||
processed_texts = file_handler.process_text_in_memory(test_texts)
|
||
|
||
if processed_texts:
|
||
logger.info(f"✅ 内存处理文本成功: {len(processed_texts)} 条")
|
||
for i, text_info in enumerate(processed_texts):
|
||
logger.info(f" 文本 {i}: {text_info['bos_key']}")
|
||
else:
|
||
logger.warning("⚠️ 内存处理文本返回空结果")
|
||
|
||
return len(processed_texts) > 0
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 内存处理测试失败: {e}")
|
||
return False
|
||
|
||
def test_web_api_optimized():
|
||
"""测试优化后的Web API"""
|
||
print("\n" + "="*60)
|
||
print("测试优化后的Web API")
|
||
print("="*60)
|
||
|
||
base_url = "http://127.0.0.1:5000"
|
||
|
||
try:
|
||
# 测试系统初始化
|
||
logger.info("测试系统初始化...")
|
||
response = requests.post(f"{base_url}/api/init")
|
||
if response.status_code == 200:
|
||
result = response.json()
|
||
logger.info(f"✅ 系统初始化: {result.get('message')}")
|
||
else:
|
||
logger.warning(f"⚠️ 系统可能已初始化: {response.status_code}")
|
||
|
||
# 测试文本上传(内存处理)
|
||
logger.info("测试文本上传(内存处理)...")
|
||
text_data = {
|
||
"texts": [
|
||
"优化后的文本处理测试",
|
||
"内存模式处理文本",
|
||
"自动清理临时文件"
|
||
]
|
||
}
|
||
|
||
response = requests.post(
|
||
f"{base_url}/api/upload/texts",
|
||
json=text_data,
|
||
headers={'Content-Type': 'application/json'}
|
||
)
|
||
|
||
if response.status_code == 200:
|
||
result = response.json()
|
||
logger.info(f"✅ 文本上传成功: {result.get('message')}")
|
||
logger.info(f" 处理方法: {result.get('processing_method')}")
|
||
logger.info(f" 处理数量: {result.get('processed_texts')}")
|
||
else:
|
||
logger.error(f"❌ 文本上传失败: {response.status_code}")
|
||
logger.error(f" 错误信息: {response.text}")
|
||
|
||
# 测试图像上传(智能处理)
|
||
logger.info("测试图像上传(智能处理)...")
|
||
|
||
# 创建测试图像
|
||
test_image = Image.new('RGB', (200, 200), color='blue')
|
||
img_buffer = BytesIO()
|
||
test_image.save(img_buffer, format='PNG')
|
||
img_buffer.seek(0)
|
||
|
||
files = {'files': ('test_image.png', img_buffer, 'image/png')}
|
||
|
||
response = requests.post(f"{base_url}/api/upload/images", files=files)
|
||
|
||
if response.status_code == 200:
|
||
result = response.json()
|
||
logger.info(f"✅ 图像上传成功: {result.get('message')}")
|
||
logger.info(f" 处理方法: {result.get('processing_methods')}")
|
||
logger.info(f" 处理数量: {result.get('processed_files')}")
|
||
else:
|
||
logger.error(f"❌ 图像上传失败: {response.status_code}")
|
||
logger.error(f" 错误信息: {response.text}")
|
||
|
||
# 测试文搜文
|
||
logger.info("测试文搜文...")
|
||
search_data = {
|
||
"query": "优化处理",
|
||
"top_k": 3
|
||
}
|
||
|
||
response = requests.post(
|
||
f"{base_url}/api/search/text_to_text",
|
||
json=search_data,
|
||
headers={'Content-Type': 'application/json'}
|
||
)
|
||
|
||
if response.status_code == 200:
|
||
result = response.json()
|
||
logger.info(f"✅ 文搜文成功: 找到 {result.get('count')} 个结果")
|
||
for i, res in enumerate(result.get('results', [])):
|
||
logger.info(f" 结果 {i}: 相似度 {res['score']:.3f}")
|
||
else:
|
||
logger.error(f"❌ 文搜文失败: {response.status_code}")
|
||
|
||
# 测试系统统计
|
||
logger.info("测试系统统计...")
|
||
response = requests.get(f"{base_url}/api/stats")
|
||
|
||
if response.status_code == 200:
|
||
result = response.json()
|
||
stats = result.get('stats', {})
|
||
logger.info(f"✅ 系统统计:")
|
||
logger.info(f" 文本数量: {stats.get('text_count', 0)}")
|
||
logger.info(f" 图像数量: {stats.get('image_count', 0)}")
|
||
logger.info(f" 后端: {result.get('backend')}")
|
||
else:
|
||
logger.error(f"❌ 获取系统统计失败: {response.status_code}")
|
||
|
||
return True
|
||
|
||
except requests.exceptions.ConnectionError:
|
||
logger.error("❌ 无法连接到Web服务器,请确保服务器正在运行")
|
||
return False
|
||
except Exception as e:
|
||
logger.error(f"❌ Web API测试失败: {e}")
|
||
return False
|
||
|
||
def test_temp_file_cleanup():
|
||
"""测试临时文件清理"""
|
||
print("\n" + "="*60)
|
||
print("测试临时文件清理")
|
||
print("="*60)
|
||
|
||
try:
|
||
from optimized_file_handler import get_file_handler
|
||
|
||
file_handler = get_file_handler()
|
||
|
||
# 创建多个临时文件
|
||
temp_paths = []
|
||
for i in range(3):
|
||
temp_path = file_handler.get_temp_file_for_model(
|
||
BytesIO(b"test content"), f"test_{i}.txt"
|
||
)
|
||
if temp_path:
|
||
temp_paths.append(temp_path)
|
||
logger.info(f"创建临时文件: {temp_path}")
|
||
|
||
# 验证文件存在
|
||
existing_count = sum(1 for path in temp_paths if os.path.exists(path))
|
||
logger.info(f"创建的临时文件数量: {existing_count}")
|
||
|
||
# 清理所有临时文件
|
||
file_handler.cleanup_all_temp_files()
|
||
|
||
# 验证文件已清理
|
||
remaining_count = sum(1 for path in temp_paths if os.path.exists(path))
|
||
logger.info(f"清理后剩余文件数量: {remaining_count}")
|
||
|
||
if remaining_count == 0:
|
||
logger.info("✅ 临时文件清理测试成功")
|
||
return True
|
||
else:
|
||
logger.warning(f"⚠️ 仍有 {remaining_count} 个文件未清理")
|
||
return False
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 临时文件清理测试失败: {e}")
|
||
return False
|
||
|
||
def main():
|
||
"""主测试函数"""
|
||
print("🚀 开始测试优化后的系统功能")
|
||
print("="*60)
|
||
|
||
# 启动Web服务器
|
||
web_process = start_web_server()
|
||
|
||
try:
|
||
test_results = []
|
||
|
||
# 运行各项测试
|
||
tests = [
|
||
("文件处理器基础功能", test_optimized_file_handler),
|
||
("内存处理功能", test_memory_processing),
|
||
("临时文件清理", test_temp_file_cleanup),
|
||
("Web API功能", test_web_api_optimized),
|
||
]
|
||
|
||
for test_name, test_func in tests:
|
||
logger.info(f"\n🔍 开始测试: {test_name}")
|
||
try:
|
||
result = test_func()
|
||
test_results.append((test_name, result))
|
||
if result:
|
||
logger.info(f"✅ {test_name} - 通过")
|
||
else:
|
||
logger.warning(f"⚠️ {test_name} - 失败")
|
||
except Exception as e:
|
||
logger.error(f"❌ {test_name} - 异常: {e}")
|
||
test_results.append((test_name, False))
|
||
|
||
# 输出测试总结
|
||
print("\n" + "="*60)
|
||
print("测试结果总结")
|
||
print("="*60)
|
||
|
||
passed = sum(1 for _, result in test_results if result)
|
||
total = len(test_results)
|
||
|
||
for test_name, result in test_results:
|
||
status = "✅ 通过" if result else "❌ 失败"
|
||
print(f"{test_name}: {status}")
|
||
|
||
print(f"\n总体结果: {passed}/{total} 项测试通过")
|
||
|
||
if passed == total:
|
||
print("🎉 所有测试通过!优化系统功能正常")
|
||
else:
|
||
print("⚠️ 部分测试失败,请检查相关功能")
|
||
|
||
return passed == total
|
||
|
||
finally:
|
||
# 确保停止Web服务器
|
||
stop_web_server(web_process)
|
||
|
||
if __name__ == "__main__":
|
||
success = main()
|
||
sys.exit(0 if success else 1)
|