mmeb/web_app_local.py
2025-09-22 10:13:11 +00:00

467 lines
16 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
本地多模态检索系统Web应用
集成本地模型和FAISS向量数据库
支持文搜文、文搜图、图搜文、图搜图四种检索模式
"""
import os
import sys
import logging
import time
import json
import base64
from io import BytesIO
from pathlib import Path
import numpy as np
from PIL import Image
from flask import Flask, request, jsonify, render_template, send_from_directory
from werkzeug.utils import secure_filename
import torch
# 设置离线模式
os.environ['TRANSFORMERS_OFFLINE'] = '1'
# 导入本地模块
from multimodal_retrieval_local import MultimodalRetrievalLocal
from optimized_file_handler import OptimizedFileHandler
# 设置日志
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# 创建Flask应用
app = Flask(__name__)
# 配置
app.config['UPLOAD_FOLDER'] = '/tmp/mmeb_uploads'
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB
app.config['MODEL_PATH'] = '/root/models/Ops-MM-embedding-v1-7B'
app.config['INDEX_PATH'] = '/root/mmeb/local_faiss_index'
app.config['ALLOWED_EXTENSIONS'] = {'txt', 'pdf', 'png', 'jpg', 'jpeg', 'gif'}
# 确保上传目录存在
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
# 创建临时文件夹
if not os.path.exists(app.config['UPLOAD_FOLDER']):
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
# 创建文件处理器
from optimized_file_handler import OptimizedFileHandler
file_handler = OptimizedFileHandler(local_storage_dir=app.config['UPLOAD_FOLDER'])
# 全局变量
retrieval_system = None
def allowed_file(filename):
"""检查文件扩展名是否允许"""
return '.' in filename and \
filename.rsplit('.', 1)[1].lower() in app.config['ALLOWED_EXTENSIONS']
def init_retrieval_system():
"""初始化检索系统"""
global retrieval_system
if retrieval_system is not None:
return retrieval_system
logger.info("初始化多模态检索系统...")
# 检查模型路径
model_path = app.config['MODEL_PATH']
if not os.path.exists(model_path):
logger.error(f"模型路径不存在: {model_path}")
raise FileNotFoundError(f"模型路径不存在: {model_path}")
# 初始化检索系统
retrieval_system = MultimodalRetrievalLocal(
model_path=model_path,
use_all_gpus=True,
index_path=app.config['INDEX_PATH']
)
logger.info("多模态检索系统初始化完成")
return retrieval_system
def get_image_base64(image_path):
"""将图像转换为base64编码"""
with open(image_path, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
return f"data:image/jpeg;base64,{encoded_string}"
@app.route('/')
def index():
"""首页"""
return render_template('local_index.html')
@app.route('/api/stats', methods=['GET'])
def get_stats():
"""获取系统统计信息"""
try:
retrieval = init_retrieval_system()
stats = retrieval.get_stats()
return jsonify({"success": True, "stats": stats})
except Exception as e:
logger.error(f"获取统计信息失败: {str(e)}")
return jsonify({"success": False, "error": str(e)}), 500
@app.route('/api/add_text', methods=['POST'])
def add_text():
"""添加文本"""
try:
data = request.json
text = data.get('text')
if not text:
return jsonify({"success": False, "error": "文本不能为空"}), 400
# 使用内存处理文本
with file_handler.temp_file_context(text.encode('utf-8'), suffix='.txt') as temp_file:
logger.info(f"处理文本: {temp_file}")
# 初始化检索系统
retrieval = init_retrieval_system()
# 添加文本
metadata = {
"timestamp": time.time(),
"source": "web_upload"
}
text_ids = retrieval.add_texts([text], [metadata])
# 保存索引
retrieval.save_index()
return jsonify({
"success": True,
"message": "文本添加成功",
"text_id": text_ids[0] if text_ids else None
})
except Exception as e:
logger.error(f"添加文本失败: {str(e)}")
return jsonify({"success": False, "error": str(e)}), 500
finally:
# 清理临时文件
file_handler.cleanup_all_temp_files()
@app.route('/api/add_image', methods=['POST'])
def add_image():
"""添加图像"""
try:
# 检查是否有文件
if 'image' not in request.files:
return jsonify({"success": False, "error": "没有上传文件"}), 400
file = request.files['image']
# 检查文件名
if file.filename == '':
return jsonify({"success": False, "error": "没有选择文件"}), 400
if file and allowed_file(file.filename):
# 读取图像数据
image_data = file.read()
file_size = len(image_data)
# 使用文件处理器处理图像
logger.info(f"处理图像: {file.filename} ({file_size} 字节)")
# 初始化检索系统
retrieval = init_retrieval_system()
# 创建临时文件
file_obj = BytesIO(image_data)
filename = secure_filename(file.filename)
# 保存到本地文件系统
image_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
with open(image_path, 'wb') as f:
f.write(image_data)
# 加载图像
try:
image = Image.open(BytesIO(image_data))
# 确保图像是RGB模式
if image.mode != 'RGB':
logger.info(f"将图像从 {image.mode} 转换为 RGB")
image = image.convert('RGB')
logger.info(f"成功加载图像: {filename}, 格式: {image.format}, 模式: {image.mode}, 大小: {image.size}")
except Exception as e:
logger.error(f"加载图像失败: {filename}, 错误: {str(e)}")
return jsonify({"success": False, "error": f"图像格式不支持: {str(e)}"}), 400
# 添加图像
metadata = {
"filename": filename,
"timestamp": time.time(),
"source": "web_upload",
"size": file_size,
"local_path": image_path
}
# 添加到检索系统
image_ids = retrieval.add_images([image], [metadata], [image_path])
# 保存索引
retrieval.save_index()
return jsonify({
"success": True,
"message": "图像添加成功",
"image_id": image_ids[0] if image_ids else None
})
else:
return jsonify({"success": False, "error": "不支持的文件类型"}), 400
except Exception as e:
logger.error(f"添加图像失败: {str(e)}")
return jsonify({"success": False, "error": str(e)}), 500
finally:
# 清理临时文件
file_handler.cleanup_all_temp_files()
@app.route('/api/search_by_text', methods=['POST'])
def search_by_text():
"""文本搜索"""
try:
data = request.json
query = data.get('query')
k = int(data.get('k', 5))
filter_type = data.get('filter_type') # "text", "image" 或 null
if not query:
return jsonify({"success": False, "error": "查询文本不能为空"}), 400
# 初始化检索系统
retrieval = init_retrieval_system()
# 执行搜索
results = retrieval.search_by_text(query, k, filter_type)
# 处理结果
processed_results = []
for result in results:
item = {
"score": result.get("score", 0),
"type": result.get("type")
}
if result.get("type") == "text":
item["text"] = result.get("text", "")
elif result.get("type") == "image":
if "path" in result and os.path.exists(result["path"]):
item["image"] = get_image_base64(result["path"])
item["filename"] = os.path.basename(result["path"])
if "description" in result:
item["description"] = result["description"]
processed_results.append(item)
return jsonify({
"success": True,
"results": processed_results,
"query": query,
"filter_type": filter_type
})
except Exception as e:
logger.error(f"文本搜索失败: {str(e)}")
return jsonify({"success": False, "error": str(e)}), 500
@app.route('/api/search_by_image', methods=['POST'])
def search_by_image():
"""图像搜索"""
try:
# 检查是否有文件
if 'image' not in request.files:
return jsonify({"success": False, "error": "没有上传文件"}), 400
file = request.files['image']
k = int(request.form.get('k', 5))
filter_type = request.form.get('filter_type') # "text", "image" 或 null
# 检查文件名
if file.filename == '':
return jsonify({"success": False, "error": "没有选择文件"}), 400
if file and allowed_file(file.filename):
# 读取图像数据
image_data = file.read()
file_size = len(image_data)
# 根据文件大小选择处理方式
if file_size <= 5 * 1024 * 1024: # 5MB
# 小文件使用内存处理
logger.info(f"使用内存处理搜索图像: {file.filename} ({file_size} 字节)")
image = Image.open(BytesIO(image_data))
# 初始化检索系统
retrieval = init_retrieval_system()
# 执行搜索
results = retrieval.search_by_image(image, k, filter_type)
else:
# 大文件使用临时文件处理
with file_handler.temp_file_context(image_data, suffix=os.path.splitext(file.filename)[1]) as temp_file:
logger.info(f"使用临时文件处理搜索图像: {temp_file} ({file_size} 字节)")
# 初始化检索系统
retrieval = init_retrieval_system()
# 加载图像
image = Image.open(temp_file)
# 执行搜索
results = retrieval.search_by_image(image, k, filter_type)
# 处理结果
processed_results = []
for result in results:
item = {
"score": result.get("score", 0),
"type": result.get("type")
}
if result.get("type") == "text":
item["text"] = result.get("text", "")
elif result.get("type") == "image":
if "path" in result and os.path.exists(result["path"]):
item["image"] = get_image_base64(result["path"])
item["filename"] = os.path.basename(result["path"])
if "description" in result:
item["description"] = result["description"]
processed_results.append(item)
return jsonify({
"success": True,
"results": processed_results,
"filter_type": filter_type
})
else:
return jsonify({"success": False, "error": "不支持的文件类型"}), 400
except Exception as e:
logger.error(f"图像搜索失败: {str(e)}")
return jsonify({"success": False, "error": str(e)}), 500
finally:
# 清理临时文件
file_handler.cleanup_all_temp_files()
@app.route('/uploads/<filename>')
def uploaded_file(filename):
"""提供上传文件的访问"""
return send_from_directory(app.config['UPLOAD_FOLDER'], filename)
@app.route('/temp/<filename>')
def temp_file(filename):
"""提供临时文件的访问"""
return send_from_directory(app.config['UPLOAD_FOLDER'], filename)
@app.route('/api/save_index', methods=['POST'])
def save_index():
"""保存索引"""
try:
# 初始化检索系统
retrieval = init_retrieval_system()
# 保存索引
retrieval.save_index()
return jsonify({
"success": True,
"message": "索引保存成功"
})
except Exception as e:
logger.error(f"保存索引失败: {str(e)}")
return jsonify({"success": False, "error": str(e)}), 500
@app.route('/api/clear_index', methods=['POST'])
def clear_index():
"""清空索引"""
try:
# 初始化检索系统
retrieval = init_retrieval_system()
# 清空索引
retrieval.clear_index()
return jsonify({
"success": True,
"message": "索引已清空"
})
except Exception as e:
logger.error(f"清空索引失败: {str(e)}")
return jsonify({"success": False, "error": str(e)}), 500
@app.route('/api/list_items', methods=['GET'])
def list_items():
"""列出所有索引项"""
try:
# 初始化检索系统
retrieval = init_retrieval_system()
# 获取所有项
items = retrieval.list_items()
return jsonify({
"success": True,
"items": items
})
except Exception as e:
logger.error(f"列出索引项失败: {str(e)}")
return jsonify({"success": False, "error": str(e)}), 500
@app.route('/api/system_info', methods=['GET', 'POST'])
def system_info():
"""获取系统信息"""
try:
# GPU信息
gpu_info = []
if torch.cuda.is_available():
for i in range(torch.cuda.device_count()):
gpu_info.append({
"id": i,
"name": torch.cuda.get_device_name(i),
"memory_total": torch.cuda.get_device_properties(i).total_memory / (1024 ** 3),
"memory_allocated": torch.cuda.memory_allocated(i) / (1024 ** 3),
"memory_reserved": torch.cuda.memory_reserved(i) / (1024 ** 3)
})
# 检索系统信息
retrieval_info = {}
if retrieval_system is not None:
retrieval_info = retrieval_system.get_stats()
return jsonify({
"success": True,
"gpu_info": gpu_info,
"retrieval_info": retrieval_info,
"model_path": app.config['MODEL_PATH'],
"index_path": app.config['INDEX_PATH']
})
except Exception as e:
logger.error(f"获取系统信息失败: {str(e)}")
return jsonify({"success": False, "error": str(e)}), 500
if __name__ == '__main__':
try:
# 预初始化检索系统
init_retrieval_system()
# 启动Web应用
app.run(host='0.0.0.0', port=5000, debug=False)
except Exception as e:
logger.error(f"启动Web应用失败: {str(e)}")
sys.exit(1)