mmeb/web_app_multigpu.py
2025-08-20 10:01:03 +00:00

618 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
多GPU多模态检索系统 - Web应用
专为双GPU部署优化
"""
import os
import json
import time
from flask import Flask, render_template, request, jsonify, send_file, url_for
from werkzeug.utils import secure_filename
from PIL import Image
import base64
import io
import logging
import traceback
import glob
# 设置环境变量优化GPU内存
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = Flask(__name__)
app.config['SECRET_KEY'] = 'multigpu_multimodal_retrieval_2024'
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max file size
# 配置上传文件夹
UPLOAD_FOLDER = 'uploads'
SAMPLE_IMAGES_FOLDER = 'sample_images'
TEXT_DATA_FOLDER = 'text_data'
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'bmp', 'webp'}
# 确保文件夹存在
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
os.makedirs(SAMPLE_IMAGES_FOLDER, exist_ok=True)
os.makedirs(TEXT_DATA_FOLDER, exist_ok=True)
# 全局检索系统实例
retrieval_system = None
def allowed_file(filename):
"""检查文件扩展名是否允许"""
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
def image_to_base64(image_path):
"""将图片转换为base64编码"""
try:
with open(image_path, "rb") as img_file:
return base64.b64encode(img_file.read()).decode('utf-8')
except Exception as e:
logger.error(f"图片转换失败: {e}")
return None
@app.route('/')
def index():
"""主页"""
return render_template('index.html')
@app.route('/api/status')
def get_status():
"""获取系统状态"""
global retrieval_system
status = {
'initialized': retrieval_system is not None,
'gpu_count': 0,
'model_loaded': False
}
try:
import torch
if torch.cuda.is_available():
status['gpu_count'] = torch.cuda.device_count()
if retrieval_system and retrieval_system.model:
status['model_loaded'] = True
status['device_ids'] = retrieval_system.device_ids
except Exception as e:
logger.error(f"获取状态失败: {e}")
return jsonify(status)
@app.route('/api/init', methods=['POST'])
def initialize_system():
"""初始化多GPU检索系统"""
global retrieval_system
try:
logger.info("正在初始化多GPU检索系统...")
# 导入多GPU检索系统
from multimodal_retrieval_multigpu import MultiGPUMultimodalRetrieval
# 初始化系统
retrieval_system = MultiGPUMultimodalRetrieval()
if retrieval_system.model is None:
raise Exception("模型加载失败")
logger.info("✅ 多GPU系统初始化成功")
return jsonify({
'success': True,
'message': '多GPU系统初始化成功',
'device_ids': retrieval_system.device_ids,
'gpu_count': len(retrieval_system.device_ids)
})
except Exception as e:
error_msg = f"系统初始化失败: {str(e)}"
logger.error(error_msg)
logger.error(traceback.format_exc())
return jsonify({
'success': False,
'message': error_msg
}), 500
@app.route('/api/search/text_to_text', methods=['POST'])
def search_text_to_text():
"""文本搜索文本"""
return handle_search('text_to_text')
@app.route('/api/search/text_to_image', methods=['POST'])
def search_text_to_image():
"""文本搜索图片"""
return handle_search('text_to_image')
@app.route('/api/search/image_to_text', methods=['POST'])
def search_image_to_text():
"""图片搜索文本"""
return handle_search('image_to_text')
@app.route('/api/search/image_to_image', methods=['POST'])
def search_image_to_image():
"""图片搜索图片"""
return handle_search('image_to_image')
@app.route('/api/search', methods=['POST'])
def search():
"""通用搜索接口(兼容旧版本)"""
mode = request.form.get('mode') or request.json.get('mode', 'text_to_text')
return handle_search(mode)
def handle_search(mode):
"""处理搜索请求的通用函数"""
global retrieval_system
if not retrieval_system:
return jsonify({
'success': False,
'message': '系统未初始化,请先点击初始化按钮'
}), 400
try:
top_k = int(request.form.get('top_k', 5))
if mode in ['text_to_text', 'text_to_image']:
# 文本查询
query = request.form.get('query') or request.json.get('query', '')
if not query.strip():
return jsonify({
'success': False,
'message': '请输入查询文本'
}), 400
logger.info(f"执行{mode}搜索: {query}")
# 执行搜索
if mode == 'text_to_text':
results = retrieval_system.search_text_to_text(query, top_k=top_k)
else: # text_to_image
results = retrieval_system.search_text_to_image(query, top_k=top_k)
return jsonify({
'success': True,
'mode': mode,
'query': query,
'results': results,
'result_count': len(results)
})
elif mode in ['image_to_text', 'image_to_image']:
# 图片查询
if 'image' not in request.files:
return jsonify({
'success': False,
'message': '请上传查询图片'
}), 400
file = request.files['image']
if file.filename == '' or not allowed_file(file.filename):
return jsonify({
'success': False,
'message': '请上传有效的图片文件'
}), 400
# 保存上传的图片
filename = secure_filename(file.filename)
timestamp = str(int(time.time()))
filename = f"query_{timestamp}_{filename}"
filepath = os.path.join(UPLOAD_FOLDER, filename)
file.save(filepath)
logger.info(f"执行{mode}搜索,图片: {filename}")
# 执行搜索
if mode == 'image_to_text':
results = retrieval_system.search_image_to_text(filepath, top_k=top_k)
else: # image_to_image
results = retrieval_system.search_image_to_image(filepath, top_k=top_k)
# 转换查询图片为base64
query_image_b64 = image_to_base64(filepath)
return jsonify({
'success': True,
'mode': mode,
'query_image': query_image_b64,
'results': results,
'result_count': len(results)
})
else:
return jsonify({
'success': False,
'message': f'不支持的搜索模式: {mode}'
}), 400
except Exception as e:
error_msg = f"搜索失败: {str(e)}"
logger.error(error_msg)
logger.error(traceback.format_exc())
return jsonify({
'success': False,
'message': error_msg
}), 500
@app.route('/api/upload/images', methods=['POST'])
def upload_images():
"""批量上传图片"""
try:
uploaded_files = []
if 'images' not in request.files:
return jsonify({'success': False, 'message': '没有选择文件'}), 400
files = request.files.getlist('images')
for file in files:
if file and file.filename != '' and allowed_file(file.filename):
filename = secure_filename(file.filename)
timestamp = str(int(time.time()))
filename = f"{timestamp}_{filename}"
filepath = os.path.join(SAMPLE_IMAGES_FOLDER, filename)
file.save(filepath)
uploaded_files.append(filename)
return jsonify({
'success': True,
'message': f'成功上传 {len(uploaded_files)} 个图片文件',
'uploaded_count': len(uploaded_files),
'files': uploaded_files
})
except Exception as e:
return jsonify({
'success': False,
'message': f'上传失败: {str(e)}'
}), 500
@app.route('/api/upload/texts', methods=['POST'])
def upload_texts():
"""批量上传文本数据"""
try:
data = request.get_json()
if not data or 'texts' not in data:
return jsonify({'success': False, 'message': '没有提供文本数据'}), 400
texts = data['texts']
if not isinstance(texts, list):
return jsonify({'success': False, 'message': '文本数据格式错误'}), 400
# 保存文本数据到文件
timestamp = str(int(time.time()))
filename = f"texts_{timestamp}.json"
filepath = os.path.join(TEXT_DATA_FOLDER, filename)
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(texts, f, ensure_ascii=False, indent=2)
return jsonify({
'success': True,
'message': f'成功上传 {len(texts)} 条文本',
'uploaded_count': len(texts)
})
except Exception as e:
return jsonify({
'success': False,
'message': f'上传失败: {str(e)}'
}), 500
@app.route('/api/upload/file', methods=['POST'])
def upload_single_file():
"""上传单个文件"""
if 'file' not in request.files:
return jsonify({'success': False, 'message': '没有选择文件'}), 400
file = request.files['file']
if file.filename == '':
return jsonify({'success': False, 'message': '没有选择文件'}), 400
if file and allowed_file(file.filename):
filename = secure_filename(file.filename)
timestamp = str(int(time.time()))
filename = f"{timestamp}_{filename}"
filepath = os.path.join(SAMPLE_IMAGES_FOLDER, filename)
file.save(filepath)
return jsonify({
'success': True,
'message': '文件上传成功',
'filename': filename
})
return jsonify({'success': False, 'message': '不支持的文件类型'}), 400
@app.route('/api/data/list', methods=['GET'])
def list_data():
"""列出已上传的数据"""
try:
# 列出图片文件
images = []
if os.path.exists(SAMPLE_IMAGES_FOLDER):
for filename in os.listdir(SAMPLE_IMAGES_FOLDER):
if allowed_file(filename):
filepath = os.path.join(SAMPLE_IMAGES_FOLDER, filename)
stat = os.stat(filepath)
images.append({
'filename': filename,
'size': stat.st_size,
'modified': stat.st_mtime
})
# 列出文本文件
texts = []
if os.path.exists(TEXT_DATA_FOLDER):
for filename in os.listdir(TEXT_DATA_FOLDER):
if filename.endswith('.json'):
filepath = os.path.join(TEXT_DATA_FOLDER, filename)
stat = os.stat(filepath)
texts.append({
'filename': filename,
'size': stat.st_size,
'modified': stat.st_mtime
})
return jsonify({
'success': True,
'data': {
'images': images,
'texts': texts
}
})
except Exception as e:
return jsonify({
'success': False,
'message': f'获取数据列表失败: {str(e)}'
}), 500
@app.route('/api/gpu_status')
def gpu_status():
"""获取GPU状态"""
try:
from smart_gpu_launcher import get_gpu_memory_info
gpu_info = get_gpu_memory_info()
return jsonify({
'success': True,
'gpu_info': gpu_info
})
except Exception as e:
return jsonify({
'success': False,
'message': f"获取GPU状态失败: {str(e)}"
}), 500
@app.route('/api/build_index', methods=['POST'])
def build_index():
"""构建检索索引"""
global retrieval_system
if not retrieval_system:
return jsonify({
'success': False,
'message': '系统未初始化'
}), 400
try:
# 获取所有图片和文本文件
image_files = []
text_data = []
# 扫描图片文件
for ext in ALLOWED_EXTENSIONS:
pattern = os.path.join(SAMPLE_IMAGES_FOLDER, f"*.{ext}")
image_files.extend(glob.glob(pattern))
# 读取文本文件(支持.json和.txt格式
text_files = glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.json"))
text_files.extend(glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.txt")))
for text_file in text_files:
try:
if text_file.endswith('.json'):
# 读取JSON格式的文本数据
with open(text_file, 'r', encoding='utf-8') as f:
data = json.load(f)
if isinstance(data, list):
text_data.extend([str(item).strip() for item in data if str(item).strip()])
else:
text_data.append(str(data).strip())
else:
# 读取TXT格式的文本数据
with open(text_file, 'r', encoding='utf-8') as f:
lines = [line.strip() for line in f.readlines() if line.strip()]
text_data.extend(lines)
except Exception as e:
logger.warning(f"读取文本文件失败 {text_file}: {e}")
# 检查是否有数据可以构建索引
if not image_files and not text_data:
return jsonify({
'success': False,
'message': '没有找到可用的图片或文本数据,请先上传数据'
}), 400
# 构建索引
if image_files:
logger.info(f"构建图片索引,共 {len(image_files)} 张图片")
retrieval_system.build_image_index_parallel(image_files)
if text_data:
logger.info(f"构建文本索引,共 {len(text_data)} 条文本")
retrieval_system.build_text_index_parallel(text_data)
return jsonify({
'success': True,
'message': f'索引构建完成!图片: {len(image_files)} 张,文本: {len(text_data)}',
'image_count': len(image_files),
'text_count': len(text_data)
})
except Exception as e:
logger.error(f"构建索引失败: {str(e)}")
return jsonify({
'success': False,
'message': f'构建索引失败: {str(e)}'
}), 500
@app.route('/api/data/stats', methods=['GET'])
def get_data_stats():
"""获取数据统计信息"""
try:
# 统计图片文件
image_count = 0
for ext in ALLOWED_EXTENSIONS:
pattern = os.path.join(SAMPLE_IMAGES_FOLDER, f"*.{ext}")
image_count += len(glob.glob(pattern))
# 统计文本数据
text_count = 0
text_files = glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.txt"))
for text_file in text_files:
try:
with open(text_file, 'r', encoding='utf-8') as f:
lines = [line.strip() for line in f.readlines() if line.strip()]
text_count += len(lines)
except Exception:
continue
return jsonify({
'success': True,
'image_count': image_count,
'text_count': text_count
})
except Exception as e:
logger.error(f"获取数据统计失败: {str(e)}")
return jsonify({
'success': False,
'message': f'获取统计失败: {str(e)}'
}), 500
@app.route('/api/data/clear', methods=['POST'])
def clear_data():
"""清空所有数据"""
try:
# 清空图片文件
for ext in ALLOWED_EXTENSIONS:
pattern = os.path.join(SAMPLE_IMAGES_FOLDER, f"*.{ext}")
for file_path in glob.glob(pattern):
try:
os.remove(file_path)
except Exception as e:
logger.warning(f"删除图片文件失败 {file_path}: {e}")
# 清空文本文件
text_files = glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.txt"))
for text_file in text_files:
try:
os.remove(text_file)
except Exception as e:
logger.warning(f"删除文本文件失败 {text_file}: {e}")
# 重置索引
global retrieval_system
if retrieval_system:
retrieval_system.text_index = None
retrieval_system.image_index = None
retrieval_system.text_data = []
retrieval_system.image_data = []
return jsonify({
'success': True,
'message': '数据已清空'
})
except Exception as e:
logger.error(f"清空数据失败: {str(e)}")
return jsonify({
'success': False,
'message': f'清空数据失败: {str(e)}'
}), 500
@app.route('/uploads/<filename>')
def uploaded_file(filename):
"""提供上传文件的访问"""
return send_file(os.path.join(SAMPLE_IMAGES_FOLDER, filename))
def print_startup_info():
"""打印启动信息"""
print("🚀 启动多GPU多模态检索Web应用")
print("=" * 60)
print("访问地址: http://localhost:5000")
print("支持功能:")
print(" 📝 文搜文 - 文本查找相似文本")
print(" 🖼️ 文搜图 - 文本查找相关图片")
print(" 📝 图搜文 - 图片查找相关文本")
print(" 🖼️ 图搜图 - 图片查找相似图片")
print(" 📤 批量上传 - 图片和文本数据管理")
print("GPU配置:")
try:
import torch
if torch.cuda.is_available():
gpu_count = torch.cuda.device_count()
print(f" 🖥️ 检测到 {gpu_count} 个GPU")
for i in range(gpu_count):
name = torch.cuda.get_device_name(i)
props = torch.cuda.get_device_properties(i)
memory_gb = props.total_memory / 1024**3
print(f" GPU {i}: {name} ({memory_gb:.1f}GB)")
else:
print(" ❌ CUDA不可用")
except Exception as e:
print(f" ❌ GPU检查失败: {e}")
print("=" * 60)
def auto_initialize():
"""启动时自动初始化系统"""
global retrieval_system
try:
logger.info("🚀 启动时自动初始化多GPU检索系统...")
# 导入多GPU检索系统
from multimodal_retrieval_multigpu import MultiGPUMultimodalRetrieval
# 初始化系统
retrieval_system = MultiGPUMultimodalRetrieval()
if retrieval_system.model is None:
raise Exception("模型加载失败")
logger.info("✅ 系统自动初始化成功")
return True
except Exception as e:
logger.error(f"❌ 系统自动初始化失败: {str(e)}")
logger.error(traceback.format_exc())
return False
if __name__ == '__main__':
print_startup_info()
# 启动时自动初始化
auto_initialize()
# 启动Flask应用
app.run(
host='0.0.0.0',
port=5000,
debug=False,
threaded=True
)