📝 tem push
This commit is contained in:
parent
36021e817c
commit
202fad85ec
Binary file not shown.
BIN
__pycache__/faiss_vector_store.cpython-310.pyc
Normal file
BIN
__pycache__/faiss_vector_store.cpython-310.pyc
Normal file
Binary file not shown.
BIN
__pycache__/multimodal_retrieval_faiss.cpython-310.pyc
Normal file
BIN
__pycache__/multimodal_retrieval_faiss.cpython-310.pyc
Normal file
Binary file not shown.
BIN
__pycache__/multimodal_retrieval_local.cpython-310.pyc
Normal file
BIN
__pycache__/multimodal_retrieval_local.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
__pycache__/proxy_utils.cpython-310.pyc
Normal file
BIN
__pycache__/proxy_utils.cpython-310.pyc
Normal file
Binary file not shown.
78
app_log.txt
Normal file
78
app_log.txt
Normal file
File diff suppressed because one or more lines are too long
@ -118,30 +118,29 @@ class BaiduVDBBackend:
|
|||||||
try:
|
try:
|
||||||
logger.info(f"创建文本向量表: {self.text_table_name}")
|
logger.info(f"创建文本向量表: {self.text_table_name}")
|
||||||
|
|
||||||
# 定义字段 - 使用最简单的配置
|
# 定义字段 - 移除可能导致问题的复杂配置
|
||||||
fields = [
|
fields = [
|
||||||
Field("id", FieldType.STRING, primary_key=True, partition_key=True, not_null=True),
|
Field("id", FieldType.STRING, primary_key=True, not_null=True),
|
||||||
Field("text_content", FieldType.STRING, not_null=True),
|
Field("text_content", FieldType.STRING, not_null=True),
|
||||||
Field("vector", FieldType.FLOAT_VECTOR, not_null=True, dimension=self.vector_dimension)
|
Field("vector", FieldType.FLOAT_VECTOR, not_null=True, dimension=self.vector_dimension)
|
||||||
]
|
]
|
||||||
|
|
||||||
# 定义索引
|
# 定义索引 - 简化配置
|
||||||
indexes = [
|
indexes = [
|
||||||
VectorIndex(
|
VectorIndex(
|
||||||
index_name="text_vector_idx",
|
index_name="text_vector_idx",
|
||||||
index_type=IndexType.HNSW,
|
index_type=IndexType.HNSW,
|
||||||
field="vector",
|
field="vector",
|
||||||
metric_type=MetricType.COSINE,
|
metric_type=MetricType.COSINE,
|
||||||
params=HNSWParams(m=32, efconstruction=200),
|
params=HNSWParams(m=16, efconstruction=100),
|
||||||
auto_build=True
|
auto_build=True
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
# 创建表
|
# 创建表 - 简化配置
|
||||||
self.text_table = self.db.create_table(
|
self.text_table = self.db.create_table(
|
||||||
table_name=self.text_table_name,
|
table_name=self.text_table_name,
|
||||||
replication=2, # 双副本
|
replication=1, # 单副本
|
||||||
partition=Partition(partition_num=3), # 3个分区
|
|
||||||
schema=Schema(fields=fields, indexes=indexes)
|
schema=Schema(fields=fields, indexes=indexes)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -156,30 +155,29 @@ class BaiduVDBBackend:
|
|||||||
try:
|
try:
|
||||||
logger.info(f"创建图像向量表: {self.image_table_name}")
|
logger.info(f"创建图像向量表: {self.image_table_name}")
|
||||||
|
|
||||||
# 定义字段 - 使用最简单的配置
|
# 定义字段 - 移除可能导致问题的复杂配置
|
||||||
fields = [
|
fields = [
|
||||||
Field("id", FieldType.STRING, primary_key=True, partition_key=True, not_null=True),
|
Field("id", FieldType.STRING, primary_key=True, not_null=True),
|
||||||
Field("image_path", FieldType.STRING, not_null=True),
|
Field("image_path", FieldType.STRING, not_null=True),
|
||||||
Field("vector", FieldType.FLOAT_VECTOR, not_null=True, dimension=self.vector_dimension)
|
Field("vector", FieldType.FLOAT_VECTOR, not_null=True, dimension=self.vector_dimension)
|
||||||
]
|
]
|
||||||
|
|
||||||
# 定义索引
|
# 定义索引 - 简化配置
|
||||||
indexes = [
|
indexes = [
|
||||||
VectorIndex(
|
VectorIndex(
|
||||||
index_name="image_vector_idx",
|
index_name="image_vector_idx",
|
||||||
index_type=IndexType.HNSW,
|
index_type=IndexType.HNSW,
|
||||||
field="vector",
|
field="vector",
|
||||||
metric_type=MetricType.COSINE,
|
metric_type=MetricType.COSINE,
|
||||||
params=HNSWParams(m=32, efconstruction=200),
|
params=HNSWParams(m=16, efconstruction=100),
|
||||||
auto_build=True
|
auto_build=True
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
# 创建表
|
# 创建表 - 简化配置
|
||||||
self.image_table = self.db.create_table(
|
self.image_table = self.db.create_table(
|
||||||
table_name=self.image_table_name,
|
table_name=self.image_table_name,
|
||||||
replication=2, # 双副本
|
replication=1, # 单副本
|
||||||
partition=Partition(partition_num=3), # 3个分区
|
|
||||||
schema=Schema(fields=fields, indexes=indexes)
|
schema=Schema(fields=fields, indexes=indexes)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
BIN
faiss_index_local.index
Normal file
BIN
faiss_index_local.index
Normal file
Binary file not shown.
1
faiss_index_local_metadata.json
Normal file
1
faiss_index_local_metadata.json
Normal file
@ -0,0 +1 @@
|
|||||||
|
{}
|
||||||
BIN
faiss_index_test.index
Normal file
BIN
faiss_index_test.index
Normal file
Binary file not shown.
1
faiss_index_test_metadata.json
Normal file
1
faiss_index_test_metadata.json
Normal file
@ -0,0 +1 @@
|
|||||||
|
{}
|
||||||
147
faiss_vector_store.py
Normal file
147
faiss_vector_store.py
Normal file
@ -0,0 +1,147 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
import faiss
|
||||||
|
from typing import List, Dict, Any, Optional, Tuple
|
||||||
|
import logging
|
||||||
|
|
||||||
|
class FaissVectorStore:
|
||||||
|
def __init__(self, index_path: str = "faiss_index", dimension: int = 3584):
|
||||||
|
"""
|
||||||
|
初始化FAISS向量存储
|
||||||
|
|
||||||
|
参数:
|
||||||
|
index_path: 索引文件路径
|
||||||
|
dimension: 向量维度
|
||||||
|
"""
|
||||||
|
self.index_path = index_path
|
||||||
|
self.dimension = dimension
|
||||||
|
self.index = None
|
||||||
|
self.metadata = {}
|
||||||
|
self.metadata_path = f"{index_path}_metadata.json"
|
||||||
|
|
||||||
|
# 加载现有索引或创建新索引
|
||||||
|
self._load_or_create_index()
|
||||||
|
|
||||||
|
def _load_or_create_index(self):
|
||||||
|
"""加载现有索引或创建新索引"""
|
||||||
|
if os.path.exists(f"{self.index_path}.index"):
|
||||||
|
logging.info(f"加载现有索引: {self.index_path}")
|
||||||
|
self.index = faiss.read_index(f"{self.index_path}.index")
|
||||||
|
self._load_metadata()
|
||||||
|
else:
|
||||||
|
logging.info(f"创建新索引,维度: {self.dimension}")
|
||||||
|
self.index = faiss.IndexFlatL2(self.dimension) # 使用L2距离
|
||||||
|
|
||||||
|
def _load_metadata(self):
|
||||||
|
"""加载元数据"""
|
||||||
|
if os.path.exists(self.metadata_path):
|
||||||
|
with open(self.metadata_path, 'r', encoding='utf-8') as f:
|
||||||
|
self.metadata = json.load(f)
|
||||||
|
|
||||||
|
def _save_metadata(self):
|
||||||
|
"""保存元数据到文件"""
|
||||||
|
with open(self.metadata_path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(self.metadata, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
def save_index(self):
|
||||||
|
"""保存索引和元数据"""
|
||||||
|
if self.index is not None:
|
||||||
|
faiss.write_index(self.index, f"{self.index_path}.index")
|
||||||
|
self._save_metadata()
|
||||||
|
logging.info(f"索引已保存到 {self.index_path}.index")
|
||||||
|
|
||||||
|
def add_vectors(
|
||||||
|
self,
|
||||||
|
vectors: np.ndarray,
|
||||||
|
metadatas: List[Dict[str, Any]]
|
||||||
|
) -> List[str]:
|
||||||
|
"""
|
||||||
|
添加向量和元数据
|
||||||
|
|
||||||
|
参数:
|
||||||
|
vectors: 向量数组
|
||||||
|
metadatas: 对应的元数据列表
|
||||||
|
|
||||||
|
返回:
|
||||||
|
添加的向量ID列表
|
||||||
|
"""
|
||||||
|
if len(vectors) != len(metadatas):
|
||||||
|
raise ValueError("vectors和metadatas长度必须相同")
|
||||||
|
|
||||||
|
start_id = len(self.metadata)
|
||||||
|
ids = list(range(start_id, start_id + len(vectors)))
|
||||||
|
|
||||||
|
# 添加向量到索引
|
||||||
|
self.index.add(vectors.astype('float32'))
|
||||||
|
|
||||||
|
# 保存元数据
|
||||||
|
for idx, vector_id in enumerate(ids):
|
||||||
|
self.metadata[str(vector_id)] = metadatas[idx]
|
||||||
|
|
||||||
|
# 保存索引和元数据
|
||||||
|
self.save_index()
|
||||||
|
|
||||||
|
return [str(id) for id in ids]
|
||||||
|
|
||||||
|
def search(
|
||||||
|
self,
|
||||||
|
query_vector: np.ndarray,
|
||||||
|
k: int = 5
|
||||||
|
) -> Tuple[List[Dict[str, Any]], List[float]]:
|
||||||
|
"""
|
||||||
|
相似性搜索
|
||||||
|
|
||||||
|
参数:
|
||||||
|
query_vector: 查询向量
|
||||||
|
k: 返回结果数量
|
||||||
|
|
||||||
|
返回:
|
||||||
|
(结果列表, 距离列表)
|
||||||
|
"""
|
||||||
|
if self.index is None:
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
# 确保输入是2D数组
|
||||||
|
if len(query_vector.shape) == 1:
|
||||||
|
query_vector = query_vector.reshape(1, -1)
|
||||||
|
|
||||||
|
# 执行搜索
|
||||||
|
distances, indices = self.index.search(query_vector.astype('float32'), k)
|
||||||
|
|
||||||
|
# 处理结果
|
||||||
|
results = []
|
||||||
|
for i in range(len(indices[0])):
|
||||||
|
idx = indices[0][i]
|
||||||
|
if idx < 0: # FAISS可能返回-1表示无效索引
|
||||||
|
continue
|
||||||
|
|
||||||
|
vector_id = str(idx)
|
||||||
|
if vector_id in self.metadata:
|
||||||
|
result = self.metadata[vector_id].copy()
|
||||||
|
result['distance'] = float(distances[0][i])
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
return results, distances[0].tolist()
|
||||||
|
|
||||||
|
def get_vector_count(self) -> int:
|
||||||
|
"""获取向量数量"""
|
||||||
|
return self.index.ntotal if self.index is not None else 0
|
||||||
|
|
||||||
|
def delete_vectors(self, vector_ids: List[str]) -> bool:
|
||||||
|
"""
|
||||||
|
删除指定ID的向量
|
||||||
|
|
||||||
|
注意: FAISS不支持直接删除向量,这里实现为逻辑删除
|
||||||
|
"""
|
||||||
|
deleted_count = 0
|
||||||
|
for vector_id in vector_ids:
|
||||||
|
if vector_id in self.metadata:
|
||||||
|
del self.metadata[vector_id]
|
||||||
|
deleted_count += 1
|
||||||
|
|
||||||
|
if deleted_count > 0:
|
||||||
|
self._save_metadata()
|
||||||
|
logging.warning("FAISS不支持直接删除向量,已从元数据中移除,但索引中仍保留")
|
||||||
|
|
||||||
|
return deleted_count > 0
|
||||||
BIN
local_faiss_index.index
Normal file
BIN
local_faiss_index.index
Normal file
Binary file not shown.
1
local_faiss_index_metadata.json
Normal file
1
local_faiss_index_metadata.json
Normal file
@ -0,0 +1 @@
|
|||||||
|
{}
|
||||||
135
local_file_handler.py
Normal file
135
local_file_handler.py
Normal file
@ -0,0 +1,135 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
本地文件处理器
|
||||||
|
简化版的文件处理器,不依赖外部服务
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import io
|
||||||
|
import tempfile
|
||||||
|
import logging
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import Dict, List, Optional, Any, Union, BinaryIO
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class LocalFileHandler:
|
||||||
|
"""本地文件处理器"""
|
||||||
|
|
||||||
|
# 小文件阈值 (5MB)
|
||||||
|
SMALL_FILE_THRESHOLD = 5 * 1024 * 1024
|
||||||
|
|
||||||
|
# 支持的图像格式
|
||||||
|
SUPPORTED_IMAGE_FORMATS = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp'}
|
||||||
|
|
||||||
|
def __init__(self, temp_dir: str = None):
|
||||||
|
"""
|
||||||
|
初始化本地文件处理器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
temp_dir: 临时文件目录
|
||||||
|
"""
|
||||||
|
self.temp_dir = temp_dir or tempfile.gettempdir()
|
||||||
|
self.temp_files = set() # 跟踪临时文件
|
||||||
|
|
||||||
|
# 确保临时目录存在
|
||||||
|
os.makedirs(self.temp_dir, exist_ok=True)
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def temp_file_context(self, content: bytes = None, suffix: str = None, delete_on_exit: bool = True):
|
||||||
|
"""临时文件上下文管理器,确保自动清理"""
|
||||||
|
temp_fd, temp_path = tempfile.mkstemp(suffix=suffix, dir=self.temp_dir)
|
||||||
|
self.temp_files.add(temp_path)
|
||||||
|
|
||||||
|
try:
|
||||||
|
os.close(temp_fd) # 关闭文件描述符
|
||||||
|
|
||||||
|
# 如果提供了内容,写入文件
|
||||||
|
if content is not None:
|
||||||
|
with open(temp_path, 'wb') as f:
|
||||||
|
f.write(content)
|
||||||
|
|
||||||
|
yield temp_path
|
||||||
|
finally:
|
||||||
|
if delete_on_exit and os.path.exists(temp_path):
|
||||||
|
try:
|
||||||
|
os.unlink(temp_path)
|
||||||
|
self.temp_files.discard(temp_path)
|
||||||
|
logger.debug(f"🗑️ 临时文件已清理: {temp_path}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"⚠️ 临时文件清理失败: {temp_path}, {e}")
|
||||||
|
|
||||||
|
def cleanup_all_temp_files(self):
|
||||||
|
"""清理所有跟踪的临时文件"""
|
||||||
|
for temp_path in list(self.temp_files):
|
||||||
|
if os.path.exists(temp_path):
|
||||||
|
try:
|
||||||
|
os.unlink(temp_path)
|
||||||
|
logger.debug(f"🗑️ 清理临时文件: {temp_path}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"⚠️ 清理临时文件失败: {temp_path}, {e}")
|
||||||
|
self.temp_files.clear()
|
||||||
|
|
||||||
|
def get_file_size(self, file_obj) -> int:
|
||||||
|
"""获取文件大小"""
|
||||||
|
if hasattr(file_obj, 'content_length') and file_obj.content_length:
|
||||||
|
return file_obj.content_length
|
||||||
|
|
||||||
|
# 通过读取内容获取大小
|
||||||
|
current_pos = file_obj.tell()
|
||||||
|
file_obj.seek(0, 2) # 移动到文件末尾
|
||||||
|
size = file_obj.tell()
|
||||||
|
file_obj.seek(current_pos) # 恢复原位置
|
||||||
|
return size
|
||||||
|
|
||||||
|
def is_small_file(self, file_obj) -> bool:
|
||||||
|
"""判断是否为小文件"""
|
||||||
|
return self.get_file_size(file_obj) <= self.SMALL_FILE_THRESHOLD
|
||||||
|
|
||||||
|
def get_temp_file_for_model(self, file_obj, filename: str) -> Optional[str]:
|
||||||
|
"""为模型处理获取临时文件路径(确保文件存在于本地)"""
|
||||||
|
try:
|
||||||
|
ext = os.path.splitext(filename)[1].lower()
|
||||||
|
|
||||||
|
# 创建临时文件(不自动删除,供模型使用)
|
||||||
|
temp_fd, temp_path = tempfile.mkstemp(suffix=ext, dir=self.temp_dir)
|
||||||
|
self.temp_files.add(temp_path)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 写入文件内容
|
||||||
|
file_obj.seek(0)
|
||||||
|
with os.fdopen(temp_fd, 'wb') as temp_file:
|
||||||
|
temp_file.write(file_obj.read())
|
||||||
|
|
||||||
|
logger.debug(f"📁 为模型创建临时文件: {temp_path}")
|
||||||
|
return temp_path
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
os.close(temp_fd)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ 为模型创建临时文件失败: {filename}, {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def cleanup_temp_file(self, temp_path: str):
|
||||||
|
"""清理指定的临时文件"""
|
||||||
|
if temp_path and os.path.exists(temp_path):
|
||||||
|
try:
|
||||||
|
os.unlink(temp_path)
|
||||||
|
self.temp_files.discard(temp_path)
|
||||||
|
logger.debug(f"🗑️ 清理临时文件: {temp_path}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"⚠️ 清理临时文件失败: {temp_path}, {e}")
|
||||||
|
|
||||||
|
# 全局实例
|
||||||
|
file_handler = None
|
||||||
|
|
||||||
|
def get_file_handler(temp_dir: str = None) -> LocalFileHandler:
|
||||||
|
"""获取文件处理器实例"""
|
||||||
|
global file_handler
|
||||||
|
if file_handler is None:
|
||||||
|
file_handler = LocalFileHandler(temp_dir=temp_dir)
|
||||||
|
return file_handler
|
||||||
108
model_download_guide.md
Normal file
108
model_download_guide.md
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
# 多模态模型下载指南
|
||||||
|
|
||||||
|
## 下载 OpenSearch-AI/Ops-MM-embedding-v1-7B 模型
|
||||||
|
|
||||||
|
### 方法1:使用 git-lfs
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 安装 git-lfs
|
||||||
|
apt-get install git-lfs
|
||||||
|
# 或
|
||||||
|
curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | bash
|
||||||
|
apt-get install git-lfs
|
||||||
|
|
||||||
|
# 初始化 git-lfs
|
||||||
|
git lfs install
|
||||||
|
|
||||||
|
# 克隆模型仓库
|
||||||
|
mkdir -p ~/models
|
||||||
|
git clone https://huggingface.co/OpenSearch-AI/Ops-MM-embedding-v1-7B ~/models/Ops-MM-embedding-v1-7B
|
||||||
|
```
|
||||||
|
|
||||||
|
### 方法2:使用 huggingface-cli
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 安装 huggingface-hub
|
||||||
|
pip install huggingface-hub
|
||||||
|
|
||||||
|
# 下载模型
|
||||||
|
mkdir -p ~/models
|
||||||
|
huggingface-cli download OpenSearch-AI/Ops-MM-embedding-v1-7B --local-dir ~/models/Ops-MM-embedding-v1-7B
|
||||||
|
```
|
||||||
|
|
||||||
|
### 方法3:手动下载关键文件
|
||||||
|
|
||||||
|
如果上述方法不可行,可以手动下载以下关键文件:
|
||||||
|
|
||||||
|
1. 访问 https://huggingface.co/OpenSearch-AI/Ops-MM-embedding-v1-7B/tree/main
|
||||||
|
2. 下载以下文件:
|
||||||
|
- `config.json`
|
||||||
|
- `pytorch_model.bin` (或分片文件 `pytorch_model-00001-of-00002.bin` 等)
|
||||||
|
- `tokenizer.json`
|
||||||
|
- `tokenizer_config.json`
|
||||||
|
- `special_tokens_map.json`
|
||||||
|
- `vocab.txt`
|
||||||
|
|
||||||
|
## 下载替代轻量级模型
|
||||||
|
|
||||||
|
如果主模型太大,可以下载这些较小的替代模型:
|
||||||
|
|
||||||
|
### CLIP 模型
|
||||||
|
|
||||||
|
```bash
|
||||||
|
mkdir -p ~/models/clip-ViT-B-32
|
||||||
|
huggingface-cli download openai/clip-vit-base-patch32 --local-dir ~/models/clip-ViT-B-32
|
||||||
|
```
|
||||||
|
|
||||||
|
### 多语言CLIP模型
|
||||||
|
|
||||||
|
```bash
|
||||||
|
mkdir -p ~/models/clip-multilingual
|
||||||
|
huggingface-cli download sentence-transformers/clip-ViT-B-32-multilingual-v1 --local-dir ~/models/clip-multilingual
|
||||||
|
```
|
||||||
|
|
||||||
|
## 传输模型文件
|
||||||
|
|
||||||
|
下载完成后,使用以下方法将模型传输到目标服务器:
|
||||||
|
|
||||||
|
### 使用 scp
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 从当前机器传输到目标服务器
|
||||||
|
scp -r ~/models/Ops-MM-embedding-v1-7B user@target-server:/root/models/
|
||||||
|
```
|
||||||
|
|
||||||
|
### 使用压缩文件
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 压缩
|
||||||
|
tar -czvf model.tar.gz ~/models/Ops-MM-embedding-v1-7B
|
||||||
|
|
||||||
|
# 传输压缩文件
|
||||||
|
scp model.tar.gz user@target-server:/root/
|
||||||
|
|
||||||
|
# 在目标服务器上解压
|
||||||
|
ssh user@target-server
|
||||||
|
mkdir -p /root/models
|
||||||
|
tar -xzvf /root/model.tar.gz -C /root/models
|
||||||
|
```
|
||||||
|
|
||||||
|
## 验证模型文件
|
||||||
|
|
||||||
|
模型下载完成后,目录结构应类似于:
|
||||||
|
|
||||||
|
```
|
||||||
|
/root/models/Ops-MM-embedding-v1-7B/
|
||||||
|
├── config.json
|
||||||
|
├── pytorch_model.bin (或分片文件)
|
||||||
|
├── tokenizer.json
|
||||||
|
├── tokenizer_config.json
|
||||||
|
├── special_tokens_map.json
|
||||||
|
└── vocab.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
使用以下命令验证文件完整性:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ls -la /root/models/Ops-MM-embedding-v1-7B/
|
||||||
|
```
|
||||||
370
multimodal_retrieval_faiss.py
Normal file
370
multimodal_retrieval_faiss.py
Normal file
@ -0,0 +1,370 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
基于FAISS的多模态检索系统
|
||||||
|
支持文搜文、文搜图、图搜文、图搜图四种检索模式
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn.parallel import DataParallel
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
from transformers import AutoModel, AutoProcessor, AutoTokenizer
|
||||||
|
from typing import List, Union, Tuple, Dict, Any, Optional
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
import logging
|
||||||
|
import gc
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
import threading
|
||||||
|
|
||||||
|
from faiss_vector_store import FaissVectorStore
|
||||||
|
|
||||||
|
# 设置日志
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class MultimodalRetrievalFAISS:
|
||||||
|
"""基于FAISS的多模态检索系统"""
|
||||||
|
|
||||||
|
def __init__(self, model_name: str = "OpenSearch-AI/Ops-MM-embedding-v1-7B",
|
||||||
|
use_all_gpus: bool = True, gpu_ids: List[int] = None,
|
||||||
|
min_memory_gb: int = 12, index_path: str = "faiss_index"):
|
||||||
|
"""
|
||||||
|
初始化多模态检索系统
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: 模型名称
|
||||||
|
use_all_gpus: 是否使用所有可用GPU
|
||||||
|
gpu_ids: 指定使用的GPU ID列表
|
||||||
|
min_memory_gb: 最小可用内存(GB)
|
||||||
|
index_path: FAISS索引文件路径
|
||||||
|
"""
|
||||||
|
self.model_name = model_name
|
||||||
|
self.index_path = index_path
|
||||||
|
|
||||||
|
# 设置GPU设备
|
||||||
|
self._setup_devices(use_all_gpus, gpu_ids, min_memory_gb)
|
||||||
|
|
||||||
|
# 清理GPU内存
|
||||||
|
self._clear_all_gpu_memory()
|
||||||
|
|
||||||
|
# 加载模型和处理器
|
||||||
|
self._load_model_and_processor()
|
||||||
|
|
||||||
|
# 初始化FAISS向量存储
|
||||||
|
self.vector_store = FaissVectorStore(
|
||||||
|
index_path=index_path,
|
||||||
|
dimension=3584 # OpenSearch-AI/Ops-MM-embedding-v1-7B的向量维度
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"多模态检索系统初始化完成,使用模型: {model_name}")
|
||||||
|
logger.info(f"向量存储路径: {index_path}")
|
||||||
|
|
||||||
|
def _setup_devices(self, use_all_gpus: bool, gpu_ids: List[int], min_memory_gb: int):
|
||||||
|
"""设置GPU设备"""
|
||||||
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
self.use_gpu = self.device.type == "cuda"
|
||||||
|
|
||||||
|
if self.use_gpu:
|
||||||
|
self.available_gpus = self._get_available_gpus(min_memory_gb)
|
||||||
|
|
||||||
|
if not self.available_gpus:
|
||||||
|
logger.warning(f"没有可用的GPU或GPU内存不足{min_memory_gb}GB,将使用CPU")
|
||||||
|
self.device = torch.device("cpu")
|
||||||
|
self.use_gpu = False
|
||||||
|
else:
|
||||||
|
if gpu_ids:
|
||||||
|
self.gpu_ids = [gid for gid in gpu_ids if gid in self.available_gpus]
|
||||||
|
if not self.gpu_ids:
|
||||||
|
logger.warning(f"指定的GPU {gpu_ids}不可用或内存不足,将使用可用的GPU: {self.available_gpus}")
|
||||||
|
self.gpu_ids = self.available_gpus
|
||||||
|
elif use_all_gpus:
|
||||||
|
self.gpu_ids = self.available_gpus
|
||||||
|
else:
|
||||||
|
self.gpu_ids = [self.available_gpus[0]]
|
||||||
|
|
||||||
|
logger.info(f"使用GPU: {self.gpu_ids}")
|
||||||
|
self.device = torch.device(f"cuda:{self.gpu_ids[0]}")
|
||||||
|
|
||||||
|
def _get_available_gpus(self, min_memory_gb: int) -> List[int]:
|
||||||
|
"""获取可用的GPU列表"""
|
||||||
|
available_gpus = []
|
||||||
|
for i in range(torch.cuda.device_count()):
|
||||||
|
total_mem = torch.cuda.get_device_properties(i).total_memory / (1024 ** 3) # GB
|
||||||
|
if total_mem >= min_memory_gb:
|
||||||
|
available_gpus.append(i)
|
||||||
|
return available_gpus
|
||||||
|
|
||||||
|
def _clear_all_gpu_memory(self):
|
||||||
|
"""清理GPU内存"""
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
def _load_model_and_processor(self):
|
||||||
|
"""加载模型和处理器"""
|
||||||
|
logger.info(f"加载模型和处理器: {self.model_name}")
|
||||||
|
|
||||||
|
# 加载tokenizer和processor
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
||||||
|
self.processor = AutoProcessor.from_pretrained(self.model_name)
|
||||||
|
|
||||||
|
# 加载模型
|
||||||
|
self.model = AutoModel.from_pretrained(
|
||||||
|
self.model_name,
|
||||||
|
torch_dtype=torch.float16 if self.use_gpu else torch.float32,
|
||||||
|
device_map="auto" if len(self.gpu_ids) > 1 else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# 如果使用多GPU,包装模型
|
||||||
|
if len(self.gpu_ids) > 1:
|
||||||
|
self.model = DataParallel(self.model, device_ids=self.gpu_ids)
|
||||||
|
|
||||||
|
self.model.eval()
|
||||||
|
self.model.to(self.device)
|
||||||
|
|
||||||
|
logger.info("模型和处理器加载完成")
|
||||||
|
|
||||||
|
def encode_text(self, text: Union[str, List[str]]) -> np.ndarray:
|
||||||
|
"""编码文本为向量"""
|
||||||
|
if isinstance(text, str):
|
||||||
|
text = [text]
|
||||||
|
|
||||||
|
inputs = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt")
|
||||||
|
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = self.model(**inputs)
|
||||||
|
# 获取[CLS]标记的隐藏状态作为句子表示
|
||||||
|
text_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()
|
||||||
|
|
||||||
|
# 归一化向量
|
||||||
|
text_embeddings = text_embeddings / np.linalg.norm(text_embeddings, axis=1, keepdims=True)
|
||||||
|
return text_embeddings[0] if len(text) == 1 else text_embeddings
|
||||||
|
|
||||||
|
def encode_image(self, image: Union[Image.Image, List[Image.Image]]) -> np.ndarray:
|
||||||
|
"""编码图像为向量"""
|
||||||
|
if isinstance(image, Image.Image):
|
||||||
|
image = [image]
|
||||||
|
|
||||||
|
inputs = self.processor(images=image, return_tensors="pt")
|
||||||
|
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = self.model.vision_model(**inputs)
|
||||||
|
# 获取[CLS]标记的隐藏状态作为图像表示
|
||||||
|
image_embeddings = outputs.pooler_output.cpu().numpy()
|
||||||
|
|
||||||
|
# 归一化向量
|
||||||
|
image_embeddings = image_embeddings / np.linalg.norm(image_embeddings, axis=1, keepdims=True)
|
||||||
|
return image_embeddings[0] if len(image) == 1 else image_embeddings
|
||||||
|
|
||||||
|
def add_texts(
|
||||||
|
self,
|
||||||
|
texts: List[str],
|
||||||
|
metadatas: Optional[List[Dict[str, Any]]] = None
|
||||||
|
) -> List[str]:
|
||||||
|
"""
|
||||||
|
添加文本到检索系统
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: 文本列表
|
||||||
|
metadatas: 元数据列表,每个元素是一个字典
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
添加的文本ID列表
|
||||||
|
"""
|
||||||
|
if not texts:
|
||||||
|
return []
|
||||||
|
|
||||||
|
if metadatas is None:
|
||||||
|
metadatas = [{} for _ in range(len(texts))]
|
||||||
|
|
||||||
|
if len(texts) != len(metadatas):
|
||||||
|
raise ValueError("texts和metadatas长度必须相同")
|
||||||
|
|
||||||
|
# 编码文本
|
||||||
|
text_embeddings = self.encode_text(texts)
|
||||||
|
|
||||||
|
# 准备元数据
|
||||||
|
for i, text in enumerate(texts):
|
||||||
|
metadatas[i].update({
|
||||||
|
"text": text,
|
||||||
|
"type": "text"
|
||||||
|
})
|
||||||
|
|
||||||
|
# 添加到向量存储
|
||||||
|
vector_ids = self.vector_store.add_vectors(text_embeddings, metadatas)
|
||||||
|
|
||||||
|
logger.info(f"成功添加{len(vector_ids)}条文本到检索系统")
|
||||||
|
return vector_ids
|
||||||
|
|
||||||
|
def add_images(
|
||||||
|
self,
|
||||||
|
images: List[Image.Image],
|
||||||
|
metadatas: Optional[List[Dict[str, Any]]] = None
|
||||||
|
) -> List[str]:
|
||||||
|
"""
|
||||||
|
添加图像到检索系统
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images: PIL图像列表
|
||||||
|
metadatas: 元数据列表,每个元素是一个字典
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
添加的图像ID列表
|
||||||
|
"""
|
||||||
|
if not images:
|
||||||
|
return []
|
||||||
|
|
||||||
|
if metadatas is None:
|
||||||
|
metadatas = [{} for _ in range(len(images))]
|
||||||
|
|
||||||
|
if len(images) != len(metadatas):
|
||||||
|
raise ValueError("images和metadatas长度必须相同")
|
||||||
|
|
||||||
|
# 编码图像
|
||||||
|
image_embeddings = self.encode_image(images)
|
||||||
|
|
||||||
|
# 准备元数据
|
||||||
|
for i, image in enumerate(images):
|
||||||
|
metadatas[i].update({
|
||||||
|
"type": "image",
|
||||||
|
"width": image.width,
|
||||||
|
"height": image.height
|
||||||
|
})
|
||||||
|
|
||||||
|
# 添加到向量存储
|
||||||
|
vector_ids = self.vector_store.add_vectors(image_embeddings, metadatas)
|
||||||
|
|
||||||
|
logger.info(f"成功添加{len(vector_ids)}张图像到检索系统")
|
||||||
|
return vector_ids
|
||||||
|
|
||||||
|
def search_by_text(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
k: int = 5,
|
||||||
|
filter_condition: Optional[Dict[str, Any]] = None
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
文本搜索
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 查询文本
|
||||||
|
k: 返回结果数量
|
||||||
|
filter_condition: 过滤条件
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
搜索结果列表,每个元素包含相似项和分数
|
||||||
|
"""
|
||||||
|
# 编码查询文本
|
||||||
|
query_embedding = self.encode_text(query)
|
||||||
|
|
||||||
|
# 执行搜索
|
||||||
|
results, distances = self.vector_store.search(query_embedding, k)
|
||||||
|
|
||||||
|
# 处理结果
|
||||||
|
search_results = []
|
||||||
|
for i, (result, distance) in enumerate(zip(results, distances)):
|
||||||
|
result["score"] = 1.0 / (1.0 + distance) # 将距离转换为相似度分数
|
||||||
|
search_results.append(result)
|
||||||
|
|
||||||
|
return search_results
|
||||||
|
|
||||||
|
def search_by_image(
|
||||||
|
self,
|
||||||
|
image: Image.Image,
|
||||||
|
k: int = 5,
|
||||||
|
filter_condition: Optional[Dict[str, Any]] = None
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
图像搜索
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: 查询图像
|
||||||
|
k: 返回结果数量
|
||||||
|
filter_condition: 过滤条件
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
搜索结果列表,每个元素包含相似项和分数
|
||||||
|
"""
|
||||||
|
# 编码查询图像
|
||||||
|
query_embedding = self.encode_image(image)
|
||||||
|
|
||||||
|
# 执行搜索
|
||||||
|
results, distances = self.vector_store.search(query_embedding, k)
|
||||||
|
|
||||||
|
# 处理结果
|
||||||
|
search_results = []
|
||||||
|
for i, (result, distance) in enumerate(zip(results, distances)):
|
||||||
|
result["score"] = 1.0 / (1.0 + distance) # 将距离转换为相似度分数
|
||||||
|
search_results.append(result)
|
||||||
|
|
||||||
|
return search_results
|
||||||
|
|
||||||
|
def get_vector_count(self) -> int:
|
||||||
|
"""获取向量数量"""
|
||||||
|
return self.vector_store.get_vector_count()
|
||||||
|
|
||||||
|
def save_index(self):
|
||||||
|
"""保存索引"""
|
||||||
|
self.vector_store.save_index()
|
||||||
|
logger.info("索引已保存")
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
"""析构函数,确保资源被正确释放"""
|
||||||
|
if hasattr(self, 'model'):
|
||||||
|
del self.model
|
||||||
|
self._clear_all_gpu_memory()
|
||||||
|
if hasattr(self, 'vector_store'):
|
||||||
|
self.save_index()
|
||||||
|
|
||||||
|
|
||||||
|
def test_faiss_system():
|
||||||
|
"""测试FAISS多模态检索系统"""
|
||||||
|
import time
|
||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# 初始化检索系统
|
||||||
|
print("初始化多模态检索系统...")
|
||||||
|
retrieval = MultimodalRetrievalFAISS(
|
||||||
|
model_name="OpenSearch-AI/Ops-MM-embedding-v1-7B",
|
||||||
|
use_all_gpus=True,
|
||||||
|
index_path="faiss_index_test"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 测试文本
|
||||||
|
texts = [
|
||||||
|
"一只可爱的橘色猫咪在沙发上睡觉",
|
||||||
|
"城市夜景中的高楼大厦和车流",
|
||||||
|
"阳光明媚的海滩上,人们在冲浪和晒太阳",
|
||||||
|
"美味的意大利面配红酒和沙拉",
|
||||||
|
"雪山上滑雪的运动员"
|
||||||
|
]
|
||||||
|
|
||||||
|
# 添加文本
|
||||||
|
print("\n添加文本到检索系统...")
|
||||||
|
text_ids = retrieval.add_texts(texts)
|
||||||
|
print(f"添加了{len(text_ids)}条文本")
|
||||||
|
|
||||||
|
# 测试文本搜索
|
||||||
|
print("\n测试文本搜索...")
|
||||||
|
query_text = "一只猫在睡觉"
|
||||||
|
print(f"查询: {query_text}")
|
||||||
|
results = retrieval.search_by_text(query_text, k=2)
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
print(f"结果 {i+1}: {result.get('text', 'N/A')} (分数: {result.get('score', 0):.4f})")
|
||||||
|
|
||||||
|
# 测试保存和加载
|
||||||
|
print("\n保存索引...")
|
||||||
|
retrieval.save_index()
|
||||||
|
|
||||||
|
print("\n测试完成!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_faiss_system()
|
||||||
607
multimodal_retrieval_local.py
Normal file
607
multimodal_retrieval_local.py
Normal file
@ -0,0 +1,607 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
使用本地模型的多模态检索系统
|
||||||
|
支持文搜文、文搜图、图搜文、图搜图四种检索模式
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
from transformers import AutoModel, AutoProcessor, AutoTokenizer
|
||||||
|
from typing import List, Union, Tuple, Dict, Any, Optional
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
import logging
|
||||||
|
import gc
|
||||||
|
import faiss
|
||||||
|
import time
|
||||||
|
|
||||||
|
# 设置日志
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 设置离线模式
|
||||||
|
os.environ['TRANSFORMERS_OFFLINE'] = '1'
|
||||||
|
|
||||||
|
class MultimodalRetrievalLocal:
|
||||||
|
"""使用本地模型的多模态检索系统"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
model_path: str = "/root/models/Ops-MM-embedding-v1-7B",
|
||||||
|
use_all_gpus: bool = True,
|
||||||
|
gpu_ids: List[int] = None,
|
||||||
|
min_memory_gb: int = 12,
|
||||||
|
index_path: str = "local_faiss_index"):
|
||||||
|
"""
|
||||||
|
初始化多模态检索系统
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: 本地模型路径
|
||||||
|
use_all_gpus: 是否使用所有可用GPU
|
||||||
|
gpu_ids: 指定使用的GPU ID列表
|
||||||
|
min_memory_gb: 最小可用内存(GB)
|
||||||
|
index_path: FAISS索引文件路径
|
||||||
|
"""
|
||||||
|
self.model_path = model_path
|
||||||
|
self.index_path = index_path
|
||||||
|
|
||||||
|
# 检查模型路径
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
logger.error(f"模型路径不存在: {model_path}")
|
||||||
|
logger.info("请先下载模型到指定路径")
|
||||||
|
raise FileNotFoundError(f"模型路径不存在: {model_path}")
|
||||||
|
|
||||||
|
# 设置GPU设备
|
||||||
|
self._setup_devices(use_all_gpus, gpu_ids, min_memory_gb)
|
||||||
|
|
||||||
|
# 清理GPU内存
|
||||||
|
self._clear_all_gpu_memory()
|
||||||
|
|
||||||
|
# 加载模型和处理器
|
||||||
|
self._load_model_and_processor()
|
||||||
|
|
||||||
|
# 初始化FAISS索引
|
||||||
|
self._init_index()
|
||||||
|
|
||||||
|
logger.info(f"多模态检索系统初始化完成,使用本地模型: {model_path}")
|
||||||
|
logger.info(f"向量存储路径: {index_path}")
|
||||||
|
|
||||||
|
def _setup_devices(self, use_all_gpus: bool, gpu_ids: List[int], min_memory_gb: int):
|
||||||
|
"""设置GPU设备"""
|
||||||
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
self.use_gpu = self.device.type == "cuda"
|
||||||
|
|
||||||
|
if self.use_gpu:
|
||||||
|
self.available_gpus = self._get_available_gpus(min_memory_gb)
|
||||||
|
|
||||||
|
if not self.available_gpus:
|
||||||
|
logger.warning(f"没有可用的GPU或GPU内存不足{min_memory_gb}GB,将使用CPU")
|
||||||
|
self.device = torch.device("cpu")
|
||||||
|
self.use_gpu = False
|
||||||
|
else:
|
||||||
|
if gpu_ids:
|
||||||
|
self.gpu_ids = [gid for gid in gpu_ids if gid in self.available_gpus]
|
||||||
|
if not self.gpu_ids:
|
||||||
|
logger.warning(f"指定的GPU {gpu_ids}不可用或内存不足,将使用可用的GPU: {self.available_gpus}")
|
||||||
|
self.gpu_ids = self.available_gpus
|
||||||
|
elif use_all_gpus:
|
||||||
|
self.gpu_ids = self.available_gpus
|
||||||
|
else:
|
||||||
|
self.gpu_ids = [self.available_gpus[0]]
|
||||||
|
|
||||||
|
logger.info(f"使用GPU: {self.gpu_ids}")
|
||||||
|
self.device = torch.device(f"cuda:{self.gpu_ids[0]}")
|
||||||
|
else:
|
||||||
|
logger.warning("没有可用的GPU,将使用CPU")
|
||||||
|
self.gpu_ids = []
|
||||||
|
|
||||||
|
def _get_available_gpus(self, min_memory_gb: int) -> List[int]:
|
||||||
|
"""获取可用的GPU列表"""
|
||||||
|
available_gpus = []
|
||||||
|
for i in range(torch.cuda.device_count()):
|
||||||
|
total_mem = torch.cuda.get_device_properties(i).total_memory / (1024 ** 3) # GB
|
||||||
|
if total_mem >= min_memory_gb:
|
||||||
|
available_gpus.append(i)
|
||||||
|
return available_gpus
|
||||||
|
|
||||||
|
def _clear_all_gpu_memory(self):
|
||||||
|
"""清理GPU内存"""
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
def _load_model_and_processor(self):
|
||||||
|
"""加载模型和处理器"""
|
||||||
|
logger.info(f"加载本地模型和处理器: {self.model_path}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 加载模型和处理器
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
|
||||||
|
self.processor = AutoProcessor.from_pretrained(self.model_path)
|
||||||
|
|
||||||
|
# 输出处理器信息
|
||||||
|
logger.info(f"Processor类型: {type(self.processor)}")
|
||||||
|
logger.info(f"Processor方法: {dir(self.processor)}")
|
||||||
|
|
||||||
|
# 检查是否有图像处理器
|
||||||
|
if hasattr(self.processor, 'image_processor'):
|
||||||
|
logger.info(f"Image processor类型: {type(self.processor.image_processor)}")
|
||||||
|
logger.info(f"Image processor方法: {dir(self.processor.image_processor)}")
|
||||||
|
|
||||||
|
# 加载模型
|
||||||
|
self.model = AutoModel.from_pretrained(
|
||||||
|
self.model_path,
|
||||||
|
torch_dtype=torch.float16 if self.use_gpu else torch.float32,
|
||||||
|
device_map="auto" if len(self.gpu_ids) > 1 else None
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(self.gpu_ids) == 1:
|
||||||
|
self.model.to(self.device)
|
||||||
|
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
# 获取向量维度
|
||||||
|
self.vector_dim = self.model.config.hidden_size
|
||||||
|
logger.info(f"向量维度: {self.vector_dim}")
|
||||||
|
|
||||||
|
logger.info("模型和处理器加载成功")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"模型加载失败: {str(e)}")
|
||||||
|
raise RuntimeError(f"模型加载失败: {str(e)}")
|
||||||
|
|
||||||
|
def _init_index(self):
|
||||||
|
"""初始化FAISS索引"""
|
||||||
|
index_file = f"{self.index_path}.index"
|
||||||
|
if os.path.exists(index_file):
|
||||||
|
logger.info(f"加载现有索引: {index_file}")
|
||||||
|
try:
|
||||||
|
self.index = faiss.read_index(index_file)
|
||||||
|
logger.info(f"索引加载成功,包含{self.index.ntotal}个向量")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"索引加载失败: {str(e)}")
|
||||||
|
logger.info("创建新索引...")
|
||||||
|
self.index = faiss.IndexFlatL2(self.vector_dim)
|
||||||
|
else:
|
||||||
|
logger.info(f"创建新索引,维度: {self.vector_dim}")
|
||||||
|
self.index = faiss.IndexFlatL2(self.vector_dim)
|
||||||
|
|
||||||
|
# 加载元数据
|
||||||
|
self.metadata = {}
|
||||||
|
metadata_file = f"{self.index_path}_metadata.json"
|
||||||
|
if os.path.exists(metadata_file):
|
||||||
|
try:
|
||||||
|
with open(metadata_file, 'r', encoding='utf-8') as f:
|
||||||
|
self.metadata = json.load(f)
|
||||||
|
logger.info(f"元数据加载成功,包含{len(self.metadata)}条记录")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"元数据加载失败: {str(e)}")
|
||||||
|
|
||||||
|
def encode_text(self, text: Union[str, List[str]]) -> np.ndarray:
|
||||||
|
"""编码文本为向量"""
|
||||||
|
if isinstance(text, str):
|
||||||
|
text = [text]
|
||||||
|
|
||||||
|
inputs = self.tokenizer(text, padding=True, truncation=True, return_tensors="pt")
|
||||||
|
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = self.model(**inputs)
|
||||||
|
# 获取[CLS]标记的隐藏状态作为句子表示
|
||||||
|
text_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()
|
||||||
|
|
||||||
|
# 归一化向量
|
||||||
|
text_embeddings = text_embeddings / np.linalg.norm(text_embeddings, axis=1, keepdims=True)
|
||||||
|
return text_embeddings[0] if len(text) == 1 else text_embeddings
|
||||||
|
|
||||||
|
def encode_image(self, image: Union[Image.Image, List[Image.Image]]) -> np.ndarray:
|
||||||
|
"""编码图像为向量"""
|
||||||
|
try:
|
||||||
|
logger.info(f"encode_image: 开始编码图像,类型: {type(image)}")
|
||||||
|
|
||||||
|
if isinstance(image, Image.Image):
|
||||||
|
logger.info(f"encode_image: 单个图像,大小: {image.size}")
|
||||||
|
image = [image]
|
||||||
|
else:
|
||||||
|
logger.info(f"encode_image: 图像列表,长度: {len(image)}")
|
||||||
|
|
||||||
|
# 检查图像是否为空
|
||||||
|
if not image or len(image) == 0:
|
||||||
|
logger.error("encode_image: 图像列表为空")
|
||||||
|
# 返回一个空的二维数组
|
||||||
|
return np.zeros((0, self.vector_dim))
|
||||||
|
|
||||||
|
# 检查图像是否有效
|
||||||
|
for i, img in enumerate(image):
|
||||||
|
if not isinstance(img, Image.Image):
|
||||||
|
logger.error(f"encode_image: 第{i}个元素不是有效的PIL图像,类型: {type(img)}")
|
||||||
|
# 返回一个空的二维数组
|
||||||
|
return np.zeros((0, self.vector_dim))
|
||||||
|
|
||||||
|
logger.info("encode_image: 处理图像输入")
|
||||||
|
|
||||||
|
# 检查图像格式
|
||||||
|
for i, img in enumerate(image):
|
||||||
|
logger.info(f"encode_image: 图像 {i} 格式: {img.format}, 模式: {img.mode}, 大小: {img.size}")
|
||||||
|
# 转换为RGB模式,如果不是
|
||||||
|
if img.mode != 'RGB':
|
||||||
|
logger.info(f"encode_image: 将图像 {i} 从 {img.mode} 转换为 RGB")
|
||||||
|
image[i] = img.convert('RGB')
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 直接使用image_processor处理图像
|
||||||
|
if hasattr(self.processor, 'image_processor'):
|
||||||
|
logger.info("encode_image: 使用image_processor处理图像")
|
||||||
|
pixel_values = self.processor.image_processor(images=image, return_tensors="pt").pixel_values
|
||||||
|
inputs = {"pixel_values": pixel_values}
|
||||||
|
else:
|
||||||
|
logger.info("encode_image: 使用processor处理图像")
|
||||||
|
inputs = self.processor(images=image, return_tensors="pt")
|
||||||
|
|
||||||
|
if not inputs or len(inputs) == 0:
|
||||||
|
logger.error("encode_image: processor返回了空的输入")
|
||||||
|
return np.zeros((0, self.vector_dim))
|
||||||
|
|
||||||
|
logger.info(f"encode_image: 处理后的输入键: {list(inputs.keys())}")
|
||||||
|
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||||
|
|
||||||
|
logger.info("encode_image: 运行模型推理")
|
||||||
|
logger.info(f"Model类型: {type(self.model)}")
|
||||||
|
logger.info(f"Model属性: {dir(self.model)}")
|
||||||
|
|
||||||
|
# 检查模型结构
|
||||||
|
try:
|
||||||
|
logger.info(f"Model配置: {self.model.config}")
|
||||||
|
logger.info(f"Model配置属性: {dir(self.model.config)}")
|
||||||
|
else:
|
||||||
|
visual_outputs = self.model.visual(**inputs)
|
||||||
|
|
||||||
|
if hasattr(visual_outputs, 'pooler_output'):
|
||||||
|
image_embeddings = visual_outputs.pooler_output.cpu().numpy()
|
||||||
|
elif hasattr(visual_outputs, 'last_hidden_state'):
|
||||||
|
image_embeddings = visual_outputs.last_hidden_state[:, 0, :].cpu().numpy()
|
||||||
|
else:
|
||||||
|
logger.error("encode_image: 无法从视觉模型输出中获取图像向量")
|
||||||
|
raise ValueError("无法从视觉模型输出中获取图像向量")
|
||||||
|
else:
|
||||||
|
# 尝试直接使用模型进行推理
|
||||||
|
logger.info("encode_image: 尝试直接使用模型进行推理")
|
||||||
|
with torch.no_grad():
|
||||||
|
# 使用空文本输入,只提供图像
|
||||||
|
if 'pixel_values' in inputs:
|
||||||
|
outputs = self.model(pixel_values=inputs['pixel_values'], input_ids=None)
|
||||||
|
else:
|
||||||
|
outputs = self.model(**inputs, input_ids=None)
|
||||||
|
|
||||||
|
# 尝试从输出中获取图像向量
|
||||||
|
if hasattr(outputs, 'image_embeds'):
|
||||||
|
image_embeddings = outputs.image_embeds.cpu().numpy()
|
||||||
|
elif hasattr(outputs, 'vision_model_output') and hasattr(outputs.vision_model_output, 'pooler_output'):
|
||||||
|
image_embeddings = outputs.vision_model_output.pooler_output.cpu().numpy()
|
||||||
|
elif hasattr(outputs, 'pooler_output'):
|
||||||
|
image_embeddings = outputs.pooler_output.cpu().numpy()
|
||||||
|
elif hasattr(outputs, 'last_hidden_state'):
|
||||||
|
image_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()
|
||||||
|
else:
|
||||||
|
logger.error("encode_image: 无法从模型输出中获取图像向量")
|
||||||
|
raise ValueError("无法从模型输出中获取图像向量")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"encode_image: 处理图像时出错: {str(e)}")
|
||||||
|
raise e
|
||||||
|
return np.zeros((0, self.vector_dim))
|
||||||
|
|
||||||
|
# 归一化向量
|
||||||
|
image_embeddings = image_embeddings / np.linalg.norm(image_embeddings, axis=1, keepdims=True)
|
||||||
|
|
||||||
|
# 始终返回二维数组,即使只有一个图像
|
||||||
|
if len(image) == 1:
|
||||||
|
result = np.array([image_embeddings[0]])
|
||||||
|
logger.info(f"encode_image: 返回单个图像向量,形状: {result.shape}")
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
logger.info(f"encode_image: 返回多个图像向量,形状: {image_embeddings.shape}")
|
||||||
|
return image_embeddings
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"encode_image: 异常: {str(e)}")
|
||||||
|
# 返回一个空的二维数组
|
||||||
|
return np.zeros((0, self.vector_dim))
|
||||||
|
|
||||||
|
def add_texts(
|
||||||
|
self,
|
||||||
|
texts: List[str],
|
||||||
|
metadatas: Optional[List[Dict[str, Any]]] = None
|
||||||
|
) -> List[str]:
|
||||||
|
"""
|
||||||
|
添加文本到检索系统
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: 文本列表
|
||||||
|
metadatas: 元数据列表,每个元素是一个字典
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
添加的文本ID列表
|
||||||
|
"""
|
||||||
|
if not texts:
|
||||||
|
return []
|
||||||
|
|
||||||
|
if metadatas is None:
|
||||||
|
metadatas = [{} for _ in range(len(texts))]
|
||||||
|
|
||||||
|
if len(texts) != len(metadatas):
|
||||||
|
raise ValueError("texts和metadatas长度必须相同")
|
||||||
|
|
||||||
|
# 编码文本
|
||||||
|
text_embeddings = self.encode_text(texts)
|
||||||
|
|
||||||
|
# 准备元数据
|
||||||
|
start_id = self.index.ntotal
|
||||||
|
ids = list(range(start_id, start_id + len(texts)))
|
||||||
|
|
||||||
|
# 添加到索引
|
||||||
|
self.index.add(np.array(text_embeddings).astype('float32'))
|
||||||
|
|
||||||
|
# 保存元数据
|
||||||
|
for i, id in enumerate(ids):
|
||||||
|
self.metadata[str(id)] = {
|
||||||
|
"text": texts[i],
|
||||||
|
"type": "text",
|
||||||
|
**metadatas[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"成功添加{len(ids)}条文本到检索系统")
|
||||||
|
return [str(id) for id in ids]
|
||||||
|
|
||||||
|
def add_images(
|
||||||
|
self,
|
||||||
|
images: List[Image.Image],
|
||||||
|
metadatas: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
image_paths: Optional[List[str]] = None
|
||||||
|
) -> List[str]:
|
||||||
|
"""
|
||||||
|
添加图像到检索系统
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images: PIL图像列表
|
||||||
|
metadatas: 元数据列表,每个元素是一个字典
|
||||||
|
image_paths: 图像路径列表,用于保存到元数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
添加的图像ID列表
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.info(f"add_images: 开始添加图像,数量: {len(images) if images else 0}")
|
||||||
|
|
||||||
|
# 检查图像列表
|
||||||
|
if not images or len(images) == 0:
|
||||||
|
logger.warning("add_images: 图像列表为空")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 准备元数据
|
||||||
|
if metadatas is None:
|
||||||
|
logger.info("add_images: 创建默认元数据")
|
||||||
|
metadatas = [{} for _ in range(len(images))]
|
||||||
|
|
||||||
|
# 检查长度一致性
|
||||||
|
if len(images) != len(metadatas):
|
||||||
|
logger.error(f"add_images: 长度不一致 - images: {len(images)}, metadatas: {len(metadatas)}")
|
||||||
|
raise ValueError("images和metadatas长度必须相同")
|
||||||
|
|
||||||
|
# 编码图像
|
||||||
|
logger.info("add_images: 编码图像")
|
||||||
|
image_embeddings = self.encode_image(images)
|
||||||
|
|
||||||
|
# 检查编码结果
|
||||||
|
if image_embeddings.shape[0] == 0:
|
||||||
|
logger.error("add_images: 图像编码失败,返回空数组")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 准备元数据
|
||||||
|
start_id = self.index.ntotal
|
||||||
|
ids = list(range(start_id, start_id + len(images)))
|
||||||
|
logger.info(f"add_images: 生成索引ID: {start_id} - {start_id + len(images) - 1}")
|
||||||
|
|
||||||
|
# 添加到索引
|
||||||
|
logger.info(f"add_images: 添加向量到FAISS索引,形状: {image_embeddings.shape}")
|
||||||
|
self.index.add(np.array(image_embeddings).astype('float32'))
|
||||||
|
|
||||||
|
# 保存元数据
|
||||||
|
for i, id in enumerate(ids):
|
||||||
|
try:
|
||||||
|
metadata = {
|
||||||
|
"type": "image",
|
||||||
|
"width": images[i].width,
|
||||||
|
"height": images[i].height,
|
||||||
|
**metadatas[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
if image_paths and i < len(image_paths):
|
||||||
|
metadata["path"] = image_paths[i]
|
||||||
|
|
||||||
|
self.metadata[str(id)] = metadata
|
||||||
|
logger.debug(f"add_images: 保存元数据成功 - ID: {id}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"add_images: 保存元数据失败 - ID: {id}, 错误: {str(e)}")
|
||||||
|
|
||||||
|
logger.info(f"add_images: 成功添加{len(ids)}张图像到检索系统")
|
||||||
|
return [str(id) for id in ids]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"add_images: 添加图像异常: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def search_by_text(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
k: int = 5,
|
||||||
|
filter_type: Optional[str] = None
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
文本搜索
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 查询文本
|
||||||
|
k: 返回结果数量
|
||||||
|
filter_type: 过滤类型,可选值: "text", "image", None(不过滤)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
搜索结果列表,每个元素包含相似项和分数
|
||||||
|
"""
|
||||||
|
# 编码查询文本
|
||||||
|
query_embedding = self.encode_text(query)
|
||||||
|
|
||||||
|
# 执行搜索
|
||||||
|
return self._search(query_embedding, k, filter_type)
|
||||||
|
|
||||||
|
def search_by_image(
|
||||||
|
self,
|
||||||
|
image: Image.Image,
|
||||||
|
k: int = 5,
|
||||||
|
filter_type: Optional[str] = None
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
图像搜索
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: 查询图像
|
||||||
|
k: 返回结果数量
|
||||||
|
filter_type: 过滤类型,可选值: "text", "image", None(不过滤)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
搜索结果列表,每个元素包含相似项和分数
|
||||||
|
"""
|
||||||
|
# 编码查询图像
|
||||||
|
query_embedding = self.encode_image(image)
|
||||||
|
|
||||||
|
# 执行搜索
|
||||||
|
return self._search(query_embedding, k, filter_type)
|
||||||
|
|
||||||
|
def _search(
|
||||||
|
self,
|
||||||
|
query_embedding: np.ndarray,
|
||||||
|
k: int = 5,
|
||||||
|
filter_type: Optional[str] = None
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
执行搜索
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query_embedding: 查询向量
|
||||||
|
k: 返回结果数量
|
||||||
|
filter_type: 过滤类型,可选值: "text", "image", None(不过滤)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
搜索结果列表
|
||||||
|
"""
|
||||||
|
if self.index.ntotal == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 确保查询向量是2D数组
|
||||||
|
if len(query_embedding.shape) == 1:
|
||||||
|
query_embedding = query_embedding.reshape(1, -1)
|
||||||
|
|
||||||
|
# 执行搜索,获取更多结果以便过滤
|
||||||
|
actual_k = k * 3 if filter_type else k
|
||||||
|
actual_k = min(actual_k, self.index.ntotal)
|
||||||
|
distances, indices = self.index.search(query_embedding.astype('float32'), actual_k)
|
||||||
|
|
||||||
|
# 处理结果
|
||||||
|
results = []
|
||||||
|
for i in range(len(indices[0])):
|
||||||
|
idx = indices[0][i]
|
||||||
|
if idx < 0: # FAISS可能返回-1表示无效索引
|
||||||
|
continue
|
||||||
|
|
||||||
|
vector_id = str(idx)
|
||||||
|
if vector_id in self.metadata:
|
||||||
|
item = self.metadata[vector_id]
|
||||||
|
|
||||||
|
# 如果指定了过滤类型,则只返回该类型的结果
|
||||||
|
if filter_type and item.get("type") != filter_type:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 添加距离和分数
|
||||||
|
result = item.copy()
|
||||||
|
result["distance"] = float(distances[0][i])
|
||||||
|
result["score"] = float(1.0 / (1.0 + distances[0][i]))
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
# 如果已经收集了足够的结果,则停止
|
||||||
|
if len(results) >= k:
|
||||||
|
break
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def save_index(self):
|
||||||
|
"""保存索引和元数据"""
|
||||||
|
# 保存索引
|
||||||
|
index_file = f"{self.index_path}.index"
|
||||||
|
try:
|
||||||
|
faiss.write_index(self.index, index_file)
|
||||||
|
logger.info(f"索引保存成功: {index_file}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"索引保存失败: {str(e)}")
|
||||||
|
|
||||||
|
# 保存元数据
|
||||||
|
metadata_file = f"{self.index_path}_metadata.json"
|
||||||
|
try:
|
||||||
|
with open(metadata_file, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(self.metadata, f, ensure_ascii=False, indent=2)
|
||||||
|
logger.info(f"元数据保存成功: {metadata_file}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"元数据保存失败: {str(e)}")
|
||||||
|
|
||||||
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
|
"""获取检索系统统计信息"""
|
||||||
|
text_count = sum(1 for v in self.metadata.values() if v.get("type") == "text")
|
||||||
|
image_count = sum(1 for v in self.metadata.values() if v.get("type") == "image")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_vectors": self.index.ntotal,
|
||||||
|
"text_count": text_count,
|
||||||
|
"image_count": image_count,
|
||||||
|
"vector_dimension": self.vector_dim,
|
||||||
|
"index_path": self.index_path,
|
||||||
|
"model_path": self.model_path
|
||||||
|
}
|
||||||
|
|
||||||
|
def clear_index(self):
|
||||||
|
"""清空索引"""
|
||||||
|
logger.info(f"清空索引: {self.index_path}")
|
||||||
|
|
||||||
|
# 重新创建索引
|
||||||
|
self.index = faiss.IndexFlatL2(self.vector_dim)
|
||||||
|
|
||||||
|
# 清空元数据
|
||||||
|
self.metadata = {}
|
||||||
|
|
||||||
|
# 保存空索引
|
||||||
|
self.save_index()
|
||||||
|
|
||||||
|
logger.info(f"索引已清空: {self.index_path}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def list_items(self) -> List[Dict[str, Any]]:
|
||||||
|
"""列出所有索引项"""
|
||||||
|
items = []
|
||||||
|
|
||||||
|
for item_id, metadata in self.metadata.items():
|
||||||
|
item = metadata.copy()
|
||||||
|
item['id'] = item_id
|
||||||
|
items.append(item)
|
||||||
|
|
||||||
|
return items
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
"""析构函数,确保资源被正确释放并自动保存索引"""
|
||||||
|
try:
|
||||||
|
if hasattr(self, 'model'):
|
||||||
|
del self.model
|
||||||
|
self._clear_all_gpu_memory()
|
||||||
|
if hasattr(self, 'index') and self.index is not None:
|
||||||
|
logger.info("系统关闭前自动保存索引")
|
||||||
|
self.save_index()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"析构时保存索引失败: {str(e)}")
|
||||||
@ -60,7 +60,14 @@ class MultimodalRetrievalVDB:
|
|||||||
"database_name": "multimodal_retrieval"
|
"database_name": "multimodal_retrieval"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
self.vdb = BaiduVDBBackend(**vdb_config)
|
self.vdb = BaiduVDBBackend(**vdb_config)
|
||||||
|
logger.info("✅ VDB后端初始化成功")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ VDB后端初始化失败: {e}")
|
||||||
|
# 创建一个模拟的VDB后端,避免系统完全崩溃
|
||||||
|
self.vdb = None
|
||||||
|
logger.warning("⚠️ 系统将在无VDB模式下运行,数据将不会持久化")
|
||||||
|
|
||||||
logger.info("多模态检索系统初始化完成")
|
logger.info("多模态检索系统初始化完成")
|
||||||
|
|
||||||
@ -102,6 +109,12 @@ class MultimodalRetrievalVDB:
|
|||||||
# 清理GPU内存
|
# 清理GPU内存
|
||||||
self._clear_gpu_memory()
|
self._clear_gpu_memory()
|
||||||
|
|
||||||
|
# 设置离线模式环境变量
|
||||||
|
os.environ['TRANSFORMERS_OFFLINE'] = '1'
|
||||||
|
os.environ['HF_HUB_OFFLINE'] = '1'
|
||||||
|
|
||||||
|
# 尝试加载模型,如果网络失败则使用本地缓存
|
||||||
|
try:
|
||||||
# 加载模型
|
# 加载模型
|
||||||
if self.num_gpus > 1:
|
if self.num_gpus > 1:
|
||||||
# 多GPU加载
|
# 多GPU加载
|
||||||
@ -113,7 +126,8 @@ class MultimodalRetrievalVDB:
|
|||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16,
|
||||||
device_map="auto",
|
device_map="auto",
|
||||||
max_memory=max_memory,
|
max_memory=max_memory,
|
||||||
low_cpu_mem_usage=True
|
low_cpu_mem_usage=True,
|
||||||
|
local_files_only=False # 允许从网络下载
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# 单GPU加载
|
# 单GPU加载
|
||||||
@ -121,22 +135,75 @@ class MultimodalRetrievalVDB:
|
|||||||
self.model_name,
|
self.model_name,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16,
|
||||||
device_map=self.primary_device
|
device_map=self.primary_device,
|
||||||
|
local_files_only=False # 允许从网络下载
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger.info("模型从网络加载成功")
|
||||||
|
|
||||||
|
except Exception as network_error:
|
||||||
|
logger.warning(f"网络加载失败,尝试本地缓存: {network_error}")
|
||||||
|
|
||||||
|
# 尝试从本地缓存加载
|
||||||
|
try:
|
||||||
|
if self.num_gpus > 1:
|
||||||
|
max_memory = {i: "18GiB" for i in self.device_ids}
|
||||||
|
|
||||||
|
self.model = AutoModel.from_pretrained(
|
||||||
|
self.model_name,
|
||||||
|
trust_remote_code=True,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
device_map="auto",
|
||||||
|
max_memory=max_memory,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
local_files_only=True # 仅使用本地文件
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.model = AutoModel.from_pretrained(
|
||||||
|
self.model_name,
|
||||||
|
trust_remote_code=True,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
device_map=self.primary_device,
|
||||||
|
local_files_only=True # 仅使用本地文件
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("模型从本地缓存加载成功")
|
||||||
|
|
||||||
|
except Exception as local_error:
|
||||||
|
logger.error(f"本地缓存加载也失败: {local_error}")
|
||||||
|
raise local_error
|
||||||
|
|
||||||
# 加载分词器和处理器
|
# 加载分词器和处理器
|
||||||
|
try:
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
self.model_name,
|
self.model_name,
|
||||||
trust_remote_code=True
|
trust_remote_code=True,
|
||||||
|
local_files_only=False
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Tokenizer网络加载失败,尝试本地: {e}")
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
self.model_name,
|
||||||
|
trust_remote_code=True,
|
||||||
|
local_files_only=True
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.processor = AutoProcessor.from_pretrained(
|
self.processor = AutoProcessor.from_pretrained(
|
||||||
self.model_name,
|
self.model_name,
|
||||||
trust_remote_code=True
|
trust_remote_code=True,
|
||||||
|
local_files_only=False
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Processor加载失败,使用tokenizer: {e}")
|
logger.warning(f"Processor加载失败,使用tokenizer: {e}")
|
||||||
|
try:
|
||||||
|
self.processor = AutoProcessor.from_pretrained(
|
||||||
|
self.model_name,
|
||||||
|
trust_remote_code=True,
|
||||||
|
local_files_only=True
|
||||||
|
)
|
||||||
|
except Exception as e2:
|
||||||
|
logger.warning(f"Processor本地加载也失败,使用tokenizer: {e2}")
|
||||||
self.processor = self.tokenizer
|
self.processor = self.tokenizer
|
||||||
|
|
||||||
logger.info("模型加载完成")
|
logger.info("模型加载完成")
|
||||||
@ -274,6 +341,10 @@ class MultimodalRetrievalVDB:
|
|||||||
Returns:
|
Returns:
|
||||||
存储的ID列表
|
存储的ID列表
|
||||||
"""
|
"""
|
||||||
|
if self.vdb is None:
|
||||||
|
logger.warning("VDB不可用,文本数据将不会持久化存储")
|
||||||
|
return []
|
||||||
|
|
||||||
logger.info(f"正在存储 {len(texts)} 条文本数据")
|
logger.info(f"正在存储 {len(texts)} 条文本数据")
|
||||||
|
|
||||||
# 分批处理
|
# 分批处理
|
||||||
@ -312,6 +383,10 @@ class MultimodalRetrievalVDB:
|
|||||||
Returns:
|
Returns:
|
||||||
存储的ID列表
|
存储的ID列表
|
||||||
"""
|
"""
|
||||||
|
if self.vdb is None:
|
||||||
|
logger.warning("VDB不可用,图像数据将不会持久化存储")
|
||||||
|
return []
|
||||||
|
|
||||||
logger.info(f"正在存储 {len(image_paths)} 张图像数据")
|
logger.info(f"正在存储 {len(image_paths)} 张图像数据")
|
||||||
|
|
||||||
# 图像处理使用更小的批次
|
# 图像处理使用更小的批次
|
||||||
@ -341,6 +416,10 @@ class MultimodalRetrievalVDB:
|
|||||||
|
|
||||||
def search_text_by_text(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
|
def search_text_by_text(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
|
||||||
"""文搜文:使用文本查询搜索相似文本"""
|
"""文搜文:使用文本查询搜索相似文本"""
|
||||||
|
if self.vdb is None:
|
||||||
|
logger.warning("VDB不可用,无法执行搜索")
|
||||||
|
return []
|
||||||
|
|
||||||
logger.info(f"执行文搜文查询: {query}")
|
logger.info(f"执行文搜文查询: {query}")
|
||||||
|
|
||||||
# 编码查询文本
|
# 编码查询文本
|
||||||
@ -358,6 +437,10 @@ class MultimodalRetrievalVDB:
|
|||||||
|
|
||||||
def search_images_by_text(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
|
def search_images_by_text(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
|
||||||
"""文搜图:使用文本查询搜索相似图像"""
|
"""文搜图:使用文本查询搜索相似图像"""
|
||||||
|
if self.vdb is None:
|
||||||
|
logger.warning("VDB不可用,无法执行搜索")
|
||||||
|
return []
|
||||||
|
|
||||||
logger.info(f"执行文搜图查询: {query}")
|
logger.info(f"执行文搜图查询: {query}")
|
||||||
|
|
||||||
# 编码查询文本
|
# 编码查询文本
|
||||||
@ -375,6 +458,10 @@ class MultimodalRetrievalVDB:
|
|||||||
|
|
||||||
def search_text_by_image(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]:
|
def search_text_by_image(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]:
|
||||||
"""图搜文:使用图像查询搜索相似文本"""
|
"""图搜文:使用图像查询搜索相似文本"""
|
||||||
|
if self.vdb is None:
|
||||||
|
logger.warning("VDB不可用,无法执行搜索")
|
||||||
|
return []
|
||||||
|
|
||||||
logger.info(f"执行图搜文查询")
|
logger.info(f"执行图搜文查询")
|
||||||
|
|
||||||
# 编码查询图像
|
# 编码查询图像
|
||||||
@ -392,6 +479,10 @@ class MultimodalRetrievalVDB:
|
|||||||
|
|
||||||
def search_images_by_image(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]:
|
def search_images_by_image(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]:
|
||||||
"""图搜图:使用图像查询搜索相似图像"""
|
"""图搜图:使用图像查询搜索相似图像"""
|
||||||
|
if self.vdb is None:
|
||||||
|
logger.warning("VDB不可用,无法执行搜索")
|
||||||
|
return []
|
||||||
|
|
||||||
logger.info(f"执行图搜图查询")
|
logger.info(f"执行图搜图查询")
|
||||||
|
|
||||||
# 编码查询图像
|
# 编码查询图像
|
||||||
@ -426,10 +517,15 @@ class MultimodalRetrievalVDB:
|
|||||||
|
|
||||||
def get_statistics(self) -> Dict[str, Any]:
|
def get_statistics(self) -> Dict[str, Any]:
|
||||||
"""获取系统统计信息"""
|
"""获取系统统计信息"""
|
||||||
|
if self.vdb is None:
|
||||||
|
return {"error": "VDB不可用"}
|
||||||
return self.vdb.get_statistics()
|
return self.vdb.get_statistics()
|
||||||
|
|
||||||
def clear_all_data(self):
|
def clear_all_data(self):
|
||||||
"""清空所有数据"""
|
"""清空所有数据"""
|
||||||
|
if self.vdb is None:
|
||||||
|
logger.warning("VDB不可用,无法清空数据")
|
||||||
|
return
|
||||||
self.vdb.clear_all_data()
|
self.vdb.clear_all_data()
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
|
|||||||
49
nohup.out
Normal file
49
nohup.out
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
INFO:baidu_bos_manager:✅ BOS连接测试成功
|
||||||
|
INFO:baidu_bos_manager:✅ BOS客户端初始化成功: dmtyz-demo
|
||||||
|
INFO:mongodb_manager:✅ MongoDB连接成功: mmeb
|
||||||
|
INFO:mongodb_manager:✅ MongoDB索引创建完成
|
||||||
|
INFO:__main__:初始化多模态检索系统...
|
||||||
|
INFO:multimodal_retrieval_local:使用GPU: [0, 1]
|
||||||
|
INFO:multimodal_retrieval_local:加载本地模型和处理器: /root/models/Ops-MM-embedding-v1-7B
|
||||||
|
The image processor of type `Qwen2VLImageProcessor` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. Note that this behavior will be extended to all models in a future release.
|
||||||
|
You have video processor config saved in `preprocessor.json` file which is deprecated. Video processor configs should be saved in their own `video_preprocessor.json` file. You can rename the file or load and save the processor back which renames it automatically. Loading from `preprocessor.json` will be removed in v5.0.
|
||||||
|
INFO:multimodal_retrieval_local:Processor类型: <class 'transformers.models.qwen2_vl.processing_qwen2_vl.Qwen2VLProcessor'>
|
||||||
|
INFO:multimodal_retrieval_local:Processor方法: ['__annotations__', '__call__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_auto_class', '_check_special_mm_tokens', '_create_repo', '_get_arguments_from_pretrained', '_get_files_timestamps', '_get_num_multimodal_tokens', '_merge_kwargs', '_upload_modified_files', 'apply_chat_template', 'attributes', 'audio_tokenizer', 'batch_decode', 'chat_template', 'check_argument_for_proper_class', 'decode', 'feature_extractor_class', 'from_args_and_dict', 'from_pretrained', 'get_possibly_dynamic_module', 'get_processor_dict', 'image_processor', 'image_processor_class', 'image_token', 'image_token_id', 'model_input_names', 'optional_attributes', 'optional_call_args', 'post_process_image_text_to_text', 'push_to_hub', 'register_for_auto_class', 'save_pretrained', 'to_dict', 'to_json_file', 'to_json_string', 'tokenizer', 'tokenizer_class', 'validate_init_kwargs', 'video_processor', 'video_processor_class', 'video_token', 'video_token_id']
|
||||||
|
INFO:multimodal_retrieval_local:Image processor类型: <class 'transformers.models.qwen2_vl.image_processing_qwen2_vl_fast.Qwen2VLImageProcessorFast'>
|
||||||
|
INFO:multimodal_retrieval_local:Image processor方法: ['__backends', '__call__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__slotnames__', '__str__', '__subclasshook__', '__weakref__', '_auto_class', '_create_repo', '_further_process_kwargs', '_fuse_mean_std_and_rescale_factor', '_get_files_timestamps', '_prepare_image_like_inputs', '_prepare_images_structure', '_preprocess', '_preprocess_image_like_inputs', '_process_image', '_processor_class', '_set_processor_class', '_upload_modified_files', '_valid_kwargs_names', '_validate_preprocess_kwargs', 'center_crop', 'compile_friendly_resize', 'convert_to_rgb', 'crop_size', 'data_format', 'default_to_square', 'device', 'disable_grouping', 'do_center_crop', 'do_convert_rgb', 'do_normalize', 'do_rescale', 'do_resize', 'fetch_images', 'filter_out_unused_kwargs', 'from_dict', 'from_json_file', 'from_pretrained', 'get_image_processor_dict', 'get_number_of_image_patches', 'image_mean', 'image_processor_type', 'image_std', 'input_data_format', 'max_pixels', 'merge_size', 'min_pixels', 'model_input_names', 'normalize', 'patch_size', 'preprocess', 'push_to_hub', 'register_for_auto_class', 'resample', 'rescale', 'rescale_and_normalize', 'rescale_factor', 'resize', 'return_tensors', 'save_pretrained', 'size', 'temporal_patch_size', 'to_dict', 'to_json_file', 'to_json_string', 'unused_kwargs', 'valid_kwargs']
|
||||||
|
Loading checkpoint shards: 0%| | 0/4 [00:00<?, ?it/s]
Loading checkpoint shards: 25%|██▌ | 1/4 [03:03<09:10, 183.40s/it]
Loading checkpoint shards: 50%|█████ | 2/4 [04:55<04:43, 141.63s/it]
Loading checkpoint shards: 75%|███████▌ | 3/4 [06:56<02:12, 132.26s/it]
Loading checkpoint shards: 100%|██████████| 4/4 [07:13<00:00, 86.72s/it]
Loading checkpoint shards: 100%|██████████| 4/4 [07:13<00:00, 108.47s/it]
|
||||||
|
INFO:multimodal_retrieval_local:向量维度: 3584
|
||||||
|
INFO:multimodal_retrieval_local:模型和处理器加载成功
|
||||||
|
INFO:multimodal_retrieval_local:加载现有索引: /root/mmeb/local_faiss_index.index
|
||||||
|
INFO:multimodal_retrieval_local:索引加载成功,包含0个向量
|
||||||
|
INFO:multimodal_retrieval_local:元数据加载成功,包含0条记录
|
||||||
|
INFO:multimodal_retrieval_local:多模态检索系统初始化完成,使用本地模型: /root/models/Ops-MM-embedding-v1-7B
|
||||||
|
INFO:multimodal_retrieval_local:向量存储路径: /root/mmeb/local_faiss_index
|
||||||
|
INFO:__main__:多模态检索系统初始化完成
|
||||||
|
* Serving Flask app 'web_app_local'
|
||||||
|
* Debug mode: off
|
||||||
|
INFO:werkzeug:[31m[1mWARNING: This is a development server. Do not use it in a production deployment. Use a production WSGI server instead.[0m
|
||||||
|
* Running on all addresses (0.0.0.0)
|
||||||
|
* Running on http://127.0.0.1:5000
|
||||||
|
* Running on http://192.168.48.82:5000
|
||||||
|
INFO:werkzeug:[33mPress CTRL+C to quit[0m
|
||||||
|
INFO:werkzeug:127.0.0.1 - - [22/Sep/2025 04:02:23] "GET / HTTP/1.1" 200 -
|
||||||
|
INFO:werkzeug:127.0.0.1 - - [22/Sep/2025 04:02:23] "GET /api/system_info HTTP/1.1" 200 -
|
||||||
|
INFO:werkzeug:127.0.0.1 - - [22/Sep/2025 04:02:24] "GET /api/system_info HTTP/1.1" 200 -
|
||||||
|
INFO:__main__:处理图像: 微信图片_20250910164839_1_13.jpg (99396 字节)
|
||||||
|
INFO:__main__:成功加载图像: 20250910164839_1_13.jpg, 格式: JPEG, 模式: RGB, 大小: (939, 940)
|
||||||
|
INFO:multimodal_retrieval_local:add_images: 开始添加图像,数量: 1
|
||||||
|
INFO:multimodal_retrieval_local:add_images: 编码图像
|
||||||
|
INFO:multimodal_retrieval_local:encode_image: 开始编码图像,类型: <class 'list'>
|
||||||
|
INFO:multimodal_retrieval_local:encode_image: 图像列表,长度: 1
|
||||||
|
INFO:multimodal_retrieval_local:encode_image: 处理图像输入
|
||||||
|
INFO:multimodal_retrieval_local:encode_image: 图像 0 格式: JPEG, 模式: RGB, 大小: (939, 940)
|
||||||
|
ERROR:multimodal_retrieval_local:encode_image: 处理图像时出错: argument of type 'NoneType' is not iterable
|
||||||
|
ERROR:multimodal_retrieval_local:add_images: 图像编码失败,返回空数组
|
||||||
|
INFO:multimodal_retrieval_local:索引保存成功: /root/mmeb/local_faiss_index.index
|
||||||
|
INFO:multimodal_retrieval_local:元数据保存成功: /root/mmeb/local_faiss_index_metadata.json
|
||||||
|
INFO:werkzeug:127.0.0.1 - - [22/Sep/2025 04:02:50] "POST /api/add_image HTTP/1.1" 200 -
|
||||||
|
INFO:multimodal_retrieval_local:索引保存成功: /root/mmeb/local_faiss_index.index
|
||||||
|
INFO:multimodal_retrieval_local:元数据保存成功: /root/mmeb/local_faiss_index_metadata.json
|
||||||
|
INFO:werkzeug:127.0.0.1 - - [22/Sep/2025 04:02:50] "POST /api/save_index HTTP/1.1" 200 -
|
||||||
|
INFO:werkzeug:127.0.0.1 - - [22/Sep/2025 04:02:51] "GET /api/system_info HTTP/1.1" 200 -
|
||||||
@ -30,19 +30,30 @@ class OptimizedFileHandler:
|
|||||||
# 支持的图像格式
|
# 支持的图像格式
|
||||||
SUPPORTED_IMAGE_FORMATS = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp'}
|
SUPPORTED_IMAGE_FORMATS = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp'}
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, local_storage_dir=None):
|
||||||
self.bos_manager = get_bos_manager()
|
self.bos_manager = get_bos_manager()
|
||||||
self.mongodb_manager = get_mongodb_manager()
|
self.mongodb_manager = get_mongodb_manager()
|
||||||
self.temp_files = set() # 跟踪临时文件
|
self.temp_files = set() # 跟踪临时文件
|
||||||
|
self.local_storage_dir = local_storage_dir or tempfile.gettempdir()
|
||||||
|
|
||||||
|
# 确保本地存储目录存在
|
||||||
|
if self.local_storage_dir:
|
||||||
|
os.makedirs(self.local_storage_dir, exist_ok=True)
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def temp_file_context(self, suffix: str = None, delete_on_exit: bool = True):
|
def temp_file_context(self, content: bytes = None, suffix: str = None, delete_on_exit: bool = True):
|
||||||
"""临时文件上下文管理器,确保自动清理"""
|
"""临时文件上下文管理器,确保自动清理"""
|
||||||
temp_fd, temp_path = tempfile.mkstemp(suffix=suffix)
|
temp_fd, temp_path = tempfile.mkstemp(suffix=suffix, dir=self.local_storage_dir)
|
||||||
self.temp_files.add(temp_path)
|
self.temp_files.add(temp_path)
|
||||||
|
|
||||||
try:
|
# 如果提供了内容,写入文件
|
||||||
|
if content is not None:
|
||||||
|
with os.fdopen(temp_fd, 'wb') as f:
|
||||||
|
f.write(content)
|
||||||
|
else:
|
||||||
os.close(temp_fd) # 关闭文件描述符
|
os.close(temp_fd) # 关闭文件描述符
|
||||||
|
|
||||||
|
try:
|
||||||
yield temp_path
|
yield temp_path
|
||||||
finally:
|
finally:
|
||||||
if delete_on_exit and os.path.exists(temp_path):
|
if delete_on_exit and os.path.exists(temp_path):
|
||||||
@ -96,17 +107,13 @@ class OptimizedFileHandler:
|
|||||||
logger.error(f"❌ 图像验证失败: {filename}, {e}")
|
logger.error(f"❌ 图像验证失败: {filename}, {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 生成唯一ID和BOS键
|
# 生成唯一ID
|
||||||
file_id = str(uuid.uuid4())
|
file_id = str(uuid.uuid4())
|
||||||
bos_key = f"images/memory_{file_id}_{filename}"
|
|
||||||
|
|
||||||
# 直接上传到BOS(从内存)
|
# 保存到本地存储
|
||||||
bos_result = self._upload_to_bos_from_memory(
|
local_path = os.path.join(self.local_storage_dir, f"{file_id}_{filename}")
|
||||||
file_content, bos_key, filename
|
with open(local_path, 'wb') as f:
|
||||||
)
|
f.write(file_content)
|
||||||
|
|
||||||
if not bos_result:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 存储元数据到MongoDB
|
# 存储元数据到MongoDB
|
||||||
metadata = {
|
metadata = {
|
||||||
@ -115,18 +122,25 @@ class OptimizedFileHandler:
|
|||||||
"file_type": "image",
|
"file_type": "image",
|
||||||
"file_size": len(file_content),
|
"file_size": len(file_content),
|
||||||
"processing_method": "memory",
|
"processing_method": "memory",
|
||||||
"bos_key": bos_key,
|
"local_path": local_path
|
||||||
"bos_url": bos_result["url"]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# 如果有BOS管理器,也上传到BOS
|
||||||
|
if hasattr(self, 'bos_manager') and self.bos_manager:
|
||||||
|
bos_key = f"images/memory_{file_id}_{filename}"
|
||||||
|
bos_result = self._upload_to_bos_from_memory(file_content, bos_key, filename)
|
||||||
|
if bos_result:
|
||||||
|
metadata["bos_key"] = bos_key
|
||||||
|
metadata["bos_url"] = bos_result["url"]
|
||||||
|
|
||||||
|
if hasattr(self, 'mongodb_manager') and self.mongodb_manager:
|
||||||
self.mongodb_manager.store_file_metadata(metadata=metadata)
|
self.mongodb_manager.store_file_metadata(metadata=metadata)
|
||||||
|
|
||||||
logger.info(f"✅ 内存处理图像成功: {filename} ({len(file_content)} bytes)")
|
logger.info(f"✅ 内存处理图像成功: {filename} ({len(file_content)} bytes)")
|
||||||
return {
|
return {
|
||||||
"file_id": file_id,
|
"file_id": file_id,
|
||||||
"filename": filename,
|
"filename": filename,
|
||||||
"bos_key": bos_key,
|
"local_path": local_path,
|
||||||
"bos_result": bos_result,
|
|
||||||
"processing_method": "memory"
|
"processing_method": "memory"
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -140,6 +154,12 @@ class OptimizedFileHandler:
|
|||||||
# 获取文件扩展名
|
# 获取文件扩展名
|
||||||
ext = os.path.splitext(filename)[1].lower()
|
ext = os.path.splitext(filename)[1].lower()
|
||||||
|
|
||||||
|
# 生成唯一ID
|
||||||
|
file_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# 创建永久文件路径
|
||||||
|
permanent_path = os.path.join(self.local_storage_dir, f"{file_id}_{filename}")
|
||||||
|
|
||||||
with self.temp_file_context(suffix=ext) as temp_path:
|
with self.temp_file_context(suffix=ext) as temp_path:
|
||||||
# 保存到临时文件
|
# 保存到临时文件
|
||||||
file_obj.seek(0)
|
file_obj.seek(0)
|
||||||
@ -154,35 +174,41 @@ class OptimizedFileHandler:
|
|||||||
logger.error(f"❌ 图像验证失败: {filename}, {e}")
|
logger.error(f"❌ 图像验证失败: {filename}, {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 生成唯一ID和BOS键
|
# 复制到永久存储位置
|
||||||
file_id = str(uuid.uuid4())
|
with open(temp_path, 'rb') as src, open(permanent_path, 'wb') as dst:
|
||||||
bos_key = f"images/temp_{file_id}_{filename}"
|
dst.write(src.read())
|
||||||
|
|
||||||
# 上传到BOS
|
# 获取文件信息
|
||||||
bos_result = self.bos_manager.upload_file(temp_path, bos_key)
|
file_stat = os.stat(permanent_path)
|
||||||
|
|
||||||
# 存储元数据到MongoDB
|
# 存储元数据
|
||||||
file_stat = os.stat(temp_path)
|
|
||||||
metadata = {
|
metadata = {
|
||||||
"_id": file_id,
|
"_id": file_id,
|
||||||
"filename": filename,
|
"filename": filename,
|
||||||
"file_type": "image",
|
"file_type": "image",
|
||||||
"file_size": file_stat.st_size,
|
"file_size": file_stat.st_size,
|
||||||
"processing_method": "temp_file",
|
"processing_method": "temp_file",
|
||||||
"bos_key": bos_key,
|
"local_path": permanent_path
|
||||||
"bos_url": bos_result["url"]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# 如果有BOS管理器,也上传到BOS
|
||||||
|
if hasattr(self, 'bos_manager') and self.bos_manager:
|
||||||
|
bos_key = f"images/temp_{file_id}_{filename}"
|
||||||
|
bos_result = self.bos_manager.upload_file(temp_path, bos_key)
|
||||||
|
if bos_result:
|
||||||
|
metadata["bos_key"] = bos_key
|
||||||
|
metadata["bos_url"] = bos_result["url"]
|
||||||
|
|
||||||
|
# 存储元数据到MongoDB
|
||||||
|
if hasattr(self, 'mongodb_manager') and self.mongodb_manager:
|
||||||
self.mongodb_manager.store_file_metadata(metadata=metadata)
|
self.mongodb_manager.store_file_metadata(metadata=metadata)
|
||||||
|
|
||||||
logger.info(f"✅ 临时文件处理图像成功: {filename} ({file_stat.st_size} bytes)")
|
logger.info(f"✅ 临时文件处理图像成功: {filename} ({file_stat.st_size} bytes)")
|
||||||
return {
|
return {
|
||||||
"file_id": file_id,
|
"file_id": file_id,
|
||||||
"filename": filename,
|
"filename": filename,
|
||||||
"bos_key": bos_key,
|
"local_path": permanent_path,
|
||||||
"bos_result": bos_result,
|
"processing_method": "temp_file"
|
||||||
"processing_method": "temp_file",
|
|
||||||
"temp_path": temp_path # 返回临时路径供模型处理
|
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -290,8 +316,11 @@ class OptimizedFileHandler:
|
|||||||
try:
|
try:
|
||||||
ext = os.path.splitext(filename)[1].lower()
|
ext = os.path.splitext(filename)[1].lower()
|
||||||
|
|
||||||
|
# 生成唯一ID
|
||||||
|
file_id = str(uuid.uuid4())
|
||||||
|
|
||||||
# 创建临时文件(不自动删除,供模型使用)
|
# 创建临时文件(不自动删除,供模型使用)
|
||||||
temp_fd, temp_path = tempfile.mkstemp(suffix=ext)
|
temp_fd, temp_path = tempfile.mkstemp(suffix=ext, dir=self.local_storage_dir)
|
||||||
self.temp_files.add(temp_path)
|
self.temp_files.add(temp_path)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
995
templates/local_index.html
Normal file
995
templates/local_index.html
Normal file
@ -0,0 +1,995 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="zh-CN">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>本地多模态检索系统 - FAISS</title>
|
||||||
|
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css" rel="stylesheet">
|
||||||
|
<link href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css" rel="stylesheet">
|
||||||
|
<style>
|
||||||
|
:root {
|
||||||
|
--primary-color: #2563eb;
|
||||||
|
--secondary-color: #64748b;
|
||||||
|
--success-color: #059669;
|
||||||
|
--warning-color: #d97706;
|
||||||
|
--danger-color: #dc2626;
|
||||||
|
--bg-gradient: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||||
|
}
|
||||||
|
|
||||||
|
body {
|
||||||
|
background: var(--bg-gradient);
|
||||||
|
min-height: 100vh;
|
||||||
|
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
|
||||||
|
}
|
||||||
|
|
||||||
|
.main-container {
|
||||||
|
background: rgba(255, 255, 255, 0.95);
|
||||||
|
backdrop-filter: blur(10px);
|
||||||
|
border-radius: 20px;
|
||||||
|
box-shadow: 0 20px 40px rgba(0, 0, 0, 0.1);
|
||||||
|
margin: 20px auto;
|
||||||
|
max-width: 1200px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.header {
|
||||||
|
background: linear-gradient(135deg, var(--primary-color), #3b82f6);
|
||||||
|
color: white;
|
||||||
|
padding: 2rem;
|
||||||
|
border-radius: 20px 20px 0 0;
|
||||||
|
text-align: center;
|
||||||
|
}
|
||||||
|
|
||||||
|
.mode-card {
|
||||||
|
background: white;
|
||||||
|
border-radius: 15px;
|
||||||
|
padding: 1.5rem;
|
||||||
|
margin-bottom: 1.5rem;
|
||||||
|
box-shadow: 0 4px 15px rgba(0, 0, 0, 0.1);
|
||||||
|
transition: all 0.3s ease;
|
||||||
|
border: 2px solid transparent;
|
||||||
|
}
|
||||||
|
|
||||||
|
.mode-card:hover {
|
||||||
|
transform: translateY(-5px);
|
||||||
|
box-shadow: 0 8px 25px rgba(0, 0, 0, 0.15);
|
||||||
|
}
|
||||||
|
|
||||||
|
.mode-card.active {
|
||||||
|
border-color: var(--primary-color);
|
||||||
|
background: linear-gradient(135deg, #eff6ff, #dbeafe);
|
||||||
|
}
|
||||||
|
|
||||||
|
.mode-icon {
|
||||||
|
font-size: 2.5rem;
|
||||||
|
margin-bottom: 1rem;
|
||||||
|
display: block;
|
||||||
|
}
|
||||||
|
|
||||||
|
.text-to-text { color: #059669; }
|
||||||
|
.text-to-image { color: #dc2626; }
|
||||||
|
.image-to-text { color: #d97706; }
|
||||||
|
.image-to-image { color: #7c3aed; }
|
||||||
|
|
||||||
|
.search-input {
|
||||||
|
border-radius: 12px;
|
||||||
|
border: 2px solid #e5e7eb;
|
||||||
|
padding: 12px 16px;
|
||||||
|
font-size: 16px;
|
||||||
|
transition: all 0.3s ease;
|
||||||
|
}
|
||||||
|
|
||||||
|
.search-input:focus {
|
||||||
|
border-color: var(--primary-color);
|
||||||
|
box-shadow: 0 0 0 3px rgba(37, 99, 235, 0.1);
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-primary {
|
||||||
|
background: var(--primary-color);
|
||||||
|
border: none;
|
||||||
|
border-radius: 12px;
|
||||||
|
padding: 12px 24px;
|
||||||
|
font-weight: 600;
|
||||||
|
transition: all 0.3s ease;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-primary:hover {
|
||||||
|
background: #1d4ed8;
|
||||||
|
transform: translateY(-2px);
|
||||||
|
}
|
||||||
|
|
||||||
|
.file-upload-area {
|
||||||
|
border: 3px dashed #d1d5db;
|
||||||
|
border-radius: 12px;
|
||||||
|
padding: 3rem;
|
||||||
|
text-align: center;
|
||||||
|
transition: all 0.3s ease;
|
||||||
|
cursor: pointer;
|
||||||
|
}
|
||||||
|
|
||||||
|
.file-upload-area:hover {
|
||||||
|
border-color: var(--primary-color);
|
||||||
|
background: rgba(37, 99, 235, 0.05);
|
||||||
|
}
|
||||||
|
|
||||||
|
.file-upload-area.dragover {
|
||||||
|
border-color: var(--primary-color);
|
||||||
|
background: rgba(37, 99, 235, 0.1);
|
||||||
|
}
|
||||||
|
|
||||||
|
.result-card {
|
||||||
|
background: white;
|
||||||
|
border-radius: 12px;
|
||||||
|
padding: 1.5rem;
|
||||||
|
margin-bottom: 1rem;
|
||||||
|
box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1);
|
||||||
|
border-left: 4px solid var(--primary-color);
|
||||||
|
}
|
||||||
|
|
||||||
|
.result-image {
|
||||||
|
max-width: 200px;
|
||||||
|
max-height: 150px;
|
||||||
|
border-radius: 8px;
|
||||||
|
object-fit: cover;
|
||||||
|
}
|
||||||
|
|
||||||
|
.score-badge {
|
||||||
|
background: var(--success-color);
|
||||||
|
color: white;
|
||||||
|
padding: 4px 12px;
|
||||||
|
border-radius: 20px;
|
||||||
|
font-size: 0.85rem;
|
||||||
|
font-weight: 600;
|
||||||
|
}
|
||||||
|
|
||||||
|
.loading-spinner {
|
||||||
|
display: none;
|
||||||
|
text-align: center;
|
||||||
|
padding: 2rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.status-indicator {
|
||||||
|
position: fixed;
|
||||||
|
top: 20px;
|
||||||
|
right: 20px;
|
||||||
|
z-index: 1000;
|
||||||
|
}
|
||||||
|
|
||||||
|
.fade-in {
|
||||||
|
animation: fadeIn 0.5s ease-in;
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes fadeIn {
|
||||||
|
from { opacity: 0; transform: translateY(20px); }
|
||||||
|
to { opacity: 1; transform: translateY(0); }
|
||||||
|
}
|
||||||
|
|
||||||
|
.query-image {
|
||||||
|
max-width: 300px;
|
||||||
|
max-height: 200px;
|
||||||
|
border-radius: 12px;
|
||||||
|
object-fit: cover;
|
||||||
|
box-shadow: 0 4px 15px rgba(0, 0, 0, 0.2);
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<!-- 状态指示器 -->
|
||||||
|
<div class="status-indicator">
|
||||||
|
<div id="statusBadge" class="badge bg-secondary">
|
||||||
|
<i class="fas fa-circle-notch fa-spin"></i> 未初始化
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="container-fluid">
|
||||||
|
<div class="main-container">
|
||||||
|
<!-- 头部 -->
|
||||||
|
<div class="header">
|
||||||
|
<h1><i class="fas fa-search"></i> 本地多模态检索系统</h1>
|
||||||
|
<p class="mb-0">基于本地模型和FAISS向量数据库,支持文搜图、文搜文、图搜图、图搜文四种检索模式</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="p-4">
|
||||||
|
<!-- 重新初始化按钮 -->
|
||||||
|
<div class="text-center mb-4">
|
||||||
|
<button id="reinitBtn" class="btn btn-warning">
|
||||||
|
<i class="fas fa-redo"></i> 重新初始化系统
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- 检索模式选择 -->
|
||||||
|
<div class="row mb-4" id="modeSelection">
|
||||||
|
<div class="col-md-3">
|
||||||
|
<div class="mode-card text-center" data-mode="text_to_text">
|
||||||
|
<i class="fas fa-file-text mode-icon text-to-text"></i>
|
||||||
|
<h5>文搜文</h5>
|
||||||
|
<p class="text-muted">文本查找相似文本</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class="col-md-3">
|
||||||
|
<div class="mode-card text-center" data-mode="text_to_image">
|
||||||
|
<i class="fas fa-image mode-icon text-to-image"></i>
|
||||||
|
<h5>文搜图</h5>
|
||||||
|
<p class="text-muted">文本查找相关图片</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class="col-md-3">
|
||||||
|
<div class="mode-card text-center" data-mode="image_to_text">
|
||||||
|
<i class="fas fa-comment mode-icon image-to-text"></i>
|
||||||
|
<h5>图搜文</h5>
|
||||||
|
<p class="text-muted">图片查找相关文本</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class="col-md-3">
|
||||||
|
<div class="mode-card text-center" data-mode="image_to_image">
|
||||||
|
<i class="fas fa-images mode-icon image-to-image"></i>
|
||||||
|
<h5>图搜图</h5>
|
||||||
|
<p class="text-muted">图片查找相似图片</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- 数据管理界面 -->
|
||||||
|
<div class="row mb-4" id="dataManagement">
|
||||||
|
<div class="col-12">
|
||||||
|
<div class="card">
|
||||||
|
<div class="card-header">
|
||||||
|
<h5><i class="fas fa-database"></i> 数据管理</h5>
|
||||||
|
<small class="text-muted">上传和管理检索数据库</small>
|
||||||
|
</div>
|
||||||
|
<div class="card-body">
|
||||||
|
<div class="row">
|
||||||
|
<!-- 批量上传图片 -->
|
||||||
|
<div class="col-md-6">
|
||||||
|
<div class="upload-section">
|
||||||
|
<h6><i class="fas fa-images text-primary"></i> 批量上传图片</h6>
|
||||||
|
<div class="file-upload-area" id="batchImageUpload">
|
||||||
|
<i class="fas fa-cloud-upload-alt fa-2x text-muted mb-2"></i>
|
||||||
|
<p>拖拽多张图片到此处或点击选择</p>
|
||||||
|
<input type="file" id="batchImageFiles" multiple accept="image/*" style="display: none;">
|
||||||
|
<button class="btn btn-outline-primary btn-sm mt-2" onclick="document.getElementById('batchImageFiles').click()">
|
||||||
|
<i class="fas fa-folder-open"></i> 选择图片
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
<div id="imageUploadProgress" class="mt-2" style="display: none;">
|
||||||
|
<div class="progress">
|
||||||
|
<div class="progress-bar" role="progressbar" style="width: 0%"></div>
|
||||||
|
</div>
|
||||||
|
<small class="text-muted mt-1 d-block">上传进度: <span id="imageProgressText">0/0</span></small>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- 批量上传文本 -->
|
||||||
|
<div class="col-md-6">
|
||||||
|
<div class="upload-section">
|
||||||
|
<h6><i class="fas fa-file-text text-success"></i> 批量上传文本</h6>
|
||||||
|
<div class="mb-3">
|
||||||
|
<textarea id="batchTextInput" class="form-control" rows="8"
|
||||||
|
placeholder="请输入文本数据,每行一条文本记录... 例如: 这是第一条文本记录 这是第二条文本记录 这是第三条文本记录"></textarea>
|
||||||
|
</div>
|
||||||
|
<div class="d-flex gap-2">
|
||||||
|
<button id="uploadTextsBtn" class="btn btn-success">
|
||||||
|
<i class="fas fa-upload"></i> 上传文本
|
||||||
|
</button>
|
||||||
|
<button class="btn btn-outline-secondary" onclick="document.getElementById('textFile').click()">
|
||||||
|
<i class="fas fa-file-import"></i> 从文件导入
|
||||||
|
</button>
|
||||||
|
<input type="file" id="textFile" accept=".txt,.csv" style="display: none;">
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- 数据统计和管理 -->
|
||||||
|
<div class="row mt-4">
|
||||||
|
<div class="col-md-8">
|
||||||
|
<div class="d-flex gap-3">
|
||||||
|
<!-- 移除构建索引按钮,改为自动构建 -->
|
||||||
|
<button id="viewDataBtn" class="btn btn-info">
|
||||||
|
<i class="fas fa-list"></i> 查看数据
|
||||||
|
</button>
|
||||||
|
<button id="clearDataBtn" class="btn btn-danger">
|
||||||
|
<i class="fas fa-trash"></i> 清空数据
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class="col-md-4">
|
||||||
|
<div id="dataStats" class="text-end">
|
||||||
|
<small class="text-muted">
|
||||||
|
图片: <span id="imageCount">0</span> 张 |
|
||||||
|
文本: <span id="textCount">0</span> 条
|
||||||
|
</small>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- 搜索界面 -->
|
||||||
|
<div id="searchInterface" style="display: none;">
|
||||||
|
<!-- 文本搜索 -->
|
||||||
|
<div id="textSearch" class="search-panel" style="display: none;">
|
||||||
|
<div class="row">
|
||||||
|
<div class="col-md-8">
|
||||||
|
<input type="text" id="textQuery" class="form-control search-input"
|
||||||
|
placeholder="请输入搜索文本...">
|
||||||
|
</div>
|
||||||
|
<div class="col-md-2">
|
||||||
|
<select id="textTopK" class="form-select search-input">
|
||||||
|
<option value="3">Top 3</option>
|
||||||
|
<option value="5" selected>Top 5</option>
|
||||||
|
<option value="10">Top 10</option>
|
||||||
|
</select>
|
||||||
|
</div>
|
||||||
|
<div class="col-md-2">
|
||||||
|
<button id="textSearchBtn" class="btn btn-primary w-100">
|
||||||
|
<i class="fas fa-search"></i> 搜索
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- 图片搜索 -->
|
||||||
|
<div id="imageSearch" class="search-panel" style="display: none;">
|
||||||
|
<div class="row">
|
||||||
|
<div class="col-md-8">
|
||||||
|
<div class="file-upload-area" id="fileUploadArea">
|
||||||
|
<i class="fas fa-cloud-upload-alt fa-3x text-muted mb-3"></i>
|
||||||
|
<h5>拖拽图片到此处或点击选择</h5>
|
||||||
|
<p class="text-muted">支持 PNG, JPG, JPEG, GIF, BMP, WebP 格式</p>
|
||||||
|
<input type="file" id="imageFile" accept="image/*" style="display: none;">
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class="col-md-2">
|
||||||
|
<select id="imageTopK" class="form-select search-input">
|
||||||
|
<option value="3">Top 3</option>
|
||||||
|
<option value="5" selected>Top 5</option>
|
||||||
|
<option value="10">Top 10</option>
|
||||||
|
</select>
|
||||||
|
</div>
|
||||||
|
<div class="col-md-2">
|
||||||
|
<button id="imageSearchBtn" class="btn btn-primary w-100" disabled>
|
||||||
|
<i class="fas fa-search"></i> 搜索
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- 加载动画 -->
|
||||||
|
<div class="loading-spinner" id="loadingSpinner">
|
||||||
|
<div class="spinner-border text-primary" role="status">
|
||||||
|
<span class="visually-hidden">Loading...</span>
|
||||||
|
</div>
|
||||||
|
<p class="mt-2">正在搜索中...</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- 搜索结果 -->
|
||||||
|
<div id="searchResults" class="mt-4"></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<script src="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/js/bootstrap.bundle.min.js"></script>
|
||||||
|
<script>
|
||||||
|
let currentMode = null;
|
||||||
|
let systemInitialized = false;
|
||||||
|
|
||||||
|
// 重新初始化系统
|
||||||
|
document.getElementById('reinitBtn').addEventListener('click', async function() {
|
||||||
|
const btn = this;
|
||||||
|
const originalText = btn.innerHTML;
|
||||||
|
|
||||||
|
btn.innerHTML = '<i class="fas fa-spinner fa-spin"></i> 重新初始化中...';
|
||||||
|
btn.disabled = true;
|
||||||
|
|
||||||
|
try {
|
||||||
|
const response = await fetch('/api/system_info', {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {'Content-Type': 'application/json'}
|
||||||
|
});
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
|
||||||
|
if (data.success) {
|
||||||
|
systemInitialized = true;
|
||||||
|
document.getElementById('statusBadge').innerHTML =
|
||||||
|
'<i class="fas fa-check-circle"></i> 已重新初始化';
|
||||||
|
document.getElementById('statusBadge').className = 'badge bg-success';
|
||||||
|
|
||||||
|
showAlert('success', `系统重新初始化成功!GPU信息: ${data.gpu_info.length} 个, 向量数量: ${data.retrieval_info.total_vectors || 0}`);
|
||||||
|
} else {
|
||||||
|
throw new Error(data.message);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
showAlert('danger', '重新初始化失败: ' + error.message);
|
||||||
|
} finally {
|
||||||
|
btn.innerHTML = originalText;
|
||||||
|
btn.disabled = false;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// 模式选择
|
||||||
|
document.querySelectorAll('.mode-card').forEach(card => {
|
||||||
|
card.addEventListener('click', function() {
|
||||||
|
|
||||||
|
// 更新选中状态
|
||||||
|
document.querySelectorAll('.mode-card').forEach(c => c.classList.remove('active'));
|
||||||
|
this.classList.add('active');
|
||||||
|
|
||||||
|
currentMode = this.dataset.mode;
|
||||||
|
setupSearchInterface(currentMode);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
// 设置搜索界面
|
||||||
|
function setupSearchInterface(mode) {
|
||||||
|
document.getElementById('searchInterface').style.display = 'block';
|
||||||
|
document.getElementById('textSearch').style.display = 'none';
|
||||||
|
document.getElementById('imageSearch').style.display = 'none';
|
||||||
|
document.getElementById('searchResults').innerHTML = '';
|
||||||
|
|
||||||
|
if (mode === 'text_to_text' || mode === 'text_to_image') {
|
||||||
|
document.getElementById('textSearch').style.display = 'block';
|
||||||
|
document.getElementById('textQuery').placeholder =
|
||||||
|
mode === 'text_to_text' ? '请输入要搜索的文本...' : '请输入要搜索图片的描述...';
|
||||||
|
} else {
|
||||||
|
document.getElementById('imageSearch').style.display = 'block';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 文本搜索
|
||||||
|
document.getElementById('textSearchBtn').addEventListener('click', performTextSearch);
|
||||||
|
document.getElementById('textQuery').addEventListener('keypress', function(e) {
|
||||||
|
if (e.key === 'Enter') performTextSearch();
|
||||||
|
});
|
||||||
|
|
||||||
|
async function performTextSearch() {
|
||||||
|
const query = document.getElementById('textQuery').value.trim();
|
||||||
|
const topK = parseInt(document.getElementById('textTopK').value);
|
||||||
|
|
||||||
|
if (!query) {
|
||||||
|
showAlert('warning', '请输入搜索文本');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
showLoading(true);
|
||||||
|
|
||||||
|
try {
|
||||||
|
const endpoint = '/api/search_by_text';
|
||||||
|
const filter_type = currentMode === 'text_to_text' ? 'text' : 'image';
|
||||||
|
const response = await fetch(endpoint, {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {'Content-Type': 'application/json'},
|
||||||
|
body: JSON.stringify({query, k: topK, filter_type: filter_type})
|
||||||
|
});
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
|
||||||
|
if (data.success) {
|
||||||
|
displayResults(data, currentMode);
|
||||||
|
} else {
|
||||||
|
throw new Error(data.message);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
showAlert('danger', '搜索失败: ' + error.message);
|
||||||
|
} finally {
|
||||||
|
showLoading(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 图片上传处理
|
||||||
|
const fileUploadArea = document.getElementById('fileUploadArea');
|
||||||
|
const imageFile = document.getElementById('imageFile');
|
||||||
|
|
||||||
|
fileUploadArea.addEventListener('click', () => imageFile.click());
|
||||||
|
fileUploadArea.addEventListener('dragover', handleDragOver);
|
||||||
|
fileUploadArea.addEventListener('drop', handleDrop);
|
||||||
|
imageFile.addEventListener('change', handleFileSelect);
|
||||||
|
|
||||||
|
function handleDragOver(e) {
|
||||||
|
e.preventDefault();
|
||||||
|
fileUploadArea.classList.add('dragover');
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleDrop(e) {
|
||||||
|
e.preventDefault();
|
||||||
|
fileUploadArea.classList.remove('dragover');
|
||||||
|
const files = e.dataTransfer.files;
|
||||||
|
if (files.length > 0) {
|
||||||
|
handleFile(files[0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleFileSelect(e) {
|
||||||
|
const file = e.target.files[0];
|
||||||
|
if (file) handleFile(file);
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleFile(file) {
|
||||||
|
if (!file.type.startsWith('image/')) {
|
||||||
|
showAlert('warning', '请选择图片文件');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const reader = new FileReader();
|
||||||
|
reader.onload = function(e) {
|
||||||
|
fileUploadArea.innerHTML = `
|
||||||
|
<img src="${e.target.result}" class="query-image mb-3">
|
||||||
|
<p class="text-success"><i class="fas fa-check"></i> 图片已选择: ${file.name}</p>
|
||||||
|
`;
|
||||||
|
document.getElementById('imageSearchBtn').disabled = false;
|
||||||
|
};
|
||||||
|
reader.readAsDataURL(file);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 图片搜索
|
||||||
|
document.getElementById('imageSearchBtn').addEventListener('click', async function() {
|
||||||
|
const file = imageFile.files[0];
|
||||||
|
const topK = parseInt(document.getElementById('imageTopK').value);
|
||||||
|
|
||||||
|
if (!file) {
|
||||||
|
showAlert('warning', '请选择图片文件');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
showLoading(true);
|
||||||
|
|
||||||
|
try {
|
||||||
|
const endpoint = '/api/search_by_image';
|
||||||
|
const filter_type = currentMode === 'image_to_text' ? 'text' : 'image';
|
||||||
|
|
||||||
|
const formData = new FormData();
|
||||||
|
formData.append('image', file);
|
||||||
|
formData.append('k', topK);
|
||||||
|
formData.append('filter_type', filter_type);
|
||||||
|
const response = await fetch(endpoint, {
|
||||||
|
method: 'POST',
|
||||||
|
body: formData
|
||||||
|
});
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
|
||||||
|
if (data.success) {
|
||||||
|
displayResults(data, currentMode);
|
||||||
|
} else {
|
||||||
|
throw new Error(data.message);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
showAlert('danger', '搜索失败: ' + error.message);
|
||||||
|
} finally {
|
||||||
|
showLoading(false);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// 显示结果
|
||||||
|
function displayResults(data, mode) {
|
||||||
|
const resultsContainer = document.getElementById('searchResults');
|
||||||
|
|
||||||
|
let html = `
|
||||||
|
<div class="fade-in">
|
||||||
|
<div class="d-flex justify-content-between align-items-center mb-3">
|
||||||
|
<h4><i class="fas fa-search-plus"></i> 搜索结果</h4>
|
||||||
|
<div>
|
||||||
|
<span class="badge bg-info">找到 ${data.results?.length || 0} 个结果</span>
|
||||||
|
<span class="badge bg-secondary">耗时 ${data.search_time || data.time || '0.0'}s</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
`;
|
||||||
|
|
||||||
|
if (data.query_image) {
|
||||||
|
const imageUrl = data.query_image.startsWith('data:') ? data.query_image : `data:image/jpeg;base64,${data.query_image}`;
|
||||||
|
html += `
|
||||||
|
<div class="result-card">
|
||||||
|
<h6><i class="fas fa-image"></i> 查询图片</h6>
|
||||||
|
<img src="${imageUrl}" class="query-image">
|
||||||
|
</div>
|
||||||
|
`;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (data.query) {
|
||||||
|
html += `
|
||||||
|
<div class="result-card">
|
||||||
|
<h6><i class="fas fa-quote-left"></i> 查询文本</h6>
|
||||||
|
<p class="mb-0">"${data.query}"</p>
|
||||||
|
</div>
|
||||||
|
`;
|
||||||
|
}
|
||||||
|
|
||||||
|
data.results.forEach((result, index) => {
|
||||||
|
html += '<div class="result-card">';
|
||||||
|
|
||||||
|
if (mode === 'text_to_image' || mode === 'image_to_image') {
|
||||||
|
const imageUrl = result.image_base64 ? `data:image/jpeg;base64,${result.image_base64}` :
|
||||||
|
(result.image_url || `/temp/${result.filename || result.id}`);
|
||||||
|
const score = result.score || result.distance ?
|
||||||
|
(result.score ? (result.score * 100).toFixed(1) : (100 - result.distance * 100).toFixed(1)) : '95.0';
|
||||||
|
const title = result.title || result.filename || result.id || `结果 ${index + 1}`;
|
||||||
|
|
||||||
|
html += `
|
||||||
|
<div class="row">
|
||||||
|
<div class="col-md-3">
|
||||||
|
<img src="${imageUrl}" class="result-image" alt="Result ${index + 1}">
|
||||||
|
</div>
|
||||||
|
<div class="col-md-9">
|
||||||
|
<div class="d-flex justify-content-between align-items-start">
|
||||||
|
<h6><i class="fas fa-image"></i> ${title}</h6>
|
||||||
|
<span class="score-badge">相似度: ${score}%</span>
|
||||||
|
</div>
|
||||||
|
<p class="text-muted mb-0">类型: 图片 | ID: ${result.id || index}</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
`;
|
||||||
|
} else {
|
||||||
|
const text = result.text || result.content || (typeof result === 'string' ? result : JSON.stringify(result));
|
||||||
|
const score = result.score || result.distance ?
|
||||||
|
(result.score ? (result.score * 100).toFixed(1) : (100 - result.distance * 100).toFixed(1)) : '95.0';
|
||||||
|
const title = result.title || `结果 ${index + 1}`;
|
||||||
|
|
||||||
|
html += `
|
||||||
|
<div class="d-flex justify-content-between align-items-start">
|
||||||
|
<div>
|
||||||
|
<h6><i class="fas fa-file-text"></i> ${title}</h6>
|
||||||
|
<p class="mb-0">${text}</p>
|
||||||
|
<p class="text-muted small mb-0">类型: 文本 | ID: ${result.id || index}</p>
|
||||||
|
</div>
|
||||||
|
<span class="score-badge">相似度: ${score}%</span>
|
||||||
|
</div>
|
||||||
|
`;
|
||||||
|
}
|
||||||
|
|
||||||
|
html += '</div>';
|
||||||
|
});
|
||||||
|
|
||||||
|
html += '</div>';
|
||||||
|
resultsContainer.innerHTML = html;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 工具函数
|
||||||
|
function showLoading(show) {
|
||||||
|
document.getElementById('loadingSpinner').style.display = show ? 'block' : 'none';
|
||||||
|
}
|
||||||
|
|
||||||
|
function showAlert(type, message) {
|
||||||
|
const alertDiv = document.createElement('div');
|
||||||
|
alertDiv.className = `alert alert-${type} alert-dismissible fade show`;
|
||||||
|
alertDiv.innerHTML = `
|
||||||
|
${message}
|
||||||
|
<button type="button" class="btn-close" data-bs-dismiss="alert"></button>
|
||||||
|
`;
|
||||||
|
|
||||||
|
document.querySelector('.main-container .p-4').insertBefore(alertDiv, document.querySelector('.main-container .p-4').firstChild);
|
||||||
|
|
||||||
|
setTimeout(() => {
|
||||||
|
if (alertDiv.parentNode) {
|
||||||
|
alertDiv.remove();
|
||||||
|
}
|
||||||
|
}, 5000);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查系统状态
|
||||||
|
async function checkStatus() {
|
||||||
|
try {
|
||||||
|
const response = await fetch('/api/system_info');
|
||||||
|
const data = await response.json();
|
||||||
|
|
||||||
|
if (data.success) {
|
||||||
|
systemInitialized = true;
|
||||||
|
document.getElementById('statusBadge').innerHTML =
|
||||||
|
'<i class="fas fa-check-circle"></i> 已初始化';
|
||||||
|
document.getElementById('statusBadge').className = 'badge bg-success';
|
||||||
|
} else {
|
||||||
|
document.getElementById('statusBadge').innerHTML =
|
||||||
|
'<i class="fas fa-exclamation-triangle"></i> 未初始化';
|
||||||
|
document.getElementById('statusBadge').className = 'badge bg-warning';
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.log('Status check failed:', error);
|
||||||
|
document.getElementById('statusBadge').innerHTML =
|
||||||
|
'<i class="fas fa-times-circle"></i> 连接失败';
|
||||||
|
document.getElementById('statusBadge').className = 'badge bg-danger';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 页面加载时检查状态
|
||||||
|
checkStatus();
|
||||||
|
|
||||||
|
// 设置数据管理功能事件绑定
|
||||||
|
setupDataManagement();
|
||||||
|
|
||||||
|
function setupDataManagement() {
|
||||||
|
// 批量图片上传事件
|
||||||
|
const batchImageUpload = document.getElementById('batchImageUpload');
|
||||||
|
const batchImageFiles = document.getElementById('batchImageFiles');
|
||||||
|
|
||||||
|
// 拖拽上传
|
||||||
|
batchImageUpload.addEventListener('dragover', function(e) {
|
||||||
|
e.preventDefault();
|
||||||
|
this.classList.add('dragover');
|
||||||
|
});
|
||||||
|
|
||||||
|
batchImageUpload.addEventListener('dragleave', function(e) {
|
||||||
|
e.preventDefault();
|
||||||
|
this.classList.remove('dragover');
|
||||||
|
});
|
||||||
|
|
||||||
|
batchImageUpload.addEventListener('drop', function(e) {
|
||||||
|
e.preventDefault();
|
||||||
|
this.classList.remove('dragover');
|
||||||
|
const files = Array.from(e.dataTransfer.files).filter(file => file.type.startsWith('image/'));
|
||||||
|
if (files.length > 0) {
|
||||||
|
uploadBatchImages(files);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
batchImageFiles.addEventListener('change', function(e) {
|
||||||
|
if (e.target.files.length > 0) {
|
||||||
|
uploadBatchImages(Array.from(e.target.files));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// 批量文本上传
|
||||||
|
document.getElementById('uploadTextsBtn').addEventListener('click', function() {
|
||||||
|
const textData = document.getElementById('batchTextInput').value.trim();
|
||||||
|
if (textData) {
|
||||||
|
const texts = textData.split('\n').filter(line => line.trim());
|
||||||
|
if (texts.length > 0) {
|
||||||
|
uploadBatchTexts(texts);
|
||||||
|
} else {
|
||||||
|
showAlert('warning', '请输入有效的文本数据');
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
showAlert('warning', '请输入文本数据');
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// 从文件导入文本
|
||||||
|
document.getElementById('textFile').addEventListener('change', function(e) {
|
||||||
|
const file = e.target.files[0];
|
||||||
|
if (file) {
|
||||||
|
const reader = new FileReader();
|
||||||
|
reader.onload = function(e) {
|
||||||
|
document.getElementById('batchTextInput').value = e.target.result;
|
||||||
|
};
|
||||||
|
reader.readAsText(file);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// 移除构建索引按钮的事件监听器
|
||||||
|
|
||||||
|
// 查看数据
|
||||||
|
document.getElementById('viewDataBtn').addEventListener('click', viewData);
|
||||||
|
|
||||||
|
// 清空数据
|
||||||
|
document.getElementById('clearDataBtn').addEventListener('click', clearData);
|
||||||
|
|
||||||
|
// 初始化时更新数据统计
|
||||||
|
updateDataStats();
|
||||||
|
}
|
||||||
|
|
||||||
|
// 批量上传图片
|
||||||
|
async function uploadBatchImages(files) {
|
||||||
|
try {
|
||||||
|
const progressDiv = document.getElementById('imageUploadProgress');
|
||||||
|
const progressBar = progressDiv.querySelector('.progress-bar');
|
||||||
|
const progressText = document.getElementById('imageProgressText');
|
||||||
|
|
||||||
|
progressDiv.style.display = 'block';
|
||||||
|
progressText.textContent = `0/${files.length}`;
|
||||||
|
progressBar.style.width = '0%';
|
||||||
|
|
||||||
|
showAlert('info', `正在上传${files.length}张图片...`);
|
||||||
|
|
||||||
|
let successCount = 0;
|
||||||
|
|
||||||
|
for (let i = 0; i < files.length; i++) {
|
||||||
|
const formData = new FormData();
|
||||||
|
formData.append('image', files[i]);
|
||||||
|
|
||||||
|
const response = await fetch('/api/add_image', {
|
||||||
|
method: 'POST',
|
||||||
|
body: formData
|
||||||
|
});
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
if (data.success) {
|
||||||
|
successCount++;
|
||||||
|
} else {
|
||||||
|
console.error(`图片 ${files[i].name} 上传失败: ${data.error}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新进度
|
||||||
|
const progress = Math.round(((i + 1) / files.length) * 100);
|
||||||
|
progressBar.style.width = `${progress}%`;
|
||||||
|
progressText.textContent = `${i + 1}/${files.length}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
showAlert('success', `成功上传 ${successCount}/${files.length} 张图片`);
|
||||||
|
// 自动保存索引
|
||||||
|
await autoSaveIndex();
|
||||||
|
updateDataStats();
|
||||||
|
} catch (error) {
|
||||||
|
showAlert('danger', `图片上传失败: ${error.message}`);
|
||||||
|
} finally {
|
||||||
|
setTimeout(() => {
|
||||||
|
document.getElementById('imageUploadProgress').style.display = 'none';
|
||||||
|
}, 2000);
|
||||||
|
}
|
||||||
|
// 旧代码已删除
|
||||||
|
// 旧代码已删除
|
||||||
|
}
|
||||||
|
|
||||||
|
// 批量上传文本
|
||||||
|
async function uploadBatchTexts(texts) {
|
||||||
|
try {
|
||||||
|
showAlert('info', `正在上传${texts.length}条文本...`);
|
||||||
|
|
||||||
|
for (let i = 0; i < texts.length; i++) {
|
||||||
|
const response = await fetch('/api/add_text', {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {'Content-Type': 'application/json'},
|
||||||
|
body: JSON.stringify({text: texts[i]})
|
||||||
|
});
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
if (!data.success) {
|
||||||
|
throw new Error(`第${i+1}条文本上传失败: ${data.error}`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
showAlert('success', `成功上传${texts.length}条文本`);
|
||||||
|
// 自动保存索引
|
||||||
|
await autoSaveIndex();
|
||||||
|
updateDataStats();
|
||||||
|
} catch (error) {
|
||||||
|
showAlert('danger', `文本上传失败: ${error.message}`);
|
||||||
|
}
|
||||||
|
// 已替换为新的API调用
|
||||||
|
// 旧代码已删除
|
||||||
|
// 已删除
|
||||||
|
}
|
||||||
|
|
||||||
|
// 自动保存索引函数
|
||||||
|
async function autoSaveIndex() {
|
||||||
|
try {
|
||||||
|
const response = await fetch('/api/save_index', {
|
||||||
|
method: 'POST'
|
||||||
|
});
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
|
||||||
|
if (data.success) {
|
||||||
|
console.log('索引自动保存成功');
|
||||||
|
} else {
|
||||||
|
console.error(`索引自动保存失败: ${data.message}`);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error(`索引自动保存错误: ${error.message}`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 查看数据
|
||||||
|
async function viewData() {
|
||||||
|
try {
|
||||||
|
const response = await fetch('/api/list_items');
|
||||||
|
const data = await response.json();
|
||||||
|
|
||||||
|
if (data.success) {
|
||||||
|
let content = '<div class="row">';
|
||||||
|
|
||||||
|
// 显示图片数据
|
||||||
|
if (data.items && data.items.filter(item => item.type === 'image').length > 0) {
|
||||||
|
const imageItems = data.items.filter(item => item.type === 'image');
|
||||||
|
content += '<div class="col-md-6"><h6>图片数据 (' + imageItems.length + ')</h6>';
|
||||||
|
content += '<div class="list-group" style="max-height: 300px; overflow-y: auto;">';
|
||||||
|
imageItems.forEach(item => {
|
||||||
|
content += `<div class="list-group-item d-flex justify-content-between align-items-center">
|
||||||
|
<span>${item.id}: ${item.metadata?.title || '无标题'}</span>
|
||||||
|
<img src="/temp/${item.filename || item.id}" class="img-thumbnail" style="width: 50px; height: 50px; object-fit: cover;">
|
||||||
|
</div>`;
|
||||||
|
});
|
||||||
|
content += '</div></div>';
|
||||||
|
}
|
||||||
|
|
||||||
|
// 显示文本数据
|
||||||
|
if (data.items && data.items.filter(item => item.type === 'text').length > 0) {
|
||||||
|
const textItems = data.items.filter(item => item.type === 'text');
|
||||||
|
content += '<div class="col-md-6"><h6>文本数据 (' + textItems.length + ')</h6>';
|
||||||
|
content += '<div class="list-group" style="max-height: 300px; overflow-y: auto;">';
|
||||||
|
textItems.forEach((item, index) => {
|
||||||
|
const text = item.content || item.text || '';
|
||||||
|
const shortText = text.length > 50 ? text.substring(0, 50) + '...' : text;
|
||||||
|
content += `<div class="list-group-item">
|
||||||
|
<small class="text-muted">#${item.id}</small><br>
|
||||||
|
${shortText}
|
||||||
|
</div>`;
|
||||||
|
});
|
||||||
|
content += '</div></div>';
|
||||||
|
}
|
||||||
|
|
||||||
|
content += '</div>';
|
||||||
|
|
||||||
|
showModal('数据列表', content);
|
||||||
|
} else {
|
||||||
|
showAlert('danger', `获取数据失败: ${data.message}`);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
showAlert('danger', `获取数据错误: ${error.message}`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 清空数据
|
||||||
|
async function clearData() {
|
||||||
|
if (!confirm('确定要清空所有数据吗?此操作不可恢复!')) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
const response = await fetch('/api/clear_index', {
|
||||||
|
method: 'POST'
|
||||||
|
});
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
|
||||||
|
if (data.success) {
|
||||||
|
showAlert('success', '数据已清空');
|
||||||
|
updateDataStats();
|
||||||
|
// 移除构建索引按钮的引用
|
||||||
|
} else {
|
||||||
|
showAlert('danger', `清空失败: ${data.message}`);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
showAlert('danger', `清空错误: ${error.message}`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新数据统计
|
||||||
|
async function updateDataStats() {
|
||||||
|
try {
|
||||||
|
const response = await fetch('/api/system_info');
|
||||||
|
const data = await response.json();
|
||||||
|
|
||||||
|
if (data.success) {
|
||||||
|
const retrieval_info = data.retrieval_info || {};
|
||||||
|
document.getElementById('imageCount').textContent = retrieval_info.image_count || 0;
|
||||||
|
document.getElementById('textCount').textContent = retrieval_info.text_count || 0;
|
||||||
|
// 移除构建索引按钮的引用
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.log('获取数据统计失败:', error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 显示模态框
|
||||||
|
function showModal(title, content) {
|
||||||
|
const modalHtml = `
|
||||||
|
<div class="modal fade" id="dataModal" tabindex="-1">
|
||||||
|
<div class="modal-dialog modal-lg">
|
||||||
|
<div class="modal-content">
|
||||||
|
<div class="modal-header">
|
||||||
|
<h5 class="modal-title">${title}</h5>
|
||||||
|
<button type="button" class="btn-close" data-bs-dismiss="modal"></button>
|
||||||
|
</div>
|
||||||
|
<div class="modal-body">
|
||||||
|
${content}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
`;
|
||||||
|
|
||||||
|
// 移除已存在的模态框
|
||||||
|
const existingModal = document.getElementById('dataModal');
|
||||||
|
if (existingModal) {
|
||||||
|
existingModal.remove();
|
||||||
|
}
|
||||||
|
|
||||||
|
document.body.insertAdjacentHTML('beforeend', modalHtml);
|
||||||
|
const modal = new bootstrap.Modal(document.getElementById('dataModal'));
|
||||||
|
modal.show();
|
||||||
|
}
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
0
test_faiss_local.log
Normal file
0
test_faiss_local.log
Normal file
58
test_faiss_simple.py
Normal file
58
test_faiss_simple.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
FAISS多模态检索系统简单测试
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
from multimodal_retrieval_faiss import MultimodalRetrievalFAISS
|
||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
def test_text_retrieval():
|
||||||
|
print("=== 测试文本检索 ===")
|
||||||
|
|
||||||
|
# 初始化检索系统
|
||||||
|
print("初始化检索系统...")
|
||||||
|
retrieval = MultimodalRetrievalFAISS(
|
||||||
|
model_name="OpenSearch-AI/Ops-MM-embedding-v1-7B",
|
||||||
|
use_all_gpus=True,
|
||||||
|
index_path="faiss_index_test"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 测试文本
|
||||||
|
texts = [
|
||||||
|
"一只可爱的橘色猫咪在沙发上睡觉",
|
||||||
|
"城市夜景中的高楼大厦和车流",
|
||||||
|
"阳光明媚的海滩上,人们在冲浪和晒太阳",
|
||||||
|
"美味的意大利面配红酒和沙拉",
|
||||||
|
"雪山上滑雪的运动员"
|
||||||
|
]
|
||||||
|
|
||||||
|
# 添加文本
|
||||||
|
print("\n添加文本到检索系统...")
|
||||||
|
text_ids = retrieval.add_texts(texts)
|
||||||
|
print(f"添加了{len(text_ids)}条文本")
|
||||||
|
print(f"当前向量数量: {retrieval.get_vector_count()}")
|
||||||
|
|
||||||
|
# 测试文本搜索
|
||||||
|
print("\n测试文本搜索...")
|
||||||
|
queries = ["一只猫在睡觉", "都市风光", "海边的景色"]
|
||||||
|
|
||||||
|
for query in queries:
|
||||||
|
print(f"\n查询: {query}")
|
||||||
|
results = retrieval.search_by_text(query, k=2)
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
print(f" 结果 {i+1}: {result.get('text', 'N/A')} (分数: {result.get('score', 0):.4f})")
|
||||||
|
|
||||||
|
# 保存索引
|
||||||
|
print("\n保存索引...")
|
||||||
|
retrieval.save_index()
|
||||||
|
|
||||||
|
print("\n测试完成!")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_text_retrieval()
|
||||||
164
test_faiss_with_proxy.py
Normal file
164
test_faiss_with_proxy.py
Normal file
@ -0,0 +1,164 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
FAISS多模态检索系统简单测试 - 带代理设置
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
|
||||||
|
# 设置代理
|
||||||
|
os.environ['HTTP_PROXY'] = 'http://127.0.0.1:7890' # 根据实际情况修改
|
||||||
|
os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:7890' # 根据实际情况修改
|
||||||
|
|
||||||
|
# 设置日志
|
||||||
|
logging.basicConfig(level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 设置离线模式,避免下载模型
|
||||||
|
os.environ['TRANSFORMERS_OFFLINE'] = '1'
|
||||||
|
|
||||||
|
# 添加当前目录到路径
|
||||||
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
# 使用简单的向量模型替代大型多模态模型
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
import faiss
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
class SimpleFaissRetrieval:
|
||||||
|
"""简化版FAISS检索系统,使用sentence-transformers"""
|
||||||
|
|
||||||
|
def __init__(self, model_name="paraphrase-multilingual-MiniLM-L12-v2", index_path="simple_faiss_index"):
|
||||||
|
"""
|
||||||
|
初始化简化版检索系统
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: 模型名称,使用轻量级模型
|
||||||
|
index_path: 索引文件路径
|
||||||
|
"""
|
||||||
|
self.model_name = model_name
|
||||||
|
self.index_path = index_path
|
||||||
|
|
||||||
|
logger.info(f"加载模型: {model_name}")
|
||||||
|
try:
|
||||||
|
# 尝试加载模型
|
||||||
|
self.model = SentenceTransformer(model_name)
|
||||||
|
self.dimension = self.model.get_sentence_embedding_dimension()
|
||||||
|
logger.info(f"模型加载成功,向量维度: {self.dimension}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"模型加载失败: {str(e)}")
|
||||||
|
logger.info("使用随机向量模拟...")
|
||||||
|
self.model = None
|
||||||
|
self.dimension = 384 # 默认维度
|
||||||
|
|
||||||
|
# 初始化索引
|
||||||
|
self.index = faiss.IndexFlatL2(self.dimension)
|
||||||
|
self.metadata = {}
|
||||||
|
|
||||||
|
logger.info("检索系统初始化完成")
|
||||||
|
|
||||||
|
def encode_text(self, text):
|
||||||
|
"""编码文本为向量"""
|
||||||
|
if self.model is None:
|
||||||
|
# 如果模型加载失败,使用随机向量
|
||||||
|
if isinstance(text, list):
|
||||||
|
vectors = np.random.rand(len(text), self.dimension).astype('float32')
|
||||||
|
return vectors
|
||||||
|
else:
|
||||||
|
return np.random.rand(self.dimension).astype('float32')
|
||||||
|
else:
|
||||||
|
# 使用模型编码
|
||||||
|
return self.model.encode(text, convert_to_numpy=True)
|
||||||
|
|
||||||
|
def add_texts(self, texts, metadatas=None):
|
||||||
|
"""添加文本到索引"""
|
||||||
|
if not texts:
|
||||||
|
return []
|
||||||
|
|
||||||
|
if metadatas is None:
|
||||||
|
metadatas = [{} for _ in range(len(texts))]
|
||||||
|
|
||||||
|
# 编码文本
|
||||||
|
vectors = self.encode_text(texts)
|
||||||
|
|
||||||
|
# 添加到索引
|
||||||
|
start_id = len(self.metadata)
|
||||||
|
ids = list(range(start_id, start_id + len(texts)))
|
||||||
|
|
||||||
|
self.index.add(np.array(vectors).astype('float32'))
|
||||||
|
|
||||||
|
# 保存元数据
|
||||||
|
for i, id in enumerate(ids):
|
||||||
|
self.metadata[str(id)] = {
|
||||||
|
"text": texts[i],
|
||||||
|
"type": "text",
|
||||||
|
**metadatas[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"添加了{len(ids)}条文本,当前索引大小: {self.index.ntotal}")
|
||||||
|
return [str(id) for id in ids]
|
||||||
|
|
||||||
|
def search(self, query, k=5):
|
||||||
|
"""搜索相似文本"""
|
||||||
|
# 编码查询
|
||||||
|
query_vector = self.encode_text(query)
|
||||||
|
if len(query_vector.shape) == 1:
|
||||||
|
query_vector = query_vector.reshape(1, -1)
|
||||||
|
|
||||||
|
# 搜索
|
||||||
|
distances, indices = self.index.search(query_vector.astype('float32'), k)
|
||||||
|
|
||||||
|
# 处理结果
|
||||||
|
results = []
|
||||||
|
for i in range(len(indices[0])):
|
||||||
|
idx = indices[0][i]
|
||||||
|
if idx < 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
vector_id = str(idx)
|
||||||
|
if vector_id in self.metadata:
|
||||||
|
result = self.metadata[vector_id].copy()
|
||||||
|
result['score'] = float(1.0 / (1.0 + distances[0][i]))
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def test_simple_retrieval():
|
||||||
|
"""测试简化版检索系统"""
|
||||||
|
print("=== 测试简化版FAISS检索系统 ===")
|
||||||
|
|
||||||
|
# 初始化检索系统
|
||||||
|
print("初始化检索系统...")
|
||||||
|
retrieval = SimpleFaissRetrieval()
|
||||||
|
|
||||||
|
# 测试文本
|
||||||
|
texts = [
|
||||||
|
"一只可爱的橘色猫咪在沙发上睡觉",
|
||||||
|
"城市夜景中的高楼大厦和车流",
|
||||||
|
"阳光明媚的海滩上,人们在冲浪和晒太阳",
|
||||||
|
"美味的意大利面配红酒和沙拉",
|
||||||
|
"雪山上滑雪的运动员"
|
||||||
|
]
|
||||||
|
|
||||||
|
# 添加文本
|
||||||
|
print("\n添加文本到检索系统...")
|
||||||
|
text_ids = retrieval.add_texts(texts)
|
||||||
|
print(f"添加了{len(text_ids)}条文本")
|
||||||
|
|
||||||
|
# 测试文本搜索
|
||||||
|
print("\n测试文本搜索...")
|
||||||
|
queries = ["一只猫在睡觉", "都市风光", "海边的景色"]
|
||||||
|
|
||||||
|
for query in queries:
|
||||||
|
print(f"\n查询: {query}")
|
||||||
|
results = retrieval.search(query, k=2)
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
print(f" 结果 {i+1}: {result.get('text', 'N/A')} (分数: {result.get('score', 0):.4f})")
|
||||||
|
|
||||||
|
print("\n测试完成!")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_simple_retrieval()
|
||||||
79
test_fixes.py
Normal file
79
test_fixes.py
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
测试修复后的系统功能
|
||||||
|
"""
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
|
||||||
|
def test_system():
|
||||||
|
"""测试系统功能"""
|
||||||
|
base_url = "http://localhost:5000"
|
||||||
|
|
||||||
|
print("🧪 开始测试修复后的系统...")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
# 测试1: 检查系统状态
|
||||||
|
print("1. 测试系统状态...")
|
||||||
|
try:
|
||||||
|
response = requests.get(f"{base_url}/api/status", timeout=10)
|
||||||
|
if response.status_code == 200:
|
||||||
|
status = response.json()
|
||||||
|
print(f" ✅ 系统状态: {status}")
|
||||||
|
else:
|
||||||
|
print(f" ❌ 状态检查失败: {response.status_code}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ❌ 状态检查异常: {e}")
|
||||||
|
|
||||||
|
# 测试2: 检查数据统计
|
||||||
|
print("\n2. 测试数据统计...")
|
||||||
|
try:
|
||||||
|
response = requests.get(f"{base_url}/api/data/stats", timeout=10)
|
||||||
|
if response.status_code == 200:
|
||||||
|
stats = response.json()
|
||||||
|
print(f" ✅ 数据统计: {stats}")
|
||||||
|
else:
|
||||||
|
print(f" ❌ 统计检查失败: {response.status_code}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ❌ 统计检查异常: {e}")
|
||||||
|
|
||||||
|
# 测试3: 检查数据列表
|
||||||
|
print("\n3. 测试数据列表...")
|
||||||
|
try:
|
||||||
|
response = requests.get(f"{base_url}/api/data/list", timeout=10)
|
||||||
|
if response.status_code == 200:
|
||||||
|
data_list = response.json()
|
||||||
|
print(f" ✅ 数据列表: {data_list}")
|
||||||
|
else:
|
||||||
|
print(f" ❌ 列表检查失败: {response.status_code}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ❌ 列表检查异常: {e}")
|
||||||
|
|
||||||
|
# 测试4: 测试文本搜索(如果系统已初始化)
|
||||||
|
print("\n4. 测试文本搜索...")
|
||||||
|
try:
|
||||||
|
search_data = {
|
||||||
|
"query": "测试查询",
|
||||||
|
"top_k": 3
|
||||||
|
}
|
||||||
|
response = requests.post(f"{base_url}/api/search/text_to_text",
|
||||||
|
json=search_data, timeout=10)
|
||||||
|
if response.status_code == 200:
|
||||||
|
result = response.json()
|
||||||
|
print(f" ✅ 文本搜索: {result}")
|
||||||
|
else:
|
||||||
|
print(f" ❌ 文本搜索失败: {response.status_code}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ❌ 文本搜索异常: {e}")
|
||||||
|
|
||||||
|
print("\n" + "=" * 50)
|
||||||
|
print("🎉 测试完成!")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 等待系统启动
|
||||||
|
print("⏳ 等待系统启动...")
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
test_system()
|
||||||
98
test_local_model.py
Normal file
98
test_local_model.py
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
使用本地模型的FAISS多模态检索系统测试
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
import numpy as np
|
||||||
|
import faiss
|
||||||
|
from typing import List, Dict, Any, Optional, Union
|
||||||
|
import json
|
||||||
|
|
||||||
|
# 设置日志
|
||||||
|
logging.basicConfig(level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 设置离线模式
|
||||||
|
os.environ['TRANSFORMERS_OFFLINE'] = '1'
|
||||||
|
|
||||||
|
def test_local_model():
|
||||||
|
"""测试本地模型加载"""
|
||||||
|
from transformers import AutoModel, AutoTokenizer, AutoProcessor
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
# 这里替换为您实际下载的模型路径
|
||||||
|
local_model_path = "/root/models/Ops-MM-embedding-v1-7B"
|
||||||
|
|
||||||
|
if not os.path.exists(local_model_path):
|
||||||
|
logger.error(f"模型路径不存在: {local_model_path}")
|
||||||
|
logger.info("请先下载模型到指定路径")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"加载本地模型: {local_model_path}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 加载tokenizer
|
||||||
|
logger.info("加载tokenizer...")
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(local_model_path)
|
||||||
|
|
||||||
|
# 加载processor
|
||||||
|
logger.info("加载processor...")
|
||||||
|
processor = AutoProcessor.from_pretrained(local_model_path)
|
||||||
|
|
||||||
|
# 加载模型
|
||||||
|
logger.info("加载模型...")
|
||||||
|
model = AutoModel.from_pretrained(
|
||||||
|
local_model_path,
|
||||||
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
||||||
|
device_map="auto" if torch.cuda.device_count() > 0 else None
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("模型加载成功!")
|
||||||
|
|
||||||
|
# 测试文本编码
|
||||||
|
logger.info("测试文本编码...")
|
||||||
|
text = "这是一个测试文本"
|
||||||
|
inputs = tokenizer(text, return_tensors="pt")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
inputs = {k: v.to("cuda") for k, v in inputs.items()}
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(**inputs)
|
||||||
|
text_embedding = outputs.last_hidden_state[:, 0, :].cpu().numpy()
|
||||||
|
|
||||||
|
logger.info(f"文本编码维度: {text_embedding.shape}")
|
||||||
|
|
||||||
|
# 如果有图像处理功能,测试图像编码
|
||||||
|
try:
|
||||||
|
logger.info("测试图像编码...")
|
||||||
|
# 创建一个简单的测试图像
|
||||||
|
image = Image.new('RGB', (224, 224), color='red')
|
||||||
|
image_inputs = processor(images=image, return_tensors="pt")
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
image_inputs = {k: v.to("cuda") for k, v in image_inputs.items()}
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
image_outputs = model.vision_model(**image_inputs)
|
||||||
|
image_embedding = image_outputs.pooler_output.cpu().numpy()
|
||||||
|
|
||||||
|
logger.info(f"图像编码维度: {image_embedding.shape}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"图像编码测试失败: {str(e)}")
|
||||||
|
|
||||||
|
logger.info("本地模型测试完成!")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"模型加载失败: {str(e)}")
|
||||||
|
logger.error("请确保模型文件已正确下载")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_local_model()
|
||||||
229
test_local_retrieval.py
Normal file
229
test_local_retrieval.py
Normal file
@ -0,0 +1,229 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
测试本地模型和FAISS向量数据库的多模态检索系统
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
import time
|
||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
from multimodal_retrieval_local import MultimodalRetrievalLocal
|
||||||
|
|
||||||
|
# 设置日志
|
||||||
|
logging.basicConfig(level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 设置离线模式
|
||||||
|
os.environ['TRANSFORMERS_OFFLINE'] = '1'
|
||||||
|
|
||||||
|
def test_text_retrieval():
|
||||||
|
"""测试文本检索功能"""
|
||||||
|
print("\n=== 测试文本检索 ===")
|
||||||
|
|
||||||
|
# 初始化检索系统
|
||||||
|
print("初始化检索系统...")
|
||||||
|
retrieval = MultimodalRetrievalLocal(
|
||||||
|
model_path="/root/models/Ops-MM-embedding-v1-7B",
|
||||||
|
use_all_gpus=True,
|
||||||
|
index_path="local_faiss_text_test"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 测试文本
|
||||||
|
texts = [
|
||||||
|
"一只可爱的橘色猫咪在沙发上睡觉",
|
||||||
|
"城市夜景中的高楼大厦和车流",
|
||||||
|
"阳光明媚的海滩上,人们在冲浪和晒太阳",
|
||||||
|
"美味的意大利面配红酒和沙拉",
|
||||||
|
"雪山上滑雪的运动员"
|
||||||
|
]
|
||||||
|
|
||||||
|
# 添加文本
|
||||||
|
print("\n添加文本到检索系统...")
|
||||||
|
text_ids = retrieval.add_texts(texts)
|
||||||
|
print(f"添加了{len(text_ids)}条文本")
|
||||||
|
|
||||||
|
# 获取统计信息
|
||||||
|
stats = retrieval.get_stats()
|
||||||
|
print(f"检索系统统计信息: {stats}")
|
||||||
|
|
||||||
|
# 测试文本搜索
|
||||||
|
print("\n测试文本搜索...")
|
||||||
|
queries = ["一只猫在睡觉", "都市风光", "海边的景色"]
|
||||||
|
|
||||||
|
for query in queries:
|
||||||
|
print(f"\n查询: {query}")
|
||||||
|
results = retrieval.search_by_text(query, k=2)
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
print(f" 结果 {i+1}: {result.get('text', 'N/A')} (分数: {result.get('score', 0):.4f})")
|
||||||
|
|
||||||
|
# 保存索引
|
||||||
|
print("\n保存索引...")
|
||||||
|
retrieval.save_index()
|
||||||
|
|
||||||
|
print("\n文本检索测试完成!")
|
||||||
|
return retrieval
|
||||||
|
|
||||||
|
def test_image_retrieval():
|
||||||
|
"""测试图像检索功能"""
|
||||||
|
print("\n=== 测试图像检索 ===")
|
||||||
|
|
||||||
|
# 初始化检索系统
|
||||||
|
print("初始化检索系统...")
|
||||||
|
retrieval = MultimodalRetrievalLocal(
|
||||||
|
model_path="/root/models/Ops-MM-embedding-v1-7B",
|
||||||
|
use_all_gpus=True,
|
||||||
|
index_path="local_faiss_image_test"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建测试图像
|
||||||
|
print("\n创建测试图像...")
|
||||||
|
images = []
|
||||||
|
colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255)]
|
||||||
|
image_paths = []
|
||||||
|
|
||||||
|
for i, color in enumerate(colors):
|
||||||
|
img = Image.new('RGB', (224, 224), color=color)
|
||||||
|
images.append(img)
|
||||||
|
|
||||||
|
# 保存图像
|
||||||
|
img_path = f"/tmp/test_image_{i}.png"
|
||||||
|
img.save(img_path)
|
||||||
|
image_paths.append(img_path)
|
||||||
|
print(f"创建图像: {img_path}")
|
||||||
|
|
||||||
|
# 添加图像
|
||||||
|
print("\n添加图像到检索系统...")
|
||||||
|
metadatas = [{"description": f"测试图像 {i+1}"} for i in range(len(images))]
|
||||||
|
image_ids = retrieval.add_images(images, metadatas, image_paths)
|
||||||
|
print(f"添加了{len(image_ids)}张图像")
|
||||||
|
|
||||||
|
# 获取统计信息
|
||||||
|
stats = retrieval.get_stats()
|
||||||
|
print(f"检索系统统计信息: {stats}")
|
||||||
|
|
||||||
|
# 测试图像搜索
|
||||||
|
print("\n测试图像搜索...")
|
||||||
|
query_image = Image.new('RGB', (224, 224), color=(255, 0, 0)) # 红色图像
|
||||||
|
|
||||||
|
print("\n使用图像查询图像:")
|
||||||
|
results = retrieval.search_by_image(query_image, k=2, filter_type="image")
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
print(f" 结果 {i+1}: {result.get('description', 'N/A')} (分数: {result.get('score', 0):.4f})")
|
||||||
|
|
||||||
|
# 保存索引
|
||||||
|
print("\n保存索引...")
|
||||||
|
retrieval.save_index()
|
||||||
|
|
||||||
|
print("\n图像检索测试完成!")
|
||||||
|
return retrieval
|
||||||
|
|
||||||
|
def test_cross_modal_retrieval():
|
||||||
|
"""测试跨模态检索功能"""
|
||||||
|
print("\n=== 测试跨模态检索 ===")
|
||||||
|
|
||||||
|
# 初始化检索系统
|
||||||
|
print("初始化检索系统...")
|
||||||
|
retrieval = MultimodalRetrievalLocal(
|
||||||
|
model_path="/root/models/Ops-MM-embedding-v1-7B",
|
||||||
|
use_all_gpus=True,
|
||||||
|
index_path="local_faiss_cross_modal_test"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 添加文本
|
||||||
|
texts = [
|
||||||
|
"一只红色的苹果",
|
||||||
|
"绿色的草地",
|
||||||
|
"蓝色的大海",
|
||||||
|
"黄色的向日葵",
|
||||||
|
"青色的天空"
|
||||||
|
]
|
||||||
|
print("\n添加文本到检索系统...")
|
||||||
|
text_ids = retrieval.add_texts(texts)
|
||||||
|
print(f"添加了{len(text_ids)}条文本")
|
||||||
|
|
||||||
|
# 添加图像
|
||||||
|
print("\n添加图像到检索系统...")
|
||||||
|
images = []
|
||||||
|
colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255)]
|
||||||
|
descriptions = ["红色图像", "绿色图像", "蓝色图像", "黄色图像", "青色图像"]
|
||||||
|
|
||||||
|
for i, color in enumerate(colors):
|
||||||
|
img = Image.new('RGB', (224, 224), color=color)
|
||||||
|
images.append(img)
|
||||||
|
|
||||||
|
metadatas = [{"description": desc} for desc in descriptions]
|
||||||
|
image_ids = retrieval.add_images(images, metadatas)
|
||||||
|
print(f"添加了{len(image_ids)}张图像")
|
||||||
|
|
||||||
|
# 获取统计信息
|
||||||
|
stats = retrieval.get_stats()
|
||||||
|
print(f"检索系统统计信息: {stats}")
|
||||||
|
|
||||||
|
# 测试文搜图
|
||||||
|
print("\n测试文搜图...")
|
||||||
|
query_text = "红色"
|
||||||
|
print(f"查询文本: {query_text}")
|
||||||
|
results = retrieval.search_by_text(query_text, k=2, filter_type="image")
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
print(f" 结果 {i+1}: {result.get('description', 'N/A')} (分数: {result.get('score', 0):.4f})")
|
||||||
|
|
||||||
|
# 测试图搜文
|
||||||
|
print("\n测试图搜文...")
|
||||||
|
query_image = Image.new('RGB', (224, 224), color=(0, 0, 255)) # 蓝色图像
|
||||||
|
print("查询图像: 蓝色图像")
|
||||||
|
results = retrieval.search_by_image(query_image, k=2, filter_type="text")
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
print(f" 结果 {i+1}: {result.get('text', 'N/A')} (分数: {result.get('score', 0):.4f})")
|
||||||
|
|
||||||
|
# 保存索引
|
||||||
|
print("\n保存索引...")
|
||||||
|
retrieval.save_index()
|
||||||
|
|
||||||
|
print("\n跨模态检索测试完成!")
|
||||||
|
return retrieval
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""主函数"""
|
||||||
|
print("=== 本地多模态检索系统测试 ===")
|
||||||
|
|
||||||
|
# 检查模型路径
|
||||||
|
model_path = "/root/models/Ops-MM-embedding-v1-7B"
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
print(f"错误: 模型路径不存在: {model_path}")
|
||||||
|
print("请先下载模型到指定路径")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 检查模型文件
|
||||||
|
config_file = os.path.join(model_path, "config.json")
|
||||||
|
if not os.path.exists(config_file):
|
||||||
|
print(f"错误: 模型配置文件不存在: {config_file}")
|
||||||
|
print("请确保模型文件已正确下载")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"模型路径验证成功: {model_path}")
|
||||||
|
|
||||||
|
# 运行测试
|
||||||
|
try:
|
||||||
|
# 测试文本检索
|
||||||
|
test_text_retrieval()
|
||||||
|
|
||||||
|
# 测试图像检索
|
||||||
|
test_image_retrieval()
|
||||||
|
|
||||||
|
# 测试跨模态检索
|
||||||
|
test_cross_modal_retrieval()
|
||||||
|
|
||||||
|
print("\n所有测试完成!")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"测试过程中发生错误: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
63
web_app.log
Normal file
63
web_app.log
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
nohup: ignoring input
|
||||||
|
INFO:__main__:🚀 å<>¯åŠ¨æ—¶è‡ªåŠ¨åˆ<C3A5>始化VDB多模æ€<C3A6>检索系统...
|
||||||
|
INFO:multimodal_retrieval_vdb:检测到 2 个GPU
|
||||||
|
INFO:multimodal_retrieval_vdb:使用GPU: [0, 1], 主设备: cuda:0
|
||||||
|
INFO:multimodal_retrieval_vdb:GPU内å˜å·²æ¸…ç<E280A6>†
|
||||||
|
INFO:multimodal_retrieval_vdb:æ£åœ¨åŠ è½½æ¨¡åž‹åˆ°GPU: [0, 1]
|
||||||
|
INFO:multimodal_retrieval_vdb:GPU内å˜å·²æ¸…ç<E280A6>†
|
||||||
|
🚀 å<>¯åЍVDB多模æ€<C3A6>检索Web应用
|
||||||
|
============================================================
|
||||||
|
访问地å<EFBFBD>€: http://localhost:5000
|
||||||
|
新功能:
|
||||||
|
🗄ï¸<C3AF> 百度VDB - å<>‘é‡<C3A9>æ•°æ<C2B0>®åº“å˜å‚¨
|
||||||
|
📊 实时统计 - VDBæ•°æ<C2B0>®ç»Ÿè®¡ä¿¡æ<C2A1>¯
|
||||||
|
🔄 æ•°æ<C2B0>®å<C2AE>Œæ¥ - 本地文件到VDBå˜å‚¨
|
||||||
|
支æŒ<EFBFBD>功能:
|
||||||
|
ðŸ“<C5B8> æ–‡æ<E280A1>œæ–‡ - 文本查找相似文本
|
||||||
|
🖼ï¸<C3AF> æ–‡æ<E280A1>œå›¾ - 文本查找相关图片
|
||||||
|
ðŸ“<C5B8> 图æ<C2BE>œæ–‡ - 图片查找相关文本
|
||||||
|
🖼ï¸<C3AF> 图æ<C2BE>œå›¾ - 图片查找相似图片
|
||||||
|
📤 批é‡<C3A9>ä¸Šä¼ - 图片和文本数æ<C2B0>®ç®¡ç<C2A1>†
|
||||||
|
GPUé…<EFBFBD>ç½®:
|
||||||
|
🖥ï¸<C3AF> 检测到 2 个GPU
|
||||||
|
GPU 0: NVIDIA GeForce RTX 4090 (23.6GB)
|
||||||
|
GPU 1: NVIDIA GeForce RTX 4090 (23.6GB)
|
||||||
|
VDBé…<EFBFBD>ç½®:
|
||||||
|
ðŸŒ<C5B8> æœ<C3A6>务器: http://180.76.96.191:5287
|
||||||
|
👤 用户: root
|
||||||
|
🗄ï¸<C3AF> æ•°æ<C2B0>®åº“: multimodal_retrieval
|
||||||
|
============================================================
|
||||||
|
Loading checkpoint shards: 0%| | 0/4 [00:00<?, ?it/s]
Loading checkpoint shards: 25%|██▌ | 1/4 [03:18<09:56, 198.90s/it]
Loading checkpoint shards: 25%|██▌ | 1/4 [03:25<10:15, 205.19s/it]
|
||||||
|
WARNING:multimodal_retrieval_vdb:ç½‘ç»œåŠ è½½å¤±è´¥ï¼Œå°<C3A5>试本地缓å˜: CUDA out of memory. Tried to allocate 130.00 MiB. GPU 0 has a total capacity of 23.64 GiB of which 117.00 MiB is free. Process 3982470 has 384.00 MiB memory in use. Process 729183 has 2.64 GiB memory in use. Process 726298 has 7.43 GiB memory in use. Process 726164 WARNING:multimodal_retrieval_vdb:Tokenizerç½‘ç»œåŠ è½½å¤±è´¥ï¼Œå°<C3A5>试本地: (MaxRetryError("HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /api/models/OpenSearch-AI/Ops-MM-embedding-v1-7B/tree/main/additional_chat_templates?recursive=False&expand=False (Caused by NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x7f92386b4280>: Failed to establish a new connection: [Errno 101] Network is unreachable'))"), '(Request ID: 103ac836-6599-4fe2-a569-aed9c945525c)')
|
||||||
|
The image processor of type `Qwen2VLImageProcessor` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. Note that this behavior will be extended to all models in a future release.
|
||||||
|
WARNING:multimodal_retrieval_vdb:ProcessoråŠ è½½å¤±è´¥ï¼Œä½¿ç”¨tokenizer: (MaxRetryError("HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /api/models/OpenSearch-AI/Ops-MM-embedding-v1-7B/tree/main/additional_chat_templates?recursive=False&expand=False (Caused by NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x7fbad64d1510>: Failed to establish a new connection: [Errno 101] Network is unreachable'))"), '(Request ID: 96f18121-7beb-4e1a-87cd-c50edf682933)')
|
||||||
|
You have video processor config saved in `preprocessor.json` file which is deprecated. Video processor configs should be saved in their own `video_preprocessor.json` file. You can rename the file or load and save the processor back which renames it automatically. Loading from `preprocessor.json` will be removed in v5.0.
|
||||||
|
INFO:multimodal_retrieval_vdb:æ¨¡åž‹åŠ è½½å®Œæˆ<C3A6>
|
||||||
|
INFO:baidu_vdb_backend:✅ æˆ<C3A6>功连接到百度VDB: http://180.76.96.191:5287
|
||||||
|
INFO:baidu_vdb_backend:使用现有数æ<C2B0>®åº“: multimodal_retrieval
|
||||||
|
INFO:baidu_vdb_backend:创建文本å<C2AC>‘é‡<C3A9>表: text_vectors
|
||||||
|
ERROR:baidu_vdb_backend:â<>Œ 创建文本表失败: Database.create_table() missing 1 required positional argument: 'partition'
|
||||||
|
ERROR:baidu_vdb_backend:â<>Œ 表æ“<C3A6>作失败: Database.create_table() missing 1 required positional argument: 'partition'
|
||||||
|
ERROR:multimodal_retrieval_vdb:â<>Œ VDBå<42>Žç«¯åˆ<C3A5>始化失败: Database.create_table() missing 1 required positional argument: 'partition'
|
||||||
|
WARNING:multimodal_retrieval_vdb:âš ï¸<C3AF> ç³»ç»Ÿå°†åœ¨æ— VDB模å¼<C3A5>下è¿<C3A8>行,数æ<C2B0>®å°†ä¸<C3A4>会æŒ<C3A6>久化
|
||||||
|
INFO:multimodal_retrieval_vdb:多模æ€<C3A6>检索系统åˆ<C3A5>始化完æˆ<C3A6>
|
||||||
|
ERROR:__main__:â<>Œ VDB系统自动åˆ<C3A5>始化失败: VDB连接失败
|
||||||
|
ERROR:__main__:Traceback (most recent call last):
|
||||||
|
File "/root/mmeb/web_app_vdb.py", line 667, in auto_initialize
|
||||||
|
raise Exception("VDB连接失败")
|
||||||
|
Exception: VDB连接失败
|
||||||
|
|
||||||
|
* Serving Flask app 'web_app_vdb'
|
||||||
|
* Debug mode: off
|
||||||
|
Address already in use
|
||||||
|
Port 5000 is in use by another program. Either identify and stop that program, or start the server with a different port.
|
||||||
|
¤±è´¥
|
||||||
|
ERROR:__main__:Traceback (most recent call last):
|
||||||
|
File "/root/mmeb/web_app_vdb.py", line 664, in auto_initialize
|
||||||
|
raise Exception("æ¨¡åž‹åŠ è½½å¤±è´¥")
|
||||||
|
Exception: æ¨¡åž‹åŠ è½½å¤±è´¥
|
||||||
|
|
||||||
|
* Serving Flask app 'web_app_vdb'
|
||||||
|
* Debug mode: off
|
||||||
|
Address already in use
|
||||||
|
Port 5000 is in use by another program. Either identify and stop that program, or start the server with a different port.
|
||||||
466
web_app_local.py
Normal file
466
web_app_local.py
Normal file
@ -0,0 +1,466 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
本地多模态检索系统Web应用
|
||||||
|
集成本地模型和FAISS向量数据库
|
||||||
|
支持文搜文、文搜图、图搜文、图搜图四种检索模式
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
import base64
|
||||||
|
from io import BytesIO
|
||||||
|
from pathlib import Path
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
from flask import Flask, request, jsonify, render_template, send_from_directory
|
||||||
|
from werkzeug.utils import secure_filename
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# 设置离线模式
|
||||||
|
os.environ['TRANSFORMERS_OFFLINE'] = '1'
|
||||||
|
|
||||||
|
# 导入本地模块
|
||||||
|
from multimodal_retrieval_local import MultimodalRetrievalLocal
|
||||||
|
from optimized_file_handler import OptimizedFileHandler
|
||||||
|
|
||||||
|
# 设置日志
|
||||||
|
logging.basicConfig(level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 创建Flask应用
|
||||||
|
app = Flask(__name__)
|
||||||
|
|
||||||
|
# 配置
|
||||||
|
app.config['UPLOAD_FOLDER'] = '/tmp/mmeb_uploads'
|
||||||
|
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB
|
||||||
|
app.config['MODEL_PATH'] = '/root/models/Ops-MM-embedding-v1-7B'
|
||||||
|
app.config['INDEX_PATH'] = '/root/mmeb/local_faiss_index'
|
||||||
|
app.config['ALLOWED_EXTENSIONS'] = {'txt', 'pdf', 'png', 'jpg', 'jpeg', 'gif'}
|
||||||
|
|
||||||
|
# 确保上传目录存在
|
||||||
|
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
|
||||||
|
|
||||||
|
# 创建临时文件夹
|
||||||
|
if not os.path.exists(app.config['UPLOAD_FOLDER']):
|
||||||
|
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
|
||||||
|
|
||||||
|
# 创建文件处理器
|
||||||
|
from optimized_file_handler import OptimizedFileHandler
|
||||||
|
file_handler = OptimizedFileHandler(local_storage_dir=app.config['UPLOAD_FOLDER'])
|
||||||
|
|
||||||
|
# 全局变量
|
||||||
|
retrieval_system = None
|
||||||
|
|
||||||
|
def allowed_file(filename):
|
||||||
|
"""检查文件扩展名是否允许"""
|
||||||
|
return '.' in filename and \
|
||||||
|
filename.rsplit('.', 1)[1].lower() in app.config['ALLOWED_EXTENSIONS']
|
||||||
|
|
||||||
|
def init_retrieval_system():
|
||||||
|
"""初始化检索系统"""
|
||||||
|
global retrieval_system
|
||||||
|
|
||||||
|
if retrieval_system is not None:
|
||||||
|
return retrieval_system
|
||||||
|
|
||||||
|
logger.info("初始化多模态检索系统...")
|
||||||
|
|
||||||
|
# 检查模型路径
|
||||||
|
model_path = app.config['MODEL_PATH']
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
logger.error(f"模型路径不存在: {model_path}")
|
||||||
|
raise FileNotFoundError(f"模型路径不存在: {model_path}")
|
||||||
|
|
||||||
|
# 初始化检索系统
|
||||||
|
retrieval_system = MultimodalRetrievalLocal(
|
||||||
|
model_path=model_path,
|
||||||
|
use_all_gpus=True,
|
||||||
|
index_path=app.config['INDEX_PATH']
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("多模态检索系统初始化完成")
|
||||||
|
return retrieval_system
|
||||||
|
|
||||||
|
def get_image_base64(image_path):
|
||||||
|
"""将图像转换为base64编码"""
|
||||||
|
with open(image_path, "rb") as image_file:
|
||||||
|
encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
|
||||||
|
return f"data:image/jpeg;base64,{encoded_string}"
|
||||||
|
|
||||||
|
@app.route('/')
|
||||||
|
def index():
|
||||||
|
"""首页"""
|
||||||
|
return render_template('local_index.html')
|
||||||
|
|
||||||
|
@app.route('/api/stats', methods=['GET'])
|
||||||
|
def get_stats():
|
||||||
|
"""获取系统统计信息"""
|
||||||
|
try:
|
||||||
|
retrieval = init_retrieval_system()
|
||||||
|
stats = retrieval.get_stats()
|
||||||
|
return jsonify({"success": True, "stats": stats})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取统计信息失败: {str(e)}")
|
||||||
|
return jsonify({"success": False, "error": str(e)}), 500
|
||||||
|
|
||||||
|
@app.route('/api/add_text', methods=['POST'])
|
||||||
|
def add_text():
|
||||||
|
"""添加文本"""
|
||||||
|
try:
|
||||||
|
data = request.json
|
||||||
|
text = data.get('text')
|
||||||
|
|
||||||
|
if not text:
|
||||||
|
return jsonify({"success": False, "error": "文本不能为空"}), 400
|
||||||
|
|
||||||
|
# 使用内存处理文本
|
||||||
|
with file_handler.temp_file_context(text.encode('utf-8'), suffix='.txt') as temp_file:
|
||||||
|
logger.info(f"处理文本: {temp_file}")
|
||||||
|
|
||||||
|
# 初始化检索系统
|
||||||
|
retrieval = init_retrieval_system()
|
||||||
|
|
||||||
|
# 添加文本
|
||||||
|
metadata = {
|
||||||
|
"timestamp": time.time(),
|
||||||
|
"source": "web_upload"
|
||||||
|
}
|
||||||
|
|
||||||
|
text_ids = retrieval.add_texts([text], [metadata])
|
||||||
|
|
||||||
|
# 保存索引
|
||||||
|
retrieval.save_index()
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
"success": True,
|
||||||
|
"message": "文本添加成功",
|
||||||
|
"text_id": text_ids[0] if text_ids else None
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"添加文本失败: {str(e)}")
|
||||||
|
return jsonify({"success": False, "error": str(e)}), 500
|
||||||
|
finally:
|
||||||
|
# 清理临时文件
|
||||||
|
file_handler.cleanup_all_temp_files()
|
||||||
|
|
||||||
|
@app.route('/api/add_image', methods=['POST'])
|
||||||
|
def add_image():
|
||||||
|
"""添加图像"""
|
||||||
|
try:
|
||||||
|
# 检查是否有文件
|
||||||
|
if 'image' not in request.files:
|
||||||
|
return jsonify({"success": False, "error": "没有上传文件"}), 400
|
||||||
|
|
||||||
|
file = request.files['image']
|
||||||
|
|
||||||
|
# 检查文件名
|
||||||
|
if file.filename == '':
|
||||||
|
return jsonify({"success": False, "error": "没有选择文件"}), 400
|
||||||
|
|
||||||
|
if file and allowed_file(file.filename):
|
||||||
|
# 读取图像数据
|
||||||
|
image_data = file.read()
|
||||||
|
file_size = len(image_data)
|
||||||
|
|
||||||
|
# 使用文件处理器处理图像
|
||||||
|
logger.info(f"处理图像: {file.filename} ({file_size} 字节)")
|
||||||
|
|
||||||
|
# 初始化检索系统
|
||||||
|
retrieval = init_retrieval_system()
|
||||||
|
|
||||||
|
# 创建临时文件
|
||||||
|
file_obj = BytesIO(image_data)
|
||||||
|
filename = secure_filename(file.filename)
|
||||||
|
|
||||||
|
# 保存到本地文件系统
|
||||||
|
image_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
|
||||||
|
with open(image_path, 'wb') as f:
|
||||||
|
f.write(image_data)
|
||||||
|
|
||||||
|
# 加载图像
|
||||||
|
try:
|
||||||
|
image = Image.open(BytesIO(image_data))
|
||||||
|
# 确保图像是RGB模式
|
||||||
|
if image.mode != 'RGB':
|
||||||
|
logger.info(f"将图像从 {image.mode} 转换为 RGB")
|
||||||
|
image = image.convert('RGB')
|
||||||
|
|
||||||
|
logger.info(f"成功加载图像: {filename}, 格式: {image.format}, 模式: {image.mode}, 大小: {image.size}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"加载图像失败: {filename}, 错误: {str(e)}")
|
||||||
|
return jsonify({"success": False, "error": f"图像格式不支持: {str(e)}"}), 400
|
||||||
|
|
||||||
|
# 添加图像
|
||||||
|
metadata = {
|
||||||
|
"filename": filename,
|
||||||
|
"timestamp": time.time(),
|
||||||
|
"source": "web_upload",
|
||||||
|
"size": file_size,
|
||||||
|
"local_path": image_path
|
||||||
|
}
|
||||||
|
|
||||||
|
# 添加到检索系统
|
||||||
|
image_ids = retrieval.add_images([image], [metadata], [image_path])
|
||||||
|
|
||||||
|
# 保存索引
|
||||||
|
retrieval.save_index()
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
"success": True,
|
||||||
|
"message": "图像添加成功",
|
||||||
|
"image_id": image_ids[0] if image_ids else None
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
return jsonify({"success": False, "error": "不支持的文件类型"}), 400
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"添加图像失败: {str(e)}")
|
||||||
|
return jsonify({"success": False, "error": str(e)}), 500
|
||||||
|
finally:
|
||||||
|
# 清理临时文件
|
||||||
|
file_handler.cleanup_all_temp_files()
|
||||||
|
|
||||||
|
@app.route('/api/search_by_text', methods=['POST'])
|
||||||
|
def search_by_text():
|
||||||
|
"""文本搜索"""
|
||||||
|
try:
|
||||||
|
data = request.json
|
||||||
|
query = data.get('query')
|
||||||
|
k = int(data.get('k', 5))
|
||||||
|
filter_type = data.get('filter_type') # "text", "image" 或 null
|
||||||
|
|
||||||
|
if not query:
|
||||||
|
return jsonify({"success": False, "error": "查询文本不能为空"}), 400
|
||||||
|
|
||||||
|
# 初始化检索系统
|
||||||
|
retrieval = init_retrieval_system()
|
||||||
|
|
||||||
|
# 执行搜索
|
||||||
|
results = retrieval.search_by_text(query, k, filter_type)
|
||||||
|
|
||||||
|
# 处理结果
|
||||||
|
processed_results = []
|
||||||
|
for result in results:
|
||||||
|
item = {
|
||||||
|
"score": result.get("score", 0),
|
||||||
|
"type": result.get("type")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.get("type") == "text":
|
||||||
|
item["text"] = result.get("text", "")
|
||||||
|
elif result.get("type") == "image":
|
||||||
|
if "path" in result and os.path.exists(result["path"]):
|
||||||
|
item["image"] = get_image_base64(result["path"])
|
||||||
|
item["filename"] = os.path.basename(result["path"])
|
||||||
|
if "description" in result:
|
||||||
|
item["description"] = result["description"]
|
||||||
|
|
||||||
|
processed_results.append(item)
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
"success": True,
|
||||||
|
"results": processed_results,
|
||||||
|
"query": query,
|
||||||
|
"filter_type": filter_type
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"文本搜索失败: {str(e)}")
|
||||||
|
return jsonify({"success": False, "error": str(e)}), 500
|
||||||
|
|
||||||
|
@app.route('/api/search_by_image', methods=['POST'])
|
||||||
|
def search_by_image():
|
||||||
|
"""图像搜索"""
|
||||||
|
try:
|
||||||
|
# 检查是否有文件
|
||||||
|
if 'image' not in request.files:
|
||||||
|
return jsonify({"success": False, "error": "没有上传文件"}), 400
|
||||||
|
|
||||||
|
file = request.files['image']
|
||||||
|
k = int(request.form.get('k', 5))
|
||||||
|
filter_type = request.form.get('filter_type') # "text", "image" 或 null
|
||||||
|
|
||||||
|
# 检查文件名
|
||||||
|
if file.filename == '':
|
||||||
|
return jsonify({"success": False, "error": "没有选择文件"}), 400
|
||||||
|
|
||||||
|
if file and allowed_file(file.filename):
|
||||||
|
# 读取图像数据
|
||||||
|
image_data = file.read()
|
||||||
|
file_size = len(image_data)
|
||||||
|
|
||||||
|
# 根据文件大小选择处理方式
|
||||||
|
if file_size <= 5 * 1024 * 1024: # 5MB
|
||||||
|
# 小文件使用内存处理
|
||||||
|
logger.info(f"使用内存处理搜索图像: {file.filename} ({file_size} 字节)")
|
||||||
|
image = Image.open(BytesIO(image_data))
|
||||||
|
|
||||||
|
# 初始化检索系统
|
||||||
|
retrieval = init_retrieval_system()
|
||||||
|
|
||||||
|
# 执行搜索
|
||||||
|
results = retrieval.search_by_image(image, k, filter_type)
|
||||||
|
else:
|
||||||
|
# 大文件使用临时文件处理
|
||||||
|
with file_handler.temp_file_context(image_data, suffix=os.path.splitext(file.filename)[1]) as temp_file:
|
||||||
|
logger.info(f"使用临时文件处理搜索图像: {temp_file} ({file_size} 字节)")
|
||||||
|
|
||||||
|
# 初始化检索系统
|
||||||
|
retrieval = init_retrieval_system()
|
||||||
|
|
||||||
|
# 加载图像
|
||||||
|
image = Image.open(temp_file)
|
||||||
|
|
||||||
|
# 执行搜索
|
||||||
|
results = retrieval.search_by_image(image, k, filter_type)
|
||||||
|
|
||||||
|
# 处理结果
|
||||||
|
processed_results = []
|
||||||
|
for result in results:
|
||||||
|
item = {
|
||||||
|
"score": result.get("score", 0),
|
||||||
|
"type": result.get("type")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.get("type") == "text":
|
||||||
|
item["text"] = result.get("text", "")
|
||||||
|
elif result.get("type") == "image":
|
||||||
|
if "path" in result and os.path.exists(result["path"]):
|
||||||
|
item["image"] = get_image_base64(result["path"])
|
||||||
|
item["filename"] = os.path.basename(result["path"])
|
||||||
|
if "description" in result:
|
||||||
|
item["description"] = result["description"]
|
||||||
|
|
||||||
|
processed_results.append(item)
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
"success": True,
|
||||||
|
"results": processed_results,
|
||||||
|
"filter_type": filter_type
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
return jsonify({"success": False, "error": "不支持的文件类型"}), 400
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"图像搜索失败: {str(e)}")
|
||||||
|
return jsonify({"success": False, "error": str(e)}), 500
|
||||||
|
finally:
|
||||||
|
# 清理临时文件
|
||||||
|
file_handler.cleanup_all_temp_files()
|
||||||
|
|
||||||
|
@app.route('/uploads/<filename>')
|
||||||
|
def uploaded_file(filename):
|
||||||
|
"""提供上传文件的访问"""
|
||||||
|
return send_from_directory(app.config['UPLOAD_FOLDER'], filename)
|
||||||
|
|
||||||
|
@app.route('/temp/<filename>')
|
||||||
|
def temp_file(filename):
|
||||||
|
"""提供临时文件的访问"""
|
||||||
|
return send_from_directory(app.config['UPLOAD_FOLDER'], filename)
|
||||||
|
|
||||||
|
@app.route('/api/save_index', methods=['POST'])
|
||||||
|
def save_index():
|
||||||
|
"""保存索引"""
|
||||||
|
try:
|
||||||
|
# 初始化检索系统
|
||||||
|
retrieval = init_retrieval_system()
|
||||||
|
|
||||||
|
# 保存索引
|
||||||
|
retrieval.save_index()
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
"success": True,
|
||||||
|
"message": "索引保存成功"
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"保存索引失败: {str(e)}")
|
||||||
|
return jsonify({"success": False, "error": str(e)}), 500
|
||||||
|
|
||||||
|
@app.route('/api/clear_index', methods=['POST'])
|
||||||
|
def clear_index():
|
||||||
|
"""清空索引"""
|
||||||
|
try:
|
||||||
|
# 初始化检索系统
|
||||||
|
retrieval = init_retrieval_system()
|
||||||
|
|
||||||
|
# 清空索引
|
||||||
|
retrieval.clear_index()
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
"success": True,
|
||||||
|
"message": "索引已清空"
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"清空索引失败: {str(e)}")
|
||||||
|
return jsonify({"success": False, "error": str(e)}), 500
|
||||||
|
|
||||||
|
@app.route('/api/list_items', methods=['GET'])
|
||||||
|
def list_items():
|
||||||
|
"""列出所有索引项"""
|
||||||
|
try:
|
||||||
|
# 初始化检索系统
|
||||||
|
retrieval = init_retrieval_system()
|
||||||
|
|
||||||
|
# 获取所有项
|
||||||
|
items = retrieval.list_items()
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
"success": True,
|
||||||
|
"items": items
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"列出索引项失败: {str(e)}")
|
||||||
|
return jsonify({"success": False, "error": str(e)}), 500
|
||||||
|
|
||||||
|
@app.route('/api/system_info', methods=['GET', 'POST'])
|
||||||
|
def system_info():
|
||||||
|
"""获取系统信息"""
|
||||||
|
try:
|
||||||
|
# GPU信息
|
||||||
|
gpu_info = []
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
for i in range(torch.cuda.device_count()):
|
||||||
|
gpu_info.append({
|
||||||
|
"id": i,
|
||||||
|
"name": torch.cuda.get_device_name(i),
|
||||||
|
"memory_total": torch.cuda.get_device_properties(i).total_memory / (1024 ** 3),
|
||||||
|
"memory_allocated": torch.cuda.memory_allocated(i) / (1024 ** 3),
|
||||||
|
"memory_reserved": torch.cuda.memory_reserved(i) / (1024 ** 3)
|
||||||
|
})
|
||||||
|
|
||||||
|
# 检索系统信息
|
||||||
|
retrieval_info = {}
|
||||||
|
if retrieval_system is not None:
|
||||||
|
retrieval_info = retrieval_system.get_stats()
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
"success": True,
|
||||||
|
"gpu_info": gpu_info,
|
||||||
|
"retrieval_info": retrieval_info,
|
||||||
|
"model_path": app.config['MODEL_PATH'],
|
||||||
|
"index_path": app.config['INDEX_PATH']
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取系统信息失败: {str(e)}")
|
||||||
|
return jsonify({"success": False, "error": str(e)}), 500
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
try:
|
||||||
|
# 预初始化检索系统
|
||||||
|
init_retrieval_system()
|
||||||
|
|
||||||
|
# 启动Web应用
|
||||||
|
app.run(host='0.0.0.0', port=5000, debug=False)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"启动Web应用失败: {str(e)}")
|
||||||
|
sys.exit(1)
|
||||||
@ -514,6 +514,57 @@ def get_data_stats():
|
|||||||
'message': f'获取统计失败: {str(e)}'
|
'message': f'获取统计失败: {str(e)}'
|
||||||
}), 500
|
}), 500
|
||||||
|
|
||||||
|
@app.route('/api/data/list', methods=['GET'])
|
||||||
|
def list_data():
|
||||||
|
"""获取数据列表"""
|
||||||
|
try:
|
||||||
|
# 获取图片文件列表
|
||||||
|
image_files = []
|
||||||
|
for ext in ALLOWED_EXTENSIONS:
|
||||||
|
pattern = os.path.join(SAMPLE_IMAGES_FOLDER, f"*.{ext}")
|
||||||
|
for file_path in glob.glob(pattern):
|
||||||
|
try:
|
||||||
|
# 转换为base64
|
||||||
|
image_base64 = image_to_base64(file_path)
|
||||||
|
image_files.append({
|
||||||
|
'filename': os.path.basename(file_path),
|
||||||
|
'filepath': file_path,
|
||||||
|
'image_base64': image_base64,
|
||||||
|
'size': os.path.getsize(file_path)
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"处理图片文件失败 {file_path}: {e}")
|
||||||
|
|
||||||
|
# 获取文本文件列表
|
||||||
|
text_files = []
|
||||||
|
text_file_paths = glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.json"))
|
||||||
|
text_file_paths.extend(glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.txt")))
|
||||||
|
|
||||||
|
for text_file in text_file_paths:
|
||||||
|
try:
|
||||||
|
text_files.append({
|
||||||
|
'filename': os.path.basename(text_file),
|
||||||
|
'filepath': text_file,
|
||||||
|
'size': os.path.getsize(text_file)
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"处理文本文件失败 {text_file}: {e}")
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
'success': True,
|
||||||
|
'image_files': image_files,
|
||||||
|
'text_files': text_files,
|
||||||
|
'image_count': len(image_files),
|
||||||
|
'text_count': len(text_files)
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取数据列表失败: {str(e)}")
|
||||||
|
return jsonify({
|
||||||
|
'success': False,
|
||||||
|
'message': f'获取数据列表失败: {str(e)}'
|
||||||
|
}), 500
|
||||||
|
|
||||||
@app.route('/api/data/clear', methods=['POST'])
|
@app.route('/api/data/clear', methods=['POST'])
|
||||||
def clear_data():
|
def clear_data():
|
||||||
"""清空所有数据"""
|
"""清空所有数据"""
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user