678 lines
23 KiB
Python
678 lines
23 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
多GPU多模态检索系统 - Web应用
|
||
专为双GPU部署优化
|
||
"""
|
||
|
||
import os
|
||
import json
|
||
import time
|
||
from flask import Flask, render_template, request, jsonify, send_file, url_for
|
||
from werkzeug.utils import secure_filename
|
||
from PIL import Image
|
||
import base64
|
||
import io
|
||
import logging
|
||
import traceback
|
||
import glob
|
||
|
||
# 设置环境变量优化GPU内存
|
||
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
|
||
|
||
# 配置日志
|
||
logging.basicConfig(level=logging.INFO)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
app = Flask(__name__)
|
||
app.config['SECRET_KEY'] = 'multigpu_multimodal_retrieval_2024'
|
||
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max file size
|
||
|
||
# 配置上传文件夹
|
||
UPLOAD_FOLDER = 'uploads'
|
||
SAMPLE_IMAGES_FOLDER = 'sample_images'
|
||
TEXT_DATA_FOLDER = 'text_data'
|
||
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'bmp', 'webp'}
|
||
|
||
# 确保文件夹存在
|
||
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
|
||
os.makedirs(SAMPLE_IMAGES_FOLDER, exist_ok=True)
|
||
os.makedirs(TEXT_DATA_FOLDER, exist_ok=True)
|
||
|
||
# 全局检索系统实例
|
||
retrieval_system = None
|
||
|
||
def allowed_file(filename):
|
||
"""检查文件扩展名是否允许"""
|
||
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
|
||
|
||
def image_to_base64(image_path):
|
||
"""将图片转换为base64编码"""
|
||
try:
|
||
with open(image_path, "rb") as img_file:
|
||
return base64.b64encode(img_file.read()).decode('utf-8')
|
||
except Exception as e:
|
||
logger.error(f"图片转换失败: {e}")
|
||
return None
|
||
|
||
@app.route('/')
|
||
def index():
|
||
"""主页"""
|
||
return render_template('index.html')
|
||
|
||
@app.route('/api/status')
|
||
def get_status():
|
||
"""获取系统状态"""
|
||
global retrieval_system
|
||
|
||
status = {
|
||
'initialized': retrieval_system is not None,
|
||
'gpu_count': 0,
|
||
'model_loaded': False
|
||
}
|
||
|
||
try:
|
||
import torch
|
||
if torch.cuda.is_available():
|
||
status['gpu_count'] = torch.cuda.device_count()
|
||
|
||
if retrieval_system and retrieval_system.model:
|
||
status['model_loaded'] = True
|
||
status['device_ids'] = retrieval_system.device_ids
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取状态失败: {e}")
|
||
|
||
return jsonify(status)
|
||
|
||
@app.route('/api/init', methods=['POST'])
|
||
def initialize_system():
|
||
"""初始化多GPU检索系统"""
|
||
global retrieval_system
|
||
|
||
try:
|
||
logger.info("正在初始化多GPU检索系统...")
|
||
|
||
# 导入多GPU检索系统
|
||
from multimodal_retrieval_multigpu import MultiGPUMultimodalRetrieval
|
||
|
||
# 初始化系统
|
||
retrieval_system = MultiGPUMultimodalRetrieval()
|
||
|
||
if retrieval_system.model is None:
|
||
raise Exception("模型加载失败")
|
||
|
||
logger.info("✅ 多GPU系统初始化成功")
|
||
|
||
return jsonify({
|
||
'success': True,
|
||
'message': '多GPU系统初始化成功',
|
||
'device_ids': retrieval_system.device_ids,
|
||
'gpu_count': len(retrieval_system.device_ids)
|
||
})
|
||
|
||
except Exception as e:
|
||
error_msg = f"系统初始化失败: {str(e)}"
|
||
logger.error(error_msg)
|
||
logger.error(traceback.format_exc())
|
||
|
||
return jsonify({
|
||
'success': False,
|
||
'message': error_msg
|
||
}), 500
|
||
|
||
@app.route('/api/search/text_to_text', methods=['POST'])
|
||
def search_text_to_text():
|
||
"""文本搜索文本"""
|
||
return handle_search('text_to_text')
|
||
|
||
@app.route('/api/search/text_to_image', methods=['POST'])
|
||
def search_text_to_image():
|
||
"""文本搜索图片"""
|
||
return handle_search('text_to_image')
|
||
|
||
@app.route('/api/search/image_to_text', methods=['POST'])
|
||
def search_image_to_text():
|
||
"""图片搜索文本"""
|
||
return handle_search('image_to_text')
|
||
|
||
@app.route('/api/search/image_to_image', methods=['POST'])
|
||
def search_image_to_image():
|
||
"""图片搜索图片"""
|
||
return handle_search('image_to_image')
|
||
|
||
@app.route('/api/search', methods=['POST'])
|
||
def search():
|
||
"""通用搜索接口(兼容旧版本)"""
|
||
mode = request.form.get('mode') or request.json.get('mode', 'text_to_text')
|
||
return handle_search(mode)
|
||
|
||
def handle_search(mode):
|
||
"""处理搜索请求的通用函数"""
|
||
global retrieval_system
|
||
|
||
if not retrieval_system:
|
||
return jsonify({
|
||
'success': False,
|
||
'message': '系统未初始化,请先点击初始化按钮'
|
||
}), 400
|
||
|
||
try:
|
||
top_k = int(request.form.get('top_k', 5))
|
||
|
||
if mode in ['text_to_text', 'text_to_image']:
|
||
# 文本查询
|
||
query = request.form.get('query') or request.json.get('query', '')
|
||
if not query.strip():
|
||
return jsonify({
|
||
'success': False,
|
||
'message': '请输入查询文本'
|
||
}), 400
|
||
|
||
logger.info(f"执行{mode}搜索: {query}")
|
||
|
||
# 执行搜索
|
||
if mode == 'text_to_text':
|
||
raw_results = retrieval_system.search_text_to_text(query, top_k=top_k)
|
||
# 格式化文本搜索结果
|
||
results = []
|
||
for text, score in raw_results:
|
||
results.append({
|
||
'text': text,
|
||
'score': float(score)
|
||
})
|
||
else: # text_to_image
|
||
raw_results = retrieval_system.search_text_to_image(query, top_k=top_k)
|
||
# 格式化图像搜索结果
|
||
results = []
|
||
for image_path, score in raw_results:
|
||
try:
|
||
# 读取图像并转换为base64
|
||
with open(image_path, 'rb') as img_file:
|
||
image_data = img_file.read()
|
||
image_base64 = base64.b64encode(image_data).decode('utf-8')
|
||
|
||
results.append({
|
||
'filename': os.path.basename(image_path),
|
||
'image_path': image_path,
|
||
'image_base64': image_base64,
|
||
'score': float(score)
|
||
})
|
||
except Exception as e:
|
||
logger.error(f"读取图像失败 {image_path}: {e}")
|
||
results.append({
|
||
'filename': os.path.basename(image_path),
|
||
'image_path': image_path,
|
||
'image_base64': '',
|
||
'score': float(score)
|
||
})
|
||
|
||
return jsonify({
|
||
'success': True,
|
||
'mode': mode,
|
||
'query': query,
|
||
'results': results,
|
||
'result_count': len(results)
|
||
})
|
||
|
||
elif mode in ['image_to_text', 'image_to_image']:
|
||
# 图片查询
|
||
if 'image' not in request.files:
|
||
return jsonify({
|
||
'success': False,
|
||
'message': '请上传查询图片'
|
||
}), 400
|
||
|
||
file = request.files['image']
|
||
if file.filename == '' or not allowed_file(file.filename):
|
||
return jsonify({
|
||
'success': False,
|
||
'message': '请上传有效的图片文件'
|
||
}), 400
|
||
|
||
# 保存上传的图片
|
||
filename = secure_filename(file.filename)
|
||
timestamp = str(int(time.time()))
|
||
filename = f"query_{timestamp}_{filename}"
|
||
filepath = os.path.join(UPLOAD_FOLDER, filename)
|
||
file.save(filepath)
|
||
|
||
logger.info(f"执行{mode}搜索,图片: {filename}")
|
||
|
||
# 执行搜索
|
||
if mode == 'image_to_text':
|
||
raw_results = retrieval_system.search_image_to_text(filepath, top_k=top_k)
|
||
# 格式化文本搜索结果
|
||
results = []
|
||
for text, score in raw_results:
|
||
results.append({
|
||
'text': text,
|
||
'score': float(score)
|
||
})
|
||
else: # image_to_image
|
||
raw_results = retrieval_system.search_image_to_image(filepath, top_k=top_k)
|
||
# 格式化图像搜索结果
|
||
results = []
|
||
for image_path, score in raw_results:
|
||
try:
|
||
# 读取图像并转换为base64
|
||
with open(image_path, 'rb') as img_file:
|
||
image_data = img_file.read()
|
||
image_base64 = base64.b64encode(image_data).decode('utf-8')
|
||
|
||
results.append({
|
||
'filename': os.path.basename(image_path),
|
||
'image_path': image_path,
|
||
'image_base64': image_base64,
|
||
'score': float(score)
|
||
})
|
||
except Exception as e:
|
||
logger.error(f"读取图像失败 {image_path}: {e}")
|
||
results.append({
|
||
'filename': os.path.basename(image_path),
|
||
'image_path': image_path,
|
||
'image_base64': '',
|
||
'score': float(score)
|
||
})
|
||
|
||
# 转换查询图片为base64
|
||
query_image_b64 = image_to_base64(filepath)
|
||
|
||
return jsonify({
|
||
'success': True,
|
||
'mode': mode,
|
||
'query_image': query_image_b64,
|
||
'results': results,
|
||
'result_count': len(results)
|
||
})
|
||
|
||
else:
|
||
return jsonify({
|
||
'success': False,
|
||
'message': f'不支持的搜索模式: {mode}'
|
||
}), 400
|
||
|
||
except Exception as e:
|
||
error_msg = f"搜索失败: {str(e)}"
|
||
logger.error(error_msg)
|
||
logger.error(traceback.format_exc())
|
||
|
||
return jsonify({
|
||
'success': False,
|
||
'message': error_msg
|
||
}), 500
|
||
|
||
@app.route('/api/upload/images', methods=['POST'])
|
||
def upload_images():
|
||
"""批量上传图片"""
|
||
try:
|
||
uploaded_files = []
|
||
|
||
if 'images' not in request.files:
|
||
return jsonify({'success': False, 'message': '没有选择文件'}), 400
|
||
|
||
files = request.files.getlist('images')
|
||
|
||
for file in files:
|
||
if file and file.filename != '' and allowed_file(file.filename):
|
||
filename = secure_filename(file.filename)
|
||
timestamp = str(int(time.time()))
|
||
filename = f"{timestamp}_{filename}"
|
||
filepath = os.path.join(SAMPLE_IMAGES_FOLDER, filename)
|
||
file.save(filepath)
|
||
uploaded_files.append(filename)
|
||
|
||
return jsonify({
|
||
'success': True,
|
||
'message': f'成功上传 {len(uploaded_files)} 个图片文件',
|
||
'uploaded_count': len(uploaded_files),
|
||
'files': uploaded_files
|
||
})
|
||
|
||
except Exception as e:
|
||
return jsonify({
|
||
'success': False,
|
||
'message': f'上传失败: {str(e)}'
|
||
}), 500
|
||
|
||
@app.route('/api/upload/texts', methods=['POST'])
|
||
def upload_texts():
|
||
"""批量上传文本数据"""
|
||
try:
|
||
data = request.get_json()
|
||
|
||
if not data or 'texts' not in data:
|
||
return jsonify({'success': False, 'message': '没有提供文本数据'}), 400
|
||
|
||
texts = data['texts']
|
||
if not isinstance(texts, list):
|
||
return jsonify({'success': False, 'message': '文本数据格式错误'}), 400
|
||
|
||
# 保存文本数据到文件
|
||
timestamp = str(int(time.time()))
|
||
filename = f"texts_{timestamp}.json"
|
||
filepath = os.path.join(TEXT_DATA_FOLDER, filename)
|
||
|
||
with open(filepath, 'w', encoding='utf-8') as f:
|
||
json.dump(texts, f, ensure_ascii=False, indent=2)
|
||
|
||
return jsonify({
|
||
'success': True,
|
||
'message': f'成功上传 {len(texts)} 条文本',
|
||
'uploaded_count': len(texts)
|
||
})
|
||
|
||
except Exception as e:
|
||
return jsonify({
|
||
'success': False,
|
||
'message': f'上传失败: {str(e)}'
|
||
}), 500
|
||
|
||
@app.route('/api/upload/file', methods=['POST'])
|
||
def upload_single_file():
|
||
"""上传单个文件"""
|
||
if 'file' not in request.files:
|
||
return jsonify({'success': False, 'message': '没有选择文件'}), 400
|
||
|
||
file = request.files['file']
|
||
if file.filename == '':
|
||
return jsonify({'success': False, 'message': '没有选择文件'}), 400
|
||
|
||
if file and allowed_file(file.filename):
|
||
filename = secure_filename(file.filename)
|
||
timestamp = str(int(time.time()))
|
||
filename = f"{timestamp}_{filename}"
|
||
filepath = os.path.join(SAMPLE_IMAGES_FOLDER, filename)
|
||
file.save(filepath)
|
||
|
||
return jsonify({
|
||
'success': True,
|
||
'message': '文件上传成功',
|
||
'filename': filename
|
||
})
|
||
|
||
return jsonify({'success': False, 'message': '不支持的文件类型'}), 400
|
||
|
||
@app.route('/api/data/list', methods=['GET'])
|
||
def list_data():
|
||
"""列出已上传的数据"""
|
||
try:
|
||
# 列出图片文件
|
||
images = []
|
||
if os.path.exists(SAMPLE_IMAGES_FOLDER):
|
||
for filename in os.listdir(SAMPLE_IMAGES_FOLDER):
|
||
if allowed_file(filename):
|
||
filepath = os.path.join(SAMPLE_IMAGES_FOLDER, filename)
|
||
stat = os.stat(filepath)
|
||
images.append({
|
||
'filename': filename,
|
||
'size': stat.st_size,
|
||
'modified': stat.st_mtime
|
||
})
|
||
|
||
# 列出文本文件
|
||
texts = []
|
||
if os.path.exists(TEXT_DATA_FOLDER):
|
||
for filename in os.listdir(TEXT_DATA_FOLDER):
|
||
if filename.endswith('.json'):
|
||
filepath = os.path.join(TEXT_DATA_FOLDER, filename)
|
||
stat = os.stat(filepath)
|
||
texts.append({
|
||
'filename': filename,
|
||
'size': stat.st_size,
|
||
'modified': stat.st_mtime
|
||
})
|
||
|
||
return jsonify({
|
||
'success': True,
|
||
'data': {
|
||
'images': images,
|
||
'texts': texts
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
return jsonify({
|
||
'success': False,
|
||
'message': f'获取数据列表失败: {str(e)}'
|
||
}), 500
|
||
|
||
@app.route('/api/gpu_status')
|
||
def gpu_status():
|
||
"""获取GPU状态"""
|
||
try:
|
||
from smart_gpu_launcher import get_gpu_memory_info
|
||
gpu_info = get_gpu_memory_info()
|
||
|
||
return jsonify({
|
||
'success': True,
|
||
'gpu_info': gpu_info
|
||
})
|
||
|
||
except Exception as e:
|
||
return jsonify({
|
||
'success': False,
|
||
'message': f"获取GPU状态失败: {str(e)}"
|
||
}), 500
|
||
|
||
@app.route('/api/build_index', methods=['POST'])
|
||
def build_index():
|
||
"""构建检索索引"""
|
||
global retrieval_system
|
||
|
||
if not retrieval_system:
|
||
return jsonify({
|
||
'success': False,
|
||
'message': '系统未初始化'
|
||
}), 400
|
||
|
||
try:
|
||
# 获取所有图片和文本文件
|
||
image_files = []
|
||
text_data = []
|
||
|
||
# 扫描图片文件
|
||
for ext in ALLOWED_EXTENSIONS:
|
||
pattern = os.path.join(SAMPLE_IMAGES_FOLDER, f"*.{ext}")
|
||
image_files.extend(glob.glob(pattern))
|
||
|
||
# 读取文本文件(支持.json和.txt格式)
|
||
text_files = glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.json"))
|
||
text_files.extend(glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.txt")))
|
||
|
||
for text_file in text_files:
|
||
try:
|
||
if text_file.endswith('.json'):
|
||
# 读取JSON格式的文本数据
|
||
with open(text_file, 'r', encoding='utf-8') as f:
|
||
data = json.load(f)
|
||
if isinstance(data, list):
|
||
text_data.extend([str(item).strip() for item in data if str(item).strip()])
|
||
else:
|
||
text_data.append(str(data).strip())
|
||
else:
|
||
# 读取TXT格式的文本数据
|
||
with open(text_file, 'r', encoding='utf-8') as f:
|
||
lines = [line.strip() for line in f.readlines() if line.strip()]
|
||
text_data.extend(lines)
|
||
except Exception as e:
|
||
logger.warning(f"读取文本文件失败 {text_file}: {e}")
|
||
|
||
# 检查是否有数据可以构建索引
|
||
if not image_files and not text_data:
|
||
return jsonify({
|
||
'success': False,
|
||
'message': '没有找到可用的图片或文本数据,请先上传数据'
|
||
}), 400
|
||
|
||
# 构建索引
|
||
if image_files:
|
||
logger.info(f"构建图片索引,共 {len(image_files)} 张图片")
|
||
retrieval_system.build_image_index_parallel(image_files)
|
||
|
||
if text_data:
|
||
logger.info(f"构建文本索引,共 {len(text_data)} 条文本")
|
||
retrieval_system.build_text_index_parallel(text_data)
|
||
|
||
return jsonify({
|
||
'success': True,
|
||
'message': f'索引构建完成!图片: {len(image_files)} 张,文本: {len(text_data)} 条',
|
||
'image_count': len(image_files),
|
||
'text_count': len(text_data)
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"构建索引失败: {str(e)}")
|
||
return jsonify({
|
||
'success': False,
|
||
'message': f'构建索引失败: {str(e)}'
|
||
}), 500
|
||
|
||
@app.route('/api/data/stats', methods=['GET'])
|
||
def get_data_stats():
|
||
"""获取数据统计信息"""
|
||
try:
|
||
# 统计图片文件
|
||
image_count = 0
|
||
for ext in ALLOWED_EXTENSIONS:
|
||
pattern = os.path.join(SAMPLE_IMAGES_FOLDER, f"*.{ext}")
|
||
image_count += len(glob.glob(pattern))
|
||
|
||
# 统计文本数据
|
||
text_count = 0
|
||
text_files = glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.txt"))
|
||
for text_file in text_files:
|
||
try:
|
||
with open(text_file, 'r', encoding='utf-8') as f:
|
||
lines = [line.strip() for line in f.readlines() if line.strip()]
|
||
text_count += len(lines)
|
||
except Exception:
|
||
continue
|
||
|
||
return jsonify({
|
||
'success': True,
|
||
'image_count': image_count,
|
||
'text_count': text_count
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取数据统计失败: {str(e)}")
|
||
return jsonify({
|
||
'success': False,
|
||
'message': f'获取统计失败: {str(e)}'
|
||
}), 500
|
||
|
||
@app.route('/api/data/clear', methods=['POST'])
|
||
def clear_data():
|
||
"""清空所有数据"""
|
||
try:
|
||
# 清空图片文件
|
||
for ext in ALLOWED_EXTENSIONS:
|
||
pattern = os.path.join(SAMPLE_IMAGES_FOLDER, f"*.{ext}")
|
||
for file_path in glob.glob(pattern):
|
||
try:
|
||
os.remove(file_path)
|
||
except Exception as e:
|
||
logger.warning(f"删除图片文件失败 {file_path}: {e}")
|
||
|
||
# 清空文本文件
|
||
text_files = glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.txt"))
|
||
for text_file in text_files:
|
||
try:
|
||
os.remove(text_file)
|
||
except Exception as e:
|
||
logger.warning(f"删除文本文件失败 {text_file}: {e}")
|
||
|
||
# 重置索引
|
||
global retrieval_system
|
||
if retrieval_system:
|
||
retrieval_system.text_index = None
|
||
retrieval_system.image_index = None
|
||
retrieval_system.text_data = []
|
||
retrieval_system.image_data = []
|
||
|
||
return jsonify({
|
||
'success': True,
|
||
'message': '数据已清空'
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"清空数据失败: {str(e)}")
|
||
return jsonify({
|
||
'success': False,
|
||
'message': f'清空数据失败: {str(e)}'
|
||
}), 500
|
||
|
||
@app.route('/uploads/<filename>')
|
||
def uploaded_file(filename):
|
||
"""提供上传文件的访问"""
|
||
return send_file(os.path.join(SAMPLE_IMAGES_FOLDER, filename))
|
||
|
||
def print_startup_info():
|
||
"""打印启动信息"""
|
||
print("🚀 启动多GPU多模态检索Web应用")
|
||
print("=" * 60)
|
||
print("访问地址: http://localhost:5000")
|
||
print("支持功能:")
|
||
print(" 📝 文搜文 - 文本查找相似文本")
|
||
print(" 🖼️ 文搜图 - 文本查找相关图片")
|
||
print(" 📝 图搜文 - 图片查找相关文本")
|
||
print(" 🖼️ 图搜图 - 图片查找相似图片")
|
||
print(" 📤 批量上传 - 图片和文本数据管理")
|
||
print("GPU配置:")
|
||
|
||
try:
|
||
import torch
|
||
if torch.cuda.is_available():
|
||
gpu_count = torch.cuda.device_count()
|
||
print(f" 🖥️ 检测到 {gpu_count} 个GPU")
|
||
for i in range(gpu_count):
|
||
name = torch.cuda.get_device_name(i)
|
||
props = torch.cuda.get_device_properties(i)
|
||
memory_gb = props.total_memory / 1024**3
|
||
print(f" GPU {i}: {name} ({memory_gb:.1f}GB)")
|
||
else:
|
||
print(" ❌ CUDA不可用")
|
||
except Exception as e:
|
||
print(f" ❌ GPU检查失败: {e}")
|
||
|
||
print("=" * 60)
|
||
|
||
def auto_initialize():
|
||
"""启动时自动初始化系统"""
|
||
global retrieval_system
|
||
|
||
try:
|
||
logger.info("🚀 启动时自动初始化多GPU检索系统...")
|
||
|
||
# 导入多GPU检索系统
|
||
from multimodal_retrieval_multigpu import MultiGPUMultimodalRetrieval
|
||
|
||
# 初始化系统
|
||
retrieval_system = MultiGPUMultimodalRetrieval()
|
||
|
||
if retrieval_system.model is None:
|
||
raise Exception("模型加载失败")
|
||
|
||
logger.info("✅ 系统自动初始化成功")
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 系统自动初始化失败: {str(e)}")
|
||
logger.error(traceback.format_exc())
|
||
return False
|
||
|
||
if __name__ == '__main__':
|
||
print_startup_info()
|
||
|
||
# 启动时自动初始化
|
||
auto_initialize()
|
||
|
||
# 启动Flask应用
|
||
app.run(
|
||
host='0.0.0.0',
|
||
port=5000,
|
||
debug=False,
|
||
threaded=True
|
||
)
|