mmeb/web_app_vdb.py
2025-09-22 10:13:11 +00:00

690 lines
24 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
集成百度VDB的多模态检索Web应用
支持向量存储和多种检索方式
"""
import os
import json
import time
from flask import Flask, render_template, request, jsonify, send_file
from werkzeug.utils import secure_filename
from PIL import Image
import base64
import io
import logging
import traceback
import glob
# 设置环境变量优化GPU内存
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = Flask(__name__)
app.config['SECRET_KEY'] = 'vdb_multimodal_retrieval_2024'
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max file size
# 配置上传文件夹
UPLOAD_FOLDER = 'uploads'
SAMPLE_IMAGES_FOLDER = 'sample_images'
TEXT_DATA_FOLDER = 'text_data'
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'bmp', 'webp'}
# 确保文件夹存在
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
os.makedirs(SAMPLE_IMAGES_FOLDER, exist_ok=True)
os.makedirs(TEXT_DATA_FOLDER, exist_ok=True)
# 全局检索系统实例
retrieval_system = None
def allowed_file(filename):
"""检查文件扩展名是否允许"""
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
def image_to_base64(image_path):
"""将图片转换为base64编码"""
try:
with open(image_path, "rb") as img_file:
return base64.b64encode(img_file.read()).decode('utf-8')
except Exception as e:
logger.error(f"图片转换失败: {e}")
return None
@app.route('/')
def index():
"""主页"""
return render_template('index.html')
@app.route('/api/status')
def get_status():
"""获取系统状态"""
global retrieval_system
status = {
'initialized': retrieval_system is not None,
'gpu_count': 0,
'model_loaded': False,
'vdb_connected': False
}
try:
import torch
if torch.cuda.is_available():
status['gpu_count'] = torch.cuda.device_count()
if retrieval_system:
status['model_loaded'] = retrieval_system.model is not None
status['vdb_connected'] = retrieval_system.vdb is not None
status['device_ids'] = retrieval_system.device_ids
# 获取VDB统计信息
if retrieval_system.vdb:
stats = retrieval_system.get_statistics()
status['vdb_stats'] = stats
except Exception as e:
logger.error(f"获取状态失败: {e}")
return jsonify(status)
@app.route('/api/init', methods=['POST'])
def initialize_system():
"""初始化VDB多模态检索系统"""
global retrieval_system
try:
logger.info("正在初始化VDB多模态检索系统...")
# 导入VDB检索系统
from multimodal_retrieval_vdb import MultimodalRetrievalVDB
# 初始化系统
retrieval_system = MultimodalRetrievalVDB()
if retrieval_system.model is None:
raise Exception("模型加载失败")
if retrieval_system.vdb is None:
raise Exception("VDB连接失败")
logger.info("✅ VDB多模态系统初始化成功")
# 获取统计信息
stats = retrieval_system.get_statistics()
return jsonify({
'success': True,
'message': 'VDB多模态系统初始化成功',
'device_ids': retrieval_system.device_ids,
'gpu_count': len(retrieval_system.device_ids),
'vdb_stats': stats
})
except Exception as e:
error_msg = f"系统初始化失败: {str(e)}"
logger.error(error_msg)
logger.error(traceback.format_exc())
return jsonify({
'success': False,
'message': error_msg
}), 500
@app.route('/api/search/text_to_text', methods=['POST'])
def search_text_to_text():
"""文本搜索文本"""
return handle_search('text_to_text')
@app.route('/api/search/text_to_image', methods=['POST'])
def search_text_to_image():
"""文本搜索图片"""
return handle_search('text_to_image')
@app.route('/api/search/image_to_text', methods=['POST'])
def search_image_to_text():
"""图片搜索文本"""
return handle_search('image_to_text')
@app.route('/api/search/image_to_image', methods=['POST'])
def search_image_to_image():
"""图片搜索图片"""
return handle_search('image_to_image')
@app.route('/api/search', methods=['POST'])
def search():
"""通用搜索接口(兼容旧版本)"""
mode = request.form.get('mode') or request.json.get('mode', 'text_to_text')
return handle_search(mode)
def handle_search(mode):
"""处理搜索请求的通用函数"""
global retrieval_system
if not retrieval_system:
return jsonify({
'success': False,
'message': '系统未初始化,请先点击初始化按钮'
}), 400
try:
top_k = int(request.form.get('top_k', 5))
if mode in ['text_to_text', 'text_to_image']:
# 文本查询
query = request.form.get('query') or request.json.get('query', '')
if not query.strip():
return jsonify({
'success': False,
'message': '请输入查询文本'
}), 400
logger.info(f"执行{mode}搜索: {query}")
# 执行搜索
if mode == 'text_to_text':
raw_results = retrieval_system.search_text_to_text(query, top_k=top_k)
# 格式化文本搜索结果
results = []
for text, score in raw_results:
results.append({
'text': text,
'score': float(score)
})
else: # text_to_image
raw_results = retrieval_system.search_text_to_image(query, top_k=top_k)
# 格式化图像搜索结果
results = []
for image_path, score in raw_results:
try:
# 读取图像并转换为base64
with open(image_path, 'rb') as img_file:
image_data = img_file.read()
image_base64 = base64.b64encode(image_data).decode('utf-8')
results.append({
'filename': os.path.basename(image_path),
'image_path': image_path,
'image_base64': image_base64,
'score': float(score)
})
except Exception as e:
logger.error(f"读取图像失败 {image_path}: {e}")
results.append({
'filename': os.path.basename(image_path),
'image_path': image_path,
'image_base64': '',
'score': float(score)
})
return jsonify({
'success': True,
'mode': mode,
'query': query,
'results': results,
'result_count': len(results)
})
elif mode in ['image_to_text', 'image_to_image']:
# 图片查询
if 'image' not in request.files:
return jsonify({
'success': False,
'message': '请上传查询图片'
}), 400
file = request.files['image']
if file.filename == '' or not allowed_file(file.filename):
return jsonify({
'success': False,
'message': '请上传有效的图片文件'
}), 400
# 保存上传的图片
filename = secure_filename(file.filename)
timestamp = str(int(time.time()))
filename = f"query_{timestamp}_{filename}"
filepath = os.path.join(UPLOAD_FOLDER, filename)
file.save(filepath)
logger.info(f"执行{mode}搜索,图片: {filename}")
# 执行搜索
if mode == 'image_to_text':
raw_results = retrieval_system.search_image_to_text(filepath, top_k=top_k)
# 格式化文本搜索结果
results = []
for text, score in raw_results:
results.append({
'text': text,
'score': float(score)
})
else: # image_to_image
raw_results = retrieval_system.search_image_to_image(filepath, top_k=top_k)
# 格式化图像搜索结果
results = []
for image_path, score in raw_results:
try:
# 读取图像并转换为base64
with open(image_path, 'rb') as img_file:
image_data = img_file.read()
image_base64 = base64.b64encode(image_data).decode('utf-8')
results.append({
'filename': os.path.basename(image_path),
'image_path': image_path,
'image_base64': image_base64,
'score': float(score)
})
except Exception as e:
logger.error(f"读取图像失败 {image_path}: {e}")
results.append({
'filename': os.path.basename(image_path),
'image_path': image_path,
'image_base64': '',
'score': float(score)
})
# 转换查询图片为base64
query_image_b64 = image_to_base64(filepath)
return jsonify({
'success': True,
'mode': mode,
'query_image': query_image_b64,
'results': results,
'result_count': len(results)
})
else:
return jsonify({
'success': False,
'message': f'不支持的搜索模式: {mode}'
}), 400
except Exception as e:
error_msg = f"搜索失败: {str(e)}"
logger.error(error_msg)
logger.error(traceback.format_exc())
return jsonify({
'success': False,
'message': error_msg
}), 500
@app.route('/api/upload/images', methods=['POST'])
def upload_images():
"""批量上传图片"""
try:
uploaded_files = []
if 'images' not in request.files:
return jsonify({'success': False, 'message': '没有选择文件'}), 400
files = request.files.getlist('images')
for file in files:
if file and file.filename != '' and allowed_file(file.filename):
filename = secure_filename(file.filename)
timestamp = str(int(time.time()))
filename = f"{timestamp}_{filename}"
filepath = os.path.join(SAMPLE_IMAGES_FOLDER, filename)
file.save(filepath)
uploaded_files.append(filename)
return jsonify({
'success': True,
'message': f'成功上传 {len(uploaded_files)} 个图片文件',
'uploaded_count': len(uploaded_files),
'files': uploaded_files
})
except Exception as e:
return jsonify({
'success': False,
'message': f'上传失败: {str(e)}'
}), 500
@app.route('/api/upload/texts', methods=['POST'])
def upload_texts():
"""批量上传文本数据"""
try:
data = request.get_json()
if not data or 'texts' not in data:
return jsonify({'success': False, 'message': '没有提供文本数据'}), 400
texts = data['texts']
if not isinstance(texts, list):
return jsonify({'success': False, 'message': '文本数据格式错误'}), 400
# 保存文本数据到文件
timestamp = str(int(time.time()))
filename = f"texts_{timestamp}.json"
filepath = os.path.join(TEXT_DATA_FOLDER, filename)
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(texts, f, ensure_ascii=False, indent=2)
return jsonify({
'success': True,
'message': f'成功上传 {len(texts)} 条文本',
'uploaded_count': len(texts)
})
except Exception as e:
return jsonify({
'success': False,
'message': f'上传失败: {str(e)}'
}), 500
@app.route('/api/store_data', methods=['POST'])
def store_data():
"""将上传的数据存储到VDB"""
global retrieval_system
if not retrieval_system:
return jsonify({
'success': False,
'message': '系统未初始化'
}), 400
try:
# 获取所有图片和文本文件
image_files = []
text_data = []
# 扫描图片文件
for ext in ALLOWED_EXTENSIONS:
pattern = os.path.join(SAMPLE_IMAGES_FOLDER, f"*.{ext}")
image_files.extend(glob.glob(pattern))
# 读取文本文件
text_files = glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.json"))
text_files.extend(glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.txt")))
for text_file in text_files:
try:
if text_file.endswith('.json'):
with open(text_file, 'r', encoding='utf-8') as f:
data = json.load(f)
if isinstance(data, list):
text_data.extend([str(item).strip() for item in data if str(item).strip()])
else:
text_data.append(str(data).strip())
else:
with open(text_file, 'r', encoding='utf-8') as f:
lines = [line.strip() for line in f.readlines() if line.strip()]
text_data.extend(lines)
except Exception as e:
logger.warning(f"读取文本文件失败 {text_file}: {e}")
# 检查是否有数据可以存储
if not image_files and not text_data:
return jsonify({
'success': False,
'message': '没有找到可用的图片或文本数据,请先上传数据'
}), 400
# 存储数据到VDB
stored_images = 0
stored_texts = 0
if image_files:
logger.info(f"存储图片到VDB{len(image_files)} 张图片")
image_ids = retrieval_system.store_images(image_files)
stored_images = len(image_ids)
if text_data:
logger.info(f"存储文本到VDB{len(text_data)} 条文本")
text_ids = retrieval_system.store_texts(text_data)
stored_texts = len(text_ids)
# 获取更新后的统计信息
stats = retrieval_system.get_statistics()
return jsonify({
'success': True,
'message': f'数据存储完成!图片: {stored_images} 张,文本: {stored_texts}',
'stored_images': stored_images,
'stored_texts': stored_texts,
'vdb_stats': stats
})
except Exception as e:
logger.error(f"存储数据失败: {str(e)}")
return jsonify({
'success': False,
'message': f'存储数据失败: {str(e)}'
}), 500
@app.route('/api/data/stats', methods=['GET'])
def get_data_stats():
"""获取数据统计信息"""
global retrieval_system
try:
# 统计本地文件
image_count = 0
for ext in ALLOWED_EXTENSIONS:
pattern = os.path.join(SAMPLE_IMAGES_FOLDER, f"*.{ext}")
image_count += len(glob.glob(pattern))
text_count = 0
text_files = glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.json"))
text_files.extend(glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.txt")))
for text_file in text_files:
try:
if text_file.endswith('.json'):
with open(text_file, 'r', encoding='utf-8') as f:
data = json.load(f)
if isinstance(data, list):
text_count += len(data)
else:
text_count += 1
else:
with open(text_file, 'r', encoding='utf-8') as f:
lines = [line.strip() for line in f.readlines() if line.strip()]
text_count += len(lines)
except Exception:
continue
# 获取VDB统计信息
vdb_stats = {}
if retrieval_system:
vdb_stats = retrieval_system.get_statistics()
return jsonify({
'success': True,
'local_files': {
'image_count': image_count,
'text_count': text_count
},
'vdb_stats': vdb_stats
})
except Exception as e:
logger.error(f"获取数据统计失败: {str(e)}")
return jsonify({
'success': False,
'message': f'获取统计失败: {str(e)}'
}), 500
@app.route('/api/data/list', methods=['GET'])
def list_data():
"""获取数据列表"""
try:
# 获取图片文件列表
image_files = []
for ext in ALLOWED_EXTENSIONS:
pattern = os.path.join(SAMPLE_IMAGES_FOLDER, f"*.{ext}")
for file_path in glob.glob(pattern):
try:
# 转换为base64
image_base64 = image_to_base64(file_path)
image_files.append({
'filename': os.path.basename(file_path),
'filepath': file_path,
'image_base64': image_base64,
'size': os.path.getsize(file_path)
})
except Exception as e:
logger.warning(f"处理图片文件失败 {file_path}: {e}")
# 获取文本文件列表
text_files = []
text_file_paths = glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.json"))
text_file_paths.extend(glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.txt")))
for text_file in text_file_paths:
try:
text_files.append({
'filename': os.path.basename(text_file),
'filepath': text_file,
'size': os.path.getsize(text_file)
})
except Exception as e:
logger.warning(f"处理文本文件失败 {text_file}: {e}")
return jsonify({
'success': True,
'image_files': image_files,
'text_files': text_files,
'image_count': len(image_files),
'text_count': len(text_files)
})
except Exception as e:
logger.error(f"获取数据列表失败: {str(e)}")
return jsonify({
'success': False,
'message': f'获取数据列表失败: {str(e)}'
}), 500
@app.route('/api/data/clear', methods=['POST'])
def clear_data():
"""清空所有数据"""
global retrieval_system
try:
# 清空本地文件
for ext in ALLOWED_EXTENSIONS:
pattern = os.path.join(SAMPLE_IMAGES_FOLDER, f"*.{ext}")
for file_path in glob.glob(pattern):
try:
os.remove(file_path)
except Exception as e:
logger.warning(f"删除图片文件失败 {file_path}: {e}")
text_files = glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.json"))
text_files.extend(glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.txt")))
for text_file in text_files:
try:
os.remove(text_file)
except Exception as e:
logger.warning(f"删除文本文件失败 {text_file}: {e}")
# 清空VDB数据
if retrieval_system:
retrieval_system.clear_all_data()
return jsonify({
'success': True,
'message': '所有数据已清空包括VDB中的向量数据'
})
except Exception as e:
logger.error(f"清空数据失败: {str(e)}")
return jsonify({
'success': False,
'message': f'清空数据失败: {str(e)}'
}), 500
@app.route('/uploads/<filename>')
def uploaded_file(filename):
"""提供上传文件的访问"""
return send_file(os.path.join(SAMPLE_IMAGES_FOLDER, filename))
def print_startup_info():
"""打印启动信息"""
print("🚀 启动VDB多模态检索Web应用")
print("=" * 60)
print("访问地址: http://localhost:5000")
print("新功能:")
print(" 🗄️ 百度VDB - 向量数据库存储")
print(" 📊 实时统计 - VDB数据统计信息")
print(" 🔄 数据同步 - 本地文件到VDB存储")
print("支持功能:")
print(" 📝 文搜文 - 文本查找相似文本")
print(" 🖼️ 文搜图 - 文本查找相关图片")
print(" 📝 图搜文 - 图片查找相关文本")
print(" 🖼️ 图搜图 - 图片查找相似图片")
print(" 📤 批量上传 - 图片和文本数据管理")
print("GPU配置:")
try:
import torch
if torch.cuda.is_available():
gpu_count = torch.cuda.device_count()
print(f" 🖥️ 检测到 {gpu_count} 个GPU")
for i in range(gpu_count):
name = torch.cuda.get_device_name(i)
props = torch.cuda.get_device_properties(i)
memory_gb = props.total_memory / 1024**3
print(f" GPU {i}: {name} ({memory_gb:.1f}GB)")
else:
print(" ❌ CUDA不可用")
except Exception as e:
print(f" ❌ GPU检查失败: {e}")
print("VDB配置:")
print(" 🌐 服务器: http://180.76.96.191:5287")
print(" 👤 用户: root")
print(" 🗄️ 数据库: multimodal_retrieval")
print("=" * 60)
def auto_initialize():
"""启动时自动初始化系统"""
global retrieval_system
try:
logger.info("🚀 启动时自动初始化VDB多模态检索系统...")
# 导入VDB检索系统
from multimodal_retrieval_vdb import MultimodalRetrievalVDB
# 初始化系统
retrieval_system = MultimodalRetrievalVDB()
if retrieval_system.model is None:
raise Exception("模型加载失败")
if retrieval_system.vdb is None:
raise Exception("VDB连接失败")
logger.info("✅ VDB系统自动初始化成功")
return True
except Exception as e:
logger.error(f"❌ VDB系统自动初始化失败: {str(e)}")
logger.error(traceback.format_exc())
return False
if __name__ == '__main__':
print_startup_info()
# 启动时自动初始化
auto_initialize()
# 启动Flask应用
app.run(
host='0.0.0.0',
port=5000,
debug=False,
threaded=True
)