♻️ format project
This commit is contained in:
parent
202fad85ec
commit
39e3fe76ea
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
78
app_log.txt
78
app_log.txt
File diff suppressed because one or more lines are too long
@ -1,342 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
百度对象存储BOS管理器
|
|
||||||
用于存储和管理多模态文件的原始数据
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import logging
|
|
||||||
import json
|
|
||||||
import copy
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import Dict, List, Optional, Any, Union
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from baidubce.auth import bce_credentials
|
|
||||||
from baidubce import bce_base_client, bce_client_configuration
|
|
||||||
from baidubce.services.bos.bos_client import BosClient
|
|
||||||
from baidubce.exception import BceError
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
class BaiduBOSManager:
|
|
||||||
"""百度对象存储BOS管理器"""
|
|
||||||
|
|
||||||
def __init__(self, ak: str = None, sk: str = None, endpoint: str = None, bucket: str = None):
|
|
||||||
"""
|
|
||||||
初始化BOS客户端
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ak: Access Key
|
|
||||||
sk: Secret Key
|
|
||||||
endpoint: BOS服务端点
|
|
||||||
bucket: 存储桶名称
|
|
||||||
"""
|
|
||||||
self.ak = ak or "ALTAKmzKDy1OqhqmepD2OeXqbN"
|
|
||||||
self.sk = sk or "b79c5fbc26344868916ec6e9e2ff65f0"
|
|
||||||
self.endpoint = endpoint or "https://bj.bcebos.com"
|
|
||||||
self.bucket = bucket or "dmtyz-demo"
|
|
||||||
|
|
||||||
self.client = None
|
|
||||||
self._init_client()
|
|
||||||
|
|
||||||
def _init_client(self):
|
|
||||||
"""初始化BOS客户端"""
|
|
||||||
try:
|
|
||||||
config = bce_client_configuration.BceClientConfiguration(
|
|
||||||
credentials=bce_credentials.BceCredentials(self.ak, self.sk),
|
|
||||||
endpoint=self.endpoint
|
|
||||||
)
|
|
||||||
self.client = BosClient(config)
|
|
||||||
|
|
||||||
# 测试连接
|
|
||||||
self._test_connection()
|
|
||||||
logger.info(f"✅ BOS客户端初始化成功: {self.bucket}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ BOS客户端初始化失败: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def _test_connection(self):
|
|
||||||
"""测试BOS连接"""
|
|
||||||
try:
|
|
||||||
# 尝试列出存储桶
|
|
||||||
self.client.list_buckets()
|
|
||||||
logger.info("✅ BOS连接测试成功")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ BOS连接测试失败: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def upload_file(self, local_path: str, bos_key: str = None,
|
|
||||||
content_type: str = None) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
上传文件到BOS
|
|
||||||
|
|
||||||
Args:
|
|
||||||
local_path: 本地文件路径
|
|
||||||
bos_key: BOS对象键,如果为None则自动生成
|
|
||||||
content_type: 文件内容类型
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
上传结果信息
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if not os.path.exists(local_path):
|
|
||||||
raise FileNotFoundError(f"文件不存在: {local_path}")
|
|
||||||
|
|
||||||
# 生成BOS键
|
|
||||||
if bos_key is None:
|
|
||||||
bos_key = self._generate_bos_key(local_path)
|
|
||||||
|
|
||||||
# 自动检测内容类型
|
|
||||||
if content_type is None:
|
|
||||||
content_type = self._detect_content_type(local_path)
|
|
||||||
|
|
||||||
# 获取文件大小
|
|
||||||
file_stat = os.stat(local_path)
|
|
||||||
file_size = file_stat.st_size
|
|
||||||
|
|
||||||
# 上传文件(使用put_object_from_file方法)
|
|
||||||
response = self.client.put_object_from_file(
|
|
||||||
self.bucket,
|
|
||||||
bos_key,
|
|
||||||
local_path,
|
|
||||||
content_type=content_type
|
|
||||||
)
|
|
||||||
|
|
||||||
# 获取文件信息
|
|
||||||
file_stat = os.stat(local_path)
|
|
||||||
|
|
||||||
result = {
|
|
||||||
"bos_key": bos_key,
|
|
||||||
"bucket": self.bucket,
|
|
||||||
"file_size": file_stat.st_size,
|
|
||||||
"content_type": content_type,
|
|
||||||
"upload_time": datetime.utcnow().isoformat(),
|
|
||||||
"etag": response.metadata.etag if hasattr(response, 'metadata') else None,
|
|
||||||
"url": f"{self.endpoint}/{self.bucket}/{bos_key}"
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.info(f"✅ 文件上传成功: {bos_key}")
|
|
||||||
return result
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 文件上传失败: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def download_file(self, bos_key: str, local_path: str) -> bool:
|
|
||||||
"""
|
|
||||||
从BOS下载文件
|
|
||||||
|
|
||||||
Args:
|
|
||||||
bos_key: BOS对象键
|
|
||||||
local_path: 本地保存路径
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
是否下载成功
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 确保目录存在
|
|
||||||
os.makedirs(os.path.dirname(local_path), exist_ok=True)
|
|
||||||
|
|
||||||
# 下载文件
|
|
||||||
response = self.client.get_object(self.bucket, bos_key)
|
|
||||||
|
|
||||||
with open(local_path, 'wb') as f:
|
|
||||||
for chunk in response.data:
|
|
||||||
f.write(chunk)
|
|
||||||
|
|
||||||
logger.info(f"✅ 文件下载成功: {bos_key} -> {local_path}")
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 文件下载失败: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get_object_metadata(self, bos_key: str) -> Optional[Dict]:
|
|
||||||
"""
|
|
||||||
获取对象元数据
|
|
||||||
|
|
||||||
Args:
|
|
||||||
bos_key: BOS对象键
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
对象元数据
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
response = self.client.get_object_meta_data(self.bucket, bos_key)
|
|
||||||
|
|
||||||
metadata = {
|
|
||||||
"bos_key": bos_key,
|
|
||||||
"bucket": self.bucket,
|
|
||||||
"content_length": response.metadata.content_length,
|
|
||||||
"content_type": response.metadata.content_type,
|
|
||||||
"etag": response.metadata.etag,
|
|
||||||
"last_modified": response.metadata.last_modified,
|
|
||||||
"url": f"{self.endpoint}/{self.bucket}/{bos_key}"
|
|
||||||
}
|
|
||||||
|
|
||||||
return metadata
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 获取对象元数据失败: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def delete_object(self, bos_key: str) -> bool:
|
|
||||||
"""
|
|
||||||
删除BOS对象
|
|
||||||
|
|
||||||
Args:
|
|
||||||
bos_key: BOS对象键
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
是否删除成功
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
self.client.delete_object(self.bucket, bos_key)
|
|
||||||
logger.info(f"✅ 对象删除成功: {bos_key}")
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 对象删除失败: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def list_objects(self, prefix: str = "", max_keys: int = 1000) -> List[Dict]:
|
|
||||||
"""
|
|
||||||
列出BOS对象
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prefix: 对象键前缀
|
|
||||||
max_keys: 最大返回数量
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
对象列表
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
response = self.client.list_objects(
|
|
||||||
bucket_name=self.bucket,
|
|
||||||
prefix=prefix,
|
|
||||||
max_keys=max_keys
|
|
||||||
)
|
|
||||||
|
|
||||||
objects = []
|
|
||||||
if hasattr(response, 'contents'):
|
|
||||||
for obj in response.contents:
|
|
||||||
objects.append({
|
|
||||||
"key": obj.key,
|
|
||||||
"size": obj.size,
|
|
||||||
"last_modified": obj.last_modified,
|
|
||||||
"etag": obj.etag,
|
|
||||||
"url": f"{self.endpoint}/{self.bucket}/{obj.key}"
|
|
||||||
})
|
|
||||||
|
|
||||||
return objects
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 列出对象失败: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
def restore_archive_object(self, bos_key: str, days: int = 1, tier: str = "Standard") -> bool:
|
|
||||||
"""
|
|
||||||
恢复归档对象
|
|
||||||
|
|
||||||
Args:
|
|
||||||
bos_key: BOS对象键
|
|
||||||
days: 恢复天数
|
|
||||||
tier: 恢复级别 (Expedited/Standard/Bulk)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
是否成功发起恢复
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 使用自定义客户端进行归档恢复
|
|
||||||
restore_client = ArchiveRestoreClient(
|
|
||||||
bce_client_configuration.BceClientConfiguration(
|
|
||||||
credentials=bce_credentials.BceCredentials(self.ak, self.sk),
|
|
||||||
endpoint=self.endpoint
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
response = restore_client.restore_object(self.bucket, bos_key, days, tier)
|
|
||||||
logger.info(f"✅ 归档恢复请求已发送: {bos_key}")
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 归档恢复失败: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _generate_bos_key(self, local_path: str) -> str:
|
|
||||||
"""生成BOS对象键"""
|
|
||||||
filename = os.path.basename(local_path)
|
|
||||||
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
|
||||||
|
|
||||||
# 根据文件类型分类存储
|
|
||||||
if self._is_image_file(local_path):
|
|
||||||
return f"images/{timestamp}_{filename}"
|
|
||||||
elif self._is_text_file(local_path):
|
|
||||||
return f"texts/{timestamp}_{filename}"
|
|
||||||
else:
|
|
||||||
return f"files/{timestamp}_{filename}"
|
|
||||||
|
|
||||||
def _detect_content_type(self, file_path: str) -> str:
|
|
||||||
"""检测文件内容类型"""
|
|
||||||
ext = os.path.splitext(file_path)[1].lower()
|
|
||||||
|
|
||||||
content_types = {
|
|
||||||
'.jpg': 'image/jpeg',
|
|
||||||
'.jpeg': 'image/jpeg',
|
|
||||||
'.png': 'image/png',
|
|
||||||
'.gif': 'image/gif',
|
|
||||||
'.bmp': 'image/bmp',
|
|
||||||
'.webp': 'image/webp',
|
|
||||||
'.txt': 'text/plain',
|
|
||||||
'.json': 'application/json',
|
|
||||||
'.pdf': 'application/pdf'
|
|
||||||
}
|
|
||||||
|
|
||||||
return content_types.get(ext, 'application/octet-stream')
|
|
||||||
|
|
||||||
def _is_image_file(self, file_path: str) -> bool:
|
|
||||||
"""判断是否为图像文件"""
|
|
||||||
image_exts = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp'}
|
|
||||||
ext = os.path.splitext(file_path)[1].lower()
|
|
||||||
return ext in image_exts
|
|
||||||
|
|
||||||
def _is_text_file(self, file_path: str) -> bool:
|
|
||||||
"""判断是否为文本文件"""
|
|
||||||
text_exts = {'.txt', '.json', '.csv', '.md'}
|
|
||||||
ext = os.path.splitext(file_path)[1].lower()
|
|
||||||
return ext in text_exts
|
|
||||||
|
|
||||||
|
|
||||||
class ArchiveRestoreClient(bce_base_client.BceBaseClient):
|
|
||||||
"""归档恢复客户端"""
|
|
||||||
|
|
||||||
def __init__(self, config):
|
|
||||||
self.config = copy.deepcopy(bce_client_configuration.DEFAULT_CONFIG)
|
|
||||||
self.config.merge_non_none_values(config)
|
|
||||||
|
|
||||||
def restore_object(self, bucket: str, key: str, days: int = 1, tier: str = "Standard"):
|
|
||||||
"""恢复归档对象"""
|
|
||||||
path = f'/{bucket}/{key}'.encode('utf-8')
|
|
||||||
headers = {
|
|
||||||
b'x-bce-restore-days': str(days).encode('utf-8'),
|
|
||||||
b'x-bce-restore-tier': tier.encode('utf-8'),
|
|
||||||
b'Accept': b'application/json'
|
|
||||||
}
|
|
||||||
|
|
||||||
params = {"restore": ""}
|
|
||||||
payload = json.dumps({}, ensure_ascii=False)
|
|
||||||
return self._send_request(b'POST', path, headers, params, payload.encode('utf-8'))
|
|
||||||
|
|
||||||
|
|
||||||
# 全局实例
|
|
||||||
bos_manager = None
|
|
||||||
|
|
||||||
def get_bos_manager() -> BaiduBOSManager:
|
|
||||||
"""获取BOS管理器实例"""
|
|
||||||
global bos_manager
|
|
||||||
if bos_manager is None:
|
|
||||||
bos_manager = BaiduBOSManager()
|
|
||||||
return bos_manager
|
|
||||||
@ -1,483 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
百度VDB向量数据库后端
|
|
||||||
支持多模态向量存储和检索
|
|
||||||
"""
|
|
||||||
|
|
||||||
import pymochow
|
|
||||||
from pymochow.configuration import Configuration
|
|
||||||
from pymochow.auth.bce_credentials import BceCredentials
|
|
||||||
from pymochow.model.schema import Schema, Field, VectorIndex, HNSWParams, SecondaryIndex
|
|
||||||
from pymochow.model.enum import FieldType, IndexType, MetricType
|
|
||||||
from pymochow.model.table import Partition, Row, VectorTopkSearchRequest, FloatVector, VectorSearchConfig
|
|
||||||
import numpy as np
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
import logging
|
|
||||||
from typing import List, Tuple, Union, Dict, Any
|
|
||||||
import hashlib
|
|
||||||
import os
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
class BaiduVDBBackend:
|
|
||||||
"""百度VDB向量数据库后端"""
|
|
||||||
|
|
||||||
def __init__(self, account: str = "root", api_key: str = "vdb$yjr9ln3n0td",
|
|
||||||
endpoint: str = "http://180.76.96.191:5287", database_name: str = "multimodal_retrieval"):
|
|
||||||
"""
|
|
||||||
初始化百度VDB后端
|
|
||||||
|
|
||||||
Args:
|
|
||||||
account: 用户名
|
|
||||||
api_key: API密钥
|
|
||||||
endpoint: 服务器端点
|
|
||||||
database_name: 数据库名称
|
|
||||||
"""
|
|
||||||
self.account = account
|
|
||||||
self.api_key = api_key
|
|
||||||
self.endpoint = endpoint
|
|
||||||
self.database_name = database_name
|
|
||||||
|
|
||||||
# 初始化客户端
|
|
||||||
self.client = None
|
|
||||||
self.db = None
|
|
||||||
self.text_table = None
|
|
||||||
self.image_table = None
|
|
||||||
|
|
||||||
# 表名配置
|
|
||||||
self.text_table_name = "text_vectors"
|
|
||||||
self.image_table_name = "image_vectors"
|
|
||||||
|
|
||||||
# 向量维度(根据Ops-MM-embedding-v1-7B模型)
|
|
||||||
self.vector_dimension = 3584
|
|
||||||
|
|
||||||
self._connect()
|
|
||||||
self._ensure_database()
|
|
||||||
self._ensure_tables()
|
|
||||||
|
|
||||||
def _connect(self):
|
|
||||||
"""连接到百度VDB"""
|
|
||||||
try:
|
|
||||||
config = Configuration(
|
|
||||||
credentials=BceCredentials(self.account, self.api_key),
|
|
||||||
endpoint=self.endpoint
|
|
||||||
)
|
|
||||||
self.client = pymochow.MochowClient(config)
|
|
||||||
logger.info(f"✅ 成功连接到百度VDB: {self.endpoint}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 连接百度VDB失败: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def _ensure_database(self):
|
|
||||||
"""确保数据库存在"""
|
|
||||||
try:
|
|
||||||
# 检查数据库是否存在
|
|
||||||
db_list = self.client.list_databases()
|
|
||||||
existing_dbs = [db.database_name for db in db_list]
|
|
||||||
|
|
||||||
if self.database_name not in existing_dbs:
|
|
||||||
logger.info(f"创建数据库: {self.database_name}")
|
|
||||||
self.db = self.client.create_database(self.database_name)
|
|
||||||
else:
|
|
||||||
logger.info(f"使用现有数据库: {self.database_name}")
|
|
||||||
self.db = self.client.database(self.database_name)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 数据库操作失败: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def _ensure_tables(self):
|
|
||||||
"""确保表存在"""
|
|
||||||
try:
|
|
||||||
# 获取现有表列表
|
|
||||||
existing_tables = self.db.list_table()
|
|
||||||
existing_table_names = [table.table_name for table in existing_tables]
|
|
||||||
|
|
||||||
# 创建文本向量表
|
|
||||||
if self.text_table_name not in existing_table_names:
|
|
||||||
self._create_text_table()
|
|
||||||
else:
|
|
||||||
self.text_table = self.db.table(self.text_table_name)
|
|
||||||
logger.info(f"使用现有文本表: {self.text_table_name}")
|
|
||||||
|
|
||||||
# 创建图像向量表
|
|
||||||
if self.image_table_name not in existing_table_names:
|
|
||||||
self._create_image_table()
|
|
||||||
else:
|
|
||||||
self.image_table = self.db.table(self.image_table_name)
|
|
||||||
logger.info(f"使用现有图像表: {self.image_table_name}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 表操作失败: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def _create_text_table(self):
|
|
||||||
"""创建文本向量表"""
|
|
||||||
try:
|
|
||||||
logger.info(f"创建文本向量表: {self.text_table_name}")
|
|
||||||
|
|
||||||
# 定义字段 - 移除可能导致问题的复杂配置
|
|
||||||
fields = [
|
|
||||||
Field("id", FieldType.STRING, primary_key=True, not_null=True),
|
|
||||||
Field("text_content", FieldType.STRING, not_null=True),
|
|
||||||
Field("vector", FieldType.FLOAT_VECTOR, not_null=True, dimension=self.vector_dimension)
|
|
||||||
]
|
|
||||||
|
|
||||||
# 定义索引 - 简化配置
|
|
||||||
indexes = [
|
|
||||||
VectorIndex(
|
|
||||||
index_name="text_vector_idx",
|
|
||||||
index_type=IndexType.HNSW,
|
|
||||||
field="vector",
|
|
||||||
metric_type=MetricType.COSINE,
|
|
||||||
params=HNSWParams(m=16, efconstruction=100),
|
|
||||||
auto_build=True
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
# 创建表 - 简化配置
|
|
||||||
self.text_table = self.db.create_table(
|
|
||||||
table_name=self.text_table_name,
|
|
||||||
replication=1, # 单副本
|
|
||||||
schema=Schema(fields=fields, indexes=indexes)
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"✅ 文本向量表创建成功")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 创建文本表失败: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def _create_image_table(self):
|
|
||||||
"""创建图像向量表"""
|
|
||||||
try:
|
|
||||||
logger.info(f"创建图像向量表: {self.image_table_name}")
|
|
||||||
|
|
||||||
# 定义字段 - 移除可能导致问题的复杂配置
|
|
||||||
fields = [
|
|
||||||
Field("id", FieldType.STRING, primary_key=True, not_null=True),
|
|
||||||
Field("image_path", FieldType.STRING, not_null=True),
|
|
||||||
Field("vector", FieldType.FLOAT_VECTOR, not_null=True, dimension=self.vector_dimension)
|
|
||||||
]
|
|
||||||
|
|
||||||
# 定义索引 - 简化配置
|
|
||||||
indexes = [
|
|
||||||
VectorIndex(
|
|
||||||
index_name="image_vector_idx",
|
|
||||||
index_type=IndexType.HNSW,
|
|
||||||
field="vector",
|
|
||||||
metric_type=MetricType.COSINE,
|
|
||||||
params=HNSWParams(m=16, efconstruction=100),
|
|
||||||
auto_build=True
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
# 创建表 - 简化配置
|
|
||||||
self.image_table = self.db.create_table(
|
|
||||||
table_name=self.image_table_name,
|
|
||||||
replication=1, # 单副本
|
|
||||||
schema=Schema(fields=fields, indexes=indexes)
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"✅ 图像向量表创建成功")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 创建图像表失败: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def _generate_id(self, content: str) -> str:
|
|
||||||
"""生成唯一ID"""
|
|
||||||
return hashlib.md5(f"{content}_{time.time()}".encode()).hexdigest()
|
|
||||||
|
|
||||||
def store_text_vectors(self, texts: List[str], vectors: np.ndarray, metadata: List[Dict] = None) -> List[str]:
|
|
||||||
"""
|
|
||||||
存储文本向量
|
|
||||||
|
|
||||||
Args:
|
|
||||||
texts: 文本列表
|
|
||||||
vectors: 向量数组
|
|
||||||
metadata: 元数据列表
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
存储的ID列表
|
|
||||||
"""
|
|
||||||
if len(texts) != len(vectors):
|
|
||||||
raise ValueError("文本数量与向量数量不匹配")
|
|
||||||
|
|
||||||
try:
|
|
||||||
rows = []
|
|
||||||
ids = []
|
|
||||||
current_time = int(time.time() * 1000) # 毫秒时间戳
|
|
||||||
|
|
||||||
for i, (text, vector) in enumerate(zip(texts, vectors)):
|
|
||||||
doc_id = self._generate_id(text)
|
|
||||||
ids.append(doc_id)
|
|
||||||
|
|
||||||
# 准备元数据
|
|
||||||
meta = metadata[i] if metadata and i < len(metadata) else {}
|
|
||||||
meta_json = json.dumps(meta, ensure_ascii=False)
|
|
||||||
|
|
||||||
row = Row(
|
|
||||||
id=doc_id,
|
|
||||||
text_content=text,
|
|
||||||
vector=vector.tolist()
|
|
||||||
)
|
|
||||||
rows.append(row)
|
|
||||||
|
|
||||||
# 批量插入
|
|
||||||
self.text_table.upsert(rows)
|
|
||||||
logger.info(f"✅ 成功存储 {len(texts)} 条文本向量")
|
|
||||||
|
|
||||||
return ids
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 存储文本向量失败: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def store_image_vectors(self, image_paths: List[str], vectors: np.ndarray, metadata: List[Dict] = None) -> List[str]:
|
|
||||||
"""
|
|
||||||
存储图像向量
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image_paths: 图像路径列表
|
|
||||||
vectors: 向量数组
|
|
||||||
metadata: 元数据列表
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
存储的ID列表
|
|
||||||
"""
|
|
||||||
if len(image_paths) != len(vectors):
|
|
||||||
raise ValueError("图像数量与向量数量不匹配")
|
|
||||||
|
|
||||||
try:
|
|
||||||
rows = []
|
|
||||||
ids = []
|
|
||||||
current_time = int(time.time() * 1000) # 毫秒时间戳
|
|
||||||
|
|
||||||
for i, (image_path, vector) in enumerate(zip(image_paths, vectors)):
|
|
||||||
doc_id = self._generate_id(image_path)
|
|
||||||
ids.append(doc_id)
|
|
||||||
|
|
||||||
# 准备元数据
|
|
||||||
meta = metadata[i] if metadata and i < len(metadata) else {}
|
|
||||||
meta_json = json.dumps(meta, ensure_ascii=False)
|
|
||||||
|
|
||||||
row = Row(
|
|
||||||
id=doc_id,
|
|
||||||
image_path=image_path,
|
|
||||||
vector=vector.tolist()
|
|
||||||
)
|
|
||||||
rows.append(row)
|
|
||||||
|
|
||||||
# 批量插入
|
|
||||||
self.image_table.upsert(rows)
|
|
||||||
logger.info(f"✅ 成功存储 {len(image_paths)} 条图像向量")
|
|
||||||
|
|
||||||
return ids
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 存储图像向量失败: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def search_text_vectors(self, query_vector: np.ndarray, top_k: int = 5,
|
|
||||||
filter_condition: str = None) -> List[Tuple[str, str, float, Dict]]:
|
|
||||||
"""
|
|
||||||
搜索文本向量
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query_vector: 查询向量
|
|
||||||
top_k: 返回结果数量
|
|
||||||
filter_condition: 过滤条件
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(id, text_content, score, metadata) 列表
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 构建搜索请求
|
|
||||||
request = VectorTopkSearchRequest(
|
|
||||||
vector_field="vector",
|
|
||||||
vector=FloatVector(query_vector.tolist()),
|
|
||||||
limit=top_k,
|
|
||||||
filter=filter_condition,
|
|
||||||
config=VectorSearchConfig(ef=200)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 执行搜索
|
|
||||||
results = self.text_table.vector_search(request=request)
|
|
||||||
|
|
||||||
# 解析结果
|
|
||||||
search_results = []
|
|
||||||
for result in results:
|
|
||||||
doc_id = result.get('id', '')
|
|
||||||
text_content = result.get('text_content', '')
|
|
||||||
score = result.get('_score', 0.0)
|
|
||||||
|
|
||||||
search_results.append((doc_id, text_content, float(score), {}))
|
|
||||||
|
|
||||||
logger.info(f"✅ 文本向量搜索完成,返回 {len(search_results)} 条结果")
|
|
||||||
return search_results
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 文本向量搜索失败: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
def search_image_vectors(self, query_vector: np.ndarray, top_k: int = 5,
|
|
||||||
filter_condition: str = None) -> List[Tuple[str, str, str, float, Dict]]:
|
|
||||||
"""
|
|
||||||
搜索图像向量
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query_vector: 查询向量
|
|
||||||
top_k: 返回结果数量
|
|
||||||
filter_condition: 过滤条件
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(id, image_path, image_name, score, metadata) 列表
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 构建搜索请求
|
|
||||||
request = VectorTopkSearchRequest(
|
|
||||||
vector_field="vector",
|
|
||||||
vector=FloatVector(query_vector.tolist()),
|
|
||||||
limit=top_k,
|
|
||||||
filter=filter_condition,
|
|
||||||
config=VectorSearchConfig(ef=200)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 执行搜索
|
|
||||||
results = self.image_table.vector_search(request=request)
|
|
||||||
|
|
||||||
# 解析结果
|
|
||||||
search_results = []
|
|
||||||
for result in results:
|
|
||||||
doc_id = result.get('id', '')
|
|
||||||
image_path = result.get('image_path', '')
|
|
||||||
image_name = os.path.basename(image_path)
|
|
||||||
score = result.get('_score', 0.0)
|
|
||||||
|
|
||||||
search_results.append((doc_id, image_path, image_name, float(score), {}))
|
|
||||||
|
|
||||||
logger.info(f"✅ 图像向量搜索完成,返回 {len(search_results)} 条结果")
|
|
||||||
return search_results
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 图像向量搜索失败: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
def get_statistics(self) -> Dict[str, Any]:
|
|
||||||
"""获取数据库统计信息"""
|
|
||||||
try:
|
|
||||||
stats = {}
|
|
||||||
|
|
||||||
# 文本表统计
|
|
||||||
if self.text_table:
|
|
||||||
text_stats = self.text_table.stats()
|
|
||||||
stats['text_table'] = {
|
|
||||||
'row_count': text_stats.get('rowCount', 0),
|
|
||||||
'memory_size_mb': text_stats.get('memorySizeInByte', 0) / (1024 * 1024),
|
|
||||||
'disk_size_mb': text_stats.get('diskSizeInByte', 0) / (1024 * 1024)
|
|
||||||
}
|
|
||||||
|
|
||||||
# 图像表统计
|
|
||||||
if self.image_table:
|
|
||||||
image_stats = self.image_table.stats()
|
|
||||||
stats['image_table'] = {
|
|
||||||
'row_count': image_stats.get('rowCount', 0),
|
|
||||||
'memory_size_mb': image_stats.get('memorySizeInByte', 0) / (1024 * 1024),
|
|
||||||
'disk_size_mb': image_stats.get('diskSizeInByte', 0) / (1024 * 1024)
|
|
||||||
}
|
|
||||||
|
|
||||||
return stats
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 获取统计信息失败: {e}")
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def clear_all_data(self):
|
|
||||||
"""清空所有数据"""
|
|
||||||
try:
|
|
||||||
# 清空文本表
|
|
||||||
if self.text_table:
|
|
||||||
self.text_table.delete(filter="*")
|
|
||||||
logger.info("✅ 文本表数据已清空")
|
|
||||||
|
|
||||||
# 清空图像表
|
|
||||||
if self.image_table:
|
|
||||||
self.image_table.delete(filter="*")
|
|
||||||
logger.info("✅ 图像表数据已清空")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 清空数据失败: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
"""关闭连接"""
|
|
||||||
try:
|
|
||||||
if self.client:
|
|
||||||
self.client.close()
|
|
||||||
logger.info("✅ 百度VDB连接已关闭")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 关闭连接失败: {e}")
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
||||||
self.close()
|
|
||||||
|
|
||||||
|
|
||||||
def test_vdb_backend():
|
|
||||||
"""测试VDB后端功能"""
|
|
||||||
print("=" * 60)
|
|
||||||
print("测试百度VDB后端功能")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 初始化后端
|
|
||||||
vdb = BaiduVDBBackend()
|
|
||||||
|
|
||||||
# 测试数据
|
|
||||||
test_texts = [
|
|
||||||
"这是一个测试文本",
|
|
||||||
"多模态检索系统",
|
|
||||||
"向量数据库测试"
|
|
||||||
]
|
|
||||||
|
|
||||||
# 生成测试向量(随机向量)
|
|
||||||
test_vectors = np.random.rand(3, 3584).astype(np.float32)
|
|
||||||
|
|
||||||
# 存储文本向量
|
|
||||||
print("1. 存储文本向量...")
|
|
||||||
text_ids = vdb.store_text_vectors(test_texts, test_vectors)
|
|
||||||
print(f" 存储成功,ID: {text_ids}")
|
|
||||||
|
|
||||||
# 搜索文本向量
|
|
||||||
print("\n2. 搜索文本向量...")
|
|
||||||
query_vector = np.random.rand(3584).astype(np.float32)
|
|
||||||
results = vdb.search_text_vectors(query_vector, top_k=3)
|
|
||||||
print(f" 搜索结果: {len(results)} 条")
|
|
||||||
for i, (doc_id, text, score, meta) in enumerate(results, 1):
|
|
||||||
print(f" {i}. {text[:30]}... (相似度: {score:.4f})")
|
|
||||||
|
|
||||||
# 获取统计信息
|
|
||||||
print("\n3. 数据库统计信息...")
|
|
||||||
stats = vdb.get_statistics()
|
|
||||||
print(f" 统计信息: {stats}")
|
|
||||||
|
|
||||||
# 清理测试数据
|
|
||||||
print("\n4. 清理测试数据...")
|
|
||||||
vdb.clear_all_data()
|
|
||||||
print(" 清理完成")
|
|
||||||
|
|
||||||
print("\n✅ 百度VDB后端测试完成!")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"\n❌ 测试失败: {e}")
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test_vdb_backend()
|
|
||||||
@ -1,482 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
修复版百度VDB后端 - 解决Invalid Index Schema错误
|
|
||||||
基于官方文档规范重新设计表结构
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import numpy as np
|
|
||||||
import json
|
|
||||||
import hashlib
|
|
||||||
import time
|
|
||||||
import logging
|
|
||||||
from typing import List, Tuple, Dict, Any, Optional
|
|
||||||
|
|
||||||
import pymochow
|
|
||||||
from pymochow.configuration import Configuration
|
|
||||||
from pymochow.auth.bce_credentials import BceCredentials
|
|
||||||
from pymochow.model.schema import Schema, Field, VectorIndex, HNSWParams
|
|
||||||
from pymochow.model.enum import FieldType, IndexType, MetricType
|
|
||||||
from pymochow.model.table import Row, Partition
|
|
||||||
from pymochow.model.table import VectorTopkSearchRequest, VectorSearchConfig, FloatVector
|
|
||||||
|
|
||||||
# 设置日志
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
class BaiduVDBFixed:
|
|
||||||
"""修复版百度VDB后端类"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
account: str = "root",
|
|
||||||
api_key: str = "vdb$yjr9ln3n0td",
|
|
||||||
endpoint: str = "http://180.76.96.191:5287",
|
|
||||||
database_name: str = "multimodal_fixed",
|
|
||||||
vector_dimension: int = 3584):
|
|
||||||
"""
|
|
||||||
初始化VDB连接
|
|
||||||
|
|
||||||
Args:
|
|
||||||
account: 账户名
|
|
||||||
api_key: API密钥
|
|
||||||
endpoint: 服务端点
|
|
||||||
database_name: 数据库名称
|
|
||||||
vector_dimension: 向量维度
|
|
||||||
"""
|
|
||||||
self.account = account
|
|
||||||
self.api_key = api_key
|
|
||||||
self.endpoint = endpoint
|
|
||||||
self.database_name = database_name
|
|
||||||
self.vector_dimension = vector_dimension
|
|
||||||
|
|
||||||
# 表名
|
|
||||||
self.text_table_name = "text_vectors_v2"
|
|
||||||
self.image_table_name = "image_vectors_v2"
|
|
||||||
|
|
||||||
# 初始化连接
|
|
||||||
self.client = None
|
|
||||||
self.db = None
|
|
||||||
self.text_table = None
|
|
||||||
self.image_table = None
|
|
||||||
|
|
||||||
self._init_connection()
|
|
||||||
|
|
||||||
def _init_connection(self):
|
|
||||||
"""初始化数据库连接"""
|
|
||||||
try:
|
|
||||||
logger.info("🔗 初始化百度VDB连接...")
|
|
||||||
|
|
||||||
# 创建配置
|
|
||||||
config = Configuration(
|
|
||||||
credentials=BceCredentials(self.account, self.api_key),
|
|
||||||
endpoint=self.endpoint
|
|
||||||
)
|
|
||||||
|
|
||||||
# 创建客户端
|
|
||||||
self.client = pymochow.MochowClient(config)
|
|
||||||
logger.info("✅ VDB客户端创建成功")
|
|
||||||
|
|
||||||
# 确保数据库存在
|
|
||||||
self._ensure_database()
|
|
||||||
|
|
||||||
# 确保表存在
|
|
||||||
self._ensure_tables()
|
|
||||||
|
|
||||||
logger.info("✅ VDB后端初始化完成")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ VDB连接初始化失败: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def _ensure_database(self):
|
|
||||||
"""确保数据库存在"""
|
|
||||||
try:
|
|
||||||
# 检查数据库是否存在
|
|
||||||
db_list = self.client.list_databases()
|
|
||||||
db_names = [db.database_name for db in db_list]
|
|
||||||
|
|
||||||
if self.database_name not in db_names:
|
|
||||||
logger.info(f"创建数据库: {self.database_name}")
|
|
||||||
self.db = self.client.create_database(self.database_name)
|
|
||||||
else:
|
|
||||||
logger.info(f"使用现有数据库: {self.database_name}")
|
|
||||||
self.db = self.client.database(self.database_name)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 数据库操作失败: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def _ensure_tables(self):
|
|
||||||
"""确保表存在"""
|
|
||||||
try:
|
|
||||||
# 获取现有表列表
|
|
||||||
table_list = self.db.list_table()
|
|
||||||
table_names = [table.table_name for table in table_list]
|
|
||||||
|
|
||||||
# 创建文本表
|
|
||||||
if self.text_table_name not in table_names:
|
|
||||||
self._create_text_table_fixed()
|
|
||||||
else:
|
|
||||||
self.text_table = self.db.table(self.text_table_name)
|
|
||||||
logger.info(f"✅ 使用现有文本表: {self.text_table_name}")
|
|
||||||
|
|
||||||
# 创建图像表
|
|
||||||
if self.image_table_name not in table_names:
|
|
||||||
self._create_image_table_fixed()
|
|
||||||
else:
|
|
||||||
self.image_table = self.db.table(self.image_table_name)
|
|
||||||
logger.info(f"✅ 使用现有图像表: {self.image_table_name}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 表操作失败: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def _create_text_table_fixed(self):
|
|
||||||
"""创建修复版文本向量表"""
|
|
||||||
try:
|
|
||||||
logger.info(f"创建修复版文本向量表: {self.text_table_name}")
|
|
||||||
|
|
||||||
# 定义字段 - 严格按照官方文档规范
|
|
||||||
fields = [
|
|
||||||
# 主键和分区键 - 必须是STRING类型
|
|
||||||
Field("id", FieldType.STRING, primary_key=True, partition_key=True, not_null=True),
|
|
||||||
# 文本内容 - 使用STRING而不是TEXT
|
|
||||||
Field("content", FieldType.STRING, not_null=True),
|
|
||||||
# 向量字段 - 必须指定维度
|
|
||||||
Field("vector", FieldType.FLOAT_VECTOR, not_null=True, dimension=self.vector_dimension)
|
|
||||||
]
|
|
||||||
|
|
||||||
# 定义索引 - 只创建向量索引,避免复杂的二级索引
|
|
||||||
indexes = [
|
|
||||||
VectorIndex(
|
|
||||||
index_name="text_vector_index",
|
|
||||||
index_type=IndexType.HNSW,
|
|
||||||
field="vector",
|
|
||||||
metric_type=MetricType.COSINE,
|
|
||||||
params=HNSWParams(m=16, efconstruction=200), # 使用较小的参数
|
|
||||||
auto_build=True
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
# 创建Schema
|
|
||||||
schema = Schema(fields=fields, indexes=indexes)
|
|
||||||
|
|
||||||
# 创建表 - 使用较小的副本数和分区数
|
|
||||||
self.text_table = self.db.create_table(
|
|
||||||
table_name=self.text_table_name,
|
|
||||||
replication=2, # 最小副本数
|
|
||||||
partition=Partition(partition_num=1), # 单分区
|
|
||||||
schema=schema,
|
|
||||||
description="修复版文本向量表"
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"✅ 文本表创建成功: {self.text_table_name}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 创建文本表失败: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def _create_image_table_fixed(self):
|
|
||||||
"""创建修复版图像向量表"""
|
|
||||||
try:
|
|
||||||
logger.info(f"创建修复版图像向量表: {self.image_table_name}")
|
|
||||||
|
|
||||||
# 定义字段 - 严格按照官方文档规范
|
|
||||||
fields = [
|
|
||||||
# 主键和分区键
|
|
||||||
Field("id", FieldType.STRING, primary_key=True, partition_key=True, not_null=True),
|
|
||||||
# 图像路径
|
|
||||||
Field("image_path", FieldType.STRING, not_null=True),
|
|
||||||
# 向量字段
|
|
||||||
Field("vector", FieldType.FLOAT_VECTOR, not_null=True, dimension=self.vector_dimension)
|
|
||||||
]
|
|
||||||
|
|
||||||
# 定义索引 - 只创建向量索引
|
|
||||||
indexes = [
|
|
||||||
VectorIndex(
|
|
||||||
index_name="image_vector_index",
|
|
||||||
index_type=IndexType.HNSW,
|
|
||||||
field="vector",
|
|
||||||
metric_type=MetricType.COSINE,
|
|
||||||
params=HNSWParams(m=16, efconstruction=200),
|
|
||||||
auto_build=True
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
# 创建Schema
|
|
||||||
schema = Schema(fields=fields, indexes=indexes)
|
|
||||||
|
|
||||||
# 创建表
|
|
||||||
self.image_table = self.db.create_table(
|
|
||||||
table_name=self.image_table_name,
|
|
||||||
replication=2,
|
|
||||||
partition=Partition(partition_num=1),
|
|
||||||
schema=schema,
|
|
||||||
description="修复版图像向量表"
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"✅ 图像表创建成功: {self.image_table_name}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 创建图像表失败: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def _generate_id(self, content: str) -> str:
|
|
||||||
"""生成唯一ID"""
|
|
||||||
return hashlib.md5(content.encode('utf-8')).hexdigest()
|
|
||||||
|
|
||||||
def store_text_vectors(self, texts: List[str], vectors: np.ndarray) -> List[str]:
|
|
||||||
"""存储文本向量"""
|
|
||||||
try:
|
|
||||||
if len(texts) != len(vectors):
|
|
||||||
raise ValueError("文本数量与向量数量不匹配")
|
|
||||||
|
|
||||||
logger.info(f"存储 {len(texts)} 条文本向量...")
|
|
||||||
|
|
||||||
rows = []
|
|
||||||
ids = []
|
|
||||||
|
|
||||||
for i, (text, vector) in enumerate(zip(texts, vectors)):
|
|
||||||
doc_id = self._generate_id(text)
|
|
||||||
ids.append(doc_id)
|
|
||||||
|
|
||||||
row = Row(
|
|
||||||
id=doc_id,
|
|
||||||
content=text,
|
|
||||||
vector=vector.tolist()
|
|
||||||
)
|
|
||||||
rows.append(row)
|
|
||||||
|
|
||||||
# 批量插入
|
|
||||||
self.text_table.upsert(rows)
|
|
||||||
logger.info(f"✅ 成功存储 {len(texts)} 条文本向量")
|
|
||||||
|
|
||||||
return ids
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 存储文本向量失败: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
def store_image_vectors(self, image_paths: List[str], vectors: np.ndarray) -> List[str]:
|
|
||||||
"""存储图像向量"""
|
|
||||||
try:
|
|
||||||
if len(image_paths) != len(vectors):
|
|
||||||
raise ValueError("图像数量与向量数量不匹配")
|
|
||||||
|
|
||||||
logger.info(f"存储 {len(image_paths)} 条图像向量...")
|
|
||||||
|
|
||||||
rows = []
|
|
||||||
ids = []
|
|
||||||
|
|
||||||
for i, (image_path, vector) in enumerate(zip(image_paths, vectors)):
|
|
||||||
doc_id = self._generate_id(image_path)
|
|
||||||
ids.append(doc_id)
|
|
||||||
|
|
||||||
row = Row(
|
|
||||||
id=doc_id,
|
|
||||||
image_path=image_path,
|
|
||||||
vector=vector.tolist()
|
|
||||||
)
|
|
||||||
rows.append(row)
|
|
||||||
|
|
||||||
# 批量插入
|
|
||||||
self.image_table.upsert(rows)
|
|
||||||
logger.info(f"✅ 成功存储 {len(image_paths)} 条图像向量")
|
|
||||||
|
|
||||||
return ids
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 存储图像向量失败: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
def search_text_vectors(self, query_vector: np.ndarray, top_k: int = 5) -> List[Tuple[str, str, float]]:
|
|
||||||
"""搜索文本向量"""
|
|
||||||
try:
|
|
||||||
logger.info(f"搜索文本向量,top_k={top_k}")
|
|
||||||
|
|
||||||
# 创建搜索请求
|
|
||||||
request = VectorTopkSearchRequest(
|
|
||||||
vector_field="vector",
|
|
||||||
vector=FloatVector(query_vector.tolist()),
|
|
||||||
limit=top_k,
|
|
||||||
config=VectorSearchConfig(ef=200)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 执行搜索
|
|
||||||
results = self.text_table.vector_search(request=request)
|
|
||||||
|
|
||||||
# 解析结果
|
|
||||||
search_results = []
|
|
||||||
for result in results:
|
|
||||||
doc_id = result.get('id', '')
|
|
||||||
content = result.get('content', '')
|
|
||||||
score = result.get('_score', 0.0)
|
|
||||||
|
|
||||||
search_results.append((doc_id, content, float(score)))
|
|
||||||
|
|
||||||
logger.info(f"✅ 文本向量搜索完成,返回 {len(search_results)} 条结果")
|
|
||||||
return search_results
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 文本向量搜索失败: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
def search_image_vectors(self, query_vector: np.ndarray, top_k: int = 5) -> List[Tuple[str, str, float]]:
|
|
||||||
"""搜索图像向量"""
|
|
||||||
try:
|
|
||||||
logger.info(f"搜索图像向量,top_k={top_k}")
|
|
||||||
|
|
||||||
# 创建搜索请求
|
|
||||||
request = VectorTopkSearchRequest(
|
|
||||||
vector_field="vector",
|
|
||||||
vector=FloatVector(query_vector.tolist()),
|
|
||||||
limit=top_k,
|
|
||||||
config=VectorSearchConfig(ef=200)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 执行搜索
|
|
||||||
results = self.image_table.vector_search(request=request)
|
|
||||||
|
|
||||||
# 解析结果
|
|
||||||
search_results = []
|
|
||||||
for result in results:
|
|
||||||
doc_id = result.get('id', '')
|
|
||||||
image_path = result.get('image_path', '')
|
|
||||||
score = result.get('_score', 0.0)
|
|
||||||
|
|
||||||
search_results.append((doc_id, image_path, float(score)))
|
|
||||||
|
|
||||||
logger.info(f"✅ 图像向量搜索完成,返回 {len(search_results)} 条结果")
|
|
||||||
return search_results
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 图像向量搜索失败: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
def get_statistics(self) -> Dict[str, Any]:
|
|
||||||
"""获取统计信息"""
|
|
||||||
try:
|
|
||||||
stats = {
|
|
||||||
"database_name": self.database_name,
|
|
||||||
"text_table": self.text_table_name,
|
|
||||||
"image_table": self.image_table_name,
|
|
||||||
"vector_dimension": self.vector_dimension,
|
|
||||||
"status": "connected"
|
|
||||||
}
|
|
||||||
|
|
||||||
# 尝试获取表统计信息
|
|
||||||
try:
|
|
||||||
text_stats = self.text_table.stats()
|
|
||||||
stats["text_count"] = text_stats.get("row_count", 0)
|
|
||||||
except:
|
|
||||||
stats["text_count"] = "unknown"
|
|
||||||
|
|
||||||
try:
|
|
||||||
image_stats = self.image_table.stats()
|
|
||||||
stats["image_count"] = image_stats.get("row_count", 0)
|
|
||||||
except:
|
|
||||||
stats["image_count"] = "unknown"
|
|
||||||
|
|
||||||
return stats
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 获取统计信息失败: {e}")
|
|
||||||
return {"status": "error", "error": str(e)}
|
|
||||||
|
|
||||||
def clear_all_data(self):
|
|
||||||
"""清空所有数据"""
|
|
||||||
try:
|
|
||||||
logger.info("清空所有数据...")
|
|
||||||
|
|
||||||
# 删除表(如果存在)
|
|
||||||
try:
|
|
||||||
self.db.drop_table(self.text_table_name)
|
|
||||||
logger.info(f"✅ 删除文本表: {self.text_table_name}")
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.db.drop_table(self.image_table_name)
|
|
||||||
logger.info(f"✅ 删除图像表: {self.image_table_name}")
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# 重新创建表
|
|
||||||
self._ensure_tables()
|
|
||||||
logger.info("✅ 数据清空完成")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 清空数据失败: {e}")
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
"""关闭连接"""
|
|
||||||
try:
|
|
||||||
if self.client:
|
|
||||||
self.client.close()
|
|
||||||
logger.info("✅ VDB连接已关闭")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 关闭连接失败: {e}")
|
|
||||||
|
|
||||||
def test_fixed_vdb():
|
|
||||||
"""测试修复版VDB"""
|
|
||||||
print("=" * 60)
|
|
||||||
print("测试修复版百度VDB后端")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
vdb = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 1. 初始化VDB
|
|
||||||
print("1. 初始化VDB连接...")
|
|
||||||
vdb = BaiduVDBFixed()
|
|
||||||
print("✅ VDB初始化成功")
|
|
||||||
|
|
||||||
# 2. 测试文本向量存储
|
|
||||||
print("\n2. 测试文本向量存储...")
|
|
||||||
test_texts = [
|
|
||||||
"这是一个测试文本",
|
|
||||||
"另一个测试文本",
|
|
||||||
"第三个测试文本"
|
|
||||||
]
|
|
||||||
|
|
||||||
# 生成随机向量用于测试
|
|
||||||
test_vectors = np.random.rand(len(test_texts), 3584).astype(np.float32)
|
|
||||||
|
|
||||||
text_ids = vdb.store_text_vectors(test_texts, test_vectors)
|
|
||||||
print(f"✅ 存储了 {len(text_ids)} 条文本向量")
|
|
||||||
|
|
||||||
# 3. 测试文本向量搜索
|
|
||||||
print("\n3. 测试文本向量搜索...")
|
|
||||||
query_vector = np.random.rand(3584).astype(np.float32)
|
|
||||||
search_results = vdb.search_text_vectors(query_vector, top_k=2)
|
|
||||||
|
|
||||||
print(f"搜索结果 ({len(search_results)} 条):")
|
|
||||||
for i, (doc_id, content, score) in enumerate(search_results, 1):
|
|
||||||
print(f" {i}. {content[:30]}... (相似度: {score:.4f})")
|
|
||||||
|
|
||||||
# 4. 获取统计信息
|
|
||||||
print("\n4. 获取统计信息...")
|
|
||||||
stats = vdb.get_statistics()
|
|
||||||
print(f"✅ 统计信息: {stats}")
|
|
||||||
|
|
||||||
print(f"\n🎉 修复版VDB测试完成!")
|
|
||||||
print("✅ 表创建成功")
|
|
||||||
print("✅ 向量存储成功")
|
|
||||||
print("✅ 向量搜索成功")
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ 测试失败: {e}")
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
return False
|
|
||||||
|
|
||||||
finally:
|
|
||||||
if vdb:
|
|
||||||
vdb.close()
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test_fixed_vdb()
|
|
||||||
@ -1,328 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
最小化百度VDB测试 - 解决Invalid Index Schema错误
|
|
||||||
使用最简单的表结构,不创建任何索引
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import numpy as np
|
|
||||||
import json
|
|
||||||
import hashlib
|
|
||||||
import time
|
|
||||||
import logging
|
|
||||||
from typing import List, Tuple, Dict, Any, Optional
|
|
||||||
|
|
||||||
import pymochow
|
|
||||||
from pymochow.configuration import Configuration
|
|
||||||
from pymochow.auth.bce_credentials import BceCredentials
|
|
||||||
from pymochow.model.schema import Schema, Field
|
|
||||||
from pymochow.model.enum import FieldType
|
|
||||||
from pymochow.model.table import Row, Partition
|
|
||||||
|
|
||||||
# 设置日志
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
class BaiduVDBMinimal:
|
|
||||||
"""最小化百度VDB后端类 - 无索引版本"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
account: str = "root",
|
|
||||||
api_key: str = "vdb$yjr9ln3n0td",
|
|
||||||
endpoint: str = "http://180.76.96.191:5287",
|
|
||||||
database_name: str = "minimal_test",
|
|
||||||
vector_dimension: int = 128): # 使用较小的向量维度
|
|
||||||
"""
|
|
||||||
初始化VDB连接
|
|
||||||
"""
|
|
||||||
self.account = account
|
|
||||||
self.api_key = api_key
|
|
||||||
self.endpoint = endpoint
|
|
||||||
self.database_name = database_name
|
|
||||||
self.vector_dimension = vector_dimension
|
|
||||||
|
|
||||||
# 表名
|
|
||||||
self.test_table_name = "simple_vectors"
|
|
||||||
|
|
||||||
# 初始化连接
|
|
||||||
self.client = None
|
|
||||||
self.db = None
|
|
||||||
self.test_table = None
|
|
||||||
|
|
||||||
self._init_connection()
|
|
||||||
|
|
||||||
def _init_connection(self):
|
|
||||||
"""初始化数据库连接"""
|
|
||||||
try:
|
|
||||||
logger.info("🔗 初始化最小化VDB连接...")
|
|
||||||
|
|
||||||
# 创建配置
|
|
||||||
config = Configuration(
|
|
||||||
credentials=BceCredentials(self.account, self.api_key),
|
|
||||||
endpoint=self.endpoint
|
|
||||||
)
|
|
||||||
|
|
||||||
# 创建客户端
|
|
||||||
self.client = pymochow.MochowClient(config)
|
|
||||||
logger.info("✅ VDB客户端创建成功")
|
|
||||||
|
|
||||||
# 确保数据库存在
|
|
||||||
self._ensure_database()
|
|
||||||
|
|
||||||
# 确保表存在
|
|
||||||
self._ensure_table()
|
|
||||||
|
|
||||||
logger.info("✅ 最小化VDB后端初始化完成")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ VDB连接初始化失败: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def _ensure_database(self):
|
|
||||||
"""确保数据库存在"""
|
|
||||||
try:
|
|
||||||
# 检查数据库是否存在
|
|
||||||
db_list = self.client.list_databases()
|
|
||||||
db_names = [db.database_name for db in db_list]
|
|
||||||
|
|
||||||
if self.database_name not in db_names:
|
|
||||||
logger.info(f"创建数据库: {self.database_name}")
|
|
||||||
self.db = self.client.create_database(self.database_name)
|
|
||||||
else:
|
|
||||||
logger.info(f"使用现有数据库: {self.database_name}")
|
|
||||||
self.db = self.client.database(self.database_name)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 数据库操作失败: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def _ensure_table(self):
|
|
||||||
"""确保表存在"""
|
|
||||||
try:
|
|
||||||
# 获取现有表列表
|
|
||||||
table_list = self.db.list_table()
|
|
||||||
table_names = [table.table_name for table in table_list]
|
|
||||||
|
|
||||||
# 创建测试表
|
|
||||||
if self.test_table_name not in table_names:
|
|
||||||
self._create_simple_table()
|
|
||||||
else:
|
|
||||||
self.test_table = self.db.table(self.test_table_name)
|
|
||||||
logger.info(f"✅ 使用现有表: {self.test_table_name}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 表操作失败: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def _create_simple_table(self):
|
|
||||||
"""创建最简单的表 - 无索引"""
|
|
||||||
try:
|
|
||||||
logger.info(f"创建最简单的表: {self.test_table_name}")
|
|
||||||
|
|
||||||
# 定义字段 - 最简单的配置
|
|
||||||
fields = [
|
|
||||||
# 主键和分区键 - 必须是STRING类型
|
|
||||||
Field("id", FieldType.STRING, primary_key=True, partition_key=True, not_null=True),
|
|
||||||
# 内容字段
|
|
||||||
Field("content", FieldType.STRING, not_null=True),
|
|
||||||
# 向量字段 - 使用较小维度
|
|
||||||
Field("vector", FieldType.FLOAT_VECTOR, not_null=True, dimension=self.vector_dimension)
|
|
||||||
]
|
|
||||||
|
|
||||||
# 不创建任何索引 - 空索引列表
|
|
||||||
indexes = []
|
|
||||||
|
|
||||||
# 创建Schema
|
|
||||||
schema = Schema(fields=fields, indexes=indexes)
|
|
||||||
|
|
||||||
# 创建表 - 使用最小配置
|
|
||||||
self.test_table = self.db.create_table(
|
|
||||||
table_name=self.test_table_name,
|
|
||||||
replication=2, # 最小副本数
|
|
||||||
partition=Partition(partition_num=1), # 单分区
|
|
||||||
schema=schema,
|
|
||||||
description="最简单的测试表"
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"✅ 简单表创建成功: {self.test_table_name}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 创建简单表失败: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def _generate_id(self, content: str) -> str:
|
|
||||||
"""生成唯一ID"""
|
|
||||||
return hashlib.md5(content.encode('utf-8')).hexdigest()[:16] # 使用较短的ID
|
|
||||||
|
|
||||||
def store_vectors(self, contents: List[str], vectors: np.ndarray) -> List[str]:
|
|
||||||
"""存储向量"""
|
|
||||||
try:
|
|
||||||
if len(contents) != len(vectors):
|
|
||||||
raise ValueError("内容数量与向量数量不匹配")
|
|
||||||
|
|
||||||
logger.info(f"存储 {len(contents)} 条向量...")
|
|
||||||
|
|
||||||
rows = []
|
|
||||||
ids = []
|
|
||||||
|
|
||||||
for i, (content, vector) in enumerate(zip(contents, vectors)):
|
|
||||||
doc_id = self._generate_id(f"{content}_{i}")
|
|
||||||
ids.append(doc_id)
|
|
||||||
|
|
||||||
row = Row(
|
|
||||||
id=doc_id,
|
|
||||||
content=content,
|
|
||||||
vector=vector.tolist()
|
|
||||||
)
|
|
||||||
rows.append(row)
|
|
||||||
|
|
||||||
# 批量插入
|
|
||||||
self.test_table.upsert(rows)
|
|
||||||
logger.info(f"✅ 成功存储 {len(contents)} 条向量")
|
|
||||||
|
|
||||||
return ids
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 存储向量失败: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
def get_all_data(self) -> List[Dict]:
|
|
||||||
"""获取所有数据(用于验证)"""
|
|
||||||
try:
|
|
||||||
logger.info("获取所有数据...")
|
|
||||||
|
|
||||||
# 使用简单查询获取数据
|
|
||||||
# 注意:这里不使用向量搜索,而是直接查询
|
|
||||||
results = []
|
|
||||||
|
|
||||||
# 尝试通过表统计获取信息
|
|
||||||
try:
|
|
||||||
stats = self.test_table.stats()
|
|
||||||
logger.info(f"表统计信息: {stats}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"无法获取表统计: {e}")
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 获取数据失败: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
def get_statistics(self) -> Dict[str, Any]:
|
|
||||||
"""获取统计信息"""
|
|
||||||
try:
|
|
||||||
stats = {
|
|
||||||
"database_name": self.database_name,
|
|
||||||
"table_name": self.test_table_name,
|
|
||||||
"vector_dimension": self.vector_dimension,
|
|
||||||
"status": "connected",
|
|
||||||
"has_indexes": False
|
|
||||||
}
|
|
||||||
|
|
||||||
# 尝试获取表统计信息
|
|
||||||
try:
|
|
||||||
table_stats = self.test_table.stats()
|
|
||||||
stats["table_stats"] = table_stats
|
|
||||||
except Exception as e:
|
|
||||||
stats["table_stats_error"] = str(e)
|
|
||||||
|
|
||||||
return stats
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 获取统计信息失败: {e}")
|
|
||||||
return {"status": "error", "error": str(e)}
|
|
||||||
|
|
||||||
def clear_all_data(self):
|
|
||||||
"""清空所有数据"""
|
|
||||||
try:
|
|
||||||
logger.info("清空所有数据...")
|
|
||||||
|
|
||||||
# 删除表(如果存在)
|
|
||||||
try:
|
|
||||||
self.db.drop_table(self.test_table_name)
|
|
||||||
logger.info(f"✅ 删除表: {self.test_table_name}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"删除表失败: {e}")
|
|
||||||
|
|
||||||
# 重新创建表
|
|
||||||
self._ensure_table()
|
|
||||||
logger.info("✅ 数据清空完成")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 清空数据失败: {e}")
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
"""关闭连接"""
|
|
||||||
try:
|
|
||||||
if self.client:
|
|
||||||
self.client.close()
|
|
||||||
logger.info("✅ VDB连接已关闭")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 关闭连接失败: {e}")
|
|
||||||
|
|
||||||
def test_minimal_vdb():
|
|
||||||
"""测试最小化VDB"""
|
|
||||||
print("=" * 60)
|
|
||||||
print("测试最小化百度VDB后端(无索引版本)")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
vdb = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 1. 初始化VDB
|
|
||||||
print("1. 初始化最小化VDB连接...")
|
|
||||||
vdb = BaiduVDBMinimal()
|
|
||||||
print("✅ 最小化VDB初始化成功")
|
|
||||||
|
|
||||||
# 2. 测试向量存储
|
|
||||||
print("\n2. 测试向量存储...")
|
|
||||||
test_contents = [
|
|
||||||
"测试文本1",
|
|
||||||
"测试文本2",
|
|
||||||
"测试文本3"
|
|
||||||
]
|
|
||||||
|
|
||||||
# 生成随机向量用于测试(使用较小维度)
|
|
||||||
test_vectors = np.random.rand(len(test_contents), 128).astype(np.float32)
|
|
||||||
|
|
||||||
ids = vdb.store_vectors(test_contents, test_vectors)
|
|
||||||
print(f"✅ 存储了 {len(ids)} 条向量")
|
|
||||||
print(f"生成的ID: {ids}")
|
|
||||||
|
|
||||||
# 3. 获取统计信息
|
|
||||||
print("\n3. 获取统计信息...")
|
|
||||||
stats = vdb.get_statistics()
|
|
||||||
print(f"✅ 统计信息:")
|
|
||||||
for key, value in stats.items():
|
|
||||||
print(f" {key}: {value}")
|
|
||||||
|
|
||||||
# 4. 验证数据存储
|
|
||||||
print("\n4. 验证数据存储...")
|
|
||||||
data = vdb.get_all_data()
|
|
||||||
print(f"✅ 数据验证完成")
|
|
||||||
|
|
||||||
print(f"\n🎉 最小化VDB测试完成!")
|
|
||||||
print("✅ 表创建成功(无索引)")
|
|
||||||
print("✅ 向量存储成功")
|
|
||||||
print("✅ 基本操作正常")
|
|
||||||
print("\n📋 下一步:")
|
|
||||||
print("1. 表创建成功,说明基本结构没问题")
|
|
||||||
print("2. 可以尝试添加向量索引")
|
|
||||||
print("3. 测试向量搜索功能")
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ 测试失败: {e}")
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
return False
|
|
||||||
|
|
||||||
finally:
|
|
||||||
if vdb:
|
|
||||||
vdb.close()
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test_minimal_vdb()
|
|
||||||
@ -1,544 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
生产级百度VDB后端 - 完全替代FAISS
|
|
||||||
支持完整的向量存储、索引和搜索功能
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import numpy as np
|
|
||||||
import json
|
|
||||||
import hashlib
|
|
||||||
import time
|
|
||||||
import logging
|
|
||||||
from typing import List, Tuple, Dict, Any, Optional, Union
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
import pymochow
|
|
||||||
from pymochow.configuration import Configuration
|
|
||||||
from pymochow.auth.bce_credentials import BceCredentials
|
|
||||||
from pymochow.model.schema import Schema, Field, VectorIndex, HNSWParams
|
|
||||||
from pymochow.model.enum import FieldType, IndexType, MetricType
|
|
||||||
from pymochow.model.table import Row, Partition
|
|
||||||
from pymochow.model.table import VectorTopkSearchRequest, VectorSearchConfig, FloatVector
|
|
||||||
|
|
||||||
# 设置日志
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
class BaiduVDBProduction:
|
|
||||||
"""生产级百度VDB后端类"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
account: str = "root",
|
|
||||||
api_key: str = "vdb$yjr9ln3n0td",
|
|
||||||
endpoint: str = "http://180.76.96.191:5287",
|
|
||||||
database_name: str = "multimodal_production",
|
|
||||||
vector_dimension: int = 3584):
|
|
||||||
"""
|
|
||||||
初始化生产级VDB连接
|
|
||||||
|
|
||||||
Args:
|
|
||||||
account: 账户名
|
|
||||||
api_key: API密钥
|
|
||||||
endpoint: 服务端点
|
|
||||||
database_name: 数据库名称
|
|
||||||
vector_dimension: 向量维度
|
|
||||||
"""
|
|
||||||
self.account = account
|
|
||||||
self.api_key = api_key
|
|
||||||
self.endpoint = endpoint
|
|
||||||
self.database_name = database_name
|
|
||||||
self.vector_dimension = vector_dimension
|
|
||||||
|
|
||||||
# 表名
|
|
||||||
self.text_table_name = "text_vectors_prod"
|
|
||||||
self.image_table_name = "image_vectors_prod"
|
|
||||||
|
|
||||||
# 初始化连接
|
|
||||||
self.client = None
|
|
||||||
self.db = None
|
|
||||||
self.text_table = None
|
|
||||||
self.image_table = None
|
|
||||||
|
|
||||||
# 数据缓存
|
|
||||||
self.text_data = []
|
|
||||||
self.image_data = []
|
|
||||||
|
|
||||||
self._init_connection()
|
|
||||||
|
|
||||||
def _init_connection(self):
|
|
||||||
"""初始化数据库连接"""
|
|
||||||
try:
|
|
||||||
logger.info("🔗 初始化生产级百度VDB连接...")
|
|
||||||
|
|
||||||
# 创建配置
|
|
||||||
config = Configuration(
|
|
||||||
credentials=BceCredentials(self.account, self.api_key),
|
|
||||||
endpoint=self.endpoint
|
|
||||||
)
|
|
||||||
|
|
||||||
# 创建客户端
|
|
||||||
self.client = pymochow.MochowClient(config)
|
|
||||||
logger.info("✅ VDB客户端创建成功")
|
|
||||||
|
|
||||||
# 确保数据库存在
|
|
||||||
self._ensure_database()
|
|
||||||
|
|
||||||
# 确保表存在
|
|
||||||
self._ensure_tables()
|
|
||||||
|
|
||||||
logger.info("✅ 生产级VDB后端初始化完成")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ VDB连接初始化失败: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def _ensure_database(self):
|
|
||||||
"""确保数据库存在"""
|
|
||||||
try:
|
|
||||||
# 检查数据库是否存在
|
|
||||||
db_list = self.client.list_databases()
|
|
||||||
db_names = [db.database_name for db in db_list]
|
|
||||||
|
|
||||||
if self.database_name not in db_names:
|
|
||||||
logger.info(f"创建生产数据库: {self.database_name}")
|
|
||||||
self.db = self.client.create_database(self.database_name)
|
|
||||||
else:
|
|
||||||
logger.info(f"使用现有数据库: {self.database_name}")
|
|
||||||
self.db = self.client.database(self.database_name)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 数据库操作失败: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def _ensure_tables(self):
|
|
||||||
"""确保表存在"""
|
|
||||||
try:
|
|
||||||
# 获取现有表列表
|
|
||||||
table_list = self.db.list_table()
|
|
||||||
table_names = [table.table_name for table in table_list]
|
|
||||||
|
|
||||||
# 创建文本表
|
|
||||||
if self.text_table_name not in table_names:
|
|
||||||
self._create_text_table()
|
|
||||||
else:
|
|
||||||
self.text_table = self.db.table(self.text_table_name)
|
|
||||||
logger.info(f"✅ 使用现有文本表: {self.text_table_name}")
|
|
||||||
|
|
||||||
# 创建图像表
|
|
||||||
if self.image_table_name not in table_names:
|
|
||||||
self._create_image_table()
|
|
||||||
else:
|
|
||||||
self.image_table = self.db.table(self.image_table_name)
|
|
||||||
logger.info(f"✅ 使用现有图像表: {self.image_table_name}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 表操作失败: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def _create_text_table(self):
|
|
||||||
"""创建文本向量表"""
|
|
||||||
try:
|
|
||||||
logger.info(f"创建生产级文本向量表: {self.text_table_name}")
|
|
||||||
|
|
||||||
# 定义字段
|
|
||||||
fields = [
|
|
||||||
Field("id", FieldType.STRING, primary_key=True, partition_key=True, not_null=True),
|
|
||||||
Field("content", FieldType.STRING, not_null=True),
|
|
||||||
Field("vector", FieldType.FLOAT_VECTOR, not_null=True, dimension=self.vector_dimension)
|
|
||||||
]
|
|
||||||
|
|
||||||
# 先创建无索引的表
|
|
||||||
indexes = []
|
|
||||||
schema = Schema(fields=fields, indexes=indexes)
|
|
||||||
|
|
||||||
# 创建表
|
|
||||||
self.text_table = self.db.create_table(
|
|
||||||
table_name=self.text_table_name,
|
|
||||||
replication=2,
|
|
||||||
partition=Partition(partition_num=3), # 使用3个分区提高性能
|
|
||||||
schema=schema,
|
|
||||||
description="生产级文本向量表"
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"✅ 文本表创建成功: {self.text_table_name}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 创建文本表失败: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def _create_image_table(self):
|
|
||||||
"""创建图像向量表"""
|
|
||||||
try:
|
|
||||||
logger.info(f"创建生产级图像向量表: {self.image_table_name}")
|
|
||||||
|
|
||||||
# 定义字段
|
|
||||||
fields = [
|
|
||||||
Field("id", FieldType.STRING, primary_key=True, partition_key=True, not_null=True),
|
|
||||||
Field("image_path", FieldType.STRING, not_null=True),
|
|
||||||
Field("vector", FieldType.FLOAT_VECTOR, not_null=True, dimension=self.vector_dimension)
|
|
||||||
]
|
|
||||||
|
|
||||||
# 先创建无索引的表
|
|
||||||
indexes = []
|
|
||||||
schema = Schema(fields=fields, indexes=indexes)
|
|
||||||
|
|
||||||
# 创建表
|
|
||||||
self.image_table = self.db.create_table(
|
|
||||||
table_name=self.image_table_name,
|
|
||||||
replication=2,
|
|
||||||
partition=Partition(partition_num=3),
|
|
||||||
schema=schema,
|
|
||||||
description="生产级图像向量表"
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"✅ 图像表创建成功: {self.image_table_name}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 创建图像表失败: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def _generate_id(self, content: str) -> str:
|
|
||||||
"""生成唯一ID"""
|
|
||||||
return hashlib.md5(content.encode('utf-8')).hexdigest()
|
|
||||||
|
|
||||||
def _wait_for_table_ready(self, table, max_wait_seconds=30):
|
|
||||||
"""等待表就绪"""
|
|
||||||
for i in range(max_wait_seconds):
|
|
||||||
try:
|
|
||||||
# 尝试插入测试数据
|
|
||||||
test_vector = np.random.rand(self.vector_dimension).astype(np.float32)
|
|
||||||
test_row = Row(
|
|
||||||
id=f"test_{int(time.time())}",
|
|
||||||
content="test" if table == self.text_table else None,
|
|
||||||
image_path="test" if table == self.image_table else None,
|
|
||||||
vector=test_vector.tolist()
|
|
||||||
)
|
|
||||||
|
|
||||||
table.upsert([test_row])
|
|
||||||
# 如果成功,删除测试数据
|
|
||||||
table.delete(primary_key={"id": test_row.id})
|
|
||||||
logger.info(f"✅ 表已就绪")
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
if "Table Not Ready" in str(e):
|
|
||||||
logger.info(f"等待表就绪... ({i+1}/{max_wait_seconds})")
|
|
||||||
time.sleep(1)
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
|
|
||||||
logger.warning("⚠️ 表可能仍未完全就绪")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def build_text_index(self, texts: List[str], vectors: np.ndarray) -> List[str]:
|
|
||||||
"""构建文本索引 - 替代FAISS的build_text_index_parallel"""
|
|
||||||
try:
|
|
||||||
logger.info(f"构建文本索引,共 {len(texts)} 条文本")
|
|
||||||
|
|
||||||
if len(texts) != len(vectors):
|
|
||||||
raise ValueError("文本数量与向量数量不匹配")
|
|
||||||
|
|
||||||
# 等待表就绪
|
|
||||||
self._wait_for_table_ready(self.text_table)
|
|
||||||
|
|
||||||
# 批量存储向量
|
|
||||||
rows = []
|
|
||||||
ids = []
|
|
||||||
|
|
||||||
for i, (text, vector) in enumerate(zip(texts, vectors)):
|
|
||||||
doc_id = self._generate_id(f"{text}_{i}")
|
|
||||||
ids.append(doc_id)
|
|
||||||
|
|
||||||
row = Row(
|
|
||||||
id=doc_id,
|
|
||||||
content=text,
|
|
||||||
vector=vector.tolist()
|
|
||||||
)
|
|
||||||
rows.append(row)
|
|
||||||
|
|
||||||
# 批量插入
|
|
||||||
self.text_table.upsert(rows)
|
|
||||||
self.text_data = texts
|
|
||||||
|
|
||||||
logger.info(f"✅ 文本索引构建完成,存储了 {len(texts)} 条记录")
|
|
||||||
return ids
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 构建文本索引失败: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
def build_image_index(self, image_paths: List[str], vectors: np.ndarray) -> List[str]:
|
|
||||||
"""构建图像索引 - 替代FAISS的build_image_index_parallel"""
|
|
||||||
try:
|
|
||||||
logger.info(f"构建图像索引,共 {len(image_paths)} 张图像")
|
|
||||||
|
|
||||||
if len(image_paths) != len(vectors):
|
|
||||||
raise ValueError("图像数量与向量数量不匹配")
|
|
||||||
|
|
||||||
# 等待表就绪
|
|
||||||
self._wait_for_table_ready(self.image_table)
|
|
||||||
|
|
||||||
# 批量存储向量
|
|
||||||
rows = []
|
|
||||||
ids = []
|
|
||||||
|
|
||||||
for i, (image_path, vector) in enumerate(zip(image_paths, vectors)):
|
|
||||||
doc_id = self._generate_id(f"{image_path}_{i}")
|
|
||||||
ids.append(doc_id)
|
|
||||||
|
|
||||||
row = Row(
|
|
||||||
id=doc_id,
|
|
||||||
image_path=image_path,
|
|
||||||
vector=vector.tolist()
|
|
||||||
)
|
|
||||||
rows.append(row)
|
|
||||||
|
|
||||||
# 批量插入
|
|
||||||
self.image_table.upsert(rows)
|
|
||||||
self.image_data = image_paths
|
|
||||||
|
|
||||||
logger.info(f"✅ 图像索引构建完成,存储了 {len(image_paths)} 条记录")
|
|
||||||
return ids
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 构建图像索引失败: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
def search_text_by_text(self, query_vector: np.ndarray, top_k: int = 5) -> List[Tuple[str, float]]:
|
|
||||||
"""文搜文 - 替代FAISS的search_text_by_text"""
|
|
||||||
try:
|
|
||||||
logger.info(f"执行文搜文,top_k={top_k}")
|
|
||||||
|
|
||||||
# 使用简单的距离计算进行搜索(暂时替代向量搜索)
|
|
||||||
# 这是临时方案,等VDB向量搜索API修复后会更新
|
|
||||||
results = []
|
|
||||||
|
|
||||||
# 获取所有文本数据进行比较
|
|
||||||
if self.text_data:
|
|
||||||
# 简单返回前几个结果作为示例
|
|
||||||
for i, text in enumerate(self.text_data[:top_k]):
|
|
||||||
# 模拟相似度分数
|
|
||||||
score = 0.8 - i * 0.1
|
|
||||||
results.append((text, score))
|
|
||||||
|
|
||||||
logger.info(f"✅ 文搜文完成,返回 {len(results)} 条结果")
|
|
||||||
return results
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 文搜文失败: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
def search_images_by_text(self, query_vector: np.ndarray, top_k: int = 5) -> List[Tuple[str, float]]:
|
|
||||||
"""文搜图 - 替代FAISS的search_images_by_text"""
|
|
||||||
try:
|
|
||||||
logger.info(f"执行文搜图,top_k={top_k}")
|
|
||||||
|
|
||||||
results = []
|
|
||||||
|
|
||||||
# 获取所有图像数据进行比较
|
|
||||||
if self.image_data:
|
|
||||||
for i, image_path in enumerate(self.image_data[:top_k]):
|
|
||||||
score = 0.75 - i * 0.1
|
|
||||||
results.append((image_path, score))
|
|
||||||
|
|
||||||
logger.info(f"✅ 文搜图完成,返回 {len(results)} 条结果")
|
|
||||||
return results
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 文搜图失败: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
def search_images_by_image(self, query_vector: np.ndarray, top_k: int = 5) -> List[Tuple[str, float]]:
|
|
||||||
"""图搜图 - 替代FAISS的search_images_by_image"""
|
|
||||||
try:
|
|
||||||
logger.info(f"执行图搜图,top_k={top_k}")
|
|
||||||
|
|
||||||
results = []
|
|
||||||
|
|
||||||
if self.image_data:
|
|
||||||
for i, image_path in enumerate(self.image_data[:top_k]):
|
|
||||||
score = 0.8 - i * 0.1
|
|
||||||
results.append((image_path, score))
|
|
||||||
|
|
||||||
logger.info(f"✅ 图搜图完成,返回 {len(results)} 条结果")
|
|
||||||
return results
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 图搜图失败: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
def search_text_by_image(self, query_vector: np.ndarray, top_k: int = 5) -> List[Tuple[str, float]]:
|
|
||||||
"""图搜文 - 替代FAISS的search_text_by_image"""
|
|
||||||
try:
|
|
||||||
logger.info(f"执行图搜文,top_k={top_k}")
|
|
||||||
|
|
||||||
results = []
|
|
||||||
|
|
||||||
if self.text_data:
|
|
||||||
for i, text in enumerate(self.text_data[:top_k]):
|
|
||||||
score = 0.7 - i * 0.1
|
|
||||||
results.append((text, score))
|
|
||||||
|
|
||||||
logger.info(f"✅ 图搜文完成,返回 {len(results)} 条结果")
|
|
||||||
return results
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 图搜文失败: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
def get_statistics(self) -> Dict[str, Any]:
|
|
||||||
"""获取统计信息"""
|
|
||||||
try:
|
|
||||||
stats = {
|
|
||||||
"database_name": self.database_name,
|
|
||||||
"text_table": self.text_table_name,
|
|
||||||
"image_table": self.image_table_name,
|
|
||||||
"vector_dimension": self.vector_dimension,
|
|
||||||
"status": "connected",
|
|
||||||
"backend": "Baidu VDB"
|
|
||||||
}
|
|
||||||
|
|
||||||
# 获取表统计信息
|
|
||||||
try:
|
|
||||||
text_stats = self.text_table.stats()
|
|
||||||
stats["text_count"] = text_stats.get("row_count", 0)
|
|
||||||
except:
|
|
||||||
stats["text_count"] = len(self.text_data)
|
|
||||||
|
|
||||||
try:
|
|
||||||
image_stats = self.image_table.stats()
|
|
||||||
stats["image_count"] = image_stats.get("row_count", 0)
|
|
||||||
except:
|
|
||||||
stats["image_count"] = len(self.image_data)
|
|
||||||
|
|
||||||
return stats
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 获取统计信息失败: {e}")
|
|
||||||
return {"status": "error", "error": str(e)}
|
|
||||||
|
|
||||||
def clear_all_data(self):
|
|
||||||
"""清空所有数据"""
|
|
||||||
try:
|
|
||||||
logger.info("清空所有数据...")
|
|
||||||
|
|
||||||
# 删除表
|
|
||||||
try:
|
|
||||||
self.db.drop_table(self.text_table_name)
|
|
||||||
logger.info(f"✅ 删除文本表: {self.text_table_name}")
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.db.drop_table(self.image_table_name)
|
|
||||||
logger.info(f"✅ 删除图像表: {self.image_table_name}")
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# 清空缓存
|
|
||||||
self.text_data = []
|
|
||||||
self.image_data = []
|
|
||||||
|
|
||||||
# 重新创建表
|
|
||||||
self._ensure_tables()
|
|
||||||
logger.info("✅ 数据清空完成")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 清空数据失败: {e}")
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
"""关闭连接"""
|
|
||||||
try:
|
|
||||||
if self.client:
|
|
||||||
self.client.close()
|
|
||||||
logger.info("✅ VDB连接已关闭")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 关闭连接失败: {e}")
|
|
||||||
|
|
||||||
def test_production_vdb():
|
|
||||||
"""测试生产级VDB"""
|
|
||||||
print("=" * 60)
|
|
||||||
print("测试生产级百度VDB后端")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
vdb = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 1. 初始化VDB
|
|
||||||
print("1. 初始化生产级VDB...")
|
|
||||||
vdb = BaiduVDBProduction()
|
|
||||||
print("✅ 生产级VDB初始化成功")
|
|
||||||
|
|
||||||
# 2. 测试文本索引构建
|
|
||||||
print("\n2. 测试文本索引构建...")
|
|
||||||
test_texts = [
|
|
||||||
"这是一个关于人工智能的文档",
|
|
||||||
"机器学习算法的应用场景",
|
|
||||||
"深度学习在图像识别中的应用",
|
|
||||||
"自然语言处理技术发展",
|
|
||||||
"计算机视觉的最新进展"
|
|
||||||
]
|
|
||||||
|
|
||||||
# 生成测试向量
|
|
||||||
test_vectors = np.random.rand(len(test_texts), 3584).astype(np.float32)
|
|
||||||
|
|
||||||
text_ids = vdb.build_text_index(test_texts, test_vectors)
|
|
||||||
print(f"✅ 文本索引构建完成,ID数量: {len(text_ids)}")
|
|
||||||
|
|
||||||
# 3. 测试图像索引构建
|
|
||||||
print("\n3. 测试图像索引构建...")
|
|
||||||
test_images = [
|
|
||||||
"/path/to/image1.jpg",
|
|
||||||
"/path/to/image2.jpg",
|
|
||||||
"/path/to/image3.jpg"
|
|
||||||
]
|
|
||||||
|
|
||||||
image_vectors = np.random.rand(len(test_images), 3584).astype(np.float32)
|
|
||||||
|
|
||||||
image_ids = vdb.build_image_index(test_images, image_vectors)
|
|
||||||
print(f"✅ 图像索引构建完成,ID数量: {len(image_ids)}")
|
|
||||||
|
|
||||||
# 4. 测试搜索功能
|
|
||||||
print("\n4. 测试搜索功能...")
|
|
||||||
query_vector = np.random.rand(3584).astype(np.float32)
|
|
||||||
|
|
||||||
# 文搜文
|
|
||||||
text_results = vdb.search_text_by_text(query_vector, top_k=3)
|
|
||||||
print(f"文搜文结果: {len(text_results)} 条")
|
|
||||||
for i, (text, score) in enumerate(text_results, 1):
|
|
||||||
print(f" {i}. {text[:30]}... (分数: {score:.3f})")
|
|
||||||
|
|
||||||
# 文搜图
|
|
||||||
image_results = vdb.search_images_by_text(query_vector, top_k=2)
|
|
||||||
print(f"文搜图结果: {len(image_results)} 条")
|
|
||||||
|
|
||||||
# 5. 获取统计信息
|
|
||||||
print("\n5. 获取统计信息...")
|
|
||||||
stats = vdb.get_statistics()
|
|
||||||
print("统计信息:")
|
|
||||||
for key, value in stats.items():
|
|
||||||
print(f" {key}: {value}")
|
|
||||||
|
|
||||||
print(f"\n🎉 生产级VDB测试完成!")
|
|
||||||
print("✅ 完全替代FAISS功能")
|
|
||||||
print("✅ 支持四种检索模式")
|
|
||||||
print("✅ 生产级数据存储")
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ 测试失败: {e}")
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
return False
|
|
||||||
|
|
||||||
finally:
|
|
||||||
if vdb:
|
|
||||||
vdb.close()
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test_production_vdb()
|
|
||||||
@ -1,198 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
带索引的百度VDB测试 - 在表创建完成后添加索引
|
|
||||||
"""
|
|
||||||
|
|
||||||
import time
|
|
||||||
import logging
|
|
||||||
from baidu_vdb_minimal import BaiduVDBMinimal
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import pymochow
|
|
||||||
from pymochow.configuration import Configuration
|
|
||||||
from pymochow.auth.bce_credentials import BceCredentials
|
|
||||||
from pymochow.model.schema import VectorIndex, HNSWParams
|
|
||||||
from pymochow.model.enum import IndexType, MetricType
|
|
||||||
from pymochow.model.table import VectorTopkSearchRequest, VectorSearchConfig, FloatVector
|
|
||||||
|
|
||||||
# 设置日志
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
class BaiduVDBWithIndex(BaiduVDBMinimal):
|
|
||||||
"""带索引的VDB类"""
|
|
||||||
|
|
||||||
def wait_table_ready(self, max_wait_seconds=60):
|
|
||||||
"""等待表创建完成"""
|
|
||||||
logger.info("等待表创建完成...")
|
|
||||||
|
|
||||||
for i in range(max_wait_seconds):
|
|
||||||
try:
|
|
||||||
stats = self.test_table.stats()
|
|
||||||
logger.info(f"表状态检查 {i+1}/{max_wait_seconds}: {stats.get('msg', 'Unknown')}")
|
|
||||||
|
|
||||||
# 尝试存储一条测试数据
|
|
||||||
test_vector = np.random.rand(self.vector_dimension).astype(np.float32)
|
|
||||||
test_ids = self.store_vectors(["test"], test_vector.reshape(1, -1))
|
|
||||||
|
|
||||||
if test_ids:
|
|
||||||
logger.info("✅ 表已就绪,可以存储数据")
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
if "Table Not Ready" in str(e):
|
|
||||||
logger.info(f"表仍在创建中... ({i+1}/{max_wait_seconds})")
|
|
||||||
time.sleep(1)
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
logger.error(f"其他错误: {e}")
|
|
||||||
break
|
|
||||||
|
|
||||||
logger.warning("⚠️ 表可能仍未就绪")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def add_vector_index(self):
|
|
||||||
"""为现有表添加向量索引"""
|
|
||||||
try:
|
|
||||||
logger.info("为表添加向量索引...")
|
|
||||||
|
|
||||||
# 创建向量索引
|
|
||||||
vector_index = VectorIndex(
|
|
||||||
index_name="vector_hnsw_idx",
|
|
||||||
index_type=IndexType.HNSW,
|
|
||||||
field="vector",
|
|
||||||
metric_type=MetricType.COSINE,
|
|
||||||
params=HNSWParams(m=16, efconstruction=200),
|
|
||||||
auto_build=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# 添加索引到表
|
|
||||||
self.test_table.add_index(vector_index)
|
|
||||||
logger.info("✅ 向量索引添加成功")
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 添加向量索引失败: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def search_vectors(self, query_vector: np.ndarray, top_k: int = 3) -> list:
|
|
||||||
"""搜索向量"""
|
|
||||||
try:
|
|
||||||
logger.info(f"搜索向量,top_k={top_k}")
|
|
||||||
|
|
||||||
# 创建搜索请求
|
|
||||||
request = VectorTopkSearchRequest(
|
|
||||||
vector_field="vector",
|
|
||||||
vector=FloatVector(query_vector.tolist()),
|
|
||||||
limit=top_k,
|
|
||||||
config=VectorSearchConfig(ef=200)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 执行搜索
|
|
||||||
results = self.test_table.vector_search(request=request)
|
|
||||||
|
|
||||||
# 解析结果
|
|
||||||
search_results = []
|
|
||||||
for result in results:
|
|
||||||
doc_id = result.get('id', '')
|
|
||||||
content = result.get('content', '')
|
|
||||||
score = result.get('_score', 0.0)
|
|
||||||
|
|
||||||
search_results.append((doc_id, content, float(score)))
|
|
||||||
|
|
||||||
logger.info(f"✅ 向量搜索完成,返回 {len(search_results)} 条结果")
|
|
||||||
return search_results
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 向量搜索失败: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
def test_vdb_with_index():
|
|
||||||
"""测试带索引的VDB"""
|
|
||||||
print("=" * 60)
|
|
||||||
print("测试带索引的百度VDB")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
vdb = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 1. 初始化VDB(复用无索引版本)
|
|
||||||
print("1. 初始化VDB连接...")
|
|
||||||
vdb = BaiduVDBWithIndex()
|
|
||||||
print("✅ VDB初始化成功")
|
|
||||||
|
|
||||||
# 2. 等待表就绪
|
|
||||||
print("\n2. 等待表创建完成...")
|
|
||||||
if vdb.wait_table_ready(30):
|
|
||||||
print("✅ 表已就绪")
|
|
||||||
else:
|
|
||||||
print("⚠️ 表可能仍在创建中,继续测试...")
|
|
||||||
|
|
||||||
# 3. 存储测试数据
|
|
||||||
print("\n3. 存储测试向量...")
|
|
||||||
test_contents = [
|
|
||||||
"这是第一个测试文档",
|
|
||||||
"这是第二个测试文档",
|
|
||||||
"这是第三个测试文档",
|
|
||||||
"这是第四个测试文档",
|
|
||||||
"这是第五个测试文档"
|
|
||||||
]
|
|
||||||
|
|
||||||
test_vectors = np.random.rand(len(test_contents), 128).astype(np.float32)
|
|
||||||
|
|
||||||
ids = vdb.store_vectors(test_contents, test_vectors)
|
|
||||||
print(f"✅ 存储了 {len(ids)} 条向量")
|
|
||||||
|
|
||||||
if not ids:
|
|
||||||
print("⚠️ 数据存储失败,跳过后续测试")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# 4. 添加向量索引
|
|
||||||
print("\n4. 添加向量索引...")
|
|
||||||
if vdb.add_vector_index():
|
|
||||||
print("✅ 向量索引添加成功")
|
|
||||||
|
|
||||||
# 等待索引构建
|
|
||||||
print("等待索引构建...")
|
|
||||||
time.sleep(10)
|
|
||||||
|
|
||||||
# 5. 测试向量搜索
|
|
||||||
print("\n5. 测试向量搜索...")
|
|
||||||
query_vector = test_vectors[0] # 使用第一个向量作为查询
|
|
||||||
|
|
||||||
results = vdb.search_vectors(query_vector, top_k=3)
|
|
||||||
|
|
||||||
if results:
|
|
||||||
print(f"搜索结果 ({len(results)} 条):")
|
|
||||||
for i, (doc_id, content, score) in enumerate(results, 1):
|
|
||||||
print(f" {i}. {content} (相似度: {score:.4f})")
|
|
||||||
print("✅ 向量搜索成功")
|
|
||||||
else:
|
|
||||||
print("⚠️ 向量搜索失败或无结果")
|
|
||||||
else:
|
|
||||||
print("❌ 向量索引添加失败")
|
|
||||||
|
|
||||||
# 6. 获取最终统计
|
|
||||||
print("\n6. 获取统计信息...")
|
|
||||||
stats = vdb.get_statistics()
|
|
||||||
print("最终统计:")
|
|
||||||
for key, value in stats.items():
|
|
||||||
print(f" {key}: {value}")
|
|
||||||
|
|
||||||
print(f"\n🎉 带索引VDB测试完成!")
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ 测试失败: {e}")
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
return False
|
|
||||||
|
|
||||||
finally:
|
|
||||||
if vdb:
|
|
||||||
vdb.close()
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test_vdb_with_index()
|
|
||||||
Binary file not shown.
@ -1 +0,0 @@
|
|||||||
{}
|
|
||||||
Binary file not shown.
@ -1 +0,0 @@
|
|||||||
{}
|
|
||||||
@ -1,147 +0,0 @@
|
|||||||
import os
|
|
||||||
import json
|
|
||||||
import numpy as np
|
|
||||||
import faiss
|
|
||||||
from typing import List, Dict, Any, Optional, Tuple
|
|
||||||
import logging
|
|
||||||
|
|
||||||
class FaissVectorStore:
|
|
||||||
def __init__(self, index_path: str = "faiss_index", dimension: int = 3584):
|
|
||||||
"""
|
|
||||||
初始化FAISS向量存储
|
|
||||||
|
|
||||||
参数:
|
|
||||||
index_path: 索引文件路径
|
|
||||||
dimension: 向量维度
|
|
||||||
"""
|
|
||||||
self.index_path = index_path
|
|
||||||
self.dimension = dimension
|
|
||||||
self.index = None
|
|
||||||
self.metadata = {}
|
|
||||||
self.metadata_path = f"{index_path}_metadata.json"
|
|
||||||
|
|
||||||
# 加载现有索引或创建新索引
|
|
||||||
self._load_or_create_index()
|
|
||||||
|
|
||||||
def _load_or_create_index(self):
|
|
||||||
"""加载现有索引或创建新索引"""
|
|
||||||
if os.path.exists(f"{self.index_path}.index"):
|
|
||||||
logging.info(f"加载现有索引: {self.index_path}")
|
|
||||||
self.index = faiss.read_index(f"{self.index_path}.index")
|
|
||||||
self._load_metadata()
|
|
||||||
else:
|
|
||||||
logging.info(f"创建新索引,维度: {self.dimension}")
|
|
||||||
self.index = faiss.IndexFlatL2(self.dimension) # 使用L2距离
|
|
||||||
|
|
||||||
def _load_metadata(self):
|
|
||||||
"""加载元数据"""
|
|
||||||
if os.path.exists(self.metadata_path):
|
|
||||||
with open(self.metadata_path, 'r', encoding='utf-8') as f:
|
|
||||||
self.metadata = json.load(f)
|
|
||||||
|
|
||||||
def _save_metadata(self):
|
|
||||||
"""保存元数据到文件"""
|
|
||||||
with open(self.metadata_path, 'w', encoding='utf-8') as f:
|
|
||||||
json.dump(self.metadata, f, ensure_ascii=False, indent=2)
|
|
||||||
|
|
||||||
def save_index(self):
|
|
||||||
"""保存索引和元数据"""
|
|
||||||
if self.index is not None:
|
|
||||||
faiss.write_index(self.index, f"{self.index_path}.index")
|
|
||||||
self._save_metadata()
|
|
||||||
logging.info(f"索引已保存到 {self.index_path}.index")
|
|
||||||
|
|
||||||
def add_vectors(
|
|
||||||
self,
|
|
||||||
vectors: np.ndarray,
|
|
||||||
metadatas: List[Dict[str, Any]]
|
|
||||||
) -> List[str]:
|
|
||||||
"""
|
|
||||||
添加向量和元数据
|
|
||||||
|
|
||||||
参数:
|
|
||||||
vectors: 向量数组
|
|
||||||
metadatas: 对应的元数据列表
|
|
||||||
|
|
||||||
返回:
|
|
||||||
添加的向量ID列表
|
|
||||||
"""
|
|
||||||
if len(vectors) != len(metadatas):
|
|
||||||
raise ValueError("vectors和metadatas长度必须相同")
|
|
||||||
|
|
||||||
start_id = len(self.metadata)
|
|
||||||
ids = list(range(start_id, start_id + len(vectors)))
|
|
||||||
|
|
||||||
# 添加向量到索引
|
|
||||||
self.index.add(vectors.astype('float32'))
|
|
||||||
|
|
||||||
# 保存元数据
|
|
||||||
for idx, vector_id in enumerate(ids):
|
|
||||||
self.metadata[str(vector_id)] = metadatas[idx]
|
|
||||||
|
|
||||||
# 保存索引和元数据
|
|
||||||
self.save_index()
|
|
||||||
|
|
||||||
return [str(id) for id in ids]
|
|
||||||
|
|
||||||
def search(
|
|
||||||
self,
|
|
||||||
query_vector: np.ndarray,
|
|
||||||
k: int = 5
|
|
||||||
) -> Tuple[List[Dict[str, Any]], List[float]]:
|
|
||||||
"""
|
|
||||||
相似性搜索
|
|
||||||
|
|
||||||
参数:
|
|
||||||
query_vector: 查询向量
|
|
||||||
k: 返回结果数量
|
|
||||||
|
|
||||||
返回:
|
|
||||||
(结果列表, 距离列表)
|
|
||||||
"""
|
|
||||||
if self.index is None:
|
|
||||||
return [], []
|
|
||||||
|
|
||||||
# 确保输入是2D数组
|
|
||||||
if len(query_vector.shape) == 1:
|
|
||||||
query_vector = query_vector.reshape(1, -1)
|
|
||||||
|
|
||||||
# 执行搜索
|
|
||||||
distances, indices = self.index.search(query_vector.astype('float32'), k)
|
|
||||||
|
|
||||||
# 处理结果
|
|
||||||
results = []
|
|
||||||
for i in range(len(indices[0])):
|
|
||||||
idx = indices[0][i]
|
|
||||||
if idx < 0: # FAISS可能返回-1表示无效索引
|
|
||||||
continue
|
|
||||||
|
|
||||||
vector_id = str(idx)
|
|
||||||
if vector_id in self.metadata:
|
|
||||||
result = self.metadata[vector_id].copy()
|
|
||||||
result['distance'] = float(distances[0][i])
|
|
||||||
results.append(result)
|
|
||||||
|
|
||||||
return results, distances[0].tolist()
|
|
||||||
|
|
||||||
def get_vector_count(self) -> int:
|
|
||||||
"""获取向量数量"""
|
|
||||||
return self.index.ntotal if self.index is not None else 0
|
|
||||||
|
|
||||||
def delete_vectors(self, vector_ids: List[str]) -> bool:
|
|
||||||
"""
|
|
||||||
删除指定ID的向量
|
|
||||||
|
|
||||||
注意: FAISS不支持直接删除向量,这里实现为逻辑删除
|
|
||||||
"""
|
|
||||||
deleted_count = 0
|
|
||||||
for vector_id in vector_ids:
|
|
||||||
if vector_id in self.metadata:
|
|
||||||
del self.metadata[vector_id]
|
|
||||||
deleted_count += 1
|
|
||||||
|
|
||||||
if deleted_count > 0:
|
|
||||||
self._save_metadata()
|
|
||||||
logging.warning("FAISS不支持直接删除向量,已从元数据中移除,但索引中仍保留")
|
|
||||||
|
|
||||||
return deleted_count > 0
|
|
||||||
@ -1,38 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
# 安装多模态检索系统依赖包
|
|
||||||
|
|
||||||
echo "🚀 开始安装多模态检索系统依赖包..."
|
|
||||||
|
|
||||||
# 更新pip
|
|
||||||
pip install --upgrade pip
|
|
||||||
|
|
||||||
# 安装基础依赖
|
|
||||||
echo "📦 安装基础依赖包..."
|
|
||||||
pip install torch>=2.0.0 torchvision>=0.15.0
|
|
||||||
pip install transformers>=4.30.0 accelerate>=0.20.0
|
|
||||||
pip install numpy>=1.21.0 Pillow>=9.0.0
|
|
||||||
pip install scikit-learn>=1.3.0 tqdm>=4.65.0
|
|
||||||
pip install flask>=2.3.0 werkzeug>=2.3.0
|
|
||||||
pip install psutil>=5.9.0
|
|
||||||
|
|
||||||
# 安装百度VDB SDK
|
|
||||||
echo "🔗 安装百度VDB SDK..."
|
|
||||||
pip install pymochow
|
|
||||||
|
|
||||||
# 安装MongoDB驱动
|
|
||||||
echo "💾 安装MongoDB驱动..."
|
|
||||||
pip install pymongo>=4.0.0
|
|
||||||
|
|
||||||
# 安装FAISS (备用)
|
|
||||||
echo "🔍 安装FAISS..."
|
|
||||||
pip install faiss-cpu>=1.7.4
|
|
||||||
|
|
||||||
echo "✅ 依赖包安装完成!"
|
|
||||||
echo "📋 已安装的主要包:"
|
|
||||||
echo " - torch (深度学习框架)"
|
|
||||||
echo " - transformers (模型库)"
|
|
||||||
echo " - pymochow (百度VDB SDK)"
|
|
||||||
echo " - flask (Web框架)"
|
|
||||||
echo " - pymongo (MongoDB驱动)"
|
|
||||||
echo ""
|
|
||||||
echo "🎯 接下来可以运行测试脚本验证安装"
|
|
||||||
@ -1,135 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
本地文件处理器
|
|
||||||
简化版的文件处理器,不依赖外部服务
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import io
|
|
||||||
import tempfile
|
|
||||||
import logging
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from typing import Dict, List, Optional, Any, Union, BinaryIO
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
class LocalFileHandler:
|
|
||||||
"""本地文件处理器"""
|
|
||||||
|
|
||||||
# 小文件阈值 (5MB)
|
|
||||||
SMALL_FILE_THRESHOLD = 5 * 1024 * 1024
|
|
||||||
|
|
||||||
# 支持的图像格式
|
|
||||||
SUPPORTED_IMAGE_FORMATS = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp'}
|
|
||||||
|
|
||||||
def __init__(self, temp_dir: str = None):
|
|
||||||
"""
|
|
||||||
初始化本地文件处理器
|
|
||||||
|
|
||||||
Args:
|
|
||||||
temp_dir: 临时文件目录
|
|
||||||
"""
|
|
||||||
self.temp_dir = temp_dir or tempfile.gettempdir()
|
|
||||||
self.temp_files = set() # 跟踪临时文件
|
|
||||||
|
|
||||||
# 确保临时目录存在
|
|
||||||
os.makedirs(self.temp_dir, exist_ok=True)
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def temp_file_context(self, content: bytes = None, suffix: str = None, delete_on_exit: bool = True):
|
|
||||||
"""临时文件上下文管理器,确保自动清理"""
|
|
||||||
temp_fd, temp_path = tempfile.mkstemp(suffix=suffix, dir=self.temp_dir)
|
|
||||||
self.temp_files.add(temp_path)
|
|
||||||
|
|
||||||
try:
|
|
||||||
os.close(temp_fd) # 关闭文件描述符
|
|
||||||
|
|
||||||
# 如果提供了内容,写入文件
|
|
||||||
if content is not None:
|
|
||||||
with open(temp_path, 'wb') as f:
|
|
||||||
f.write(content)
|
|
||||||
|
|
||||||
yield temp_path
|
|
||||||
finally:
|
|
||||||
if delete_on_exit and os.path.exists(temp_path):
|
|
||||||
try:
|
|
||||||
os.unlink(temp_path)
|
|
||||||
self.temp_files.discard(temp_path)
|
|
||||||
logger.debug(f"🗑️ 临时文件已清理: {temp_path}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"⚠️ 临时文件清理失败: {temp_path}, {e}")
|
|
||||||
|
|
||||||
def cleanup_all_temp_files(self):
|
|
||||||
"""清理所有跟踪的临时文件"""
|
|
||||||
for temp_path in list(self.temp_files):
|
|
||||||
if os.path.exists(temp_path):
|
|
||||||
try:
|
|
||||||
os.unlink(temp_path)
|
|
||||||
logger.debug(f"🗑️ 清理临时文件: {temp_path}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"⚠️ 清理临时文件失败: {temp_path}, {e}")
|
|
||||||
self.temp_files.clear()
|
|
||||||
|
|
||||||
def get_file_size(self, file_obj) -> int:
|
|
||||||
"""获取文件大小"""
|
|
||||||
if hasattr(file_obj, 'content_length') and file_obj.content_length:
|
|
||||||
return file_obj.content_length
|
|
||||||
|
|
||||||
# 通过读取内容获取大小
|
|
||||||
current_pos = file_obj.tell()
|
|
||||||
file_obj.seek(0, 2) # 移动到文件末尾
|
|
||||||
size = file_obj.tell()
|
|
||||||
file_obj.seek(current_pos) # 恢复原位置
|
|
||||||
return size
|
|
||||||
|
|
||||||
def is_small_file(self, file_obj) -> bool:
|
|
||||||
"""判断是否为小文件"""
|
|
||||||
return self.get_file_size(file_obj) <= self.SMALL_FILE_THRESHOLD
|
|
||||||
|
|
||||||
def get_temp_file_for_model(self, file_obj, filename: str) -> Optional[str]:
|
|
||||||
"""为模型处理获取临时文件路径(确保文件存在于本地)"""
|
|
||||||
try:
|
|
||||||
ext = os.path.splitext(filename)[1].lower()
|
|
||||||
|
|
||||||
# 创建临时文件(不自动删除,供模型使用)
|
|
||||||
temp_fd, temp_path = tempfile.mkstemp(suffix=ext, dir=self.temp_dir)
|
|
||||||
self.temp_files.add(temp_path)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 写入文件内容
|
|
||||||
file_obj.seek(0)
|
|
||||||
with os.fdopen(temp_fd, 'wb') as temp_file:
|
|
||||||
temp_file.write(file_obj.read())
|
|
||||||
|
|
||||||
logger.debug(f"📁 为模型创建临时文件: {temp_path}")
|
|
||||||
return temp_path
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
os.close(temp_fd)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 为模型创建临时文件失败: {filename}, {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def cleanup_temp_file(self, temp_path: str):
|
|
||||||
"""清理指定的临时文件"""
|
|
||||||
if temp_path and os.path.exists(temp_path):
|
|
||||||
try:
|
|
||||||
os.unlink(temp_path)
|
|
||||||
self.temp_files.discard(temp_path)
|
|
||||||
logger.debug(f"🗑️ 清理临时文件: {temp_path}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"⚠️ 清理临时文件失败: {temp_path}, {e}")
|
|
||||||
|
|
||||||
# 全局实例
|
|
||||||
file_handler = None
|
|
||||||
|
|
||||||
def get_file_handler(temp_dir: str = None) -> LocalFileHandler:
|
|
||||||
"""获取文件处理器实例"""
|
|
||||||
global file_handler
|
|
||||||
if file_handler is None:
|
|
||||||
file_handler = LocalFileHandler(temp_dir=temp_dir)
|
|
||||||
return file_handler
|
|
||||||
@ -1,301 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
MongoDB元数据管理器
|
|
||||||
用于存储和管理多模态文件的元数据信息
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import logging
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import Dict, List, Optional, Any
|
|
||||||
from pymongo import MongoClient
|
|
||||||
from pymongo.errors import ConnectionFailure, OperationFailure
|
|
||||||
import hashlib
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
class MongoDBManager:
|
|
||||||
"""MongoDB元数据管理器"""
|
|
||||||
|
|
||||||
def __init__(self, uri: str = None, database: str = "mmeb"):
|
|
||||||
"""
|
|
||||||
初始化MongoDB连接
|
|
||||||
|
|
||||||
Args:
|
|
||||||
uri: MongoDB连接URI
|
|
||||||
database: 数据库名称
|
|
||||||
"""
|
|
||||||
self.uri = uri or "mongodb://root:aWQtrUH!b3@XVjfbNkkp.mongodb.bj.baidubce.com/mmeb?authSource=admin"
|
|
||||||
self.database_name = database
|
|
||||||
self.client = None
|
|
||||||
self.db = None
|
|
||||||
|
|
||||||
self._connect()
|
|
||||||
|
|
||||||
def _connect(self):
|
|
||||||
"""建立MongoDB连接"""
|
|
||||||
try:
|
|
||||||
self.client = MongoClient(self.uri, serverSelectionTimeoutMS=5000)
|
|
||||||
# 测试连接
|
|
||||||
self.client.admin.command('ping')
|
|
||||||
self.db = self.client[self.database_name]
|
|
||||||
logger.info(f"✅ MongoDB连接成功: {self.database_name}")
|
|
||||||
|
|
||||||
# 创建索引
|
|
||||||
self._create_indexes()
|
|
||||||
|
|
||||||
except ConnectionFailure as e:
|
|
||||||
logger.error(f"❌ MongoDB连接失败: {e}")
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ MongoDB初始化失败: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def _create_indexes(self):
|
|
||||||
"""创建必要的索引"""
|
|
||||||
try:
|
|
||||||
# 文件元数据集合索引
|
|
||||||
files_collection = self.db.files
|
|
||||||
files_collection.create_index("file_hash", unique=True)
|
|
||||||
files_collection.create_index("file_type")
|
|
||||||
files_collection.create_index("upload_time")
|
|
||||||
files_collection.create_index("bos_key")
|
|
||||||
|
|
||||||
# 向量索引集合索引
|
|
||||||
vectors_collection = self.db.vectors
|
|
||||||
vectors_collection.create_index("file_id")
|
|
||||||
vectors_collection.create_index("vector_type")
|
|
||||||
vectors_collection.create_index("vdb_id")
|
|
||||||
|
|
||||||
logger.info("✅ MongoDB索引创建完成")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"⚠️ 创建索引时出现警告: {e}")
|
|
||||||
|
|
||||||
def store_file_metadata(self, file_path: str = None, file_type: str = None,
|
|
||||||
bos_key: str = None, additional_info: Dict = None,
|
|
||||||
metadata: Dict = None) -> str:
|
|
||||||
"""
|
|
||||||
存储文件元数据
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_path: 本地文件路径 (可选,如果提供metadata则忽略)
|
|
||||||
file_type: 文件类型 (image/text)
|
|
||||||
bos_key: BOS存储键
|
|
||||||
additional_info: 额外信息
|
|
||||||
metadata: 直接提供的元数据字典 (新增参数)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
文件ID
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 如果直接提供了元数据,使用元数据
|
|
||||||
if metadata:
|
|
||||||
# 确保必要字段存在
|
|
||||||
if 'upload_time' not in metadata:
|
|
||||||
metadata['upload_time'] = datetime.utcnow()
|
|
||||||
if 'status' not in metadata:
|
|
||||||
metadata['status'] = 'active'
|
|
||||||
|
|
||||||
# 检查是否已存在(基于file_id或其他唯一标识)
|
|
||||||
existing = None
|
|
||||||
if 'file_id' in metadata:
|
|
||||||
existing = self.db.files.find_one({"file_id": metadata['file_id']})
|
|
||||||
elif 'bos_key' in metadata:
|
|
||||||
existing = self.db.files.find_one({"bos_key": metadata['bos_key']})
|
|
||||||
|
|
||||||
if existing:
|
|
||||||
logger.info(f"文件已存在: {metadata.get('filename', 'unknown')} (ID: {existing['_id']})")
|
|
||||||
return str(existing['_id'])
|
|
||||||
|
|
||||||
# 插入新记录
|
|
||||||
result = self.db.files.insert_one(metadata)
|
|
||||||
file_id = str(result.inserted_id)
|
|
||||||
|
|
||||||
logger.info(f"✅ 文件元数据已存储: {metadata.get('filename', 'unknown')} (ID: {file_id})")
|
|
||||||
return file_id
|
|
||||||
|
|
||||||
# 原有逻辑:基于文件路径
|
|
||||||
if not file_path or not file_type or not bos_key:
|
|
||||||
raise ValueError("file_path, file_type, and bos_key are required when metadata is not provided")
|
|
||||||
|
|
||||||
# 计算文件哈希
|
|
||||||
file_hash = self._calculate_file_hash(file_path)
|
|
||||||
|
|
||||||
# 获取文件信息
|
|
||||||
file_stat = os.stat(file_path)
|
|
||||||
filename = os.path.basename(file_path)
|
|
||||||
|
|
||||||
metadata = {
|
|
||||||
"filename": filename,
|
|
||||||
"file_path": file_path,
|
|
||||||
"file_type": file_type,
|
|
||||||
"file_hash": file_hash,
|
|
||||||
"file_size": file_stat.st_size,
|
|
||||||
"bos_key": bos_key,
|
|
||||||
"upload_time": datetime.utcnow(),
|
|
||||||
"status": "active",
|
|
||||||
"additional_info": additional_info or {}
|
|
||||||
}
|
|
||||||
|
|
||||||
# 检查是否已存在
|
|
||||||
existing = self.db.files.find_one({"file_hash": file_hash})
|
|
||||||
if existing:
|
|
||||||
logger.info(f"文件已存在: {filename} (ID: {existing['_id']})")
|
|
||||||
return str(existing['_id'])
|
|
||||||
|
|
||||||
# 插入新记录
|
|
||||||
result = self.db.files.insert_one(metadata)
|
|
||||||
file_id = str(result.inserted_id)
|
|
||||||
|
|
||||||
logger.info(f"✅ 文件元数据已存储: {filename} (ID: {file_id})")
|
|
||||||
return file_id
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 存储文件元数据失败: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def store_vector_metadata(self, file_id: str, vector_type: str,
|
|
||||||
vdb_id: str, vector_info: Dict = None):
|
|
||||||
"""
|
|
||||||
存储向量元数据
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_id: 文件ID
|
|
||||||
vector_type: 向量类型 (text_vector/image_vector)
|
|
||||||
vdb_id: VDB中的向量ID
|
|
||||||
vector_info: 向量信息
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
vector_metadata = {
|
|
||||||
"file_id": file_id,
|
|
||||||
"vector_type": vector_type,
|
|
||||||
"vdb_id": vdb_id,
|
|
||||||
"create_time": datetime.utcnow(),
|
|
||||||
"vector_info": vector_info or {}
|
|
||||||
}
|
|
||||||
|
|
||||||
result = self.db.vectors.insert_one(vector_metadata)
|
|
||||||
logger.info(f"✅ 向量元数据已存储: {vector_type} (ID: {result.inserted_id})")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 存储向量元数据失败: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def get_file_metadata(self, file_id: str) -> Optional[Dict]:
|
|
||||||
"""获取文件元数据"""
|
|
||||||
try:
|
|
||||||
from bson import ObjectId
|
|
||||||
result = self.db.files.find_one({"_id": ObjectId(file_id)})
|
|
||||||
if result:
|
|
||||||
result['_id'] = str(result['_id'])
|
|
||||||
return result
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 获取文件元数据失败: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_files_by_type(self, file_type: str, limit: int = 100) -> List[Dict]:
|
|
||||||
"""根据类型获取文件列表"""
|
|
||||||
try:
|
|
||||||
cursor = self.db.files.find(
|
|
||||||
{"file_type": file_type, "status": "active"}
|
|
||||||
).limit(limit).sort("upload_time", -1)
|
|
||||||
|
|
||||||
results = []
|
|
||||||
for doc in cursor:
|
|
||||||
doc['_id'] = str(doc['_id'])
|
|
||||||
results.append(doc)
|
|
||||||
|
|
||||||
return results
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 获取文件列表失败: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
def get_all_files(self, limit: int = 1000) -> List[Dict]:
|
|
||||||
"""获取所有文件列表"""
|
|
||||||
try:
|
|
||||||
cursor = self.db.files.find(
|
|
||||||
{"status": "active"}
|
|
||||||
).limit(limit).sort("upload_time", -1)
|
|
||||||
|
|
||||||
results = []
|
|
||||||
for doc in cursor:
|
|
||||||
doc['_id'] = str(doc['_id'])
|
|
||||||
results.append(doc)
|
|
||||||
|
|
||||||
return results
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 获取所有文件列表失败: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
def get_vector_metadata(self, file_id: str, vector_type: str = None) -> List[Dict]:
|
|
||||||
"""获取向量元数据"""
|
|
||||||
try:
|
|
||||||
query = {"file_id": file_id}
|
|
||||||
if vector_type:
|
|
||||||
query["vector_type"] = vector_type
|
|
||||||
|
|
||||||
cursor = self.db.vectors.find(query)
|
|
||||||
results = []
|
|
||||||
for doc in cursor:
|
|
||||||
doc['_id'] = str(doc['_id'])
|
|
||||||
results.append(doc)
|
|
||||||
|
|
||||||
return results
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 获取向量元数据失败: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
def get_stats(self) -> Dict:
|
|
||||||
"""获取统计信息"""
|
|
||||||
try:
|
|
||||||
stats = {
|
|
||||||
"total_files": self.db.files.count_documents({"status": "active"}),
|
|
||||||
"image_files": self.db.files.count_documents({"file_type": "image", "status": "active"}),
|
|
||||||
"text_files": self.db.files.count_documents({"file_type": "text", "status": "active"}),
|
|
||||||
"total_vectors": self.db.vectors.count_documents({}),
|
|
||||||
"image_vectors": self.db.vectors.count_documents({"vector_type": "image_vector"}),
|
|
||||||
"text_vectors": self.db.vectors.count_documents({"vector_type": "text_vector"})
|
|
||||||
}
|
|
||||||
return stats
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 获取统计信息失败: {e}")
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def delete_file_metadata(self, file_id: str):
|
|
||||||
"""删除文件元数据(软删除)"""
|
|
||||||
try:
|
|
||||||
from bson import ObjectId
|
|
||||||
self.db.files.update_one(
|
|
||||||
{"_id": ObjectId(file_id)},
|
|
||||||
{"$set": {"status": "deleted", "delete_time": datetime.utcnow()}}
|
|
||||||
)
|
|
||||||
logger.info(f"✅ 文件元数据已删除: {file_id}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 删除文件元数据失败: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def _calculate_file_hash(self, file_path: str) -> str:
|
|
||||||
"""计算文件SHA256哈希"""
|
|
||||||
hash_sha256 = hashlib.sha256()
|
|
||||||
with open(file_path, "rb") as f:
|
|
||||||
for chunk in iter(lambda: f.read(4096), b""):
|
|
||||||
hash_sha256.update(chunk)
|
|
||||||
return hash_sha256.hexdigest()
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
"""关闭连接"""
|
|
||||||
if self.client:
|
|
||||||
self.client.close()
|
|
||||||
logger.info("MongoDB连接已关闭")
|
|
||||||
|
|
||||||
# 全局实例
|
|
||||||
mongodb_manager = None
|
|
||||||
|
|
||||||
def get_mongodb_manager() -> MongoDBManager:
|
|
||||||
"""获取MongoDB管理器实例"""
|
|
||||||
global mongodb_manager
|
|
||||||
if mongodb_manager is None:
|
|
||||||
mongodb_manager = MongoDBManager()
|
|
||||||
return mongodb_manager
|
|
||||||
@ -1,370 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
基于FAISS的多模态检索系统
|
|
||||||
支持文搜文、文搜图、图搜文、图搜图四种检索模式
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch.nn.parallel import DataParallel
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
from transformers import AutoModel, AutoProcessor, AutoTokenizer
|
|
||||||
from typing import List, Union, Tuple, Dict, Any, Optional
|
|
||||||
import os
|
|
||||||
import json
|
|
||||||
from pathlib import Path
|
|
||||||
import logging
|
|
||||||
import gc
|
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
||||||
import threading
|
|
||||||
|
|
||||||
from faiss_vector_store import FaissVectorStore
|
|
||||||
|
|
||||||
# 设置日志
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
class MultimodalRetrievalFAISS:
|
|
||||||
"""基于FAISS的多模态检索系统"""
|
|
||||||
|
|
||||||
def __init__(self, model_name: str = "OpenSearch-AI/Ops-MM-embedding-v1-7B",
|
|
||||||
use_all_gpus: bool = True, gpu_ids: List[int] = None,
|
|
||||||
min_memory_gb: int = 12, index_path: str = "faiss_index"):
|
|
||||||
"""
|
|
||||||
初始化多模态检索系统
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_name: 模型名称
|
|
||||||
use_all_gpus: 是否使用所有可用GPU
|
|
||||||
gpu_ids: 指定使用的GPU ID列表
|
|
||||||
min_memory_gb: 最小可用内存(GB)
|
|
||||||
index_path: FAISS索引文件路径
|
|
||||||
"""
|
|
||||||
self.model_name = model_name
|
|
||||||
self.index_path = index_path
|
|
||||||
|
|
||||||
# 设置GPU设备
|
|
||||||
self._setup_devices(use_all_gpus, gpu_ids, min_memory_gb)
|
|
||||||
|
|
||||||
# 清理GPU内存
|
|
||||||
self._clear_all_gpu_memory()
|
|
||||||
|
|
||||||
# 加载模型和处理器
|
|
||||||
self._load_model_and_processor()
|
|
||||||
|
|
||||||
# 初始化FAISS向量存储
|
|
||||||
self.vector_store = FaissVectorStore(
|
|
||||||
index_path=index_path,
|
|
||||||
dimension=3584 # OpenSearch-AI/Ops-MM-embedding-v1-7B的向量维度
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"多模态检索系统初始化完成,使用模型: {model_name}")
|
|
||||||
logger.info(f"向量存储路径: {index_path}")
|
|
||||||
|
|
||||||
def _setup_devices(self, use_all_gpus: bool, gpu_ids: List[int], min_memory_gb: int):
|
|
||||||
"""设置GPU设备"""
|
|
||||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
self.use_gpu = self.device.type == "cuda"
|
|
||||||
|
|
||||||
if self.use_gpu:
|
|
||||||
self.available_gpus = self._get_available_gpus(min_memory_gb)
|
|
||||||
|
|
||||||
if not self.available_gpus:
|
|
||||||
logger.warning(f"没有可用的GPU或GPU内存不足{min_memory_gb}GB,将使用CPU")
|
|
||||||
self.device = torch.device("cpu")
|
|
||||||
self.use_gpu = False
|
|
||||||
else:
|
|
||||||
if gpu_ids:
|
|
||||||
self.gpu_ids = [gid for gid in gpu_ids if gid in self.available_gpus]
|
|
||||||
if not self.gpu_ids:
|
|
||||||
logger.warning(f"指定的GPU {gpu_ids}不可用或内存不足,将使用可用的GPU: {self.available_gpus}")
|
|
||||||
self.gpu_ids = self.available_gpus
|
|
||||||
elif use_all_gpus:
|
|
||||||
self.gpu_ids = self.available_gpus
|
|
||||||
else:
|
|
||||||
self.gpu_ids = [self.available_gpus[0]]
|
|
||||||
|
|
||||||
logger.info(f"使用GPU: {self.gpu_ids}")
|
|
||||||
self.device = torch.device(f"cuda:{self.gpu_ids[0]}")
|
|
||||||
|
|
||||||
def _get_available_gpus(self, min_memory_gb: int) -> List[int]:
|
|
||||||
"""获取可用的GPU列表"""
|
|
||||||
available_gpus = []
|
|
||||||
for i in range(torch.cuda.device_count()):
|
|
||||||
total_mem = torch.cuda.get_device_properties(i).total_memory / (1024 ** 3) # GB
|
|
||||||
if total_mem >= min_memory_gb:
|
|
||||||
available_gpus.append(i)
|
|
||||||
return available_gpus
|
|
||||||
|
|
||||||
def _clear_all_gpu_memory(self):
|
|
||||||
"""清理GPU内存"""
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
def _load_model_and_processor(self):
|
|
||||||
"""加载模型和处理器"""
|
|
||||||
logger.info(f"加载模型和处理器: {self.model_name}")
|
|
||||||
|
|
||||||
# 加载tokenizer和processor
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
|
||||||
self.processor = AutoProcessor.from_pretrained(self.model_name)
|
|
||||||
|
|
||||||
# 加载模型
|
|
||||||
self.model = AutoModel.from_pretrained(
|
|
||||||
self.model_name,
|
|
||||||
torch_dtype=torch.float16 if self.use_gpu else torch.float32,
|
|
||||||
device_map="auto" if len(self.gpu_ids) > 1 else None
|
|
||||||
)
|
|
||||||
|
|
||||||
# 如果使用多GPU,包装模型
|
|
||||||
if len(self.gpu_ids) > 1:
|
|
||||||
self.model = DataParallel(self.model, device_ids=self.gpu_ids)
|
|
||||||
|
|
||||||
self.model.eval()
|
|
||||||
self.model.to(self.device)
|
|
||||||
|
|
||||||
logger.info("模型和处理器加载完成")
|
|
||||||
|
|
||||||
def encode_text(self, text: Union[str, List[str]]) -> np.ndarray:
|
|
||||||
"""编码文本为向量"""
|
|
||||||
if isinstance(text, str):
|
|
||||||
text = [text]
|
|
||||||
|
|
||||||
inputs = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt")
|
|
||||||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
outputs = self.model(**inputs)
|
|
||||||
# 获取[CLS]标记的隐藏状态作为句子表示
|
|
||||||
text_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()
|
|
||||||
|
|
||||||
# 归一化向量
|
|
||||||
text_embeddings = text_embeddings / np.linalg.norm(text_embeddings, axis=1, keepdims=True)
|
|
||||||
return text_embeddings[0] if len(text) == 1 else text_embeddings
|
|
||||||
|
|
||||||
def encode_image(self, image: Union[Image.Image, List[Image.Image]]) -> np.ndarray:
|
|
||||||
"""编码图像为向量"""
|
|
||||||
if isinstance(image, Image.Image):
|
|
||||||
image = [image]
|
|
||||||
|
|
||||||
inputs = self.processor(images=image, return_tensors="pt")
|
|
||||||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
outputs = self.model.vision_model(**inputs)
|
|
||||||
# 获取[CLS]标记的隐藏状态作为图像表示
|
|
||||||
image_embeddings = outputs.pooler_output.cpu().numpy()
|
|
||||||
|
|
||||||
# 归一化向量
|
|
||||||
image_embeddings = image_embeddings / np.linalg.norm(image_embeddings, axis=1, keepdims=True)
|
|
||||||
return image_embeddings[0] if len(image) == 1 else image_embeddings
|
|
||||||
|
|
||||||
def add_texts(
|
|
||||||
self,
|
|
||||||
texts: List[str],
|
|
||||||
metadatas: Optional[List[Dict[str, Any]]] = None
|
|
||||||
) -> List[str]:
|
|
||||||
"""
|
|
||||||
添加文本到检索系统
|
|
||||||
|
|
||||||
Args:
|
|
||||||
texts: 文本列表
|
|
||||||
metadatas: 元数据列表,每个元素是一个字典
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
添加的文本ID列表
|
|
||||||
"""
|
|
||||||
if not texts:
|
|
||||||
return []
|
|
||||||
|
|
||||||
if metadatas is None:
|
|
||||||
metadatas = [{} for _ in range(len(texts))]
|
|
||||||
|
|
||||||
if len(texts) != len(metadatas):
|
|
||||||
raise ValueError("texts和metadatas长度必须相同")
|
|
||||||
|
|
||||||
# 编码文本
|
|
||||||
text_embeddings = self.encode_text(texts)
|
|
||||||
|
|
||||||
# 准备元数据
|
|
||||||
for i, text in enumerate(texts):
|
|
||||||
metadatas[i].update({
|
|
||||||
"text": text,
|
|
||||||
"type": "text"
|
|
||||||
})
|
|
||||||
|
|
||||||
# 添加到向量存储
|
|
||||||
vector_ids = self.vector_store.add_vectors(text_embeddings, metadatas)
|
|
||||||
|
|
||||||
logger.info(f"成功添加{len(vector_ids)}条文本到检索系统")
|
|
||||||
return vector_ids
|
|
||||||
|
|
||||||
def add_images(
|
|
||||||
self,
|
|
||||||
images: List[Image.Image],
|
|
||||||
metadatas: Optional[List[Dict[str, Any]]] = None
|
|
||||||
) -> List[str]:
|
|
||||||
"""
|
|
||||||
添加图像到检索系统
|
|
||||||
|
|
||||||
Args:
|
|
||||||
images: PIL图像列表
|
|
||||||
metadatas: 元数据列表,每个元素是一个字典
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
添加的图像ID列表
|
|
||||||
"""
|
|
||||||
if not images:
|
|
||||||
return []
|
|
||||||
|
|
||||||
if metadatas is None:
|
|
||||||
metadatas = [{} for _ in range(len(images))]
|
|
||||||
|
|
||||||
if len(images) != len(metadatas):
|
|
||||||
raise ValueError("images和metadatas长度必须相同")
|
|
||||||
|
|
||||||
# 编码图像
|
|
||||||
image_embeddings = self.encode_image(images)
|
|
||||||
|
|
||||||
# 准备元数据
|
|
||||||
for i, image in enumerate(images):
|
|
||||||
metadatas[i].update({
|
|
||||||
"type": "image",
|
|
||||||
"width": image.width,
|
|
||||||
"height": image.height
|
|
||||||
})
|
|
||||||
|
|
||||||
# 添加到向量存储
|
|
||||||
vector_ids = self.vector_store.add_vectors(image_embeddings, metadatas)
|
|
||||||
|
|
||||||
logger.info(f"成功添加{len(vector_ids)}张图像到检索系统")
|
|
||||||
return vector_ids
|
|
||||||
|
|
||||||
def search_by_text(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
k: int = 5,
|
|
||||||
filter_condition: Optional[Dict[str, Any]] = None
|
|
||||||
) -> List[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
文本搜索
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: 查询文本
|
|
||||||
k: 返回结果数量
|
|
||||||
filter_condition: 过滤条件
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
搜索结果列表,每个元素包含相似项和分数
|
|
||||||
"""
|
|
||||||
# 编码查询文本
|
|
||||||
query_embedding = self.encode_text(query)
|
|
||||||
|
|
||||||
# 执行搜索
|
|
||||||
results, distances = self.vector_store.search(query_embedding, k)
|
|
||||||
|
|
||||||
# 处理结果
|
|
||||||
search_results = []
|
|
||||||
for i, (result, distance) in enumerate(zip(results, distances)):
|
|
||||||
result["score"] = 1.0 / (1.0 + distance) # 将距离转换为相似度分数
|
|
||||||
search_results.append(result)
|
|
||||||
|
|
||||||
return search_results
|
|
||||||
|
|
||||||
def search_by_image(
|
|
||||||
self,
|
|
||||||
image: Image.Image,
|
|
||||||
k: int = 5,
|
|
||||||
filter_condition: Optional[Dict[str, Any]] = None
|
|
||||||
) -> List[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
图像搜索
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image: 查询图像
|
|
||||||
k: 返回结果数量
|
|
||||||
filter_condition: 过滤条件
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
搜索结果列表,每个元素包含相似项和分数
|
|
||||||
"""
|
|
||||||
# 编码查询图像
|
|
||||||
query_embedding = self.encode_image(image)
|
|
||||||
|
|
||||||
# 执行搜索
|
|
||||||
results, distances = self.vector_store.search(query_embedding, k)
|
|
||||||
|
|
||||||
# 处理结果
|
|
||||||
search_results = []
|
|
||||||
for i, (result, distance) in enumerate(zip(results, distances)):
|
|
||||||
result["score"] = 1.0 / (1.0 + distance) # 将距离转换为相似度分数
|
|
||||||
search_results.append(result)
|
|
||||||
|
|
||||||
return search_results
|
|
||||||
|
|
||||||
def get_vector_count(self) -> int:
|
|
||||||
"""获取向量数量"""
|
|
||||||
return self.vector_store.get_vector_count()
|
|
||||||
|
|
||||||
def save_index(self):
|
|
||||||
"""保存索引"""
|
|
||||||
self.vector_store.save_index()
|
|
||||||
logger.info("索引已保存")
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
"""析构函数,确保资源被正确释放"""
|
|
||||||
if hasattr(self, 'model'):
|
|
||||||
del self.model
|
|
||||||
self._clear_all_gpu_memory()
|
|
||||||
if hasattr(self, 'vector_store'):
|
|
||||||
self.save_index()
|
|
||||||
|
|
||||||
|
|
||||||
def test_faiss_system():
|
|
||||||
"""测试FAISS多模态检索系统"""
|
|
||||||
import time
|
|
||||||
from PIL import Image
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
# 初始化检索系统
|
|
||||||
print("初始化多模态检索系统...")
|
|
||||||
retrieval = MultimodalRetrievalFAISS(
|
|
||||||
model_name="OpenSearch-AI/Ops-MM-embedding-v1-7B",
|
|
||||||
use_all_gpus=True,
|
|
||||||
index_path="faiss_index_test"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 测试文本
|
|
||||||
texts = [
|
|
||||||
"一只可爱的橘色猫咪在沙发上睡觉",
|
|
||||||
"城市夜景中的高楼大厦和车流",
|
|
||||||
"阳光明媚的海滩上,人们在冲浪和晒太阳",
|
|
||||||
"美味的意大利面配红酒和沙拉",
|
|
||||||
"雪山上滑雪的运动员"
|
|
||||||
]
|
|
||||||
|
|
||||||
# 添加文本
|
|
||||||
print("\n添加文本到检索系统...")
|
|
||||||
text_ids = retrieval.add_texts(texts)
|
|
||||||
print(f"添加了{len(text_ids)}条文本")
|
|
||||||
|
|
||||||
# 测试文本搜索
|
|
||||||
print("\n测试文本搜索...")
|
|
||||||
query_text = "一只猫在睡觉"
|
|
||||||
print(f"查询: {query_text}")
|
|
||||||
results = retrieval.search_by_text(query_text, k=2)
|
|
||||||
for i, result in enumerate(results):
|
|
||||||
print(f"结果 {i+1}: {result.get('text', 'N/A')} (分数: {result.get('score', 0):.4f})")
|
|
||||||
|
|
||||||
# 测试保存和加载
|
|
||||||
print("\n保存索引...")
|
|
||||||
retrieval.save_index()
|
|
||||||
|
|
||||||
print("\n测试完成!")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test_faiss_system()
|
|
||||||
@ -8,7 +8,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import AutoModel, AutoProcessor, AutoTokenizer
|
from ops_mm_embedding_v1 import OpsMMEmbeddingV1
|
||||||
from typing import List, Union, Tuple, Dict, Any, Optional
|
from typing import List, Union, Tuple, Dict, Any, Optional
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
@ -59,8 +59,8 @@ class MultimodalRetrievalLocal:
|
|||||||
# 清理GPU内存
|
# 清理GPU内存
|
||||||
self._clear_all_gpu_memory()
|
self._clear_all_gpu_memory()
|
||||||
|
|
||||||
# 加载模型和处理器
|
# 加载嵌入模型
|
||||||
self._load_model_and_processor()
|
self._load_embedding_model()
|
||||||
|
|
||||||
# 初始化FAISS索引
|
# 初始化FAISS索引
|
||||||
self._init_index()
|
self._init_index()
|
||||||
@ -112,45 +112,23 @@ class MultimodalRetrievalLocal:
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
def _load_model_and_processor(self):
|
def _load_embedding_model(self):
|
||||||
"""加载模型和处理器"""
|
"""加载多模态嵌入模型 OpsMMEmbeddingV1"""
|
||||||
logger.info(f"加载本地模型和处理器: {self.model_path}")
|
logger.info(f"加载本地多模态嵌入模型: {self.model_path}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 加载模型和处理器
|
device_str = "cuda" if self.use_gpu else "cpu"
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
|
self.model = OpsMMEmbeddingV1(
|
||||||
self.processor = AutoProcessor.from_pretrained(self.model_path)
|
|
||||||
|
|
||||||
# 输出处理器信息
|
|
||||||
logger.info(f"Processor类型: {type(self.processor)}")
|
|
||||||
logger.info(f"Processor方法: {dir(self.processor)}")
|
|
||||||
|
|
||||||
# 检查是否有图像处理器
|
|
||||||
if hasattr(self.processor, 'image_processor'):
|
|
||||||
logger.info(f"Image processor类型: {type(self.processor.image_processor)}")
|
|
||||||
logger.info(f"Image processor方法: {dir(self.processor.image_processor)}")
|
|
||||||
|
|
||||||
# 加载模型
|
|
||||||
self.model = AutoModel.from_pretrained(
|
|
||||||
self.model_path,
|
self.model_path,
|
||||||
torch_dtype=torch.float16 if self.use_gpu else torch.float32,
|
device=device_str,
|
||||||
device_map="auto" if len(self.gpu_ids) > 1 else None
|
attn_implementation=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(self.gpu_ids) == 1:
|
|
||||||
self.model.to(self.device)
|
|
||||||
|
|
||||||
self.model.eval()
|
|
||||||
|
|
||||||
# 获取向量维度
|
# 获取向量维度
|
||||||
self.vector_dim = self.model.config.hidden_size
|
self.vector_dim = int(getattr(self.model.base_model.config, "hidden_size"))
|
||||||
logger.info(f"向量维度: {self.vector_dim}")
|
logger.info(f"向量维度: {self.vector_dim}")
|
||||||
|
logger.info("嵌入模型加载成功")
|
||||||
logger.info("模型和处理器加载成功")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"模型加载失败: {str(e)}")
|
logger.error(f"嵌入模型加载失败: {str(e)}")
|
||||||
raise RuntimeError(f"模型加载失败: {str(e)}")
|
raise RuntimeError(f"嵌入模型加载失败: {str(e)}")
|
||||||
|
|
||||||
def _init_index(self):
|
def _init_index(self):
|
||||||
"""初始化FAISS索引"""
|
"""初始化FAISS索引"""
|
||||||
@ -180,133 +158,35 @@ class MultimodalRetrievalLocal:
|
|||||||
logger.error(f"元数据加载失败: {str(e)}")
|
logger.error(f"元数据加载失败: {str(e)}")
|
||||||
|
|
||||||
def encode_text(self, text: Union[str, List[str]]) -> np.ndarray:
|
def encode_text(self, text: Union[str, List[str]]) -> np.ndarray:
|
||||||
"""编码文本为向量"""
|
"""编码文本为向量(使用 OpsMMEmbeddingV1)"""
|
||||||
if isinstance(text, str):
|
if isinstance(text, str):
|
||||||
text = [text]
|
text = [text]
|
||||||
|
with torch.inference_mode():
|
||||||
inputs = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt")
|
emb = self.model.get_text_embeddings(texts=text)
|
||||||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
text_embeddings = emb.detach().float().cpu().numpy()
|
||||||
|
# emb 已经做过 L2 归一化,这里保持一致
|
||||||
with torch.no_grad():
|
|
||||||
outputs = self.model(**inputs)
|
|
||||||
# 获取[CLS]标记的隐藏状态作为句子表示
|
|
||||||
text_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()
|
|
||||||
|
|
||||||
# 归一化向量
|
|
||||||
text_embeddings = text_embeddings / np.linalg.norm(text_embeddings, axis=1, keepdims=True)
|
|
||||||
return text_embeddings[0] if len(text) == 1 else text_embeddings
|
return text_embeddings[0] if len(text) == 1 else text_embeddings
|
||||||
|
|
||||||
def encode_image(self, image: Union[Image.Image, List[Image.Image]]) -> np.ndarray:
|
def encode_image(self, image: Union[Image.Image, List[Image.Image]]) -> np.ndarray:
|
||||||
"""编码图像为向量"""
|
"""编码图像为向量(使用 OpsMMEmbeddingV1)"""
|
||||||
try:
|
try:
|
||||||
logger.info(f"encode_image: 开始编码图像,类型: {type(image)}")
|
# 规范为列表
|
||||||
|
images: List[Image.Image]
|
||||||
if isinstance(image, Image.Image):
|
if isinstance(image, Image.Image):
|
||||||
logger.info(f"encode_image: 单个图像,大小: {image.size}")
|
images = [image]
|
||||||
image = [image]
|
|
||||||
else:
|
else:
|
||||||
logger.info(f"encode_image: 图像列表,长度: {len(image)}")
|
images = image
|
||||||
|
if not images:
|
||||||
# 检查图像是否为空
|
|
||||||
if not image or len(image) == 0:
|
|
||||||
logger.error("encode_image: 图像列表为空")
|
logger.error("encode_image: 图像列表为空")
|
||||||
# 返回一个空的二维数组
|
|
||||||
return np.zeros((0, self.vector_dim))
|
return np.zeros((0, self.vector_dim))
|
||||||
|
# 强制为 RGB
|
||||||
# 检查图像是否有效
|
rgb_images = [img.convert('RGB') if img.mode != 'RGB' else img for img in images]
|
||||||
for i, img in enumerate(image):
|
with torch.inference_mode():
|
||||||
if not isinstance(img, Image.Image):
|
emb = self.model.get_image_embeddings(images=rgb_images)
|
||||||
logger.error(f"encode_image: 第{i}个元素不是有效的PIL图像,类型: {type(img)}")
|
image_embeddings = emb.detach().float().cpu().numpy()
|
||||||
# 返回一个空的二维数组
|
return image_embeddings
|
||||||
return np.zeros((0, self.vector_dim))
|
|
||||||
|
|
||||||
logger.info("encode_image: 处理图像输入")
|
|
||||||
|
|
||||||
# 检查图像格式
|
|
||||||
for i, img in enumerate(image):
|
|
||||||
logger.info(f"encode_image: 图像 {i} 格式: {img.format}, 模式: {img.mode}, 大小: {img.size}")
|
|
||||||
# 转换为RGB模式,如果不是
|
|
||||||
if img.mode != 'RGB':
|
|
||||||
logger.info(f"encode_image: 将图像 {i} 从 {img.mode} 转换为 RGB")
|
|
||||||
image[i] = img.convert('RGB')
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 直接使用image_processor处理图像
|
|
||||||
if hasattr(self.processor, 'image_processor'):
|
|
||||||
logger.info("encode_image: 使用image_processor处理图像")
|
|
||||||
pixel_values = self.processor.image_processor(images=image, return_tensors="pt").pixel_values
|
|
||||||
inputs = {"pixel_values": pixel_values}
|
|
||||||
else:
|
|
||||||
logger.info("encode_image: 使用processor处理图像")
|
|
||||||
inputs = self.processor(images=image, return_tensors="pt")
|
|
||||||
|
|
||||||
if not inputs or len(inputs) == 0:
|
|
||||||
logger.error("encode_image: processor返回了空的输入")
|
|
||||||
return np.zeros((0, self.vector_dim))
|
|
||||||
|
|
||||||
logger.info(f"encode_image: 处理后的输入键: {list(inputs.keys())}")
|
|
||||||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
|
||||||
|
|
||||||
logger.info("encode_image: 运行模型推理")
|
|
||||||
logger.info(f"Model类型: {type(self.model)}")
|
|
||||||
logger.info(f"Model属性: {dir(self.model)}")
|
|
||||||
|
|
||||||
# 检查模型结构
|
|
||||||
try:
|
|
||||||
logger.info(f"Model配置: {self.model.config}")
|
|
||||||
logger.info(f"Model配置属性: {dir(self.model.config)}")
|
|
||||||
else:
|
|
||||||
visual_outputs = self.model.visual(**inputs)
|
|
||||||
|
|
||||||
if hasattr(visual_outputs, 'pooler_output'):
|
|
||||||
image_embeddings = visual_outputs.pooler_output.cpu().numpy()
|
|
||||||
elif hasattr(visual_outputs, 'last_hidden_state'):
|
|
||||||
image_embeddings = visual_outputs.last_hidden_state[:, 0, :].cpu().numpy()
|
|
||||||
else:
|
|
||||||
logger.error("encode_image: 无法从视觉模型输出中获取图像向量")
|
|
||||||
raise ValueError("无法从视觉模型输出中获取图像向量")
|
|
||||||
else:
|
|
||||||
# 尝试直接使用模型进行推理
|
|
||||||
logger.info("encode_image: 尝试直接使用模型进行推理")
|
|
||||||
with torch.no_grad():
|
|
||||||
# 使用空文本输入,只提供图像
|
|
||||||
if 'pixel_values' in inputs:
|
|
||||||
outputs = self.model(pixel_values=inputs['pixel_values'], input_ids=None)
|
|
||||||
else:
|
|
||||||
outputs = self.model(**inputs, input_ids=None)
|
|
||||||
|
|
||||||
# 尝试从输出中获取图像向量
|
|
||||||
if hasattr(outputs, 'image_embeds'):
|
|
||||||
image_embeddings = outputs.image_embeds.cpu().numpy()
|
|
||||||
elif hasattr(outputs, 'vision_model_output') and hasattr(outputs.vision_model_output, 'pooler_output'):
|
|
||||||
image_embeddings = outputs.vision_model_output.pooler_output.cpu().numpy()
|
|
||||||
elif hasattr(outputs, 'pooler_output'):
|
|
||||||
image_embeddings = outputs.pooler_output.cpu().numpy()
|
|
||||||
elif hasattr(outputs, 'last_hidden_state'):
|
|
||||||
image_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()
|
|
||||||
else:
|
|
||||||
logger.error("encode_image: 无法从模型输出中获取图像向量")
|
|
||||||
raise ValueError("无法从模型输出中获取图像向量")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"encode_image: 处理图像时出错: {str(e)}")
|
|
||||||
raise e
|
|
||||||
return np.zeros((0, self.vector_dim))
|
|
||||||
|
|
||||||
# 归一化向量
|
|
||||||
image_embeddings = image_embeddings / np.linalg.norm(image_embeddings, axis=1, keepdims=True)
|
|
||||||
|
|
||||||
# 始终返回二维数组,即使只有一个图像
|
|
||||||
if len(image) == 1:
|
|
||||||
result = np.array([image_embeddings[0]])
|
|
||||||
logger.info(f"encode_image: 返回单个图像向量,形状: {result.shape}")
|
|
||||||
return result
|
|
||||||
else:
|
|
||||||
logger.info(f"encode_image: 返回多个图像向量,形状: {image_embeddings.shape}")
|
|
||||||
return image_embeddings
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"encode_image: 异常: {str(e)}")
|
logger.error(f"encode_image: 异常: {str(e)}")
|
||||||
# 返回一个空的二维数组
|
|
||||||
return np.zeros((0, self.vector_dim))
|
return np.zeros((0, self.vector_dim))
|
||||||
|
|
||||||
def add_texts(
|
def add_texts(
|
||||||
|
|||||||
@ -1,592 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
集成百度VDB的多模态检索系统
|
|
||||||
支持文搜文、文搜图、图搜文、图搜图四种检索模式
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
from transformers import AutoModel, AutoProcessor, AutoTokenizer
|
|
||||||
from typing import List, Union, Tuple, Dict, Any
|
|
||||||
import os
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import gc
|
|
||||||
from baidu_vdb_backend import BaiduVDBBackend
|
|
||||||
|
|
||||||
# 设置日志
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
class MultimodalRetrievalVDB:
|
|
||||||
"""集成百度VDB的多模态检索系统"""
|
|
||||||
|
|
||||||
def __init__(self, model_name: str = "OpenSearch-AI/Ops-MM-embedding-v1-7B",
|
|
||||||
use_all_gpus: bool = True, gpu_ids: List[int] = None,
|
|
||||||
vdb_config: Dict[str, str] = None):
|
|
||||||
"""
|
|
||||||
初始化多模态检索系统
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_name: 模型名称
|
|
||||||
use_all_gpus: 是否使用所有可用GPU
|
|
||||||
gpu_ids: 指定使用的GPU ID列表
|
|
||||||
vdb_config: VDB配置字典
|
|
||||||
"""
|
|
||||||
self.model_name = model_name
|
|
||||||
|
|
||||||
# 设置GPU设备
|
|
||||||
self._setup_devices(use_all_gpus, gpu_ids)
|
|
||||||
|
|
||||||
# 清理GPU内存
|
|
||||||
self._clear_gpu_memory()
|
|
||||||
|
|
||||||
logger.info(f"正在加载模型到GPU: {self.device_ids}")
|
|
||||||
|
|
||||||
# 加载模型和处理器
|
|
||||||
self.model = None
|
|
||||||
self.tokenizer = None
|
|
||||||
self.processor = None
|
|
||||||
self._load_model()
|
|
||||||
|
|
||||||
# 初始化百度VDB后端
|
|
||||||
if vdb_config is None:
|
|
||||||
vdb_config = {
|
|
||||||
"account": "root",
|
|
||||||
"api_key": "vdb$yjr9ln3n0td",
|
|
||||||
"endpoint": "http://180.76.96.191:5287",
|
|
||||||
"database_name": "multimodal_retrieval"
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.vdb = BaiduVDBBackend(**vdb_config)
|
|
||||||
logger.info("✅ VDB后端初始化成功")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ VDB后端初始化失败: {e}")
|
|
||||||
# 创建一个模拟的VDB后端,避免系统完全崩溃
|
|
||||||
self.vdb = None
|
|
||||||
logger.warning("⚠️ 系统将在无VDB模式下运行,数据将不会持久化")
|
|
||||||
|
|
||||||
logger.info("多模态检索系统初始化完成")
|
|
||||||
|
|
||||||
def _setup_devices(self, use_all_gpus: bool, gpu_ids: List[int]):
|
|
||||||
"""设置GPU设备"""
|
|
||||||
if not torch.cuda.is_available():
|
|
||||||
raise RuntimeError("CUDA不可用,无法使用GPU")
|
|
||||||
|
|
||||||
total_gpus = torch.cuda.device_count()
|
|
||||||
logger.info(f"检测到 {total_gpus} 个GPU")
|
|
||||||
|
|
||||||
if use_all_gpus:
|
|
||||||
self.device_ids = list(range(total_gpus))
|
|
||||||
elif gpu_ids:
|
|
||||||
self.device_ids = gpu_ids
|
|
||||||
else:
|
|
||||||
self.device_ids = [0]
|
|
||||||
|
|
||||||
self.num_gpus = len(self.device_ids)
|
|
||||||
self.primary_device = f"cuda:{self.device_ids[0]}"
|
|
||||||
|
|
||||||
logger.info(f"使用GPU: {self.device_ids}, 主设备: {self.primary_device}")
|
|
||||||
|
|
||||||
def _clear_gpu_memory(self):
|
|
||||||
"""清理GPU内存"""
|
|
||||||
for gpu_id in self.device_ids:
|
|
||||||
torch.cuda.set_device(gpu_id)
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
gc.collect()
|
|
||||||
logger.info("GPU内存已清理")
|
|
||||||
|
|
||||||
def _load_model(self):
|
|
||||||
"""加载模型"""
|
|
||||||
try:
|
|
||||||
# 设置环境变量优化内存使用
|
|
||||||
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
|
|
||||||
|
|
||||||
# 清理GPU内存
|
|
||||||
self._clear_gpu_memory()
|
|
||||||
|
|
||||||
# 设置离线模式环境变量
|
|
||||||
os.environ['TRANSFORMERS_OFFLINE'] = '1'
|
|
||||||
os.environ['HF_HUB_OFFLINE'] = '1'
|
|
||||||
|
|
||||||
# 尝试加载模型,如果网络失败则使用本地缓存
|
|
||||||
try:
|
|
||||||
# 加载模型
|
|
||||||
if self.num_gpus > 1:
|
|
||||||
# 多GPU加载
|
|
||||||
max_memory = {i: "18GiB" for i in self.device_ids}
|
|
||||||
|
|
||||||
self.model = AutoModel.from_pretrained(
|
|
||||||
self.model_name,
|
|
||||||
trust_remote_code=True,
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
device_map="auto",
|
|
||||||
max_memory=max_memory,
|
|
||||||
low_cpu_mem_usage=True,
|
|
||||||
local_files_only=False # 允许从网络下载
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# 单GPU加载
|
|
||||||
self.model = AutoModel.from_pretrained(
|
|
||||||
self.model_name,
|
|
||||||
trust_remote_code=True,
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
device_map=self.primary_device,
|
|
||||||
local_files_only=False # 允许从网络下载
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("模型从网络加载成功")
|
|
||||||
|
|
||||||
except Exception as network_error:
|
|
||||||
logger.warning(f"网络加载失败,尝试本地缓存: {network_error}")
|
|
||||||
|
|
||||||
# 尝试从本地缓存加载
|
|
||||||
try:
|
|
||||||
if self.num_gpus > 1:
|
|
||||||
max_memory = {i: "18GiB" for i in self.device_ids}
|
|
||||||
|
|
||||||
self.model = AutoModel.from_pretrained(
|
|
||||||
self.model_name,
|
|
||||||
trust_remote_code=True,
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
device_map="auto",
|
|
||||||
max_memory=max_memory,
|
|
||||||
low_cpu_mem_usage=True,
|
|
||||||
local_files_only=True # 仅使用本地文件
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.model = AutoModel.from_pretrained(
|
|
||||||
self.model_name,
|
|
||||||
trust_remote_code=True,
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
device_map=self.primary_device,
|
|
||||||
local_files_only=True # 仅使用本地文件
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("模型从本地缓存加载成功")
|
|
||||||
|
|
||||||
except Exception as local_error:
|
|
||||||
logger.error(f"本地缓存加载也失败: {local_error}")
|
|
||||||
raise local_error
|
|
||||||
|
|
||||||
# 加载分词器和处理器
|
|
||||||
try:
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
self.model_name,
|
|
||||||
trust_remote_code=True,
|
|
||||||
local_files_only=False
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Tokenizer网络加载失败,尝试本地: {e}")
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
self.model_name,
|
|
||||||
trust_remote_code=True,
|
|
||||||
local_files_only=True
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.processor = AutoProcessor.from_pretrained(
|
|
||||||
self.model_name,
|
|
||||||
trust_remote_code=True,
|
|
||||||
local_files_only=False
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Processor加载失败,使用tokenizer: {e}")
|
|
||||||
try:
|
|
||||||
self.processor = AutoProcessor.from_pretrained(
|
|
||||||
self.model_name,
|
|
||||||
trust_remote_code=True,
|
|
||||||
local_files_only=True
|
|
||||||
)
|
|
||||||
except Exception as e2:
|
|
||||||
logger.warning(f"Processor本地加载也失败,使用tokenizer: {e2}")
|
|
||||||
self.processor = self.tokenizer
|
|
||||||
|
|
||||||
logger.info("模型加载完成")
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"模型加载失败: {str(e)}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def encode_text_batch(self, texts: List[str]) -> np.ndarray:
|
|
||||||
"""
|
|
||||||
批量编码文本为向量
|
|
||||||
|
|
||||||
Args:
|
|
||||||
texts: 文本列表
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
文本向量数组
|
|
||||||
"""
|
|
||||||
if not texts:
|
|
||||||
return np.array([])
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
# 预处理输入
|
|
||||||
inputs = self.tokenizer(
|
|
||||||
text=texts,
|
|
||||||
return_tensors="pt",
|
|
||||||
padding=True,
|
|
||||||
truncation=True,
|
|
||||||
max_length=512
|
|
||||||
)
|
|
||||||
|
|
||||||
# 将输入移动到主设备
|
|
||||||
inputs = {k: v.to(self.primary_device) for k, v in inputs.items()}
|
|
||||||
|
|
||||||
# 前向传播
|
|
||||||
outputs = self.model(**inputs)
|
|
||||||
embeddings = outputs.last_hidden_state.mean(dim=1)
|
|
||||||
|
|
||||||
# 清理GPU内存
|
|
||||||
del inputs, outputs
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
return embeddings.cpu().numpy().astype(np.float32)
|
|
||||||
|
|
||||||
def encode_image_batch(self, images: List[Union[str, Image.Image]]) -> np.ndarray:
|
|
||||||
"""
|
|
||||||
批量编码图像为向量
|
|
||||||
|
|
||||||
Args:
|
|
||||||
images: 图像路径或PIL图像列表
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
图像向量数组
|
|
||||||
"""
|
|
||||||
if not images:
|
|
||||||
return np.array([])
|
|
||||||
|
|
||||||
# 预处理图像
|
|
||||||
processed_images = []
|
|
||||||
for img in images:
|
|
||||||
if isinstance(img, str):
|
|
||||||
img = Image.open(img).convert('RGB')
|
|
||||||
elif isinstance(img, Image.Image):
|
|
||||||
img = img.convert('RGB')
|
|
||||||
processed_images.append(img)
|
|
||||||
|
|
||||||
try:
|
|
||||||
logger.info(f"处理 {len(processed_images)} 张图像")
|
|
||||||
|
|
||||||
# 使用多模态模型生成图像embedding
|
|
||||||
conversations = []
|
|
||||||
for i in range(len(processed_images)):
|
|
||||||
conversation = [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{"type": "image", "image": processed_images[i]},
|
|
||||||
{"type": "text", "text": "What is in this image?"}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
conversations.append(conversation)
|
|
||||||
|
|
||||||
# 使用processor处理
|
|
||||||
try:
|
|
||||||
texts = []
|
|
||||||
for conv in conversations:
|
|
||||||
text = self.processor.apply_chat_template(conv, tokenize=False, add_generation_prompt=False)
|
|
||||||
texts.append(text)
|
|
||||||
|
|
||||||
# 处理文本和图像
|
|
||||||
inputs = self.processor(
|
|
||||||
text=texts,
|
|
||||||
images=processed_images,
|
|
||||||
return_tensors="pt",
|
|
||||||
padding=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# 移动到GPU
|
|
||||||
inputs = {k: v.to(self.primary_device) for k, v in inputs.items()}
|
|
||||||
|
|
||||||
# 获取模型输出
|
|
||||||
with torch.no_grad():
|
|
||||||
outputs = self.model(**inputs)
|
|
||||||
embeddings = outputs.last_hidden_state.mean(dim=1)
|
|
||||||
|
|
||||||
# 转换为numpy数组
|
|
||||||
embeddings = embeddings.cpu().numpy().astype(np.float32)
|
|
||||||
|
|
||||||
except Exception as inner_e:
|
|
||||||
logger.warning(f"多模态模型图像编码失败: {inner_e}")
|
|
||||||
# 使用零向量作为fallback
|
|
||||||
embedding_dim = 3584
|
|
||||||
embeddings = np.zeros((len(processed_images), embedding_dim), dtype=np.float32)
|
|
||||||
|
|
||||||
logger.info(f"生成图像embeddings: {embeddings.shape}")
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"图像编码失败: {e}")
|
|
||||||
# 返回零向量作为fallback
|
|
||||||
embedding_dim = 3584
|
|
||||||
embeddings = np.zeros((len(processed_images), embedding_dim), dtype=np.float32)
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
def store_texts(self, texts: List[str], metadata: List[Dict] = None) -> List[str]:
|
|
||||||
"""
|
|
||||||
存储文本数据
|
|
||||||
|
|
||||||
Args:
|
|
||||||
texts: 文本列表
|
|
||||||
metadata: 元数据列表
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
存储的ID列表
|
|
||||||
"""
|
|
||||||
if self.vdb is None:
|
|
||||||
logger.warning("VDB不可用,文本数据将不会持久化存储")
|
|
||||||
return []
|
|
||||||
|
|
||||||
logger.info(f"正在存储 {len(texts)} 条文本数据")
|
|
||||||
|
|
||||||
# 分批处理
|
|
||||||
batch_size = 16
|
|
||||||
all_ids = []
|
|
||||||
|
|
||||||
for i in range(0, len(texts), batch_size):
|
|
||||||
batch_texts = texts[i:i+batch_size]
|
|
||||||
batch_metadata = metadata[i:i+batch_size] if metadata else None
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 编码文本
|
|
||||||
vectors = self.encode_text_batch(batch_texts)
|
|
||||||
|
|
||||||
# 存储到VDB
|
|
||||||
ids = self.vdb.store_text_vectors(batch_texts, vectors, batch_metadata)
|
|
||||||
all_ids.extend(ids)
|
|
||||||
|
|
||||||
logger.info(f"已处理 {i + len(batch_texts)}/{len(texts)} 条文本")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"处理文本批次时出错: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
logger.info(f"✅ 文本存储完成,共 {len(all_ids)} 条")
|
|
||||||
return all_ids
|
|
||||||
|
|
||||||
def store_images(self, image_paths: List[str], metadata: List[Dict] = None) -> List[str]:
|
|
||||||
"""
|
|
||||||
存储图像数据
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image_paths: 图像路径列表
|
|
||||||
metadata: 元数据列表
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
存储的ID列表
|
|
||||||
"""
|
|
||||||
if self.vdb is None:
|
|
||||||
logger.warning("VDB不可用,图像数据将不会持久化存储")
|
|
||||||
return []
|
|
||||||
|
|
||||||
logger.info(f"正在存储 {len(image_paths)} 张图像数据")
|
|
||||||
|
|
||||||
# 图像处理使用更小的批次
|
|
||||||
batch_size = 8
|
|
||||||
all_ids = []
|
|
||||||
|
|
||||||
for i in range(0, len(image_paths), batch_size):
|
|
||||||
batch_images = image_paths[i:i+batch_size]
|
|
||||||
batch_metadata = metadata[i:i+batch_size] if metadata else None
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 编码图像
|
|
||||||
vectors = self.encode_image_batch(batch_images)
|
|
||||||
|
|
||||||
# 存储到VDB
|
|
||||||
ids = self.vdb.store_image_vectors(batch_images, vectors, batch_metadata)
|
|
||||||
all_ids.extend(ids)
|
|
||||||
|
|
||||||
logger.info(f"已处理 {i + len(batch_images)}/{len(image_paths)} 张图像")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"处理图像批次时出错: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
logger.info(f"✅ 图像存储完成,共 {len(all_ids)} 条")
|
|
||||||
return all_ids
|
|
||||||
|
|
||||||
def search_text_by_text(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
|
|
||||||
"""文搜文:使用文本查询搜索相似文本"""
|
|
||||||
if self.vdb is None:
|
|
||||||
logger.warning("VDB不可用,无法执行搜索")
|
|
||||||
return []
|
|
||||||
|
|
||||||
logger.info(f"执行文搜文查询: {query}")
|
|
||||||
|
|
||||||
# 编码查询文本
|
|
||||||
query_vector = self.encode_text_batch([query])[0]
|
|
||||||
|
|
||||||
# 在VDB中搜索
|
|
||||||
results = self.vdb.search_text_vectors(query_vector, top_k)
|
|
||||||
|
|
||||||
# 格式化结果
|
|
||||||
formatted_results = []
|
|
||||||
for doc_id, text_content, score, metadata in results:
|
|
||||||
formatted_results.append((text_content, score))
|
|
||||||
|
|
||||||
return formatted_results
|
|
||||||
|
|
||||||
def search_images_by_text(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
|
|
||||||
"""文搜图:使用文本查询搜索相似图像"""
|
|
||||||
if self.vdb is None:
|
|
||||||
logger.warning("VDB不可用,无法执行搜索")
|
|
||||||
return []
|
|
||||||
|
|
||||||
logger.info(f"执行文搜图查询: {query}")
|
|
||||||
|
|
||||||
# 编码查询文本
|
|
||||||
query_vector = self.encode_text_batch([query])[0]
|
|
||||||
|
|
||||||
# 在VDB中搜索图像
|
|
||||||
results = self.vdb.search_image_vectors(query_vector, top_k)
|
|
||||||
|
|
||||||
# 格式化结果
|
|
||||||
formatted_results = []
|
|
||||||
for doc_id, image_path, image_name, score, metadata in results:
|
|
||||||
formatted_results.append((image_path, score))
|
|
||||||
|
|
||||||
return formatted_results
|
|
||||||
|
|
||||||
def search_text_by_image(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]:
|
|
||||||
"""图搜文:使用图像查询搜索相似文本"""
|
|
||||||
if self.vdb is None:
|
|
||||||
logger.warning("VDB不可用,无法执行搜索")
|
|
||||||
return []
|
|
||||||
|
|
||||||
logger.info(f"执行图搜文查询")
|
|
||||||
|
|
||||||
# 编码查询图像
|
|
||||||
query_vector = self.encode_image_batch([query_image])[0]
|
|
||||||
|
|
||||||
# 在VDB中搜索文本
|
|
||||||
results = self.vdb.search_text_vectors(query_vector, top_k)
|
|
||||||
|
|
||||||
# 格式化结果
|
|
||||||
formatted_results = []
|
|
||||||
for doc_id, text_content, score, metadata in results:
|
|
||||||
formatted_results.append((text_content, score))
|
|
||||||
|
|
||||||
return formatted_results
|
|
||||||
|
|
||||||
def search_images_by_image(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]:
|
|
||||||
"""图搜图:使用图像查询搜索相似图像"""
|
|
||||||
if self.vdb is None:
|
|
||||||
logger.warning("VDB不可用,无法执行搜索")
|
|
||||||
return []
|
|
||||||
|
|
||||||
logger.info(f"执行图搜图查询")
|
|
||||||
|
|
||||||
# 编码查询图像
|
|
||||||
query_vector = self.encode_image_batch([query_image])[0]
|
|
||||||
|
|
||||||
# 在VDB中搜索图像
|
|
||||||
results = self.vdb.search_image_vectors(query_vector, top_k)
|
|
||||||
|
|
||||||
# 格式化结果
|
|
||||||
formatted_results = []
|
|
||||||
for doc_id, image_path, image_name, score, metadata in results:
|
|
||||||
formatted_results.append((image_path, score))
|
|
||||||
|
|
||||||
return formatted_results
|
|
||||||
|
|
||||||
# Web应用兼容的方法名称
|
|
||||||
def search_text_to_text(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
|
|
||||||
"""文搜文:Web应用兼容方法"""
|
|
||||||
return self.search_text_by_text(query, top_k)
|
|
||||||
|
|
||||||
def search_text_to_image(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
|
|
||||||
"""文搜图:Web应用兼容方法"""
|
|
||||||
return self.search_images_by_text(query, top_k)
|
|
||||||
|
|
||||||
def search_image_to_text(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]:
|
|
||||||
"""图搜文:Web应用兼容方法"""
|
|
||||||
return self.search_text_by_image(query_image, top_k)
|
|
||||||
|
|
||||||
def search_image_to_image(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]:
|
|
||||||
"""图搜图:Web应用兼容方法"""
|
|
||||||
return self.search_images_by_image(query_image, top_k)
|
|
||||||
|
|
||||||
def get_statistics(self) -> Dict[str, Any]:
|
|
||||||
"""获取系统统计信息"""
|
|
||||||
if self.vdb is None:
|
|
||||||
return {"error": "VDB不可用"}
|
|
||||||
return self.vdb.get_statistics()
|
|
||||||
|
|
||||||
def clear_all_data(self):
|
|
||||||
"""清空所有数据"""
|
|
||||||
if self.vdb is None:
|
|
||||||
logger.warning("VDB不可用,无法清空数据")
|
|
||||||
return
|
|
||||||
self.vdb.clear_all_data()
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
"""关闭系统"""
|
|
||||||
if self.vdb:
|
|
||||||
self.vdb.close()
|
|
||||||
self._clear_gpu_memory()
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
||||||
self.close()
|
|
||||||
|
|
||||||
|
|
||||||
def check_system_info():
|
|
||||||
"""检查系统信息"""
|
|
||||||
print("=== 多模态检索系统信息 ===")
|
|
||||||
|
|
||||||
if not torch.cuda.is_available():
|
|
||||||
print("❌ CUDA不可用")
|
|
||||||
return
|
|
||||||
|
|
||||||
gpu_count = torch.cuda.device_count()
|
|
||||||
print(f"✅ 检测到 {gpu_count} 个GPU")
|
|
||||||
print(f"CUDA版本: {torch.version.cuda}")
|
|
||||||
print(f"PyTorch版本: {torch.__version__}")
|
|
||||||
|
|
||||||
for i in range(gpu_count):
|
|
||||||
gpu_name = torch.cuda.get_device_name(i)
|
|
||||||
gpu_memory = torch.cuda.get_device_properties(i).total_memory / 1024**3
|
|
||||||
print(f"GPU {i}: {gpu_name} ({gpu_memory:.1f}GB)")
|
|
||||||
|
|
||||||
print("========================")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# 检查系统环境
|
|
||||||
check_system_info()
|
|
||||||
|
|
||||||
# 示例使用
|
|
||||||
print("\n正在初始化多模态检索系统...")
|
|
||||||
|
|
||||||
try:
|
|
||||||
retrieval_system = MultimodalRetrievalVDB()
|
|
||||||
print("✅ 系统初始化成功!")
|
|
||||||
|
|
||||||
# 显示统计信息
|
|
||||||
stats = retrieval_system.get_statistics()
|
|
||||||
print(f"\n📊 数据库统计信息: {stats}")
|
|
||||||
|
|
||||||
print("\n🚀 多模态检索系统就绪!")
|
|
||||||
print("支持的检索模式:")
|
|
||||||
print("1. 文搜文: search_text_by_text()")
|
|
||||||
print("2. 文搜图: search_images_by_text()")
|
|
||||||
print("3. 图搜文: search_text_by_image()")
|
|
||||||
print("4. 图搜图: search_images_by_image()")
|
|
||||||
print("5. 存储文本: store_texts()")
|
|
||||||
print("6. 存储图像: store_images()")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ 系统初始化失败: {e}")
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
@ -1,443 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
纯百度VDB多模态检索系统 - 完全替代FAISS
|
|
||||||
支持文搜文、文搜图、图搜文、图搜图四种检索模式
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
from transformers import AutoModel, AutoProcessor, AutoTokenizer
|
|
||||||
from typing import List, Union, Tuple, Dict, Any
|
|
||||||
import os
|
|
||||||
import json
|
|
||||||
from pathlib import Path
|
|
||||||
import logging
|
|
||||||
import gc
|
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
||||||
import threading
|
|
||||||
|
|
||||||
from baidu_vdb_production import BaiduVDBProduction
|
|
||||||
|
|
||||||
# 设置日志
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
class MultimodalRetrievalVDBOnly:
|
|
||||||
"""纯百度VDB多模态检索系统,完全替代FAISS"""
|
|
||||||
|
|
||||||
def __init__(self, model_name: str = "OpenSearch-AI/Ops-MM-embedding-v1-7B",
|
|
||||||
use_all_gpus: bool = True, gpu_ids: List[int] = None, min_memory_gb=12):
|
|
||||||
"""
|
|
||||||
初始化纯VDB多模态检索系统
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_name: 模型名称
|
|
||||||
use_all_gpus: 是否使用所有可用GPU
|
|
||||||
gpu_ids: 指定使用的GPU ID列表
|
|
||||||
min_memory_gb: 最小可用内存(GB)
|
|
||||||
"""
|
|
||||||
self.model_name = model_name
|
|
||||||
|
|
||||||
# 设置GPU设备
|
|
||||||
self._setup_devices(use_all_gpus, gpu_ids, min_memory_gb)
|
|
||||||
|
|
||||||
# 清理GPU内存
|
|
||||||
self._clear_all_gpu_memory()
|
|
||||||
|
|
||||||
logger.info(f"正在加载模型到多GPU: {self.device_ids}")
|
|
||||||
|
|
||||||
# 加载模型和处理器
|
|
||||||
self.model = None
|
|
||||||
self.tokenizer = None
|
|
||||||
self.processor = None
|
|
||||||
self._load_model_multigpu()
|
|
||||||
|
|
||||||
# 初始化百度VDB后端(替代FAISS索引)
|
|
||||||
logger.info("初始化百度VDB后端...")
|
|
||||||
self.vdb = BaiduVDBProduction()
|
|
||||||
logger.info("✅ 百度VDB后端初始化完成")
|
|
||||||
|
|
||||||
# 线程锁
|
|
||||||
self.model_lock = threading.Lock()
|
|
||||||
|
|
||||||
logger.info("✅ 纯VDB多模态检索系统初始化完成")
|
|
||||||
|
|
||||||
def _setup_devices(self, use_all_gpus, gpu_ids, min_memory_gb):
|
|
||||||
"""设置GPU设备"""
|
|
||||||
if not torch.cuda.is_available():
|
|
||||||
raise RuntimeError("CUDA不可用,需要GPU支持")
|
|
||||||
|
|
||||||
total_gpus = torch.cuda.device_count()
|
|
||||||
logger.info(f"检测到 {total_gpus} 个GPU")
|
|
||||||
|
|
||||||
# 获取可用GPU
|
|
||||||
available_gpus = []
|
|
||||||
for i in range(total_gpus):
|
|
||||||
memory_gb = torch.cuda.get_device_properties(i).total_memory / (1024**3)
|
|
||||||
free_memory = torch.cuda.memory_reserved(i) / (1024**3)
|
|
||||||
available_memory = memory_gb - free_memory
|
|
||||||
|
|
||||||
logger.info(f"GPU {i}: {torch.cuda.get_device_properties(i).name} ({memory_gb:.1f}GB)")
|
|
||||||
|
|
||||||
if available_memory >= min_memory_gb:
|
|
||||||
available_gpus.append(i)
|
|
||||||
logger.info(f"GPU {i}: {available_memory:.0f}MB 可用 (合适)")
|
|
||||||
else:
|
|
||||||
logger.info(f"GPU {i}: {available_memory:.0f}MB 可用 (不足)")
|
|
||||||
|
|
||||||
if not available_gpus:
|
|
||||||
raise RuntimeError(f"没有找到满足 {min_memory_gb}GB 内存要求的GPU")
|
|
||||||
|
|
||||||
# 选择使用的GPU
|
|
||||||
if gpu_ids:
|
|
||||||
self.device_ids = [gpu_id for gpu_id in gpu_ids if gpu_id in available_gpus]
|
|
||||||
elif use_all_gpus:
|
|
||||||
self.device_ids = available_gpus
|
|
||||||
else:
|
|
||||||
self.device_ids = [available_gpus[0]]
|
|
||||||
|
|
||||||
if not self.device_ids:
|
|
||||||
raise RuntimeError("没有可用的GPU设备")
|
|
||||||
|
|
||||||
# 设置主设备
|
|
||||||
self.primary_device = f"cuda:{self.device_ids[0]}"
|
|
||||||
torch.cuda.set_device(self.device_ids[0])
|
|
||||||
|
|
||||||
logger.info(f"使用GPU: {self.device_ids}, 主设备: {self.primary_device}")
|
|
||||||
|
|
||||||
def _clear_all_gpu_memory(self):
|
|
||||||
"""清理所有GPU内存"""
|
|
||||||
for device_id in self.device_ids:
|
|
||||||
with torch.cuda.device(device_id):
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
gc.collect()
|
|
||||||
logger.info("所有GPU内存已清理")
|
|
||||||
|
|
||||||
def _load_model_multigpu(self):
|
|
||||||
"""加载模型到多GPU"""
|
|
||||||
try:
|
|
||||||
# 清理GPU内存
|
|
||||||
self._clear_all_gpu_memory()
|
|
||||||
|
|
||||||
logger.info(f"正在加载模型到多GPU: {self.device_ids}")
|
|
||||||
|
|
||||||
# 加载模型
|
|
||||||
self.model = AutoModel.from_pretrained(
|
|
||||||
self.model_name,
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
trust_remote_code=True,
|
|
||||||
device_map="auto"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 加载tokenizer和processor
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True)
|
|
||||||
logger.info("Tokenizer加载成功")
|
|
||||||
|
|
||||||
self.processor = AutoProcessor.from_pretrained(self.model_name, trust_remote_code=True)
|
|
||||||
logger.info("Processor加载成功")
|
|
||||||
|
|
||||||
# 显示设备映射
|
|
||||||
if hasattr(self.model, 'hf_device_map'):
|
|
||||||
logger.info(f"模型已成功加载到设备: {dict(list(self.model.hf_device_map.items())[:10])}")
|
|
||||||
|
|
||||||
self.model.eval()
|
|
||||||
logger.info("多GPU模型加载完成")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"模型加载失败: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def encode_text_batch(self, texts: List[str], batch_size: int = 8) -> np.ndarray:
|
|
||||||
"""批量编码文本"""
|
|
||||||
try:
|
|
||||||
with self.model_lock:
|
|
||||||
all_embeddings = []
|
|
||||||
|
|
||||||
for i in range(0, len(texts), batch_size):
|
|
||||||
batch_texts = texts[i:i + batch_size]
|
|
||||||
|
|
||||||
# 使用processor处理文本
|
|
||||||
inputs = self.processor(
|
|
||||||
text=batch_texts,
|
|
||||||
return_tensors="pt",
|
|
||||||
padding=True,
|
|
||||||
truncation=True,
|
|
||||||
max_length=512
|
|
||||||
)
|
|
||||||
|
|
||||||
# 将输入移动到主设备
|
|
||||||
inputs = {k: v.to(self.primary_device) for k, v in inputs.items()}
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
outputs = self.model(**inputs)
|
|
||||||
embeddings = outputs.last_hidden_state.mean(dim=1)
|
|
||||||
embeddings = embeddings.cpu().numpy()
|
|
||||||
all_embeddings.append(embeddings)
|
|
||||||
|
|
||||||
return np.vstack(all_embeddings)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"文本编码失败: {e}")
|
|
||||||
return np.zeros((len(texts), 3584), dtype=np.float32)
|
|
||||||
|
|
||||||
def encode_image_batch(self, images: List[Union[str, Image.Image]], batch_size: int = 4) -> np.ndarray:
|
|
||||||
"""批量编码图像"""
|
|
||||||
try:
|
|
||||||
with self.model_lock:
|
|
||||||
processed_images = []
|
|
||||||
|
|
||||||
# 处理图像输入
|
|
||||||
for img in images:
|
|
||||||
if isinstance(img, str):
|
|
||||||
if os.path.exists(img):
|
|
||||||
processed_images.append(Image.open(img).convert('RGB'))
|
|
||||||
else:
|
|
||||||
logger.warning(f"图像文件不存在: {img}")
|
|
||||||
processed_images.append(Image.new('RGB', (224, 224), color='white'))
|
|
||||||
elif isinstance(img, Image.Image):
|
|
||||||
processed_images.append(img.convert('RGB'))
|
|
||||||
else:
|
|
||||||
logger.warning(f"不支持的图像类型: {type(img)}")
|
|
||||||
processed_images.append(Image.new('RGB', (224, 224), color='white'))
|
|
||||||
|
|
||||||
all_embeddings = []
|
|
||||||
|
|
||||||
for i in range(0, len(processed_images), batch_size):
|
|
||||||
batch_images = processed_images[i:i + batch_size]
|
|
||||||
|
|
||||||
# 使用processor处理图像
|
|
||||||
inputs = self.processor(
|
|
||||||
images=batch_images,
|
|
||||||
return_tensors="pt",
|
|
||||||
padding=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# 将输入移动到主设备
|
|
||||||
inputs = {k: v.to(self.primary_device) for k, v in inputs.items()}
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
outputs = self.model(**inputs)
|
|
||||||
embeddings = outputs.last_hidden_state.mean(dim=1)
|
|
||||||
embeddings = embeddings.cpu().numpy()
|
|
||||||
all_embeddings.append(embeddings)
|
|
||||||
|
|
||||||
return np.vstack(all_embeddings)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"图像编码失败: {e}")
|
|
||||||
embedding_dim = 3584
|
|
||||||
embeddings = np.zeros((len(images), embedding_dim), dtype=np.float32)
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
def build_text_index_parallel(self, texts: List[str], save_path: str = None):
|
|
||||||
"""
|
|
||||||
构建文本索引(使用VDB替代FAISS)
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
logger.info(f"正在构建文本索引,共 {len(texts)} 条文本")
|
|
||||||
|
|
||||||
# 编码文本
|
|
||||||
embeddings = self.encode_text_batch(texts)
|
|
||||||
|
|
||||||
# 使用VDB存储
|
|
||||||
self.vdb.build_text_index(texts, embeddings)
|
|
||||||
|
|
||||||
logger.info("文本索引构建完成")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"构建文本索引失败: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def build_image_index_parallel(self, image_paths: List[str], save_path: str = None):
|
|
||||||
"""
|
|
||||||
构建图像索引(使用VDB替代FAISS)
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
logger.info(f"正在构建图像索引,共 {len(image_paths)} 张图像")
|
|
||||||
|
|
||||||
# 编码图像
|
|
||||||
embeddings = self.encode_image_batch(image_paths)
|
|
||||||
|
|
||||||
# 使用VDB存储
|
|
||||||
self.vdb.build_image_index(image_paths, embeddings)
|
|
||||||
|
|
||||||
logger.info("图像索引构建完成")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"构建图像索引失败: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def search_text_by_text(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
|
|
||||||
"""文搜文:使用文本查询搜索相似文本"""
|
|
||||||
try:
|
|
||||||
query_embedding = self.encode_text_batch([query])
|
|
||||||
return self.vdb.search_text_by_text(query_embedding[0], top_k)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"文搜文失败: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
def search_images_by_text(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
|
|
||||||
"""文搜图:使用文本查询搜索相似图像"""
|
|
||||||
try:
|
|
||||||
query_embedding = self.encode_text_batch([query])
|
|
||||||
return self.vdb.search_images_by_text(query_embedding[0], top_k)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"文搜图失败: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
def search_images_by_image(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]:
|
|
||||||
"""图搜图:使用图像查询搜索相似图像"""
|
|
||||||
try:
|
|
||||||
query_embedding = self.encode_image_batch([query_image])
|
|
||||||
return self.vdb.search_images_by_image(query_embedding[0], top_k)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"图搜图失败: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
def search_text_by_image(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]:
|
|
||||||
"""图搜文:使用图像查询搜索相似文本"""
|
|
||||||
try:
|
|
||||||
query_embedding = self.encode_image_batch([query_image])
|
|
||||||
return self.vdb.search_text_by_image(query_embedding[0], top_k)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"图搜文失败: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
# Web应用兼容方法
|
|
||||||
def search_by_text(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
|
|
||||||
"""文搜文:Web应用兼容方法"""
|
|
||||||
return self.search_text_by_text(query, top_k)
|
|
||||||
|
|
||||||
def search_by_image(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]:
|
|
||||||
"""图搜图:Web应用兼容方法"""
|
|
||||||
return self.search_images_by_image(query_image, top_k)
|
|
||||||
|
|
||||||
def search_images_by_text_query(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
|
|
||||||
"""文搜图:Web应用兼容方法"""
|
|
||||||
return self.search_images_by_text(query, top_k)
|
|
||||||
|
|
||||||
def search_texts_by_image_query(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]:
|
|
||||||
"""图搜文:Web应用兼容方法"""
|
|
||||||
return self.search_text_by_image(query_image, top_k)
|
|
||||||
|
|
||||||
def get_statistics(self) -> Dict[str, Any]:
|
|
||||||
"""获取系统统计信息"""
|
|
||||||
try:
|
|
||||||
vdb_stats = self.vdb.get_statistics()
|
|
||||||
|
|
||||||
stats = {
|
|
||||||
"model_name": self.model_name,
|
|
||||||
"device_ids": self.device_ids,
|
|
||||||
"primary_device": self.primary_device,
|
|
||||||
"backend": "Baidu VDB (No FAISS)",
|
|
||||||
**vdb_stats
|
|
||||||
}
|
|
||||||
|
|
||||||
return stats
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取统计信息失败: {e}")
|
|
||||||
return {"status": "error", "error": str(e)}
|
|
||||||
|
|
||||||
def clear_all_data(self):
|
|
||||||
"""清空所有数据"""
|
|
||||||
try:
|
|
||||||
self.vdb.clear_all_data()
|
|
||||||
logger.info("✅ 所有数据已清空")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 清空数据失败: {e}")
|
|
||||||
|
|
||||||
def get_gpu_memory_info(self):
|
|
||||||
"""获取所有GPU内存使用信息"""
|
|
||||||
memory_info = {}
|
|
||||||
for device_id in self.device_ids:
|
|
||||||
with torch.cuda.device(device_id):
|
|
||||||
allocated = torch.cuda.memory_allocated() / (1024**3)
|
|
||||||
reserved = torch.cuda.memory_reserved() / (1024**3)
|
|
||||||
total = torch.cuda.get_device_properties(device_id).total_memory / (1024**3)
|
|
||||||
|
|
||||||
memory_info[f"GPU_{device_id}"] = {
|
|
||||||
"allocated_GB": round(allocated, 2),
|
|
||||||
"reserved_GB": round(reserved, 2),
|
|
||||||
"total_GB": round(total, 2),
|
|
||||||
"free_GB": round(total - reserved, 2)
|
|
||||||
}
|
|
||||||
|
|
||||||
return memory_info
|
|
||||||
|
|
||||||
def cleanup(self):
|
|
||||||
"""清理资源"""
|
|
||||||
try:
|
|
||||||
if self.vdb:
|
|
||||||
self.vdb.close()
|
|
||||||
|
|
||||||
self._clear_all_gpu_memory()
|
|
||||||
logger.info("✅ 资源清理完成")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 资源清理失败: {e}")
|
|
||||||
|
|
||||||
def test_vdb_only_system():
|
|
||||||
"""测试纯VDB多模态检索系统"""
|
|
||||||
print("=" * 60)
|
|
||||||
print("测试纯百度VDB多模态检索系统")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
system = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 1. 初始化系统
|
|
||||||
print("1. 初始化纯VDB多模态检索系统...")
|
|
||||||
system = MultimodalRetrievalVDBOnly()
|
|
||||||
print("✅ 系统初始化成功")
|
|
||||||
|
|
||||||
# 2. 构建文本索引
|
|
||||||
print("\n2. 构建文本索引...")
|
|
||||||
test_texts = [
|
|
||||||
"人工智能技术的发展趋势",
|
|
||||||
"机器学习在医疗领域的应用",
|
|
||||||
"深度学习算法优化方法",
|
|
||||||
"计算机视觉技术创新",
|
|
||||||
"自然语言处理最新进展"
|
|
||||||
]
|
|
||||||
|
|
||||||
system.build_text_index_parallel(test_texts)
|
|
||||||
print("✅ 文本索引构建完成")
|
|
||||||
|
|
||||||
# 3. 测试文搜文
|
|
||||||
print("\n3. 测试文搜文...")
|
|
||||||
query = "AI技术"
|
|
||||||
results = system.search_text_by_text(query, top_k=3)
|
|
||||||
print(f"查询: {query}")
|
|
||||||
for i, (text, score) in enumerate(results, 1):
|
|
||||||
print(f" {i}. {text} (相似度: {score:.3f})")
|
|
||||||
|
|
||||||
# 4. 获取统计信息
|
|
||||||
print("\n4. 获取统计信息...")
|
|
||||||
stats = system.get_statistics()
|
|
||||||
print("系统统计:")
|
|
||||||
for key, value in stats.items():
|
|
||||||
print(f" {key}: {value}")
|
|
||||||
|
|
||||||
print(f"\n🎉 纯VDB系统测试完成!")
|
|
||||||
print("✅ 完全移除FAISS依赖")
|
|
||||||
print("✅ 使用百度VDB作为向量数据库")
|
|
||||||
print("✅ 支持多模态检索功能")
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ 测试失败: {e}")
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
return False
|
|
||||||
|
|
||||||
finally:
|
|
||||||
if system:
|
|
||||||
system.cleanup()
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test_vdb_only_system()
|
|
||||||
49
nohup.out
49
nohup.out
@ -1,49 +0,0 @@
|
|||||||
INFO:baidu_bos_manager:✅ BOS连接测试成功
|
|
||||||
INFO:baidu_bos_manager:✅ BOS客户端初始化成功: dmtyz-demo
|
|
||||||
INFO:mongodb_manager:✅ MongoDB连接成功: mmeb
|
|
||||||
INFO:mongodb_manager:✅ MongoDB索引创建完成
|
|
||||||
INFO:__main__:初始化多模态检索系统...
|
|
||||||
INFO:multimodal_retrieval_local:使用GPU: [0, 1]
|
|
||||||
INFO:multimodal_retrieval_local:加载本地模型和处理器: /root/models/Ops-MM-embedding-v1-7B
|
|
||||||
The image processor of type `Qwen2VLImageProcessor` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. Note that this behavior will be extended to all models in a future release.
|
|
||||||
You have video processor config saved in `preprocessor.json` file which is deprecated. Video processor configs should be saved in their own `video_preprocessor.json` file. You can rename the file or load and save the processor back which renames it automatically. Loading from `preprocessor.json` will be removed in v5.0.
|
|
||||||
INFO:multimodal_retrieval_local:Processor类型: <class 'transformers.models.qwen2_vl.processing_qwen2_vl.Qwen2VLProcessor'>
|
|
||||||
INFO:multimodal_retrieval_local:Processor方法: ['__annotations__', '__call__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_auto_class', '_check_special_mm_tokens', '_create_repo', '_get_arguments_from_pretrained', '_get_files_timestamps', '_get_num_multimodal_tokens', '_merge_kwargs', '_upload_modified_files', 'apply_chat_template', 'attributes', 'audio_tokenizer', 'batch_decode', 'chat_template', 'check_argument_for_proper_class', 'decode', 'feature_extractor_class', 'from_args_and_dict', 'from_pretrained', 'get_possibly_dynamic_module', 'get_processor_dict', 'image_processor', 'image_processor_class', 'image_token', 'image_token_id', 'model_input_names', 'optional_attributes', 'optional_call_args', 'post_process_image_text_to_text', 'push_to_hub', 'register_for_auto_class', 'save_pretrained', 'to_dict', 'to_json_file', 'to_json_string', 'tokenizer', 'tokenizer_class', 'validate_init_kwargs', 'video_processor', 'video_processor_class', 'video_token', 'video_token_id']
|
|
||||||
INFO:multimodal_retrieval_local:Image processor类型: <class 'transformers.models.qwen2_vl.image_processing_qwen2_vl_fast.Qwen2VLImageProcessorFast'>
|
|
||||||
INFO:multimodal_retrieval_local:Image processor方法: ['__backends', '__call__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__slotnames__', '__str__', '__subclasshook__', '__weakref__', '_auto_class', '_create_repo', '_further_process_kwargs', '_fuse_mean_std_and_rescale_factor', '_get_files_timestamps', '_prepare_image_like_inputs', '_prepare_images_structure', '_preprocess', '_preprocess_image_like_inputs', '_process_image', '_processor_class', '_set_processor_class', '_upload_modified_files', '_valid_kwargs_names', '_validate_preprocess_kwargs', 'center_crop', 'compile_friendly_resize', 'convert_to_rgb', 'crop_size', 'data_format', 'default_to_square', 'device', 'disable_grouping', 'do_center_crop', 'do_convert_rgb', 'do_normalize', 'do_rescale', 'do_resize', 'fetch_images', 'filter_out_unused_kwargs', 'from_dict', 'from_json_file', 'from_pretrained', 'get_image_processor_dict', 'get_number_of_image_patches', 'image_mean', 'image_processor_type', 'image_std', 'input_data_format', 'max_pixels', 'merge_size', 'min_pixels', 'model_input_names', 'normalize', 'patch_size', 'preprocess', 'push_to_hub', 'register_for_auto_class', 'resample', 'rescale', 'rescale_and_normalize', 'rescale_factor', 'resize', 'return_tensors', 'save_pretrained', 'size', 'temporal_patch_size', 'to_dict', 'to_json_file', 'to_json_string', 'unused_kwargs', 'valid_kwargs']
|
|
||||||
Loading checkpoint shards: 0%| | 0/4 [00:00<?, ?it/s]
Loading checkpoint shards: 25%|██▌ | 1/4 [03:03<09:10, 183.40s/it]
Loading checkpoint shards: 50%|█████ | 2/4 [04:55<04:43, 141.63s/it]
Loading checkpoint shards: 75%|███████▌ | 3/4 [06:56<02:12, 132.26s/it]
Loading checkpoint shards: 100%|██████████| 4/4 [07:13<00:00, 86.72s/it]
Loading checkpoint shards: 100%|██████████| 4/4 [07:13<00:00, 108.47s/it]
|
|
||||||
INFO:multimodal_retrieval_local:向量维度: 3584
|
|
||||||
INFO:multimodal_retrieval_local:模型和处理器加载成功
|
|
||||||
INFO:multimodal_retrieval_local:加载现有索引: /root/mmeb/local_faiss_index.index
|
|
||||||
INFO:multimodal_retrieval_local:索引加载成功,包含0个向量
|
|
||||||
INFO:multimodal_retrieval_local:元数据加载成功,包含0条记录
|
|
||||||
INFO:multimodal_retrieval_local:多模态检索系统初始化完成,使用本地模型: /root/models/Ops-MM-embedding-v1-7B
|
|
||||||
INFO:multimodal_retrieval_local:向量存储路径: /root/mmeb/local_faiss_index
|
|
||||||
INFO:__main__:多模态检索系统初始化完成
|
|
||||||
* Serving Flask app 'web_app_local'
|
|
||||||
* Debug mode: off
|
|
||||||
INFO:werkzeug:[31m[1mWARNING: This is a development server. Do not use it in a production deployment. Use a production WSGI server instead.[0m
|
|
||||||
* Running on all addresses (0.0.0.0)
|
|
||||||
* Running on http://127.0.0.1:5000
|
|
||||||
* Running on http://192.168.48.82:5000
|
|
||||||
INFO:werkzeug:[33mPress CTRL+C to quit[0m
|
|
||||||
INFO:werkzeug:127.0.0.1 - - [22/Sep/2025 04:02:23] "GET / HTTP/1.1" 200 -
|
|
||||||
INFO:werkzeug:127.0.0.1 - - [22/Sep/2025 04:02:23] "GET /api/system_info HTTP/1.1" 200 -
|
|
||||||
INFO:werkzeug:127.0.0.1 - - [22/Sep/2025 04:02:24] "GET /api/system_info HTTP/1.1" 200 -
|
|
||||||
INFO:__main__:处理图像: 微信图片_20250910164839_1_13.jpg (99396 字节)
|
|
||||||
INFO:__main__:成功加载图像: 20250910164839_1_13.jpg, 格式: JPEG, 模式: RGB, 大小: (939, 940)
|
|
||||||
INFO:multimodal_retrieval_local:add_images: 开始添加图像,数量: 1
|
|
||||||
INFO:multimodal_retrieval_local:add_images: 编码图像
|
|
||||||
INFO:multimodal_retrieval_local:encode_image: 开始编码图像,类型: <class 'list'>
|
|
||||||
INFO:multimodal_retrieval_local:encode_image: 图像列表,长度: 1
|
|
||||||
INFO:multimodal_retrieval_local:encode_image: 处理图像输入
|
|
||||||
INFO:multimodal_retrieval_local:encode_image: 图像 0 格式: JPEG, 模式: RGB, 大小: (939, 940)
|
|
||||||
ERROR:multimodal_retrieval_local:encode_image: 处理图像时出错: argument of type 'NoneType' is not iterable
|
|
||||||
ERROR:multimodal_retrieval_local:add_images: 图像编码失败,返回空数组
|
|
||||||
INFO:multimodal_retrieval_local:索引保存成功: /root/mmeb/local_faiss_index.index
|
|
||||||
INFO:multimodal_retrieval_local:元数据保存成功: /root/mmeb/local_faiss_index_metadata.json
|
|
||||||
INFO:werkzeug:127.0.0.1 - - [22/Sep/2025 04:02:50] "POST /api/add_image HTTP/1.1" 200 -
|
|
||||||
INFO:multimodal_retrieval_local:索引保存成功: /root/mmeb/local_faiss_index.index
|
|
||||||
INFO:multimodal_retrieval_local:元数据保存成功: /root/mmeb/local_faiss_index_metadata.json
|
|
||||||
INFO:werkzeug:127.0.0.1 - - [22/Sep/2025 04:02:50] "POST /api/save_index HTTP/1.1" 200 -
|
|
||||||
INFO:werkzeug:127.0.0.1 - - [22/Sep/2025 04:02:51] "GET /api/system_info HTTP/1.1" 200 -
|
|
||||||
235
quick_test.py
235
quick_test.py
@ -1,235 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
快速测试脚本 - 验证多模态检索系统功能
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import logging
|
|
||||||
import traceback
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# 设置日志
|
|
||||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
def test_imports():
|
|
||||||
"""测试关键模块导入"""
|
|
||||||
logger.info("🔍 测试模块导入...")
|
|
||||||
|
|
||||||
try:
|
|
||||||
import torch
|
|
||||||
logger.info(f"✅ PyTorch {torch.__version__}")
|
|
||||||
|
|
||||||
import transformers
|
|
||||||
logger.info(f"✅ Transformers {transformers.__version__}")
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
logger.info(f"✅ NumPy {np.__version__}")
|
|
||||||
|
|
||||||
from PIL import Image
|
|
||||||
logger.info("✅ Pillow")
|
|
||||||
|
|
||||||
import flask
|
|
||||||
logger.info(f"✅ Flask {flask.__version__}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
import pymochow
|
|
||||||
logger.info("✅ PyMochow (百度VDB SDK)")
|
|
||||||
except ImportError:
|
|
||||||
logger.warning("⚠️ PyMochow 未安装,需要运行: pip install pymochow")
|
|
||||||
|
|
||||||
try:
|
|
||||||
import pymongo
|
|
||||||
logger.info("✅ PyMongo")
|
|
||||||
except ImportError:
|
|
||||||
logger.warning("⚠️ PyMongo 未安装,需要运行: pip install pymongo")
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 模块导入失败: {str(e)}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def test_gpu_availability():
|
|
||||||
"""测试GPU可用性"""
|
|
||||||
logger.info("🖥️ 检查GPU环境...")
|
|
||||||
|
|
||||||
try:
|
|
||||||
import torch
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
gpu_count = torch.cuda.device_count()
|
|
||||||
logger.info(f"✅ 检测到 {gpu_count} 个GPU")
|
|
||||||
|
|
||||||
for i in range(gpu_count):
|
|
||||||
gpu_name = torch.cuda.get_device_name(i)
|
|
||||||
gpu_memory = torch.cuda.get_device_properties(i).total_memory / 1024**3
|
|
||||||
logger.info(f" GPU {i}: {gpu_name} ({gpu_memory:.1f}GB)")
|
|
||||||
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
logger.info("ℹ️ 未检测到GPU,将使用CPU")
|
|
||||||
return False
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ GPU检查失败: {str(e)}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def test_baidu_vdb_connection():
|
|
||||||
"""测试百度VDB连接"""
|
|
||||||
logger.info("🔗 测试百度VDB连接...")
|
|
||||||
|
|
||||||
try:
|
|
||||||
import pymochow
|
|
||||||
from pymochow.configuration import Configuration
|
|
||||||
from pymochow.auth.bce_credentials import BceCredentials
|
|
||||||
|
|
||||||
# 连接配置
|
|
||||||
account = "root"
|
|
||||||
api_key = "vdb$yjr9ln3n0td"
|
|
||||||
endpoint = "http://180.76.96.191:5287"
|
|
||||||
|
|
||||||
config = Configuration(
|
|
||||||
credentials=BceCredentials(account, api_key),
|
|
||||||
endpoint=endpoint
|
|
||||||
)
|
|
||||||
|
|
||||||
client = pymochow.MochowClient(config)
|
|
||||||
|
|
||||||
# 测试连接 - 列出数据库
|
|
||||||
databases = client.list_databases()
|
|
||||||
logger.info(f"✅ VDB连接成功,发现 {len(databases)} 个数据库")
|
|
||||||
|
|
||||||
client.close()
|
|
||||||
return True
|
|
||||||
|
|
||||||
except ImportError:
|
|
||||||
logger.error("❌ PyMochow 未安装,无法测试VDB连接")
|
|
||||||
return False
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ VDB连接失败: {str(e)}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def test_model_loading():
|
|
||||||
"""测试模型加载"""
|
|
||||||
logger.info("🤖 测试模型加载...")
|
|
||||||
|
|
||||||
try:
|
|
||||||
from ops_mm_embedding_v1 import OpsMMEmbeddingV1
|
|
||||||
|
|
||||||
logger.info("正在初始化模型...")
|
|
||||||
model = OpsMMEmbeddingV1()
|
|
||||||
|
|
||||||
# 测试文本编码
|
|
||||||
test_texts = ["测试文本"]
|
|
||||||
embeddings = model.embed(texts=test_texts)
|
|
||||||
|
|
||||||
logger.info(f"✅ 模型加载成功,向量维度: {embeddings.shape}")
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 模型加载失败: {str(e)}")
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
return False
|
|
||||||
|
|
||||||
def test_web_app_import():
|
|
||||||
"""测试Web应用导入"""
|
|
||||||
logger.info("🌐 测试Web应用模块...")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 测试导入主要模块
|
|
||||||
from multimodal_retrieval_vdb_only import MultimodalRetrievalVDBOnly
|
|
||||||
logger.info("✅ 多模态检索系统模块")
|
|
||||||
|
|
||||||
from baidu_vdb_production import BaiduVDBProduction
|
|
||||||
logger.info("✅ 百度VDB后端模块")
|
|
||||||
|
|
||||||
# 测试Web应用文件存在
|
|
||||||
web_app_file = Path("web_app_vdb_production.py")
|
|
||||||
if web_app_file.exists():
|
|
||||||
logger.info("✅ Web应用文件存在")
|
|
||||||
else:
|
|
||||||
logger.error("❌ Web应用文件不存在")
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ Web应用模块测试失败: {str(e)}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def create_test_directories():
|
|
||||||
"""创建必要的测试目录"""
|
|
||||||
logger.info("📁 创建测试目录...")
|
|
||||||
|
|
||||||
directories = ["uploads", "sample_images", "text_data"]
|
|
||||||
|
|
||||||
for dir_name in directories:
|
|
||||||
dir_path = Path(dir_name)
|
|
||||||
dir_path.mkdir(exist_ok=True)
|
|
||||||
logger.info(f"✅ 目录已创建: {dir_name}")
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""主测试函数"""
|
|
||||||
logger.info("🚀 开始快速测试...")
|
|
||||||
logger.info("=" * 50)
|
|
||||||
|
|
||||||
test_results = {}
|
|
||||||
|
|
||||||
# 1. 测试模块导入
|
|
||||||
test_results["imports"] = test_imports()
|
|
||||||
|
|
||||||
# 2. 测试GPU环境
|
|
||||||
test_results["gpu"] = test_gpu_availability()
|
|
||||||
|
|
||||||
# 3. 测试VDB连接
|
|
||||||
test_results["vdb"] = test_baidu_vdb_connection()
|
|
||||||
|
|
||||||
# 4. 测试Web应用模块
|
|
||||||
test_results["web_modules"] = test_web_app_import()
|
|
||||||
|
|
||||||
# 5. 创建测试目录
|
|
||||||
create_test_directories()
|
|
||||||
|
|
||||||
# 6. 尝试测试模型加载(可选)
|
|
||||||
if test_results["imports"]:
|
|
||||||
logger.info("\n⚠️ 模型加载测试需要较长时间,是否跳过?")
|
|
||||||
logger.info("如需测试模型,请单独运行模型测试")
|
|
||||||
# test_results["model"] = test_model_loading()
|
|
||||||
|
|
||||||
# 输出测试结果
|
|
||||||
logger.info("\n" + "=" * 50)
|
|
||||||
logger.info("📊 测试结果汇总:")
|
|
||||||
logger.info("=" * 50)
|
|
||||||
|
|
||||||
for test_name, result in test_results.items():
|
|
||||||
status = "✅ 通过" if result else "❌ 失败"
|
|
||||||
test_display = {
|
|
||||||
"imports": "模块导入",
|
|
||||||
"gpu": "GPU环境",
|
|
||||||
"vdb": "VDB连接",
|
|
||||||
"web_modules": "Web模块",
|
|
||||||
"model": "模型加载"
|
|
||||||
}.get(test_name, test_name)
|
|
||||||
|
|
||||||
logger.info(f"{test_display}: {status}")
|
|
||||||
|
|
||||||
# 计算成功率
|
|
||||||
success_count = sum(test_results.values())
|
|
||||||
total_count = len(test_results)
|
|
||||||
success_rate = (success_count / total_count) * 100
|
|
||||||
|
|
||||||
logger.info(f"\n总体成功率: {success_count}/{total_count} ({success_rate:.1f}%)")
|
|
||||||
|
|
||||||
if success_rate >= 75:
|
|
||||||
logger.info("🎉 系统基本就绪!可以启动Web应用进行完整测试")
|
|
||||||
logger.info("运行命令: python web_app_vdb_production.py")
|
|
||||||
else:
|
|
||||||
logger.warning("⚠️ 系统存在问题,请检查失败的测试项")
|
|
||||||
|
|
||||||
return test_results
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@ -5,10 +5,7 @@ accelerate>=0.20.0
|
|||||||
faiss-cpu>=1.7.4
|
faiss-cpu>=1.7.4
|
||||||
numpy>=1.21.0
|
numpy>=1.21.0
|
||||||
Pillow>=9.0.0
|
Pillow>=9.0.0
|
||||||
scikit-learn>=1.3.0
|
|
||||||
tqdm>=4.65.0
|
tqdm>=4.65.0
|
||||||
flask>=2.3.0
|
flask>=2.3.0
|
||||||
werkzeug>=2.3.0
|
werkzeug>=2.3.0
|
||||||
psutil>=5.9.0
|
requests>=2.31.0
|
||||||
pymockow>=1.0.0
|
|
||||||
pymongo>=4.0.0
|
|
||||||
|
|||||||
152
run_tests.py
152
run_tests.py
@ -1,152 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
运行系统测试 - 验证多模态检索系统功能
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import logging
|
|
||||||
import traceback
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# 设置日志
|
|
||||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
def test_imports():
|
|
||||||
"""测试关键模块导入"""
|
|
||||||
logger.info("🔍 测试模块导入...")
|
|
||||||
|
|
||||||
try:
|
|
||||||
import torch
|
|
||||||
logger.info(f"✅ PyTorch {torch.__version__}")
|
|
||||||
|
|
||||||
import transformers
|
|
||||||
logger.info(f"✅ Transformers {transformers.__version__}")
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
logger.info(f"✅ NumPy {np.__version__}")
|
|
||||||
|
|
||||||
from PIL import Image
|
|
||||||
logger.info("✅ Pillow")
|
|
||||||
|
|
||||||
import flask
|
|
||||||
logger.info(f"✅ Flask {flask.__version__}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
import pymochow
|
|
||||||
logger.info("✅ PyMochow (百度VDB SDK)")
|
|
||||||
return True
|
|
||||||
except ImportError:
|
|
||||||
logger.warning("⚠️ PyMochow 未安装")
|
|
||||||
return False
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 模块导入失败: {str(e)}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def test_baidu_vdb_connection():
|
|
||||||
"""测试百度VDB连接"""
|
|
||||||
logger.info("🔗 测试百度VDB连接...")
|
|
||||||
|
|
||||||
try:
|
|
||||||
import pymochow
|
|
||||||
from pymochow.configuration import Configuration
|
|
||||||
from pymochow.auth.bce_credentials import BceCredentials
|
|
||||||
|
|
||||||
# 连接配置
|
|
||||||
account = "root"
|
|
||||||
api_key = "vdb$yjr9ln3n0td"
|
|
||||||
endpoint = "http://180.76.96.191:5287"
|
|
||||||
|
|
||||||
config = Configuration(
|
|
||||||
credentials=BceCredentials(account, api_key),
|
|
||||||
endpoint=endpoint
|
|
||||||
)
|
|
||||||
|
|
||||||
client = pymochow.MochowClient(config)
|
|
||||||
|
|
||||||
# 测试连接
|
|
||||||
databases = client.list_databases()
|
|
||||||
logger.info(f"✅ VDB连接成功,发现 {len(databases)} 个数据库")
|
|
||||||
|
|
||||||
client.close()
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ VDB连接失败: {str(e)}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def test_system_modules():
|
|
||||||
"""测试系统模块"""
|
|
||||||
logger.info("🔧 测试系统模块...")
|
|
||||||
|
|
||||||
try:
|
|
||||||
from multimodal_retrieval_vdb_only import MultimodalRetrievalVDBOnly
|
|
||||||
logger.info("✅ 多模态检索系统")
|
|
||||||
|
|
||||||
from baidu_vdb_production import BaiduVDBProduction
|
|
||||||
logger.info("✅ 百度VDB后端")
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 系统模块测试失败: {str(e)}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def create_directories():
|
|
||||||
"""创建必要目录"""
|
|
||||||
logger.info("📁 创建必要目录...")
|
|
||||||
|
|
||||||
directories = ["uploads", "sample_images", "text_data"]
|
|
||||||
|
|
||||||
for dir_name in directories:
|
|
||||||
dir_path = Path(dir_name)
|
|
||||||
dir_path.mkdir(exist_ok=True)
|
|
||||||
logger.info(f"✅ 目录: {dir_name}")
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""主测试函数"""
|
|
||||||
logger.info("🚀 开始系统测试...")
|
|
||||||
logger.info("=" * 50)
|
|
||||||
|
|
||||||
# 创建目录
|
|
||||||
create_directories()
|
|
||||||
|
|
||||||
# 运行测试
|
|
||||||
results = {}
|
|
||||||
results["imports"] = test_imports()
|
|
||||||
|
|
||||||
if results["imports"]:
|
|
||||||
results["vdb"] = test_baidu_vdb_connection()
|
|
||||||
results["modules"] = test_system_modules()
|
|
||||||
else:
|
|
||||||
logger.error("❌ 基础模块导入失败,跳过其他测试")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# 输出结果
|
|
||||||
logger.info("\n" + "=" * 50)
|
|
||||||
logger.info("📊 测试结果:")
|
|
||||||
logger.info("=" * 50)
|
|
||||||
|
|
||||||
for test_name, result in results.items():
|
|
||||||
status = "✅ 通过" if result else "❌ 失败"
|
|
||||||
logger.info(f"{test_name}: {status}")
|
|
||||||
|
|
||||||
success_count = sum(results.values())
|
|
||||||
total_count = len(results)
|
|
||||||
success_rate = (success_count / total_count) * 100
|
|
||||||
|
|
||||||
logger.info(f"\n成功率: {success_count}/{total_count} ({success_rate:.1f}%)")
|
|
||||||
|
|
||||||
if success_rate >= 75:
|
|
||||||
logger.info("🎉 系统测试通过!")
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
logger.warning("⚠️ 系统存在问题")
|
|
||||||
return False
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
success = main()
|
|
||||||
sys.exit(0 if success else 1)
|
|
||||||
@ -1,91 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
后台启动Web服务器脚本
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import subprocess
|
|
||||||
import signal
|
|
||||||
import time
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
def start_web_server():
|
|
||||||
"""在后台启动Web服务器"""
|
|
||||||
try:
|
|
||||||
logger.info("🚀 启动优化版Web服务器...")
|
|
||||||
|
|
||||||
# 启动Web应用进程
|
|
||||||
process = subprocess.Popen([
|
|
||||||
sys.executable, 'web_app_vdb_production.py'
|
|
||||||
], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
|
||||||
|
|
||||||
logger.info(f"✅ Web服务器已启动,PID: {process.pid}")
|
|
||||||
logger.info("🌐 服务地址: http://127.0.0.1:5000")
|
|
||||||
|
|
||||||
# 等待几秒让服务器完全启动
|
|
||||||
time.sleep(5)
|
|
||||||
|
|
||||||
return process
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 启动Web服务器失败: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def stop_web_server(process):
|
|
||||||
"""停止Web服务器"""
|
|
||||||
if process:
|
|
||||||
try:
|
|
||||||
process.terminate()
|
|
||||||
process.wait(timeout=5)
|
|
||||||
logger.info("✅ Web服务器已停止")
|
|
||||||
except subprocess.TimeoutExpired:
|
|
||||||
process.kill()
|
|
||||||
logger.info("🔥 强制停止Web服务器")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 停止Web服务器失败: {e}")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='Web服务器管理')
|
|
||||||
parser.add_argument('action', choices=['start', 'test'],
|
|
||||||
help='操作: start(启动服务器) 或 test(启动并运行测试)')
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
if args.action == 'start':
|
|
||||||
# 只启动服务器
|
|
||||||
process = start_web_server()
|
|
||||||
if process:
|
|
||||||
try:
|
|
||||||
logger.info("按 Ctrl+C 停止服务器")
|
|
||||||
process.wait()
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
logger.info("🛑 用户停止服务")
|
|
||||||
stop_web_server(process)
|
|
||||||
|
|
||||||
elif args.action == 'test':
|
|
||||||
# 启动服务器并运行测试
|
|
||||||
process = start_web_server()
|
|
||||||
if process:
|
|
||||||
try:
|
|
||||||
# 运行测试
|
|
||||||
logger.info("🧪 运行优化系统测试...")
|
|
||||||
test_result = subprocess.run([
|
|
||||||
sys.executable, 'test_optimized_system.py'
|
|
||||||
], capture_output=True, text=True)
|
|
||||||
|
|
||||||
print(test_result.stdout)
|
|
||||||
if test_result.stderr:
|
|
||||||
print("STDERR:", test_result.stderr)
|
|
||||||
|
|
||||||
logger.info(f"测试完成,退出码: {test_result.returncode}")
|
|
||||||
|
|
||||||
finally:
|
|
||||||
# 停止服务器
|
|
||||||
stop_web_server(process)
|
|
||||||
@ -1,32 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
# 启动多模态检索系统测试
|
|
||||||
|
|
||||||
echo "🚀 启动多模态检索系统测试"
|
|
||||||
echo "================================"
|
|
||||||
|
|
||||||
# 设置Python路径
|
|
||||||
export PYTHONPATH=/root/mmeb:$PYTHONPATH
|
|
||||||
|
|
||||||
# 1. 安装依赖包
|
|
||||||
echo "📦 步骤1: 安装依赖包"
|
|
||||||
pip install pymochow pymongo --quiet
|
|
||||||
|
|
||||||
# 2. 运行快速测试
|
|
||||||
echo "🔍 步骤2: 运行快速测试"
|
|
||||||
python quick_test.py
|
|
||||||
|
|
||||||
# 3. 测试百度VDB连接
|
|
||||||
echo "🔗 步骤3: 测试百度VDB连接"
|
|
||||||
python test_baidu_vdb_connection.py
|
|
||||||
|
|
||||||
# 4. 启动Web应用(可选)
|
|
||||||
echo "🌐 步骤4: 是否启动Web应用?(y/n)"
|
|
||||||
read -p "输入选择: " choice
|
|
||||||
if [ "$choice" = "y" ] || [ "$choice" = "Y" ]; then
|
|
||||||
echo "启动Web应用..."
|
|
||||||
python web_app_vdb_production.py
|
|
||||||
else
|
|
||||||
echo "跳过Web应用启动"
|
|
||||||
fi
|
|
||||||
|
|
||||||
echo "✅ 测试完成!"
|
|
||||||
130
start_web_app.py
130
start_web_app.py
@ -1,130 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
启动Web应用测试脚本
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import logging
|
|
||||||
import subprocess
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# 设置日志
|
|
||||||
logging.basicConfig(level=logging.INFO, format='%(message)s')
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
def check_dependencies():
|
|
||||||
"""检查依赖包"""
|
|
||||||
logger.info("📦 检查依赖包...")
|
|
||||||
|
|
||||||
required_packages = [
|
|
||||||
'torch', 'transformers', 'numpy', 'PIL', 'flask', 'pymochow'
|
|
||||||
]
|
|
||||||
|
|
||||||
missing_packages = []
|
|
||||||
for package in required_packages:
|
|
||||||
try:
|
|
||||||
if package == 'PIL':
|
|
||||||
from PIL import Image
|
|
||||||
else:
|
|
||||||
__import__(package)
|
|
||||||
logger.info(f"✅ {package}")
|
|
||||||
except ImportError:
|
|
||||||
missing_packages.append(package)
|
|
||||||
logger.error(f"❌ {package} 未安装")
|
|
||||||
|
|
||||||
if missing_packages:
|
|
||||||
logger.info("安装缺失的包:")
|
|
||||||
for pkg in missing_packages:
|
|
||||||
if pkg == 'PIL':
|
|
||||||
logger.info("pip install Pillow")
|
|
||||||
elif pkg == 'pymochow':
|
|
||||||
logger.info("pip install pymochow")
|
|
||||||
else:
|
|
||||||
logger.info(f"pip install {pkg}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def test_vdb_connection():
|
|
||||||
"""测试VDB连接"""
|
|
||||||
logger.info("🔗 测试百度VDB连接...")
|
|
||||||
|
|
||||||
try:
|
|
||||||
import pymochow
|
|
||||||
from pymochow.configuration import Configuration
|
|
||||||
from pymochow.auth.bce_credentials import BceCredentials
|
|
||||||
|
|
||||||
config = Configuration(
|
|
||||||
credentials=BceCredentials("root", "vdb$yjr9ln3n0td"),
|
|
||||||
endpoint="http://180.76.96.191:5287"
|
|
||||||
)
|
|
||||||
|
|
||||||
client = pymochow.MochowClient(config)
|
|
||||||
databases = client.list_databases()
|
|
||||||
client.close()
|
|
||||||
|
|
||||||
logger.info(f"✅ VDB连接成功,发现 {len(databases)} 个数据库")
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ VDB连接失败: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def prepare_directories():
|
|
||||||
"""准备必要目录"""
|
|
||||||
logger.info("📁 准备目录...")
|
|
||||||
|
|
||||||
directories = ["uploads", "sample_images", "text_data", "templates"]
|
|
||||||
|
|
||||||
for dir_name in directories:
|
|
||||||
Path(dir_name).mkdir(exist_ok=True)
|
|
||||||
logger.info(f"✅ {dir_name}")
|
|
||||||
|
|
||||||
def start_web_app():
|
|
||||||
"""启动Web应用"""
|
|
||||||
logger.info("🌐 启动Web应用...")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 设置环境变量
|
|
||||||
os.environ['FLASK_APP'] = 'web_app_vdb_production.py'
|
|
||||||
os.environ['FLASK_ENV'] = 'development'
|
|
||||||
|
|
||||||
# 启动Flask应用
|
|
||||||
logger.info("启动地址: http://localhost:5000")
|
|
||||||
logger.info("按 Ctrl+C 停止服务")
|
|
||||||
|
|
||||||
# 直接运行Python文件
|
|
||||||
subprocess.run([sys.executable, 'web_app_vdb_production.py'], check=True)
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
logger.info("🛑 用户停止服务")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ Web应用启动失败: {e}")
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""主函数"""
|
|
||||||
logger.info("🚀 启动多模态检索系统Web应用")
|
|
||||||
logger.info("=" * 50)
|
|
||||||
|
|
||||||
# 1. 检查依赖
|
|
||||||
if not check_dependencies():
|
|
||||||
logger.error("❌ 依赖包检查失败,请先安装缺失的包")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# 2. 测试VDB连接
|
|
||||||
if not test_vdb_connection():
|
|
||||||
logger.error("❌ VDB连接失败,请检查网络和配置")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# 3. 准备目录
|
|
||||||
prepare_directories()
|
|
||||||
|
|
||||||
# 4. 启动Web应用
|
|
||||||
start_web_app()
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
BIN
static/favicon.ico
Normal file
BIN
static/favicon.ico
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 785 KiB |
@ -1,971 +0,0 @@
|
|||||||
<!DOCTYPE html>
|
|
||||||
<html lang="zh-CN">
|
|
||||||
<head>
|
|
||||||
<meta charset="UTF-8">
|
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
|
||||||
<title>多模态检索系统</title>
|
|
||||||
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css" rel="stylesheet">
|
|
||||||
<link href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css" rel="stylesheet">
|
|
||||||
<style>
|
|
||||||
:root {
|
|
||||||
--primary-color: #2563eb;
|
|
||||||
--secondary-color: #64748b;
|
|
||||||
--success-color: #059669;
|
|
||||||
--warning-color: #d97706;
|
|
||||||
--danger-color: #dc2626;
|
|
||||||
--bg-gradient: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
|
||||||
}
|
|
||||||
|
|
||||||
body {
|
|
||||||
background: var(--bg-gradient);
|
|
||||||
min-height: 100vh;
|
|
||||||
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
|
|
||||||
}
|
|
||||||
|
|
||||||
.main-container {
|
|
||||||
background: rgba(255, 255, 255, 0.95);
|
|
||||||
backdrop-filter: blur(10px);
|
|
||||||
border-radius: 20px;
|
|
||||||
box-shadow: 0 20px 40px rgba(0, 0, 0, 0.1);
|
|
||||||
margin: 20px auto;
|
|
||||||
max-width: 1200px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.header {
|
|
||||||
background: linear-gradient(135deg, var(--primary-color), #3b82f6);
|
|
||||||
color: white;
|
|
||||||
padding: 2rem;
|
|
||||||
border-radius: 20px 20px 0 0;
|
|
||||||
text-align: center;
|
|
||||||
}
|
|
||||||
|
|
||||||
.mode-card {
|
|
||||||
background: white;
|
|
||||||
border-radius: 15px;
|
|
||||||
padding: 1.5rem;
|
|
||||||
margin-bottom: 1.5rem;
|
|
||||||
box-shadow: 0 4px 15px rgba(0, 0, 0, 0.1);
|
|
||||||
transition: all 0.3s ease;
|
|
||||||
border: 2px solid transparent;
|
|
||||||
}
|
|
||||||
|
|
||||||
.mode-card:hover {
|
|
||||||
transform: translateY(-5px);
|
|
||||||
box-shadow: 0 8px 25px rgba(0, 0, 0, 0.15);
|
|
||||||
}
|
|
||||||
|
|
||||||
.mode-card.active {
|
|
||||||
border-color: var(--primary-color);
|
|
||||||
background: linear-gradient(135deg, #eff6ff, #dbeafe);
|
|
||||||
}
|
|
||||||
|
|
||||||
.mode-icon {
|
|
||||||
font-size: 2.5rem;
|
|
||||||
margin-bottom: 1rem;
|
|
||||||
display: block;
|
|
||||||
}
|
|
||||||
|
|
||||||
.text-to-text { color: #059669; }
|
|
||||||
.text-to-image { color: #dc2626; }
|
|
||||||
.image-to-text { color: #d97706; }
|
|
||||||
.image-to-image { color: #7c3aed; }
|
|
||||||
|
|
||||||
.search-input {
|
|
||||||
border-radius: 12px;
|
|
||||||
border: 2px solid #e5e7eb;
|
|
||||||
padding: 12px 16px;
|
|
||||||
font-size: 16px;
|
|
||||||
transition: all 0.3s ease;
|
|
||||||
}
|
|
||||||
|
|
||||||
.search-input:focus {
|
|
||||||
border-color: var(--primary-color);
|
|
||||||
box-shadow: 0 0 0 3px rgba(37, 99, 235, 0.1);
|
|
||||||
}
|
|
||||||
|
|
||||||
.btn-primary {
|
|
||||||
background: var(--primary-color);
|
|
||||||
border: none;
|
|
||||||
border-radius: 12px;
|
|
||||||
padding: 12px 24px;
|
|
||||||
font-weight: 600;
|
|
||||||
transition: all 0.3s ease;
|
|
||||||
}
|
|
||||||
|
|
||||||
.btn-primary:hover {
|
|
||||||
background: #1d4ed8;
|
|
||||||
transform: translateY(-2px);
|
|
||||||
}
|
|
||||||
|
|
||||||
.file-upload-area {
|
|
||||||
border: 3px dashed #d1d5db;
|
|
||||||
border-radius: 12px;
|
|
||||||
padding: 3rem;
|
|
||||||
text-align: center;
|
|
||||||
transition: all 0.3s ease;
|
|
||||||
cursor: pointer;
|
|
||||||
}
|
|
||||||
|
|
||||||
.file-upload-area:hover {
|
|
||||||
border-color: var(--primary-color);
|
|
||||||
background: rgba(37, 99, 235, 0.05);
|
|
||||||
}
|
|
||||||
|
|
||||||
.file-upload-area.dragover {
|
|
||||||
border-color: var(--primary-color);
|
|
||||||
background: rgba(37, 99, 235, 0.1);
|
|
||||||
}
|
|
||||||
|
|
||||||
.result-card {
|
|
||||||
background: white;
|
|
||||||
border-radius: 12px;
|
|
||||||
padding: 1.5rem;
|
|
||||||
margin-bottom: 1rem;
|
|
||||||
box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1);
|
|
||||||
border-left: 4px solid var(--primary-color);
|
|
||||||
}
|
|
||||||
|
|
||||||
.result-image {
|
|
||||||
max-width: 200px;
|
|
||||||
max-height: 150px;
|
|
||||||
border-radius: 8px;
|
|
||||||
object-fit: cover;
|
|
||||||
}
|
|
||||||
|
|
||||||
.score-badge {
|
|
||||||
background: var(--success-color);
|
|
||||||
color: white;
|
|
||||||
padding: 4px 12px;
|
|
||||||
border-radius: 20px;
|
|
||||||
font-size: 0.85rem;
|
|
||||||
font-weight: 600;
|
|
||||||
}
|
|
||||||
|
|
||||||
.loading-spinner {
|
|
||||||
display: none;
|
|
||||||
text-align: center;
|
|
||||||
padding: 2rem;
|
|
||||||
}
|
|
||||||
|
|
||||||
.status-indicator {
|
|
||||||
position: fixed;
|
|
||||||
top: 20px;
|
|
||||||
right: 20px;
|
|
||||||
z-index: 1000;
|
|
||||||
}
|
|
||||||
|
|
||||||
.fade-in {
|
|
||||||
animation: fadeIn 0.5s ease-in;
|
|
||||||
}
|
|
||||||
|
|
||||||
@keyframes fadeIn {
|
|
||||||
from { opacity: 0; transform: translateY(20px); }
|
|
||||||
to { opacity: 1; transform: translateY(0); }
|
|
||||||
}
|
|
||||||
|
|
||||||
.query-image {
|
|
||||||
max-width: 300px;
|
|
||||||
max-height: 200px;
|
|
||||||
border-radius: 12px;
|
|
||||||
object-fit: cover;
|
|
||||||
box-shadow: 0 4px 15px rgba(0, 0, 0, 0.2);
|
|
||||||
}
|
|
||||||
</style>
|
|
||||||
</head>
|
|
||||||
<body>
|
|
||||||
<!-- 状态指示器 -->
|
|
||||||
<div class="status-indicator">
|
|
||||||
<div id="statusBadge" class="badge bg-secondary">
|
|
||||||
<i class="fas fa-circle-notch fa-spin"></i> 未初始化
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div class="container-fluid">
|
|
||||||
<div class="main-container">
|
|
||||||
<!-- 头部 -->
|
|
||||||
<div class="header">
|
|
||||||
<h1><i class="fas fa-search"></i> 多模态检索系统</h1>
|
|
||||||
<p class="mb-0">支持文搜图、文搜文、图搜图、图搜文四种检索模式</p>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div class="p-4">
|
|
||||||
<!-- 重新初始化按钮 -->
|
|
||||||
<div class="text-center mb-4">
|
|
||||||
<button id="reinitBtn" class="btn btn-warning">
|
|
||||||
<i class="fas fa-redo"></i> 重新初始化系统
|
|
||||||
</button>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<!-- 检索模式选择 -->
|
|
||||||
<div class="row mb-4" id="modeSelection">
|
|
||||||
<div class="col-md-3">
|
|
||||||
<div class="mode-card text-center" data-mode="text_to_text">
|
|
||||||
<i class="fas fa-file-text mode-icon text-to-text"></i>
|
|
||||||
<h5>文搜文</h5>
|
|
||||||
<p class="text-muted">文本查找相似文本</p>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
<div class="col-md-3">
|
|
||||||
<div class="mode-card text-center" data-mode="text_to_image">
|
|
||||||
<i class="fas fa-image mode-icon text-to-image"></i>
|
|
||||||
<h5>文搜图</h5>
|
|
||||||
<p class="text-muted">文本查找相关图片</p>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
<div class="col-md-3">
|
|
||||||
<div class="mode-card text-center" data-mode="image_to_text">
|
|
||||||
<i class="fas fa-comment mode-icon image-to-text"></i>
|
|
||||||
<h5>图搜文</h5>
|
|
||||||
<p class="text-muted">图片查找相关文本</p>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
<div class="col-md-3">
|
|
||||||
<div class="mode-card text-center" data-mode="image_to_image">
|
|
||||||
<i class="fas fa-images mode-icon image-to-image"></i>
|
|
||||||
<h5>图搜图</h5>
|
|
||||||
<p class="text-muted">图片查找相似图片</p>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<!-- 数据管理界面 -->
|
|
||||||
<div class="row mb-4" id="dataManagement">
|
|
||||||
<div class="col-12">
|
|
||||||
<div class="card">
|
|
||||||
<div class="card-header">
|
|
||||||
<h5><i class="fas fa-database"></i> 数据管理</h5>
|
|
||||||
<small class="text-muted">上传和管理检索数据库</small>
|
|
||||||
</div>
|
|
||||||
<div class="card-body">
|
|
||||||
<div class="row">
|
|
||||||
<!-- 批量上传图片 -->
|
|
||||||
<div class="col-md-6">
|
|
||||||
<div class="upload-section">
|
|
||||||
<h6><i class="fas fa-images text-primary"></i> 批量上传图片</h6>
|
|
||||||
<div class="file-upload-area" id="batchImageUpload">
|
|
||||||
<i class="fas fa-cloud-upload-alt fa-2x text-muted mb-2"></i>
|
|
||||||
<p>拖拽多张图片到此处或点击选择</p>
|
|
||||||
<input type="file" id="batchImageFiles" multiple accept="image/*" style="display: none;">
|
|
||||||
<button class="btn btn-outline-primary btn-sm mt-2" onclick="document.getElementById('batchImageFiles').click()">
|
|
||||||
<i class="fas fa-folder-open"></i> 选择图片
|
|
||||||
</button>
|
|
||||||
</div>
|
|
||||||
<div id="imageUploadProgress" class="mt-2" style="display: none;">
|
|
||||||
<div class="progress">
|
|
||||||
<div class="progress-bar" role="progressbar" style="width: 0%"></div>
|
|
||||||
</div>
|
|
||||||
<small class="text-muted mt-1 d-block">上传进度: <span id="imageProgressText">0/0</span></small>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<!-- 批量上传文本 -->
|
|
||||||
<div class="col-md-6">
|
|
||||||
<div class="upload-section">
|
|
||||||
<h6><i class="fas fa-file-text text-success"></i> 批量上传文本</h6>
|
|
||||||
<div class="mb-3">
|
|
||||||
<textarea id="batchTextInput" class="form-control" rows="8"
|
|
||||||
placeholder="请输入文本数据,每行一条文本记录... 例如: 这是第一条文本记录 这是第二条文本记录 这是第三条文本记录"></textarea>
|
|
||||||
</div>
|
|
||||||
<div class="d-flex gap-2">
|
|
||||||
<button id="uploadTextsBtn" class="btn btn-success">
|
|
||||||
<i class="fas fa-upload"></i> 上传文本
|
|
||||||
</button>
|
|
||||||
<button class="btn btn-outline-secondary" onclick="document.getElementById('textFile').click()">
|
|
||||||
<i class="fas fa-file-import"></i> 从文件导入
|
|
||||||
</button>
|
|
||||||
<input type="file" id="textFile" accept=".txt,.csv" style="display: none;">
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<!-- 数据统计和管理 -->
|
|
||||||
<div class="row mt-4">
|
|
||||||
<div class="col-md-8">
|
|
||||||
<div class="d-flex gap-3">
|
|
||||||
<button id="buildIndexBtn" class="btn btn-warning" disabled>
|
|
||||||
<i class="fas fa-cogs"></i> 构建索引
|
|
||||||
</button>
|
|
||||||
<button id="viewDataBtn" class="btn btn-info">
|
|
||||||
<i class="fas fa-list"></i> 查看数据
|
|
||||||
</button>
|
|
||||||
<button id="clearDataBtn" class="btn btn-danger">
|
|
||||||
<i class="fas fa-trash"></i> 清空数据
|
|
||||||
</button>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
<div class="col-md-4">
|
|
||||||
<div id="dataStats" class="text-end">
|
|
||||||
<small class="text-muted">
|
|
||||||
图片: <span id="imageCount">0</span> 张 |
|
|
||||||
文本: <span id="textCount">0</span> 条
|
|
||||||
</small>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<!-- 搜索界面 -->
|
|
||||||
<div id="searchInterface" style="display: none;">
|
|
||||||
<!-- 文本搜索 -->
|
|
||||||
<div id="textSearch" class="search-panel" style="display: none;">
|
|
||||||
<div class="row">
|
|
||||||
<div class="col-md-8">
|
|
||||||
<input type="text" id="textQuery" class="form-control search-input"
|
|
||||||
placeholder="请输入搜索文本...">
|
|
||||||
</div>
|
|
||||||
<div class="col-md-2">
|
|
||||||
<select id="textTopK" class="form-select search-input">
|
|
||||||
<option value="3">Top 3</option>
|
|
||||||
<option value="5" selected>Top 5</option>
|
|
||||||
<option value="10">Top 10</option>
|
|
||||||
</select>
|
|
||||||
</div>
|
|
||||||
<div class="col-md-2">
|
|
||||||
<button id="textSearchBtn" class="btn btn-primary w-100">
|
|
||||||
<i class="fas fa-search"></i> 搜索
|
|
||||||
</button>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<!-- 图片搜索 -->
|
|
||||||
<div id="imageSearch" class="search-panel" style="display: none;">
|
|
||||||
<div class="row">
|
|
||||||
<div class="col-md-8">
|
|
||||||
<div class="file-upload-area" id="fileUploadArea">
|
|
||||||
<i class="fas fa-cloud-upload-alt fa-3x text-muted mb-3"></i>
|
|
||||||
<h5>拖拽图片到此处或点击选择</h5>
|
|
||||||
<p class="text-muted">支持 PNG, JPG, JPEG, GIF, BMP, WebP 格式</p>
|
|
||||||
<input type="file" id="imageFile" accept="image/*" style="display: none;">
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
<div class="col-md-2">
|
|
||||||
<select id="imageTopK" class="form-select search-input">
|
|
||||||
<option value="3">Top 3</option>
|
|
||||||
<option value="5" selected>Top 5</option>
|
|
||||||
<option value="10">Top 10</option>
|
|
||||||
</select>
|
|
||||||
</div>
|
|
||||||
<div class="col-md-2">
|
|
||||||
<button id="imageSearchBtn" class="btn btn-primary w-100" disabled>
|
|
||||||
<i class="fas fa-search"></i> 搜索
|
|
||||||
</button>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<!-- 加载动画 -->
|
|
||||||
<div class="loading-spinner" id="loadingSpinner">
|
|
||||||
<div class="spinner-border text-primary" role="status">
|
|
||||||
<span class="visually-hidden">Loading...</span>
|
|
||||||
</div>
|
|
||||||
<p class="mt-2">正在搜索中...</p>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<!-- 搜索结果 -->
|
|
||||||
<div id="searchResults" class="mt-4"></div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<script src="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/js/bootstrap.bundle.min.js"></script>
|
|
||||||
<script>
|
|
||||||
let currentMode = null;
|
|
||||||
let systemInitialized = false;
|
|
||||||
|
|
||||||
// 重新初始化系统
|
|
||||||
document.getElementById('reinitBtn').addEventListener('click', async function() {
|
|
||||||
const btn = this;
|
|
||||||
const originalText = btn.innerHTML;
|
|
||||||
|
|
||||||
btn.innerHTML = '<i class="fas fa-spinner fa-spin"></i> 重新初始化中...';
|
|
||||||
btn.disabled = true;
|
|
||||||
|
|
||||||
try {
|
|
||||||
const response = await fetch('/api/init', {
|
|
||||||
method: 'POST',
|
|
||||||
headers: {'Content-Type': 'application/json'}
|
|
||||||
});
|
|
||||||
|
|
||||||
const data = await response.json();
|
|
||||||
|
|
||||||
if (data.success) {
|
|
||||||
systemInitialized = true;
|
|
||||||
document.getElementById('statusBadge').innerHTML =
|
|
||||||
'<i class="fas fa-check-circle"></i> 已重新初始化';
|
|
||||||
document.getElementById('statusBadge').className = 'badge bg-success';
|
|
||||||
|
|
||||||
showAlert('success', `系统重新初始化成功!GPU: ${data.gpu_count} 个`);
|
|
||||||
} else {
|
|
||||||
throw new Error(data.message);
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
showAlert('danger', '重新初始化失败: ' + error.message);
|
|
||||||
} finally {
|
|
||||||
btn.innerHTML = originalText;
|
|
||||||
btn.disabled = false;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
// 模式选择
|
|
||||||
document.querySelectorAll('.mode-card').forEach(card => {
|
|
||||||
card.addEventListener('click', function() {
|
|
||||||
|
|
||||||
// 更新选中状态
|
|
||||||
document.querySelectorAll('.mode-card').forEach(c => c.classList.remove('active'));
|
|
||||||
this.classList.add('active');
|
|
||||||
|
|
||||||
currentMode = this.dataset.mode;
|
|
||||||
setupSearchInterface(currentMode);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
// 设置搜索界面
|
|
||||||
function setupSearchInterface(mode) {
|
|
||||||
document.getElementById('searchInterface').style.display = 'block';
|
|
||||||
document.getElementById('textSearch').style.display = 'none';
|
|
||||||
document.getElementById('imageSearch').style.display = 'none';
|
|
||||||
document.getElementById('searchResults').innerHTML = '';
|
|
||||||
|
|
||||||
if (mode === 'text_to_text' || mode === 'text_to_image') {
|
|
||||||
document.getElementById('textSearch').style.display = 'block';
|
|
||||||
document.getElementById('textQuery').placeholder =
|
|
||||||
mode === 'text_to_text' ? '请输入要搜索的文本...' : '请输入要搜索图片的描述...';
|
|
||||||
} else {
|
|
||||||
document.getElementById('imageSearch').style.display = 'block';
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 文本搜索
|
|
||||||
document.getElementById('textSearchBtn').addEventListener('click', performTextSearch);
|
|
||||||
document.getElementById('textQuery').addEventListener('keypress', function(e) {
|
|
||||||
if (e.key === 'Enter') performTextSearch();
|
|
||||||
});
|
|
||||||
|
|
||||||
async function performTextSearch() {
|
|
||||||
const query = document.getElementById('textQuery').value.trim();
|
|
||||||
const topK = parseInt(document.getElementById('textTopK').value);
|
|
||||||
|
|
||||||
if (!query) {
|
|
||||||
showAlert('warning', '请输入搜索文本');
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
showLoading(true);
|
|
||||||
|
|
||||||
try {
|
|
||||||
const endpoint = currentMode === 'text_to_text' ? '/api/search/text_to_text' : '/api/search/text_to_image';
|
|
||||||
const response = await fetch(endpoint, {
|
|
||||||
method: 'POST',
|
|
||||||
headers: {'Content-Type': 'application/json'},
|
|
||||||
body: JSON.stringify({query, top_k: topK})
|
|
||||||
});
|
|
||||||
|
|
||||||
const data = await response.json();
|
|
||||||
|
|
||||||
if (data.success) {
|
|
||||||
displayResults(data, currentMode);
|
|
||||||
} else {
|
|
||||||
throw new Error(data.message);
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
showAlert('danger', '搜索失败: ' + error.message);
|
|
||||||
} finally {
|
|
||||||
showLoading(false);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 图片上传处理
|
|
||||||
const fileUploadArea = document.getElementById('fileUploadArea');
|
|
||||||
const imageFile = document.getElementById('imageFile');
|
|
||||||
|
|
||||||
fileUploadArea.addEventListener('click', () => imageFile.click());
|
|
||||||
fileUploadArea.addEventListener('dragover', handleDragOver);
|
|
||||||
fileUploadArea.addEventListener('drop', handleDrop);
|
|
||||||
imageFile.addEventListener('change', handleFileSelect);
|
|
||||||
|
|
||||||
function handleDragOver(e) {
|
|
||||||
e.preventDefault();
|
|
||||||
fileUploadArea.classList.add('dragover');
|
|
||||||
}
|
|
||||||
|
|
||||||
function handleDrop(e) {
|
|
||||||
e.preventDefault();
|
|
||||||
fileUploadArea.classList.remove('dragover');
|
|
||||||
const files = e.dataTransfer.files;
|
|
||||||
if (files.length > 0) {
|
|
||||||
handleFile(files[0]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function handleFileSelect(e) {
|
|
||||||
const file = e.target.files[0];
|
|
||||||
if (file) handleFile(file);
|
|
||||||
}
|
|
||||||
|
|
||||||
function handleFile(file) {
|
|
||||||
if (!file.type.startsWith('image/')) {
|
|
||||||
showAlert('warning', '请选择图片文件');
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const reader = new FileReader();
|
|
||||||
reader.onload = function(e) {
|
|
||||||
fileUploadArea.innerHTML = `
|
|
||||||
<img src="${e.target.result}" class="query-image mb-3">
|
|
||||||
<p class="text-success"><i class="fas fa-check"></i> 图片已选择: ${file.name}</p>
|
|
||||||
`;
|
|
||||||
document.getElementById('imageSearchBtn').disabled = false;
|
|
||||||
};
|
|
||||||
reader.readAsDataURL(file);
|
|
||||||
}
|
|
||||||
|
|
||||||
// 图片搜索
|
|
||||||
document.getElementById('imageSearchBtn').addEventListener('click', async function() {
|
|
||||||
const file = imageFile.files[0];
|
|
||||||
const topK = parseInt(document.getElementById('imageTopK').value);
|
|
||||||
|
|
||||||
if (!file) {
|
|
||||||
showAlert('warning', '请选择图片文件');
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
showLoading(true);
|
|
||||||
|
|
||||||
try {
|
|
||||||
const formData = new FormData();
|
|
||||||
formData.append('image', file);
|
|
||||||
formData.append('top_k', topK);
|
|
||||||
|
|
||||||
const endpoint = currentMode === 'image_to_text' ? '/api/search/image_to_text' : '/api/search/image_to_image';
|
|
||||||
const response = await fetch(endpoint, {
|
|
||||||
method: 'POST',
|
|
||||||
body: formData
|
|
||||||
});
|
|
||||||
|
|
||||||
const data = await response.json();
|
|
||||||
|
|
||||||
if (data.success) {
|
|
||||||
displayResults(data, currentMode);
|
|
||||||
} else {
|
|
||||||
throw new Error(data.message);
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
showAlert('danger', '搜索失败: ' + error.message);
|
|
||||||
} finally {
|
|
||||||
showLoading(false);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
// 显示结果
|
|
||||||
function displayResults(data, mode) {
|
|
||||||
const resultsContainer = document.getElementById('searchResults');
|
|
||||||
|
|
||||||
let html = `
|
|
||||||
<div class="fade-in">
|
|
||||||
<div class="d-flex justify-content-between align-items-center mb-3">
|
|
||||||
<h4><i class="fas fa-search-plus"></i> 搜索结果</h4>
|
|
||||||
<div>
|
|
||||||
<span class="badge bg-info">找到 ${data.result_count} 个结果</span>
|
|
||||||
<span class="badge bg-secondary">耗时 ${data.search_time}s</span>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
`;
|
|
||||||
|
|
||||||
if (data.query_image) {
|
|
||||||
html += `
|
|
||||||
<div class="result-card">
|
|
||||||
<h6><i class="fas fa-image"></i> 查询图片</h6>
|
|
||||||
<img src="data:image/jpeg;base64,${data.query_image}" class="query-image">
|
|
||||||
</div>
|
|
||||||
`;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (data.query) {
|
|
||||||
html += `
|
|
||||||
<div class="result-card">
|
|
||||||
<h6><i class="fas fa-quote-left"></i> 查询文本</h6>
|
|
||||||
<p class="mb-0">"${data.query}"</p>
|
|
||||||
</div>
|
|
||||||
`;
|
|
||||||
}
|
|
||||||
|
|
||||||
data.results.forEach((result, index) => {
|
|
||||||
html += '<div class="result-card">';
|
|
||||||
|
|
||||||
if (mode === 'text_to_image' || mode === 'image_to_image') {
|
|
||||||
html += `
|
|
||||||
<div class="row">
|
|
||||||
<div class="col-md-3">
|
|
||||||
<img src="data:image/jpeg;base64,${result.image_base64}"
|
|
||||||
class="result-image" alt="Result ${index + 1}">
|
|
||||||
</div>
|
|
||||||
<div class="col-md-9">
|
|
||||||
<div class="d-flex justify-content-between align-items-start">
|
|
||||||
<h6><i class="fas fa-image"></i> ${result.filename}</h6>
|
|
||||||
<span class="score-badge">相似度: ${(result.score * 100).toFixed(1)}%</span>
|
|
||||||
</div>
|
|
||||||
<p class="text-muted mb-0">路径: ${result.image_path}</p>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
`;
|
|
||||||
} else {
|
|
||||||
html += `
|
|
||||||
<div class="d-flex justify-content-between align-items-start">
|
|
||||||
<div>
|
|
||||||
<h6><i class="fas fa-file-text"></i> 结果 ${index + 1}</h6>
|
|
||||||
<p class="mb-0">${result.text || result}</p>
|
|
||||||
</div>
|
|
||||||
<span class="score-badge">相似度: ${((result.score || 0.95) * 100).toFixed(1)}%</span>
|
|
||||||
</div>
|
|
||||||
`;
|
|
||||||
}
|
|
||||||
|
|
||||||
html += '</div>';
|
|
||||||
});
|
|
||||||
|
|
||||||
html += '</div>';
|
|
||||||
resultsContainer.innerHTML = html;
|
|
||||||
}
|
|
||||||
|
|
||||||
// 工具函数
|
|
||||||
function showLoading(show) {
|
|
||||||
document.getElementById('loadingSpinner').style.display = show ? 'block' : 'none';
|
|
||||||
}
|
|
||||||
|
|
||||||
function showAlert(type, message) {
|
|
||||||
const alertDiv = document.createElement('div');
|
|
||||||
alertDiv.className = `alert alert-${type} alert-dismissible fade show`;
|
|
||||||
alertDiv.innerHTML = `
|
|
||||||
${message}
|
|
||||||
<button type="button" class="btn-close" data-bs-dismiss="alert"></button>
|
|
||||||
`;
|
|
||||||
|
|
||||||
document.querySelector('.main-container .p-4').insertBefore(alertDiv, document.querySelector('.main-container .p-4').firstChild);
|
|
||||||
|
|
||||||
setTimeout(() => {
|
|
||||||
if (alertDiv.parentNode) {
|
|
||||||
alertDiv.remove();
|
|
||||||
}
|
|
||||||
}, 5000);
|
|
||||||
}
|
|
||||||
|
|
||||||
// 检查系统状态
|
|
||||||
async function checkStatus() {
|
|
||||||
try {
|
|
||||||
const response = await fetch('/api/status');
|
|
||||||
const data = await response.json();
|
|
||||||
|
|
||||||
if (data.initialized) {
|
|
||||||
systemInitialized = true;
|
|
||||||
document.getElementById('statusBadge').innerHTML =
|
|
||||||
'<i class="fas fa-check-circle"></i> 已初始化';
|
|
||||||
document.getElementById('statusBadge').className = 'badge bg-success';
|
|
||||||
} else {
|
|
||||||
document.getElementById('statusBadge').innerHTML =
|
|
||||||
'<i class="fas fa-exclamation-triangle"></i> 未初始化';
|
|
||||||
document.getElementById('statusBadge').className = 'badge bg-warning';
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
console.log('Status check failed:', error);
|
|
||||||
document.getElementById('statusBadge').innerHTML =
|
|
||||||
'<i class="fas fa-times-circle"></i> 连接失败';
|
|
||||||
document.getElementById('statusBadge').className = 'badge bg-danger';
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 页面加载时检查状态
|
|
||||||
checkStatus();
|
|
||||||
|
|
||||||
// 设置数据管理功能事件绑定
|
|
||||||
setupDataManagement();
|
|
||||||
|
|
||||||
function setupDataManagement() {
|
|
||||||
// 批量图片上传事件
|
|
||||||
const batchImageUpload = document.getElementById('batchImageUpload');
|
|
||||||
const batchImageFiles = document.getElementById('batchImageFiles');
|
|
||||||
|
|
||||||
// 拖拽上传
|
|
||||||
batchImageUpload.addEventListener('dragover', function(e) {
|
|
||||||
e.preventDefault();
|
|
||||||
this.classList.add('dragover');
|
|
||||||
});
|
|
||||||
|
|
||||||
batchImageUpload.addEventListener('dragleave', function(e) {
|
|
||||||
e.preventDefault();
|
|
||||||
this.classList.remove('dragover');
|
|
||||||
});
|
|
||||||
|
|
||||||
batchImageUpload.addEventListener('drop', function(e) {
|
|
||||||
e.preventDefault();
|
|
||||||
this.classList.remove('dragover');
|
|
||||||
const files = Array.from(e.dataTransfer.files).filter(file => file.type.startsWith('image/'));
|
|
||||||
if (files.length > 0) {
|
|
||||||
uploadBatchImages(files);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
batchImageFiles.addEventListener('change', function(e) {
|
|
||||||
if (e.target.files.length > 0) {
|
|
||||||
uploadBatchImages(Array.from(e.target.files));
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
// 批量文本上传
|
|
||||||
document.getElementById('uploadTextsBtn').addEventListener('click', function() {
|
|
||||||
const textData = document.getElementById('batchTextInput').value.trim();
|
|
||||||
if (textData) {
|
|
||||||
const texts = textData.split('\n').filter(line => line.trim());
|
|
||||||
if (texts.length > 0) {
|
|
||||||
uploadBatchTexts(texts);
|
|
||||||
} else {
|
|
||||||
showAlert('warning', '请输入有效的文本数据');
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
showAlert('warning', '请输入文本数据');
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
// 从文件导入文本
|
|
||||||
document.getElementById('textFile').addEventListener('change', function(e) {
|
|
||||||
const file = e.target.files[0];
|
|
||||||
if (file) {
|
|
||||||
const reader = new FileReader();
|
|
||||||
reader.onload = function(e) {
|
|
||||||
document.getElementById('batchTextInput').value = e.target.result;
|
|
||||||
};
|
|
||||||
reader.readAsText(file);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
// 构建索引
|
|
||||||
document.getElementById('buildIndexBtn').addEventListener('click', buildIndex);
|
|
||||||
|
|
||||||
// 查看数据
|
|
||||||
document.getElementById('viewDataBtn').addEventListener('click', viewData);
|
|
||||||
|
|
||||||
// 清空数据
|
|
||||||
document.getElementById('clearDataBtn').addEventListener('click', clearData);
|
|
||||||
|
|
||||||
// 初始化时更新数据统计
|
|
||||||
updateDataStats();
|
|
||||||
}
|
|
||||||
|
|
||||||
// 批量上传图片
|
|
||||||
async function uploadBatchImages(files) {
|
|
||||||
const progressDiv = document.getElementById('imageUploadProgress');
|
|
||||||
const progressBar = progressDiv.querySelector('.progress-bar');
|
|
||||||
const progressText = document.getElementById('imageProgressText');
|
|
||||||
|
|
||||||
progressDiv.style.display = 'block';
|
|
||||||
progressText.textContent = `0/${files.length}`;
|
|
||||||
progressBar.style.width = '0%';
|
|
||||||
|
|
||||||
const formData = new FormData();
|
|
||||||
files.forEach(file => {
|
|
||||||
formData.append('files', file);
|
|
||||||
});
|
|
||||||
|
|
||||||
try {
|
|
||||||
const response = await fetch('/api/upload/images', {
|
|
||||||
method: 'POST',
|
|
||||||
body: formData
|
|
||||||
});
|
|
||||||
|
|
||||||
const data = await response.json();
|
|
||||||
|
|
||||||
if (data.success) {
|
|
||||||
progressBar.style.width = '100%';
|
|
||||||
progressText.textContent = `${files.length}/${files.length}`;
|
|
||||||
showAlert('success', `成功上传 ${data.uploaded_count} 张图片`);
|
|
||||||
updateDataStats();
|
|
||||||
document.getElementById('buildIndexBtn').disabled = false;
|
|
||||||
} else {
|
|
||||||
showAlert('danger', `上传失败: ${data.message}`);
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
showAlert('danger', `上传错误: ${error.message}`);
|
|
||||||
} finally {
|
|
||||||
setTimeout(() => {
|
|
||||||
progressDiv.style.display = 'none';
|
|
||||||
}, 2000);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 批量上传文本
|
|
||||||
async function uploadBatchTexts(texts) {
|
|
||||||
try {
|
|
||||||
const response = await fetch('/api/upload/texts', {
|
|
||||||
method: 'POST',
|
|
||||||
headers: {
|
|
||||||
'Content-Type': 'application/json',
|
|
||||||
},
|
|
||||||
body: JSON.stringify({ texts: texts })
|
|
||||||
});
|
|
||||||
|
|
||||||
const data = await response.json();
|
|
||||||
|
|
||||||
if (data.success) {
|
|
||||||
showAlert('success', `成功上传 ${data.uploaded_count} 条文本`);
|
|
||||||
document.getElementById('batchTextInput').value = '';
|
|
||||||
updateDataStats();
|
|
||||||
document.getElementById('buildIndexBtn').disabled = false;
|
|
||||||
} else {
|
|
||||||
showAlert('danger', `上传失败: ${data.message}`);
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
showAlert('danger', `上传错误: ${error.message}`);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 构建索引
|
|
||||||
async function buildIndex() {
|
|
||||||
const btn = document.getElementById('buildIndexBtn');
|
|
||||||
const originalText = btn.innerHTML;
|
|
||||||
btn.innerHTML = '<i class="fas fa-spinner fa-spin"></i> 构建中...';
|
|
||||||
btn.disabled = true;
|
|
||||||
|
|
||||||
try {
|
|
||||||
const response = await fetch('/api/build_index', {
|
|
||||||
method: 'POST'
|
|
||||||
});
|
|
||||||
|
|
||||||
const data = await response.json();
|
|
||||||
|
|
||||||
if (data.success) {
|
|
||||||
showAlert('success', '索引构建完成!现在可以进行搜索了');
|
|
||||||
} else {
|
|
||||||
showAlert('danger', `索引构建失败: ${data.message}`);
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
showAlert('danger', `构建错误: ${error.message}`);
|
|
||||||
} finally {
|
|
||||||
btn.innerHTML = originalText;
|
|
||||||
btn.disabled = false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 查看数据
|
|
||||||
async function viewData() {
|
|
||||||
try {
|
|
||||||
const response = await fetch('/api/data/list');
|
|
||||||
const data = await response.json();
|
|
||||||
|
|
||||||
if (data.success) {
|
|
||||||
let content = '<div class="row">';
|
|
||||||
|
|
||||||
// 显示图片数据
|
|
||||||
if (data.images && data.images.length > 0) {
|
|
||||||
content += '<div class="col-md-6"><h6>图片数据 (' + data.images.length + ')</h6>';
|
|
||||||
content += '<div class="list-group" style="max-height: 300px; overflow-y: auto;">';
|
|
||||||
data.images.forEach(img => {
|
|
||||||
content += `<div class="list-group-item d-flex justify-content-between align-items-center">
|
|
||||||
<span>${img}</span>
|
|
||||||
<img src="/uploads/${img}" class="img-thumbnail" style="width: 50px; height: 50px; object-fit: cover;">
|
|
||||||
</div>`;
|
|
||||||
});
|
|
||||||
content += '</div></div>';
|
|
||||||
}
|
|
||||||
|
|
||||||
// 显示文本数据
|
|
||||||
if (data.texts && data.texts.length > 0) {
|
|
||||||
content += '<div class="col-md-6"><h6>文本数据 (' + data.texts.length + ')</h6>';
|
|
||||||
content += '<div class="list-group" style="max-height: 300px; overflow-y: auto;">';
|
|
||||||
data.texts.forEach((text, index) => {
|
|
||||||
const shortText = text.length > 50 ? text.substring(0, 50) + '...' : text;
|
|
||||||
content += `<div class="list-group-item">
|
|
||||||
<small class="text-muted">#${index + 1}</small><br>
|
|
||||||
${shortText}
|
|
||||||
</div>`;
|
|
||||||
});
|
|
||||||
content += '</div></div>';
|
|
||||||
}
|
|
||||||
|
|
||||||
content += '</div>';
|
|
||||||
|
|
||||||
showModal('数据列表', content);
|
|
||||||
} else {
|
|
||||||
showAlert('danger', `获取数据失败: ${data.message}`);
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
showAlert('danger', `获取数据错误: ${error.message}`);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 清空数据
|
|
||||||
async function clearData() {
|
|
||||||
if (!confirm('确定要清空所有数据吗?此操作不可恢复!')) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
|
||||||
const response = await fetch('/api/data/clear', {
|
|
||||||
method: 'POST'
|
|
||||||
});
|
|
||||||
|
|
||||||
const data = await response.json();
|
|
||||||
|
|
||||||
if (data.success) {
|
|
||||||
showAlert('success', '数据已清空');
|
|
||||||
updateDataStats();
|
|
||||||
document.getElementById('buildIndexBtn').disabled = true;
|
|
||||||
} else {
|
|
||||||
showAlert('danger', `清空失败: ${data.message}`);
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
showAlert('danger', `清空错误: ${error.message}`);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 更新数据统计
|
|
||||||
async function updateDataStats() {
|
|
||||||
try {
|
|
||||||
const response = await fetch('/api/data/stats');
|
|
||||||
const data = await response.json();
|
|
||||||
|
|
||||||
if (data.success) {
|
|
||||||
document.getElementById('imageCount').textContent = data.image_count || 0;
|
|
||||||
document.getElementById('textCount').textContent = data.text_count || 0;
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
console.log('获取数据统计失败:', error);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 显示模态框
|
|
||||||
function showModal(title, content) {
|
|
||||||
const modalHtml = `
|
|
||||||
<div class="modal fade" id="dataModal" tabindex="-1">
|
|
||||||
<div class="modal-dialog modal-lg">
|
|
||||||
<div class="modal-content">
|
|
||||||
<div class="modal-header">
|
|
||||||
<h5 class="modal-title">${title}</h5>
|
|
||||||
<button type="button" class="btn-close" data-bs-dismiss="modal"></button>
|
|
||||||
</div>
|
|
||||||
<div class="modal-body">
|
|
||||||
${content}
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
`;
|
|
||||||
|
|
||||||
// 移除已存在的模态框
|
|
||||||
const existingModal = document.getElementById('dataModal');
|
|
||||||
if (existingModal) {
|
|
||||||
existingModal.remove();
|
|
||||||
}
|
|
||||||
|
|
||||||
document.body.insertAdjacentHTML('beforeend', modalHtml);
|
|
||||||
const modal = new bootstrap.Modal(document.getElementById('dataModal'));
|
|
||||||
modal.show();
|
|
||||||
}
|
|
||||||
</script>
|
|
||||||
</body>
|
|
||||||
</html>
|
|
||||||
@ -4,6 +4,7 @@
|
|||||||
<meta charset="UTF-8">
|
<meta charset="UTF-8">
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
<title>本地多模态检索系统 - FAISS</title>
|
<title>本地多模态检索系统 - FAISS</title>
|
||||||
|
<link rel="icon" href="/favicon.ico" />
|
||||||
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css" rel="stylesheet">
|
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css" rel="stylesheet">
|
||||||
<link href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css" rel="stylesheet">
|
<link href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css" rel="stylesheet">
|
||||||
<style>
|
<style>
|
||||||
@ -153,6 +154,25 @@
|
|||||||
right: 20px;
|
right: 20px;
|
||||||
z-index: 1000;
|
z-index: 1000;
|
||||||
}
|
}
|
||||||
|
.status-bar {
|
||||||
|
position: fixed;
|
||||||
|
top: 20px;
|
||||||
|
right: 120px;
|
||||||
|
background: rgba(255,255,255,0.9);
|
||||||
|
border-radius: 12px;
|
||||||
|
padding: 8px 12px;
|
||||||
|
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
|
||||||
|
font-size: 12px;
|
||||||
|
color: #334155;
|
||||||
|
display: flex;
|
||||||
|
gap: 12px;
|
||||||
|
align-items: center;
|
||||||
|
}
|
||||||
|
.status-item {
|
||||||
|
display: flex;
|
||||||
|
gap: 6px;
|
||||||
|
align-items: center;
|
||||||
|
}
|
||||||
|
|
||||||
.fade-in {
|
.fade-in {
|
||||||
animation: fadeIn 0.5s ease-in;
|
animation: fadeIn 0.5s ease-in;
|
||||||
@ -179,6 +199,21 @@
|
|||||||
<i class="fas fa-circle-notch fa-spin"></i> 未初始化
|
<i class="fas fa-circle-notch fa-spin"></i> 未初始化
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
<!-- 右上角状态栏 -->
|
||||||
|
<div class="status-bar">
|
||||||
|
<div class="status-item" title="Vector Dimension">
|
||||||
|
<i class="fas fa-ruler-combined text-primary"></i>
|
||||||
|
<span id="statusVectorDim">-</span>
|
||||||
|
</div>
|
||||||
|
<div class="status-item" title="Total Vectors">
|
||||||
|
<i class="fas fa-database text-success"></i>
|
||||||
|
<span id="statusTotalVectors">-</span>
|
||||||
|
</div>
|
||||||
|
<div class="status-item" title="Server Time (UTC)">
|
||||||
|
<i class="fas fa-clock text-warning"></i>
|
||||||
|
<span id="statusServerTime">-</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
<div class="container-fluid">
|
<div class="container-fluid">
|
||||||
<div class="main-container">
|
<div class="main-container">
|
||||||
@ -693,8 +728,28 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 页面加载时检查状态
|
// 轮询 /api/stats 更新右上角状态栏
|
||||||
|
async function updateStatusBar() {
|
||||||
|
try {
|
||||||
|
const res = await fetch('/api/stats');
|
||||||
|
const data = await res.json();
|
||||||
|
if (data && data.success) {
|
||||||
|
const dim = data.debug?.vector_dimension ?? data.stats?.vector_dimension ?? '-';
|
||||||
|
const total = data.debug?.total_vectors ?? data.stats?.total_vectors ?? '-';
|
||||||
|
const time = data.debug?.server_time ?? '-';
|
||||||
|
document.getElementById('statusVectorDim').textContent = dim;
|
||||||
|
document.getElementById('statusTotalVectors').textContent = total;
|
||||||
|
document.getElementById('statusServerTime').textContent = time.replace('T',' ').replace('Z','');
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
// 忽略一次失败,等待下次轮询
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 页面加载时检查状态并开始轮询状态栏
|
||||||
checkStatus();
|
checkStatus();
|
||||||
|
updateStatusBar();
|
||||||
|
setInterval(updateStatusBar, 5000);
|
||||||
|
|
||||||
// 设置数据管理功能事件绑定
|
// 设置数据管理功能事件绑定
|
||||||
setupDataManagement();
|
setupDataManagement();
|
||||||
|
|||||||
@ -1,104 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
测试所有四种多模态检索模式
|
|
||||||
"""
|
|
||||||
|
|
||||||
from multimodal_retrieval_multigpu import MultiGPUMultimodalRetrieval
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
import os
|
|
||||||
|
|
||||||
def test_all_retrieval_modes():
|
|
||||||
print('正在初始化多GPU多模态检索系统...')
|
|
||||||
retrieval = MultiGPUMultimodalRetrieval()
|
|
||||||
|
|
||||||
# 准备测试数据
|
|
||||||
test_texts = [
|
|
||||||
"一只可爱的小猫",
|
|
||||||
"美丽的风景照片",
|
|
||||||
"现代建筑设计",
|
|
||||||
"colorful flowers in garden"
|
|
||||||
]
|
|
||||||
|
|
||||||
test_images = [
|
|
||||||
'sample_images/1755677101_1__.jpg',
|
|
||||||
'sample_images/1755677101_2__.jpg',
|
|
||||||
'sample_images/1755677101_3__.jpg',
|
|
||||||
'sample_images/1755677101_4__.jpg'
|
|
||||||
]
|
|
||||||
|
|
||||||
# 验证测试图像存在
|
|
||||||
existing_images = [img for img in test_images if os.path.exists(img)]
|
|
||||||
if not existing_images:
|
|
||||||
print("❌ 没有找到测试图像文件")
|
|
||||||
return
|
|
||||||
|
|
||||||
print(f"找到 {len(existing_images)} 张测试图像")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 1. 构建文本索引
|
|
||||||
print('\n=== 构建文本索引 ===')
|
|
||||||
retrieval.build_text_index_parallel(test_texts)
|
|
||||||
print('✅ 文本索引构建完成')
|
|
||||||
|
|
||||||
# 2. 构建图像索引
|
|
||||||
print('\n=== 构建图像索引 ===')
|
|
||||||
retrieval.build_image_index_parallel(existing_images)
|
|
||||||
print('✅ 图像索引构建完成')
|
|
||||||
|
|
||||||
# 3. 测试文本到文本检索
|
|
||||||
print('\n=== 测试文本到文本检索 ===')
|
|
||||||
query = "小动物"
|
|
||||||
results = retrieval.search_text_by_text(query, top_k=3)
|
|
||||||
print(f'查询: "{query}"')
|
|
||||||
for i, (text, score) in enumerate(results):
|
|
||||||
print(f' {i+1}. {text} (相似度: {score:.4f})')
|
|
||||||
|
|
||||||
# 4. 测试文本到图像检索
|
|
||||||
print('\n=== 测试文本到图像检索 ===')
|
|
||||||
query = "beautiful image"
|
|
||||||
results = retrieval.search_images_by_text(query, top_k=3)
|
|
||||||
print(f'查询: "{query}"')
|
|
||||||
for i, (image_path, score) in enumerate(results):
|
|
||||||
print(f' {i+1}. {image_path} (相似度: {score:.4f})')
|
|
||||||
|
|
||||||
# 5. 测试图像到文本检索
|
|
||||||
print('\n=== 测试图像到文本检索 ===')
|
|
||||||
query_image = existing_images[0]
|
|
||||||
results = retrieval.search_text_by_image(query_image, top_k=3)
|
|
||||||
print(f'查询图像: {query_image}')
|
|
||||||
for i, (text, score) in enumerate(results):
|
|
||||||
print(f' {i+1}. {text} (相似度: {score:.4f})')
|
|
||||||
|
|
||||||
# 6. 测试图像到图像检索
|
|
||||||
print('\n=== 测试图像到图像检索 ===')
|
|
||||||
query_image = existing_images[0]
|
|
||||||
results = retrieval.search_images_by_image(query_image, top_k=3)
|
|
||||||
print(f'查询图像: {query_image}')
|
|
||||||
for i, (image_path, score) in enumerate(results):
|
|
||||||
print(f' {i+1}. {image_path} (相似度: {score:.4f})')
|
|
||||||
|
|
||||||
print('\n✅ 所有四种检索模式测试完成!')
|
|
||||||
|
|
||||||
# 7. 测试Web应用兼容的方法名
|
|
||||||
print('\n=== 测试Web应用兼容方法 ===')
|
|
||||||
try:
|
|
||||||
results = retrieval.search_text_to_image("test query", top_k=2)
|
|
||||||
print('✅ search_text_to_image 方法正常')
|
|
||||||
|
|
||||||
results = retrieval.search_image_to_text(existing_images[0], top_k=2)
|
|
||||||
print('✅ search_image_to_text 方法正常')
|
|
||||||
|
|
||||||
results = retrieval.search_image_to_image(existing_images[0], top_k=2)
|
|
||||||
print('✅ search_image_to_image 方法正常')
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f'❌ Web应用兼容方法测试失败: {e}')
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f'❌ 测试过程中出现错误: {e}')
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test_all_retrieval_modes()
|
|
||||||
@ -1,285 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
百度VDB向量数据库连接测试脚本
|
|
||||||
测试数据库连接、基本操作和可用性
|
|
||||||
"""
|
|
||||||
|
|
||||||
import pymochow
|
|
||||||
from pymochow.configuration import Configuration
|
|
||||||
from pymochow.auth.bce_credentials import BceCredentials
|
|
||||||
from pymochow.model.schema import Schema, Field, VectorIndex, HNSWParams
|
|
||||||
from pymochow.model.enum import FieldType, IndexType, MetricType
|
|
||||||
from pymochow.model.table import Partition, Row
|
|
||||||
import traceback
|
|
||||||
import time
|
|
||||||
|
|
||||||
# 数据库连接配置
|
|
||||||
ACCOUNT = 'root'
|
|
||||||
API_KEY = 'vdb$yjr9ln3n0td' # 你提供的密码作为API密钥
|
|
||||||
ENDPOINT = 'http://180.76.96.191:5287' # 使用标准端口5287
|
|
||||||
|
|
||||||
def test_connection():
|
|
||||||
"""测试数据库连接"""
|
|
||||||
print("=" * 50)
|
|
||||||
print("测试1: 数据库连接")
|
|
||||||
print("=" * 50)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 创建配置和客户端
|
|
||||||
config = Configuration(
|
|
||||||
credentials=BceCredentials(ACCOUNT, API_KEY),
|
|
||||||
endpoint=ENDPOINT
|
|
||||||
)
|
|
||||||
client = pymochow.MochowClient(config)
|
|
||||||
|
|
||||||
print(f"✓ 成功创建客户端连接")
|
|
||||||
print(f" - 账户: {ACCOUNT}")
|
|
||||||
print(f" - 端点: {ENDPOINT}")
|
|
||||||
|
|
||||||
return client
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"✗ 连接失败: {str(e)}")
|
|
||||||
print(f"详细错误: {traceback.format_exc()}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def test_list_databases(client):
|
|
||||||
"""测试数据库列表查询"""
|
|
||||||
print("\n" + "=" * 50)
|
|
||||||
print("测试2: 查询数据库列表")
|
|
||||||
print("=" * 50)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 查询数据库列表
|
|
||||||
db_list = client.list_databases()
|
|
||||||
print(f"✓ 成功查询数据库列表")
|
|
||||||
print(f" - 数据库数量: {len(db_list)}")
|
|
||||||
|
|
||||||
if db_list:
|
|
||||||
print(" - 数据库列表:")
|
|
||||||
for i, db in enumerate(db_list, 1):
|
|
||||||
print(f" {i}. {db.database_name}")
|
|
||||||
else:
|
|
||||||
print(" - 当前没有数据库")
|
|
||||||
|
|
||||||
return db_list
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"✗ 查询数据库列表失败: {str(e)}")
|
|
||||||
print(f"详细错误: {traceback.format_exc()}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def test_create_database(client):
|
|
||||||
"""测试创建数据库"""
|
|
||||||
print("\n" + "=" * 50)
|
|
||||||
print("测试3: 创建测试数据库")
|
|
||||||
print("=" * 50)
|
|
||||||
|
|
||||||
test_db_name = "test_db_" + str(int(time.time()))
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 创建测试数据库
|
|
||||||
db = client.create_database(test_db_name)
|
|
||||||
print(f"✓ 成功创建数据库: {test_db_name}")
|
|
||||||
|
|
||||||
return db, test_db_name
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"✗ 创建数据库失败: {str(e)}")
|
|
||||||
print(f"详细错误: {traceback.format_exc()}")
|
|
||||||
return None, test_db_name
|
|
||||||
|
|
||||||
def test_create_table(client, db, db_name):
|
|
||||||
"""测试创建表"""
|
|
||||||
print("\n" + "=" * 50)
|
|
||||||
print("测试4: 创建测试表")
|
|
||||||
print("=" * 50)
|
|
||||||
|
|
||||||
table_name = "test_table"
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 定义表字段
|
|
||||||
fields = []
|
|
||||||
fields.append(Field("id", FieldType.STRING, primary_key=True,
|
|
||||||
partition_key=True, auto_increment=False, not_null=True))
|
|
||||||
fields.append(Field("text", FieldType.STRING, not_null=True))
|
|
||||||
fields.append(Field("vector", FieldType.FLOAT_VECTOR, not_null=True, dimension=3))
|
|
||||||
|
|
||||||
# 定义索引
|
|
||||||
indexes = []
|
|
||||||
indexes.append(
|
|
||||||
VectorIndex(
|
|
||||||
index_name="vector_idx",
|
|
||||||
index_type=IndexType.HNSW,
|
|
||||||
field="vector",
|
|
||||||
metric_type=MetricType.L2,
|
|
||||||
params=HNSWParams(m=16, efconstruction=100),
|
|
||||||
auto_build=True
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 创建表
|
|
||||||
table = db.create_table(
|
|
||||||
table_name=table_name,
|
|
||||||
replication=2, # 最小副本数为2
|
|
||||||
partition=Partition(partition_num=1), # 单分区测试
|
|
||||||
schema=Schema(fields=fields, indexes=indexes)
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"✓ 成功创建表: {table_name}")
|
|
||||||
print(f" - 字段数量: {len(fields)}")
|
|
||||||
print(f" - 索引数量: {len(indexes)}")
|
|
||||||
|
|
||||||
return table, table_name
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"✗ 创建表失败: {str(e)}")
|
|
||||||
print(f"详细错误: {traceback.format_exc()}")
|
|
||||||
return None, table_name
|
|
||||||
|
|
||||||
def test_insert_data(table):
|
|
||||||
"""测试插入数据"""
|
|
||||||
print("\n" + "=" * 50)
|
|
||||||
print("测试5: 插入测试数据")
|
|
||||||
print("=" * 50)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 准备测试数据
|
|
||||||
rows = [
|
|
||||||
Row(id='001', text='测试文本1', vector=[0.1, 0.2, 0.3]),
|
|
||||||
Row(id='002', text='测试文本2', vector=[0.4, 0.5, 0.6]),
|
|
||||||
Row(id='003', text='测试文本3', vector=[0.7, 0.8, 0.9])
|
|
||||||
]
|
|
||||||
|
|
||||||
# 插入数据
|
|
||||||
table.upsert(rows)
|
|
||||||
print(f"✓ 成功插入 {len(rows)} 条测试数据")
|
|
||||||
|
|
||||||
# 等待一下让数据生效
|
|
||||||
time.sleep(2)
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"✗ 插入数据失败: {str(e)}")
|
|
||||||
print(f"详细错误: {traceback.format_exc()}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def test_query_data(table):
|
|
||||||
"""测试查询数据"""
|
|
||||||
print("\n" + "=" * 50)
|
|
||||||
print("测试6: 查询测试数据")
|
|
||||||
print("=" * 50)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 标量查询
|
|
||||||
primary_key = {'id': '001'}
|
|
||||||
result = table.query(primary_key=primary_key, retrieve_vector=True)
|
|
||||||
|
|
||||||
if result:
|
|
||||||
print(f"✓ 成功查询数据")
|
|
||||||
print(f" - ID: {result.get('id')}")
|
|
||||||
print(f" - 文本: {result.get('text')}")
|
|
||||||
print(f" - 向量: {result.get('vector')}")
|
|
||||||
else:
|
|
||||||
print("✗ 查询结果为空")
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"✗ 查询数据失败: {str(e)}")
|
|
||||||
print(f"详细错误: {traceback.format_exc()}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def test_cleanup(client, db_name, table_name):
|
|
||||||
"""清理测试数据"""
|
|
||||||
print("\n" + "=" * 50)
|
|
||||||
print("测试7: 清理测试数据")
|
|
||||||
print("=" * 50)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 获取数据库对象
|
|
||||||
db = client.database(db_name)
|
|
||||||
|
|
||||||
# 删除表
|
|
||||||
try:
|
|
||||||
db.drop_table(table_name)
|
|
||||||
print(f"✓ 成功删除表: {table_name}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"⚠ 删除表失败: {str(e)}")
|
|
||||||
|
|
||||||
# 删除数据库
|
|
||||||
try:
|
|
||||||
client.drop_database(db_name)
|
|
||||||
print(f"✓ 成功删除数据库: {db_name}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"⚠ 删除数据库失败: {str(e)}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"⚠ 清理过程出现错误: {str(e)}")
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""主测试函数"""
|
|
||||||
print("百度VDB向量数据库可用性测试")
|
|
||||||
print("测试配置:")
|
|
||||||
print(f" - 用户名: {ACCOUNT}")
|
|
||||||
print(f" - 服务器: {ENDPOINT}")
|
|
||||||
print(f" - 测试时间: {time.strftime('%Y-%m-%d %H:%M:%S')}")
|
|
||||||
|
|
||||||
client = None
|
|
||||||
db = None
|
|
||||||
db_name = None
|
|
||||||
table = None
|
|
||||||
table_name = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 测试1: 连接
|
|
||||||
client = test_connection()
|
|
||||||
if not client:
|
|
||||||
print("\n❌ 数据库连接失败,无法继续测试")
|
|
||||||
return
|
|
||||||
|
|
||||||
# 测试2: 查询数据库列表
|
|
||||||
db_list = test_list_databases(client)
|
|
||||||
|
|
||||||
# 测试3: 创建数据库
|
|
||||||
db, db_name = test_create_database(client)
|
|
||||||
if not db:
|
|
||||||
print("\n❌ 无法创建数据库,跳过后续测试")
|
|
||||||
return
|
|
||||||
|
|
||||||
# 测试4: 创建表
|
|
||||||
table, table_name = test_create_table(client, db, db_name)
|
|
||||||
if not table:
|
|
||||||
print("\n❌ 无法创建表,跳过后续测试")
|
|
||||||
return
|
|
||||||
|
|
||||||
# 测试5: 插入数据
|
|
||||||
if test_insert_data(table):
|
|
||||||
# 测试6: 查询数据
|
|
||||||
test_query_data(table)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"\n❌ 测试过程中发生未预期错误: {str(e)}")
|
|
||||||
print(f"详细错误: {traceback.format_exc()}")
|
|
||||||
|
|
||||||
finally:
|
|
||||||
# 测试7: 清理
|
|
||||||
if client and db_name:
|
|
||||||
test_cleanup(client, db_name, table_name)
|
|
||||||
|
|
||||||
# 关闭连接
|
|
||||||
if client:
|
|
||||||
try:
|
|
||||||
client.close()
|
|
||||||
print(f"\n✓ 已关闭数据库连接")
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
print("\n" + "=" * 50)
|
|
||||||
print("测试完成")
|
|
||||||
print("=" * 50)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@ -1,349 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
百度VDB连接可用性测试
|
|
||||||
测试连接、数据库操作、表操作和向量检索功能
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
import logging
|
|
||||||
import traceback
|
|
||||||
from typing import List, Dict, Any
|
|
||||||
|
|
||||||
import pymochow
|
|
||||||
from pymochow.configuration import Configuration
|
|
||||||
from pymochow.auth.bce_credentials import BceCredentials
|
|
||||||
from pymochow.model.schema import Schema, Field, VectorIndex, HNSWParams
|
|
||||||
from pymochow.model.enum import FieldType, IndexType, MetricType, TableState
|
|
||||||
from pymochow.model.table import Row, Partition
|
|
||||||
from pymochow.model.table import VectorTopkSearchRequest, VectorSearchConfig, FloatVector
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
# 设置日志
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
class BaiduVDBConnectionTest:
|
|
||||||
"""百度VDB连接测试类"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
"""初始化测试连接"""
|
|
||||||
# 您提供的连接信息
|
|
||||||
self.account = "root"
|
|
||||||
self.api_key = "vdb$yjr9ln3n0td"
|
|
||||||
self.endpoint = "http://180.76.96.191:5287"
|
|
||||||
|
|
||||||
self.client = None
|
|
||||||
self.test_db_name = "test_connection_db"
|
|
||||||
self.test_table_name = "test_vectors"
|
|
||||||
|
|
||||||
def connect(self) -> bool:
|
|
||||||
"""测试连接"""
|
|
||||||
try:
|
|
||||||
logger.info("正在测试百度VDB连接...")
|
|
||||||
logger.info(f"端点: {self.endpoint}")
|
|
||||||
logger.info(f"账户: {self.account}")
|
|
||||||
|
|
||||||
# 创建配置
|
|
||||||
config = Configuration(
|
|
||||||
credentials=BceCredentials(self.account, self.api_key),
|
|
||||||
endpoint=self.endpoint
|
|
||||||
)
|
|
||||||
|
|
||||||
# 创建客户端
|
|
||||||
self.client = pymochow.MochowClient(config)
|
|
||||||
logger.info("✅ VDB客户端创建成功")
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ VDB连接失败: {str(e)}")
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
return False
|
|
||||||
|
|
||||||
def test_database_operations(self) -> bool:
|
|
||||||
"""测试数据库操作"""
|
|
||||||
try:
|
|
||||||
logger.info("\n=== 测试数据库操作 ===")
|
|
||||||
|
|
||||||
# 1. 列出现有数据库
|
|
||||||
logger.info("1. 查询数据库列表...")
|
|
||||||
databases = self.client.list_databases()
|
|
||||||
logger.info(f"现有数据库数量: {len(databases)}")
|
|
||||||
for db in databases:
|
|
||||||
logger.info(f" - {db.database_name}")
|
|
||||||
|
|
||||||
# 2. 创建测试数据库
|
|
||||||
logger.info(f"2. 创建测试数据库: {self.test_db_name}")
|
|
||||||
try:
|
|
||||||
# 先尝试删除可能存在的测试数据库
|
|
||||||
try:
|
|
||||||
self.client.drop_database(self.test_db_name)
|
|
||||||
logger.info("删除了已存在的测试数据库")
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# 创建新数据库
|
|
||||||
db = self.client.create_database(self.test_db_name)
|
|
||||||
logger.info(f"✅ 数据库创建成功: {db.database_name}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 数据库创建失败: {str(e)}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# 3. 验证数据库创建
|
|
||||||
logger.info("3. 验证数据库创建...")
|
|
||||||
databases = self.client.list_databases()
|
|
||||||
db_names = [db.database_name for db in databases]
|
|
||||||
if self.test_db_name in db_names:
|
|
||||||
logger.info("✅ 数据库验证成功")
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
logger.error("❌ 数据库验证失败")
|
|
||||||
return False
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 数据库操作测试失败: {str(e)}")
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
return False
|
|
||||||
|
|
||||||
def test_table_operations(self) -> bool:
|
|
||||||
"""测试表操作"""
|
|
||||||
try:
|
|
||||||
logger.info("\n=== 测试表操作 ===")
|
|
||||||
|
|
||||||
# 获取数据库对象
|
|
||||||
db = self.client.database(self.test_db_name)
|
|
||||||
|
|
||||||
# 1. 定义表结构
|
|
||||||
logger.info("1. 定义表结构...")
|
|
||||||
fields = []
|
|
||||||
fields.append(Field("id", FieldType.STRING, primary_key=True,
|
|
||||||
partition_key=True, auto_increment=False, not_null=True))
|
|
||||||
fields.append(Field("content", FieldType.STRING, not_null=True))
|
|
||||||
fields.append(Field("vector", FieldType.FLOAT_VECTOR, not_null=True, dimension=128))
|
|
||||||
|
|
||||||
# 定义向量索引
|
|
||||||
indexes = []
|
|
||||||
indexes.append(
|
|
||||||
VectorIndex(
|
|
||||||
index_name="vector_idx",
|
|
||||||
index_type=IndexType.HNSW,
|
|
||||||
field="vector",
|
|
||||||
metric_type=MetricType.L2,
|
|
||||||
params=HNSWParams(m=16, efconstruction=100),
|
|
||||||
auto_build=True
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. 创建表
|
|
||||||
logger.info(f"2. 创建表: {self.test_table_name}")
|
|
||||||
table = db.create_table(
|
|
||||||
table_name=self.test_table_name,
|
|
||||||
replication=1, # 单副本用于测试
|
|
||||||
partition=Partition(partition_num=1), # 单分区用于测试
|
|
||||||
schema=Schema(fields=fields, indexes=indexes)
|
|
||||||
)
|
|
||||||
logger.info(f"✅ 表创建成功: {table.table_name}")
|
|
||||||
|
|
||||||
# 3. 等待表状态正常
|
|
||||||
logger.info("3. 等待表状态正常...")
|
|
||||||
max_wait = 30 # 最多等待30秒
|
|
||||||
wait_time = 0
|
|
||||||
while wait_time < max_wait:
|
|
||||||
table_info = db.describe_table(self.test_table_name)
|
|
||||||
logger.info(f"表状态: {table_info.state}")
|
|
||||||
if table_info.state == TableState.NORMAL:
|
|
||||||
logger.info("✅ 表状态正常")
|
|
||||||
break
|
|
||||||
time.sleep(2)
|
|
||||||
wait_time += 2
|
|
||||||
else:
|
|
||||||
logger.warning("⚠️ 表状态等待超时,继续测试...")
|
|
||||||
|
|
||||||
# 4. 查询表列表
|
|
||||||
logger.info("4. 查询表列表...")
|
|
||||||
tables = db.list_table()
|
|
||||||
table_names = [t.table_name for t in tables]
|
|
||||||
logger.info(f"表数量: {len(tables)}")
|
|
||||||
for table_name in table_names:
|
|
||||||
logger.info(f" - {table_name}")
|
|
||||||
|
|
||||||
if self.test_table_name in table_names:
|
|
||||||
logger.info("✅ 表操作测试成功")
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
logger.error("❌ 表验证失败")
|
|
||||||
return False
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 表操作测试失败: {str(e)}")
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
return False
|
|
||||||
|
|
||||||
def test_vector_operations(self) -> bool:
|
|
||||||
"""测试向量操作"""
|
|
||||||
try:
|
|
||||||
logger.info("\n=== 测试向量操作 ===")
|
|
||||||
|
|
||||||
# 获取表对象
|
|
||||||
db = self.client.database(self.test_db_name)
|
|
||||||
table = db.table(self.test_table_name)
|
|
||||||
|
|
||||||
# 1. 插入测试向量
|
|
||||||
logger.info("1. 插入测试向量...")
|
|
||||||
test_vectors = []
|
|
||||||
for i in range(5):
|
|
||||||
vector = np.random.rand(128).astype(np.float32).tolist()
|
|
||||||
row = Row(
|
|
||||||
id=f"test_{i:03d}",
|
|
||||||
content=f"测试内容_{i}",
|
|
||||||
vector=vector
|
|
||||||
)
|
|
||||||
test_vectors.append(row)
|
|
||||||
|
|
||||||
table.upsert(test_vectors)
|
|
||||||
logger.info(f"✅ 插入了 {len(test_vectors)} 个向量")
|
|
||||||
|
|
||||||
# 2. 查询向量
|
|
||||||
logger.info("2. 测试向量查询...")
|
|
||||||
primary_key = {'id': 'test_001'}
|
|
||||||
result = table.query(primary_key=primary_key, retrieve_vector=True)
|
|
||||||
if result:
|
|
||||||
logger.info("✅ 向量查询成功")
|
|
||||||
logger.info(f"查询结果: ID={result.id}, Content={result.content}")
|
|
||||||
else:
|
|
||||||
logger.warning("⚠️ 向量查询无结果")
|
|
||||||
|
|
||||||
# 3. 向量检索
|
|
||||||
logger.info("3. 测试向量检索...")
|
|
||||||
query_vector = np.random.rand(128).astype(np.float32).tolist()
|
|
||||||
search_request = VectorTopkSearchRequest(
|
|
||||||
vector_field="vector",
|
|
||||||
vector=FloatVector(query_vector),
|
|
||||||
limit=3,
|
|
||||||
config=VectorSearchConfig(ef=100)
|
|
||||||
)
|
|
||||||
|
|
||||||
search_results = table.vector_search(request=search_request)
|
|
||||||
logger.info(f"✅ 向量检索成功,返回 {len(search_results)} 个结果")
|
|
||||||
|
|
||||||
for i, result in enumerate(search_results):
|
|
||||||
logger.info(f" 结果 {i+1}: ID={result.id}, 相似度={result.distance:.4f}")
|
|
||||||
|
|
||||||
# 4. 查询表统计信息
|
|
||||||
logger.info("4. 查询表统计信息...")
|
|
||||||
stats = table.stats()
|
|
||||||
logger.info(f"记录数: {stats.rowCount}")
|
|
||||||
logger.info(f"内存大小: {stats.memorySizeInByte} bytes")
|
|
||||||
logger.info(f"磁盘大小: {stats.diskSizeInByte} bytes")
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 向量操作测试失败: {str(e)}")
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
return False
|
|
||||||
|
|
||||||
def cleanup(self) -> bool:
|
|
||||||
"""清理测试数据"""
|
|
||||||
try:
|
|
||||||
logger.info("\n=== 清理测试数据 ===")
|
|
||||||
|
|
||||||
# 删除测试数据库
|
|
||||||
logger.info(f"删除测试数据库: {self.test_db_name}")
|
|
||||||
self.client.drop_database(self.test_db_name)
|
|
||||||
logger.info("✅ 测试数据清理完成")
|
|
||||||
|
|
||||||
# 关闭连接
|
|
||||||
if self.client:
|
|
||||||
self.client.close()
|
|
||||||
logger.info("✅ VDB连接已关闭")
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 清理失败: {str(e)}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def run_full_test(self) -> Dict[str, bool]:
|
|
||||||
"""运行完整测试"""
|
|
||||||
results = {
|
|
||||||
"connection": False,
|
|
||||||
"database_ops": False,
|
|
||||||
"table_ops": False,
|
|
||||||
"vector_ops": False,
|
|
||||||
"cleanup": False
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
logger.info("🚀 开始百度VDB连接可用性测试")
|
|
||||||
logger.info("=" * 50)
|
|
||||||
|
|
||||||
# 1. 测试连接
|
|
||||||
if self.connect():
|
|
||||||
results["connection"] = True
|
|
||||||
|
|
||||||
# 2. 测试数据库操作
|
|
||||||
if self.test_database_operations():
|
|
||||||
results["database_ops"] = True
|
|
||||||
|
|
||||||
# 3. 测试表操作
|
|
||||||
if self.test_table_operations():
|
|
||||||
results["table_ops"] = True
|
|
||||||
|
|
||||||
# 4. 测试向量操作
|
|
||||||
if self.test_vector_operations():
|
|
||||||
results["vector_ops"] = True
|
|
||||||
|
|
||||||
# 5. 清理
|
|
||||||
if self.cleanup():
|
|
||||||
results["cleanup"] = True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 测试过程中发生错误: {str(e)}")
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
|
|
||||||
finally:
|
|
||||||
# 输出测试结果
|
|
||||||
logger.info("\n" + "=" * 50)
|
|
||||||
logger.info("📊 测试结果汇总:")
|
|
||||||
logger.info("=" * 50)
|
|
||||||
|
|
||||||
test_items = [
|
|
||||||
("VDB连接", results["connection"]),
|
|
||||||
("数据库操作", results["database_ops"]),
|
|
||||||
("表操作", results["table_ops"]),
|
|
||||||
("向量操作", results["vector_ops"]),
|
|
||||||
("数据清理", results["cleanup"])
|
|
||||||
]
|
|
||||||
|
|
||||||
for item, success in test_items:
|
|
||||||
status = "✅ 成功" if success else "❌ 失败"
|
|
||||||
logger.info(f"{item}: {status}")
|
|
||||||
|
|
||||||
# 计算总体成功率
|
|
||||||
success_count = sum(results.values())
|
|
||||||
total_count = len(results)
|
|
||||||
success_rate = (success_count / total_count) * 100
|
|
||||||
|
|
||||||
logger.info(f"\n总体成功率: {success_count}/{total_count} ({success_rate:.1f}%)")
|
|
||||||
|
|
||||||
if success_rate >= 80:
|
|
||||||
logger.info("🎉 百度VDB连接测试基本通过!")
|
|
||||||
else:
|
|
||||||
logger.warning("⚠️ 百度VDB连接存在问题,需要进一步检查")
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""主函数"""
|
|
||||||
tester = BaiduVDBConnectionTest()
|
|
||||||
results = tester.run_full_test()
|
|
||||||
|
|
||||||
# 返回测试是否成功
|
|
||||||
return results["connection"] and results["database_ops"]
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@ -1,165 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
百度VDB向量数据库简单连接测试
|
|
||||||
专注于验证连接可用性和基本操作
|
|
||||||
"""
|
|
||||||
|
|
||||||
import pymochow
|
|
||||||
from pymochow.configuration import Configuration
|
|
||||||
from pymochow.auth.bce_credentials import BceCredentials
|
|
||||||
import traceback
|
|
||||||
import time
|
|
||||||
|
|
||||||
# 数据库连接配置
|
|
||||||
ACCOUNT = 'root'
|
|
||||||
API_KEY = 'vdb$yjr9ln3n0td'
|
|
||||||
ENDPOINT = 'http://180.76.96.191:5287'
|
|
||||||
|
|
||||||
def test_basic_connection():
|
|
||||||
"""测试基本连接功能"""
|
|
||||||
print("=" * 60)
|
|
||||||
print("百度VDB向量数据库连接测试")
|
|
||||||
print("=" * 60)
|
|
||||||
print(f"测试时间: {time.strftime('%Y-%m-%d %H:%M:%S')}")
|
|
||||||
print(f"服务器地址: {ENDPOINT}")
|
|
||||||
print(f"用户名: {ACCOUNT}")
|
|
||||||
print()
|
|
||||||
|
|
||||||
client = None
|
|
||||||
test_db_name = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 1. 测试客户端创建
|
|
||||||
print("1. 创建客户端连接...")
|
|
||||||
config = Configuration(
|
|
||||||
credentials=BceCredentials(ACCOUNT, API_KEY),
|
|
||||||
endpoint=ENDPOINT
|
|
||||||
)
|
|
||||||
client = pymochow.MochowClient(config)
|
|
||||||
print(" ✓ 客户端创建成功")
|
|
||||||
|
|
||||||
# 2. 测试数据库列表查询
|
|
||||||
print("\n2. 查询数据库列表...")
|
|
||||||
db_list = client.list_databases()
|
|
||||||
print(f" ✓ 查询成功,当前数据库数量: {len(db_list)}")
|
|
||||||
|
|
||||||
if db_list:
|
|
||||||
print(" 现有数据库:")
|
|
||||||
for i, db in enumerate(db_list, 1):
|
|
||||||
print(f" {i}. {db.database_name}")
|
|
||||||
else:
|
|
||||||
print(" 当前没有数据库")
|
|
||||||
|
|
||||||
# 3. 测试创建数据库
|
|
||||||
print("\n3. 创建测试数据库...")
|
|
||||||
test_db_name = f"test_connection_{int(time.time())}"
|
|
||||||
db = client.create_database(test_db_name)
|
|
||||||
print(f" ✓ 成功创建数据库: {test_db_name}")
|
|
||||||
|
|
||||||
# 4. 验证数据库创建
|
|
||||||
print("\n4. 验证数据库创建...")
|
|
||||||
db_list_after = client.list_databases()
|
|
||||||
print(f" ✓ 验证成功,数据库数量: {len(db_list_after)}")
|
|
||||||
|
|
||||||
# 5. 获取数据库对象
|
|
||||||
print("\n5. 获取数据库对象...")
|
|
||||||
test_db = client.database(test_db_name)
|
|
||||||
print(f" ✓ 成功获取数据库对象: {test_db.database_name}")
|
|
||||||
|
|
||||||
# 6. 查询表列表(应该为空)
|
|
||||||
print("\n6. 查询表列表...")
|
|
||||||
table_list = test_db.list_table()
|
|
||||||
print(f" ✓ 查询成功,表数量: {len(table_list)}")
|
|
||||||
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("🎉 所有基本连接测试通过!")
|
|
||||||
print("✓ 数据库连接正常")
|
|
||||||
print("✓ 认证信息正确")
|
|
||||||
print("✓ 网络连接稳定")
|
|
||||||
print("✓ 基本数据库操作可用")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"\n❌ 测试失败: {str(e)}")
|
|
||||||
print(f"详细错误信息:")
|
|
||||||
print(traceback.format_exc())
|
|
||||||
return False
|
|
||||||
|
|
||||||
finally:
|
|
||||||
# 清理测试数据库
|
|
||||||
if client and test_db_name:
|
|
||||||
try:
|
|
||||||
print(f"\n7. 清理测试数据库...")
|
|
||||||
client.drop_database(test_db_name)
|
|
||||||
print(f" ✓ 成功删除测试数据库: {test_db_name}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f" ⚠ 清理失败: {str(e)}")
|
|
||||||
|
|
||||||
# 关闭连接
|
|
||||||
if client:
|
|
||||||
try:
|
|
||||||
client.close()
|
|
||||||
print(" ✓ 连接已关闭")
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def test_advanced_operations():
|
|
||||||
"""测试高级操作(可选)"""
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("高级功能测试(可选)")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
client = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 创建客户端
|
|
||||||
config = Configuration(
|
|
||||||
credentials=BceCredentials(ACCOUNT, API_KEY),
|
|
||||||
endpoint=ENDPOINT
|
|
||||||
)
|
|
||||||
client = pymochow.MochowClient(config)
|
|
||||||
|
|
||||||
# 测试多次连接
|
|
||||||
print("1. 测试连接稳定性...")
|
|
||||||
for i in range(3):
|
|
||||||
db_list = client.list_databases()
|
|
||||||
print(f" 第{i+1}次查询: {len(db_list)}个数据库")
|
|
||||||
time.sleep(1)
|
|
||||||
print(" ✓ 连接稳定")
|
|
||||||
|
|
||||||
print("\n✓ 高级功能测试通过")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"⚠ 高级功能测试出现问题: {str(e)}")
|
|
||||||
|
|
||||||
finally:
|
|
||||||
if client:
|
|
||||||
try:
|
|
||||||
client.close()
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# 运行基本连接测试
|
|
||||||
success = test_basic_connection()
|
|
||||||
|
|
||||||
if success:
|
|
||||||
# 如果基本测试通过,运行高级测试
|
|
||||||
test_advanced_operations()
|
|
||||||
|
|
||||||
print(f"\n🎯 总结:")
|
|
||||||
print(f"你的百度VDB向量数据库配置完全可用!")
|
|
||||||
print(f"- 服务器地址: {ENDPOINT}")
|
|
||||||
print(f"- 用户名: {ACCOUNT}")
|
|
||||||
print(f"- 连接状态: 正常")
|
|
||||||
print(f"- 基本操作: 可用")
|
|
||||||
print(f"- 建议: 可以开始使用该数据库进行向量存储和检索")
|
|
||||||
else:
|
|
||||||
print(f"\n❌ 数据库连接存在问题,请检查:")
|
|
||||||
print(f"1. 网络连接是否正常")
|
|
||||||
print(f"2. 服务器地址是否正确: {ENDPOINT}")
|
|
||||||
print(f"3. 用户名和密码是否正确")
|
|
||||||
print(f"4. 防火墙是否阻止了连接")
|
|
||||||
@ -1,58 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
FAISS多模态检索系统简单测试
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
|
||||||
|
|
||||||
from multimodal_retrieval_faiss import MultimodalRetrievalFAISS
|
|
||||||
from PIL import Image
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
def test_text_retrieval():
|
|
||||||
print("=== 测试文本检索 ===")
|
|
||||||
|
|
||||||
# 初始化检索系统
|
|
||||||
print("初始化检索系统...")
|
|
||||||
retrieval = MultimodalRetrievalFAISS(
|
|
||||||
model_name="OpenSearch-AI/Ops-MM-embedding-v1-7B",
|
|
||||||
use_all_gpus=True,
|
|
||||||
index_path="faiss_index_test"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 测试文本
|
|
||||||
texts = [
|
|
||||||
"一只可爱的橘色猫咪在沙发上睡觉",
|
|
||||||
"城市夜景中的高楼大厦和车流",
|
|
||||||
"阳光明媚的海滩上,人们在冲浪和晒太阳",
|
|
||||||
"美味的意大利面配红酒和沙拉",
|
|
||||||
"雪山上滑雪的运动员"
|
|
||||||
]
|
|
||||||
|
|
||||||
# 添加文本
|
|
||||||
print("\n添加文本到检索系统...")
|
|
||||||
text_ids = retrieval.add_texts(texts)
|
|
||||||
print(f"添加了{len(text_ids)}条文本")
|
|
||||||
print(f"当前向量数量: {retrieval.get_vector_count()}")
|
|
||||||
|
|
||||||
# 测试文本搜索
|
|
||||||
print("\n测试文本搜索...")
|
|
||||||
queries = ["一只猫在睡觉", "都市风光", "海边的景色"]
|
|
||||||
|
|
||||||
for query in queries:
|
|
||||||
print(f"\n查询: {query}")
|
|
||||||
results = retrieval.search_by_text(query, k=2)
|
|
||||||
for i, result in enumerate(results):
|
|
||||||
print(f" 结果 {i+1}: {result.get('text', 'N/A')} (分数: {result.get('score', 0):.4f})")
|
|
||||||
|
|
||||||
# 保存索引
|
|
||||||
print("\n保存索引...")
|
|
||||||
retrieval.save_index()
|
|
||||||
|
|
||||||
print("\n测试完成!")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test_text_retrieval()
|
|
||||||
@ -1,164 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
FAISS多模态检索系统简单测试 - 带代理设置
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
import logging
|
|
||||||
|
|
||||||
# 设置代理
|
|
||||||
os.environ['HTTP_PROXY'] = 'http://127.0.0.1:7890' # 根据实际情况修改
|
|
||||||
os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:7890' # 根据实际情况修改
|
|
||||||
|
|
||||||
# 设置日志
|
|
||||||
logging.basicConfig(level=logging.INFO,
|
|
||||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# 设置离线模式,避免下载模型
|
|
||||||
os.environ['TRANSFORMERS_OFFLINE'] = '1'
|
|
||||||
|
|
||||||
# 添加当前目录到路径
|
|
||||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
|
||||||
|
|
||||||
# 使用简单的向量模型替代大型多模态模型
|
|
||||||
from sentence_transformers import SentenceTransformer
|
|
||||||
import faiss
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
class SimpleFaissRetrieval:
|
|
||||||
"""简化版FAISS检索系统,使用sentence-transformers"""
|
|
||||||
|
|
||||||
def __init__(self, model_name="paraphrase-multilingual-MiniLM-L12-v2", index_path="simple_faiss_index"):
|
|
||||||
"""
|
|
||||||
初始化简化版检索系统
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_name: 模型名称,使用轻量级模型
|
|
||||||
index_path: 索引文件路径
|
|
||||||
"""
|
|
||||||
self.model_name = model_name
|
|
||||||
self.index_path = index_path
|
|
||||||
|
|
||||||
logger.info(f"加载模型: {model_name}")
|
|
||||||
try:
|
|
||||||
# 尝试加载模型
|
|
||||||
self.model = SentenceTransformer(model_name)
|
|
||||||
self.dimension = self.model.get_sentence_embedding_dimension()
|
|
||||||
logger.info(f"模型加载成功,向量维度: {self.dimension}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"模型加载失败: {str(e)}")
|
|
||||||
logger.info("使用随机向量模拟...")
|
|
||||||
self.model = None
|
|
||||||
self.dimension = 384 # 默认维度
|
|
||||||
|
|
||||||
# 初始化索引
|
|
||||||
self.index = faiss.IndexFlatL2(self.dimension)
|
|
||||||
self.metadata = {}
|
|
||||||
|
|
||||||
logger.info("检索系统初始化完成")
|
|
||||||
|
|
||||||
def encode_text(self, text):
|
|
||||||
"""编码文本为向量"""
|
|
||||||
if self.model is None:
|
|
||||||
# 如果模型加载失败,使用随机向量
|
|
||||||
if isinstance(text, list):
|
|
||||||
vectors = np.random.rand(len(text), self.dimension).astype('float32')
|
|
||||||
return vectors
|
|
||||||
else:
|
|
||||||
return np.random.rand(self.dimension).astype('float32')
|
|
||||||
else:
|
|
||||||
# 使用模型编码
|
|
||||||
return self.model.encode(text, convert_to_numpy=True)
|
|
||||||
|
|
||||||
def add_texts(self, texts, metadatas=None):
|
|
||||||
"""添加文本到索引"""
|
|
||||||
if not texts:
|
|
||||||
return []
|
|
||||||
|
|
||||||
if metadatas is None:
|
|
||||||
metadatas = [{} for _ in range(len(texts))]
|
|
||||||
|
|
||||||
# 编码文本
|
|
||||||
vectors = self.encode_text(texts)
|
|
||||||
|
|
||||||
# 添加到索引
|
|
||||||
start_id = len(self.metadata)
|
|
||||||
ids = list(range(start_id, start_id + len(texts)))
|
|
||||||
|
|
||||||
self.index.add(np.array(vectors).astype('float32'))
|
|
||||||
|
|
||||||
# 保存元数据
|
|
||||||
for i, id in enumerate(ids):
|
|
||||||
self.metadata[str(id)] = {
|
|
||||||
"text": texts[i],
|
|
||||||
"type": "text",
|
|
||||||
**metadatas[i]
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.info(f"添加了{len(ids)}条文本,当前索引大小: {self.index.ntotal}")
|
|
||||||
return [str(id) for id in ids]
|
|
||||||
|
|
||||||
def search(self, query, k=5):
|
|
||||||
"""搜索相似文本"""
|
|
||||||
# 编码查询
|
|
||||||
query_vector = self.encode_text(query)
|
|
||||||
if len(query_vector.shape) == 1:
|
|
||||||
query_vector = query_vector.reshape(1, -1)
|
|
||||||
|
|
||||||
# 搜索
|
|
||||||
distances, indices = self.index.search(query_vector.astype('float32'), k)
|
|
||||||
|
|
||||||
# 处理结果
|
|
||||||
results = []
|
|
||||||
for i in range(len(indices[0])):
|
|
||||||
idx = indices[0][i]
|
|
||||||
if idx < 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
vector_id = str(idx)
|
|
||||||
if vector_id in self.metadata:
|
|
||||||
result = self.metadata[vector_id].copy()
|
|
||||||
result['score'] = float(1.0 / (1.0 + distances[0][i]))
|
|
||||||
results.append(result)
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
def test_simple_retrieval():
|
|
||||||
"""测试简化版检索系统"""
|
|
||||||
print("=== 测试简化版FAISS检索系统 ===")
|
|
||||||
|
|
||||||
# 初始化检索系统
|
|
||||||
print("初始化检索系统...")
|
|
||||||
retrieval = SimpleFaissRetrieval()
|
|
||||||
|
|
||||||
# 测试文本
|
|
||||||
texts = [
|
|
||||||
"一只可爱的橘色猫咪在沙发上睡觉",
|
|
||||||
"城市夜景中的高楼大厦和车流",
|
|
||||||
"阳光明媚的海滩上,人们在冲浪和晒太阳",
|
|
||||||
"美味的意大利面配红酒和沙拉",
|
|
||||||
"雪山上滑雪的运动员"
|
|
||||||
]
|
|
||||||
|
|
||||||
# 添加文本
|
|
||||||
print("\n添加文本到检索系统...")
|
|
||||||
text_ids = retrieval.add_texts(texts)
|
|
||||||
print(f"添加了{len(text_ids)}条文本")
|
|
||||||
|
|
||||||
# 测试文本搜索
|
|
||||||
print("\n测试文本搜索...")
|
|
||||||
queries = ["一只猫在睡觉", "都市风光", "海边的景色"]
|
|
||||||
|
|
||||||
for query in queries:
|
|
||||||
print(f"\n查询: {query}")
|
|
||||||
results = retrieval.search(query, k=2)
|
|
||||||
for i, result in enumerate(results):
|
|
||||||
print(f" 结果 {i+1}: {result.get('text', 'N/A')} (分数: {result.get('score', 0):.4f})")
|
|
||||||
|
|
||||||
print("\n测试完成!")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test_simple_retrieval()
|
|
||||||
@ -1,79 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
测试修复后的系统功能
|
|
||||||
"""
|
|
||||||
|
|
||||||
import requests
|
|
||||||
import time
|
|
||||||
import json
|
|
||||||
|
|
||||||
def test_system():
|
|
||||||
"""测试系统功能"""
|
|
||||||
base_url = "http://localhost:5000"
|
|
||||||
|
|
||||||
print("🧪 开始测试修复后的系统...")
|
|
||||||
print("=" * 50)
|
|
||||||
|
|
||||||
# 测试1: 检查系统状态
|
|
||||||
print("1. 测试系统状态...")
|
|
||||||
try:
|
|
||||||
response = requests.get(f"{base_url}/api/status", timeout=10)
|
|
||||||
if response.status_code == 200:
|
|
||||||
status = response.json()
|
|
||||||
print(f" ✅ 系统状态: {status}")
|
|
||||||
else:
|
|
||||||
print(f" ❌ 状态检查失败: {response.status_code}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f" ❌ 状态检查异常: {e}")
|
|
||||||
|
|
||||||
# 测试2: 检查数据统计
|
|
||||||
print("\n2. 测试数据统计...")
|
|
||||||
try:
|
|
||||||
response = requests.get(f"{base_url}/api/data/stats", timeout=10)
|
|
||||||
if response.status_code == 200:
|
|
||||||
stats = response.json()
|
|
||||||
print(f" ✅ 数据统计: {stats}")
|
|
||||||
else:
|
|
||||||
print(f" ❌ 统计检查失败: {response.status_code}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f" ❌ 统计检查异常: {e}")
|
|
||||||
|
|
||||||
# 测试3: 检查数据列表
|
|
||||||
print("\n3. 测试数据列表...")
|
|
||||||
try:
|
|
||||||
response = requests.get(f"{base_url}/api/data/list", timeout=10)
|
|
||||||
if response.status_code == 200:
|
|
||||||
data_list = response.json()
|
|
||||||
print(f" ✅ 数据列表: {data_list}")
|
|
||||||
else:
|
|
||||||
print(f" ❌ 列表检查失败: {response.status_code}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f" ❌ 列表检查异常: {e}")
|
|
||||||
|
|
||||||
# 测试4: 测试文本搜索(如果系统已初始化)
|
|
||||||
print("\n4. 测试文本搜索...")
|
|
||||||
try:
|
|
||||||
search_data = {
|
|
||||||
"query": "测试查询",
|
|
||||||
"top_k": 3
|
|
||||||
}
|
|
||||||
response = requests.post(f"{base_url}/api/search/text_to_text",
|
|
||||||
json=search_data, timeout=10)
|
|
||||||
if response.status_code == 200:
|
|
||||||
result = response.json()
|
|
||||||
print(f" ✅ 文本搜索: {result}")
|
|
||||||
else:
|
|
||||||
print(f" ❌ 文本搜索失败: {response.status_code}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f" ❌ 文本搜索异常: {e}")
|
|
||||||
|
|
||||||
print("\n" + "=" * 50)
|
|
||||||
print("🎉 测试完成!")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# 等待系统启动
|
|
||||||
print("⏳ 等待系统启动...")
|
|
||||||
time.sleep(5)
|
|
||||||
|
|
||||||
test_system()
|
|
||||||
@ -1,49 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
测试图像编码功能
|
|
||||||
"""
|
|
||||||
|
|
||||||
from multimodal_retrieval_multigpu import MultiGPUMultimodalRetrieval
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
def test_image_encoding():
|
|
||||||
print('正在初始化多GPU多模态检索系统...')
|
|
||||||
retrieval = MultiGPUMultimodalRetrieval()
|
|
||||||
|
|
||||||
# 测试文本编码
|
|
||||||
print('测试文本编码...')
|
|
||||||
text_embeddings = retrieval.encode_text_batch(['这是一个测试文本'])
|
|
||||||
print(f'文本embedding形状: {text_embeddings.shape}')
|
|
||||||
print(f'文本embedding数据类型: {text_embeddings.dtype}')
|
|
||||||
|
|
||||||
# 测试图像编码
|
|
||||||
print('测试图像编码...')
|
|
||||||
test_images = ['sample_images/1755677101_1__.jpg']
|
|
||||||
image_embeddings = retrieval.encode_image_batch(test_images)
|
|
||||||
print(f'图像embedding形状: {image_embeddings.shape}')
|
|
||||||
print(f'图像embedding数据类型: {image_embeddings.dtype}')
|
|
||||||
|
|
||||||
# 测试两次相同图像的embedding是否一致
|
|
||||||
print('测试embedding一致性...')
|
|
||||||
image_embeddings2 = retrieval.encode_image_batch(test_images)
|
|
||||||
consistency = np.allclose(image_embeddings, image_embeddings2, rtol=1e-5)
|
|
||||||
print(f'相同图像embedding一致性: {consistency}')
|
|
||||||
|
|
||||||
# 测试不同图像的embedding差异
|
|
||||||
print('测试不同图像embedding差异...')
|
|
||||||
test_images2 = ['sample_images/1755677101_2__.jpg']
|
|
||||||
image_embeddings3 = retrieval.encode_image_batch(test_images2)
|
|
||||||
similarity = np.dot(image_embeddings[0], image_embeddings3[0]) / (np.linalg.norm(image_embeddings[0]) * np.linalg.norm(image_embeddings3[0]))
|
|
||||||
print(f'不同图像间相似度: {similarity:.4f}')
|
|
||||||
|
|
||||||
# 验证维度一致性
|
|
||||||
if text_embeddings.shape[1] == image_embeddings.shape[1]:
|
|
||||||
print('✅ 文本和图像embedding维度一致')
|
|
||||||
else:
|
|
||||||
print('❌ 文本和图像embedding维度不一致')
|
|
||||||
|
|
||||||
print('测试完成!')
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test_image_encoding()
|
|
||||||
@ -1,98 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
使用本地模型的FAISS多模态检索系统测试
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
|
||||||
import numpy as np
|
|
||||||
import faiss
|
|
||||||
from typing import List, Dict, Any, Optional, Union
|
|
||||||
import json
|
|
||||||
|
|
||||||
# 设置日志
|
|
||||||
logging.basicConfig(level=logging.INFO,
|
|
||||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# 设置离线模式
|
|
||||||
os.environ['TRANSFORMERS_OFFLINE'] = '1'
|
|
||||||
|
|
||||||
def test_local_model():
|
|
||||||
"""测试本地模型加载"""
|
|
||||||
from transformers import AutoModel, AutoTokenizer, AutoProcessor
|
|
||||||
import torch
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
# 这里替换为您实际下载的模型路径
|
|
||||||
local_model_path = "/root/models/Ops-MM-embedding-v1-7B"
|
|
||||||
|
|
||||||
if not os.path.exists(local_model_path):
|
|
||||||
logger.error(f"模型路径不存在: {local_model_path}")
|
|
||||||
logger.info("请先下载模型到指定路径")
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info(f"加载本地模型: {local_model_path}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 加载tokenizer
|
|
||||||
logger.info("加载tokenizer...")
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(local_model_path)
|
|
||||||
|
|
||||||
# 加载processor
|
|
||||||
logger.info("加载processor...")
|
|
||||||
processor = AutoProcessor.from_pretrained(local_model_path)
|
|
||||||
|
|
||||||
# 加载模型
|
|
||||||
logger.info("加载模型...")
|
|
||||||
model = AutoModel.from_pretrained(
|
|
||||||
local_model_path,
|
|
||||||
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
|
||||||
device_map="auto" if torch.cuda.device_count() > 0 else None
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("模型加载成功!")
|
|
||||||
|
|
||||||
# 测试文本编码
|
|
||||||
logger.info("测试文本编码...")
|
|
||||||
text = "这是一个测试文本"
|
|
||||||
inputs = tokenizer(text, return_tensors="pt")
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
inputs = {k: v.to("cuda") for k, v in inputs.items()}
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
outputs = model(**inputs)
|
|
||||||
text_embedding = outputs.last_hidden_state[:, 0, :].cpu().numpy()
|
|
||||||
|
|
||||||
logger.info(f"文本编码维度: {text_embedding.shape}")
|
|
||||||
|
|
||||||
# 如果有图像处理功能,测试图像编码
|
|
||||||
try:
|
|
||||||
logger.info("测试图像编码...")
|
|
||||||
# 创建一个简单的测试图像
|
|
||||||
image = Image.new('RGB', (224, 224), color='red')
|
|
||||||
image_inputs = processor(images=image, return_tensors="pt")
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
image_inputs = {k: v.to("cuda") for k, v in image_inputs.items()}
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
image_outputs = model.vision_model(**image_inputs)
|
|
||||||
image_embedding = image_outputs.pooler_output.cpu().numpy()
|
|
||||||
|
|
||||||
logger.info(f"图像编码维度: {image_embedding.shape}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"图像编码测试失败: {str(e)}")
|
|
||||||
|
|
||||||
logger.info("本地模型测试完成!")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"模型加载失败: {str(e)}")
|
|
||||||
logger.error("请确保模型文件已正确下载")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test_local_model()
|
|
||||||
@ -1,229 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
测试本地模型和FAISS向量数据库的多模态检索系统
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
|
||||||
import time
|
|
||||||
from PIL import Image
|
|
||||||
import numpy as np
|
|
||||||
from multimodal_retrieval_local import MultimodalRetrievalLocal
|
|
||||||
|
|
||||||
# 设置日志
|
|
||||||
logging.basicConfig(level=logging.INFO,
|
|
||||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# 设置离线模式
|
|
||||||
os.environ['TRANSFORMERS_OFFLINE'] = '1'
|
|
||||||
|
|
||||||
def test_text_retrieval():
|
|
||||||
"""测试文本检索功能"""
|
|
||||||
print("\n=== 测试文本检索 ===")
|
|
||||||
|
|
||||||
# 初始化检索系统
|
|
||||||
print("初始化检索系统...")
|
|
||||||
retrieval = MultimodalRetrievalLocal(
|
|
||||||
model_path="/root/models/Ops-MM-embedding-v1-7B",
|
|
||||||
use_all_gpus=True,
|
|
||||||
index_path="local_faiss_text_test"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 测试文本
|
|
||||||
texts = [
|
|
||||||
"一只可爱的橘色猫咪在沙发上睡觉",
|
|
||||||
"城市夜景中的高楼大厦和车流",
|
|
||||||
"阳光明媚的海滩上,人们在冲浪和晒太阳",
|
|
||||||
"美味的意大利面配红酒和沙拉",
|
|
||||||
"雪山上滑雪的运动员"
|
|
||||||
]
|
|
||||||
|
|
||||||
# 添加文本
|
|
||||||
print("\n添加文本到检索系统...")
|
|
||||||
text_ids = retrieval.add_texts(texts)
|
|
||||||
print(f"添加了{len(text_ids)}条文本")
|
|
||||||
|
|
||||||
# 获取统计信息
|
|
||||||
stats = retrieval.get_stats()
|
|
||||||
print(f"检索系统统计信息: {stats}")
|
|
||||||
|
|
||||||
# 测试文本搜索
|
|
||||||
print("\n测试文本搜索...")
|
|
||||||
queries = ["一只猫在睡觉", "都市风光", "海边的景色"]
|
|
||||||
|
|
||||||
for query in queries:
|
|
||||||
print(f"\n查询: {query}")
|
|
||||||
results = retrieval.search_by_text(query, k=2)
|
|
||||||
for i, result in enumerate(results):
|
|
||||||
print(f" 结果 {i+1}: {result.get('text', 'N/A')} (分数: {result.get('score', 0):.4f})")
|
|
||||||
|
|
||||||
# 保存索引
|
|
||||||
print("\n保存索引...")
|
|
||||||
retrieval.save_index()
|
|
||||||
|
|
||||||
print("\n文本检索测试完成!")
|
|
||||||
return retrieval
|
|
||||||
|
|
||||||
def test_image_retrieval():
|
|
||||||
"""测试图像检索功能"""
|
|
||||||
print("\n=== 测试图像检索 ===")
|
|
||||||
|
|
||||||
# 初始化检索系统
|
|
||||||
print("初始化检索系统...")
|
|
||||||
retrieval = MultimodalRetrievalLocal(
|
|
||||||
model_path="/root/models/Ops-MM-embedding-v1-7B",
|
|
||||||
use_all_gpus=True,
|
|
||||||
index_path="local_faiss_image_test"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 创建测试图像
|
|
||||||
print("\n创建测试图像...")
|
|
||||||
images = []
|
|
||||||
colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255)]
|
|
||||||
image_paths = []
|
|
||||||
|
|
||||||
for i, color in enumerate(colors):
|
|
||||||
img = Image.new('RGB', (224, 224), color=color)
|
|
||||||
images.append(img)
|
|
||||||
|
|
||||||
# 保存图像
|
|
||||||
img_path = f"/tmp/test_image_{i}.png"
|
|
||||||
img.save(img_path)
|
|
||||||
image_paths.append(img_path)
|
|
||||||
print(f"创建图像: {img_path}")
|
|
||||||
|
|
||||||
# 添加图像
|
|
||||||
print("\n添加图像到检索系统...")
|
|
||||||
metadatas = [{"description": f"测试图像 {i+1}"} for i in range(len(images))]
|
|
||||||
image_ids = retrieval.add_images(images, metadatas, image_paths)
|
|
||||||
print(f"添加了{len(image_ids)}张图像")
|
|
||||||
|
|
||||||
# 获取统计信息
|
|
||||||
stats = retrieval.get_stats()
|
|
||||||
print(f"检索系统统计信息: {stats}")
|
|
||||||
|
|
||||||
# 测试图像搜索
|
|
||||||
print("\n测试图像搜索...")
|
|
||||||
query_image = Image.new('RGB', (224, 224), color=(255, 0, 0)) # 红色图像
|
|
||||||
|
|
||||||
print("\n使用图像查询图像:")
|
|
||||||
results = retrieval.search_by_image(query_image, k=2, filter_type="image")
|
|
||||||
for i, result in enumerate(results):
|
|
||||||
print(f" 结果 {i+1}: {result.get('description', 'N/A')} (分数: {result.get('score', 0):.4f})")
|
|
||||||
|
|
||||||
# 保存索引
|
|
||||||
print("\n保存索引...")
|
|
||||||
retrieval.save_index()
|
|
||||||
|
|
||||||
print("\n图像检索测试完成!")
|
|
||||||
return retrieval
|
|
||||||
|
|
||||||
def test_cross_modal_retrieval():
|
|
||||||
"""测试跨模态检索功能"""
|
|
||||||
print("\n=== 测试跨模态检索 ===")
|
|
||||||
|
|
||||||
# 初始化检索系统
|
|
||||||
print("初始化检索系统...")
|
|
||||||
retrieval = MultimodalRetrievalLocal(
|
|
||||||
model_path="/root/models/Ops-MM-embedding-v1-7B",
|
|
||||||
use_all_gpus=True,
|
|
||||||
index_path="local_faiss_cross_modal_test"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 添加文本
|
|
||||||
texts = [
|
|
||||||
"一只红色的苹果",
|
|
||||||
"绿色的草地",
|
|
||||||
"蓝色的大海",
|
|
||||||
"黄色的向日葵",
|
|
||||||
"青色的天空"
|
|
||||||
]
|
|
||||||
print("\n添加文本到检索系统...")
|
|
||||||
text_ids = retrieval.add_texts(texts)
|
|
||||||
print(f"添加了{len(text_ids)}条文本")
|
|
||||||
|
|
||||||
# 添加图像
|
|
||||||
print("\n添加图像到检索系统...")
|
|
||||||
images = []
|
|
||||||
colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255)]
|
|
||||||
descriptions = ["红色图像", "绿色图像", "蓝色图像", "黄色图像", "青色图像"]
|
|
||||||
|
|
||||||
for i, color in enumerate(colors):
|
|
||||||
img = Image.new('RGB', (224, 224), color=color)
|
|
||||||
images.append(img)
|
|
||||||
|
|
||||||
metadatas = [{"description": desc} for desc in descriptions]
|
|
||||||
image_ids = retrieval.add_images(images, metadatas)
|
|
||||||
print(f"添加了{len(image_ids)}张图像")
|
|
||||||
|
|
||||||
# 获取统计信息
|
|
||||||
stats = retrieval.get_stats()
|
|
||||||
print(f"检索系统统计信息: {stats}")
|
|
||||||
|
|
||||||
# 测试文搜图
|
|
||||||
print("\n测试文搜图...")
|
|
||||||
query_text = "红色"
|
|
||||||
print(f"查询文本: {query_text}")
|
|
||||||
results = retrieval.search_by_text(query_text, k=2, filter_type="image")
|
|
||||||
for i, result in enumerate(results):
|
|
||||||
print(f" 结果 {i+1}: {result.get('description', 'N/A')} (分数: {result.get('score', 0):.4f})")
|
|
||||||
|
|
||||||
# 测试图搜文
|
|
||||||
print("\n测试图搜文...")
|
|
||||||
query_image = Image.new('RGB', (224, 224), color=(0, 0, 255)) # 蓝色图像
|
|
||||||
print("查询图像: 蓝色图像")
|
|
||||||
results = retrieval.search_by_image(query_image, k=2, filter_type="text")
|
|
||||||
for i, result in enumerate(results):
|
|
||||||
print(f" 结果 {i+1}: {result.get('text', 'N/A')} (分数: {result.get('score', 0):.4f})")
|
|
||||||
|
|
||||||
# 保存索引
|
|
||||||
print("\n保存索引...")
|
|
||||||
retrieval.save_index()
|
|
||||||
|
|
||||||
print("\n跨模态检索测试完成!")
|
|
||||||
return retrieval
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""主函数"""
|
|
||||||
print("=== 本地多模态检索系统测试 ===")
|
|
||||||
|
|
||||||
# 检查模型路径
|
|
||||||
model_path = "/root/models/Ops-MM-embedding-v1-7B"
|
|
||||||
if not os.path.exists(model_path):
|
|
||||||
print(f"错误: 模型路径不存在: {model_path}")
|
|
||||||
print("请先下载模型到指定路径")
|
|
||||||
return
|
|
||||||
|
|
||||||
# 检查模型文件
|
|
||||||
config_file = os.path.join(model_path, "config.json")
|
|
||||||
if not os.path.exists(config_file):
|
|
||||||
print(f"错误: 模型配置文件不存在: {config_file}")
|
|
||||||
print("请确保模型文件已正确下载")
|
|
||||||
return
|
|
||||||
|
|
||||||
print(f"模型路径验证成功: {model_path}")
|
|
||||||
|
|
||||||
# 运行测试
|
|
||||||
try:
|
|
||||||
# 测试文本检索
|
|
||||||
test_text_retrieval()
|
|
||||||
|
|
||||||
# 测试图像检索
|
|
||||||
test_image_retrieval()
|
|
||||||
|
|
||||||
# 测试跨模态检索
|
|
||||||
test_cross_modal_retrieval()
|
|
||||||
|
|
||||||
print("\n所有测试完成!")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"测试过程中发生错误: {str(e)}")
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@ -1,354 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
测试优化后的系统功能
|
|
||||||
验证自动清理、内存处理和流式上传
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
import tempfile
|
|
||||||
import logging
|
|
||||||
import subprocess
|
|
||||||
import signal
|
|
||||||
from io import BytesIO
|
|
||||||
from PIL import Image
|
|
||||||
import requests
|
|
||||||
import json
|
|
||||||
|
|
||||||
# 设置日志
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
def start_web_server():
|
|
||||||
"""启动Web服务器"""
|
|
||||||
try:
|
|
||||||
logger.info("🚀 启动Web服务器...")
|
|
||||||
|
|
||||||
# 启动Web应用进程
|
|
||||||
process = subprocess.Popen([
|
|
||||||
sys.executable, 'web_app_vdb_production.py'
|
|
||||||
], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
|
||||||
|
|
||||||
logger.info(f"✅ Web服务器已启动,PID: {process.pid}")
|
|
||||||
|
|
||||||
# 等待服务器启动
|
|
||||||
for i in range(10):
|
|
||||||
try:
|
|
||||||
response = requests.get("http://127.0.0.1:5000/", timeout=2)
|
|
||||||
if response.status_code == 200:
|
|
||||||
logger.info("✅ Web服务器就绪")
|
|
||||||
return process
|
|
||||||
except:
|
|
||||||
time.sleep(2)
|
|
||||||
|
|
||||||
logger.warning("⚠️ Web服务器启动超时")
|
|
||||||
return process
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 启动Web服务器失败: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def stop_web_server(process):
|
|
||||||
"""停止Web服务器"""
|
|
||||||
if process:
|
|
||||||
try:
|
|
||||||
process.terminate()
|
|
||||||
process.wait(timeout=5)
|
|
||||||
logger.info("✅ Web服务器已停止")
|
|
||||||
except subprocess.TimeoutExpired:
|
|
||||||
process.kill()
|
|
||||||
logger.info("🔥 强制停止Web服务器")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 停止Web服务器失败: {e}")
|
|
||||||
|
|
||||||
def test_optimized_file_handler():
|
|
||||||
"""测试优化的文件处理器"""
|
|
||||||
print("\n" + "="*60)
|
|
||||||
print("测试优化的文件处理器")
|
|
||||||
print("="*60)
|
|
||||||
|
|
||||||
try:
|
|
||||||
from optimized_file_handler import get_file_handler
|
|
||||||
|
|
||||||
# 获取文件处理器实例
|
|
||||||
file_handler = get_file_handler()
|
|
||||||
logger.info("✅ 文件处理器初始化成功")
|
|
||||||
|
|
||||||
# 创建测试图像
|
|
||||||
test_image = Image.new('RGB', (100, 100), color='red')
|
|
||||||
img_buffer = BytesIO()
|
|
||||||
test_image.save(img_buffer, format='PNG')
|
|
||||||
img_buffer.seek(0)
|
|
||||||
|
|
||||||
# 测试文件大小判断
|
|
||||||
file_size = file_handler.get_file_size(img_buffer)
|
|
||||||
is_small = file_handler.is_small_file(img_buffer)
|
|
||||||
logger.info(f"测试图像大小: {file_size} bytes, 小文件: {is_small}")
|
|
||||||
|
|
||||||
# 测试临时文件上下文管理器
|
|
||||||
with file_handler.temp_file_context(suffix='.png') as temp_path:
|
|
||||||
logger.info(f"创建临时文件: {temp_path}")
|
|
||||||
assert os.path.exists(temp_path)
|
|
||||||
|
|
||||||
# 验证临时文件已被清理
|
|
||||||
assert not os.path.exists(temp_path)
|
|
||||||
logger.info("✅ 临时文件自动清理成功")
|
|
||||||
|
|
||||||
# 测试清理所有临时文件
|
|
||||||
file_handler.cleanup_all_temp_files()
|
|
||||||
logger.info("✅ 批量清理临时文件成功")
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 文件处理器测试失败: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def test_memory_processing():
|
|
||||||
"""测试内存处理功能"""
|
|
||||||
print("\n" + "="*60)
|
|
||||||
print("测试内存处理功能")
|
|
||||||
print("="*60)
|
|
||||||
|
|
||||||
try:
|
|
||||||
from optimized_file_handler import get_file_handler
|
|
||||||
|
|
||||||
file_handler = get_file_handler()
|
|
||||||
|
|
||||||
# 测试文本内存处理
|
|
||||||
test_texts = [
|
|
||||||
"这是一个测试文本",
|
|
||||||
"测试内存处理功能",
|
|
||||||
"优化的文件处理器"
|
|
||||||
]
|
|
||||||
|
|
||||||
logger.info(f"开始内存处理 {len(test_texts)} 条文本...")
|
|
||||||
processed_texts = file_handler.process_text_in_memory(test_texts)
|
|
||||||
|
|
||||||
if processed_texts:
|
|
||||||
logger.info(f"✅ 内存处理文本成功: {len(processed_texts)} 条")
|
|
||||||
for i, text_info in enumerate(processed_texts):
|
|
||||||
logger.info(f" 文本 {i}: {text_info['bos_key']}")
|
|
||||||
else:
|
|
||||||
logger.warning("⚠️ 内存处理文本返回空结果")
|
|
||||||
|
|
||||||
return len(processed_texts) > 0
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 内存处理测试失败: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def test_web_api_optimized():
|
|
||||||
"""测试优化后的Web API"""
|
|
||||||
print("\n" + "="*60)
|
|
||||||
print("测试优化后的Web API")
|
|
||||||
print("="*60)
|
|
||||||
|
|
||||||
base_url = "http://127.0.0.1:5000"
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 测试系统初始化
|
|
||||||
logger.info("测试系统初始化...")
|
|
||||||
response = requests.post(f"{base_url}/api/init")
|
|
||||||
if response.status_code == 200:
|
|
||||||
result = response.json()
|
|
||||||
logger.info(f"✅ 系统初始化: {result.get('message')}")
|
|
||||||
else:
|
|
||||||
logger.warning(f"⚠️ 系统可能已初始化: {response.status_code}")
|
|
||||||
|
|
||||||
# 测试文本上传(内存处理)
|
|
||||||
logger.info("测试文本上传(内存处理)...")
|
|
||||||
text_data = {
|
|
||||||
"texts": [
|
|
||||||
"优化后的文本处理测试",
|
|
||||||
"内存模式处理文本",
|
|
||||||
"自动清理临时文件"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
response = requests.post(
|
|
||||||
f"{base_url}/api/upload/texts",
|
|
||||||
json=text_data,
|
|
||||||
headers={'Content-Type': 'application/json'}
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status_code == 200:
|
|
||||||
result = response.json()
|
|
||||||
logger.info(f"✅ 文本上传成功: {result.get('message')}")
|
|
||||||
logger.info(f" 处理方法: {result.get('processing_method')}")
|
|
||||||
logger.info(f" 处理数量: {result.get('processed_texts')}")
|
|
||||||
else:
|
|
||||||
logger.error(f"❌ 文本上传失败: {response.status_code}")
|
|
||||||
logger.error(f" 错误信息: {response.text}")
|
|
||||||
|
|
||||||
# 测试图像上传(智能处理)
|
|
||||||
logger.info("测试图像上传(智能处理)...")
|
|
||||||
|
|
||||||
# 创建测试图像
|
|
||||||
test_image = Image.new('RGB', (200, 200), color='blue')
|
|
||||||
img_buffer = BytesIO()
|
|
||||||
test_image.save(img_buffer, format='PNG')
|
|
||||||
img_buffer.seek(0)
|
|
||||||
|
|
||||||
files = {'files': ('test_image.png', img_buffer, 'image/png')}
|
|
||||||
|
|
||||||
response = requests.post(f"{base_url}/api/upload/images", files=files)
|
|
||||||
|
|
||||||
if response.status_code == 200:
|
|
||||||
result = response.json()
|
|
||||||
logger.info(f"✅ 图像上传成功: {result.get('message')}")
|
|
||||||
logger.info(f" 处理方法: {result.get('processing_methods')}")
|
|
||||||
logger.info(f" 处理数量: {result.get('processed_files')}")
|
|
||||||
else:
|
|
||||||
logger.error(f"❌ 图像上传失败: {response.status_code}")
|
|
||||||
logger.error(f" 错误信息: {response.text}")
|
|
||||||
|
|
||||||
# 测试文搜文
|
|
||||||
logger.info("测试文搜文...")
|
|
||||||
search_data = {
|
|
||||||
"query": "优化处理",
|
|
||||||
"top_k": 3
|
|
||||||
}
|
|
||||||
|
|
||||||
response = requests.post(
|
|
||||||
f"{base_url}/api/search/text_to_text",
|
|
||||||
json=search_data,
|
|
||||||
headers={'Content-Type': 'application/json'}
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status_code == 200:
|
|
||||||
result = response.json()
|
|
||||||
logger.info(f"✅ 文搜文成功: 找到 {result.get('count')} 个结果")
|
|
||||||
for i, res in enumerate(result.get('results', [])):
|
|
||||||
logger.info(f" 结果 {i}: 相似度 {res['score']:.3f}")
|
|
||||||
else:
|
|
||||||
logger.error(f"❌ 文搜文失败: {response.status_code}")
|
|
||||||
|
|
||||||
# 测试系统统计
|
|
||||||
logger.info("测试系统统计...")
|
|
||||||
response = requests.get(f"{base_url}/api/stats")
|
|
||||||
|
|
||||||
if response.status_code == 200:
|
|
||||||
result = response.json()
|
|
||||||
stats = result.get('stats', {})
|
|
||||||
logger.info(f"✅ 系统统计:")
|
|
||||||
logger.info(f" 文本数量: {stats.get('text_count', 0)}")
|
|
||||||
logger.info(f" 图像数量: {stats.get('image_count', 0)}")
|
|
||||||
logger.info(f" 后端: {result.get('backend')}")
|
|
||||||
else:
|
|
||||||
logger.error(f"❌ 获取系统统计失败: {response.status_code}")
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
except requests.exceptions.ConnectionError:
|
|
||||||
logger.error("❌ 无法连接到Web服务器,请确保服务器正在运行")
|
|
||||||
return False
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ Web API测试失败: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def test_temp_file_cleanup():
|
|
||||||
"""测试临时文件清理"""
|
|
||||||
print("\n" + "="*60)
|
|
||||||
print("测试临时文件清理")
|
|
||||||
print("="*60)
|
|
||||||
|
|
||||||
try:
|
|
||||||
from optimized_file_handler import get_file_handler
|
|
||||||
|
|
||||||
file_handler = get_file_handler()
|
|
||||||
|
|
||||||
# 创建多个临时文件
|
|
||||||
temp_paths = []
|
|
||||||
for i in range(3):
|
|
||||||
temp_path = file_handler.get_temp_file_for_model(
|
|
||||||
BytesIO(b"test content"), f"test_{i}.txt"
|
|
||||||
)
|
|
||||||
if temp_path:
|
|
||||||
temp_paths.append(temp_path)
|
|
||||||
logger.info(f"创建临时文件: {temp_path}")
|
|
||||||
|
|
||||||
# 验证文件存在
|
|
||||||
existing_count = sum(1 for path in temp_paths if os.path.exists(path))
|
|
||||||
logger.info(f"创建的临时文件数量: {existing_count}")
|
|
||||||
|
|
||||||
# 清理所有临时文件
|
|
||||||
file_handler.cleanup_all_temp_files()
|
|
||||||
|
|
||||||
# 验证文件已清理
|
|
||||||
remaining_count = sum(1 for path in temp_paths if os.path.exists(path))
|
|
||||||
logger.info(f"清理后剩余文件数量: {remaining_count}")
|
|
||||||
|
|
||||||
if remaining_count == 0:
|
|
||||||
logger.info("✅ 临时文件清理测试成功")
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
logger.warning(f"⚠️ 仍有 {remaining_count} 个文件未清理")
|
|
||||||
return False
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 临时文件清理测试失败: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""主测试函数"""
|
|
||||||
print("🚀 开始测试优化后的系统功能")
|
|
||||||
print("="*60)
|
|
||||||
|
|
||||||
# 启动Web服务器
|
|
||||||
web_process = start_web_server()
|
|
||||||
|
|
||||||
try:
|
|
||||||
test_results = []
|
|
||||||
|
|
||||||
# 运行各项测试
|
|
||||||
tests = [
|
|
||||||
("文件处理器基础功能", test_optimized_file_handler),
|
|
||||||
("内存处理功能", test_memory_processing),
|
|
||||||
("临时文件清理", test_temp_file_cleanup),
|
|
||||||
("Web API功能", test_web_api_optimized),
|
|
||||||
]
|
|
||||||
|
|
||||||
for test_name, test_func in tests:
|
|
||||||
logger.info(f"\n🔍 开始测试: {test_name}")
|
|
||||||
try:
|
|
||||||
result = test_func()
|
|
||||||
test_results.append((test_name, result))
|
|
||||||
if result:
|
|
||||||
logger.info(f"✅ {test_name} - 通过")
|
|
||||||
else:
|
|
||||||
logger.warning(f"⚠️ {test_name} - 失败")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ {test_name} - 异常: {e}")
|
|
||||||
test_results.append((test_name, False))
|
|
||||||
|
|
||||||
# 输出测试总结
|
|
||||||
print("\n" + "="*60)
|
|
||||||
print("测试结果总结")
|
|
||||||
print("="*60)
|
|
||||||
|
|
||||||
passed = sum(1 for _, result in test_results if result)
|
|
||||||
total = len(test_results)
|
|
||||||
|
|
||||||
for test_name, result in test_results:
|
|
||||||
status = "✅ 通过" if result else "❌ 失败"
|
|
||||||
print(f"{test_name}: {status}")
|
|
||||||
|
|
||||||
print(f"\n总体结果: {passed}/{total} 项测试通过")
|
|
||||||
|
|
||||||
if passed == total:
|
|
||||||
print("🎉 所有测试通过!优化系统功能正常")
|
|
||||||
else:
|
|
||||||
print("⚠️ 部分测试失败,请检查相关功能")
|
|
||||||
|
|
||||||
return passed == total
|
|
||||||
|
|
||||||
finally:
|
|
||||||
# 确保停止Web服务器
|
|
||||||
stop_web_server(web_process)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
success = main()
|
|
||||||
sys.exit(0 if success else 1)
|
|
||||||
@ -1,65 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO, format='%(message)s')
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
def test_system():
|
|
||||||
logger.info("🚀 启动系统测试...")
|
|
||||||
|
|
||||||
# 创建目录
|
|
||||||
for dir_name in ["uploads", "sample_images", "text_data"]:
|
|
||||||
Path(dir_name).mkdir(exist_ok=True)
|
|
||||||
|
|
||||||
# 测试基础模块
|
|
||||||
try:
|
|
||||||
import torch
|
|
||||||
import transformers
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
import flask
|
|
||||||
logger.info("✅ 基础模块正常")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 基础模块失败: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# 测试VDB连接
|
|
||||||
try:
|
|
||||||
import pymochow
|
|
||||||
from pymochow.configuration import Configuration
|
|
||||||
from pymochow.auth.bce_credentials import BceCredentials
|
|
||||||
|
|
||||||
config = Configuration(
|
|
||||||
credentials=BceCredentials("root", "vdb$yjr9ln3n0td"),
|
|
||||||
endpoint="http://180.76.96.191:5287"
|
|
||||||
)
|
|
||||||
client = pymochow.MochowClient(config)
|
|
||||||
databases = client.list_databases()
|
|
||||||
client.close()
|
|
||||||
|
|
||||||
logger.info(f"✅ VDB连接成功,{len(databases)}个数据库")
|
|
||||||
except ImportError:
|
|
||||||
logger.error("❌ 需要安装: pip install pymochow")
|
|
||||||
return False
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ VDB连接失败: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# 测试系统模块
|
|
||||||
try:
|
|
||||||
from multimodal_retrieval_vdb_only import MultimodalRetrievalVDBOnly
|
|
||||||
from baidu_vdb_production import BaiduVDBProduction
|
|
||||||
logger.info("✅ 系统模块正常")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 系统模块失败: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
logger.info("🎉 系统测试完成!")
|
|
||||||
logger.info("启动Web应用: python web_app_vdb_production.py")
|
|
||||||
return True
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test_system()
|
|
||||||
@ -1,102 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
测试MongoDB和BOS存储集成
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import logging
|
|
||||||
from mongodb_manager import get_mongodb_manager
|
|
||||||
from baidu_bos_manager import get_bos_manager
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
def test_mongodb_connection():
|
|
||||||
"""测试MongoDB连接"""
|
|
||||||
try:
|
|
||||||
mongodb_mgr = get_mongodb_manager()
|
|
||||||
stats = mongodb_mgr.get_stats()
|
|
||||||
logger.info(f"✅ MongoDB连接成功,统计信息: {stats}")
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ MongoDB连接失败: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def test_bos_connection():
|
|
||||||
"""测试BOS连接"""
|
|
||||||
try:
|
|
||||||
bos_mgr = get_bos_manager()
|
|
||||||
objects = bos_mgr.list_objects(max_keys=5)
|
|
||||||
logger.info(f"✅ BOS连接成功,找到 {len(objects)} 个对象")
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ BOS连接失败: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def test_file_upload_workflow():
|
|
||||||
"""测试文件上传工作流"""
|
|
||||||
try:
|
|
||||||
# 创建测试文件
|
|
||||||
test_file = "/tmp/test_storage.txt"
|
|
||||||
with open(test_file, 'w', encoding='utf-8') as f:
|
|
||||||
f.write("这是一个测试文件,用于验证存储集成功能。")
|
|
||||||
|
|
||||||
# 获取管理器
|
|
||||||
mongodb_mgr = get_mongodb_manager()
|
|
||||||
bos_mgr = get_bos_manager()
|
|
||||||
|
|
||||||
# 上传到BOS
|
|
||||||
bos_result = bos_mgr.upload_file(test_file)
|
|
||||||
logger.info(f"✅ BOS上传成功: {bos_result['bos_key']}")
|
|
||||||
|
|
||||||
# 存储元数据到MongoDB
|
|
||||||
file_id = mongodb_mgr.store_file_metadata(
|
|
||||||
file_path=test_file,
|
|
||||||
file_type="text",
|
|
||||||
bos_key=bos_result["bos_key"],
|
|
||||||
additional_info={
|
|
||||||
"test": True,
|
|
||||||
"bos_url": bos_result["url"]
|
|
||||||
}
|
|
||||||
)
|
|
||||||
logger.info(f"✅ MongoDB元数据存储成功: {file_id}")
|
|
||||||
|
|
||||||
# 存储向量元数据
|
|
||||||
mongodb_mgr.store_vector_metadata(
|
|
||||||
file_id=file_id,
|
|
||||||
vector_type="text_vector",
|
|
||||||
vdb_id=bos_result["bos_key"],
|
|
||||||
vector_info={"test": True}
|
|
||||||
)
|
|
||||||
logger.info("✅ 向量元数据存储成功")
|
|
||||||
|
|
||||||
# 清理测试文件
|
|
||||||
os.remove(test_file)
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 文件上传工作流测试失败: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
logger.info("🚀 开始存储集成测试...")
|
|
||||||
|
|
||||||
# 测试MongoDB连接
|
|
||||||
mongodb_ok = test_mongodb_connection()
|
|
||||||
|
|
||||||
# 测试BOS连接
|
|
||||||
bos_ok = test_bos_connection()
|
|
||||||
|
|
||||||
# 测试完整工作流
|
|
||||||
workflow_ok = test_file_upload_workflow()
|
|
||||||
|
|
||||||
# 总结
|
|
||||||
if mongodb_ok and bos_ok and workflow_ok:
|
|
||||||
logger.info("🎉 所有存储集成测试通过!")
|
|
||||||
else:
|
|
||||||
logger.error("❌ 存储集成测试失败")
|
|
||||||
logger.error(f"MongoDB: {'✅' if mongodb_ok else '❌'}")
|
|
||||||
logger.error(f"BOS: {'✅' if bos_ok else '❌'}")
|
|
||||||
logger.error(f"工作流: {'✅' if workflow_ok else '❌'}")
|
|
||||||
@ -1,80 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
系统启动测试
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# 设置日志
|
|
||||||
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""主测试函数"""
|
|
||||||
logger.info("🚀 启动系统测试...")
|
|
||||||
|
|
||||||
# 创建必要目录
|
|
||||||
directories = ["uploads", "sample_images", "text_data"]
|
|
||||||
for dir_name in directories:
|
|
||||||
Path(dir_name).mkdir(exist_ok=True)
|
|
||||||
logger.info(f"✅ 创建目录: {dir_name}")
|
|
||||||
|
|
||||||
# 1. 测试基础模块导入
|
|
||||||
logger.info("\n📦 测试基础模块...")
|
|
||||||
try:
|
|
||||||
import torch
|
|
||||||
import transformers
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
import flask
|
|
||||||
logger.info("✅ 基础模块导入成功")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 基础模块导入失败: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# 2. 测试PyMochow
|
|
||||||
logger.info("\n🔗 测试百度VDB SDK...")
|
|
||||||
try:
|
|
||||||
import pymochow
|
|
||||||
from pymochow.configuration import Configuration
|
|
||||||
from pymochow.auth.bce_credentials import BceCredentials
|
|
||||||
|
|
||||||
# 测试连接
|
|
||||||
config = Configuration(
|
|
||||||
credentials=BceCredentials("root", "vdb$yjr9ln3n0td"),
|
|
||||||
endpoint="http://180.76.96.191:5287"
|
|
||||||
)
|
|
||||||
client = pymochow.MochowClient(config)
|
|
||||||
databases = client.list_databases()
|
|
||||||
client.close()
|
|
||||||
|
|
||||||
logger.info(f"✅ VDB连接成功,发现 {len(databases)} 个数据库")
|
|
||||||
except ImportError:
|
|
||||||
logger.error("❌ PyMochow未安装,请运行: pip install pymochow")
|
|
||||||
return False
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ VDB连接失败: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# 3. 测试系统模块
|
|
||||||
logger.info("\n🔧 测试系统模块...")
|
|
||||||
try:
|
|
||||||
from multimodal_retrieval_vdb_only import MultimodalRetrievalVDBOnly
|
|
||||||
from baidu_vdb_production import BaiduVDBProduction
|
|
||||||
logger.info("✅ 系统模块导入成功")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 系统模块导入失败: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
logger.info("\n🎉 系统测试完成!所有组件正常")
|
|
||||||
logger.info("可以启动Web应用: python web_app_vdb_production.py")
|
|
||||||
return True
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
success = main()
|
|
||||||
if not success:
|
|
||||||
sys.exit(1)
|
|
||||||
@ -1,220 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
VDB集成测试 - 使用现有的多模态系统测试VDB功能
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import numpy as np
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
import logging
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
# 设置日志
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
def test_vdb_with_existing_system():
|
|
||||||
"""使用现有的多模态系统测试VDB集成"""
|
|
||||||
print("=" * 60)
|
|
||||||
print("VDB集成测试 - 使用现有多模态系统")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 导入现有的多模态系统
|
|
||||||
from multimodal_retrieval_multigpu import MultiGPUMultimodalRetrieval
|
|
||||||
|
|
||||||
print("1. 初始化多模态检索系统...")
|
|
||||||
retrieval_system = MultiGPUMultimodalRetrieval()
|
|
||||||
|
|
||||||
if retrieval_system.model is None:
|
|
||||||
raise Exception("模型加载失败")
|
|
||||||
|
|
||||||
print("✅ 多模态系统初始化成功")
|
|
||||||
|
|
||||||
# 准备测试数据
|
|
||||||
test_texts = [
|
|
||||||
"一只可爱的小猫在阳光下睡觉",
|
|
||||||
"现代化的城市建筑群",
|
|
||||||
"美丽的自然风景和山脉",
|
|
||||||
"科技产品和电子设备",
|
|
||||||
"传统的中国文化艺术"
|
|
||||||
]
|
|
||||||
|
|
||||||
print(f"\n2. 测试文本编码功能...")
|
|
||||||
text_vectors = retrieval_system.encode_text_batch(test_texts)
|
|
||||||
print(f"✅ 文本向量生成成功: {text_vectors.shape}")
|
|
||||||
|
|
||||||
# 构建文本索引
|
|
||||||
print(f"\n3. 构建文本索引...")
|
|
||||||
retrieval_system.build_text_index_parallel(test_texts)
|
|
||||||
print("✅ 文本索引构建完成")
|
|
||||||
|
|
||||||
# 测试文搜文
|
|
||||||
print(f"\n4. 测试文搜文功能...")
|
|
||||||
query = "小动物"
|
|
||||||
results = retrieval_system.search_text_by_text(query, top_k=3)
|
|
||||||
print(f"查询: {query}")
|
|
||||||
for i, (text, score) in enumerate(results, 1):
|
|
||||||
print(f" {i}. {text} (相似度: {score:.4f})")
|
|
||||||
|
|
||||||
# 测试图像功能(如果有图像文件)
|
|
||||||
image_files = []
|
|
||||||
sample_dir = "sample_images"
|
|
||||||
if os.path.exists(sample_dir):
|
|
||||||
for ext in ['jpg', 'jpeg', 'png', 'gif']:
|
|
||||||
import glob
|
|
||||||
pattern = os.path.join(sample_dir, f"*.{ext}")
|
|
||||||
image_files.extend(glob.glob(pattern))
|
|
||||||
|
|
||||||
if image_files:
|
|
||||||
print(f"\n5. 测试图像编码功能...")
|
|
||||||
# 只测试前3张图像
|
|
||||||
test_images = image_files[:3]
|
|
||||||
image_vectors = retrieval_system.encode_image_batch(test_images)
|
|
||||||
print(f"✅ 图像向量生成成功: {image_vectors.shape}")
|
|
||||||
|
|
||||||
print(f"\n6. 构建图像索引...")
|
|
||||||
retrieval_system.build_image_index_parallel(test_images)
|
|
||||||
print("✅ 图像索引构建完成")
|
|
||||||
|
|
||||||
# 测试文搜图
|
|
||||||
print(f"\n7. 测试文搜图功能...")
|
|
||||||
query = "图片"
|
|
||||||
results = retrieval_system.search_images_by_text(query, top_k=2)
|
|
||||||
print(f"查询: {query}")
|
|
||||||
for i, (image_path, score) in enumerate(results, 1):
|
|
||||||
print(f" {i}. {os.path.basename(image_path)} (相似度: {score:.4f})")
|
|
||||||
else:
|
|
||||||
print(f"\n5. 跳过图像测试 - 未找到图像文件")
|
|
||||||
|
|
||||||
print(f"\n✅ 多模态系统功能测试完成!")
|
|
||||||
print("系统支持的检索模式:")
|
|
||||||
print("- 文搜文: ✅")
|
|
||||||
print("- 文搜图: ✅" if image_files else "- 文搜图: ⚠️ (需要图像数据)")
|
|
||||||
print("- 图搜文: ✅" if image_files else "- 图搜文: ⚠️ (需要图像数据)")
|
|
||||||
print("- 图搜图: ✅" if image_files else "- 图搜图: ⚠️ (需要图像数据)")
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ 测试失败: {e}")
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
return False
|
|
||||||
|
|
||||||
def test_vdb_connection_only():
|
|
||||||
"""仅测试VDB连接功能"""
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("VDB连接测试")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
try:
|
|
||||||
import pymochow
|
|
||||||
from pymochow.configuration import Configuration
|
|
||||||
from pymochow.auth.bce_credentials import BceCredentials
|
|
||||||
|
|
||||||
# VDB配置
|
|
||||||
account = "root"
|
|
||||||
api_key = "vdb$yjr9ln3n0td"
|
|
||||||
endpoint = "http://180.76.96.191:5287"
|
|
||||||
|
|
||||||
print("1. 测试VDB连接...")
|
|
||||||
config = Configuration(
|
|
||||||
credentials=BceCredentials(account, api_key),
|
|
||||||
endpoint=endpoint
|
|
||||||
)
|
|
||||||
client = pymochow.MochowClient(config)
|
|
||||||
|
|
||||||
print("2. 查询数据库列表...")
|
|
||||||
db_list = client.list_databases()
|
|
||||||
print(f"✅ VDB连接成功,数据库数量: {len(db_list)}")
|
|
||||||
|
|
||||||
# 创建测试数据库
|
|
||||||
test_db_name = f"test_multimodal_{int(time.time())}"
|
|
||||||
print(f"3. 创建测试数据库: {test_db_name}")
|
|
||||||
db = client.create_database(test_db_name)
|
|
||||||
print("✅ 测试数据库创建成功")
|
|
||||||
|
|
||||||
# 清理测试数据库
|
|
||||||
print("4. 清理测试数据库...")
|
|
||||||
client.drop_database(test_db_name)
|
|
||||||
print("✅ 测试数据库清理完成")
|
|
||||||
|
|
||||||
client.close()
|
|
||||||
print("✅ VDB连接测试完成")
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ VDB连接测试失败: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def create_sample_data():
|
|
||||||
"""创建示例数据用于测试"""
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("创建示例数据")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 创建示例文本数据
|
|
||||||
sample_texts = [
|
|
||||||
"一只橙色的小猫在花园里玩耍",
|
|
||||||
"现代化的摩天大楼群在夜晚闪闪发光",
|
|
||||||
"壮丽的山脉和清澈的湖水",
|
|
||||||
"最新的智能手机和平板电脑",
|
|
||||||
"传统的中国书法艺术作品",
|
|
||||||
"美味的中式料理和茶文化",
|
|
||||||
"春天的樱花盛开景象",
|
|
||||||
"海边的日落和波浪",
|
|
||||||
"森林中的小径和绿色植物",
|
|
||||||
"城市公园里的人们在锻炼"
|
|
||||||
]
|
|
||||||
|
|
||||||
# 保存文本数据
|
|
||||||
text_dir = "text_data"
|
|
||||||
os.makedirs(text_dir, exist_ok=True)
|
|
||||||
|
|
||||||
text_file = os.path.join(text_dir, "sample_texts.json")
|
|
||||||
with open(text_file, 'w', encoding='utf-8') as f:
|
|
||||||
json.dump(sample_texts, f, ensure_ascii=False, indent=2)
|
|
||||||
|
|
||||||
print(f"✅ 创建示例文本数据: {len(sample_texts)} 条")
|
|
||||||
print(f" 保存位置: {text_file}")
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ 创建示例数据失败: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
print("🚀 开始VDB集成测试")
|
|
||||||
|
|
||||||
# 1. 创建示例数据
|
|
||||||
create_sample_data()
|
|
||||||
|
|
||||||
# 2. 测试VDB连接
|
|
||||||
vdb_ok = test_vdb_connection_only()
|
|
||||||
|
|
||||||
# 3. 测试多模态系统
|
|
||||||
if vdb_ok:
|
|
||||||
multimodal_ok = test_vdb_with_existing_system()
|
|
||||||
|
|
||||||
if multimodal_ok:
|
|
||||||
print(f"\n🎉 所有测试通过!")
|
|
||||||
print("✅ VDB连接正常")
|
|
||||||
print("✅ 多模态系统功能正常")
|
|
||||||
print("✅ 向量编码和检索功能正常")
|
|
||||||
print("\n📋 下一步建议:")
|
|
||||||
print("1. 上传更多图像和文本数据")
|
|
||||||
print("2. 启动Web应用进行交互式测试")
|
|
||||||
print("3. 测试跨模态检索功能")
|
|
||||||
else:
|
|
||||||
print(f"\n⚠️ 多模态系统测试失败")
|
|
||||||
else:
|
|
||||||
print(f"\n❌ VDB连接测试失败,请检查配置")
|
|
||||||
|
|
||||||
print(f"\n测试完成时间: {time.strftime('%Y-%m-%d %H:%M:%S')}")
|
|
||||||
1960
vdb使用说明.md
1960
vdb使用说明.md
File diff suppressed because it is too large
Load Diff
63
web_app.log
63
web_app.log
@ -1,63 +0,0 @@
|
|||||||
nohup: ignoring input
|
|
||||||
INFO:__main__:🚀 å<>¯åŠ¨æ—¶è‡ªåŠ¨åˆ<C3A5>始化VDB多模æ€<C3A6>检索系统...
|
|
||||||
INFO:multimodal_retrieval_vdb:检测到 2 个GPU
|
|
||||||
INFO:multimodal_retrieval_vdb:使用GPU: [0, 1], 主设备: cuda:0
|
|
||||||
INFO:multimodal_retrieval_vdb:GPU内å˜å·²æ¸…ç<E280A6>†
|
|
||||||
INFO:multimodal_retrieval_vdb:æ£åœ¨åŠ è½½æ¨¡åž‹åˆ°GPU: [0, 1]
|
|
||||||
INFO:multimodal_retrieval_vdb:GPU内å˜å·²æ¸…ç<E280A6>†
|
|
||||||
🚀 å<>¯åЍVDB多模æ€<C3A6>检索Web应用
|
|
||||||
============================================================
|
|
||||||
访问地å<EFBFBD>€: http://localhost:5000
|
|
||||||
新功能:
|
|
||||||
🗄ï¸<C3AF> 百度VDB - å<>‘é‡<C3A9>æ•°æ<C2B0>®åº“å˜å‚¨
|
|
||||||
📊 实时统计 - VDBæ•°æ<C2B0>®ç»Ÿè®¡ä¿¡æ<C2A1>¯
|
|
||||||
🔄 æ•°æ<C2B0>®å<C2AE>Œæ¥ - 本地文件到VDBå˜å‚¨
|
|
||||||
支æŒ<EFBFBD>功能:
|
|
||||||
ðŸ“<C5B8> æ–‡æ<E280A1>œæ–‡ - 文本查找相似文本
|
|
||||||
🖼ï¸<C3AF> æ–‡æ<E280A1>œå›¾ - 文本查找相关图片
|
|
||||||
ðŸ“<C5B8> 图æ<C2BE>œæ–‡ - 图片查找相关文本
|
|
||||||
🖼ï¸<C3AF> 图æ<C2BE>œå›¾ - 图片查找相似图片
|
|
||||||
📤 批é‡<C3A9>ä¸Šä¼ - 图片和文本数æ<C2B0>®ç®¡ç<C2A1>†
|
|
||||||
GPUé…<EFBFBD>ç½®:
|
|
||||||
🖥ï¸<C3AF> 检测到 2 个GPU
|
|
||||||
GPU 0: NVIDIA GeForce RTX 4090 (23.6GB)
|
|
||||||
GPU 1: NVIDIA GeForce RTX 4090 (23.6GB)
|
|
||||||
VDBé…<EFBFBD>ç½®:
|
|
||||||
ðŸŒ<C5B8> æœ<C3A6>务器: http://180.76.96.191:5287
|
|
||||||
👤 用户: root
|
|
||||||
🗄ï¸<C3AF> æ•°æ<C2B0>®åº“: multimodal_retrieval
|
|
||||||
============================================================
|
|
||||||
Loading checkpoint shards: 0%| | 0/4 [00:00<?, ?it/s]
Loading checkpoint shards: 25%|██▌ | 1/4 [03:18<09:56, 198.90s/it]
Loading checkpoint shards: 25%|██▌ | 1/4 [03:25<10:15, 205.19s/it]
|
|
||||||
WARNING:multimodal_retrieval_vdb:ç½‘ç»œåŠ è½½å¤±è´¥ï¼Œå°<C3A5>试本地缓å˜: CUDA out of memory. Tried to allocate 130.00 MiB. GPU 0 has a total capacity of 23.64 GiB of which 117.00 MiB is free. Process 3982470 has 384.00 MiB memory in use. Process 729183 has 2.64 GiB memory in use. Process 726298 has 7.43 GiB memory in use. Process 726164 WARNING:multimodal_retrieval_vdb:Tokenizerç½‘ç»œåŠ è½½å¤±è´¥ï¼Œå°<C3A5>试本地: (MaxRetryError("HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /api/models/OpenSearch-AI/Ops-MM-embedding-v1-7B/tree/main/additional_chat_templates?recursive=False&expand=False (Caused by NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x7f92386b4280>: Failed to establish a new connection: [Errno 101] Network is unreachable'))"), '(Request ID: 103ac836-6599-4fe2-a569-aed9c945525c)')
|
|
||||||
The image processor of type `Qwen2VLImageProcessor` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. Note that this behavior will be extended to all models in a future release.
|
|
||||||
WARNING:multimodal_retrieval_vdb:ProcessoråŠ è½½å¤±è´¥ï¼Œä½¿ç”¨tokenizer: (MaxRetryError("HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /api/models/OpenSearch-AI/Ops-MM-embedding-v1-7B/tree/main/additional_chat_templates?recursive=False&expand=False (Caused by NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x7fbad64d1510>: Failed to establish a new connection: [Errno 101] Network is unreachable'))"), '(Request ID: 96f18121-7beb-4e1a-87cd-c50edf682933)')
|
|
||||||
You have video processor config saved in `preprocessor.json` file which is deprecated. Video processor configs should be saved in their own `video_preprocessor.json` file. You can rename the file or load and save the processor back which renames it automatically. Loading from `preprocessor.json` will be removed in v5.0.
|
|
||||||
INFO:multimodal_retrieval_vdb:æ¨¡åž‹åŠ è½½å®Œæˆ<C3A6>
|
|
||||||
INFO:baidu_vdb_backend:✅ æˆ<C3A6>功连接到百度VDB: http://180.76.96.191:5287
|
|
||||||
INFO:baidu_vdb_backend:使用现有数æ<C2B0>®åº“: multimodal_retrieval
|
|
||||||
INFO:baidu_vdb_backend:创建文本å<C2AC>‘é‡<C3A9>表: text_vectors
|
|
||||||
ERROR:baidu_vdb_backend:â<>Œ 创建文本表失败: Database.create_table() missing 1 required positional argument: 'partition'
|
|
||||||
ERROR:baidu_vdb_backend:â<>Œ 表æ“<C3A6>作失败: Database.create_table() missing 1 required positional argument: 'partition'
|
|
||||||
ERROR:multimodal_retrieval_vdb:â<>Œ VDBå<42>Žç«¯åˆ<C3A5>始化失败: Database.create_table() missing 1 required positional argument: 'partition'
|
|
||||||
WARNING:multimodal_retrieval_vdb:âš ï¸<C3AF> ç³»ç»Ÿå°†åœ¨æ— VDB模å¼<C3A5>下è¿<C3A8>行,数æ<C2B0>®å°†ä¸<C3A4>会æŒ<C3A6>久化
|
|
||||||
INFO:multimodal_retrieval_vdb:多模æ€<C3A6>检索系统åˆ<C3A5>始化完æˆ<C3A6>
|
|
||||||
ERROR:__main__:â<>Œ VDB系统自动åˆ<C3A5>始化失败: VDB连接失败
|
|
||||||
ERROR:__main__:Traceback (most recent call last):
|
|
||||||
File "/root/mmeb/web_app_vdb.py", line 667, in auto_initialize
|
|
||||||
raise Exception("VDB连接失败")
|
|
||||||
Exception: VDB连接失败
|
|
||||||
|
|
||||||
* Serving Flask app 'web_app_vdb'
|
|
||||||
* Debug mode: off
|
|
||||||
Address already in use
|
|
||||||
Port 5000 is in use by another program. Either identify and stop that program, or start the server with a different port.
|
|
||||||
¤±è´¥
|
|
||||||
ERROR:__main__:Traceback (most recent call last):
|
|
||||||
File "/root/mmeb/web_app_vdb.py", line 664, in auto_initialize
|
|
||||||
raise Exception("æ¨¡åž‹åŠ è½½å¤±è´¥")
|
|
||||||
Exception: æ¨¡åž‹åŠ è½½å¤±è´¥
|
|
||||||
|
|
||||||
* Serving Flask app 'web_app_vdb'
|
|
||||||
* Debug mode: off
|
|
||||||
Address already in use
|
|
||||||
Port 5000 is in use by another program. Either identify and stop that program, or start the server with a different port.
|
|
||||||
122
web_app_local.py
122
web_app_local.py
@ -14,9 +14,10 @@ import json
|
|||||||
import base64
|
import base64
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from datetime import datetime, timezone
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from flask import Flask, request, jsonify, render_template, send_from_directory
|
from flask import Flask, request, jsonify, render_template, send_from_directory, send_file
|
||||||
from werkzeug.utils import secure_filename
|
from werkzeug.utils import secure_filename
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -49,6 +50,19 @@ os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
|
|||||||
if not os.path.exists(app.config['UPLOAD_FOLDER']):
|
if not os.path.exists(app.config['UPLOAD_FOLDER']):
|
||||||
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
|
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
|
||||||
|
|
||||||
|
# 确保静态资源目录与 favicon 存在(方案二:静态文件)
|
||||||
|
app.static_folder = os.path.join(os.path.dirname(__file__), 'static')
|
||||||
|
os.makedirs(app.static_folder, exist_ok=True)
|
||||||
|
favicon_path = os.path.join(app.static_folder, 'favicon.ico')
|
||||||
|
if not os.path.exists(favicon_path):
|
||||||
|
# 写入一个 1x1 透明 PNG 转 ICO 的简易占位图标(使用 PNG 作为内容也可被多数浏览器识别)
|
||||||
|
# 这里直接写入一个极小的 PNG 文件,并命名为 .ico 以简化处理
|
||||||
|
transparent_png_base64 = (
|
||||||
|
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8Xw8AAqMBhHqg7T8AAAAASUVORK5CYII="
|
||||||
|
)
|
||||||
|
with open(favicon_path, 'wb') as f:
|
||||||
|
f.write(base64.b64decode(transparent_png_base64))
|
||||||
|
|
||||||
# 创建文件处理器
|
# 创建文件处理器
|
||||||
from optimized_file_handler import OptimizedFileHandler
|
from optimized_file_handler import OptimizedFileHandler
|
||||||
file_handler = OptimizedFileHandler(local_storage_dir=app.config['UPLOAD_FOLDER'])
|
file_handler = OptimizedFileHandler(local_storage_dir=app.config['UPLOAD_FOLDER'])
|
||||||
@ -97,13 +111,25 @@ def index():
|
|||||||
"""首页"""
|
"""首页"""
|
||||||
return render_template('local_index.html')
|
return render_template('local_index.html')
|
||||||
|
|
||||||
|
@app.route('/favicon.ico')
|
||||||
|
def favicon():
|
||||||
|
"""提供静态 favicon(方案二)"""
|
||||||
|
return send_from_directory(app.static_folder, 'favicon.ico')
|
||||||
|
|
||||||
@app.route('/api/stats', methods=['GET'])
|
@app.route('/api/stats', methods=['GET'])
|
||||||
def get_stats():
|
def get_stats():
|
||||||
"""获取系统统计信息"""
|
"""获取系统统计信息"""
|
||||||
try:
|
try:
|
||||||
retrieval = init_retrieval_system()
|
retrieval = init_retrieval_system()
|
||||||
stats = retrieval.get_stats()
|
stats = retrieval.get_stats()
|
||||||
return jsonify({"success": True, "stats": stats})
|
debug_info = {
|
||||||
|
"server_time": datetime.now(timezone.utc).isoformat(),
|
||||||
|
"vector_dimension": stats.get("vector_dimension"),
|
||||||
|
"total_vectors": stats.get("total_vectors"),
|
||||||
|
"model_path": stats.get("model_path"),
|
||||||
|
"index_path": stats.get("index_path"),
|
||||||
|
}
|
||||||
|
return jsonify({"success": True, "stats": stats, "debug": debug_info})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取统计信息失败: {str(e)}")
|
logger.error(f"获取统计信息失败: {str(e)}")
|
||||||
return jsonify({"success": False, "error": str(e)}), 500
|
return jsonify({"success": False, "error": str(e)}), 500
|
||||||
@ -136,10 +162,16 @@ def add_text():
|
|||||||
# 保存索引
|
# 保存索引
|
||||||
retrieval.save_index()
|
retrieval.save_index()
|
||||||
|
|
||||||
|
stats = retrieval.get_stats()
|
||||||
return jsonify({
|
return jsonify({
|
||||||
"success": True,
|
"success": True,
|
||||||
"message": "文本添加成功",
|
"message": "文本添加成功",
|
||||||
"text_id": text_ids[0] if text_ids else None
|
"text_id": text_ids[0] if text_ids else None,
|
||||||
|
"debug": {
|
||||||
|
"server_time": datetime.now(timezone.utc).isoformat(),
|
||||||
|
"vector_dimension": stats.get("vector_dimension"),
|
||||||
|
"total_vectors": stats.get("total_vectors"),
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -211,10 +243,17 @@ def add_image():
|
|||||||
# 保存索引
|
# 保存索引
|
||||||
retrieval.save_index()
|
retrieval.save_index()
|
||||||
|
|
||||||
|
stats = retrieval.get_stats()
|
||||||
return jsonify({
|
return jsonify({
|
||||||
"success": True,
|
"success": True,
|
||||||
"message": "图像添加成功",
|
"message": "图像添加成功",
|
||||||
"image_id": image_ids[0] if image_ids else None
|
"image_id": image_ids[0] if image_ids else None,
|
||||||
|
"debug": {
|
||||||
|
"server_time": datetime.now(timezone.utc).isoformat(),
|
||||||
|
"vector_dimension": stats.get("vector_dimension"),
|
||||||
|
"total_vectors": stats.get("total_vectors"),
|
||||||
|
"image_size": [image.width, image.height] if 'image' in locals() else None,
|
||||||
|
}
|
||||||
})
|
})
|
||||||
else:
|
else:
|
||||||
return jsonify({"success": False, "error": "不支持的文件类型"}), 400
|
return jsonify({"success": False, "error": "不支持的文件类型"}), 400
|
||||||
@ -263,11 +302,17 @@ def search_by_text():
|
|||||||
|
|
||||||
processed_results.append(item)
|
processed_results.append(item)
|
||||||
|
|
||||||
|
stats = retrieval.get_stats()
|
||||||
return jsonify({
|
return jsonify({
|
||||||
"success": True,
|
"success": True,
|
||||||
"results": processed_results,
|
"results": processed_results,
|
||||||
"query": query,
|
"query": query,
|
||||||
"filter_type": filter_type
|
"filter_type": filter_type,
|
||||||
|
"debug": {
|
||||||
|
"server_time": datetime.now(timezone.utc).isoformat(),
|
||||||
|
"vector_dimension": stats.get("vector_dimension"),
|
||||||
|
"total_vectors": stats.get("total_vectors"),
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -339,10 +384,16 @@ def search_by_image():
|
|||||||
|
|
||||||
processed_results.append(item)
|
processed_results.append(item)
|
||||||
|
|
||||||
|
stats = retrieval.get_stats()
|
||||||
return jsonify({
|
return jsonify({
|
||||||
"success": True,
|
"success": True,
|
||||||
"results": processed_results,
|
"results": processed_results,
|
||||||
"filter_type": filter_type
|
"filter_type": filter_type,
|
||||||
|
"debug": {
|
||||||
|
"server_time": datetime.now(timezone.utc).isoformat(),
|
||||||
|
"vector_dimension": stats.get("vector_dimension"),
|
||||||
|
"total_vectors": stats.get("total_vectors"),
|
||||||
|
}
|
||||||
})
|
})
|
||||||
else:
|
else:
|
||||||
return jsonify({"success": False, "error": "不支持的文件类型"}), 400
|
return jsonify({"success": False, "error": "不支持的文件类型"}), 400
|
||||||
@ -374,9 +425,15 @@ def save_index():
|
|||||||
# 保存索引
|
# 保存索引
|
||||||
retrieval.save_index()
|
retrieval.save_index()
|
||||||
|
|
||||||
|
stats = retrieval.get_stats()
|
||||||
return jsonify({
|
return jsonify({
|
||||||
"success": True,
|
"success": True,
|
||||||
"message": "索引保存成功"
|
"message": "索引保存成功",
|
||||||
|
"debug": {
|
||||||
|
"server_time": datetime.now(timezone.utc).isoformat(),
|
||||||
|
"vector_dimension": stats.get("vector_dimension"),
|
||||||
|
"total_vectors": stats.get("total_vectors"),
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -393,9 +450,15 @@ def clear_index():
|
|||||||
# 清空索引
|
# 清空索引
|
||||||
retrieval.clear_index()
|
retrieval.clear_index()
|
||||||
|
|
||||||
|
stats = retrieval.get_stats()
|
||||||
return jsonify({
|
return jsonify({
|
||||||
"success": True,
|
"success": True,
|
||||||
"message": "索引已清空"
|
"message": "索引已清空",
|
||||||
|
"debug": {
|
||||||
|
"server_time": datetime.now(timezone.utc).isoformat(),
|
||||||
|
"vector_dimension": stats.get("vector_dimension"),
|
||||||
|
"total_vectors": stats.get("total_vectors"),
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -412,9 +475,15 @@ def list_items():
|
|||||||
# 获取所有项
|
# 获取所有项
|
||||||
items = retrieval.list_items()
|
items = retrieval.list_items()
|
||||||
|
|
||||||
|
stats = retrieval.get_stats()
|
||||||
return jsonify({
|
return jsonify({
|
||||||
"success": True,
|
"success": True,
|
||||||
"items": items
|
"items": items,
|
||||||
|
"debug": {
|
||||||
|
"server_time": datetime.now(timezone.utc).isoformat(),
|
||||||
|
"vector_dimension": stats.get("vector_dimension"),
|
||||||
|
"total_vectors": stats.get("total_vectors"),
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -447,13 +516,32 @@ def system_info():
|
|||||||
"gpu_info": gpu_info,
|
"gpu_info": gpu_info,
|
||||||
"retrieval_info": retrieval_info,
|
"retrieval_info": retrieval_info,
|
||||||
"model_path": app.config['MODEL_PATH'],
|
"model_path": app.config['MODEL_PATH'],
|
||||||
"index_path": app.config['INDEX_PATH']
|
"index_path": app.config['INDEX_PATH'],
|
||||||
|
"debug": {"server_time": datetime.now(timezone.utc).isoformat()}
|
||||||
})
|
})
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取系统信息失败: {str(e)}")
|
logger.error(f"获取系统信息失败: {str(e)}")
|
||||||
return jsonify({"success": False, "error": str(e)}), 500
|
return jsonify({"success": False, "error": str(e)}), 500
|
||||||
|
|
||||||
|
@app.route('/api/health', methods=['GET'])
|
||||||
|
def health():
|
||||||
|
"""健康检查:模型可用性、索引状态、GPU 信息"""
|
||||||
|
try:
|
||||||
|
retrieval = init_retrieval_system()
|
||||||
|
stats = retrieval.get_stats()
|
||||||
|
health_info = {
|
||||||
|
"model_loaded": True,
|
||||||
|
"vector_dimension": stats.get("vector_dimension"),
|
||||||
|
"total_vectors": stats.get("total_vectors"),
|
||||||
|
"gpu_available": torch.cuda.is_available(),
|
||||||
|
"cuda_device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0,
|
||||||
|
"server_time": datetime.now(timezone.utc).isoformat(),
|
||||||
|
}
|
||||||
|
return jsonify({"success": True, "health": health_info})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"健康检查失败: {str(e)}")
|
||||||
|
return jsonify({"success": False, "error": str(e), "health": {"model_loaded": False, "server_time": datetime.now(timezone.utc).isoformat()}}), 500
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
try:
|
try:
|
||||||
# 预初始化检索系统
|
# 预初始化检索系统
|
||||||
|
|||||||
@ -1,677 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
多GPU多模态检索系统 - Web应用
|
|
||||||
专为双GPU部署优化
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
from flask import Flask, render_template, request, jsonify, send_file, url_for
|
|
||||||
from werkzeug.utils import secure_filename
|
|
||||||
from PIL import Image
|
|
||||||
import base64
|
|
||||||
import io
|
|
||||||
import logging
|
|
||||||
import traceback
|
|
||||||
import glob
|
|
||||||
|
|
||||||
# 设置环境变量优化GPU内存
|
|
||||||
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
|
|
||||||
|
|
||||||
# 配置日志
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
app = Flask(__name__)
|
|
||||||
app.config['SECRET_KEY'] = 'multigpu_multimodal_retrieval_2024'
|
|
||||||
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max file size
|
|
||||||
|
|
||||||
# 配置上传文件夹
|
|
||||||
UPLOAD_FOLDER = 'uploads'
|
|
||||||
SAMPLE_IMAGES_FOLDER = 'sample_images'
|
|
||||||
TEXT_DATA_FOLDER = 'text_data'
|
|
||||||
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'bmp', 'webp'}
|
|
||||||
|
|
||||||
# 确保文件夹存在
|
|
||||||
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
|
|
||||||
os.makedirs(SAMPLE_IMAGES_FOLDER, exist_ok=True)
|
|
||||||
os.makedirs(TEXT_DATA_FOLDER, exist_ok=True)
|
|
||||||
|
|
||||||
# 全局检索系统实例
|
|
||||||
retrieval_system = None
|
|
||||||
|
|
||||||
def allowed_file(filename):
|
|
||||||
"""检查文件扩展名是否允许"""
|
|
||||||
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
|
|
||||||
|
|
||||||
def image_to_base64(image_path):
|
|
||||||
"""将图片转换为base64编码"""
|
|
||||||
try:
|
|
||||||
with open(image_path, "rb") as img_file:
|
|
||||||
return base64.b64encode(img_file.read()).decode('utf-8')
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"图片转换失败: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
@app.route('/')
|
|
||||||
def index():
|
|
||||||
"""主页"""
|
|
||||||
return render_template('index.html')
|
|
||||||
|
|
||||||
@app.route('/api/status')
|
|
||||||
def get_status():
|
|
||||||
"""获取系统状态"""
|
|
||||||
global retrieval_system
|
|
||||||
|
|
||||||
status = {
|
|
||||||
'initialized': retrieval_system is not None,
|
|
||||||
'gpu_count': 0,
|
|
||||||
'model_loaded': False
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
import torch
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
status['gpu_count'] = torch.cuda.device_count()
|
|
||||||
|
|
||||||
if retrieval_system and retrieval_system.model:
|
|
||||||
status['model_loaded'] = True
|
|
||||||
status['device_ids'] = retrieval_system.device_ids
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取状态失败: {e}")
|
|
||||||
|
|
||||||
return jsonify(status)
|
|
||||||
|
|
||||||
@app.route('/api/init', methods=['POST'])
|
|
||||||
def initialize_system():
|
|
||||||
"""初始化多GPU检索系统"""
|
|
||||||
global retrieval_system
|
|
||||||
|
|
||||||
try:
|
|
||||||
logger.info("正在初始化多GPU检索系统...")
|
|
||||||
|
|
||||||
# 导入多GPU检索系统
|
|
||||||
from multimodal_retrieval_multigpu import MultiGPUMultimodalRetrieval
|
|
||||||
|
|
||||||
# 初始化系统
|
|
||||||
retrieval_system = MultiGPUMultimodalRetrieval()
|
|
||||||
|
|
||||||
if retrieval_system.model is None:
|
|
||||||
raise Exception("模型加载失败")
|
|
||||||
|
|
||||||
logger.info("✅ 多GPU系统初始化成功")
|
|
||||||
|
|
||||||
return jsonify({
|
|
||||||
'success': True,
|
|
||||||
'message': '多GPU系统初始化成功',
|
|
||||||
'device_ids': retrieval_system.device_ids,
|
|
||||||
'gpu_count': len(retrieval_system.device_ids)
|
|
||||||
})
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
error_msg = f"系统初始化失败: {str(e)}"
|
|
||||||
logger.error(error_msg)
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': error_msg
|
|
||||||
}), 500
|
|
||||||
|
|
||||||
@app.route('/api/search/text_to_text', methods=['POST'])
|
|
||||||
def search_text_to_text():
|
|
||||||
"""文本搜索文本"""
|
|
||||||
return handle_search('text_to_text')
|
|
||||||
|
|
||||||
@app.route('/api/search/text_to_image', methods=['POST'])
|
|
||||||
def search_text_to_image():
|
|
||||||
"""文本搜索图片"""
|
|
||||||
return handle_search('text_to_image')
|
|
||||||
|
|
||||||
@app.route('/api/search/image_to_text', methods=['POST'])
|
|
||||||
def search_image_to_text():
|
|
||||||
"""图片搜索文本"""
|
|
||||||
return handle_search('image_to_text')
|
|
||||||
|
|
||||||
@app.route('/api/search/image_to_image', methods=['POST'])
|
|
||||||
def search_image_to_image():
|
|
||||||
"""图片搜索图片"""
|
|
||||||
return handle_search('image_to_image')
|
|
||||||
|
|
||||||
@app.route('/api/search', methods=['POST'])
|
|
||||||
def search():
|
|
||||||
"""通用搜索接口(兼容旧版本)"""
|
|
||||||
mode = request.form.get('mode') or request.json.get('mode', 'text_to_text')
|
|
||||||
return handle_search(mode)
|
|
||||||
|
|
||||||
def handle_search(mode):
|
|
||||||
"""处理搜索请求的通用函数"""
|
|
||||||
global retrieval_system
|
|
||||||
|
|
||||||
if not retrieval_system:
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': '系统未初始化,请先点击初始化按钮'
|
|
||||||
}), 400
|
|
||||||
|
|
||||||
try:
|
|
||||||
top_k = int(request.form.get('top_k', 5))
|
|
||||||
|
|
||||||
if mode in ['text_to_text', 'text_to_image']:
|
|
||||||
# 文本查询
|
|
||||||
query = request.form.get('query') or request.json.get('query', '')
|
|
||||||
if not query.strip():
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': '请输入查询文本'
|
|
||||||
}), 400
|
|
||||||
|
|
||||||
logger.info(f"执行{mode}搜索: {query}")
|
|
||||||
|
|
||||||
# 执行搜索
|
|
||||||
if mode == 'text_to_text':
|
|
||||||
raw_results = retrieval_system.search_text_to_text(query, top_k=top_k)
|
|
||||||
# 格式化文本搜索结果
|
|
||||||
results = []
|
|
||||||
for text, score in raw_results:
|
|
||||||
results.append({
|
|
||||||
'text': text,
|
|
||||||
'score': float(score)
|
|
||||||
})
|
|
||||||
else: # text_to_image
|
|
||||||
raw_results = retrieval_system.search_text_to_image(query, top_k=top_k)
|
|
||||||
# 格式化图像搜索结果
|
|
||||||
results = []
|
|
||||||
for image_path, score in raw_results:
|
|
||||||
try:
|
|
||||||
# 读取图像并转换为base64
|
|
||||||
with open(image_path, 'rb') as img_file:
|
|
||||||
image_data = img_file.read()
|
|
||||||
image_base64 = base64.b64encode(image_data).decode('utf-8')
|
|
||||||
|
|
||||||
results.append({
|
|
||||||
'filename': os.path.basename(image_path),
|
|
||||||
'image_path': image_path,
|
|
||||||
'image_base64': image_base64,
|
|
||||||
'score': float(score)
|
|
||||||
})
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"读取图像失败 {image_path}: {e}")
|
|
||||||
results.append({
|
|
||||||
'filename': os.path.basename(image_path),
|
|
||||||
'image_path': image_path,
|
|
||||||
'image_base64': '',
|
|
||||||
'score': float(score)
|
|
||||||
})
|
|
||||||
|
|
||||||
return jsonify({
|
|
||||||
'success': True,
|
|
||||||
'mode': mode,
|
|
||||||
'query': query,
|
|
||||||
'results': results,
|
|
||||||
'result_count': len(results)
|
|
||||||
})
|
|
||||||
|
|
||||||
elif mode in ['image_to_text', 'image_to_image']:
|
|
||||||
# 图片查询
|
|
||||||
if 'image' not in request.files:
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': '请上传查询图片'
|
|
||||||
}), 400
|
|
||||||
|
|
||||||
file = request.files['image']
|
|
||||||
if file.filename == '' or not allowed_file(file.filename):
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': '请上传有效的图片文件'
|
|
||||||
}), 400
|
|
||||||
|
|
||||||
# 保存上传的图片
|
|
||||||
filename = secure_filename(file.filename)
|
|
||||||
timestamp = str(int(time.time()))
|
|
||||||
filename = f"query_{timestamp}_{filename}"
|
|
||||||
filepath = os.path.join(UPLOAD_FOLDER, filename)
|
|
||||||
file.save(filepath)
|
|
||||||
|
|
||||||
logger.info(f"执行{mode}搜索,图片: {filename}")
|
|
||||||
|
|
||||||
# 执行搜索
|
|
||||||
if mode == 'image_to_text':
|
|
||||||
raw_results = retrieval_system.search_image_to_text(filepath, top_k=top_k)
|
|
||||||
# 格式化文本搜索结果
|
|
||||||
results = []
|
|
||||||
for text, score in raw_results:
|
|
||||||
results.append({
|
|
||||||
'text': text,
|
|
||||||
'score': float(score)
|
|
||||||
})
|
|
||||||
else: # image_to_image
|
|
||||||
raw_results = retrieval_system.search_image_to_image(filepath, top_k=top_k)
|
|
||||||
# 格式化图像搜索结果
|
|
||||||
results = []
|
|
||||||
for image_path, score in raw_results:
|
|
||||||
try:
|
|
||||||
# 读取图像并转换为base64
|
|
||||||
with open(image_path, 'rb') as img_file:
|
|
||||||
image_data = img_file.read()
|
|
||||||
image_base64 = base64.b64encode(image_data).decode('utf-8')
|
|
||||||
|
|
||||||
results.append({
|
|
||||||
'filename': os.path.basename(image_path),
|
|
||||||
'image_path': image_path,
|
|
||||||
'image_base64': image_base64,
|
|
||||||
'score': float(score)
|
|
||||||
})
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"读取图像失败 {image_path}: {e}")
|
|
||||||
results.append({
|
|
||||||
'filename': os.path.basename(image_path),
|
|
||||||
'image_path': image_path,
|
|
||||||
'image_base64': '',
|
|
||||||
'score': float(score)
|
|
||||||
})
|
|
||||||
|
|
||||||
# 转换查询图片为base64
|
|
||||||
query_image_b64 = image_to_base64(filepath)
|
|
||||||
|
|
||||||
return jsonify({
|
|
||||||
'success': True,
|
|
||||||
'mode': mode,
|
|
||||||
'query_image': query_image_b64,
|
|
||||||
'results': results,
|
|
||||||
'result_count': len(results)
|
|
||||||
})
|
|
||||||
|
|
||||||
else:
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': f'不支持的搜索模式: {mode}'
|
|
||||||
}), 400
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
error_msg = f"搜索失败: {str(e)}"
|
|
||||||
logger.error(error_msg)
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': error_msg
|
|
||||||
}), 500
|
|
||||||
|
|
||||||
@app.route('/api/upload/images', methods=['POST'])
|
|
||||||
def upload_images():
|
|
||||||
"""批量上传图片"""
|
|
||||||
try:
|
|
||||||
uploaded_files = []
|
|
||||||
|
|
||||||
if 'images' not in request.files:
|
|
||||||
return jsonify({'success': False, 'message': '没有选择文件'}), 400
|
|
||||||
|
|
||||||
files = request.files.getlist('images')
|
|
||||||
|
|
||||||
for file in files:
|
|
||||||
if file and file.filename != '' and allowed_file(file.filename):
|
|
||||||
filename = secure_filename(file.filename)
|
|
||||||
timestamp = str(int(time.time()))
|
|
||||||
filename = f"{timestamp}_{filename}"
|
|
||||||
filepath = os.path.join(SAMPLE_IMAGES_FOLDER, filename)
|
|
||||||
file.save(filepath)
|
|
||||||
uploaded_files.append(filename)
|
|
||||||
|
|
||||||
return jsonify({
|
|
||||||
'success': True,
|
|
||||||
'message': f'成功上传 {len(uploaded_files)} 个图片文件',
|
|
||||||
'uploaded_count': len(uploaded_files),
|
|
||||||
'files': uploaded_files
|
|
||||||
})
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': f'上传失败: {str(e)}'
|
|
||||||
}), 500
|
|
||||||
|
|
||||||
@app.route('/api/upload/texts', methods=['POST'])
|
|
||||||
def upload_texts():
|
|
||||||
"""批量上传文本数据"""
|
|
||||||
try:
|
|
||||||
data = request.get_json()
|
|
||||||
|
|
||||||
if not data or 'texts' not in data:
|
|
||||||
return jsonify({'success': False, 'message': '没有提供文本数据'}), 400
|
|
||||||
|
|
||||||
texts = data['texts']
|
|
||||||
if not isinstance(texts, list):
|
|
||||||
return jsonify({'success': False, 'message': '文本数据格式错误'}), 400
|
|
||||||
|
|
||||||
# 保存文本数据到文件
|
|
||||||
timestamp = str(int(time.time()))
|
|
||||||
filename = f"texts_{timestamp}.json"
|
|
||||||
filepath = os.path.join(TEXT_DATA_FOLDER, filename)
|
|
||||||
|
|
||||||
with open(filepath, 'w', encoding='utf-8') as f:
|
|
||||||
json.dump(texts, f, ensure_ascii=False, indent=2)
|
|
||||||
|
|
||||||
return jsonify({
|
|
||||||
'success': True,
|
|
||||||
'message': f'成功上传 {len(texts)} 条文本',
|
|
||||||
'uploaded_count': len(texts)
|
|
||||||
})
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': f'上传失败: {str(e)}'
|
|
||||||
}), 500
|
|
||||||
|
|
||||||
@app.route('/api/upload/file', methods=['POST'])
|
|
||||||
def upload_single_file():
|
|
||||||
"""上传单个文件"""
|
|
||||||
if 'file' not in request.files:
|
|
||||||
return jsonify({'success': False, 'message': '没有选择文件'}), 400
|
|
||||||
|
|
||||||
file = request.files['file']
|
|
||||||
if file.filename == '':
|
|
||||||
return jsonify({'success': False, 'message': '没有选择文件'}), 400
|
|
||||||
|
|
||||||
if file and allowed_file(file.filename):
|
|
||||||
filename = secure_filename(file.filename)
|
|
||||||
timestamp = str(int(time.time()))
|
|
||||||
filename = f"{timestamp}_{filename}"
|
|
||||||
filepath = os.path.join(SAMPLE_IMAGES_FOLDER, filename)
|
|
||||||
file.save(filepath)
|
|
||||||
|
|
||||||
return jsonify({
|
|
||||||
'success': True,
|
|
||||||
'message': '文件上传成功',
|
|
||||||
'filename': filename
|
|
||||||
})
|
|
||||||
|
|
||||||
return jsonify({'success': False, 'message': '不支持的文件类型'}), 400
|
|
||||||
|
|
||||||
@app.route('/api/data/list', methods=['GET'])
|
|
||||||
def list_data():
|
|
||||||
"""列出已上传的数据"""
|
|
||||||
try:
|
|
||||||
# 列出图片文件
|
|
||||||
images = []
|
|
||||||
if os.path.exists(SAMPLE_IMAGES_FOLDER):
|
|
||||||
for filename in os.listdir(SAMPLE_IMAGES_FOLDER):
|
|
||||||
if allowed_file(filename):
|
|
||||||
filepath = os.path.join(SAMPLE_IMAGES_FOLDER, filename)
|
|
||||||
stat = os.stat(filepath)
|
|
||||||
images.append({
|
|
||||||
'filename': filename,
|
|
||||||
'size': stat.st_size,
|
|
||||||
'modified': stat.st_mtime
|
|
||||||
})
|
|
||||||
|
|
||||||
# 列出文本文件
|
|
||||||
texts = []
|
|
||||||
if os.path.exists(TEXT_DATA_FOLDER):
|
|
||||||
for filename in os.listdir(TEXT_DATA_FOLDER):
|
|
||||||
if filename.endswith('.json'):
|
|
||||||
filepath = os.path.join(TEXT_DATA_FOLDER, filename)
|
|
||||||
stat = os.stat(filepath)
|
|
||||||
texts.append({
|
|
||||||
'filename': filename,
|
|
||||||
'size': stat.st_size,
|
|
||||||
'modified': stat.st_mtime
|
|
||||||
})
|
|
||||||
|
|
||||||
return jsonify({
|
|
||||||
'success': True,
|
|
||||||
'data': {
|
|
||||||
'images': images,
|
|
||||||
'texts': texts
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': f'获取数据列表失败: {str(e)}'
|
|
||||||
}), 500
|
|
||||||
|
|
||||||
@app.route('/api/gpu_status')
|
|
||||||
def gpu_status():
|
|
||||||
"""获取GPU状态"""
|
|
||||||
try:
|
|
||||||
from smart_gpu_launcher import get_gpu_memory_info
|
|
||||||
gpu_info = get_gpu_memory_info()
|
|
||||||
|
|
||||||
return jsonify({
|
|
||||||
'success': True,
|
|
||||||
'gpu_info': gpu_info
|
|
||||||
})
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': f"获取GPU状态失败: {str(e)}"
|
|
||||||
}), 500
|
|
||||||
|
|
||||||
@app.route('/api/build_index', methods=['POST'])
|
|
||||||
def build_index():
|
|
||||||
"""构建检索索引"""
|
|
||||||
global retrieval_system
|
|
||||||
|
|
||||||
if not retrieval_system:
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': '系统未初始化'
|
|
||||||
}), 400
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 获取所有图片和文本文件
|
|
||||||
image_files = []
|
|
||||||
text_data = []
|
|
||||||
|
|
||||||
# 扫描图片文件
|
|
||||||
for ext in ALLOWED_EXTENSIONS:
|
|
||||||
pattern = os.path.join(SAMPLE_IMAGES_FOLDER, f"*.{ext}")
|
|
||||||
image_files.extend(glob.glob(pattern))
|
|
||||||
|
|
||||||
# 读取文本文件(支持.json和.txt格式)
|
|
||||||
text_files = glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.json"))
|
|
||||||
text_files.extend(glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.txt")))
|
|
||||||
|
|
||||||
for text_file in text_files:
|
|
||||||
try:
|
|
||||||
if text_file.endswith('.json'):
|
|
||||||
# 读取JSON格式的文本数据
|
|
||||||
with open(text_file, 'r', encoding='utf-8') as f:
|
|
||||||
data = json.load(f)
|
|
||||||
if isinstance(data, list):
|
|
||||||
text_data.extend([str(item).strip() for item in data if str(item).strip()])
|
|
||||||
else:
|
|
||||||
text_data.append(str(data).strip())
|
|
||||||
else:
|
|
||||||
# 读取TXT格式的文本数据
|
|
||||||
with open(text_file, 'r', encoding='utf-8') as f:
|
|
||||||
lines = [line.strip() for line in f.readlines() if line.strip()]
|
|
||||||
text_data.extend(lines)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"读取文本文件失败 {text_file}: {e}")
|
|
||||||
|
|
||||||
# 检查是否有数据可以构建索引
|
|
||||||
if not image_files and not text_data:
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': '没有找到可用的图片或文本数据,请先上传数据'
|
|
||||||
}), 400
|
|
||||||
|
|
||||||
# 构建索引
|
|
||||||
if image_files:
|
|
||||||
logger.info(f"构建图片索引,共 {len(image_files)} 张图片")
|
|
||||||
retrieval_system.build_image_index_parallel(image_files)
|
|
||||||
|
|
||||||
if text_data:
|
|
||||||
logger.info(f"构建文本索引,共 {len(text_data)} 条文本")
|
|
||||||
retrieval_system.build_text_index_parallel(text_data)
|
|
||||||
|
|
||||||
return jsonify({
|
|
||||||
'success': True,
|
|
||||||
'message': f'索引构建完成!图片: {len(image_files)} 张,文本: {len(text_data)} 条',
|
|
||||||
'image_count': len(image_files),
|
|
||||||
'text_count': len(text_data)
|
|
||||||
})
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"构建索引失败: {str(e)}")
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': f'构建索引失败: {str(e)}'
|
|
||||||
}), 500
|
|
||||||
|
|
||||||
@app.route('/api/data/stats', methods=['GET'])
|
|
||||||
def get_data_stats():
|
|
||||||
"""获取数据统计信息"""
|
|
||||||
try:
|
|
||||||
# 统计图片文件
|
|
||||||
image_count = 0
|
|
||||||
for ext in ALLOWED_EXTENSIONS:
|
|
||||||
pattern = os.path.join(SAMPLE_IMAGES_FOLDER, f"*.{ext}")
|
|
||||||
image_count += len(glob.glob(pattern))
|
|
||||||
|
|
||||||
# 统计文本数据
|
|
||||||
text_count = 0
|
|
||||||
text_files = glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.txt"))
|
|
||||||
for text_file in text_files:
|
|
||||||
try:
|
|
||||||
with open(text_file, 'r', encoding='utf-8') as f:
|
|
||||||
lines = [line.strip() for line in f.readlines() if line.strip()]
|
|
||||||
text_count += len(lines)
|
|
||||||
except Exception:
|
|
||||||
continue
|
|
||||||
|
|
||||||
return jsonify({
|
|
||||||
'success': True,
|
|
||||||
'image_count': image_count,
|
|
||||||
'text_count': text_count
|
|
||||||
})
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取数据统计失败: {str(e)}")
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': f'获取统计失败: {str(e)}'
|
|
||||||
}), 500
|
|
||||||
|
|
||||||
@app.route('/api/data/clear', methods=['POST'])
|
|
||||||
def clear_data():
|
|
||||||
"""清空所有数据"""
|
|
||||||
try:
|
|
||||||
# 清空图片文件
|
|
||||||
for ext in ALLOWED_EXTENSIONS:
|
|
||||||
pattern = os.path.join(SAMPLE_IMAGES_FOLDER, f"*.{ext}")
|
|
||||||
for file_path in glob.glob(pattern):
|
|
||||||
try:
|
|
||||||
os.remove(file_path)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"删除图片文件失败 {file_path}: {e}")
|
|
||||||
|
|
||||||
# 清空文本文件
|
|
||||||
text_files = glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.txt"))
|
|
||||||
for text_file in text_files:
|
|
||||||
try:
|
|
||||||
os.remove(text_file)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"删除文本文件失败 {text_file}: {e}")
|
|
||||||
|
|
||||||
# 重置索引
|
|
||||||
global retrieval_system
|
|
||||||
if retrieval_system:
|
|
||||||
retrieval_system.text_index = None
|
|
||||||
retrieval_system.image_index = None
|
|
||||||
retrieval_system.text_data = []
|
|
||||||
retrieval_system.image_data = []
|
|
||||||
|
|
||||||
return jsonify({
|
|
||||||
'success': True,
|
|
||||||
'message': '数据已清空'
|
|
||||||
})
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"清空数据失败: {str(e)}")
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': f'清空数据失败: {str(e)}'
|
|
||||||
}), 500
|
|
||||||
|
|
||||||
@app.route('/uploads/<filename>')
|
|
||||||
def uploaded_file(filename):
|
|
||||||
"""提供上传文件的访问"""
|
|
||||||
return send_file(os.path.join(SAMPLE_IMAGES_FOLDER, filename))
|
|
||||||
|
|
||||||
def print_startup_info():
|
|
||||||
"""打印启动信息"""
|
|
||||||
print("🚀 启动多GPU多模态检索Web应用")
|
|
||||||
print("=" * 60)
|
|
||||||
print("访问地址: http://localhost:5000")
|
|
||||||
print("支持功能:")
|
|
||||||
print(" 📝 文搜文 - 文本查找相似文本")
|
|
||||||
print(" 🖼️ 文搜图 - 文本查找相关图片")
|
|
||||||
print(" 📝 图搜文 - 图片查找相关文本")
|
|
||||||
print(" 🖼️ 图搜图 - 图片查找相似图片")
|
|
||||||
print(" 📤 批量上传 - 图片和文本数据管理")
|
|
||||||
print("GPU配置:")
|
|
||||||
|
|
||||||
try:
|
|
||||||
import torch
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
gpu_count = torch.cuda.device_count()
|
|
||||||
print(f" 🖥️ 检测到 {gpu_count} 个GPU")
|
|
||||||
for i in range(gpu_count):
|
|
||||||
name = torch.cuda.get_device_name(i)
|
|
||||||
props = torch.cuda.get_device_properties(i)
|
|
||||||
memory_gb = props.total_memory / 1024**3
|
|
||||||
print(f" GPU {i}: {name} ({memory_gb:.1f}GB)")
|
|
||||||
else:
|
|
||||||
print(" ❌ CUDA不可用")
|
|
||||||
except Exception as e:
|
|
||||||
print(f" ❌ GPU检查失败: {e}")
|
|
||||||
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
def auto_initialize():
|
|
||||||
"""启动时自动初始化系统"""
|
|
||||||
global retrieval_system
|
|
||||||
|
|
||||||
try:
|
|
||||||
logger.info("🚀 启动时自动初始化多GPU检索系统...")
|
|
||||||
|
|
||||||
# 导入多GPU检索系统
|
|
||||||
from multimodal_retrieval_multigpu import MultiGPUMultimodalRetrieval
|
|
||||||
|
|
||||||
# 初始化系统
|
|
||||||
retrieval_system = MultiGPUMultimodalRetrieval()
|
|
||||||
|
|
||||||
if retrieval_system.model is None:
|
|
||||||
raise Exception("模型加载失败")
|
|
||||||
|
|
||||||
logger.info("✅ 系统自动初始化成功")
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 系统自动初始化失败: {str(e)}")
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
return False
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
print_startup_info()
|
|
||||||
|
|
||||||
# 启动时自动初始化
|
|
||||||
auto_initialize()
|
|
||||||
|
|
||||||
# 启动Flask应用
|
|
||||||
app.run(
|
|
||||||
host='0.0.0.0',
|
|
||||||
port=5000,
|
|
||||||
debug=False,
|
|
||||||
threaded=True
|
|
||||||
)
|
|
||||||
689
web_app_vdb.py
689
web_app_vdb.py
@ -1,689 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
集成百度VDB的多模态检索Web应用
|
|
||||||
支持向量存储和多种检索方式
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
from flask import Flask, render_template, request, jsonify, send_file
|
|
||||||
from werkzeug.utils import secure_filename
|
|
||||||
from PIL import Image
|
|
||||||
import base64
|
|
||||||
import io
|
|
||||||
import logging
|
|
||||||
import traceback
|
|
||||||
import glob
|
|
||||||
|
|
||||||
# 设置环境变量优化GPU内存
|
|
||||||
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
|
|
||||||
|
|
||||||
# 配置日志
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
app = Flask(__name__)
|
|
||||||
app.config['SECRET_KEY'] = 'vdb_multimodal_retrieval_2024'
|
|
||||||
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max file size
|
|
||||||
|
|
||||||
# 配置上传文件夹
|
|
||||||
UPLOAD_FOLDER = 'uploads'
|
|
||||||
SAMPLE_IMAGES_FOLDER = 'sample_images'
|
|
||||||
TEXT_DATA_FOLDER = 'text_data'
|
|
||||||
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'bmp', 'webp'}
|
|
||||||
|
|
||||||
# 确保文件夹存在
|
|
||||||
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
|
|
||||||
os.makedirs(SAMPLE_IMAGES_FOLDER, exist_ok=True)
|
|
||||||
os.makedirs(TEXT_DATA_FOLDER, exist_ok=True)
|
|
||||||
|
|
||||||
# 全局检索系统实例
|
|
||||||
retrieval_system = None
|
|
||||||
|
|
||||||
def allowed_file(filename):
|
|
||||||
"""检查文件扩展名是否允许"""
|
|
||||||
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
|
|
||||||
|
|
||||||
def image_to_base64(image_path):
|
|
||||||
"""将图片转换为base64编码"""
|
|
||||||
try:
|
|
||||||
with open(image_path, "rb") as img_file:
|
|
||||||
return base64.b64encode(img_file.read()).decode('utf-8')
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"图片转换失败: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
@app.route('/')
|
|
||||||
def index():
|
|
||||||
"""主页"""
|
|
||||||
return render_template('index.html')
|
|
||||||
|
|
||||||
@app.route('/api/status')
|
|
||||||
def get_status():
|
|
||||||
"""获取系统状态"""
|
|
||||||
global retrieval_system
|
|
||||||
|
|
||||||
status = {
|
|
||||||
'initialized': retrieval_system is not None,
|
|
||||||
'gpu_count': 0,
|
|
||||||
'model_loaded': False,
|
|
||||||
'vdb_connected': False
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
import torch
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
status['gpu_count'] = torch.cuda.device_count()
|
|
||||||
|
|
||||||
if retrieval_system:
|
|
||||||
status['model_loaded'] = retrieval_system.model is not None
|
|
||||||
status['vdb_connected'] = retrieval_system.vdb is not None
|
|
||||||
status['device_ids'] = retrieval_system.device_ids
|
|
||||||
|
|
||||||
# 获取VDB统计信息
|
|
||||||
if retrieval_system.vdb:
|
|
||||||
stats = retrieval_system.get_statistics()
|
|
||||||
status['vdb_stats'] = stats
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取状态失败: {e}")
|
|
||||||
|
|
||||||
return jsonify(status)
|
|
||||||
|
|
||||||
@app.route('/api/init', methods=['POST'])
|
|
||||||
def initialize_system():
|
|
||||||
"""初始化VDB多模态检索系统"""
|
|
||||||
global retrieval_system
|
|
||||||
|
|
||||||
try:
|
|
||||||
logger.info("正在初始化VDB多模态检索系统...")
|
|
||||||
|
|
||||||
# 导入VDB检索系统
|
|
||||||
from multimodal_retrieval_vdb import MultimodalRetrievalVDB
|
|
||||||
|
|
||||||
# 初始化系统
|
|
||||||
retrieval_system = MultimodalRetrievalVDB()
|
|
||||||
|
|
||||||
if retrieval_system.model is None:
|
|
||||||
raise Exception("模型加载失败")
|
|
||||||
|
|
||||||
if retrieval_system.vdb is None:
|
|
||||||
raise Exception("VDB连接失败")
|
|
||||||
|
|
||||||
logger.info("✅ VDB多模态系统初始化成功")
|
|
||||||
|
|
||||||
# 获取统计信息
|
|
||||||
stats = retrieval_system.get_statistics()
|
|
||||||
|
|
||||||
return jsonify({
|
|
||||||
'success': True,
|
|
||||||
'message': 'VDB多模态系统初始化成功',
|
|
||||||
'device_ids': retrieval_system.device_ids,
|
|
||||||
'gpu_count': len(retrieval_system.device_ids),
|
|
||||||
'vdb_stats': stats
|
|
||||||
})
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
error_msg = f"系统初始化失败: {str(e)}"
|
|
||||||
logger.error(error_msg)
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': error_msg
|
|
||||||
}), 500
|
|
||||||
|
|
||||||
@app.route('/api/search/text_to_text', methods=['POST'])
|
|
||||||
def search_text_to_text():
|
|
||||||
"""文本搜索文本"""
|
|
||||||
return handle_search('text_to_text')
|
|
||||||
|
|
||||||
@app.route('/api/search/text_to_image', methods=['POST'])
|
|
||||||
def search_text_to_image():
|
|
||||||
"""文本搜索图片"""
|
|
||||||
return handle_search('text_to_image')
|
|
||||||
|
|
||||||
@app.route('/api/search/image_to_text', methods=['POST'])
|
|
||||||
def search_image_to_text():
|
|
||||||
"""图片搜索文本"""
|
|
||||||
return handle_search('image_to_text')
|
|
||||||
|
|
||||||
@app.route('/api/search/image_to_image', methods=['POST'])
|
|
||||||
def search_image_to_image():
|
|
||||||
"""图片搜索图片"""
|
|
||||||
return handle_search('image_to_image')
|
|
||||||
|
|
||||||
@app.route('/api/search', methods=['POST'])
|
|
||||||
def search():
|
|
||||||
"""通用搜索接口(兼容旧版本)"""
|
|
||||||
mode = request.form.get('mode') or request.json.get('mode', 'text_to_text')
|
|
||||||
return handle_search(mode)
|
|
||||||
|
|
||||||
def handle_search(mode):
|
|
||||||
"""处理搜索请求的通用函数"""
|
|
||||||
global retrieval_system
|
|
||||||
|
|
||||||
if not retrieval_system:
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': '系统未初始化,请先点击初始化按钮'
|
|
||||||
}), 400
|
|
||||||
|
|
||||||
try:
|
|
||||||
top_k = int(request.form.get('top_k', 5))
|
|
||||||
|
|
||||||
if mode in ['text_to_text', 'text_to_image']:
|
|
||||||
# 文本查询
|
|
||||||
query = request.form.get('query') or request.json.get('query', '')
|
|
||||||
if not query.strip():
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': '请输入查询文本'
|
|
||||||
}), 400
|
|
||||||
|
|
||||||
logger.info(f"执行{mode}搜索: {query}")
|
|
||||||
|
|
||||||
# 执行搜索
|
|
||||||
if mode == 'text_to_text':
|
|
||||||
raw_results = retrieval_system.search_text_to_text(query, top_k=top_k)
|
|
||||||
# 格式化文本搜索结果
|
|
||||||
results = []
|
|
||||||
for text, score in raw_results:
|
|
||||||
results.append({
|
|
||||||
'text': text,
|
|
||||||
'score': float(score)
|
|
||||||
})
|
|
||||||
else: # text_to_image
|
|
||||||
raw_results = retrieval_system.search_text_to_image(query, top_k=top_k)
|
|
||||||
# 格式化图像搜索结果
|
|
||||||
results = []
|
|
||||||
for image_path, score in raw_results:
|
|
||||||
try:
|
|
||||||
# 读取图像并转换为base64
|
|
||||||
with open(image_path, 'rb') as img_file:
|
|
||||||
image_data = img_file.read()
|
|
||||||
image_base64 = base64.b64encode(image_data).decode('utf-8')
|
|
||||||
|
|
||||||
results.append({
|
|
||||||
'filename': os.path.basename(image_path),
|
|
||||||
'image_path': image_path,
|
|
||||||
'image_base64': image_base64,
|
|
||||||
'score': float(score)
|
|
||||||
})
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"读取图像失败 {image_path}: {e}")
|
|
||||||
results.append({
|
|
||||||
'filename': os.path.basename(image_path),
|
|
||||||
'image_path': image_path,
|
|
||||||
'image_base64': '',
|
|
||||||
'score': float(score)
|
|
||||||
})
|
|
||||||
|
|
||||||
return jsonify({
|
|
||||||
'success': True,
|
|
||||||
'mode': mode,
|
|
||||||
'query': query,
|
|
||||||
'results': results,
|
|
||||||
'result_count': len(results)
|
|
||||||
})
|
|
||||||
|
|
||||||
elif mode in ['image_to_text', 'image_to_image']:
|
|
||||||
# 图片查询
|
|
||||||
if 'image' not in request.files:
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': '请上传查询图片'
|
|
||||||
}), 400
|
|
||||||
|
|
||||||
file = request.files['image']
|
|
||||||
if file.filename == '' or not allowed_file(file.filename):
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': '请上传有效的图片文件'
|
|
||||||
}), 400
|
|
||||||
|
|
||||||
# 保存上传的图片
|
|
||||||
filename = secure_filename(file.filename)
|
|
||||||
timestamp = str(int(time.time()))
|
|
||||||
filename = f"query_{timestamp}_{filename}"
|
|
||||||
filepath = os.path.join(UPLOAD_FOLDER, filename)
|
|
||||||
file.save(filepath)
|
|
||||||
|
|
||||||
logger.info(f"执行{mode}搜索,图片: {filename}")
|
|
||||||
|
|
||||||
# 执行搜索
|
|
||||||
if mode == 'image_to_text':
|
|
||||||
raw_results = retrieval_system.search_image_to_text(filepath, top_k=top_k)
|
|
||||||
# 格式化文本搜索结果
|
|
||||||
results = []
|
|
||||||
for text, score in raw_results:
|
|
||||||
results.append({
|
|
||||||
'text': text,
|
|
||||||
'score': float(score)
|
|
||||||
})
|
|
||||||
else: # image_to_image
|
|
||||||
raw_results = retrieval_system.search_image_to_image(filepath, top_k=top_k)
|
|
||||||
# 格式化图像搜索结果
|
|
||||||
results = []
|
|
||||||
for image_path, score in raw_results:
|
|
||||||
try:
|
|
||||||
# 读取图像并转换为base64
|
|
||||||
with open(image_path, 'rb') as img_file:
|
|
||||||
image_data = img_file.read()
|
|
||||||
image_base64 = base64.b64encode(image_data).decode('utf-8')
|
|
||||||
|
|
||||||
results.append({
|
|
||||||
'filename': os.path.basename(image_path),
|
|
||||||
'image_path': image_path,
|
|
||||||
'image_base64': image_base64,
|
|
||||||
'score': float(score)
|
|
||||||
})
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"读取图像失败 {image_path}: {e}")
|
|
||||||
results.append({
|
|
||||||
'filename': os.path.basename(image_path),
|
|
||||||
'image_path': image_path,
|
|
||||||
'image_base64': '',
|
|
||||||
'score': float(score)
|
|
||||||
})
|
|
||||||
|
|
||||||
# 转换查询图片为base64
|
|
||||||
query_image_b64 = image_to_base64(filepath)
|
|
||||||
|
|
||||||
return jsonify({
|
|
||||||
'success': True,
|
|
||||||
'mode': mode,
|
|
||||||
'query_image': query_image_b64,
|
|
||||||
'results': results,
|
|
||||||
'result_count': len(results)
|
|
||||||
})
|
|
||||||
|
|
||||||
else:
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': f'不支持的搜索模式: {mode}'
|
|
||||||
}), 400
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
error_msg = f"搜索失败: {str(e)}"
|
|
||||||
logger.error(error_msg)
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': error_msg
|
|
||||||
}), 500
|
|
||||||
|
|
||||||
@app.route('/api/upload/images', methods=['POST'])
|
|
||||||
def upload_images():
|
|
||||||
"""批量上传图片"""
|
|
||||||
try:
|
|
||||||
uploaded_files = []
|
|
||||||
|
|
||||||
if 'images' not in request.files:
|
|
||||||
return jsonify({'success': False, 'message': '没有选择文件'}), 400
|
|
||||||
|
|
||||||
files = request.files.getlist('images')
|
|
||||||
|
|
||||||
for file in files:
|
|
||||||
if file and file.filename != '' and allowed_file(file.filename):
|
|
||||||
filename = secure_filename(file.filename)
|
|
||||||
timestamp = str(int(time.time()))
|
|
||||||
filename = f"{timestamp}_{filename}"
|
|
||||||
filepath = os.path.join(SAMPLE_IMAGES_FOLDER, filename)
|
|
||||||
file.save(filepath)
|
|
||||||
uploaded_files.append(filename)
|
|
||||||
|
|
||||||
return jsonify({
|
|
||||||
'success': True,
|
|
||||||
'message': f'成功上传 {len(uploaded_files)} 个图片文件',
|
|
||||||
'uploaded_count': len(uploaded_files),
|
|
||||||
'files': uploaded_files
|
|
||||||
})
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': f'上传失败: {str(e)}'
|
|
||||||
}), 500
|
|
||||||
|
|
||||||
@app.route('/api/upload/texts', methods=['POST'])
|
|
||||||
def upload_texts():
|
|
||||||
"""批量上传文本数据"""
|
|
||||||
try:
|
|
||||||
data = request.get_json()
|
|
||||||
|
|
||||||
if not data or 'texts' not in data:
|
|
||||||
return jsonify({'success': False, 'message': '没有提供文本数据'}), 400
|
|
||||||
|
|
||||||
texts = data['texts']
|
|
||||||
if not isinstance(texts, list):
|
|
||||||
return jsonify({'success': False, 'message': '文本数据格式错误'}), 400
|
|
||||||
|
|
||||||
# 保存文本数据到文件
|
|
||||||
timestamp = str(int(time.time()))
|
|
||||||
filename = f"texts_{timestamp}.json"
|
|
||||||
filepath = os.path.join(TEXT_DATA_FOLDER, filename)
|
|
||||||
|
|
||||||
with open(filepath, 'w', encoding='utf-8') as f:
|
|
||||||
json.dump(texts, f, ensure_ascii=False, indent=2)
|
|
||||||
|
|
||||||
return jsonify({
|
|
||||||
'success': True,
|
|
||||||
'message': f'成功上传 {len(texts)} 条文本',
|
|
||||||
'uploaded_count': len(texts)
|
|
||||||
})
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': f'上传失败: {str(e)}'
|
|
||||||
}), 500
|
|
||||||
|
|
||||||
@app.route('/api/store_data', methods=['POST'])
|
|
||||||
def store_data():
|
|
||||||
"""将上传的数据存储到VDB"""
|
|
||||||
global retrieval_system
|
|
||||||
|
|
||||||
if not retrieval_system:
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': '系统未初始化'
|
|
||||||
}), 400
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 获取所有图片和文本文件
|
|
||||||
image_files = []
|
|
||||||
text_data = []
|
|
||||||
|
|
||||||
# 扫描图片文件
|
|
||||||
for ext in ALLOWED_EXTENSIONS:
|
|
||||||
pattern = os.path.join(SAMPLE_IMAGES_FOLDER, f"*.{ext}")
|
|
||||||
image_files.extend(glob.glob(pattern))
|
|
||||||
|
|
||||||
# 读取文本文件
|
|
||||||
text_files = glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.json"))
|
|
||||||
text_files.extend(glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.txt")))
|
|
||||||
|
|
||||||
for text_file in text_files:
|
|
||||||
try:
|
|
||||||
if text_file.endswith('.json'):
|
|
||||||
with open(text_file, 'r', encoding='utf-8') as f:
|
|
||||||
data = json.load(f)
|
|
||||||
if isinstance(data, list):
|
|
||||||
text_data.extend([str(item).strip() for item in data if str(item).strip()])
|
|
||||||
else:
|
|
||||||
text_data.append(str(data).strip())
|
|
||||||
else:
|
|
||||||
with open(text_file, 'r', encoding='utf-8') as f:
|
|
||||||
lines = [line.strip() for line in f.readlines() if line.strip()]
|
|
||||||
text_data.extend(lines)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"读取文本文件失败 {text_file}: {e}")
|
|
||||||
|
|
||||||
# 检查是否有数据可以存储
|
|
||||||
if not image_files and not text_data:
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': '没有找到可用的图片或文本数据,请先上传数据'
|
|
||||||
}), 400
|
|
||||||
|
|
||||||
# 存储数据到VDB
|
|
||||||
stored_images = 0
|
|
||||||
stored_texts = 0
|
|
||||||
|
|
||||||
if image_files:
|
|
||||||
logger.info(f"存储图片到VDB,共 {len(image_files)} 张图片")
|
|
||||||
image_ids = retrieval_system.store_images(image_files)
|
|
||||||
stored_images = len(image_ids)
|
|
||||||
|
|
||||||
if text_data:
|
|
||||||
logger.info(f"存储文本到VDB,共 {len(text_data)} 条文本")
|
|
||||||
text_ids = retrieval_system.store_texts(text_data)
|
|
||||||
stored_texts = len(text_ids)
|
|
||||||
|
|
||||||
# 获取更新后的统计信息
|
|
||||||
stats = retrieval_system.get_statistics()
|
|
||||||
|
|
||||||
return jsonify({
|
|
||||||
'success': True,
|
|
||||||
'message': f'数据存储完成!图片: {stored_images} 张,文本: {stored_texts} 条',
|
|
||||||
'stored_images': stored_images,
|
|
||||||
'stored_texts': stored_texts,
|
|
||||||
'vdb_stats': stats
|
|
||||||
})
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"存储数据失败: {str(e)}")
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': f'存储数据失败: {str(e)}'
|
|
||||||
}), 500
|
|
||||||
|
|
||||||
@app.route('/api/data/stats', methods=['GET'])
|
|
||||||
def get_data_stats():
|
|
||||||
"""获取数据统计信息"""
|
|
||||||
global retrieval_system
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 统计本地文件
|
|
||||||
image_count = 0
|
|
||||||
for ext in ALLOWED_EXTENSIONS:
|
|
||||||
pattern = os.path.join(SAMPLE_IMAGES_FOLDER, f"*.{ext}")
|
|
||||||
image_count += len(glob.glob(pattern))
|
|
||||||
|
|
||||||
text_count = 0
|
|
||||||
text_files = glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.json"))
|
|
||||||
text_files.extend(glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.txt")))
|
|
||||||
for text_file in text_files:
|
|
||||||
try:
|
|
||||||
if text_file.endswith('.json'):
|
|
||||||
with open(text_file, 'r', encoding='utf-8') as f:
|
|
||||||
data = json.load(f)
|
|
||||||
if isinstance(data, list):
|
|
||||||
text_count += len(data)
|
|
||||||
else:
|
|
||||||
text_count += 1
|
|
||||||
else:
|
|
||||||
with open(text_file, 'r', encoding='utf-8') as f:
|
|
||||||
lines = [line.strip() for line in f.readlines() if line.strip()]
|
|
||||||
text_count += len(lines)
|
|
||||||
except Exception:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 获取VDB统计信息
|
|
||||||
vdb_stats = {}
|
|
||||||
if retrieval_system:
|
|
||||||
vdb_stats = retrieval_system.get_statistics()
|
|
||||||
|
|
||||||
return jsonify({
|
|
||||||
'success': True,
|
|
||||||
'local_files': {
|
|
||||||
'image_count': image_count,
|
|
||||||
'text_count': text_count
|
|
||||||
},
|
|
||||||
'vdb_stats': vdb_stats
|
|
||||||
})
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取数据统计失败: {str(e)}")
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': f'获取统计失败: {str(e)}'
|
|
||||||
}), 500
|
|
||||||
|
|
||||||
@app.route('/api/data/list', methods=['GET'])
|
|
||||||
def list_data():
|
|
||||||
"""获取数据列表"""
|
|
||||||
try:
|
|
||||||
# 获取图片文件列表
|
|
||||||
image_files = []
|
|
||||||
for ext in ALLOWED_EXTENSIONS:
|
|
||||||
pattern = os.path.join(SAMPLE_IMAGES_FOLDER, f"*.{ext}")
|
|
||||||
for file_path in glob.glob(pattern):
|
|
||||||
try:
|
|
||||||
# 转换为base64
|
|
||||||
image_base64 = image_to_base64(file_path)
|
|
||||||
image_files.append({
|
|
||||||
'filename': os.path.basename(file_path),
|
|
||||||
'filepath': file_path,
|
|
||||||
'image_base64': image_base64,
|
|
||||||
'size': os.path.getsize(file_path)
|
|
||||||
})
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"处理图片文件失败 {file_path}: {e}")
|
|
||||||
|
|
||||||
# 获取文本文件列表
|
|
||||||
text_files = []
|
|
||||||
text_file_paths = glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.json"))
|
|
||||||
text_file_paths.extend(glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.txt")))
|
|
||||||
|
|
||||||
for text_file in text_file_paths:
|
|
||||||
try:
|
|
||||||
text_files.append({
|
|
||||||
'filename': os.path.basename(text_file),
|
|
||||||
'filepath': text_file,
|
|
||||||
'size': os.path.getsize(text_file)
|
|
||||||
})
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"处理文本文件失败 {text_file}: {e}")
|
|
||||||
|
|
||||||
return jsonify({
|
|
||||||
'success': True,
|
|
||||||
'image_files': image_files,
|
|
||||||
'text_files': text_files,
|
|
||||||
'image_count': len(image_files),
|
|
||||||
'text_count': len(text_files)
|
|
||||||
})
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取数据列表失败: {str(e)}")
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': f'获取数据列表失败: {str(e)}'
|
|
||||||
}), 500
|
|
||||||
|
|
||||||
@app.route('/api/data/clear', methods=['POST'])
|
|
||||||
def clear_data():
|
|
||||||
"""清空所有数据"""
|
|
||||||
global retrieval_system
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 清空本地文件
|
|
||||||
for ext in ALLOWED_EXTENSIONS:
|
|
||||||
pattern = os.path.join(SAMPLE_IMAGES_FOLDER, f"*.{ext}")
|
|
||||||
for file_path in glob.glob(pattern):
|
|
||||||
try:
|
|
||||||
os.remove(file_path)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"删除图片文件失败 {file_path}: {e}")
|
|
||||||
|
|
||||||
text_files = glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.json"))
|
|
||||||
text_files.extend(glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.txt")))
|
|
||||||
for text_file in text_files:
|
|
||||||
try:
|
|
||||||
os.remove(text_file)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"删除文本文件失败 {text_file}: {e}")
|
|
||||||
|
|
||||||
# 清空VDB数据
|
|
||||||
if retrieval_system:
|
|
||||||
retrieval_system.clear_all_data()
|
|
||||||
|
|
||||||
return jsonify({
|
|
||||||
'success': True,
|
|
||||||
'message': '所有数据已清空(包括VDB中的向量数据)'
|
|
||||||
})
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"清空数据失败: {str(e)}")
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': f'清空数据失败: {str(e)}'
|
|
||||||
}), 500
|
|
||||||
|
|
||||||
@app.route('/uploads/<filename>')
|
|
||||||
def uploaded_file(filename):
|
|
||||||
"""提供上传文件的访问"""
|
|
||||||
return send_file(os.path.join(SAMPLE_IMAGES_FOLDER, filename))
|
|
||||||
|
|
||||||
def print_startup_info():
|
|
||||||
"""打印启动信息"""
|
|
||||||
print("🚀 启动VDB多模态检索Web应用")
|
|
||||||
print("=" * 60)
|
|
||||||
print("访问地址: http://localhost:5000")
|
|
||||||
print("新功能:")
|
|
||||||
print(" 🗄️ 百度VDB - 向量数据库存储")
|
|
||||||
print(" 📊 实时统计 - VDB数据统计信息")
|
|
||||||
print(" 🔄 数据同步 - 本地文件到VDB存储")
|
|
||||||
print("支持功能:")
|
|
||||||
print(" 📝 文搜文 - 文本查找相似文本")
|
|
||||||
print(" 🖼️ 文搜图 - 文本查找相关图片")
|
|
||||||
print(" 📝 图搜文 - 图片查找相关文本")
|
|
||||||
print(" 🖼️ 图搜图 - 图片查找相似图片")
|
|
||||||
print(" 📤 批量上传 - 图片和文本数据管理")
|
|
||||||
print("GPU配置:")
|
|
||||||
|
|
||||||
try:
|
|
||||||
import torch
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
gpu_count = torch.cuda.device_count()
|
|
||||||
print(f" 🖥️ 检测到 {gpu_count} 个GPU")
|
|
||||||
for i in range(gpu_count):
|
|
||||||
name = torch.cuda.get_device_name(i)
|
|
||||||
props = torch.cuda.get_device_properties(i)
|
|
||||||
memory_gb = props.total_memory / 1024**3
|
|
||||||
print(f" GPU {i}: {name} ({memory_gb:.1f}GB)")
|
|
||||||
else:
|
|
||||||
print(" ❌ CUDA不可用")
|
|
||||||
except Exception as e:
|
|
||||||
print(f" ❌ GPU检查失败: {e}")
|
|
||||||
|
|
||||||
print("VDB配置:")
|
|
||||||
print(" 🌐 服务器: http://180.76.96.191:5287")
|
|
||||||
print(" 👤 用户: root")
|
|
||||||
print(" 🗄️ 数据库: multimodal_retrieval")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
def auto_initialize():
|
|
||||||
"""启动时自动初始化系统"""
|
|
||||||
global retrieval_system
|
|
||||||
|
|
||||||
try:
|
|
||||||
logger.info("🚀 启动时自动初始化VDB多模态检索系统...")
|
|
||||||
|
|
||||||
# 导入VDB检索系统
|
|
||||||
from multimodal_retrieval_vdb import MultimodalRetrievalVDB
|
|
||||||
|
|
||||||
# 初始化系统
|
|
||||||
retrieval_system = MultimodalRetrievalVDB()
|
|
||||||
|
|
||||||
if retrieval_system.model is None:
|
|
||||||
raise Exception("模型加载失败")
|
|
||||||
|
|
||||||
if retrieval_system.vdb is None:
|
|
||||||
raise Exception("VDB连接失败")
|
|
||||||
|
|
||||||
logger.info("✅ VDB系统自动初始化成功")
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ VDB系统自动初始化失败: {str(e)}")
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
return False
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
print_startup_info()
|
|
||||||
|
|
||||||
# 启动时自动初始化
|
|
||||||
auto_initialize()
|
|
||||||
|
|
||||||
# 启动Flask应用
|
|
||||||
app.run(
|
|
||||||
host='0.0.0.0',
|
|
||||||
port=5000,
|
|
||||||
debug=False,
|
|
||||||
threaded=True
|
|
||||||
)
|
|
||||||
@ -1,650 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
生产级Web应用 - 使用纯百度VDB系统,完全移除FAISS依赖
|
|
||||||
优化版本:支持自动清理、内存处理和流式上传
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import json
|
|
||||||
import base64
|
|
||||||
import time
|
|
||||||
import logging
|
|
||||||
from io import BytesIO
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict, List, Any, Optional
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
from flask import Flask, request, jsonify, render_template, send_from_directory
|
|
||||||
from werkzeug.utils import secure_filename
|
|
||||||
import threading
|
|
||||||
|
|
||||||
from multimodal_retrieval_vdb_only import MultimodalRetrievalVDBOnly
|
|
||||||
from mongodb_manager import get_mongodb_manager
|
|
||||||
from baidu_bos_manager import get_bos_manager
|
|
||||||
from optimized_file_handler import get_file_handler
|
|
||||||
|
|
||||||
# 设置日志
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# 创建Flask应用
|
|
||||||
app = Flask(__name__)
|
|
||||||
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max file size
|
|
||||||
|
|
||||||
# 全局变量
|
|
||||||
retrieval_system = None
|
|
||||||
system_lock = threading.Lock()
|
|
||||||
|
|
||||||
# 配置
|
|
||||||
UPLOAD_FOLDER = 'uploads'
|
|
||||||
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'bmp', 'webp'}
|
|
||||||
|
|
||||||
# 确保上传目录存在
|
|
||||||
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
|
|
||||||
|
|
||||||
def allowed_file(filename):
|
|
||||||
"""检查文件扩展名是否允许"""
|
|
||||||
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
|
|
||||||
|
|
||||||
def encode_image_to_base64(image_path: str) -> str:
|
|
||||||
"""将图像编码为base64字符串"""
|
|
||||||
try:
|
|
||||||
with open(image_path, 'rb') as image_file:
|
|
||||||
encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
|
|
||||||
return f"data:image/jpeg;base64,{encoded_string}"
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"图像编码失败: {e}")
|
|
||||||
return ""
|
|
||||||
|
|
||||||
@app.route('/')
|
|
||||||
def index():
|
|
||||||
"""主页"""
|
|
||||||
return render_template('index.html')
|
|
||||||
|
|
||||||
@app.route('/api/init', methods=['POST'])
|
|
||||||
def init_system():
|
|
||||||
"""初始化多模态检索系统"""
|
|
||||||
global retrieval_system
|
|
||||||
|
|
||||||
try:
|
|
||||||
with system_lock:
|
|
||||||
if retrieval_system is None:
|
|
||||||
logger.info("🚀 初始化纯VDB多模态检索系统...")
|
|
||||||
retrieval_system = MultimodalRetrievalVDBOnly()
|
|
||||||
logger.info("✅ 系统初始化完成")
|
|
||||||
|
|
||||||
return jsonify({
|
|
||||||
'success': True,
|
|
||||||
'message': '系统初始化成功',
|
|
||||||
'backend': 'Baidu VDB (No FAISS)',
|
|
||||||
'status': 'ready'
|
|
||||||
})
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 系统初始化失败: {e}")
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': f'系统初始化失败: {str(e)}'
|
|
||||||
}), 500
|
|
||||||
|
|
||||||
@app.route('/api/upload/texts', methods=['POST'])
|
|
||||||
@app.route('/api/upload_texts', methods=['POST'])
|
|
||||||
def upload_texts():
|
|
||||||
"""上传文本数据 - 优化版本"""
|
|
||||||
global retrieval_system
|
|
||||||
|
|
||||||
try:
|
|
||||||
if retrieval_system is None:
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': '系统未初始化,请先调用 /api/init'
|
|
||||||
}), 400
|
|
||||||
|
|
||||||
data = request.get_json()
|
|
||||||
texts = data.get('texts', [])
|
|
||||||
|
|
||||||
if not texts:
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': '没有提供文本数据'
|
|
||||||
}), 400
|
|
||||||
|
|
||||||
# 获取优化的文件处理器
|
|
||||||
file_handler = get_file_handler()
|
|
||||||
|
|
||||||
# 使用内存处理文本数据
|
|
||||||
logger.info(f"📝 开始内存处理 {len(texts)} 条文本...")
|
|
||||||
processed_texts = file_handler.process_text_in_memory(texts)
|
|
||||||
|
|
||||||
if not processed_texts:
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': '文本处理失败'
|
|
||||||
}), 400
|
|
||||||
|
|
||||||
# 自动构建文本索引
|
|
||||||
index_status = "索引构建失败"
|
|
||||||
try:
|
|
||||||
with system_lock:
|
|
||||||
retrieval_system.build_text_index_parallel(texts)
|
|
||||||
|
|
||||||
# 存储向量元数据
|
|
||||||
mongodb_mgr = get_mongodb_manager()
|
|
||||||
for text_info in processed_texts:
|
|
||||||
mongodb_mgr.store_vector_metadata(
|
|
||||||
file_id=text_info["file_id"],
|
|
||||||
vector_type="text_vector",
|
|
||||||
vdb_id=text_info["bos_key"],
|
|
||||||
vector_info={"text_content": text_info["text_content"]}
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"✅ 文本索引自动构建完成")
|
|
||||||
index_status = "索引已自动构建"
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 自动构建文本索引失败: {e}")
|
|
||||||
index_status = f"索引构建失败: {str(e)}"
|
|
||||||
|
|
||||||
return jsonify({
|
|
||||||
'success': True,
|
|
||||||
'message': f'成功处理 {len(texts)} 条文本(内存模式)',
|
|
||||||
'count': len(texts),
|
|
||||||
'processed_texts': len(processed_texts),
|
|
||||||
'index_status': index_status,
|
|
||||||
'auto_indexed': True,
|
|
||||||
'processing_method': 'memory',
|
|
||||||
'storage_info': {
|
|
||||||
'bos_stored': len(processed_texts),
|
|
||||||
'mongodb_stored': len(processed_texts)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 文本上传失败: {e}")
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': f'文本上传失败: {str(e)}'
|
|
||||||
}), 500
|
|
||||||
|
|
||||||
@app.route('/api/upload/images', methods=['POST'])
|
|
||||||
@app.route('/api/upload_images', methods=['POST'])
|
|
||||||
def upload_images():
|
|
||||||
"""上传图像文件 - 优化版本"""
|
|
||||||
global retrieval_system
|
|
||||||
|
|
||||||
try:
|
|
||||||
if retrieval_system is None:
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': '系统未初始化,请先调用 /api/init'
|
|
||||||
}), 400
|
|
||||||
|
|
||||||
if 'files' not in request.files:
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': '没有文件上传'
|
|
||||||
}), 400
|
|
||||||
|
|
||||||
files = request.files.getlist('files')
|
|
||||||
if not files or all(file.filename == '' for file in files):
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': '没有选择文件'
|
|
||||||
}), 400
|
|
||||||
|
|
||||||
# 获取优化的文件处理器
|
|
||||||
file_handler = get_file_handler()
|
|
||||||
processed_files = []
|
|
||||||
temp_files_for_model = []
|
|
||||||
|
|
||||||
# 处理每个文件
|
|
||||||
for file in files:
|
|
||||||
if file and allowed_file(file.filename):
|
|
||||||
filename = secure_filename(file.filename)
|
|
||||||
logger.info(f"📷 处理图像文件: {filename}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 智能处理图像(自动选择内存或临时文件)
|
|
||||||
result = file_handler.process_image_smart(file, filename)
|
|
||||||
|
|
||||||
if result:
|
|
||||||
processed_files.append(result)
|
|
||||||
|
|
||||||
# 如果需要为模型处理准备临时文件
|
|
||||||
if result.get('processing_method') == 'memory':
|
|
||||||
# 对于内存处理的文件,需要为模型创建临时文件
|
|
||||||
temp_path = file_handler.get_temp_file_for_model(file, filename)
|
|
||||||
if temp_path:
|
|
||||||
temp_files_for_model.append({
|
|
||||||
'temp_path': temp_path,
|
|
||||||
'file_info': result
|
|
||||||
})
|
|
||||||
else:
|
|
||||||
# 临时文件处理的情况,直接使用返回的路径
|
|
||||||
temp_files_for_model.append({
|
|
||||||
'temp_path': result.get('temp_path'),
|
|
||||||
'file_info': result
|
|
||||||
})
|
|
||||||
|
|
||||||
logger.info(f"✅ 图像处理成功: {filename} ({result['processing_method']})")
|
|
||||||
else:
|
|
||||||
logger.error(f"❌ 图像处理失败: {filename}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 处理图像文件失败 {filename}: {e}")
|
|
||||||
|
|
||||||
if processed_files:
|
|
||||||
# 自动构建图像索引
|
|
||||||
index_status = "索引构建失败"
|
|
||||||
try:
|
|
||||||
# 准备图像路径列表
|
|
||||||
image_paths = [item['temp_path'] for item in temp_files_for_model if item['temp_path']]
|
|
||||||
|
|
||||||
if image_paths:
|
|
||||||
with system_lock:
|
|
||||||
retrieval_system.build_image_index_parallel(image_paths)
|
|
||||||
|
|
||||||
# 存储向量元数据
|
|
||||||
mongodb_mgr = get_mongodb_manager()
|
|
||||||
for file_info in processed_files:
|
|
||||||
mongodb_mgr.store_vector_metadata(
|
|
||||||
file_id=file_info["file_id"],
|
|
||||||
vector_type="image_vector",
|
|
||||||
vdb_id=file_info["bos_key"],
|
|
||||||
vector_info={"filename": file_info["filename"]}
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"✅ 图像索引自动构建完成")
|
|
||||||
index_status = "索引已自动构建"
|
|
||||||
|
|
||||||
# 清理模型处理用的临时文件
|
|
||||||
for item in temp_files_for_model:
|
|
||||||
if item['temp_path']:
|
|
||||||
file_handler.cleanup_temp_file(item['temp_path'])
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 自动构建图像索引失败: {e}")
|
|
||||||
index_status = f"索引构建失败: {str(e)}"
|
|
||||||
|
|
||||||
# 即使索引失败也要清理临时文件
|
|
||||||
for item in temp_files_for_model:
|
|
||||||
if item['temp_path']:
|
|
||||||
file_handler.cleanup_temp_file(item['temp_path'])
|
|
||||||
|
|
||||||
return jsonify({
|
|
||||||
'success': True,
|
|
||||||
'message': f'成功上传 {len(processed_files)} 个图像文件',
|
|
||||||
'count': len(processed_files),
|
|
||||||
'processed_files': len(processed_files),
|
|
||||||
'index_status': index_status,
|
|
||||||
'auto_indexed': True,
|
|
||||||
'processing_methods': [f['processing_method'] for f in processed_files],
|
|
||||||
'storage_info': {
|
|
||||||
'bos_stored': len(processed_files),
|
|
||||||
'mongodb_stored': len(processed_files)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
else:
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': '没有有效的图像文件'
|
|
||||||
}), 400
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 图像上传失败: {e}")
|
|
||||||
# 确保清理所有临时文件
|
|
||||||
file_handler = get_file_handler()
|
|
||||||
file_handler.cleanup_all_temp_files()
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': f'图像上传失败: {str(e)}'
|
|
||||||
}), 500
|
|
||||||
|
|
||||||
@app.route('/api/search/image_to_text', methods=['POST'])
|
|
||||||
def search_image_to_text():
|
|
||||||
"""图搜文 - 优化版本"""
|
|
||||||
global retrieval_system
|
|
||||||
|
|
||||||
try:
|
|
||||||
if retrieval_system is None:
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': '系统未初始化'
|
|
||||||
}), 400
|
|
||||||
|
|
||||||
top_k = request.form.get('top_k', 5, type=int)
|
|
||||||
|
|
||||||
if 'image' not in request.files:
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': '没有上传图像'
|
|
||||||
}), 400
|
|
||||||
|
|
||||||
file = request.files['image']
|
|
||||||
if not file or not allowed_file(file.filename):
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': '无效的图像文件'
|
|
||||||
}), 400
|
|
||||||
|
|
||||||
# 获取优化的文件处理器
|
|
||||||
file_handler = get_file_handler()
|
|
||||||
|
|
||||||
# 为模型处理创建临时文件
|
|
||||||
filename = secure_filename(file.filename)
|
|
||||||
temp_path = file_handler.get_temp_file_for_model(file, filename)
|
|
||||||
|
|
||||||
if not temp_path:
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': '临时文件创建失败'
|
|
||||||
}), 500
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 执行图搜文
|
|
||||||
with system_lock:
|
|
||||||
results = retrieval_system.search_image_to_text(temp_path, top_k=top_k)
|
|
||||||
|
|
||||||
# 格式化结果
|
|
||||||
formatted_results = []
|
|
||||||
for text, score in results:
|
|
||||||
formatted_results.append({
|
|
||||||
'text': text,
|
|
||||||
'score': float(score),
|
|
||||||
'type': 'text'
|
|
||||||
})
|
|
||||||
|
|
||||||
# 编码查询图像
|
|
||||||
query_image_base64 = encode_image_to_base64(temp_path)
|
|
||||||
|
|
||||||
return jsonify({
|
|
||||||
'success': True,
|
|
||||||
'query_image': query_image_base64,
|
|
||||||
'results': formatted_results,
|
|
||||||
'count': len(formatted_results),
|
|
||||||
'search_type': 'image_to_text',
|
|
||||||
'processing_method': 'optimized_temp_file'
|
|
||||||
})
|
|
||||||
|
|
||||||
finally:
|
|
||||||
# 确保清理临时文件
|
|
||||||
file_handler.cleanup_temp_file(temp_path)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 图搜文失败: {e}")
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': f'图搜文失败: {str(e)}'
|
|
||||||
}), 500
|
|
||||||
|
|
||||||
@app.route('/api/search/image_to_image', methods=['POST'])
|
|
||||||
def search_image_to_image():
|
|
||||||
"""图搜图 - 优化版本"""
|
|
||||||
global retrieval_system
|
|
||||||
|
|
||||||
try:
|
|
||||||
if retrieval_system is None:
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': '系统未初始化'
|
|
||||||
}), 400
|
|
||||||
|
|
||||||
top_k = request.form.get('top_k', 5, type=int)
|
|
||||||
|
|
||||||
if 'image' not in request.files:
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': '没有上传图像'
|
|
||||||
}), 400
|
|
||||||
|
|
||||||
file = request.files['image']
|
|
||||||
if not file or not allowed_file(file.filename):
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': '无效的图像文件'
|
|
||||||
}), 400
|
|
||||||
|
|
||||||
# 获取优化的文件处理器
|
|
||||||
file_handler = get_file_handler()
|
|
||||||
|
|
||||||
# 为模型处理创建临时文件
|
|
||||||
filename = secure_filename(file.filename)
|
|
||||||
temp_path = file_handler.get_temp_file_for_model(file, filename)
|
|
||||||
|
|
||||||
if not temp_path:
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': '临时文件创建失败'
|
|
||||||
}), 500
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 执行图搜图
|
|
||||||
with system_lock:
|
|
||||||
results = retrieval_system.search_image_to_image(temp_path, top_k=top_k)
|
|
||||||
|
|
||||||
# 格式化结果
|
|
||||||
formatted_results = []
|
|
||||||
for image_path, score in results:
|
|
||||||
image_base64 = encode_image_to_base64(image_path)
|
|
||||||
formatted_results.append({
|
|
||||||
'image_path': image_path,
|
|
||||||
'image_base64': image_base64,
|
|
||||||
'score': float(score),
|
|
||||||
'type': 'image'
|
|
||||||
})
|
|
||||||
|
|
||||||
# 编码查询图像
|
|
||||||
query_image_base64 = encode_image_to_base64(temp_path)
|
|
||||||
|
|
||||||
return jsonify({
|
|
||||||
'success': True,
|
|
||||||
'query_image': query_image_base64,
|
|
||||||
'results': formatted_results,
|
|
||||||
'count': len(formatted_results),
|
|
||||||
'search_type': 'image_to_image',
|
|
||||||
'processing_method': 'optimized_temp_file'
|
|
||||||
})
|
|
||||||
|
|
||||||
finally:
|
|
||||||
# 确保清理临时文件
|
|
||||||
file_handler.cleanup_temp_file(temp_path)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 图搜图失败: {e}")
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': f'图搜图失败: {str(e)}'
|
|
||||||
}), 500
|
|
||||||
|
|
||||||
@app.route('/api/data/list', methods=['GET'])
|
|
||||||
def list_data():
|
|
||||||
"""获取数据列表"""
|
|
||||||
global retrieval_system
|
|
||||||
|
|
||||||
try:
|
|
||||||
if retrieval_system is None:
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': '系统未初始化'
|
|
||||||
}), 400
|
|
||||||
|
|
||||||
# 获取MongoDB管理器
|
|
||||||
mongodb_mgr = get_mongodb_manager()
|
|
||||||
|
|
||||||
# 获取所有文件元数据
|
|
||||||
files_data = mongodb_mgr.get_all_files()
|
|
||||||
|
|
||||||
images = []
|
|
||||||
texts = []
|
|
||||||
|
|
||||||
for file_data in files_data:
|
|
||||||
if file_data.get('file_type') == 'image':
|
|
||||||
# 提取文件名用于显示
|
|
||||||
file_path = file_data.get('file_path', '')
|
|
||||||
filename = os.path.basename(file_path) if file_path else file_data.get('bos_key', '')
|
|
||||||
images.append(filename)
|
|
||||||
elif file_data.get('file_type') == 'text':
|
|
||||||
# 获取文本内容
|
|
||||||
text_content = file_data.get('additional_info', {}).get('text_content', '')
|
|
||||||
if text_content:
|
|
||||||
texts.append(text_content)
|
|
||||||
|
|
||||||
return jsonify({
|
|
||||||
'success': True,
|
|
||||||
'images': images,
|
|
||||||
'texts': texts,
|
|
||||||
'total_files': len(files_data),
|
|
||||||
'image_count': len(images),
|
|
||||||
'text_count': len(texts)
|
|
||||||
})
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 获取数据列表失败: {e}")
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': f'获取数据列表失败: {str(e)}'
|
|
||||||
}), 500
|
|
||||||
|
|
||||||
@app.route('/api/stats', methods=['GET'])
|
|
||||||
@app.route('/api/status', methods=['GET'])
|
|
||||||
@app.route('/api/data/stats', methods=['GET'])
|
|
||||||
def get_stats():
|
|
||||||
"""获取系统统计信息"""
|
|
||||||
global retrieval_system
|
|
||||||
|
|
||||||
try:
|
|
||||||
if retrieval_system is None:
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': '系统未初始化'
|
|
||||||
}), 400
|
|
||||||
|
|
||||||
# 获取MongoDB统计信息
|
|
||||||
mongodb_mgr = get_mongodb_manager()
|
|
||||||
mongodb_stats = mongodb_mgr.get_stats()
|
|
||||||
|
|
||||||
with system_lock:
|
|
||||||
stats = retrieval_system.get_statistics()
|
|
||||||
gpu_info = retrieval_system.get_gpu_memory_info()
|
|
||||||
|
|
||||||
# 合并统计信息
|
|
||||||
enhanced_stats = {
|
|
||||||
**stats,
|
|
||||||
'mongodb_stats': mongodb_stats,
|
|
||||||
'storage_backend': {
|
|
||||||
'metadata': 'MongoDB',
|
|
||||||
'files': 'Baidu BOS'
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return jsonify({
|
|
||||||
'success': True,
|
|
||||||
'stats': enhanced_stats,
|
|
||||||
'gpu_info': gpu_info,
|
|
||||||
'backend': 'Baidu VDB + MongoDB + BOS'
|
|
||||||
})
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 获取统计信息失败: {e}")
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': f'获取统计信息失败: {str(e)}'
|
|
||||||
}), 500
|
|
||||||
|
|
||||||
@app.route('/api/clear', methods=['POST'])
|
|
||||||
def clear_data():
|
|
||||||
"""清空所有数据 - 优化版本"""
|
|
||||||
global retrieval_system
|
|
||||||
|
|
||||||
try:
|
|
||||||
if retrieval_system is None:
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': '系统未初始化'
|
|
||||||
}), 400
|
|
||||||
|
|
||||||
with system_lock:
|
|
||||||
retrieval_system.clear_all_data()
|
|
||||||
|
|
||||||
# 清理所有临时文件
|
|
||||||
file_handler = get_file_handler()
|
|
||||||
file_handler.cleanup_all_temp_files()
|
|
||||||
|
|
||||||
return jsonify({
|
|
||||||
'success': True,
|
|
||||||
'message': '所有数据已清空,临时文件已清理'
|
|
||||||
})
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 清空数据失败: {e}")
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': f'清空数据失败: {str(e)}'
|
|
||||||
}), 500
|
|
||||||
|
|
||||||
@app.route('/uploads/<filename>')
|
|
||||||
def uploaded_file(filename):
|
|
||||||
"""提供上传文件的访问"""
|
|
||||||
return send_from_directory(UPLOAD_FOLDER, filename)
|
|
||||||
|
|
||||||
@app.errorhandler(413)
|
|
||||||
def too_large(e):
|
|
||||||
"""文件过大错误处理"""
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': '文件过大,最大支持16MB'
|
|
||||||
}), 413
|
|
||||||
|
|
||||||
@app.errorhandler(500)
|
|
||||||
def internal_error(e):
|
|
||||||
"""内部服务器错误处理"""
|
|
||||||
return jsonify({
|
|
||||||
'success': False,
|
|
||||||
'message': '内部服务器错误'
|
|
||||||
}), 500
|
|
||||||
|
|
||||||
def init_app():
|
|
||||||
"""应用初始化 - 优化版本"""
|
|
||||||
global retrieval_system
|
|
||||||
|
|
||||||
print("=" * 60)
|
|
||||||
print("生产级多模态检索Web应用 (优化版)")
|
|
||||||
print("Backend: 纯百度VDB (无FAISS)")
|
|
||||||
print("Features: 自动清理 + 内存处理 + 流式上传")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
# 显示GPU信息
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
gpu_count = torch.cuda.device_count()
|
|
||||||
print(f"检测到 {gpu_count} 个GPU:")
|
|
||||||
for i in range(gpu_count):
|
|
||||||
gpu_name = torch.cuda.get_device_properties(i).name
|
|
||||||
gpu_memory = torch.cuda.get_device_properties(i).total_memory / (1024**3)
|
|
||||||
print(f" GPU {i}: {gpu_name} ({gpu_memory:.1f}GB)")
|
|
||||||
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
# 启动时自动初始化系统
|
|
||||||
try:
|
|
||||||
logger.info("🚀 启动时自动初始化优化版VDB检索系统...")
|
|
||||||
retrieval_system = MultimodalRetrievalVDBOnly()
|
|
||||||
logger.info("✅ 优化版系统自动初始化成功")
|
|
||||||
|
|
||||||
# 初始化文件处理器
|
|
||||||
file_handler = get_file_handler()
|
|
||||||
logger.info("✅ 优化文件处理器初始化成功")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 系统自动初始化失败: {e}")
|
|
||||||
logger.info("系统将在首次API调用时初始化")
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
import torch
|
|
||||||
|
|
||||||
# 初始化应用
|
|
||||||
init_app()
|
|
||||||
|
|
||||||
# 启动Flask应用
|
|
||||||
app.run(host='0.0.0.0', port=5000, debug=False, threaded=True)
|
|
||||||
Loading…
x
Reference in New Issue
Block a user