467 lines
16 KiB
Python
467 lines
16 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
"""
|
|
本地多模态检索系统Web应用
|
|
集成本地模型和FAISS向量数据库
|
|
支持文搜文、文搜图、图搜文、图搜图四种检索模式
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import logging
|
|
import time
|
|
import json
|
|
import base64
|
|
from io import BytesIO
|
|
from pathlib import Path
|
|
import numpy as np
|
|
from PIL import Image
|
|
from flask import Flask, request, jsonify, render_template, send_from_directory
|
|
from werkzeug.utils import secure_filename
|
|
import torch
|
|
|
|
# 设置离线模式
|
|
os.environ['TRANSFORMERS_OFFLINE'] = '1'
|
|
|
|
# 导入本地模块
|
|
from multimodal_retrieval_local import MultimodalRetrievalLocal
|
|
from optimized_file_handler import OptimizedFileHandler
|
|
|
|
# 设置日志
|
|
logging.basicConfig(level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# 创建Flask应用
|
|
app = Flask(__name__)
|
|
|
|
# 配置
|
|
app.config['UPLOAD_FOLDER'] = '/tmp/mmeb_uploads'
|
|
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB
|
|
app.config['MODEL_PATH'] = '/root/models/Ops-MM-embedding-v1-7B'
|
|
app.config['INDEX_PATH'] = '/root/mmeb/local_faiss_index'
|
|
app.config['ALLOWED_EXTENSIONS'] = {'txt', 'pdf', 'png', 'jpg', 'jpeg', 'gif'}
|
|
|
|
# 确保上传目录存在
|
|
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
|
|
|
|
# 创建临时文件夹
|
|
if not os.path.exists(app.config['UPLOAD_FOLDER']):
|
|
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
|
|
|
|
# 创建文件处理器
|
|
from optimized_file_handler import OptimizedFileHandler
|
|
file_handler = OptimizedFileHandler(local_storage_dir=app.config['UPLOAD_FOLDER'])
|
|
|
|
# 全局变量
|
|
retrieval_system = None
|
|
|
|
def allowed_file(filename):
|
|
"""检查文件扩展名是否允许"""
|
|
return '.' in filename and \
|
|
filename.rsplit('.', 1)[1].lower() in app.config['ALLOWED_EXTENSIONS']
|
|
|
|
def init_retrieval_system():
|
|
"""初始化检索系统"""
|
|
global retrieval_system
|
|
|
|
if retrieval_system is not None:
|
|
return retrieval_system
|
|
|
|
logger.info("初始化多模态检索系统...")
|
|
|
|
# 检查模型路径
|
|
model_path = app.config['MODEL_PATH']
|
|
if not os.path.exists(model_path):
|
|
logger.error(f"模型路径不存在: {model_path}")
|
|
raise FileNotFoundError(f"模型路径不存在: {model_path}")
|
|
|
|
# 初始化检索系统
|
|
retrieval_system = MultimodalRetrievalLocal(
|
|
model_path=model_path,
|
|
use_all_gpus=True,
|
|
index_path=app.config['INDEX_PATH']
|
|
)
|
|
|
|
logger.info("多模态检索系统初始化完成")
|
|
return retrieval_system
|
|
|
|
def get_image_base64(image_path):
|
|
"""将图像转换为base64编码"""
|
|
with open(image_path, "rb") as image_file:
|
|
encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
|
|
return f"data:image/jpeg;base64,{encoded_string}"
|
|
|
|
@app.route('/')
|
|
def index():
|
|
"""首页"""
|
|
return render_template('local_index.html')
|
|
|
|
@app.route('/api/stats', methods=['GET'])
|
|
def get_stats():
|
|
"""获取系统统计信息"""
|
|
try:
|
|
retrieval = init_retrieval_system()
|
|
stats = retrieval.get_stats()
|
|
return jsonify({"success": True, "stats": stats})
|
|
except Exception as e:
|
|
logger.error(f"获取统计信息失败: {str(e)}")
|
|
return jsonify({"success": False, "error": str(e)}), 500
|
|
|
|
@app.route('/api/add_text', methods=['POST'])
|
|
def add_text():
|
|
"""添加文本"""
|
|
try:
|
|
data = request.json
|
|
text = data.get('text')
|
|
|
|
if not text:
|
|
return jsonify({"success": False, "error": "文本不能为空"}), 400
|
|
|
|
# 使用内存处理文本
|
|
with file_handler.temp_file_context(text.encode('utf-8'), suffix='.txt') as temp_file:
|
|
logger.info(f"处理文本: {temp_file}")
|
|
|
|
# 初始化检索系统
|
|
retrieval = init_retrieval_system()
|
|
|
|
# 添加文本
|
|
metadata = {
|
|
"timestamp": time.time(),
|
|
"source": "web_upload"
|
|
}
|
|
|
|
text_ids = retrieval.add_texts([text], [metadata])
|
|
|
|
# 保存索引
|
|
retrieval.save_index()
|
|
|
|
return jsonify({
|
|
"success": True,
|
|
"message": "文本添加成功",
|
|
"text_id": text_ids[0] if text_ids else None
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"添加文本失败: {str(e)}")
|
|
return jsonify({"success": False, "error": str(e)}), 500
|
|
finally:
|
|
# 清理临时文件
|
|
file_handler.cleanup_all_temp_files()
|
|
|
|
@app.route('/api/add_image', methods=['POST'])
|
|
def add_image():
|
|
"""添加图像"""
|
|
try:
|
|
# 检查是否有文件
|
|
if 'image' not in request.files:
|
|
return jsonify({"success": False, "error": "没有上传文件"}), 400
|
|
|
|
file = request.files['image']
|
|
|
|
# 检查文件名
|
|
if file.filename == '':
|
|
return jsonify({"success": False, "error": "没有选择文件"}), 400
|
|
|
|
if file and allowed_file(file.filename):
|
|
# 读取图像数据
|
|
image_data = file.read()
|
|
file_size = len(image_data)
|
|
|
|
# 使用文件处理器处理图像
|
|
logger.info(f"处理图像: {file.filename} ({file_size} 字节)")
|
|
|
|
# 初始化检索系统
|
|
retrieval = init_retrieval_system()
|
|
|
|
# 创建临时文件
|
|
file_obj = BytesIO(image_data)
|
|
filename = secure_filename(file.filename)
|
|
|
|
# 保存到本地文件系统
|
|
image_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
|
|
with open(image_path, 'wb') as f:
|
|
f.write(image_data)
|
|
|
|
# 加载图像
|
|
try:
|
|
image = Image.open(BytesIO(image_data))
|
|
# 确保图像是RGB模式
|
|
if image.mode != 'RGB':
|
|
logger.info(f"将图像从 {image.mode} 转换为 RGB")
|
|
image = image.convert('RGB')
|
|
|
|
logger.info(f"成功加载图像: {filename}, 格式: {image.format}, 模式: {image.mode}, 大小: {image.size}")
|
|
except Exception as e:
|
|
logger.error(f"加载图像失败: {filename}, 错误: {str(e)}")
|
|
return jsonify({"success": False, "error": f"图像格式不支持: {str(e)}"}), 400
|
|
|
|
# 添加图像
|
|
metadata = {
|
|
"filename": filename,
|
|
"timestamp": time.time(),
|
|
"source": "web_upload",
|
|
"size": file_size,
|
|
"local_path": image_path
|
|
}
|
|
|
|
# 添加到检索系统
|
|
image_ids = retrieval.add_images([image], [metadata], [image_path])
|
|
|
|
# 保存索引
|
|
retrieval.save_index()
|
|
|
|
return jsonify({
|
|
"success": True,
|
|
"message": "图像添加成功",
|
|
"image_id": image_ids[0] if image_ids else None
|
|
})
|
|
else:
|
|
return jsonify({"success": False, "error": "不支持的文件类型"}), 400
|
|
|
|
except Exception as e:
|
|
logger.error(f"添加图像失败: {str(e)}")
|
|
return jsonify({"success": False, "error": str(e)}), 500
|
|
finally:
|
|
# 清理临时文件
|
|
file_handler.cleanup_all_temp_files()
|
|
|
|
@app.route('/api/search_by_text', methods=['POST'])
|
|
def search_by_text():
|
|
"""文本搜索"""
|
|
try:
|
|
data = request.json
|
|
query = data.get('query')
|
|
k = int(data.get('k', 5))
|
|
filter_type = data.get('filter_type') # "text", "image" 或 null
|
|
|
|
if not query:
|
|
return jsonify({"success": False, "error": "查询文本不能为空"}), 400
|
|
|
|
# 初始化检索系统
|
|
retrieval = init_retrieval_system()
|
|
|
|
# 执行搜索
|
|
results = retrieval.search_by_text(query, k, filter_type)
|
|
|
|
# 处理结果
|
|
processed_results = []
|
|
for result in results:
|
|
item = {
|
|
"score": result.get("score", 0),
|
|
"type": result.get("type")
|
|
}
|
|
|
|
if result.get("type") == "text":
|
|
item["text"] = result.get("text", "")
|
|
elif result.get("type") == "image":
|
|
if "path" in result and os.path.exists(result["path"]):
|
|
item["image"] = get_image_base64(result["path"])
|
|
item["filename"] = os.path.basename(result["path"])
|
|
if "description" in result:
|
|
item["description"] = result["description"]
|
|
|
|
processed_results.append(item)
|
|
|
|
return jsonify({
|
|
"success": True,
|
|
"results": processed_results,
|
|
"query": query,
|
|
"filter_type": filter_type
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"文本搜索失败: {str(e)}")
|
|
return jsonify({"success": False, "error": str(e)}), 500
|
|
|
|
@app.route('/api/search_by_image', methods=['POST'])
|
|
def search_by_image():
|
|
"""图像搜索"""
|
|
try:
|
|
# 检查是否有文件
|
|
if 'image' not in request.files:
|
|
return jsonify({"success": False, "error": "没有上传文件"}), 400
|
|
|
|
file = request.files['image']
|
|
k = int(request.form.get('k', 5))
|
|
filter_type = request.form.get('filter_type') # "text", "image" 或 null
|
|
|
|
# 检查文件名
|
|
if file.filename == '':
|
|
return jsonify({"success": False, "error": "没有选择文件"}), 400
|
|
|
|
if file and allowed_file(file.filename):
|
|
# 读取图像数据
|
|
image_data = file.read()
|
|
file_size = len(image_data)
|
|
|
|
# 根据文件大小选择处理方式
|
|
if file_size <= 5 * 1024 * 1024: # 5MB
|
|
# 小文件使用内存处理
|
|
logger.info(f"使用内存处理搜索图像: {file.filename} ({file_size} 字节)")
|
|
image = Image.open(BytesIO(image_data))
|
|
|
|
# 初始化检索系统
|
|
retrieval = init_retrieval_system()
|
|
|
|
# 执行搜索
|
|
results = retrieval.search_by_image(image, k, filter_type)
|
|
else:
|
|
# 大文件使用临时文件处理
|
|
with file_handler.temp_file_context(image_data, suffix=os.path.splitext(file.filename)[1]) as temp_file:
|
|
logger.info(f"使用临时文件处理搜索图像: {temp_file} ({file_size} 字节)")
|
|
|
|
# 初始化检索系统
|
|
retrieval = init_retrieval_system()
|
|
|
|
# 加载图像
|
|
image = Image.open(temp_file)
|
|
|
|
# 执行搜索
|
|
results = retrieval.search_by_image(image, k, filter_type)
|
|
|
|
# 处理结果
|
|
processed_results = []
|
|
for result in results:
|
|
item = {
|
|
"score": result.get("score", 0),
|
|
"type": result.get("type")
|
|
}
|
|
|
|
if result.get("type") == "text":
|
|
item["text"] = result.get("text", "")
|
|
elif result.get("type") == "image":
|
|
if "path" in result and os.path.exists(result["path"]):
|
|
item["image"] = get_image_base64(result["path"])
|
|
item["filename"] = os.path.basename(result["path"])
|
|
if "description" in result:
|
|
item["description"] = result["description"]
|
|
|
|
processed_results.append(item)
|
|
|
|
return jsonify({
|
|
"success": True,
|
|
"results": processed_results,
|
|
"filter_type": filter_type
|
|
})
|
|
else:
|
|
return jsonify({"success": False, "error": "不支持的文件类型"}), 400
|
|
|
|
except Exception as e:
|
|
logger.error(f"图像搜索失败: {str(e)}")
|
|
return jsonify({"success": False, "error": str(e)}), 500
|
|
finally:
|
|
# 清理临时文件
|
|
file_handler.cleanup_all_temp_files()
|
|
|
|
@app.route('/uploads/<filename>')
|
|
def uploaded_file(filename):
|
|
"""提供上传文件的访问"""
|
|
return send_from_directory(app.config['UPLOAD_FOLDER'], filename)
|
|
|
|
@app.route('/temp/<filename>')
|
|
def temp_file(filename):
|
|
"""提供临时文件的访问"""
|
|
return send_from_directory(app.config['UPLOAD_FOLDER'], filename)
|
|
|
|
@app.route('/api/save_index', methods=['POST'])
|
|
def save_index():
|
|
"""保存索引"""
|
|
try:
|
|
# 初始化检索系统
|
|
retrieval = init_retrieval_system()
|
|
|
|
# 保存索引
|
|
retrieval.save_index()
|
|
|
|
return jsonify({
|
|
"success": True,
|
|
"message": "索引保存成功"
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"保存索引失败: {str(e)}")
|
|
return jsonify({"success": False, "error": str(e)}), 500
|
|
|
|
@app.route('/api/clear_index', methods=['POST'])
|
|
def clear_index():
|
|
"""清空索引"""
|
|
try:
|
|
# 初始化检索系统
|
|
retrieval = init_retrieval_system()
|
|
|
|
# 清空索引
|
|
retrieval.clear_index()
|
|
|
|
return jsonify({
|
|
"success": True,
|
|
"message": "索引已清空"
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"清空索引失败: {str(e)}")
|
|
return jsonify({"success": False, "error": str(e)}), 500
|
|
|
|
@app.route('/api/list_items', methods=['GET'])
|
|
def list_items():
|
|
"""列出所有索引项"""
|
|
try:
|
|
# 初始化检索系统
|
|
retrieval = init_retrieval_system()
|
|
|
|
# 获取所有项
|
|
items = retrieval.list_items()
|
|
|
|
return jsonify({
|
|
"success": True,
|
|
"items": items
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"列出索引项失败: {str(e)}")
|
|
return jsonify({"success": False, "error": str(e)}), 500
|
|
|
|
@app.route('/api/system_info', methods=['GET', 'POST'])
|
|
def system_info():
|
|
"""获取系统信息"""
|
|
try:
|
|
# GPU信息
|
|
gpu_info = []
|
|
if torch.cuda.is_available():
|
|
for i in range(torch.cuda.device_count()):
|
|
gpu_info.append({
|
|
"id": i,
|
|
"name": torch.cuda.get_device_name(i),
|
|
"memory_total": torch.cuda.get_device_properties(i).total_memory / (1024 ** 3),
|
|
"memory_allocated": torch.cuda.memory_allocated(i) / (1024 ** 3),
|
|
"memory_reserved": torch.cuda.memory_reserved(i) / (1024 ** 3)
|
|
})
|
|
|
|
# 检索系统信息
|
|
retrieval_info = {}
|
|
if retrieval_system is not None:
|
|
retrieval_info = retrieval_system.get_stats()
|
|
|
|
return jsonify({
|
|
"success": True,
|
|
"gpu_info": gpu_info,
|
|
"retrieval_info": retrieval_info,
|
|
"model_path": app.config['MODEL_PATH'],
|
|
"index_path": app.config['INDEX_PATH']
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"获取系统信息失败: {str(e)}")
|
|
return jsonify({"success": False, "error": str(e)}), 500
|
|
|
|
if __name__ == '__main__':
|
|
try:
|
|
# 预初始化检索系统
|
|
init_retrieval_system()
|
|
|
|
# 启动Web应用
|
|
app.run(host='0.0.0.0', port=5000, debug=False)
|
|
except Exception as e:
|
|
logger.error(f"启动Web应用失败: {str(e)}")
|
|
sys.exit(1)
|