302 lines
11 KiB
Python
302 lines
11 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
MongoDB元数据管理器
|
||
用于存储和管理多模态文件的元数据信息
|
||
"""
|
||
|
||
import os
|
||
import logging
|
||
from datetime import datetime
|
||
from typing import Dict, List, Optional, Any
|
||
from pymongo import MongoClient
|
||
from pymongo.errors import ConnectionFailure, OperationFailure
|
||
import hashlib
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class MongoDBManager:
|
||
"""MongoDB元数据管理器"""
|
||
|
||
def __init__(self, uri: str = None, database: str = "mmeb"):
|
||
"""
|
||
初始化MongoDB连接
|
||
|
||
Args:
|
||
uri: MongoDB连接URI
|
||
database: 数据库名称
|
||
"""
|
||
self.uri = uri or "mongodb://root:aWQtrUH!b3@XVjfbNkkp.mongodb.bj.baidubce.com/mmeb?authSource=admin"
|
||
self.database_name = database
|
||
self.client = None
|
||
self.db = None
|
||
|
||
self._connect()
|
||
|
||
def _connect(self):
|
||
"""建立MongoDB连接"""
|
||
try:
|
||
self.client = MongoClient(self.uri, serverSelectionTimeoutMS=5000)
|
||
# 测试连接
|
||
self.client.admin.command('ping')
|
||
self.db = self.client[self.database_name]
|
||
logger.info(f"✅ MongoDB连接成功: {self.database_name}")
|
||
|
||
# 创建索引
|
||
self._create_indexes()
|
||
|
||
except ConnectionFailure as e:
|
||
logger.error(f"❌ MongoDB连接失败: {e}")
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"❌ MongoDB初始化失败: {e}")
|
||
raise
|
||
|
||
def _create_indexes(self):
|
||
"""创建必要的索引"""
|
||
try:
|
||
# 文件元数据集合索引
|
||
files_collection = self.db.files
|
||
files_collection.create_index("file_hash", unique=True)
|
||
files_collection.create_index("file_type")
|
||
files_collection.create_index("upload_time")
|
||
files_collection.create_index("bos_key")
|
||
|
||
# 向量索引集合索引
|
||
vectors_collection = self.db.vectors
|
||
vectors_collection.create_index("file_id")
|
||
vectors_collection.create_index("vector_type")
|
||
vectors_collection.create_index("vdb_id")
|
||
|
||
logger.info("✅ MongoDB索引创建完成")
|
||
|
||
except Exception as e:
|
||
logger.warning(f"⚠️ 创建索引时出现警告: {e}")
|
||
|
||
def store_file_metadata(self, file_path: str = None, file_type: str = None,
|
||
bos_key: str = None, additional_info: Dict = None,
|
||
metadata: Dict = None) -> str:
|
||
"""
|
||
存储文件元数据
|
||
|
||
Args:
|
||
file_path: 本地文件路径 (可选,如果提供metadata则忽略)
|
||
file_type: 文件类型 (image/text)
|
||
bos_key: BOS存储键
|
||
additional_info: 额外信息
|
||
metadata: 直接提供的元数据字典 (新增参数)
|
||
|
||
Returns:
|
||
文件ID
|
||
"""
|
||
try:
|
||
# 如果直接提供了元数据,使用元数据
|
||
if metadata:
|
||
# 确保必要字段存在
|
||
if 'upload_time' not in metadata:
|
||
metadata['upload_time'] = datetime.utcnow()
|
||
if 'status' not in metadata:
|
||
metadata['status'] = 'active'
|
||
|
||
# 检查是否已存在(基于file_id或其他唯一标识)
|
||
existing = None
|
||
if 'file_id' in metadata:
|
||
existing = self.db.files.find_one({"file_id": metadata['file_id']})
|
||
elif 'bos_key' in metadata:
|
||
existing = self.db.files.find_one({"bos_key": metadata['bos_key']})
|
||
|
||
if existing:
|
||
logger.info(f"文件已存在: {metadata.get('filename', 'unknown')} (ID: {existing['_id']})")
|
||
return str(existing['_id'])
|
||
|
||
# 插入新记录
|
||
result = self.db.files.insert_one(metadata)
|
||
file_id = str(result.inserted_id)
|
||
|
||
logger.info(f"✅ 文件元数据已存储: {metadata.get('filename', 'unknown')} (ID: {file_id})")
|
||
return file_id
|
||
|
||
# 原有逻辑:基于文件路径
|
||
if not file_path or not file_type or not bos_key:
|
||
raise ValueError("file_path, file_type, and bos_key are required when metadata is not provided")
|
||
|
||
# 计算文件哈希
|
||
file_hash = self._calculate_file_hash(file_path)
|
||
|
||
# 获取文件信息
|
||
file_stat = os.stat(file_path)
|
||
filename = os.path.basename(file_path)
|
||
|
||
metadata = {
|
||
"filename": filename,
|
||
"file_path": file_path,
|
||
"file_type": file_type,
|
||
"file_hash": file_hash,
|
||
"file_size": file_stat.st_size,
|
||
"bos_key": bos_key,
|
||
"upload_time": datetime.utcnow(),
|
||
"status": "active",
|
||
"additional_info": additional_info or {}
|
||
}
|
||
|
||
# 检查是否已存在
|
||
existing = self.db.files.find_one({"file_hash": file_hash})
|
||
if existing:
|
||
logger.info(f"文件已存在: {filename} (ID: {existing['_id']})")
|
||
return str(existing['_id'])
|
||
|
||
# 插入新记录
|
||
result = self.db.files.insert_one(metadata)
|
||
file_id = str(result.inserted_id)
|
||
|
||
logger.info(f"✅ 文件元数据已存储: {filename} (ID: {file_id})")
|
||
return file_id
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 存储文件元数据失败: {e}")
|
||
raise
|
||
|
||
def store_vector_metadata(self, file_id: str, vector_type: str,
|
||
vdb_id: str, vector_info: Dict = None):
|
||
"""
|
||
存储向量元数据
|
||
|
||
Args:
|
||
file_id: 文件ID
|
||
vector_type: 向量类型 (text_vector/image_vector)
|
||
vdb_id: VDB中的向量ID
|
||
vector_info: 向量信息
|
||
"""
|
||
try:
|
||
vector_metadata = {
|
||
"file_id": file_id,
|
||
"vector_type": vector_type,
|
||
"vdb_id": vdb_id,
|
||
"create_time": datetime.utcnow(),
|
||
"vector_info": vector_info or {}
|
||
}
|
||
|
||
result = self.db.vectors.insert_one(vector_metadata)
|
||
logger.info(f"✅ 向量元数据已存储: {vector_type} (ID: {result.inserted_id})")
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 存储向量元数据失败: {e}")
|
||
raise
|
||
|
||
def get_file_metadata(self, file_id: str) -> Optional[Dict]:
|
||
"""获取文件元数据"""
|
||
try:
|
||
from bson import ObjectId
|
||
result = self.db.files.find_one({"_id": ObjectId(file_id)})
|
||
if result:
|
||
result['_id'] = str(result['_id'])
|
||
return result
|
||
except Exception as e:
|
||
logger.error(f"❌ 获取文件元数据失败: {e}")
|
||
return None
|
||
|
||
def get_files_by_type(self, file_type: str, limit: int = 100) -> List[Dict]:
|
||
"""根据类型获取文件列表"""
|
||
try:
|
||
cursor = self.db.files.find(
|
||
{"file_type": file_type, "status": "active"}
|
||
).limit(limit).sort("upload_time", -1)
|
||
|
||
results = []
|
||
for doc in cursor:
|
||
doc['_id'] = str(doc['_id'])
|
||
results.append(doc)
|
||
|
||
return results
|
||
except Exception as e:
|
||
logger.error(f"❌ 获取文件列表失败: {e}")
|
||
return []
|
||
|
||
def get_all_files(self, limit: int = 1000) -> List[Dict]:
|
||
"""获取所有文件列表"""
|
||
try:
|
||
cursor = self.db.files.find(
|
||
{"status": "active"}
|
||
).limit(limit).sort("upload_time", -1)
|
||
|
||
results = []
|
||
for doc in cursor:
|
||
doc['_id'] = str(doc['_id'])
|
||
results.append(doc)
|
||
|
||
return results
|
||
except Exception as e:
|
||
logger.error(f"❌ 获取所有文件列表失败: {e}")
|
||
return []
|
||
|
||
def get_vector_metadata(self, file_id: str, vector_type: str = None) -> List[Dict]:
|
||
"""获取向量元数据"""
|
||
try:
|
||
query = {"file_id": file_id}
|
||
if vector_type:
|
||
query["vector_type"] = vector_type
|
||
|
||
cursor = self.db.vectors.find(query)
|
||
results = []
|
||
for doc in cursor:
|
||
doc['_id'] = str(doc['_id'])
|
||
results.append(doc)
|
||
|
||
return results
|
||
except Exception as e:
|
||
logger.error(f"❌ 获取向量元数据失败: {e}")
|
||
return []
|
||
|
||
def get_stats(self) -> Dict:
|
||
"""获取统计信息"""
|
||
try:
|
||
stats = {
|
||
"total_files": self.db.files.count_documents({"status": "active"}),
|
||
"image_files": self.db.files.count_documents({"file_type": "image", "status": "active"}),
|
||
"text_files": self.db.files.count_documents({"file_type": "text", "status": "active"}),
|
||
"total_vectors": self.db.vectors.count_documents({}),
|
||
"image_vectors": self.db.vectors.count_documents({"vector_type": "image_vector"}),
|
||
"text_vectors": self.db.vectors.count_documents({"vector_type": "text_vector"})
|
||
}
|
||
return stats
|
||
except Exception as e:
|
||
logger.error(f"❌ 获取统计信息失败: {e}")
|
||
return {}
|
||
|
||
def delete_file_metadata(self, file_id: str):
|
||
"""删除文件元数据(软删除)"""
|
||
try:
|
||
from bson import ObjectId
|
||
self.db.files.update_one(
|
||
{"_id": ObjectId(file_id)},
|
||
{"$set": {"status": "deleted", "delete_time": datetime.utcnow()}}
|
||
)
|
||
logger.info(f"✅ 文件元数据已删除: {file_id}")
|
||
except Exception as e:
|
||
logger.error(f"❌ 删除文件元数据失败: {e}")
|
||
raise
|
||
|
||
def _calculate_file_hash(self, file_path: str) -> str:
|
||
"""计算文件SHA256哈希"""
|
||
hash_sha256 = hashlib.sha256()
|
||
with open(file_path, "rb") as f:
|
||
for chunk in iter(lambda: f.read(4096), b""):
|
||
hash_sha256.update(chunk)
|
||
return hash_sha256.hexdigest()
|
||
|
||
def close(self):
|
||
"""关闭连接"""
|
||
if self.client:
|
||
self.client.close()
|
||
logger.info("MongoDB连接已关闭")
|
||
|
||
# 全局实例
|
||
mongodb_manager = None
|
||
|
||
def get_mongodb_manager() -> MongoDBManager:
|
||
"""获取MongoDB管理器实例"""
|
||
global mongodb_manager
|
||
if mongodb_manager is None:
|
||
mongodb_manager = MongoDBManager()
|
||
return mongodb_manager
|