#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 多GPU多模态检索系统 - Web应用 专为双GPU部署优化 """ import os import json import time from flask import Flask, render_template, request, jsonify, send_file, url_for from werkzeug.utils import secure_filename from PIL import Image import base64 import io import logging import traceback import glob # 设置环境变量优化GPU内存 os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' # 配置日志 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = Flask(__name__) app.config['SECRET_KEY'] = 'multigpu_multimodal_retrieval_2024' app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max file size # 配置上传文件夹 UPLOAD_FOLDER = 'uploads' SAMPLE_IMAGES_FOLDER = 'sample_images' TEXT_DATA_FOLDER = 'text_data' ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'bmp', 'webp'} # 确保文件夹存在 os.makedirs(UPLOAD_FOLDER, exist_ok=True) os.makedirs(SAMPLE_IMAGES_FOLDER, exist_ok=True) os.makedirs(TEXT_DATA_FOLDER, exist_ok=True) # 全局检索系统实例 retrieval_system = None def allowed_file(filename): """检查文件扩展名是否允许""" return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS def image_to_base64(image_path): """将图片转换为base64编码""" try: with open(image_path, "rb") as img_file: return base64.b64encode(img_file.read()).decode('utf-8') except Exception as e: logger.error(f"图片转换失败: {e}") return None @app.route('/') def index(): """主页""" return render_template('index.html') @app.route('/api/status') def get_status(): """获取系统状态""" global retrieval_system status = { 'initialized': retrieval_system is not None, 'gpu_count': 0, 'model_loaded': False } try: import torch if torch.cuda.is_available(): status['gpu_count'] = torch.cuda.device_count() if retrieval_system and retrieval_system.model: status['model_loaded'] = True status['device_ids'] = retrieval_system.device_ids except Exception as e: logger.error(f"获取状态失败: {e}") return jsonify(status) @app.route('/api/init', methods=['POST']) def initialize_system(): """初始化多GPU检索系统""" global retrieval_system try: logger.info("正在初始化多GPU检索系统...") # 导入多GPU检索系统 from multimodal_retrieval_multigpu import MultiGPUMultimodalRetrieval # 初始化系统 retrieval_system = MultiGPUMultimodalRetrieval() if retrieval_system.model is None: raise Exception("模型加载失败") logger.info("✅ 多GPU系统初始化成功") return jsonify({ 'success': True, 'message': '多GPU系统初始化成功', 'device_ids': retrieval_system.device_ids, 'gpu_count': len(retrieval_system.device_ids) }) except Exception as e: error_msg = f"系统初始化失败: {str(e)}" logger.error(error_msg) logger.error(traceback.format_exc()) return jsonify({ 'success': False, 'message': error_msg }), 500 @app.route('/api/search/text_to_text', methods=['POST']) def search_text_to_text(): """文本搜索文本""" return handle_search('text_to_text') @app.route('/api/search/text_to_image', methods=['POST']) def search_text_to_image(): """文本搜索图片""" return handle_search('text_to_image') @app.route('/api/search/image_to_text', methods=['POST']) def search_image_to_text(): """图片搜索文本""" return handle_search('image_to_text') @app.route('/api/search/image_to_image', methods=['POST']) def search_image_to_image(): """图片搜索图片""" return handle_search('image_to_image') @app.route('/api/search', methods=['POST']) def search(): """通用搜索接口(兼容旧版本)""" mode = request.form.get('mode') or request.json.get('mode', 'text_to_text') return handle_search(mode) def handle_search(mode): """处理搜索请求的通用函数""" global retrieval_system if not retrieval_system: return jsonify({ 'success': False, 'message': '系统未初始化,请先点击初始化按钮' }), 400 try: top_k = int(request.form.get('top_k', 5)) if mode in ['text_to_text', 'text_to_image']: # 文本查询 query = request.form.get('query') or request.json.get('query', '') if not query.strip(): return jsonify({ 'success': False, 'message': '请输入查询文本' }), 400 logger.info(f"执行{mode}搜索: {query}") # 执行搜索 if mode == 'text_to_text': raw_results = retrieval_system.search_text_to_text(query, top_k=top_k) # 格式化文本搜索结果 results = [] for text, score in raw_results: results.append({ 'text': text, 'score': float(score) }) else: # text_to_image raw_results = retrieval_system.search_text_to_image(query, top_k=top_k) # 格式化图像搜索结果 results = [] for image_path, score in raw_results: try: # 读取图像并转换为base64 with open(image_path, 'rb') as img_file: image_data = img_file.read() image_base64 = base64.b64encode(image_data).decode('utf-8') results.append({ 'filename': os.path.basename(image_path), 'image_path': image_path, 'image_base64': image_base64, 'score': float(score) }) except Exception as e: logger.error(f"读取图像失败 {image_path}: {e}") results.append({ 'filename': os.path.basename(image_path), 'image_path': image_path, 'image_base64': '', 'score': float(score) }) return jsonify({ 'success': True, 'mode': mode, 'query': query, 'results': results, 'result_count': len(results) }) elif mode in ['image_to_text', 'image_to_image']: # 图片查询 if 'image' not in request.files: return jsonify({ 'success': False, 'message': '请上传查询图片' }), 400 file = request.files['image'] if file.filename == '' or not allowed_file(file.filename): return jsonify({ 'success': False, 'message': '请上传有效的图片文件' }), 400 # 保存上传的图片 filename = secure_filename(file.filename) timestamp = str(int(time.time())) filename = f"query_{timestamp}_{filename}" filepath = os.path.join(UPLOAD_FOLDER, filename) file.save(filepath) logger.info(f"执行{mode}搜索,图片: {filename}") # 执行搜索 if mode == 'image_to_text': raw_results = retrieval_system.search_image_to_text(filepath, top_k=top_k) # 格式化文本搜索结果 results = [] for text, score in raw_results: results.append({ 'text': text, 'score': float(score) }) else: # image_to_image raw_results = retrieval_system.search_image_to_image(filepath, top_k=top_k) # 格式化图像搜索结果 results = [] for image_path, score in raw_results: try: # 读取图像并转换为base64 with open(image_path, 'rb') as img_file: image_data = img_file.read() image_base64 = base64.b64encode(image_data).decode('utf-8') results.append({ 'filename': os.path.basename(image_path), 'image_path': image_path, 'image_base64': image_base64, 'score': float(score) }) except Exception as e: logger.error(f"读取图像失败 {image_path}: {e}") results.append({ 'filename': os.path.basename(image_path), 'image_path': image_path, 'image_base64': '', 'score': float(score) }) # 转换查询图片为base64 query_image_b64 = image_to_base64(filepath) return jsonify({ 'success': True, 'mode': mode, 'query_image': query_image_b64, 'results': results, 'result_count': len(results) }) else: return jsonify({ 'success': False, 'message': f'不支持的搜索模式: {mode}' }), 400 except Exception as e: error_msg = f"搜索失败: {str(e)}" logger.error(error_msg) logger.error(traceback.format_exc()) return jsonify({ 'success': False, 'message': error_msg }), 500 @app.route('/api/upload/images', methods=['POST']) def upload_images(): """批量上传图片""" try: uploaded_files = [] if 'images' not in request.files: return jsonify({'success': False, 'message': '没有选择文件'}), 400 files = request.files.getlist('images') for file in files: if file and file.filename != '' and allowed_file(file.filename): filename = secure_filename(file.filename) timestamp = str(int(time.time())) filename = f"{timestamp}_{filename}" filepath = os.path.join(SAMPLE_IMAGES_FOLDER, filename) file.save(filepath) uploaded_files.append(filename) return jsonify({ 'success': True, 'message': f'成功上传 {len(uploaded_files)} 个图片文件', 'uploaded_count': len(uploaded_files), 'files': uploaded_files }) except Exception as e: return jsonify({ 'success': False, 'message': f'上传失败: {str(e)}' }), 500 @app.route('/api/upload/texts', methods=['POST']) def upload_texts(): """批量上传文本数据""" try: data = request.get_json() if not data or 'texts' not in data: return jsonify({'success': False, 'message': '没有提供文本数据'}), 400 texts = data['texts'] if not isinstance(texts, list): return jsonify({'success': False, 'message': '文本数据格式错误'}), 400 # 保存文本数据到文件 timestamp = str(int(time.time())) filename = f"texts_{timestamp}.json" filepath = os.path.join(TEXT_DATA_FOLDER, filename) with open(filepath, 'w', encoding='utf-8') as f: json.dump(texts, f, ensure_ascii=False, indent=2) return jsonify({ 'success': True, 'message': f'成功上传 {len(texts)} 条文本', 'uploaded_count': len(texts) }) except Exception as e: return jsonify({ 'success': False, 'message': f'上传失败: {str(e)}' }), 500 @app.route('/api/upload/file', methods=['POST']) def upload_single_file(): """上传单个文件""" if 'file' not in request.files: return jsonify({'success': False, 'message': '没有选择文件'}), 400 file = request.files['file'] if file.filename == '': return jsonify({'success': False, 'message': '没有选择文件'}), 400 if file and allowed_file(file.filename): filename = secure_filename(file.filename) timestamp = str(int(time.time())) filename = f"{timestamp}_{filename}" filepath = os.path.join(SAMPLE_IMAGES_FOLDER, filename) file.save(filepath) return jsonify({ 'success': True, 'message': '文件上传成功', 'filename': filename }) return jsonify({'success': False, 'message': '不支持的文件类型'}), 400 @app.route('/api/data/list', methods=['GET']) def list_data(): """列出已上传的数据""" try: # 列出图片文件 images = [] if os.path.exists(SAMPLE_IMAGES_FOLDER): for filename in os.listdir(SAMPLE_IMAGES_FOLDER): if allowed_file(filename): filepath = os.path.join(SAMPLE_IMAGES_FOLDER, filename) stat = os.stat(filepath) images.append({ 'filename': filename, 'size': stat.st_size, 'modified': stat.st_mtime }) # 列出文本文件 texts = [] if os.path.exists(TEXT_DATA_FOLDER): for filename in os.listdir(TEXT_DATA_FOLDER): if filename.endswith('.json'): filepath = os.path.join(TEXT_DATA_FOLDER, filename) stat = os.stat(filepath) texts.append({ 'filename': filename, 'size': stat.st_size, 'modified': stat.st_mtime }) return jsonify({ 'success': True, 'data': { 'images': images, 'texts': texts } }) except Exception as e: return jsonify({ 'success': False, 'message': f'获取数据列表失败: {str(e)}' }), 500 @app.route('/api/gpu_status') def gpu_status(): """获取GPU状态""" try: from smart_gpu_launcher import get_gpu_memory_info gpu_info = get_gpu_memory_info() return jsonify({ 'success': True, 'gpu_info': gpu_info }) except Exception as e: return jsonify({ 'success': False, 'message': f"获取GPU状态失败: {str(e)}" }), 500 @app.route('/api/build_index', methods=['POST']) def build_index(): """构建检索索引""" global retrieval_system if not retrieval_system: return jsonify({ 'success': False, 'message': '系统未初始化' }), 400 try: # 获取所有图片和文本文件 image_files = [] text_data = [] # 扫描图片文件 for ext in ALLOWED_EXTENSIONS: pattern = os.path.join(SAMPLE_IMAGES_FOLDER, f"*.{ext}") image_files.extend(glob.glob(pattern)) # 读取文本文件(支持.json和.txt格式) text_files = glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.json")) text_files.extend(glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.txt"))) for text_file in text_files: try: if text_file.endswith('.json'): # 读取JSON格式的文本数据 with open(text_file, 'r', encoding='utf-8') as f: data = json.load(f) if isinstance(data, list): text_data.extend([str(item).strip() for item in data if str(item).strip()]) else: text_data.append(str(data).strip()) else: # 读取TXT格式的文本数据 with open(text_file, 'r', encoding='utf-8') as f: lines = [line.strip() for line in f.readlines() if line.strip()] text_data.extend(lines) except Exception as e: logger.warning(f"读取文本文件失败 {text_file}: {e}") # 检查是否有数据可以构建索引 if not image_files and not text_data: return jsonify({ 'success': False, 'message': '没有找到可用的图片或文本数据,请先上传数据' }), 400 # 构建索引 if image_files: logger.info(f"构建图片索引,共 {len(image_files)} 张图片") retrieval_system.build_image_index_parallel(image_files) if text_data: logger.info(f"构建文本索引,共 {len(text_data)} 条文本") retrieval_system.build_text_index_parallel(text_data) return jsonify({ 'success': True, 'message': f'索引构建完成!图片: {len(image_files)} 张,文本: {len(text_data)} 条', 'image_count': len(image_files), 'text_count': len(text_data) }) except Exception as e: logger.error(f"构建索引失败: {str(e)}") return jsonify({ 'success': False, 'message': f'构建索引失败: {str(e)}' }), 500 @app.route('/api/data/stats', methods=['GET']) def get_data_stats(): """获取数据统计信息""" try: # 统计图片文件 image_count = 0 for ext in ALLOWED_EXTENSIONS: pattern = os.path.join(SAMPLE_IMAGES_FOLDER, f"*.{ext}") image_count += len(glob.glob(pattern)) # 统计文本数据 text_count = 0 text_files = glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.txt")) for text_file in text_files: try: with open(text_file, 'r', encoding='utf-8') as f: lines = [line.strip() for line in f.readlines() if line.strip()] text_count += len(lines) except Exception: continue return jsonify({ 'success': True, 'image_count': image_count, 'text_count': text_count }) except Exception as e: logger.error(f"获取数据统计失败: {str(e)}") return jsonify({ 'success': False, 'message': f'获取统计失败: {str(e)}' }), 500 @app.route('/api/data/clear', methods=['POST']) def clear_data(): """清空所有数据""" try: # 清空图片文件 for ext in ALLOWED_EXTENSIONS: pattern = os.path.join(SAMPLE_IMAGES_FOLDER, f"*.{ext}") for file_path in glob.glob(pattern): try: os.remove(file_path) except Exception as e: logger.warning(f"删除图片文件失败 {file_path}: {e}") # 清空文本文件 text_files = glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.txt")) for text_file in text_files: try: os.remove(text_file) except Exception as e: logger.warning(f"删除文本文件失败 {text_file}: {e}") # 重置索引 global retrieval_system if retrieval_system: retrieval_system.text_index = None retrieval_system.image_index = None retrieval_system.text_data = [] retrieval_system.image_data = [] return jsonify({ 'success': True, 'message': '数据已清空' }) except Exception as e: logger.error(f"清空数据失败: {str(e)}") return jsonify({ 'success': False, 'message': f'清空数据失败: {str(e)}' }), 500 @app.route('/uploads/') def uploaded_file(filename): """提供上传文件的访问""" return send_file(os.path.join(SAMPLE_IMAGES_FOLDER, filename)) def print_startup_info(): """打印启动信息""" print("🚀 启动多GPU多模态检索Web应用") print("=" * 60) print("访问地址: http://localhost:5000") print("支持功能:") print(" 📝 文搜文 - 文本查找相似文本") print(" 🖼️ 文搜图 - 文本查找相关图片") print(" 📝 图搜文 - 图片查找相关文本") print(" 🖼️ 图搜图 - 图片查找相似图片") print(" 📤 批量上传 - 图片和文本数据管理") print("GPU配置:") try: import torch if torch.cuda.is_available(): gpu_count = torch.cuda.device_count() print(f" 🖥️ 检测到 {gpu_count} 个GPU") for i in range(gpu_count): name = torch.cuda.get_device_name(i) props = torch.cuda.get_device_properties(i) memory_gb = props.total_memory / 1024**3 print(f" GPU {i}: {name} ({memory_gb:.1f}GB)") else: print(" ❌ CUDA不可用") except Exception as e: print(f" ❌ GPU检查失败: {e}") print("=" * 60) def auto_initialize(): """启动时自动初始化系统""" global retrieval_system try: logger.info("🚀 启动时自动初始化多GPU检索系统...") # 导入多GPU检索系统 from multimodal_retrieval_multigpu import MultiGPUMultimodalRetrieval # 初始化系统 retrieval_system = MultiGPUMultimodalRetrieval() if retrieval_system.model is None: raise Exception("模型加载失败") logger.info("✅ 系统自动初始化成功") return True except Exception as e: logger.error(f"❌ 系统自动初始化失败: {str(e)}") logger.error(traceback.format_exc()) return False if __name__ == '__main__': print_startup_info() # 启动时自动初始化 auto_initialize() # 启动Flask应用 app.run( host='0.0.0.0', port=5000, debug=False, threaded=True )