Compare commits
No commits in common. "2b480e5277d23ac73fdacac65274f42e038e52a0" and "f6099c312e284e0871474d86ddbd30697188df8c" have entirely different histories.
2b480e5277
...
f6099c312e
BIN
__pycache__/multimodal_retrieval_multigpu.cpython-310.pyc
Normal file
@ -1,108 +0,0 @@
|
|||||||
# 多模态模型下载指南
|
|
||||||
|
|
||||||
## 下载 OpenSearch-AI/Ops-MM-embedding-v1-7B 模型
|
|
||||||
|
|
||||||
### 方法1:使用 git-lfs
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# 安装 git-lfs
|
|
||||||
apt-get install git-lfs
|
|
||||||
# 或
|
|
||||||
curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | bash
|
|
||||||
apt-get install git-lfs
|
|
||||||
|
|
||||||
# 初始化 git-lfs
|
|
||||||
git lfs install
|
|
||||||
|
|
||||||
# 克隆模型仓库
|
|
||||||
mkdir -p ~/models
|
|
||||||
git clone https://huggingface.co/OpenSearch-AI/Ops-MM-embedding-v1-7B ~/models/Ops-MM-embedding-v1-7B
|
|
||||||
```
|
|
||||||
|
|
||||||
### 方法2:使用 huggingface-cli
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# 安装 huggingface-hub
|
|
||||||
pip install huggingface-hub
|
|
||||||
|
|
||||||
# 下载模型
|
|
||||||
mkdir -p ~/models
|
|
||||||
huggingface-cli download OpenSearch-AI/Ops-MM-embedding-v1-7B --local-dir ~/models/Ops-MM-embedding-v1-7B
|
|
||||||
```
|
|
||||||
|
|
||||||
### 方法3:手动下载关键文件
|
|
||||||
|
|
||||||
如果上述方法不可行,可以手动下载以下关键文件:
|
|
||||||
|
|
||||||
1. 访问 https://huggingface.co/OpenSearch-AI/Ops-MM-embedding-v1-7B/tree/main
|
|
||||||
2. 下载以下文件:
|
|
||||||
- `config.json`
|
|
||||||
- `pytorch_model.bin` (或分片文件 `pytorch_model-00001-of-00002.bin` 等)
|
|
||||||
- `tokenizer.json`
|
|
||||||
- `tokenizer_config.json`
|
|
||||||
- `special_tokens_map.json`
|
|
||||||
- `vocab.txt`
|
|
||||||
|
|
||||||
## 下载替代轻量级模型
|
|
||||||
|
|
||||||
如果主模型太大,可以下载这些较小的替代模型:
|
|
||||||
|
|
||||||
### CLIP 模型
|
|
||||||
|
|
||||||
```bash
|
|
||||||
mkdir -p ~/models/clip-ViT-B-32
|
|
||||||
huggingface-cli download openai/clip-vit-base-patch32 --local-dir ~/models/clip-ViT-B-32
|
|
||||||
```
|
|
||||||
|
|
||||||
### 多语言CLIP模型
|
|
||||||
|
|
||||||
```bash
|
|
||||||
mkdir -p ~/models/clip-multilingual
|
|
||||||
huggingface-cli download sentence-transformers/clip-ViT-B-32-multilingual-v1 --local-dir ~/models/clip-multilingual
|
|
||||||
```
|
|
||||||
|
|
||||||
## 传输模型文件
|
|
||||||
|
|
||||||
下载完成后,使用以下方法将模型传输到目标服务器:
|
|
||||||
|
|
||||||
### 使用 scp
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# 从当前机器传输到目标服务器
|
|
||||||
scp -r ~/models/Ops-MM-embedding-v1-7B user@target-server:/root/models/
|
|
||||||
```
|
|
||||||
|
|
||||||
### 使用压缩文件
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# 压缩
|
|
||||||
tar -czvf model.tar.gz ~/models/Ops-MM-embedding-v1-7B
|
|
||||||
|
|
||||||
# 传输压缩文件
|
|
||||||
scp model.tar.gz user@target-server:/root/
|
|
||||||
|
|
||||||
# 在目标服务器上解压
|
|
||||||
ssh user@target-server
|
|
||||||
mkdir -p /root/models
|
|
||||||
tar -xzvf /root/model.tar.gz -C /root/models
|
|
||||||
```
|
|
||||||
|
|
||||||
## 验证模型文件
|
|
||||||
|
|
||||||
模型下载完成后,目录结构应类似于:
|
|
||||||
|
|
||||||
```
|
|
||||||
/root/models/Ops-MM-embedding-v1-7B/
|
|
||||||
├── config.json
|
|
||||||
├── pytorch_model.bin (或分片文件)
|
|
||||||
├── tokenizer.json
|
|
||||||
├── tokenizer_config.json
|
|
||||||
├── special_tokens_map.json
|
|
||||||
└── vocab.txt
|
|
||||||
```
|
|
||||||
|
|
||||||
使用以下命令验证文件完整性:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
ls -la /root/models/Ops-MM-embedding-v1-7B/
|
|
||||||
```
|
|
||||||
@ -1,605 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
使用本地模型的多模态检索系统
|
|
||||||
支持文搜文、文搜图、图搜文、图搜图四种检索模式
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
from ops_mm_embedding_v1 import OpsMMEmbeddingV1
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
from typing import List, Union, Tuple, Dict, Any, Optional
|
|
||||||
import os
|
|
||||||
import json
|
|
||||||
from pathlib import Path
|
|
||||||
import logging
|
|
||||||
import gc
|
|
||||||
import faiss
|
|
||||||
import time
|
|
||||||
|
|
||||||
# 设置日志
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# 设置离线模式
|
|
||||||
os.environ['TRANSFORMERS_OFFLINE'] = '1'
|
|
||||||
|
|
||||||
class MultimodalRetrievalLocal:
|
|
||||||
"""使用本地模型的多模态检索系统"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
model_path: str = "/root/models/Ops-MM-embedding-v1-7B",
|
|
||||||
use_all_gpus: bool = True,
|
|
||||||
gpu_ids: List[int] = None,
|
|
||||||
min_memory_gb: int = 12,
|
|
||||||
index_path: str = "local_faiss_index",
|
|
||||||
shard_model_across_gpus: bool = False,
|
|
||||||
load_in_4bit: bool = False,
|
|
||||||
load_in_8bit: bool = False,
|
|
||||||
torch_dtype: Optional[torch.dtype] = torch.bfloat16):
|
|
||||||
"""
|
|
||||||
初始化多模态检索系统
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_path: 本地模型路径
|
|
||||||
use_all_gpus: 是否使用所有可用GPU
|
|
||||||
gpu_ids: 指定使用的GPU ID列表
|
|
||||||
min_memory_gb: 最小可用内存(GB)
|
|
||||||
index_path: FAISS索引文件路径
|
|
||||||
"""
|
|
||||||
self.model_path = model_path
|
|
||||||
self.index_path = index_path
|
|
||||||
self.shard_model_across_gpus = shard_model_across_gpus
|
|
||||||
self.load_in_4bit = load_in_4bit
|
|
||||||
self.load_in_8bit = load_in_8bit
|
|
||||||
self.torch_dtype = torch_dtype
|
|
||||||
|
|
||||||
# 检查模型路径
|
|
||||||
if not os.path.exists(model_path):
|
|
||||||
logger.error(f"模型路径不存在: {model_path}")
|
|
||||||
logger.info("请先下载模型到指定路径")
|
|
||||||
raise FileNotFoundError(f"模型路径不存在: {model_path}")
|
|
||||||
|
|
||||||
# 设置GPU设备
|
|
||||||
self._setup_devices(use_all_gpus, gpu_ids, min_memory_gb)
|
|
||||||
|
|
||||||
# 清理GPU内存
|
|
||||||
self._clear_all_gpu_memory()
|
|
||||||
|
|
||||||
# 加载嵌入模型
|
|
||||||
self._load_embedding_model()
|
|
||||||
|
|
||||||
# 初始化FAISS索引
|
|
||||||
self._init_index()
|
|
||||||
|
|
||||||
logger.info(f"多模态检索系统初始化完成,使用本地模型: {model_path}")
|
|
||||||
logger.info(f"向量存储路径: {index_path}")
|
|
||||||
|
|
||||||
def _setup_devices(self, use_all_gpus: bool, gpu_ids: List[int], min_memory_gb: int):
|
|
||||||
"""设置GPU设备"""
|
|
||||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
self.use_gpu = self.device.type == "cuda"
|
|
||||||
|
|
||||||
if self.use_gpu:
|
|
||||||
self.available_gpus = self._get_available_gpus(min_memory_gb)
|
|
||||||
|
|
||||||
if not self.available_gpus:
|
|
||||||
logger.warning(f"没有可用的GPU或GPU内存不足{min_memory_gb}GB,将使用CPU")
|
|
||||||
self.device = torch.device("cpu")
|
|
||||||
self.use_gpu = False
|
|
||||||
else:
|
|
||||||
if gpu_ids:
|
|
||||||
self.gpu_ids = [gid for gid in gpu_ids if gid in self.available_gpus]
|
|
||||||
if not self.gpu_ids:
|
|
||||||
logger.warning(f"指定的GPU {gpu_ids}不可用或内存不足,将使用可用的GPU: {self.available_gpus}")
|
|
||||||
self.gpu_ids = self.available_gpus
|
|
||||||
elif use_all_gpus:
|
|
||||||
self.gpu_ids = self.available_gpus
|
|
||||||
else:
|
|
||||||
self.gpu_ids = [self.available_gpus[0]]
|
|
||||||
|
|
||||||
logger.info(f"使用GPU: {self.gpu_ids}")
|
|
||||||
self.device = torch.device(f"cuda:{self.gpu_ids[0]}")
|
|
||||||
else:
|
|
||||||
logger.warning("没有可用的GPU,将使用CPU")
|
|
||||||
self.gpu_ids = []
|
|
||||||
|
|
||||||
def _get_available_gpus(self, min_memory_gb: int) -> List[int]:
|
|
||||||
"""获取可用的GPU列表"""
|
|
||||||
available_gpus = []
|
|
||||||
for i in range(torch.cuda.device_count()):
|
|
||||||
total_mem = torch.cuda.get_device_properties(i).total_memory / (1024 ** 3) # GB
|
|
||||||
if total_mem >= min_memory_gb:
|
|
||||||
available_gpus.append(i)
|
|
||||||
return available_gpus
|
|
||||||
|
|
||||||
def _clear_all_gpu_memory(self):
|
|
||||||
"""清理GPU内存"""
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
def _load_embedding_model(self):
|
|
||||||
"""加载多模态嵌入模型 OpsMMEmbeddingV1"""
|
|
||||||
logger.info(f"加载本地多模态嵌入模型: {self.model_path}")
|
|
||||||
try:
|
|
||||||
self.models: List[OpsMMEmbeddingV1] = []
|
|
||||||
if self.use_gpu and len(self.gpu_ids) > 1 and self.shard_model_across_gpus:
|
|
||||||
# Tensor-parallel sharding using device_map='auto' across visible GPUs
|
|
||||||
logger.info(f"启用模型跨GPU切片(shard),使用设备: {self.gpu_ids}")
|
|
||||||
# Rely on CUDA_VISIBLE_DEVICES to constrain which GPUs are visible, or use accelerate's auto mapping
|
|
||||||
ref_model = OpsMMEmbeddingV1(
|
|
||||||
self.model_path,
|
|
||||||
device="cuda",
|
|
||||||
attn_implementation=None,
|
|
||||||
device_map="auto",
|
|
||||||
load_in_4bit=self.load_in_4bit,
|
|
||||||
load_in_8bit=self.load_in_8bit,
|
|
||||||
torch_dtype=self.torch_dtype,
|
|
||||||
)
|
|
||||||
# For sharded model, we keep a single logical model reference
|
|
||||||
self.models.append(ref_model)
|
|
||||||
else:
|
|
||||||
if self.use_gpu and len(self.gpu_ids) > 1:
|
|
||||||
logger.info(f"检测到多GPU,可用: {self.gpu_ids},为每张卡加载一个模型副本(数据并行)")
|
|
||||||
for gid in self.gpu_ids:
|
|
||||||
device_str = f"cuda:{gid}"
|
|
||||||
self.models.append(
|
|
||||||
OpsMMEmbeddingV1(
|
|
||||||
self.model_path,
|
|
||||||
device=device_str,
|
|
||||||
attn_implementation=None,
|
|
||||||
load_in_4bit=self.load_in_4bit,
|
|
||||||
load_in_8bit=self.load_in_8bit,
|
|
||||||
torch_dtype=self.torch_dtype,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
ref_model = self.models[0]
|
|
||||||
else:
|
|
||||||
device_str = "cuda" if self.use_gpu else "cpu"
|
|
||||||
logger.info(f"使用单设备: {device_str}")
|
|
||||||
ref_model = OpsMMEmbeddingV1(
|
|
||||||
self.model_path,
|
|
||||||
device=device_str,
|
|
||||||
attn_implementation=None,
|
|
||||||
load_in_4bit=self.load_in_4bit,
|
|
||||||
load_in_8bit=self.load_in_8bit,
|
|
||||||
torch_dtype=self.torch_dtype,
|
|
||||||
)
|
|
||||||
self.models.append(ref_model)
|
|
||||||
|
|
||||||
# 获取向量维度
|
|
||||||
self.vector_dim = int(getattr(ref_model.base_model.config, "hidden_size"))
|
|
||||||
logger.info(f"向量维度: {self.vector_dim}")
|
|
||||||
logger.info("嵌入模型加载成功")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"嵌入模型加载失败: {str(e)}")
|
|
||||||
raise RuntimeError(f"嵌入模型加载失败: {str(e)}")
|
|
||||||
|
|
||||||
def _init_index(self):
|
|
||||||
"""初始化FAISS索引"""
|
|
||||||
index_file = f"{self.index_path}.index"
|
|
||||||
if os.path.exists(index_file):
|
|
||||||
logger.info(f"加载现有索引: {index_file}")
|
|
||||||
try:
|
|
||||||
self.index = faiss.read_index(index_file)
|
|
||||||
logger.info(f"索引加载成功,包含{self.index.ntotal}个向量")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"索引加载失败: {str(e)}")
|
|
||||||
logger.info("创建新索引...")
|
|
||||||
self.index = faiss.IndexFlatL2(self.vector_dim)
|
|
||||||
else:
|
|
||||||
logger.info(f"创建新索引,维度: {self.vector_dim}")
|
|
||||||
self.index = faiss.IndexFlatL2(self.vector_dim)
|
|
||||||
|
|
||||||
# 加载元数据
|
|
||||||
self.metadata = {}
|
|
||||||
metadata_file = f"{self.index_path}_metadata.json"
|
|
||||||
if os.path.exists(metadata_file):
|
|
||||||
try:
|
|
||||||
with open(metadata_file, 'r', encoding='utf-8') as f:
|
|
||||||
self.metadata = json.load(f)
|
|
||||||
logger.info(f"元数据加载成功,包含{len(self.metadata)}条记录")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"元数据加载失败: {str(e)}")
|
|
||||||
|
|
||||||
def encode_text(self, text: Union[str, List[str]]) -> np.ndarray:
|
|
||||||
"""编码文本为向量(使用 OpsMMEmbeddingV1,多GPU并行)"""
|
|
||||||
if isinstance(text, str):
|
|
||||||
text = [text]
|
|
||||||
if len(text) == 0:
|
|
||||||
return np.zeros((0, self.vector_dim))
|
|
||||||
|
|
||||||
# 单条或单模型直接处理
|
|
||||||
if len(self.models) == 1 or len(text) == 1:
|
|
||||||
with torch.inference_mode():
|
|
||||||
emb = self.models[0].get_text_embeddings(texts=text)
|
|
||||||
arr = emb.detach().float().cpu().numpy()
|
|
||||||
return arr[0] if len(text) == 1 else arr
|
|
||||||
|
|
||||||
# 多GPU并行:按设备平均切分
|
|
||||||
shard_count = len(self.models)
|
|
||||||
shards: List[List[str]] = []
|
|
||||||
for i in range(shard_count):
|
|
||||||
shards.append([])
|
|
||||||
for idx, t in enumerate(text):
|
|
||||||
shards[idx % shard_count].append(t)
|
|
||||||
|
|
||||||
results: List[np.ndarray] = [None] * shard_count # type: ignore
|
|
||||||
|
|
||||||
def run_shard(model: OpsMMEmbeddingV1, shard_texts: List[str], shard_idx: int):
|
|
||||||
if len(shard_texts) == 0:
|
|
||||||
results[shard_idx] = np.zeros((0, self.vector_dim))
|
|
||||||
return
|
|
||||||
with torch.inference_mode():
|
|
||||||
emb = model.get_text_embeddings(texts=shard_texts)
|
|
||||||
results[shard_idx] = emb.detach().float().cpu().numpy()
|
|
||||||
|
|
||||||
with ThreadPoolExecutor(max_workers=shard_count) as ex:
|
|
||||||
for i, (model, shard_texts) in enumerate(zip(self.models, shards)):
|
|
||||||
ex.submit(run_shard, model, shard_texts, i)
|
|
||||||
|
|
||||||
# 还原原始顺序(按 round-robin 拼接)
|
|
||||||
merged: List[np.ndarray] = []
|
|
||||||
max_len = max(len(s) for s in shards)
|
|
||||||
for k in range(max_len):
|
|
||||||
for i in range(shard_count):
|
|
||||||
if k < len(shards[i]) and results[i].shape[0] > 0:
|
|
||||||
merged.append(results[i][k:k+1])
|
|
||||||
if len(merged) == 0:
|
|
||||||
return np.zeros((0, self.vector_dim))
|
|
||||||
return np.vstack(merged)
|
|
||||||
|
|
||||||
def encode_image(self, image: Union[Image.Image, List[Image.Image]]) -> np.ndarray:
|
|
||||||
"""编码图像为向量(使用 OpsMMEmbeddingV1,多GPU并行)"""
|
|
||||||
try:
|
|
||||||
# 规范为列表
|
|
||||||
if isinstance(image, Image.Image):
|
|
||||||
images: List[Image.Image] = [image]
|
|
||||||
else:
|
|
||||||
images = image
|
|
||||||
if not images:
|
|
||||||
logger.error("encode_image: 图像列表为空")
|
|
||||||
return np.zeros((0, self.vector_dim))
|
|
||||||
|
|
||||||
# 强制为 RGB
|
|
||||||
rgb_images = [img.convert('RGB') if img.mode != 'RGB' else img for img in images]
|
|
||||||
|
|
||||||
# 单张或单模型直接处理
|
|
||||||
if len(self.models) == 1 or len(rgb_images) == 1:
|
|
||||||
with torch.inference_mode():
|
|
||||||
emb = self.models[0].get_image_embeddings(images=rgb_images)
|
|
||||||
return emb.detach().float().cpu().numpy()
|
|
||||||
|
|
||||||
# 多GPU并行:按设备平均切分
|
|
||||||
shard_count = len(self.models)
|
|
||||||
shards: List[List[Image.Image]] = [[] for _ in range(shard_count)]
|
|
||||||
for idx, img in enumerate(rgb_images):
|
|
||||||
shards[idx % shard_count].append(img)
|
|
||||||
|
|
||||||
results: List[np.ndarray] = [None] * shard_count # type: ignore
|
|
||||||
|
|
||||||
def run_shard(model: OpsMMEmbeddingV1, shard_imgs: List[Image.Image], shard_idx: int):
|
|
||||||
if len(shard_imgs) == 0:
|
|
||||||
results[shard_idx] = np.zeros((0, self.vector_dim))
|
|
||||||
return
|
|
||||||
with torch.inference_mode():
|
|
||||||
emb = model.get_image_embeddings(images=shard_imgs)
|
|
||||||
results[shard_idx] = emb.detach().float().cpu().numpy()
|
|
||||||
|
|
||||||
with ThreadPoolExecutor(max_workers=shard_count) as ex:
|
|
||||||
for i, (model, shard_imgs) in enumerate(zip(self.models, shards)):
|
|
||||||
ex.submit(run_shard, model, shard_imgs, i)
|
|
||||||
|
|
||||||
# 还原原始顺序(按 round-robin 拼接)
|
|
||||||
merged: List[np.ndarray] = []
|
|
||||||
max_len = max(len(s) for s in shards)
|
|
||||||
for k in range(max_len):
|
|
||||||
for i in range(shard_count):
|
|
||||||
if k < len(shards[i]) and results[i].shape[0] > 0:
|
|
||||||
merged.append(results[i][k:k+1])
|
|
||||||
if len(merged) == 0:
|
|
||||||
return np.zeros((0, self.vector_dim))
|
|
||||||
return np.vstack(merged)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"encode_image: 异常: {str(e)}")
|
|
||||||
return np.zeros((0, self.vector_dim))
|
|
||||||
|
|
||||||
def add_texts(
|
|
||||||
self,
|
|
||||||
texts: List[str],
|
|
||||||
metadatas: Optional[List[Dict[str, Any]]] = None
|
|
||||||
) -> List[str]:
|
|
||||||
"""
|
|
||||||
添加文本到检索系统
|
|
||||||
|
|
||||||
Args:
|
|
||||||
texts: 文本列表
|
|
||||||
metadatas: 元数据列表,每个元素是一个字典
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
添加的文本ID列表
|
|
||||||
"""
|
|
||||||
if not texts:
|
|
||||||
return []
|
|
||||||
|
|
||||||
if metadatas is None:
|
|
||||||
metadatas = [{} for _ in range(len(texts))]
|
|
||||||
|
|
||||||
if len(texts) != len(metadatas):
|
|
||||||
raise ValueError("texts和metadatas长度必须相同")
|
|
||||||
|
|
||||||
# 编码文本
|
|
||||||
text_embeddings = self.encode_text(texts)
|
|
||||||
|
|
||||||
# 准备元数据
|
|
||||||
start_id = self.index.ntotal
|
|
||||||
ids = list(range(start_id, start_id + len(texts)))
|
|
||||||
|
|
||||||
# 添加到索引
|
|
||||||
self.index.add(np.array(text_embeddings).astype('float32'))
|
|
||||||
|
|
||||||
# 保存元数据
|
|
||||||
for i, id in enumerate(ids):
|
|
||||||
self.metadata[str(id)] = {
|
|
||||||
"text": texts[i],
|
|
||||||
"type": "text",
|
|
||||||
**metadatas[i]
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.info(f"成功添加{len(ids)}条文本到检索系统")
|
|
||||||
return [str(id) for id in ids]
|
|
||||||
|
|
||||||
def add_images(
|
|
||||||
self,
|
|
||||||
images: List[Image.Image],
|
|
||||||
metadatas: Optional[List[Dict[str, Any]]] = None,
|
|
||||||
image_paths: Optional[List[str]] = None
|
|
||||||
) -> List[str]:
|
|
||||||
"""
|
|
||||||
添加图像到检索系统
|
|
||||||
|
|
||||||
Args:
|
|
||||||
images: PIL图像列表
|
|
||||||
metadatas: 元数据列表,每个元素是一个字典
|
|
||||||
image_paths: 图像路径列表,用于保存到元数据
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
添加的图像ID列表
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
logger.info(f"add_images: 开始添加图像,数量: {len(images) if images else 0}")
|
|
||||||
|
|
||||||
# 检查图像列表
|
|
||||||
if not images or len(images) == 0:
|
|
||||||
logger.warning("add_images: 图像列表为空")
|
|
||||||
return []
|
|
||||||
|
|
||||||
# 准备元数据
|
|
||||||
if metadatas is None:
|
|
||||||
logger.info("add_images: 创建默认元数据")
|
|
||||||
metadatas = [{} for _ in range(len(images))]
|
|
||||||
|
|
||||||
# 检查长度一致性
|
|
||||||
if len(images) != len(metadatas):
|
|
||||||
logger.error(f"add_images: 长度不一致 - images: {len(images)}, metadatas: {len(metadatas)}")
|
|
||||||
raise ValueError("images和metadatas长度必须相同")
|
|
||||||
|
|
||||||
# 编码图像
|
|
||||||
logger.info("add_images: 编码图像")
|
|
||||||
image_embeddings = self.encode_image(images)
|
|
||||||
|
|
||||||
# 检查编码结果
|
|
||||||
if image_embeddings.shape[0] == 0:
|
|
||||||
logger.error("add_images: 图像编码失败,返回空数组")
|
|
||||||
return []
|
|
||||||
|
|
||||||
# 准备元数据
|
|
||||||
start_id = self.index.ntotal
|
|
||||||
ids = list(range(start_id, start_id + len(images)))
|
|
||||||
logger.info(f"add_images: 生成索引ID: {start_id} - {start_id + len(images) - 1}")
|
|
||||||
|
|
||||||
# 添加到索引
|
|
||||||
logger.info(f"add_images: 添加向量到FAISS索引,形状: {image_embeddings.shape}")
|
|
||||||
self.index.add(np.array(image_embeddings).astype('float32'))
|
|
||||||
|
|
||||||
# 保存元数据
|
|
||||||
for i, id in enumerate(ids):
|
|
||||||
try:
|
|
||||||
metadata = {
|
|
||||||
"type": "image",
|
|
||||||
"width": images[i].width,
|
|
||||||
"height": images[i].height,
|
|
||||||
**metadatas[i]
|
|
||||||
}
|
|
||||||
|
|
||||||
if image_paths and i < len(image_paths):
|
|
||||||
metadata["path"] = image_paths[i]
|
|
||||||
|
|
||||||
self.metadata[str(id)] = metadata
|
|
||||||
logger.debug(f"add_images: 保存元数据成功 - ID: {id}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"add_images: 保存元数据失败 - ID: {id}, 错误: {str(e)}")
|
|
||||||
|
|
||||||
logger.info(f"add_images: 成功添加{len(ids)}张图像到检索系统")
|
|
||||||
return [str(id) for id in ids]
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"add_images: 添加图像异常: {str(e)}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
def search_by_text(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
k: int = 5,
|
|
||||||
filter_type: Optional[str] = None
|
|
||||||
) -> List[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
文本搜索
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: 查询文本
|
|
||||||
k: 返回结果数量
|
|
||||||
filter_type: 过滤类型,可选值: "text", "image", None(不过滤)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
搜索结果列表,每个元素包含相似项和分数
|
|
||||||
"""
|
|
||||||
# 编码查询文本
|
|
||||||
query_embedding = self.encode_text(query)
|
|
||||||
|
|
||||||
# 执行搜索
|
|
||||||
return self._search(query_embedding, k, filter_type)
|
|
||||||
|
|
||||||
def search_by_image(
|
|
||||||
self,
|
|
||||||
image: Image.Image,
|
|
||||||
k: int = 5,
|
|
||||||
filter_type: Optional[str] = None
|
|
||||||
) -> List[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
图像搜索
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image: 查询图像
|
|
||||||
k: 返回结果数量
|
|
||||||
filter_type: 过滤类型,可选值: "text", "image", None(不过滤)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
搜索结果列表,每个元素包含相似项和分数
|
|
||||||
"""
|
|
||||||
# 编码查询图像
|
|
||||||
query_embedding = self.encode_image(image)
|
|
||||||
|
|
||||||
# 执行搜索
|
|
||||||
return self._search(query_embedding, k, filter_type)
|
|
||||||
|
|
||||||
def _search(
|
|
||||||
self,
|
|
||||||
query_embedding: np.ndarray,
|
|
||||||
k: int = 5,
|
|
||||||
filter_type: Optional[str] = None
|
|
||||||
) -> List[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
执行搜索
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query_embedding: 查询向量
|
|
||||||
k: 返回结果数量
|
|
||||||
filter_type: 过滤类型,可选值: "text", "image", None(不过滤)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
搜索结果列表
|
|
||||||
"""
|
|
||||||
if self.index.ntotal == 0:
|
|
||||||
return []
|
|
||||||
|
|
||||||
# 确保查询向量是2D数组
|
|
||||||
if len(query_embedding.shape) == 1:
|
|
||||||
query_embedding = query_embedding.reshape(1, -1)
|
|
||||||
|
|
||||||
# 执行搜索,获取更多结果以便过滤
|
|
||||||
actual_k = k * 3 if filter_type else k
|
|
||||||
actual_k = min(actual_k, self.index.ntotal)
|
|
||||||
distances, indices = self.index.search(query_embedding.astype('float32'), actual_k)
|
|
||||||
|
|
||||||
# 处理结果
|
|
||||||
results = []
|
|
||||||
for i in range(len(indices[0])):
|
|
||||||
idx = indices[0][i]
|
|
||||||
if idx < 0: # FAISS可能返回-1表示无效索引
|
|
||||||
continue
|
|
||||||
|
|
||||||
vector_id = str(idx)
|
|
||||||
if vector_id in self.metadata:
|
|
||||||
item = self.metadata[vector_id]
|
|
||||||
|
|
||||||
# 如果指定了过滤类型,则只返回该类型的结果
|
|
||||||
if filter_type and item.get("type") != filter_type:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 添加距离和分数
|
|
||||||
result = item.copy()
|
|
||||||
result["distance"] = float(distances[0][i])
|
|
||||||
result["score"] = float(1.0 / (1.0 + distances[0][i]))
|
|
||||||
results.append(result)
|
|
||||||
|
|
||||||
# 如果已经收集了足够的结果,则停止
|
|
||||||
if len(results) >= k:
|
|
||||||
break
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
def save_index(self):
|
|
||||||
"""保存索引和元数据"""
|
|
||||||
# 保存索引
|
|
||||||
index_file = f"{self.index_path}.index"
|
|
||||||
try:
|
|
||||||
faiss.write_index(self.index, index_file)
|
|
||||||
logger.info(f"索引保存成功: {index_file}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"索引保存失败: {str(e)}")
|
|
||||||
|
|
||||||
# 保存元数据
|
|
||||||
metadata_file = f"{self.index_path}_metadata.json"
|
|
||||||
try:
|
|
||||||
with open(metadata_file, 'w', encoding='utf-8') as f:
|
|
||||||
json.dump(self.metadata, f, ensure_ascii=False, indent=2)
|
|
||||||
logger.info(f"元数据保存成功: {metadata_file}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"元数据保存失败: {str(e)}")
|
|
||||||
|
|
||||||
def get_stats(self) -> Dict[str, Any]:
|
|
||||||
"""获取检索系统统计信息"""
|
|
||||||
text_count = sum(1 for v in self.metadata.values() if v.get("type") == "text")
|
|
||||||
image_count = sum(1 for v in self.metadata.values() if v.get("type") == "image")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"total_vectors": self.index.ntotal,
|
|
||||||
"text_count": text_count,
|
|
||||||
"image_count": image_count,
|
|
||||||
"vector_dimension": self.vector_dim,
|
|
||||||
"index_path": self.index_path,
|
|
||||||
"model_path": self.model_path
|
|
||||||
}
|
|
||||||
|
|
||||||
def clear_index(self):
|
|
||||||
"""清空索引"""
|
|
||||||
logger.info(f"清空索引: {self.index_path}")
|
|
||||||
|
|
||||||
# 重新创建索引
|
|
||||||
self.index = faiss.IndexFlatL2(self.vector_dim)
|
|
||||||
|
|
||||||
# 清空元数据
|
|
||||||
self.metadata = {}
|
|
||||||
|
|
||||||
# 保存空索引
|
|
||||||
self.save_index()
|
|
||||||
|
|
||||||
logger.info(f"索引已清空: {self.index_path}")
|
|
||||||
return True
|
|
||||||
|
|
||||||
def list_items(self) -> List[Dict[str, Any]]:
|
|
||||||
"""列出所有索引项"""
|
|
||||||
items = []
|
|
||||||
|
|
||||||
for item_id, metadata in self.metadata.items():
|
|
||||||
item = metadata.copy()
|
|
||||||
item['id'] = item_id
|
|
||||||
items.append(item)
|
|
||||||
|
|
||||||
return items
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
"""析构函数,确保资源被正确释放并自动保存索引"""
|
|
||||||
try:
|
|
||||||
if hasattr(self, 'model'):
|
|
||||||
del self.model
|
|
||||||
self._clear_all_gpu_memory()
|
|
||||||
if hasattr(self, 'index') and self.index is not None:
|
|
||||||
logger.info("系统关闭前自动保存索引")
|
|
||||||
self.save_index()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"析构时保存索引失败: {str(e)}")
|
|
||||||
632
multimodal_retrieval_multigpu.py
Normal file
@ -0,0 +1,632 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
import faiss
|
||||||
|
from transformers import AutoModel, AutoProcessor, AutoTokenizer
|
||||||
|
from typing import List, Union, Tuple, Dict
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
import logging
|
||||||
|
import gc
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
import threading
|
||||||
|
|
||||||
|
# 设置日志
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class MultiGPUMultimodalRetrieval:
|
||||||
|
"""多GPU优化的多模态检索系统,支持文搜图、文搜文、图搜图、图搜文"""
|
||||||
|
|
||||||
|
def __init__(self, model_name: str = "OpenSearch-AI/Ops-MM-embedding-v1-7B",
|
||||||
|
use_all_gpus: bool = True, gpu_ids: List[int] = None, min_memory_gb=12):
|
||||||
|
"""
|
||||||
|
初始化多GPU多模态检索系统
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: 模型名称
|
||||||
|
use_all_gpus: 是否使用所有可用GPU
|
||||||
|
gpu_ids: 指定使用的GPU ID列表
|
||||||
|
min_memory_gb: 最小可用内存(GB)
|
||||||
|
"""
|
||||||
|
self.model_name = model_name
|
||||||
|
|
||||||
|
# 设置GPU设备
|
||||||
|
self._setup_devices(use_all_gpus, gpu_ids, min_memory_gb)
|
||||||
|
|
||||||
|
# 清理GPU内存
|
||||||
|
self._clear_all_gpu_memory()
|
||||||
|
|
||||||
|
logger.info(f"正在加载模型到多GPU: {self.device_ids}")
|
||||||
|
|
||||||
|
# 加载模型和处理器
|
||||||
|
self.model = None
|
||||||
|
self.tokenizer = None
|
||||||
|
self.processor = None
|
||||||
|
self._load_model_multigpu()
|
||||||
|
|
||||||
|
# 初始化索引
|
||||||
|
self.text_index = None
|
||||||
|
self.image_index = None
|
||||||
|
self.text_data = []
|
||||||
|
self.image_data = []
|
||||||
|
|
||||||
|
logger.info("多GPU模型加载完成")
|
||||||
|
|
||||||
|
def _setup_devices(self, use_all_gpus: bool, gpu_ids: List[int], min_memory_gb=12):
|
||||||
|
"""设置GPU设备"""
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
raise RuntimeError("CUDA不可用,无法使用多GPU")
|
||||||
|
|
||||||
|
total_gpus = torch.cuda.device_count()
|
||||||
|
logger.info(f"检测到 {total_gpus} 个GPU")
|
||||||
|
|
||||||
|
# 检查是否设置了CUDA_VISIBLE_DEVICES
|
||||||
|
cuda_visible_devices = os.environ.get('CUDA_VISIBLE_DEVICES')
|
||||||
|
if cuda_visible_devices is not None:
|
||||||
|
# 如果设置了CUDA_VISIBLE_DEVICES,使用可见的GPU
|
||||||
|
visible_gpu_count = len(cuda_visible_devices.split(','))
|
||||||
|
self.device_ids = list(range(visible_gpu_count))
|
||||||
|
logger.info(f"使用CUDA_VISIBLE_DEVICES指定的GPU: {cuda_visible_devices}")
|
||||||
|
elif use_all_gpus:
|
||||||
|
self.device_ids = self._select_best_gpus(min_memory_gb)
|
||||||
|
elif gpu_ids:
|
||||||
|
self.device_ids = gpu_ids
|
||||||
|
else:
|
||||||
|
self.device_ids = [0]
|
||||||
|
|
||||||
|
self.num_gpus = len(self.device_ids)
|
||||||
|
self.primary_device = f"cuda:{self.device_ids[0]}"
|
||||||
|
|
||||||
|
logger.info(f"使用GPU: {self.device_ids}, 主设备: {self.primary_device}")
|
||||||
|
|
||||||
|
def _clear_all_gpu_memory(self):
|
||||||
|
"""清理所有GPU内存"""
|
||||||
|
for gpu_id in self.device_ids:
|
||||||
|
torch.cuda.set_device(gpu_id)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
gc.collect()
|
||||||
|
logger.info("所有GPU内存已清理")
|
||||||
|
|
||||||
|
def _load_model_multigpu(self):
|
||||||
|
"""加载模型到多GPU"""
|
||||||
|
try:
|
||||||
|
# 设置环境变量优化内存使用
|
||||||
|
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
|
||||||
|
|
||||||
|
# 清理GPU内存
|
||||||
|
self._clear_gpu_memory()
|
||||||
|
|
||||||
|
# 首先尝试使用accelerate的自动设备映射
|
||||||
|
if self.num_gpus > 1:
|
||||||
|
# 设置最大内存限制(每个GPU 18GB,留出缓冲)
|
||||||
|
max_memory = {i: "18GiB" for i in self.device_ids}
|
||||||
|
|
||||||
|
logger.info(f"正在加载模型到多GPU: {self.device_ids}")
|
||||||
|
self.model = AutoModel.from_pretrained(
|
||||||
|
self.model_name,
|
||||||
|
trust_remote_code=True,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
device_map="auto",
|
||||||
|
max_memory=max_memory,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
offload_folder="./offload"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 单GPU加载
|
||||||
|
self.model = AutoModel.from_pretrained(
|
||||||
|
self.model_name,
|
||||||
|
trust_remote_code=True,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
device_map=self.primary_device
|
||||||
|
)
|
||||||
|
|
||||||
|
# 加载分词器和处理器到主设备
|
||||||
|
try:
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
self.model_name,
|
||||||
|
trust_remote_code=True
|
||||||
|
)
|
||||||
|
logger.info("Tokenizer加载成功")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Tokenizer加载失败: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 加载处理器用于图像处理
|
||||||
|
try:
|
||||||
|
self.processor = AutoProcessor.from_pretrained(
|
||||||
|
self.model_name,
|
||||||
|
trust_remote_code=True
|
||||||
|
)
|
||||||
|
logger.info("Processor加载成功")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Processor加载失败: {e}")
|
||||||
|
# 如果AutoProcessor失败,尝试使用tokenizer作为fallback
|
||||||
|
logger.info("尝试使用tokenizer作为processor的fallback")
|
||||||
|
self.processor = self.tokenizer
|
||||||
|
|
||||||
|
logger.info(f"模型已成功加载到设备: {self.model.hf_device_map if hasattr(self.model, 'hf_device_map') else self.primary_device}")
|
||||||
|
logger.info("多GPU模型加载完成")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"多GPU模型加载失败: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _clear_gpu_memory(self):
|
||||||
|
"""清理GPU内存"""
|
||||||
|
for gpu_id in self.device_ids:
|
||||||
|
torch.cuda.set_device(gpu_id)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
gc.collect()
|
||||||
|
logger.info("GPU内存已清理")
|
||||||
|
|
||||||
|
def _get_gpu_memory_info(self):
|
||||||
|
"""获取GPU内存使用情况"""
|
||||||
|
try:
|
||||||
|
import subprocess
|
||||||
|
result = subprocess.run(['nvidia-smi', '--query-gpu=memory.used,memory.total', '--format=csv,nounits,noheader'],
|
||||||
|
capture_output=True, text=True, check=True)
|
||||||
|
lines = result.stdout.strip().split('\n')
|
||||||
|
gpu_info = []
|
||||||
|
for i, line in enumerate(lines):
|
||||||
|
used, total = map(int, line.split(', '))
|
||||||
|
free = total - used
|
||||||
|
gpu_info.append({
|
||||||
|
'gpu_id': i,
|
||||||
|
'used': used,
|
||||||
|
'total': total,
|
||||||
|
'free': free,
|
||||||
|
'usage_percent': (used / total) * 100
|
||||||
|
})
|
||||||
|
return gpu_info
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"无法获取GPU内存信息: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _select_best_gpus(self, min_memory_gb=12):
|
||||||
|
"""选择内存充足的GPU"""
|
||||||
|
gpu_info = self._get_gpu_memory_info()
|
||||||
|
if not gpu_info:
|
||||||
|
return list(range(torch.cuda.device_count()))
|
||||||
|
|
||||||
|
# 按可用内存排序
|
||||||
|
gpu_info.sort(key=lambda x: x['free'], reverse=True)
|
||||||
|
|
||||||
|
# 选择内存充足的GPU
|
||||||
|
min_memory_mb = min_memory_gb * 1024
|
||||||
|
suitable_gpus = []
|
||||||
|
|
||||||
|
for gpu in gpu_info:
|
||||||
|
if gpu['free'] >= min_memory_mb:
|
||||||
|
suitable_gpus.append(gpu['gpu_id'])
|
||||||
|
logger.info(f"GPU {gpu['gpu_id']}: {gpu['free']}MB 可用 (合适)")
|
||||||
|
else:
|
||||||
|
logger.warning(f"GPU {gpu['gpu_id']}: {gpu['free']}MB 可用 (不足)")
|
||||||
|
|
||||||
|
if not suitable_gpus:
|
||||||
|
# 如果没有GPU满足要求,选择可用内存最多的
|
||||||
|
logger.warning(f"没有GPU有足够内存({min_memory_gb}GB),选择可用内存最多的GPU")
|
||||||
|
suitable_gpus = [gpu_info[0]['gpu_id']]
|
||||||
|
|
||||||
|
return suitable_gpus
|
||||||
|
|
||||||
|
def encode_text_batch(self, texts: List[str]) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
批量编码文本为向量(多GPU优化)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: 文本列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
文本向量
|
||||||
|
"""
|
||||||
|
if not texts:
|
||||||
|
return np.array([])
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
# 预处理输入
|
||||||
|
inputs = self.tokenizer(
|
||||||
|
text=texts,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
max_length=512
|
||||||
|
)
|
||||||
|
|
||||||
|
# 将输入移动到主设备
|
||||||
|
inputs = {k: v.to(self.primary_device) for k, v in inputs.items()}
|
||||||
|
|
||||||
|
# 前向传播
|
||||||
|
outputs = self.model(**inputs)
|
||||||
|
embeddings = outputs.last_hidden_state.mean(dim=1)
|
||||||
|
|
||||||
|
# 清理GPU内存
|
||||||
|
del inputs, outputs
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
return embeddings.cpu().numpy().astype(np.float32)
|
||||||
|
|
||||||
|
def encode_image_batch(self, images: List[Union[str, Image.Image]]) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
批量编码图像为向量
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images: 图像路径或PIL图像列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
图像向量
|
||||||
|
"""
|
||||||
|
if not images:
|
||||||
|
return np.array([])
|
||||||
|
|
||||||
|
# 预处理图像
|
||||||
|
processed_images = []
|
||||||
|
for img in images:
|
||||||
|
if isinstance(img, str):
|
||||||
|
img = Image.open(img).convert('RGB')
|
||||||
|
elif isinstance(img, Image.Image):
|
||||||
|
img = img.convert('RGB')
|
||||||
|
processed_images.append(img)
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info(f"处理 {len(processed_images)} 张图像")
|
||||||
|
|
||||||
|
# 使用多模态模型生成图像embedding
|
||||||
|
# 为每张图像创建简单的文本描述作为输入
|
||||||
|
conversations = []
|
||||||
|
for i in range(len(processed_images)):
|
||||||
|
# 使用简化的对话格式
|
||||||
|
conversation = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image", "image": processed_images[i]},
|
||||||
|
{"type": "text", "text": "What is in this image?"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
conversations.append(conversation)
|
||||||
|
|
||||||
|
# 使用processor处理
|
||||||
|
try:
|
||||||
|
# 尝试使用apply_chat_template方法
|
||||||
|
texts = []
|
||||||
|
for conv in conversations:
|
||||||
|
text = self.processor.apply_chat_template(conv, tokenize=False, add_generation_prompt=False)
|
||||||
|
texts.append(text)
|
||||||
|
|
||||||
|
# 处理文本和图像
|
||||||
|
inputs = self.processor(
|
||||||
|
text=texts,
|
||||||
|
images=processed_images,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# 移动到GPU
|
||||||
|
inputs = {k: v.to(self.primary_device) for k, v in inputs.items()}
|
||||||
|
|
||||||
|
# 获取模型输出
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = self.model(**inputs)
|
||||||
|
embeddings = outputs.last_hidden_state.mean(dim=1)
|
||||||
|
|
||||||
|
# 转换为numpy数组
|
||||||
|
embeddings = embeddings.cpu().numpy().astype(np.float32)
|
||||||
|
|
||||||
|
except Exception as inner_e:
|
||||||
|
logger.warning(f"多模态模型图像编码失败,使用文本模式: {inner_e}")
|
||||||
|
return np.zeros((len(processed_images), 3584), dtype=np.float32)
|
||||||
|
# 如果多模态失败,使用纯文本描述作为fallback
|
||||||
|
image_descriptions = ["An image" for _ in processed_images]
|
||||||
|
text_inputs = self.processor(
|
||||||
|
text=image_descriptions,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
max_length=512
|
||||||
|
)
|
||||||
|
text_inputs = {k: v.to(self.primary_device) for k, v in text_inputs.items()}
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = self.model(**text_inputs)
|
||||||
|
embeddings = outputs.last_hidden_state.mean(dim=1)
|
||||||
|
|
||||||
|
embeddings = embeddings.cpu().numpy().astype(np.float32)
|
||||||
|
|
||||||
|
logger.info(f"生成图像embeddings: {embeddings.shape}")
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"图像编码失败: {e}")
|
||||||
|
# 返回与文本embedding维度一致的零向量作为fallback
|
||||||
|
embedding_dim = 3584
|
||||||
|
embeddings = np.zeros((len(processed_images), embedding_dim), dtype=np.float32)
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
def build_text_index_parallel(self, texts: List[str], save_path: str = None):
|
||||||
|
"""
|
||||||
|
并行构建文本索引(多GPU优化)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: 文本列表
|
||||||
|
save_path: 索引保存路径
|
||||||
|
"""
|
||||||
|
logger.info(f"正在并行构建文本索引,共 {len(texts)} 条文本")
|
||||||
|
|
||||||
|
# 根据GPU数量调整批次大小
|
||||||
|
batch_size = max(4, 16 // self.num_gpus)
|
||||||
|
all_embeddings = []
|
||||||
|
|
||||||
|
# 分批处理
|
||||||
|
for i in range(0, len(texts), batch_size):
|
||||||
|
batch_texts = texts[i:i+batch_size]
|
||||||
|
|
||||||
|
try:
|
||||||
|
embeddings = self.encode_text_batch(batch_texts)
|
||||||
|
all_embeddings.append(embeddings)
|
||||||
|
|
||||||
|
# 显示进度
|
||||||
|
if (i // batch_size + 1) % 10 == 0:
|
||||||
|
logger.info(f"已处理 {i + len(batch_texts)}/{len(texts)} 条文本")
|
||||||
|
|
||||||
|
except torch.cuda.OutOfMemoryError:
|
||||||
|
logger.warning(f"GPU内存不足,跳过批次 {i}-{i+len(batch_texts)}")
|
||||||
|
self._clear_all_gpu_memory()
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"处理文本批次时出错: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not all_embeddings:
|
||||||
|
raise ValueError("没有成功处理任何文本")
|
||||||
|
|
||||||
|
# 合并所有嵌入向量
|
||||||
|
embeddings = np.vstack(all_embeddings)
|
||||||
|
|
||||||
|
# 构建FAISS索引
|
||||||
|
dimension = embeddings.shape[1]
|
||||||
|
self.text_index = faiss.IndexFlatIP(dimension)
|
||||||
|
|
||||||
|
# 归一化向量
|
||||||
|
faiss.normalize_L2(embeddings)
|
||||||
|
self.text_index.add(embeddings)
|
||||||
|
|
||||||
|
self.text_data = texts
|
||||||
|
|
||||||
|
if save_path:
|
||||||
|
self._save_index(self.text_index, texts, save_path + "_text")
|
||||||
|
|
||||||
|
logger.info("文本索引构建完成")
|
||||||
|
|
||||||
|
def build_image_index_parallel(self, image_paths: List[str], save_path: str = None):
|
||||||
|
"""
|
||||||
|
并行构建图像索引(多GPU优化)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_paths: 图像路径列表
|
||||||
|
save_path: 索引保存路径
|
||||||
|
"""
|
||||||
|
logger.info(f"正在并行构建图像索引,共 {len(image_paths)} 张图像")
|
||||||
|
|
||||||
|
# 图像处理使用更小的批次
|
||||||
|
batch_size = max(2, 8 // self.num_gpus)
|
||||||
|
all_embeddings = []
|
||||||
|
|
||||||
|
for i in range(0, len(image_paths), batch_size):
|
||||||
|
batch_images = image_paths[i:i+batch_size]
|
||||||
|
|
||||||
|
try:
|
||||||
|
embeddings = self.encode_image_batch(batch_images)
|
||||||
|
all_embeddings.append(embeddings)
|
||||||
|
|
||||||
|
# 显示进度
|
||||||
|
if (i // batch_size + 1) % 5 == 0:
|
||||||
|
logger.info(f"已处理 {i + len(batch_images)}/{len(image_paths)} 张图像")
|
||||||
|
|
||||||
|
except torch.cuda.OutOfMemoryError:
|
||||||
|
logger.warning(f"GPU内存不足,跳过图像批次 {i}-{i+len(batch_images)}")
|
||||||
|
self._clear_all_gpu_memory()
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"处理图像批次时出错: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not all_embeddings:
|
||||||
|
raise ValueError("没有成功处理任何图像")
|
||||||
|
|
||||||
|
embeddings = np.vstack(all_embeddings)
|
||||||
|
|
||||||
|
# 构建FAISS索引
|
||||||
|
dimension = embeddings.shape[1]
|
||||||
|
self.image_index = faiss.IndexFlatIP(dimension)
|
||||||
|
|
||||||
|
faiss.normalize_L2(embeddings)
|
||||||
|
self.image_index.add(embeddings)
|
||||||
|
|
||||||
|
self.image_data = image_paths
|
||||||
|
|
||||||
|
if save_path:
|
||||||
|
self._save_index(self.image_index, image_paths, save_path + "_image")
|
||||||
|
|
||||||
|
logger.info("图像索引构建完成")
|
||||||
|
|
||||||
|
def search_text_by_text(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
|
||||||
|
"""文搜文:使用文本查询搜索相似文本"""
|
||||||
|
if self.text_index is None:
|
||||||
|
raise ValueError("文本索引未构建,请先调用 build_text_index_parallel")
|
||||||
|
|
||||||
|
query_embedding = self.encode_text_batch([query]).astype(np.float32)
|
||||||
|
faiss.normalize_L2(query_embedding)
|
||||||
|
|
||||||
|
scores, indices = self.text_index.search(query_embedding, top_k)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for score, idx in zip(scores[0], indices[0]):
|
||||||
|
if idx != -1:
|
||||||
|
results.append((self.text_data[idx], float(score)))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def search_images_by_text(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
|
||||||
|
"""文搜图:使用文本查询搜索相似图像"""
|
||||||
|
if self.image_index is None:
|
||||||
|
raise ValueError("图像索引未构建,请先调用 build_image_index_parallel")
|
||||||
|
|
||||||
|
query_embedding = self.encode_text_batch([query]).astype(np.float32)
|
||||||
|
faiss.normalize_L2(query_embedding)
|
||||||
|
|
||||||
|
scores, indices = self.image_index.search(query_embedding, top_k)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for score, idx in zip(scores[0], indices[0]):
|
||||||
|
if idx != -1:
|
||||||
|
results.append((self.image_data[idx], float(score)))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def search_images_by_image(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]:
|
||||||
|
"""图搜图:使用图像查询搜索相似图像"""
|
||||||
|
if self.image_index is None:
|
||||||
|
raise ValueError("图像索引未构建,请先调用 build_image_index_parallel")
|
||||||
|
|
||||||
|
query_embedding = self.encode_image_batch([query_image]).astype(np.float32)
|
||||||
|
faiss.normalize_L2(query_embedding)
|
||||||
|
|
||||||
|
scores, indices = self.image_index.search(query_embedding, top_k)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for score, idx in zip(scores[0], indices[0]):
|
||||||
|
if idx != -1:
|
||||||
|
results.append((self.image_data[idx], float(score)))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def search_text_by_image(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]:
|
||||||
|
"""图搜文:使用图像查询搜索相似文本"""
|
||||||
|
if self.text_index is None:
|
||||||
|
raise ValueError("文本索引未构建,请先调用 build_text_index_parallel")
|
||||||
|
|
||||||
|
query_embedding = self.encode_image_batch([query_image]).astype(np.float32)
|
||||||
|
faiss.normalize_L2(query_embedding)
|
||||||
|
|
||||||
|
scores, indices = self.text_index.search(query_embedding, top_k)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for score, idx in zip(scores[0], indices[0]):
|
||||||
|
if idx != -1:
|
||||||
|
results.append((self.text_data[idx], float(score)))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
# Web应用兼容的方法名称
|
||||||
|
def search_text_to_image(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
|
||||||
|
"""文搜图:Web应用兼容方法"""
|
||||||
|
return self.search_images_by_text(query, top_k)
|
||||||
|
|
||||||
|
def search_image_to_image(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]:
|
||||||
|
"""图搜图:Web应用兼容方法"""
|
||||||
|
return self.search_images_by_image(query_image, top_k)
|
||||||
|
|
||||||
|
def search_text_to_text(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
|
||||||
|
"""文搜文:Web应用兼容方法"""
|
||||||
|
return self.search_text_by_text(query, top_k)
|
||||||
|
|
||||||
|
def search_image_to_text(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]:
|
||||||
|
"""图搜文:Web应用兼容方法"""
|
||||||
|
return self.search_text_by_image(query_image, top_k)
|
||||||
|
|
||||||
|
def _save_index(self, index, data, path_prefix):
|
||||||
|
"""保存索引和数据"""
|
||||||
|
faiss.write_index(index, f"{path_prefix}.index")
|
||||||
|
with open(f"{path_prefix}.json", 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
def load_index(self, path_prefix, index_type="text"):
|
||||||
|
"""加载已保存的索引"""
|
||||||
|
index = faiss.read_index(f"{path_prefix}.index")
|
||||||
|
with open(f"{path_prefix}.json", 'r', encoding='utf-8') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
if index_type == "text":
|
||||||
|
self.text_index = index
|
||||||
|
self.text_data = data
|
||||||
|
else:
|
||||||
|
self.image_index = index
|
||||||
|
self.image_data = data
|
||||||
|
|
||||||
|
logger.info(f"已加载 {index_type} 索引")
|
||||||
|
|
||||||
|
def get_gpu_memory_info(self):
|
||||||
|
"""获取所有GPU内存使用信息"""
|
||||||
|
memory_info = {}
|
||||||
|
for gpu_id in self.device_ids:
|
||||||
|
torch.cuda.set_device(gpu_id)
|
||||||
|
allocated = torch.cuda.memory_allocated(gpu_id) / 1024**3
|
||||||
|
cached = torch.cuda.memory_reserved(gpu_id) / 1024**3
|
||||||
|
total = torch.cuda.get_device_properties(gpu_id).total_memory / 1024**3
|
||||||
|
free = total - cached
|
||||||
|
|
||||||
|
memory_info[f"GPU_{gpu_id}"] = {
|
||||||
|
"total": f"{total:.1f}GB",
|
||||||
|
"allocated": f"{allocated:.1f}GB",
|
||||||
|
"cached": f"{cached:.1f}GB",
|
||||||
|
"free": f"{free:.1f}GB"
|
||||||
|
}
|
||||||
|
|
||||||
|
return memory_info
|
||||||
|
|
||||||
|
def check_multigpu_info():
|
||||||
|
"""检查多GPU环境信息"""
|
||||||
|
print("=== 多GPU环境信息 ===")
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
print("❌ CUDA不可用")
|
||||||
|
return
|
||||||
|
|
||||||
|
gpu_count = torch.cuda.device_count()
|
||||||
|
print(f"✅ 检测到 {gpu_count} 个GPU")
|
||||||
|
print(f"CUDA版本: {torch.version.cuda}")
|
||||||
|
print(f"PyTorch版本: {torch.__version__}")
|
||||||
|
|
||||||
|
for i in range(gpu_count):
|
||||||
|
gpu_name = torch.cuda.get_device_name(i)
|
||||||
|
gpu_memory = torch.cuda.get_device_properties(i).total_memory / 1024**3
|
||||||
|
print(f"GPU {i}: {gpu_name} ({gpu_memory:.1f}GB)")
|
||||||
|
|
||||||
|
print("=====================")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 检查多GPU环境
|
||||||
|
check_multigpu_info()
|
||||||
|
|
||||||
|
# 示例使用
|
||||||
|
print("\n正在初始化多GPU多模态检索系统...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
retrieval_system = MultiGPUMultimodalRetrieval()
|
||||||
|
print("✅ 多GPU系统初始化成功!")
|
||||||
|
|
||||||
|
# 显示GPU内存使用情况
|
||||||
|
memory_info = retrieval_system.get_gpu_memory_info()
|
||||||
|
print("\n📊 GPU内存使用情况:")
|
||||||
|
for gpu, info in memory_info.items():
|
||||||
|
print(f" {gpu}: {info['allocated']} / {info['total']} (已用/总计)")
|
||||||
|
|
||||||
|
print("\n🚀 多GPU多模态检索系统就绪!")
|
||||||
|
print("支持的检索模式:")
|
||||||
|
print("1. 文搜文: search_text_by_text()")
|
||||||
|
print("2. 文搜图: search_images_by_text()")
|
||||||
|
print("3. 图搜图: search_images_by_image()")
|
||||||
|
print("4. 图搜文: search_text_by_image()")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ 多GPU系统初始化失败: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
@ -18,45 +18,19 @@ class OpsMMEmbeddingV1(nn.Module):
|
|||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
max_length: Optional[int] = None,
|
max_length: Optional[int] = None,
|
||||||
attn_implementation: Optional[str] = None,
|
attn_implementation: Optional[str] = None,
|
||||||
device_map: Optional[str] = None,
|
|
||||||
load_in_4bit: bool = False,
|
|
||||||
load_in_8bit: bool = False,
|
|
||||||
torch_dtype: Optional[torch.dtype] = torch.bfloat16,
|
|
||||||
processor_min_pixels: int = 128 * 28 * 28,
|
|
||||||
processor_max_pixels: int = 512 * 28 * 28,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.device = device
|
self.device = device
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
self.default_instruction = "You are a helpful assistant."
|
self.default_instruction = "You are a helpful assistant."
|
||||||
from transformers import AutoModelForImageTextToText
|
|
||||||
load_kwargs = dict(
|
|
||||||
torch_dtype=torch_dtype,
|
|
||||||
low_cpu_mem_usage=True,
|
|
||||||
attn_implementation=attn_implementation,
|
|
||||||
)
|
|
||||||
# Quantization options (requires bitsandbytes for 4/8-bit)
|
|
||||||
if load_in_4bit:
|
|
||||||
load_kwargs["load_in_4bit"] = True
|
|
||||||
if load_in_8bit:
|
|
||||||
load_kwargs["load_in_8bit"] = True
|
|
||||||
if device_map is not None:
|
|
||||||
load_kwargs["device_map"] = device_map
|
|
||||||
|
|
||||||
self.base_model = AutoModelForImageTextToText.from_pretrained(
|
self.base_model = AutoModelForImageTextToText.from_pretrained(
|
||||||
model_name,
|
model_name,
|
||||||
**load_kwargs,
|
torch_dtype=torch.bfloat16,
|
||||||
)
|
low_cpu_mem_usage=True,
|
||||||
# Only move to a single device when not using tensor-parallel sharding
|
attn_implementation=attn_implementation,
|
||||||
if device_map is None:
|
).to(self.device)
|
||||||
self.base_model = self.base_model.to(self.device)
|
|
||||||
|
|
||||||
# Use configurable pixel limits to control VRAM usage
|
self.processor = AutoProcessor.from_pretrained(model_name, min_pixels=256 * 28 * 28, max_pixels=1280 * 28 * 28)
|
||||||
self.processor = AutoProcessor.from_pretrained(
|
|
||||||
model_name,
|
|
||||||
min_pixels=processor_min_pixels,
|
|
||||||
max_pixels=processor_max_pixels,
|
|
||||||
)
|
|
||||||
self.processor.tokenizer.padding_side = "left"
|
self.processor.tokenizer.padding_side = "left"
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
||||||
@ -146,11 +120,9 @@ class OpsMMEmbeddingV1(nn.Module):
|
|||||||
input_texts.append(msg)
|
input_texts.append(msg)
|
||||||
input_images.append(processed_image)
|
input_images.append(processed_image)
|
||||||
|
|
||||||
# Only pass images when present; some processors expect paired inputs and
|
# Only pass to processor if we actually have images
|
||||||
# can raise unpack errors if we pass images=None with multi-modal processor.
|
processed_images = input_images if any(img is not None for img in input_images) else None
|
||||||
has_images = any(img is not None for img in input_images)
|
|
||||||
if has_images:
|
|
||||||
processed_images = input_images
|
|
||||||
inputs = self.processor(
|
inputs = self.processor(
|
||||||
text=input_texts,
|
text=input_texts,
|
||||||
images=processed_images,
|
images=processed_images,
|
||||||
@ -159,14 +131,6 @@ class OpsMMEmbeddingV1(nn.Module):
|
|||||||
max_length=self.max_length,
|
max_length=self.max_length,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
inputs = self.processor(
|
|
||||||
text=input_texts,
|
|
||||||
padding=True,
|
|
||||||
truncation=True,
|
|
||||||
max_length=self.max_length,
|
|
||||||
return_tensors="pt",
|
|
||||||
)
|
|
||||||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
|
|||||||
@ -1,367 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
优化的文件处理器
|
|
||||||
支持自动清理、内存处理和流式上传
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import io
|
|
||||||
import tempfile
|
|
||||||
import logging
|
|
||||||
import uuid
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from typing import Dict, List, Optional, Any, Union, BinaryIO
|
|
||||||
from pathlib import Path
|
|
||||||
from PIL import Image
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
# Optional external managers (BOS/Mongo) are disabled in minimal setup
|
|
||||||
# Previously:
|
|
||||||
# from baidu_bos_manager import get_bos_manager
|
|
||||||
# from mongodb_manager import get_mongodb_manager
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
class OptimizedFileHandler:
|
|
||||||
"""优化的文件处理器"""
|
|
||||||
|
|
||||||
# 小文件阈值 (5MB)
|
|
||||||
SMALL_FILE_THRESHOLD = 5 * 1024 * 1024
|
|
||||||
|
|
||||||
# 支持的图像格式
|
|
||||||
SUPPORTED_IMAGE_FORMATS = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp'}
|
|
||||||
|
|
||||||
def __init__(self, local_storage_dir=None):
|
|
||||||
# In minimal setup, BOS and MongoDB are not used
|
|
||||||
self.bos_manager = None
|
|
||||||
self.mongodb_manager = None
|
|
||||||
self.temp_files = set() # 跟踪临时文件
|
|
||||||
self.local_storage_dir = local_storage_dir or tempfile.gettempdir()
|
|
||||||
|
|
||||||
# 确保本地存储目录存在
|
|
||||||
if self.local_storage_dir:
|
|
||||||
os.makedirs(self.local_storage_dir, exist_ok=True)
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def temp_file_context(self, content: bytes = None, suffix: str = None, delete_on_exit: bool = True):
|
|
||||||
"""临时文件上下文管理器,确保自动清理"""
|
|
||||||
temp_fd, temp_path = tempfile.mkstemp(suffix=suffix, dir=self.local_storage_dir)
|
|
||||||
self.temp_files.add(temp_path)
|
|
||||||
|
|
||||||
# 如果提供了内容,写入文件
|
|
||||||
if content is not None:
|
|
||||||
with os.fdopen(temp_fd, 'wb') as f:
|
|
||||||
f.write(content)
|
|
||||||
else:
|
|
||||||
os.close(temp_fd) # 关闭文件描述符
|
|
||||||
|
|
||||||
try:
|
|
||||||
yield temp_path
|
|
||||||
finally:
|
|
||||||
if delete_on_exit and os.path.exists(temp_path):
|
|
||||||
try:
|
|
||||||
os.unlink(temp_path)
|
|
||||||
self.temp_files.discard(temp_path)
|
|
||||||
logger.debug(f"🗑️ 临时文件已清理: {temp_path}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"⚠️ 临时文件清理失败: {temp_path}, {e}")
|
|
||||||
|
|
||||||
def cleanup_all_temp_files(self):
|
|
||||||
"""清理所有跟踪的临时文件"""
|
|
||||||
for temp_path in list(self.temp_files):
|
|
||||||
if os.path.exists(temp_path):
|
|
||||||
try:
|
|
||||||
os.unlink(temp_path)
|
|
||||||
logger.debug(f"🗑️ 清理临时文件: {temp_path}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"⚠️ 清理临时文件失败: {temp_path}, {e}")
|
|
||||||
self.temp_files.clear()
|
|
||||||
|
|
||||||
def get_file_size(self, file_obj) -> int:
|
|
||||||
"""获取文件大小"""
|
|
||||||
if hasattr(file_obj, 'content_length') and file_obj.content_length:
|
|
||||||
return file_obj.content_length
|
|
||||||
|
|
||||||
# 通过读取内容获取大小
|
|
||||||
current_pos = file_obj.tell()
|
|
||||||
file_obj.seek(0, 2) # 移动到文件末尾
|
|
||||||
size = file_obj.tell()
|
|
||||||
file_obj.seek(current_pos) # 恢复原位置
|
|
||||||
return size
|
|
||||||
|
|
||||||
def is_small_file(self, file_obj) -> bool:
|
|
||||||
"""判断是否为小文件"""
|
|
||||||
return self.get_file_size(file_obj) <= self.SMALL_FILE_THRESHOLD
|
|
||||||
|
|
||||||
def process_image_in_memory(self, file_obj, filename: str) -> Optional[Dict[str, Any]]:
|
|
||||||
"""在内存中处理小图像文件"""
|
|
||||||
try:
|
|
||||||
# 读取文件内容到内存
|
|
||||||
file_obj.seek(0)
|
|
||||||
file_content = file_obj.read()
|
|
||||||
file_obj.seek(0)
|
|
||||||
|
|
||||||
# 验证图像格式
|
|
||||||
try:
|
|
||||||
image = Image.open(io.BytesIO(file_content))
|
|
||||||
image.verify() # 验证图像完整性
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 图像验证失败: {filename}, {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 生成唯一ID
|
|
||||||
file_id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
# 保存到本地存储
|
|
||||||
local_path = os.path.join(self.local_storage_dir, f"{file_id}_{filename}")
|
|
||||||
with open(local_path, 'wb') as f:
|
|
||||||
f.write(file_content)
|
|
||||||
|
|
||||||
# 存储元数据到MongoDB
|
|
||||||
metadata = {
|
|
||||||
"_id": file_id,
|
|
||||||
"filename": filename,
|
|
||||||
"file_type": "image",
|
|
||||||
"file_size": len(file_content),
|
|
||||||
"processing_method": "memory",
|
|
||||||
"local_path": local_path
|
|
||||||
}
|
|
||||||
|
|
||||||
# 如果有BOS管理器,也上传到BOS
|
|
||||||
if self.bos_manager:
|
|
||||||
bos_key = f"images/memory_{file_id}_{filename}"
|
|
||||||
bos_result = self._upload_to_bos_from_memory(file_content, bos_key, filename)
|
|
||||||
if bos_result:
|
|
||||||
metadata["bos_key"] = bos_key
|
|
||||||
metadata["bos_url"] = bos_result["url"]
|
|
||||||
|
|
||||||
if self.mongodb_manager:
|
|
||||||
self.mongodb_manager.store_file_metadata(metadata=metadata)
|
|
||||||
|
|
||||||
logger.info(f"✅ 内存处理图像成功: {filename} ({len(file_content)} bytes)")
|
|
||||||
return {
|
|
||||||
"file_id": file_id,
|
|
||||||
"filename": filename,
|
|
||||||
"local_path": local_path,
|
|
||||||
"processing_method": "memory"
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 内存处理图像失败: {filename}, {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def process_image_with_temp_file(self, file_obj, filename: str) -> Optional[Dict[str, Any]]:
|
|
||||||
"""使用临时文件处理大图像文件"""
|
|
||||||
try:
|
|
||||||
# 获取文件扩展名
|
|
||||||
ext = os.path.splitext(filename)[1].lower()
|
|
||||||
|
|
||||||
# 生成唯一ID
|
|
||||||
file_id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
# 创建永久文件路径
|
|
||||||
permanent_path = os.path.join(self.local_storage_dir, f"{file_id}_{filename}")
|
|
||||||
|
|
||||||
with self.temp_file_context(suffix=ext) as temp_path:
|
|
||||||
# 保存到临时文件
|
|
||||||
file_obj.seek(0)
|
|
||||||
with open(temp_path, 'wb') as temp_file:
|
|
||||||
temp_file.write(file_obj.read())
|
|
||||||
|
|
||||||
# 验证图像
|
|
||||||
try:
|
|
||||||
with Image.open(temp_path) as image:
|
|
||||||
image.verify()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 图像验证失败: {filename}, {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 复制到永久存储位置
|
|
||||||
with open(temp_path, 'rb') as src, open(permanent_path, 'wb') as dst:
|
|
||||||
dst.write(src.read())
|
|
||||||
|
|
||||||
# 获取文件信息
|
|
||||||
file_stat = os.stat(permanent_path)
|
|
||||||
|
|
||||||
# 存储元数据
|
|
||||||
metadata = {
|
|
||||||
"_id": file_id,
|
|
||||||
"filename": filename,
|
|
||||||
"file_type": "image",
|
|
||||||
"file_size": file_stat.st_size,
|
|
||||||
"processing_method": "temp_file",
|
|
||||||
"local_path": permanent_path
|
|
||||||
}
|
|
||||||
|
|
||||||
# 如果有BOS管理器,也上传到BOS
|
|
||||||
if self.bos_manager:
|
|
||||||
bos_key = f"images/temp_{file_id}_{filename}"
|
|
||||||
bos_result = self.bos_manager.upload_file(temp_path, bos_key)
|
|
||||||
if bos_result:
|
|
||||||
metadata["bos_key"] = bos_key
|
|
||||||
metadata["bos_url"] = bos_result["url"]
|
|
||||||
|
|
||||||
# 存储元数据到MongoDB
|
|
||||||
if self.mongodb_manager:
|
|
||||||
self.mongodb_manager.store_file_metadata(metadata=metadata)
|
|
||||||
|
|
||||||
logger.info(f"✅ 临时文件处理图像成功: {filename} ({file_stat.st_size} bytes)")
|
|
||||||
return {
|
|
||||||
"file_id": file_id,
|
|
||||||
"filename": filename,
|
|
||||||
"local_path": permanent_path,
|
|
||||||
"processing_method": "temp_file"
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 临时文件处理图像失败: {filename}, {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def process_image_smart(self, file_obj, filename: str) -> Optional[Dict[str, Any]]:
|
|
||||||
"""智能处理图像文件(自动选择内存或临时文件)"""
|
|
||||||
if self.is_small_file(file_obj):
|
|
||||||
logger.info(f"📦 小文件内存处理: {filename}")
|
|
||||||
return self.process_image_in_memory(file_obj, filename)
|
|
||||||
else:
|
|
||||||
logger.info(f"📁 大文件临时处理: {filename}")
|
|
||||||
return self.process_image_with_temp_file(file_obj, filename)
|
|
||||||
|
|
||||||
def process_text_in_memory(self, texts: List[str]) -> List[Dict[str, Any]]:
|
|
||||||
"""在内存中处理文本数据"""
|
|
||||||
processed_texts = []
|
|
||||||
|
|
||||||
for i, text in enumerate(texts):
|
|
||||||
try:
|
|
||||||
# 生成唯一ID和BOS键
|
|
||||||
file_id = str(uuid.uuid4())
|
|
||||||
bos_key = f"texts/memory_{file_id}.txt"
|
|
||||||
|
|
||||||
# 将文本转换为字节
|
|
||||||
text_bytes = text.encode('utf-8')
|
|
||||||
|
|
||||||
# 直接上传到BOS
|
|
||||||
bos_result = self._upload_to_bos_from_memory(
|
|
||||||
text_bytes, bos_key, f"text_{i}.txt"
|
|
||||||
)
|
|
||||||
|
|
||||||
if bos_result:
|
|
||||||
# 存储元数据到MongoDB
|
|
||||||
metadata = {
|
|
||||||
"_id": file_id,
|
|
||||||
"filename": f"text_{i}.txt",
|
|
||||||
"file_type": "text",
|
|
||||||
"file_size": len(text_bytes),
|
|
||||||
"processing_method": "memory",
|
|
||||||
"bos_key": bos_key,
|
|
||||||
"bos_url": bos_result["url"],
|
|
||||||
"text_content": text
|
|
||||||
}
|
|
||||||
|
|
||||||
self.mongodb_manager.store_file_metadata(metadata=metadata)
|
|
||||||
|
|
||||||
processed_texts.append({
|
|
||||||
"file_id": file_id,
|
|
||||||
"text_content": text,
|
|
||||||
"bos_key": bos_key,
|
|
||||||
"bos_result": bos_result
|
|
||||||
})
|
|
||||||
|
|
||||||
logger.info(f"✅ 内存处理文本成功: text_{i} ({len(text_bytes)} bytes)")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 内存处理文本失败 {i}: {e}")
|
|
||||||
|
|
||||||
return processed_texts
|
|
||||||
|
|
||||||
def download_from_bos_for_processing(self, bos_key: str, local_filename: str = None) -> Optional[str]:
|
|
||||||
"""从BOS下载文件用于模型处理"""
|
|
||||||
try:
|
|
||||||
# 生成临时文件路径
|
|
||||||
if local_filename:
|
|
||||||
ext = os.path.splitext(local_filename)[1]
|
|
||||||
else:
|
|
||||||
ext = os.path.splitext(bos_key)[1]
|
|
||||||
|
|
||||||
with self.temp_file_context(suffix=ext, delete_on_exit=False) as temp_path:
|
|
||||||
# 从BOS下载文件
|
|
||||||
if not self.bos_manager:
|
|
||||||
logger.warning("BOS manager is not available; skip download.")
|
|
||||||
return None
|
|
||||||
success = self.bos_manager.download_file(bos_key, temp_path)
|
|
||||||
|
|
||||||
if success:
|
|
||||||
logger.info(f"✅ 从BOS下载文件用于处理: {bos_key}")
|
|
||||||
return temp_path
|
|
||||||
else:
|
|
||||||
logger.error(f"❌ 从BOS下载文件失败: {bos_key}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 从BOS下载文件异常: {bos_key}, {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _upload_to_bos_from_memory(self, content: bytes, bos_key: str, filename: str) -> Optional[Dict[str, Any]]:
|
|
||||||
"""从内存直接上传到BOS"""
|
|
||||||
try:
|
|
||||||
if not self.bos_manager:
|
|
||||||
return None
|
|
||||||
# 创建临时文件用于上传
|
|
||||||
with self.temp_file_context() as temp_path:
|
|
||||||
with open(temp_path, 'wb') as temp_file:
|
|
||||||
temp_file.write(content)
|
|
||||||
result = self.bos_manager.upload_file(temp_path, bos_key)
|
|
||||||
return result
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 内存上传到BOS失败: {filename}, {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_temp_file_for_model(self, file_obj, filename: str) -> Optional[str]:
|
|
||||||
"""为模型处理获取临时文件路径(确保文件存在于本地)"""
|
|
||||||
try:
|
|
||||||
ext = os.path.splitext(filename)[1].lower()
|
|
||||||
|
|
||||||
# 生成唯一ID
|
|
||||||
file_id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
# 创建临时文件(不自动删除,供模型使用)
|
|
||||||
temp_fd, temp_path = tempfile.mkstemp(suffix=ext, dir=self.local_storage_dir)
|
|
||||||
self.temp_files.add(temp_path)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 写入文件内容
|
|
||||||
file_obj.seek(0)
|
|
||||||
with os.fdopen(temp_fd, 'wb') as temp_file:
|
|
||||||
temp_file.write(file_obj.read())
|
|
||||||
|
|
||||||
logger.debug(f"📁 为模型创建临时文件: {temp_path}")
|
|
||||||
return temp_path
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
os.close(temp_fd)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"❌ 为模型创建临时文件失败: {filename}, {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def cleanup_temp_file(self, temp_path: str):
|
|
||||||
"""清理指定的临时文件"""
|
|
||||||
if temp_path and os.path.exists(temp_path):
|
|
||||||
try:
|
|
||||||
os.unlink(temp_path)
|
|
||||||
self.temp_files.discard(temp_path)
|
|
||||||
logger.debug(f"🗑️ 清理临时文件: {temp_path}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"⚠️ 清理临时文件失败: {temp_path}, {e}")
|
|
||||||
|
|
||||||
|
|
||||||
# 全局实例
|
|
||||||
file_handler = None
|
|
||||||
|
|
||||||
def get_file_handler() -> OptimizedFileHandler:
|
|
||||||
"""获取优化文件处理器实例"""
|
|
||||||
global file_handler
|
|
||||||
if file_handler is None:
|
|
||||||
file_handler = OptimizedFileHandler()
|
|
||||||
return file_handler
|
|
||||||
@ -1,11 +1,12 @@
|
|||||||
torch>=2.0.0
|
torch>=2.0.0
|
||||||
|
torchvision>=0.15.0
|
||||||
transformers>=4.30.0
|
transformers>=4.30.0
|
||||||
accelerate>=0.20.0
|
accelerate>=0.20.0
|
||||||
faiss-cpu>=1.7.4
|
faiss-cpu>=1.7.4
|
||||||
numpy>=1.21.0
|
numpy>=1.21.0
|
||||||
Pillow>=9.0.0
|
Pillow>=9.0.0
|
||||||
|
scikit-learn>=1.3.0
|
||||||
tqdm>=4.65.0
|
tqdm>=4.65.0
|
||||||
flask>=2.3.0
|
flask>=2.3.0
|
||||||
werkzeug>=2.3.0
|
werkzeug>=2.3.0
|
||||||
requests>=2.31.0
|
psutil>=5.9.0
|
||||||
safetensors>=0.4.0
|
|
||||||
|
|||||||
@ -1,14 +0,0 @@
|
|||||||
#!/usr/bin/env bash
|
|
||||||
set -euo pipefail
|
|
||||||
|
|
||||||
# Select GPUs 0 and 1 for tensor-parallel sharding
|
|
||||||
export CUDA_VISIBLE_DEVICES=0,1
|
|
||||||
|
|
||||||
# Unbuffered stdout for real-time logs
|
|
||||||
export PYTHONUNBUFFERED=1
|
|
||||||
|
|
||||||
# Help PyTorch allocator avoid fragmentation (see OOM hint)
|
|
||||||
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
|
||||||
|
|
||||||
# Start the local web app
|
|
||||||
exec python3 web_app_local.py "$@"
|
|
||||||
BIN
sample_images/1755691510_2__.jpg
Normal file
|
After Width: | Height: | Size: 109 KiB |
BIN
sample_images/1755691510_4__.jpg
Normal file
|
After Width: | Height: | Size: 189 KiB |
BIN
sample_images/1755691510_5__.jpg
Normal file
|
After Width: | Height: | Size: 201 KiB |
BIN
sample_images/1755691510_7__.jpg
Normal file
|
After Width: | Height: | Size: 166 KiB |
BIN
sample_images/1755691510_data_generation_67caeb93_00028_.png
Normal file
|
After Width: | Height: | Size: 150 KiB |
BIN
sample_images/1755691510_data_generation_67d3f794_00013_.png
Normal file
|
After Width: | Height: | Size: 312 KiB |
BIN
sample_images/1755691510_data_generation_67d3f794_00040_.png
Normal file
|
After Width: | Height: | Size: 325 KiB |
BIN
sample_images/1755691510_jpeg
Normal file
|
After Width: | Height: | Size: 4.5 KiB |
|
Before Width: | Height: | Size: 785 KiB |
@ -3,8 +3,7 @@
|
|||||||
<head>
|
<head>
|
||||||
<meta charset="UTF-8">
|
<meta charset="UTF-8">
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
<title>本地多模态检索系统 - FAISS</title>
|
<title>多模态检索系统</title>
|
||||||
<link rel="icon" href="/favicon.ico" />
|
|
||||||
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css" rel="stylesheet">
|
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css" rel="stylesheet">
|
||||||
<link href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css" rel="stylesheet">
|
<link href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css" rel="stylesheet">
|
||||||
<style>
|
<style>
|
||||||
@ -154,25 +153,6 @@
|
|||||||
right: 20px;
|
right: 20px;
|
||||||
z-index: 1000;
|
z-index: 1000;
|
||||||
}
|
}
|
||||||
.status-bar {
|
|
||||||
position: fixed;
|
|
||||||
top: 20px;
|
|
||||||
right: 120px;
|
|
||||||
background: rgba(255,255,255,0.9);
|
|
||||||
border-radius: 12px;
|
|
||||||
padding: 8px 12px;
|
|
||||||
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
|
|
||||||
font-size: 12px;
|
|
||||||
color: #334155;
|
|
||||||
display: flex;
|
|
||||||
gap: 12px;
|
|
||||||
align-items: center;
|
|
||||||
}
|
|
||||||
.status-item {
|
|
||||||
display: flex;
|
|
||||||
gap: 6px;
|
|
||||||
align-items: center;
|
|
||||||
}
|
|
||||||
|
|
||||||
.fade-in {
|
.fade-in {
|
||||||
animation: fadeIn 0.5s ease-in;
|
animation: fadeIn 0.5s ease-in;
|
||||||
@ -199,28 +179,13 @@
|
|||||||
<i class="fas fa-circle-notch fa-spin"></i> 未初始化
|
<i class="fas fa-circle-notch fa-spin"></i> 未初始化
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<!-- 右上角状态栏 -->
|
|
||||||
<div class="status-bar">
|
|
||||||
<div class="status-item" title="Vector Dimension">
|
|
||||||
<i class="fas fa-ruler-combined text-primary"></i>
|
|
||||||
<span id="statusVectorDim">-</span>
|
|
||||||
</div>
|
|
||||||
<div class="status-item" title="Total Vectors">
|
|
||||||
<i class="fas fa-database text-success"></i>
|
|
||||||
<span id="statusTotalVectors">-</span>
|
|
||||||
</div>
|
|
||||||
<div class="status-item" title="Server Time (UTC)">
|
|
||||||
<i class="fas fa-clock text-warning"></i>
|
|
||||||
<span id="statusServerTime">-</span>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div class="container-fluid">
|
<div class="container-fluid">
|
||||||
<div class="main-container">
|
<div class="main-container">
|
||||||
<!-- 头部 -->
|
<!-- 头部 -->
|
||||||
<div class="header">
|
<div class="header">
|
||||||
<h1><i class="fas fa-search"></i> 本地多模态检索系统</h1>
|
<h1><i class="fas fa-search"></i> 多模态检索系统</h1>
|
||||||
<p class="mb-0">基于本地模型和FAISS向量数据库,支持文搜图、文搜文、图搜图、图搜文四种检索模式</p>
|
<p class="mb-0">支持文搜图、文搜文、图搜图、图搜文四种检索模式</p>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div class="p-4">
|
<div class="p-4">
|
||||||
@ -319,7 +284,9 @@
|
|||||||
<div class="row mt-4">
|
<div class="row mt-4">
|
||||||
<div class="col-md-8">
|
<div class="col-md-8">
|
||||||
<div class="d-flex gap-3">
|
<div class="d-flex gap-3">
|
||||||
<!-- 移除构建索引按钮,改为自动构建 -->
|
<button id="buildIndexBtn" class="btn btn-warning" disabled>
|
||||||
|
<i class="fas fa-cogs"></i> 构建索引
|
||||||
|
</button>
|
||||||
<button id="viewDataBtn" class="btn btn-info">
|
<button id="viewDataBtn" class="btn btn-info">
|
||||||
<i class="fas fa-list"></i> 查看数据
|
<i class="fas fa-list"></i> 查看数据
|
||||||
</button>
|
</button>
|
||||||
@ -356,11 +323,7 @@
|
|||||||
<option value="3">Top 3</option>
|
<option value="3">Top 3</option>
|
||||||
<option value="5" selected>Top 5</option>
|
<option value="5" selected>Top 5</option>
|
||||||
<option value="10">Top 10</option>
|
<option value="10">Top 10</option>
|
||||||
<option value="50">Top 50</option>
|
|
||||||
<option value="100">Top 100</option>
|
|
||||||
<option value="custom">自定义</option>
|
|
||||||
</select>
|
</select>
|
||||||
<input id="textTopKCustom" type="number" min="1" max="1000" placeholder="自定义K" class="form-control mt-2" style="display:none;" />
|
|
||||||
</div>
|
</div>
|
||||||
<div class="col-md-2">
|
<div class="col-md-2">
|
||||||
<button id="textSearchBtn" class="btn btn-primary w-100">
|
<button id="textSearchBtn" class="btn btn-primary w-100">
|
||||||
@ -386,11 +349,7 @@
|
|||||||
<option value="3">Top 3</option>
|
<option value="3">Top 3</option>
|
||||||
<option value="5" selected>Top 5</option>
|
<option value="5" selected>Top 5</option>
|
||||||
<option value="10">Top 10</option>
|
<option value="10">Top 10</option>
|
||||||
<option value="50">Top 50</option>
|
|
||||||
<option value="100">Top 100</option>
|
|
||||||
<option value="custom">自定义</option>
|
|
||||||
</select>
|
</select>
|
||||||
<input id="imageTopKCustom" type="number" min="1" max="1000" placeholder="自定义K" class="form-control mt-2" style="display:none;" />
|
|
||||||
</div>
|
</div>
|
||||||
<div class="col-md-2">
|
<div class="col-md-2">
|
||||||
<button id="imageSearchBtn" class="btn btn-primary w-100" disabled>
|
<button id="imageSearchBtn" class="btn btn-primary w-100" disabled>
|
||||||
@ -429,7 +388,7 @@
|
|||||||
btn.disabled = true;
|
btn.disabled = true;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const response = await fetch('/api/system_info', {
|
const response = await fetch('/api/init', {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {'Content-Type': 'application/json'}
|
headers: {'Content-Type': 'application/json'}
|
||||||
});
|
});
|
||||||
@ -442,7 +401,7 @@
|
|||||||
'<i class="fas fa-check-circle"></i> 已重新初始化';
|
'<i class="fas fa-check-circle"></i> 已重新初始化';
|
||||||
document.getElementById('statusBadge').className = 'badge bg-success';
|
document.getElementById('statusBadge').className = 'badge bg-success';
|
||||||
|
|
||||||
showAlert('success', `系统重新初始化成功!GPU信息: ${data.gpu_info.length} 个, 向量数量: ${data.retrieval_info.total_vectors || 0}`);
|
showAlert('success', `系统重新初始化成功!GPU: ${data.gpu_count} 个`);
|
||||||
} else {
|
} else {
|
||||||
throw new Error(data.message);
|
throw new Error(data.message);
|
||||||
}
|
}
|
||||||
@ -489,34 +448,9 @@
|
|||||||
if (e.key === 'Enter') performTextSearch();
|
if (e.key === 'Enter') performTextSearch();
|
||||||
});
|
});
|
||||||
|
|
||||||
// TopK 自定义输入显隐
|
|
||||||
const textTopKSel = document.getElementById('textTopK');
|
|
||||||
const textTopKCustom = document.getElementById('textTopKCustom');
|
|
||||||
const imageTopKSel = document.getElementById('imageTopK');
|
|
||||||
const imageTopKCustom = document.getElementById('imageTopKCustom');
|
|
||||||
textTopKSel.addEventListener('change', () => {
|
|
||||||
textTopKCustom.style.display = textTopKSel.value === 'custom' ? 'block' : 'none';
|
|
||||||
});
|
|
||||||
imageTopKSel.addEventListener('change', () => {
|
|
||||||
imageTopKCustom.style.display = imageTopKSel.value === 'custom' ? 'block' : 'none';
|
|
||||||
});
|
|
||||||
|
|
||||||
function resolveTopK(selectEl, customEl) {
|
|
||||||
let v = selectEl.value;
|
|
||||||
if (v === 'custom') {
|
|
||||||
const n = parseInt(customEl.value, 10);
|
|
||||||
if (!Number.isFinite(n) || n <= 0) {
|
|
||||||
showAlert('warning', '请输入有效的自定义 Top K 数值');
|
|
||||||
throw new Error('Invalid custom TopK');
|
|
||||||
}
|
|
||||||
return n;
|
|
||||||
}
|
|
||||||
return parseInt(v, 10);
|
|
||||||
}
|
|
||||||
|
|
||||||
async function performTextSearch() {
|
async function performTextSearch() {
|
||||||
const query = document.getElementById('textQuery').value.trim();
|
const query = document.getElementById('textQuery').value.trim();
|
||||||
const topK = resolveTopK(textTopKSel, textTopKCustom);
|
const topK = parseInt(document.getElementById('textTopK').value);
|
||||||
|
|
||||||
if (!query) {
|
if (!query) {
|
||||||
showAlert('warning', '请输入搜索文本');
|
showAlert('warning', '请输入搜索文本');
|
||||||
@ -526,12 +460,11 @@
|
|||||||
showLoading(true);
|
showLoading(true);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const endpoint = '/api/search_by_text';
|
const endpoint = currentMode === 'text_to_text' ? '/api/search/text_to_text' : '/api/search/text_to_image';
|
||||||
const filter_type = currentMode === 'text_to_text' ? 'text' : 'image';
|
|
||||||
const response = await fetch(endpoint, {
|
const response = await fetch(endpoint, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {'Content-Type': 'application/json'},
|
headers: {'Content-Type': 'application/json'},
|
||||||
body: JSON.stringify({query, k: topK, filter_type: filter_type})
|
body: JSON.stringify({query, top_k: topK})
|
||||||
});
|
});
|
||||||
|
|
||||||
const data = await response.json();
|
const data = await response.json();
|
||||||
@ -596,7 +529,7 @@
|
|||||||
// 图片搜索
|
// 图片搜索
|
||||||
document.getElementById('imageSearchBtn').addEventListener('click', async function() {
|
document.getElementById('imageSearchBtn').addEventListener('click', async function() {
|
||||||
const file = imageFile.files[0];
|
const file = imageFile.files[0];
|
||||||
const topK = resolveTopK(imageTopKSel, imageTopKCustom);
|
const topK = parseInt(document.getElementById('imageTopK').value);
|
||||||
|
|
||||||
if (!file) {
|
if (!file) {
|
||||||
showAlert('warning', '请选择图片文件');
|
showAlert('warning', '请选择图片文件');
|
||||||
@ -606,13 +539,11 @@
|
|||||||
showLoading(true);
|
showLoading(true);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const endpoint = '/api/search_by_image';
|
|
||||||
const filter_type = currentMode === 'image_to_text' ? 'text' : 'image';
|
|
||||||
|
|
||||||
const formData = new FormData();
|
const formData = new FormData();
|
||||||
formData.append('image', file);
|
formData.append('image', file);
|
||||||
formData.append('k', topK);
|
formData.append('top_k', topK);
|
||||||
formData.append('filter_type', filter_type);
|
|
||||||
|
const endpoint = currentMode === 'image_to_text' ? '/api/search/image_to_text' : '/api/search/image_to_image';
|
||||||
const response = await fetch(endpoint, {
|
const response = await fetch(endpoint, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
body: formData
|
body: formData
|
||||||
@ -641,18 +572,17 @@
|
|||||||
<div class="d-flex justify-content-between align-items-center mb-3">
|
<div class="d-flex justify-content-between align-items-center mb-3">
|
||||||
<h4><i class="fas fa-search-plus"></i> 搜索结果</h4>
|
<h4><i class="fas fa-search-plus"></i> 搜索结果</h4>
|
||||||
<div>
|
<div>
|
||||||
<span class="badge bg-info">找到 ${data.results?.length || 0} 个结果</span>
|
<span class="badge bg-info">找到 ${data.result_count} 个结果</span>
|
||||||
<span class="badge bg-secondary">耗时 ${data.search_time || data.time || '0.0'}s</span>
|
<span class="badge bg-secondary">耗时 ${data.search_time}s</span>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
`;
|
`;
|
||||||
|
|
||||||
if (data.query_image) {
|
if (data.query_image) {
|
||||||
const imageUrl = data.query_image.startsWith('data:') ? data.query_image : `data:image/jpeg;base64,${data.query_image}`;
|
|
||||||
html += `
|
html += `
|
||||||
<div class="result-card">
|
<div class="result-card">
|
||||||
<h6><i class="fas fa-image"></i> 查询图片</h6>
|
<h6><i class="fas fa-image"></i> 查询图片</h6>
|
||||||
<img src="${imageUrl}" class="query-image">
|
<img src="data:image/jpeg;base64,${data.query_image}" class="query-image">
|
||||||
</div>
|
</div>
|
||||||
`;
|
`;
|
||||||
}
|
}
|
||||||
@ -670,40 +600,29 @@
|
|||||||
html += '<div class="result-card">';
|
html += '<div class="result-card">';
|
||||||
|
|
||||||
if (mode === 'text_to_image' || mode === 'image_to_image') {
|
if (mode === 'text_to_image' || mode === 'image_to_image') {
|
||||||
const imageUrl = result.image_base64 ? `data:image/jpeg;base64,${result.image_base64}` :
|
|
||||||
(result.image_url || `/temp/${result.filename || result.id}`);
|
|
||||||
const score = result.score || result.distance ?
|
|
||||||
(result.score ? (result.score * 100).toFixed(1) : (100 - result.distance * 100).toFixed(1)) : '95.0';
|
|
||||||
const title = result.title || result.filename || result.id || `结果 ${index + 1}`;
|
|
||||||
|
|
||||||
html += `
|
html += `
|
||||||
<div class="row">
|
<div class="row">
|
||||||
<div class="col-md-3">
|
<div class="col-md-3">
|
||||||
<img src="${imageUrl}" class="result-image" alt="Result ${index + 1}">
|
<img src="data:image/jpeg;base64,${result.image_base64}"
|
||||||
|
class="result-image" alt="Result ${index + 1}">
|
||||||
</div>
|
</div>
|
||||||
<div class="col-md-9">
|
<div class="col-md-9">
|
||||||
<div class="d-flex justify-content-between align-items-start">
|
<div class="d-flex justify-content-between align-items-start">
|
||||||
<h6><i class="fas fa-image"></i> ${title}</h6>
|
<h6><i class="fas fa-image"></i> ${result.filename}</h6>
|
||||||
<span class="score-badge">相似度: ${score}%</span>
|
<span class="score-badge">相似度: ${(result.score * 100).toFixed(1)}%</span>
|
||||||
</div>
|
</div>
|
||||||
<p class="text-muted mb-0">类型: 图片 | ID: ${result.id || index}</p>
|
<p class="text-muted mb-0">路径: ${result.image_path}</p>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
`;
|
`;
|
||||||
} else {
|
} else {
|
||||||
const text = result.text || result.content || (typeof result === 'string' ? result : JSON.stringify(result));
|
|
||||||
const score = result.score || result.distance ?
|
|
||||||
(result.score ? (result.score * 100).toFixed(1) : (100 - result.distance * 100).toFixed(1)) : '95.0';
|
|
||||||
const title = result.title || `结果 ${index + 1}`;
|
|
||||||
|
|
||||||
html += `
|
html += `
|
||||||
<div class="d-flex justify-content-between align-items-start">
|
<div class="d-flex justify-content-between align-items-start">
|
||||||
<div>
|
<div>
|
||||||
<h6><i class="fas fa-file-text"></i> ${title}</h6>
|
<h6><i class="fas fa-file-text"></i> 结果 ${index + 1}</h6>
|
||||||
<p class="mb-0">${text}</p>
|
<p class="mb-0">${result.text || result}</p>
|
||||||
<p class="text-muted small mb-0">类型: 文本 | ID: ${result.id || index}</p>
|
|
||||||
</div>
|
</div>
|
||||||
<span class="score-badge">相似度: ${score}%</span>
|
<span class="score-badge">相似度: ${((result.score || 0.95) * 100).toFixed(1)}%</span>
|
||||||
</div>
|
</div>
|
||||||
`;
|
`;
|
||||||
}
|
}
|
||||||
@ -740,10 +659,10 @@
|
|||||||
// 检查系统状态
|
// 检查系统状态
|
||||||
async function checkStatus() {
|
async function checkStatus() {
|
||||||
try {
|
try {
|
||||||
const response = await fetch('/api/system_info');
|
const response = await fetch('/api/status');
|
||||||
const data = await response.json();
|
const data = await response.json();
|
||||||
|
|
||||||
if (data.success) {
|
if (data.initialized) {
|
||||||
systemInitialized = true;
|
systemInitialized = true;
|
||||||
document.getElementById('statusBadge').innerHTML =
|
document.getElementById('statusBadge').innerHTML =
|
||||||
'<i class="fas fa-check-circle"></i> 已初始化';
|
'<i class="fas fa-check-circle"></i> 已初始化';
|
||||||
@ -761,28 +680,8 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 轮询 /api/stats 更新右上角状态栏
|
// 页面加载时检查状态
|
||||||
async function updateStatusBar() {
|
|
||||||
try {
|
|
||||||
const res = await fetch('/api/stats');
|
|
||||||
const data = await res.json();
|
|
||||||
if (data && data.success) {
|
|
||||||
const dim = data.debug?.vector_dimension ?? data.stats?.vector_dimension ?? '-';
|
|
||||||
const total = data.debug?.total_vectors ?? data.stats?.total_vectors ?? '-';
|
|
||||||
const time = data.debug?.server_time ?? '-';
|
|
||||||
document.getElementById('statusVectorDim').textContent = dim;
|
|
||||||
document.getElementById('statusTotalVectors').textContent = total;
|
|
||||||
document.getElementById('statusServerTime').textContent = time.replace('T',' ').replace('Z','');
|
|
||||||
}
|
|
||||||
} catch (e) {
|
|
||||||
// 忽略一次失败,等待下次轮询
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 页面加载时检查状态并开始轮询状态栏
|
|
||||||
checkStatus();
|
checkStatus();
|
||||||
updateStatusBar();
|
|
||||||
setInterval(updateStatusBar, 5000);
|
|
||||||
|
|
||||||
// 设置数据管理功能事件绑定
|
// 设置数据管理功能事件绑定
|
||||||
setupDataManagement();
|
setupDataManagement();
|
||||||
@ -845,7 +744,8 @@
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
// 移除构建索引按钮的事件监听器
|
// 构建索引
|
||||||
|
document.getElementById('buildIndexBtn').addEventListener('click', buildIndex);
|
||||||
|
|
||||||
// 查看数据
|
// 查看数据
|
||||||
document.getElementById('viewDataBtn').addEventListener('click', viewData);
|
document.getElementById('viewDataBtn').addEventListener('click', viewData);
|
||||||
@ -857,9 +757,8 @@
|
|||||||
updateDataStats();
|
updateDataStats();
|
||||||
}
|
}
|
||||||
|
|
||||||
// 批量上传图片(串行:并发=1,每次上传一张,调用 /api/add_image)
|
// 批量上传图片
|
||||||
async function uploadBatchImages(files) {
|
async function uploadBatchImages(files) {
|
||||||
try {
|
|
||||||
const progressDiv = document.getElementById('imageUploadProgress');
|
const progressDiv = document.getElementById('imageUploadProgress');
|
||||||
const progressBar = progressDiv.querySelector('.progress-bar');
|
const progressBar = progressDiv.querySelector('.progress-bar');
|
||||||
const progressText = document.getElementById('imageProgressText');
|
const progressText = document.getElementById('imageProgressText');
|
||||||
@ -868,116 +767,120 @@ async function uploadBatchImages(files) {
|
|||||||
progressText.textContent = `0/${files.length}`;
|
progressText.textContent = `0/${files.length}`;
|
||||||
progressBar.style.width = '0%';
|
progressBar.style.width = '0%';
|
||||||
|
|
||||||
showAlert('info', `开始上传 ${files.length} 张图片(串行)...`);
|
|
||||||
|
|
||||||
let successCount = 0;
|
|
||||||
for (let i = 0; i < files.length; i++) {
|
|
||||||
const formData = new FormData();
|
const formData = new FormData();
|
||||||
formData.append('image', files[i]);
|
files.forEach(file => {
|
||||||
|
formData.append('images', file);
|
||||||
|
});
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const resp = await fetch('/api/add_image', { method: 'POST', body: formData });
|
const response = await fetch('/api/upload/images', {
|
||||||
const data = await resp.json();
|
method: 'POST',
|
||||||
if (data && data.success) {
|
body: formData
|
||||||
successCount++;
|
});
|
||||||
} else {
|
|
||||||
// 不中断流程,记录失败
|
|
||||||
console.warn('Upload failed for file:', files[i].name, data?.error);
|
|
||||||
}
|
|
||||||
} catch (e) {
|
|
||||||
console.warn('Request failed for file:', files[i].name, e);
|
|
||||||
}
|
|
||||||
|
|
||||||
const progress = Math.round(((i + 1) / files.length) * 100);
|
const data = await response.json();
|
||||||
progressBar.style.width = `${progress}%`;
|
|
||||||
progressText.textContent = `${i + 1}/${files.length}`;
|
|
||||||
}
|
|
||||||
|
|
||||||
showAlert('success', `上传完成:成功 ${successCount}/${files.length} 张`);
|
if (data.success) {
|
||||||
await autoSaveIndex();
|
progressBar.style.width = '100%';
|
||||||
|
progressText.textContent = `${files.length}/${files.length}`;
|
||||||
|
showAlert('success', `成功上传 ${data.uploaded_count} 张图片`);
|
||||||
updateDataStats();
|
updateDataStats();
|
||||||
|
document.getElementById('buildIndexBtn').disabled = false;
|
||||||
|
} else {
|
||||||
|
showAlert('danger', `上传失败: ${data.message}`);
|
||||||
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
showAlert('danger', `图片上传失败: ${error.message}`);
|
showAlert('danger', `上传错误: ${error.message}`);
|
||||||
} finally {
|
} finally {
|
||||||
setTimeout(() => {
|
setTimeout(() => {
|
||||||
document.getElementById('imageUploadProgress').style.display = 'none';
|
progressDiv.style.display = 'none';
|
||||||
}, 1500);
|
}, 2000);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 批量上传文本(调用批量API以利用多卡并行)
|
// 批量上传文本
|
||||||
async function uploadBatchTexts(texts) {
|
async function uploadBatchTexts(texts) {
|
||||||
try {
|
try {
|
||||||
showAlert('info', `正在批量上传 ${texts.length} 条文本...`);
|
const response = await fetch('/api/upload/texts', {
|
||||||
const response = await fetch('/api/add_texts_batch', {
|
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {'Content-Type': 'application/json'},
|
headers: {
|
||||||
body: JSON.stringify({texts})
|
'Content-Type': 'application/json',
|
||||||
|
},
|
||||||
|
body: JSON.stringify({ texts: texts })
|
||||||
});
|
});
|
||||||
|
|
||||||
const data = await response.json();
|
const data = await response.json();
|
||||||
if (!data.success) {
|
|
||||||
throw new Error(data.error || '批量上传失败');
|
if (data.success) {
|
||||||
}
|
showAlert('success', `成功上传 ${data.uploaded_count} 条文本`);
|
||||||
showAlert('success', data.message || `成功上传 ${texts.length} 条文本`);
|
document.getElementById('batchTextInput').value = '';
|
||||||
await autoSaveIndex();
|
|
||||||
updateDataStats();
|
updateDataStats();
|
||||||
|
document.getElementById('buildIndexBtn').disabled = false;
|
||||||
|
} else {
|
||||||
|
showAlert('danger', `上传失败: ${data.message}`);
|
||||||
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
showAlert('danger', `文本上传失败: ${error.message}`);
|
showAlert('danger', `上传错误: ${error.message}`);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 自动保存索引函数
|
// 构建索引
|
||||||
async function autoSaveIndex() {
|
async function buildIndex() {
|
||||||
|
const btn = document.getElementById('buildIndexBtn');
|
||||||
|
const originalText = btn.innerHTML;
|
||||||
|
btn.innerHTML = '<i class="fas fa-spinner fa-spin"></i> 构建中...';
|
||||||
|
btn.disabled = true;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const response = await fetch('/api/save_index', {
|
const response = await fetch('/api/build_index', {
|
||||||
method: 'POST'
|
method: 'POST'
|
||||||
});
|
});
|
||||||
|
|
||||||
const data = await response.json();
|
const data = await response.json();
|
||||||
|
|
||||||
if (data.success) {
|
if (data.success) {
|
||||||
console.log('索引自动保存成功');
|
showAlert('success', '索引构建完成!现在可以进行搜索了');
|
||||||
} else {
|
} else {
|
||||||
console.error(`索引自动保存失败: ${data.message}`);
|
showAlert('danger', `索引构建失败: ${data.message}`);
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error(`索引自动保存错误: ${error.message}`);
|
showAlert('danger', `构建错误: ${error.message}`);
|
||||||
|
} finally {
|
||||||
|
btn.innerHTML = originalText;
|
||||||
|
btn.disabled = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 查看数据
|
// 查看数据
|
||||||
async function viewData() {
|
async function viewData() {
|
||||||
try {
|
try {
|
||||||
const response = await fetch('/api/list_items');
|
const response = await fetch('/api/data/list');
|
||||||
const data = await response.json();
|
const data = await response.json();
|
||||||
|
|
||||||
if (data.success) {
|
if (data.success) {
|
||||||
let content = '<div class="row">';
|
let content = '<div class="row">';
|
||||||
|
|
||||||
// 显示图片数据
|
// 显示图片数据
|
||||||
if (data.items && data.items.filter(item => item.type === 'image').length > 0) {
|
if (data.images && data.images.length > 0) {
|
||||||
const imageItems = data.items.filter(item => item.type === 'image');
|
content += '<div class="col-md-6"><h6>图片数据 (' + data.images.length + ')</h6>';
|
||||||
content += '<div class="col-md-6"><h6>图片数据 (' + imageItems.length + ')</h6>';
|
|
||||||
content += '<div class="list-group" style="max-height: 300px; overflow-y: auto;">';
|
content += '<div class="list-group" style="max-height: 300px; overflow-y: auto;">';
|
||||||
imageItems.forEach(item => {
|
data.images.forEach(img => {
|
||||||
content += `<div class="list-group-item d-flex justify-content-between align-items-center">
|
content += `<div class="list-group-item d-flex justify-content-between align-items-center">
|
||||||
<span>${item.id}: ${item.metadata?.title || '无标题'}</span>
|
<span>${img}</span>
|
||||||
<img src="/temp/${item.filename || item.id}" class="img-thumbnail" style="width: 50px; height: 50px; object-fit: cover;">
|
<img src="/uploads/${img}" class="img-thumbnail" style="width: 50px; height: 50px; object-fit: cover;">
|
||||||
</div>`;
|
</div>`;
|
||||||
});
|
});
|
||||||
content += '</div></div>';
|
content += '</div></div>';
|
||||||
}
|
}
|
||||||
|
|
||||||
// 显示文本数据
|
// 显示文本数据
|
||||||
if (data.items && data.items.filter(item => item.type === 'text').length > 0) {
|
if (data.texts && data.texts.length > 0) {
|
||||||
const textItems = data.items.filter(item => item.type === 'text');
|
content += '<div class="col-md-6"><h6>文本数据 (' + data.texts.length + ')</h6>';
|
||||||
content += '<div class="col-md-6"><h6>文本数据 (' + textItems.length + ')</h6>';
|
|
||||||
content += '<div class="list-group" style="max-height: 300px; overflow-y: auto;">';
|
content += '<div class="list-group" style="max-height: 300px; overflow-y: auto;">';
|
||||||
textItems.forEach((item, index) => {
|
data.texts.forEach((text, index) => {
|
||||||
const text = item.content || item.text || '';
|
|
||||||
const shortText = text.length > 50 ? text.substring(0, 50) + '...' : text;
|
const shortText = text.length > 50 ? text.substring(0, 50) + '...' : text;
|
||||||
content += `<div class="list-group-item">
|
content += `<div class="list-group-item">
|
||||||
<small class="text-muted">#${item.id}</small><br>
|
<small class="text-muted">#${index + 1}</small><br>
|
||||||
${shortText}
|
${shortText}
|
||||||
</div>`;
|
</div>`;
|
||||||
});
|
});
|
||||||
@ -1002,7 +905,7 @@ async function uploadBatchTexts(texts) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const response = await fetch('/api/clear_index', {
|
const response = await fetch('/api/data/clear', {
|
||||||
method: 'POST'
|
method: 'POST'
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -1011,7 +914,7 @@ async function uploadBatchTexts(texts) {
|
|||||||
if (data.success) {
|
if (data.success) {
|
||||||
showAlert('success', '数据已清空');
|
showAlert('success', '数据已清空');
|
||||||
updateDataStats();
|
updateDataStats();
|
||||||
// 移除构建索引按钮的引用
|
document.getElementById('buildIndexBtn').disabled = true;
|
||||||
} else {
|
} else {
|
||||||
showAlert('danger', `清空失败: ${data.message}`);
|
showAlert('danger', `清空失败: ${data.message}`);
|
||||||
}
|
}
|
||||||
@ -1023,14 +926,12 @@ async function uploadBatchTexts(texts) {
|
|||||||
// 更新数据统计
|
// 更新数据统计
|
||||||
async function updateDataStats() {
|
async function updateDataStats() {
|
||||||
try {
|
try {
|
||||||
const response = await fetch('/api/system_info');
|
const response = await fetch('/api/data/stats');
|
||||||
const data = await response.json();
|
const data = await response.json();
|
||||||
|
|
||||||
if (data.success) {
|
if (data.success) {
|
||||||
const retrieval_info = data.retrieval_info || {};
|
document.getElementById('imageCount').textContent = data.image_count || 0;
|
||||||
document.getElementById('imageCount').textContent = retrieval_info.image_count || 0;
|
document.getElementById('textCount').textContent = data.text_count || 0;
|
||||||
document.getElementById('textCount').textContent = retrieval_info.text_count || 0;
|
|
||||||
// 移除构建索引按钮的引用
|
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.log('获取数据统计失败:', error);
|
console.log('获取数据统计失败:', error);
|
||||||
104
test_all_retrieval_modes.py
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
测试所有四种多模态检索模式
|
||||||
|
"""
|
||||||
|
|
||||||
|
from multimodal_retrieval_multigpu import MultiGPUMultimodalRetrieval
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
import os
|
||||||
|
|
||||||
|
def test_all_retrieval_modes():
|
||||||
|
print('正在初始化多GPU多模态检索系统...')
|
||||||
|
retrieval = MultiGPUMultimodalRetrieval()
|
||||||
|
|
||||||
|
# 准备测试数据
|
||||||
|
test_texts = [
|
||||||
|
"一只可爱的小猫",
|
||||||
|
"美丽的风景照片",
|
||||||
|
"现代建筑设计",
|
||||||
|
"colorful flowers in garden"
|
||||||
|
]
|
||||||
|
|
||||||
|
test_images = [
|
||||||
|
'sample_images/1755677101_1__.jpg',
|
||||||
|
'sample_images/1755677101_2__.jpg',
|
||||||
|
'sample_images/1755677101_3__.jpg',
|
||||||
|
'sample_images/1755677101_4__.jpg'
|
||||||
|
]
|
||||||
|
|
||||||
|
# 验证测试图像存在
|
||||||
|
existing_images = [img for img in test_images if os.path.exists(img)]
|
||||||
|
if not existing_images:
|
||||||
|
print("❌ 没有找到测试图像文件")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"找到 {len(existing_images)} 张测试图像")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 1. 构建文本索引
|
||||||
|
print('\n=== 构建文本索引 ===')
|
||||||
|
retrieval.build_text_index_parallel(test_texts)
|
||||||
|
print('✅ 文本索引构建完成')
|
||||||
|
|
||||||
|
# 2. 构建图像索引
|
||||||
|
print('\n=== 构建图像索引 ===')
|
||||||
|
retrieval.build_image_index_parallel(existing_images)
|
||||||
|
print('✅ 图像索引构建完成')
|
||||||
|
|
||||||
|
# 3. 测试文本到文本检索
|
||||||
|
print('\n=== 测试文本到文本检索 ===')
|
||||||
|
query = "小动物"
|
||||||
|
results = retrieval.search_text_by_text(query, top_k=3)
|
||||||
|
print(f'查询: "{query}"')
|
||||||
|
for i, (text, score) in enumerate(results):
|
||||||
|
print(f' {i+1}. {text} (相似度: {score:.4f})')
|
||||||
|
|
||||||
|
# 4. 测试文本到图像检索
|
||||||
|
print('\n=== 测试文本到图像检索 ===')
|
||||||
|
query = "beautiful image"
|
||||||
|
results = retrieval.search_images_by_text(query, top_k=3)
|
||||||
|
print(f'查询: "{query}"')
|
||||||
|
for i, (image_path, score) in enumerate(results):
|
||||||
|
print(f' {i+1}. {image_path} (相似度: {score:.4f})')
|
||||||
|
|
||||||
|
# 5. 测试图像到文本检索
|
||||||
|
print('\n=== 测试图像到文本检索 ===')
|
||||||
|
query_image = existing_images[0]
|
||||||
|
results = retrieval.search_text_by_image(query_image, top_k=3)
|
||||||
|
print(f'查询图像: {query_image}')
|
||||||
|
for i, (text, score) in enumerate(results):
|
||||||
|
print(f' {i+1}. {text} (相似度: {score:.4f})')
|
||||||
|
|
||||||
|
# 6. 测试图像到图像检索
|
||||||
|
print('\n=== 测试图像到图像检索 ===')
|
||||||
|
query_image = existing_images[0]
|
||||||
|
results = retrieval.search_images_by_image(query_image, top_k=3)
|
||||||
|
print(f'查询图像: {query_image}')
|
||||||
|
for i, (image_path, score) in enumerate(results):
|
||||||
|
print(f' {i+1}. {image_path} (相似度: {score:.4f})')
|
||||||
|
|
||||||
|
print('\n✅ 所有四种检索模式测试完成!')
|
||||||
|
|
||||||
|
# 7. 测试Web应用兼容的方法名
|
||||||
|
print('\n=== 测试Web应用兼容方法 ===')
|
||||||
|
try:
|
||||||
|
results = retrieval.search_text_to_image("test query", top_k=2)
|
||||||
|
print('✅ search_text_to_image 方法正常')
|
||||||
|
|
||||||
|
results = retrieval.search_image_to_text(existing_images[0], top_k=2)
|
||||||
|
print('✅ search_image_to_text 方法正常')
|
||||||
|
|
||||||
|
results = retrieval.search_image_to_image(existing_images[0], top_k=2)
|
||||||
|
print('✅ search_image_to_image 方法正常')
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f'❌ Web应用兼容方法测试失败: {e}')
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f'❌ 测试过程中出现错误: {e}')
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_all_retrieval_modes()
|
||||||
49
test_image_encoding.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
测试图像编码功能
|
||||||
|
"""
|
||||||
|
|
||||||
|
from multimodal_retrieval_multigpu import MultiGPUMultimodalRetrieval
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
def test_image_encoding():
|
||||||
|
print('正在初始化多GPU多模态检索系统...')
|
||||||
|
retrieval = MultiGPUMultimodalRetrieval()
|
||||||
|
|
||||||
|
# 测试文本编码
|
||||||
|
print('测试文本编码...')
|
||||||
|
text_embeddings = retrieval.encode_text_batch(['这是一个测试文本'])
|
||||||
|
print(f'文本embedding形状: {text_embeddings.shape}')
|
||||||
|
print(f'文本embedding数据类型: {text_embeddings.dtype}')
|
||||||
|
|
||||||
|
# 测试图像编码
|
||||||
|
print('测试图像编码...')
|
||||||
|
test_images = ['sample_images/1755677101_1__.jpg']
|
||||||
|
image_embeddings = retrieval.encode_image_batch(test_images)
|
||||||
|
print(f'图像embedding形状: {image_embeddings.shape}')
|
||||||
|
print(f'图像embedding数据类型: {image_embeddings.dtype}')
|
||||||
|
|
||||||
|
# 测试两次相同图像的embedding是否一致
|
||||||
|
print('测试embedding一致性...')
|
||||||
|
image_embeddings2 = retrieval.encode_image_batch(test_images)
|
||||||
|
consistency = np.allclose(image_embeddings, image_embeddings2, rtol=1e-5)
|
||||||
|
print(f'相同图像embedding一致性: {consistency}')
|
||||||
|
|
||||||
|
# 测试不同图像的embedding差异
|
||||||
|
print('测试不同图像embedding差异...')
|
||||||
|
test_images2 = ['sample_images/1755677101_2__.jpg']
|
||||||
|
image_embeddings3 = retrieval.encode_image_batch(test_images2)
|
||||||
|
similarity = np.dot(image_embeddings[0], image_embeddings3[0]) / (np.linalg.norm(image_embeddings[0]) * np.linalg.norm(image_embeddings3[0]))
|
||||||
|
print(f'不同图像间相似度: {similarity:.4f}')
|
||||||
|
|
||||||
|
# 验证维度一致性
|
||||||
|
if text_embeddings.shape[1] == image_embeddings.shape[1]:
|
||||||
|
print('✅ 文本和图像embedding维度一致')
|
||||||
|
else:
|
||||||
|
print('❌ 文本和图像embedding维度不一致')
|
||||||
|
|
||||||
|
print('测试完成!')
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_image_encoding()
|
||||||
BIN
uploads/query_1755675080_10__.jpg
Normal file
|
After Width: | Height: | Size: 172 KiB |
BIN
uploads/query_1755681423_5__.jpg
Normal file
|
After Width: | Height: | Size: 201 KiB |
646
web_app_local.py
@ -1,646 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""
|
|
||||||
本地多模态检索系统Web应用
|
|
||||||
集成本地模型和FAISS向量数据库
|
|
||||||
支持文搜文、文搜图、图搜文、图搜图四种检索模式
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
import json
|
|
||||||
import base64
|
|
||||||
from io import BytesIO
|
|
||||||
from pathlib import Path
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
from flask import Flask, request, jsonify, render_template, send_from_directory, send_file
|
|
||||||
from werkzeug.utils import secure_filename
|
|
||||||
import torch
|
|
||||||
|
|
||||||
# 设置离线模式
|
|
||||||
os.environ['TRANSFORMERS_OFFLINE'] = '1'
|
|
||||||
|
|
||||||
# 导入本地模块
|
|
||||||
from multimodal_retrieval_local import MultimodalRetrievalLocal
|
|
||||||
from optimized_file_handler import OptimizedFileHandler
|
|
||||||
|
|
||||||
# 设置日志
|
|
||||||
logging.basicConfig(level=logging.INFO,
|
|
||||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# 创建Flask应用
|
|
||||||
app = Flask(__name__)
|
|
||||||
|
|
||||||
# 配置
|
|
||||||
app.config['UPLOAD_FOLDER'] = '/tmp/mmeb_uploads'
|
|
||||||
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB
|
|
||||||
app.config['MODEL_PATH'] = '/root/models/Ops-MM-embedding-v1-7B'
|
|
||||||
app.config['INDEX_PATH'] = '/root/mmeb/local_faiss_index'
|
|
||||||
app.config['ALLOWED_EXTENSIONS'] = {'txt', 'pdf', 'png', 'jpg', 'jpeg', 'gif'}
|
|
||||||
|
|
||||||
# 确保上传目录存在
|
|
||||||
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
|
|
||||||
|
|
||||||
# 创建临时文件夹
|
|
||||||
if not os.path.exists(app.config['UPLOAD_FOLDER']):
|
|
||||||
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
|
|
||||||
|
|
||||||
# 确保静态资源目录与 favicon 存在(方案二:静态文件)
|
|
||||||
app.static_folder = os.path.join(os.path.dirname(__file__), 'static')
|
|
||||||
os.makedirs(app.static_folder, exist_ok=True)
|
|
||||||
favicon_path = os.path.join(app.static_folder, 'favicon.ico')
|
|
||||||
if not os.path.exists(favicon_path):
|
|
||||||
# 写入一个 1x1 透明 PNG 转 ICO 的简易占位图标(使用 PNG 作为内容也可被多数浏览器识别)
|
|
||||||
# 这里直接写入一个极小的 PNG 文件,并命名为 .ico 以简化处理
|
|
||||||
transparent_png_base64 = (
|
|
||||||
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8Xw8AAqMBhHqg7T8AAAAASUVORK5CYII="
|
|
||||||
)
|
|
||||||
with open(favicon_path, 'wb') as f:
|
|
||||||
f.write(base64.b64decode(transparent_png_base64))
|
|
||||||
|
|
||||||
# 创建文件处理器
|
|
||||||
from optimized_file_handler import OptimizedFileHandler
|
|
||||||
file_handler = OptimizedFileHandler(local_storage_dir=app.config['UPLOAD_FOLDER'])
|
|
||||||
|
|
||||||
# 全局变量
|
|
||||||
retrieval_system = None
|
|
||||||
|
|
||||||
def allowed_file(filename):
|
|
||||||
"""检查文件扩展名是否允许"""
|
|
||||||
return '.' in filename and \
|
|
||||||
filename.rsplit('.', 1)[1].lower() in app.config['ALLOWED_EXTENSIONS']
|
|
||||||
|
|
||||||
def init_retrieval_system():
|
|
||||||
"""初始化检索系统"""
|
|
||||||
global retrieval_system
|
|
||||||
|
|
||||||
if retrieval_system is not None:
|
|
||||||
return retrieval_system
|
|
||||||
|
|
||||||
logger.info("初始化多模态检索系统...")
|
|
||||||
|
|
||||||
# 检查模型路径
|
|
||||||
model_path = app.config['MODEL_PATH']
|
|
||||||
if not os.path.exists(model_path):
|
|
||||||
logger.error(f"模型路径不存在: {model_path}")
|
|
||||||
raise FileNotFoundError(f"模型路径不存在: {model_path}")
|
|
||||||
|
|
||||||
# 初始化检索系统
|
|
||||||
retrieval_system = MultimodalRetrievalLocal(
|
|
||||||
model_path=model_path,
|
|
||||||
use_all_gpus=True,
|
|
||||||
index_path=app.config['INDEX_PATH'],
|
|
||||||
shard_model_across_gpus=True
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("多模态检索系统初始化完成")
|
|
||||||
return retrieval_system
|
|
||||||
|
|
||||||
def get_image_base64(image_path):
|
|
||||||
"""将图像转换为base64编码"""
|
|
||||||
with open(image_path, "rb") as image_file:
|
|
||||||
encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
|
|
||||||
return f"data:image/jpeg;base64,{encoded_string}"
|
|
||||||
|
|
||||||
@app.route('/')
|
|
||||||
def index():
|
|
||||||
"""首页"""
|
|
||||||
return render_template('local_index.html')
|
|
||||||
|
|
||||||
@app.route('/favicon.ico')
|
|
||||||
def favicon():
|
|
||||||
"""提供静态 favicon(方案二)"""
|
|
||||||
return send_from_directory(app.static_folder, 'favicon.ico')
|
|
||||||
|
|
||||||
@app.route('/api/stats', methods=['GET'])
|
|
||||||
def get_stats():
|
|
||||||
"""获取系统统计信息"""
|
|
||||||
try:
|
|
||||||
retrieval = init_retrieval_system()
|
|
||||||
stats = retrieval.get_stats()
|
|
||||||
debug_info = {
|
|
||||||
"server_time": datetime.now(timezone.utc).isoformat(),
|
|
||||||
"vector_dimension": stats.get("vector_dimension"),
|
|
||||||
"total_vectors": stats.get("total_vectors"),
|
|
||||||
"model_path": stats.get("model_path"),
|
|
||||||
"index_path": stats.get("index_path"),
|
|
||||||
}
|
|
||||||
return jsonify({"success": True, "stats": stats, "debug": debug_info})
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取统计信息失败: {str(e)}")
|
|
||||||
return jsonify({"success": False, "error": str(e)}), 500
|
|
||||||
|
|
||||||
@app.route('/api/add_text', methods=['POST'])
|
|
||||||
def add_text():
|
|
||||||
"""添加文本"""
|
|
||||||
try:
|
|
||||||
data = request.json
|
|
||||||
text = data.get('text')
|
|
||||||
|
|
||||||
if not text:
|
|
||||||
return jsonify({"success": False, "error": "文本不能为空"}), 400
|
|
||||||
|
|
||||||
# 使用内存处理文本
|
|
||||||
with file_handler.temp_file_context(text.encode('utf-8'), suffix='.txt') as temp_file:
|
|
||||||
logger.info(f"处理文本: {temp_file}")
|
|
||||||
|
|
||||||
# 初始化检索系统
|
|
||||||
retrieval = init_retrieval_system()
|
|
||||||
|
|
||||||
# 添加文本
|
|
||||||
metadata = {
|
|
||||||
"timestamp": time.time(),
|
|
||||||
"source": "web_upload"
|
|
||||||
}
|
|
||||||
|
|
||||||
text_ids = retrieval.add_texts([text], [metadata])
|
|
||||||
|
|
||||||
# 保存索引
|
|
||||||
retrieval.save_index()
|
|
||||||
|
|
||||||
stats = retrieval.get_stats()
|
|
||||||
return jsonify({
|
|
||||||
"success": True,
|
|
||||||
"message": "文本添加成功",
|
|
||||||
"text_id": text_ids[0] if text_ids else None,
|
|
||||||
"debug": {
|
|
||||||
"server_time": datetime.now(timezone.utc).isoformat(),
|
|
||||||
"vector_dimension": stats.get("vector_dimension"),
|
|
||||||
"total_vectors": stats.get("total_vectors"),
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"添加文本失败: {str(e)}")
|
|
||||||
return jsonify({"success": False, "error": str(e)}), 500
|
|
||||||
finally:
|
|
||||||
# 清理临时文件
|
|
||||||
file_handler.cleanup_all_temp_files()
|
|
||||||
|
|
||||||
@app.route('/api/add_images_batch', methods=['POST'])
|
|
||||||
def add_images_batch():
|
|
||||||
"""批量添加图像(多卡并行友好)"""
|
|
||||||
try:
|
|
||||||
# 读取多文件
|
|
||||||
files = request.files.getlist('images')
|
|
||||||
if not files:
|
|
||||||
return jsonify({"success": False, "error": "没有上传文件"}), 400
|
|
||||||
|
|
||||||
retrieval = init_retrieval_system()
|
|
||||||
|
|
||||||
images = []
|
|
||||||
metadatas = []
|
|
||||||
image_paths = []
|
|
||||||
|
|
||||||
for file in files:
|
|
||||||
if file and allowed_file(file.filename):
|
|
||||||
image_data = file.read()
|
|
||||||
filename = secure_filename(file.filename)
|
|
||||||
image_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
|
|
||||||
with open(image_path, 'wb') as f:
|
|
||||||
f.write(image_data)
|
|
||||||
try:
|
|
||||||
img = Image.open(BytesIO(image_data))
|
|
||||||
if img.mode != 'RGB':
|
|
||||||
img = img.convert('RGB')
|
|
||||||
except Exception as e:
|
|
||||||
return jsonify({"success": False, "error": f"图像格式不支持: {filename}: {str(e)}"}), 400
|
|
||||||
images.append(img)
|
|
||||||
metadatas.append({
|
|
||||||
"filename": filename,
|
|
||||||
"timestamp": time.time(),
|
|
||||||
"source": "web_upload_batch",
|
|
||||||
"size": len(image_data),
|
|
||||||
"local_path": image_path
|
|
||||||
})
|
|
||||||
image_paths.append(image_path)
|
|
||||||
else:
|
|
||||||
return jsonify({"success": False, "error": f"不支持的文件类型: {file.filename}"}), 400
|
|
||||||
|
|
||||||
image_ids = retrieval.add_images(images, metadatas, image_paths)
|
|
||||||
retrieval.save_index()
|
|
||||||
|
|
||||||
stats = retrieval.get_stats()
|
|
||||||
return jsonify({
|
|
||||||
"success": True,
|
|
||||||
"message": f"批量添加成功: {len(image_ids)}/{len(images)}",
|
|
||||||
"image_ids": image_ids,
|
|
||||||
"debug": {
|
|
||||||
"server_time": datetime.now(timezone.utc).isoformat(),
|
|
||||||
"vector_dimension": stats.get("vector_dimension"),
|
|
||||||
"total_vectors": stats.get("total_vectors"),
|
|
||||||
}
|
|
||||||
})
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"批量添加图像失败: {str(e)}")
|
|
||||||
return jsonify({"success": False, "error": str(e)}), 500
|
|
||||||
finally:
|
|
||||||
file_handler.cleanup_all_temp_files()
|
|
||||||
|
|
||||||
@app.route('/api/add_texts_batch', methods=['POST'])
|
|
||||||
def add_texts_batch():
|
|
||||||
"""批量添加文本(多卡并行友好)"""
|
|
||||||
try:
|
|
||||||
data = request.json
|
|
||||||
if not data or 'texts' not in data:
|
|
||||||
return jsonify({"success": False, "error": "缺少 texts 数组"}), 400
|
|
||||||
texts = [t for t in data.get('texts', []) if isinstance(t, str) and t.strip()]
|
|
||||||
if not texts:
|
|
||||||
return jsonify({"success": False, "error": "文本数组为空"}), 400
|
|
||||||
|
|
||||||
retrieval = init_retrieval_system()
|
|
||||||
metadatas = [{"timestamp": time.time(), "source": "web_upload_batch"} for _ in texts]
|
|
||||||
text_ids = retrieval.add_texts(texts, metadatas)
|
|
||||||
retrieval.save_index()
|
|
||||||
|
|
||||||
stats = retrieval.get_stats()
|
|
||||||
return jsonify({
|
|
||||||
"success": True,
|
|
||||||
"message": f"批量添加成功: {len(text_ids)}/{len(texts)}",
|
|
||||||
"text_ids": text_ids,
|
|
||||||
"debug": {
|
|
||||||
"server_time": datetime.now(timezone.utc).isoformat(),
|
|
||||||
"vector_dimension": stats.get("vector_dimension"),
|
|
||||||
"total_vectors": stats.get("total_vectors"),
|
|
||||||
}
|
|
||||||
})
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"批量添加文本失败: {str(e)}")
|
|
||||||
return jsonify({"success": False, "error": str(e)}), 500
|
|
||||||
|
|
||||||
@app.route('/api/add_image', methods=['POST'])
|
|
||||||
def add_image():
|
|
||||||
"""添加图像"""
|
|
||||||
try:
|
|
||||||
# 检查是否有文件
|
|
||||||
if 'image' not in request.files:
|
|
||||||
return jsonify({"success": False, "error": "没有上传文件"}), 400
|
|
||||||
|
|
||||||
file = request.files['image']
|
|
||||||
|
|
||||||
# 检查文件名
|
|
||||||
if file.filename == '':
|
|
||||||
return jsonify({"success": False, "error": "没有选择文件"}), 400
|
|
||||||
|
|
||||||
if file and allowed_file(file.filename):
|
|
||||||
# 读取图像数据
|
|
||||||
image_data = file.read()
|
|
||||||
file_size = len(image_data)
|
|
||||||
|
|
||||||
# 使用文件处理器处理图像
|
|
||||||
logger.info(f"处理图像: {file.filename} ({file_size} 字节)")
|
|
||||||
|
|
||||||
# 初始化检索系统
|
|
||||||
retrieval = init_retrieval_system()
|
|
||||||
|
|
||||||
# 创建临时文件
|
|
||||||
file_obj = BytesIO(image_data)
|
|
||||||
filename = secure_filename(file.filename)
|
|
||||||
|
|
||||||
# 保存到本地文件系统
|
|
||||||
image_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
|
|
||||||
with open(image_path, 'wb') as f:
|
|
||||||
f.write(image_data)
|
|
||||||
|
|
||||||
# 加载图像
|
|
||||||
try:
|
|
||||||
image = Image.open(BytesIO(image_data))
|
|
||||||
# 确保图像是RGB模式
|
|
||||||
if image.mode != 'RGB':
|
|
||||||
logger.info(f"将图像从 {image.mode} 转换为 RGB")
|
|
||||||
image = image.convert('RGB')
|
|
||||||
|
|
||||||
logger.info(f"成功加载图像: {filename}, 格式: {image.format}, 模式: {image.mode}, 大小: {image.size}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"加载图像失败: {filename}, 错误: {str(e)}")
|
|
||||||
return jsonify({"success": False, "error": f"图像格式不支持: {str(e)}"}), 400
|
|
||||||
|
|
||||||
# 添加图像
|
|
||||||
metadata = {
|
|
||||||
"filename": filename,
|
|
||||||
"timestamp": time.time(),
|
|
||||||
"source": "web_upload",
|
|
||||||
"size": file_size,
|
|
||||||
"local_path": image_path
|
|
||||||
}
|
|
||||||
|
|
||||||
# 添加到检索系统
|
|
||||||
image_ids = retrieval.add_images([image], [metadata], [image_path])
|
|
||||||
|
|
||||||
# 保存索引
|
|
||||||
retrieval.save_index()
|
|
||||||
|
|
||||||
stats = retrieval.get_stats()
|
|
||||||
return jsonify({
|
|
||||||
"success": True,
|
|
||||||
"message": "图像添加成功",
|
|
||||||
"image_id": image_ids[0] if image_ids else None,
|
|
||||||
"debug": {
|
|
||||||
"server_time": datetime.now(timezone.utc).isoformat(),
|
|
||||||
"vector_dimension": stats.get("vector_dimension"),
|
|
||||||
"total_vectors": stats.get("total_vectors"),
|
|
||||||
"image_size": [image.width, image.height] if 'image' in locals() else None,
|
|
||||||
}
|
|
||||||
})
|
|
||||||
else:
|
|
||||||
return jsonify({"success": False, "error": "不支持的文件类型"}), 400
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"添加图像失败: {str(e)}")
|
|
||||||
return jsonify({"success": False, "error": str(e)}), 500
|
|
||||||
finally:
|
|
||||||
# 清理临时文件
|
|
||||||
file_handler.cleanup_all_temp_files()
|
|
||||||
|
|
||||||
@app.route('/api/search_by_text', methods=['POST'])
|
|
||||||
def search_by_text():
|
|
||||||
"""文本搜索"""
|
|
||||||
try:
|
|
||||||
data = request.json
|
|
||||||
query = data.get('query')
|
|
||||||
k = int(data.get('k', 5))
|
|
||||||
filter_type = data.get('filter_type') # "text", "image" 或 null
|
|
||||||
|
|
||||||
if not query:
|
|
||||||
return jsonify({"success": False, "error": "查询文本不能为空"}), 400
|
|
||||||
|
|
||||||
# 初始化检索系统
|
|
||||||
retrieval = init_retrieval_system()
|
|
||||||
|
|
||||||
# 执行搜索
|
|
||||||
results = retrieval.search_by_text(query, k, filter_type)
|
|
||||||
|
|
||||||
# 处理结果
|
|
||||||
processed_results = []
|
|
||||||
for result in results:
|
|
||||||
item = {
|
|
||||||
"score": result.get("score", 0),
|
|
||||||
"type": result.get("type")
|
|
||||||
}
|
|
||||||
|
|
||||||
if result.get("type") == "text":
|
|
||||||
item["text"] = result.get("text", "")
|
|
||||||
elif result.get("type") == "image":
|
|
||||||
if "path" in result and os.path.exists(result["path"]):
|
|
||||||
item["image"] = get_image_base64(result["path"])
|
|
||||||
item["filename"] = os.path.basename(result["path"])
|
|
||||||
if "description" in result:
|
|
||||||
item["description"] = result["description"]
|
|
||||||
|
|
||||||
processed_results.append(item)
|
|
||||||
|
|
||||||
stats = retrieval.get_stats()
|
|
||||||
return jsonify({
|
|
||||||
"success": True,
|
|
||||||
"results": processed_results,
|
|
||||||
"query": query,
|
|
||||||
"filter_type": filter_type,
|
|
||||||
"debug": {
|
|
||||||
"server_time": datetime.now(timezone.utc).isoformat(),
|
|
||||||
"vector_dimension": stats.get("vector_dimension"),
|
|
||||||
"total_vectors": stats.get("total_vectors"),
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"文本搜索失败: {str(e)}")
|
|
||||||
return jsonify({"success": False, "error": str(e)}), 500
|
|
||||||
|
|
||||||
@app.route('/api/search_by_image', methods=['POST'])
|
|
||||||
def search_by_image():
|
|
||||||
"""图像搜索"""
|
|
||||||
try:
|
|
||||||
# 检查是否有文件
|
|
||||||
if 'image' not in request.files:
|
|
||||||
return jsonify({"success": False, "error": "没有上传文件"}), 400
|
|
||||||
|
|
||||||
file = request.files['image']
|
|
||||||
k = int(request.form.get('k', 5))
|
|
||||||
filter_type = request.form.get('filter_type') # "text", "image" 或 null
|
|
||||||
|
|
||||||
# 检查文件名
|
|
||||||
if file.filename == '':
|
|
||||||
return jsonify({"success": False, "error": "没有选择文件"}), 400
|
|
||||||
|
|
||||||
if file and allowed_file(file.filename):
|
|
||||||
# 读取图像数据
|
|
||||||
image_data = file.read()
|
|
||||||
file_size = len(image_data)
|
|
||||||
|
|
||||||
# 根据文件大小选择处理方式
|
|
||||||
if file_size <= 5 * 1024 * 1024: # 5MB
|
|
||||||
# 小文件使用内存处理
|
|
||||||
logger.info(f"使用内存处理搜索图像: {file.filename} ({file_size} 字节)")
|
|
||||||
image = Image.open(BytesIO(image_data))
|
|
||||||
|
|
||||||
# 初始化检索系统
|
|
||||||
retrieval = init_retrieval_system()
|
|
||||||
|
|
||||||
# 执行搜索
|
|
||||||
results = retrieval.search_by_image(image, k, filter_type)
|
|
||||||
else:
|
|
||||||
# 大文件使用临时文件处理
|
|
||||||
with file_handler.temp_file_context(image_data, suffix=os.path.splitext(file.filename)[1]) as temp_file:
|
|
||||||
logger.info(f"使用临时文件处理搜索图像: {temp_file} ({file_size} 字节)")
|
|
||||||
|
|
||||||
# 初始化检索系统
|
|
||||||
retrieval = init_retrieval_system()
|
|
||||||
|
|
||||||
# 加载图像
|
|
||||||
image = Image.open(temp_file)
|
|
||||||
|
|
||||||
# 执行搜索
|
|
||||||
results = retrieval.search_by_image(image, k, filter_type)
|
|
||||||
|
|
||||||
# 处理结果
|
|
||||||
processed_results = []
|
|
||||||
for result in results:
|
|
||||||
item = {
|
|
||||||
"score": result.get("score", 0),
|
|
||||||
"type": result.get("type")
|
|
||||||
}
|
|
||||||
|
|
||||||
if result.get("type") == "text":
|
|
||||||
item["text"] = result.get("text", "")
|
|
||||||
elif result.get("type") == "image":
|
|
||||||
if "path" in result and os.path.exists(result["path"]):
|
|
||||||
item["image"] = get_image_base64(result["path"])
|
|
||||||
item["filename"] = os.path.basename(result["path"])
|
|
||||||
if "description" in result:
|
|
||||||
item["description"] = result["description"]
|
|
||||||
|
|
||||||
processed_results.append(item)
|
|
||||||
|
|
||||||
stats = retrieval.get_stats()
|
|
||||||
return jsonify({
|
|
||||||
"success": True,
|
|
||||||
"results": processed_results,
|
|
||||||
"filter_type": filter_type,
|
|
||||||
"debug": {
|
|
||||||
"server_time": datetime.now(timezone.utc).isoformat(),
|
|
||||||
"vector_dimension": stats.get("vector_dimension"),
|
|
||||||
"total_vectors": stats.get("total_vectors"),
|
|
||||||
}
|
|
||||||
})
|
|
||||||
else:
|
|
||||||
return jsonify({"success": False, "error": "不支持的文件类型"}), 400
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"图像搜索失败: {str(e)}")
|
|
||||||
return jsonify({"success": False, "error": str(e)}), 500
|
|
||||||
finally:
|
|
||||||
# 清理临时文件
|
|
||||||
file_handler.cleanup_all_temp_files()
|
|
||||||
|
|
||||||
@app.route('/uploads/<filename>')
|
|
||||||
def uploaded_file(filename):
|
|
||||||
"""提供上传文件的访问"""
|
|
||||||
return send_from_directory(app.config['UPLOAD_FOLDER'], filename)
|
|
||||||
|
|
||||||
@app.route('/temp/<filename>')
|
|
||||||
def temp_file(filename):
|
|
||||||
"""提供临时文件的访问"""
|
|
||||||
return send_from_directory(app.config['UPLOAD_FOLDER'], filename)
|
|
||||||
|
|
||||||
@app.route('/api/save_index', methods=['POST'])
|
|
||||||
def save_index():
|
|
||||||
"""保存索引"""
|
|
||||||
try:
|
|
||||||
# 初始化检索系统
|
|
||||||
retrieval = init_retrieval_system()
|
|
||||||
|
|
||||||
# 保存索引
|
|
||||||
retrieval.save_index()
|
|
||||||
|
|
||||||
stats = retrieval.get_stats()
|
|
||||||
return jsonify({
|
|
||||||
"success": True,
|
|
||||||
"message": "索引保存成功",
|
|
||||||
"debug": {
|
|
||||||
"server_time": datetime.now(timezone.utc).isoformat(),
|
|
||||||
"vector_dimension": stats.get("vector_dimension"),
|
|
||||||
"total_vectors": stats.get("total_vectors"),
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"保存索引失败: {str(e)}")
|
|
||||||
return jsonify({"success": False, "error": str(e)}), 500
|
|
||||||
|
|
||||||
@app.route('/api/clear_index', methods=['POST'])
|
|
||||||
def clear_index():
|
|
||||||
"""清空索引"""
|
|
||||||
try:
|
|
||||||
# 初始化检索系统
|
|
||||||
retrieval = init_retrieval_system()
|
|
||||||
|
|
||||||
# 清空索引
|
|
||||||
retrieval.clear_index()
|
|
||||||
|
|
||||||
stats = retrieval.get_stats()
|
|
||||||
return jsonify({
|
|
||||||
"success": True,
|
|
||||||
"message": "索引已清空",
|
|
||||||
"debug": {
|
|
||||||
"server_time": datetime.now(timezone.utc).isoformat(),
|
|
||||||
"vector_dimension": stats.get("vector_dimension"),
|
|
||||||
"total_vectors": stats.get("total_vectors"),
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"清空索引失败: {str(e)}")
|
|
||||||
return jsonify({"success": False, "error": str(e)}), 500
|
|
||||||
|
|
||||||
@app.route('/api/list_items', methods=['GET'])
|
|
||||||
def list_items():
|
|
||||||
"""列出所有索引项"""
|
|
||||||
try:
|
|
||||||
# 初始化检索系统
|
|
||||||
retrieval = init_retrieval_system()
|
|
||||||
|
|
||||||
# 获取所有项
|
|
||||||
items = retrieval.list_items()
|
|
||||||
|
|
||||||
stats = retrieval.get_stats()
|
|
||||||
return jsonify({
|
|
||||||
"success": True,
|
|
||||||
"items": items,
|
|
||||||
"debug": {
|
|
||||||
"server_time": datetime.now(timezone.utc).isoformat(),
|
|
||||||
"vector_dimension": stats.get("vector_dimension"),
|
|
||||||
"total_vectors": stats.get("total_vectors"),
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"列出索引项失败: {str(e)}")
|
|
||||||
return jsonify({"success": False, "error": str(e)}), 500
|
|
||||||
|
|
||||||
@app.route('/api/system_info', methods=['GET', 'POST'])
|
|
||||||
def system_info():
|
|
||||||
"""获取系统信息"""
|
|
||||||
try:
|
|
||||||
# GPU信息
|
|
||||||
gpu_info = []
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
for i in range(torch.cuda.device_count()):
|
|
||||||
gpu_info.append({
|
|
||||||
"id": i,
|
|
||||||
"name": torch.cuda.get_device_name(i),
|
|
||||||
"memory_total": torch.cuda.get_device_properties(i).total_memory / (1024 ** 3),
|
|
||||||
"memory_allocated": torch.cuda.memory_allocated(i) / (1024 ** 3),
|
|
||||||
"memory_reserved": torch.cuda.memory_reserved(i) / (1024 ** 3)
|
|
||||||
})
|
|
||||||
|
|
||||||
# 检索系统信息
|
|
||||||
retrieval_info = {}
|
|
||||||
if retrieval_system is not None:
|
|
||||||
retrieval_info = retrieval_system.get_stats()
|
|
||||||
|
|
||||||
return jsonify({
|
|
||||||
"success": True,
|
|
||||||
"gpu_info": gpu_info,
|
|
||||||
"retrieval_info": retrieval_info,
|
|
||||||
"model_path": app.config['MODEL_PATH'],
|
|
||||||
"index_path": app.config['INDEX_PATH'],
|
|
||||||
"debug": {"server_time": datetime.now(timezone.utc).isoformat()}
|
|
||||||
})
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取系统信息失败: {str(e)}")
|
|
||||||
return jsonify({"success": False, "error": str(e)}), 500
|
|
||||||
|
|
||||||
@app.route('/api/health', methods=['GET'])
|
|
||||||
def health():
|
|
||||||
"""健康检查:模型可用性、索引状态、GPU 信息"""
|
|
||||||
try:
|
|
||||||
retrieval = init_retrieval_system()
|
|
||||||
stats = retrieval.get_stats()
|
|
||||||
health_info = {
|
|
||||||
"model_loaded": True,
|
|
||||||
"vector_dimension": stats.get("vector_dimension"),
|
|
||||||
"total_vectors": stats.get("total_vectors"),
|
|
||||||
"gpu_available": torch.cuda.is_available(),
|
|
||||||
"cuda_device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0,
|
|
||||||
"server_time": datetime.now(timezone.utc).isoformat(),
|
|
||||||
}
|
|
||||||
return jsonify({"success": True, "health": health_info})
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"健康检查失败: {str(e)}")
|
|
||||||
return jsonify({"success": False, "error": str(e), "health": {"model_loaded": False, "server_time": datetime.now(timezone.utc).isoformat()}}), 500
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
try:
|
|
||||||
# 预初始化检索系统
|
|
||||||
init_retrieval_system()
|
|
||||||
|
|
||||||
# 启动Web应用
|
|
||||||
app.run(host='0.0.0.0', port=5000, debug=False)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"启动Web应用失败: {str(e)}")
|
|
||||||
sys.exit(1)
|
|
||||||
617
web_app_multigpu.py
Normal file
@ -0,0 +1,617 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
多GPU多模态检索系统 - Web应用
|
||||||
|
专为双GPU部署优化
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from flask import Flask, render_template, request, jsonify, send_file, url_for
|
||||||
|
from werkzeug.utils import secure_filename
|
||||||
|
from PIL import Image
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
import traceback
|
||||||
|
import glob
|
||||||
|
|
||||||
|
# 设置环境变量优化GPU内存
|
||||||
|
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
app = Flask(__name__)
|
||||||
|
app.config['SECRET_KEY'] = 'multigpu_multimodal_retrieval_2024'
|
||||||
|
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max file size
|
||||||
|
|
||||||
|
# 配置上传文件夹
|
||||||
|
UPLOAD_FOLDER = 'uploads'
|
||||||
|
SAMPLE_IMAGES_FOLDER = 'sample_images'
|
||||||
|
TEXT_DATA_FOLDER = 'text_data'
|
||||||
|
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'bmp', 'webp'}
|
||||||
|
|
||||||
|
# 确保文件夹存在
|
||||||
|
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
|
||||||
|
os.makedirs(SAMPLE_IMAGES_FOLDER, exist_ok=True)
|
||||||
|
os.makedirs(TEXT_DATA_FOLDER, exist_ok=True)
|
||||||
|
|
||||||
|
# 全局检索系统实例
|
||||||
|
retrieval_system = None
|
||||||
|
|
||||||
|
def allowed_file(filename):
|
||||||
|
"""检查文件扩展名是否允许"""
|
||||||
|
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
|
||||||
|
|
||||||
|
def image_to_base64(image_path):
|
||||||
|
"""将图片转换为base64编码"""
|
||||||
|
try:
|
||||||
|
with open(image_path, "rb") as img_file:
|
||||||
|
return base64.b64encode(img_file.read()).decode('utf-8')
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"图片转换失败: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
@app.route('/')
|
||||||
|
def index():
|
||||||
|
"""主页"""
|
||||||
|
return render_template('index.html')
|
||||||
|
|
||||||
|
@app.route('/api/status')
|
||||||
|
def get_status():
|
||||||
|
"""获取系统状态"""
|
||||||
|
global retrieval_system
|
||||||
|
|
||||||
|
status = {
|
||||||
|
'initialized': retrieval_system is not None,
|
||||||
|
'gpu_count': 0,
|
||||||
|
'model_loaded': False
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
status['gpu_count'] = torch.cuda.device_count()
|
||||||
|
|
||||||
|
if retrieval_system and retrieval_system.model:
|
||||||
|
status['model_loaded'] = True
|
||||||
|
status['device_ids'] = retrieval_system.device_ids
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取状态失败: {e}")
|
||||||
|
|
||||||
|
return jsonify(status)
|
||||||
|
|
||||||
|
@app.route('/api/init', methods=['POST'])
|
||||||
|
def initialize_system():
|
||||||
|
"""初始化多GPU检索系统"""
|
||||||
|
global retrieval_system
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info("正在初始化多GPU检索系统...")
|
||||||
|
|
||||||
|
# 导入多GPU检索系统
|
||||||
|
from multimodal_retrieval_multigpu import MultiGPUMultimodalRetrieval
|
||||||
|
|
||||||
|
# 初始化系统
|
||||||
|
retrieval_system = MultiGPUMultimodalRetrieval()
|
||||||
|
|
||||||
|
if retrieval_system.model is None:
|
||||||
|
raise Exception("模型加载失败")
|
||||||
|
|
||||||
|
logger.info("✅ 多GPU系统初始化成功")
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
'success': True,
|
||||||
|
'message': '多GPU系统初始化成功',
|
||||||
|
'device_ids': retrieval_system.device_ids,
|
||||||
|
'gpu_count': len(retrieval_system.device_ids)
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"系统初始化失败: {str(e)}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
'success': False,
|
||||||
|
'message': error_msg
|
||||||
|
}), 500
|
||||||
|
|
||||||
|
@app.route('/api/search/text_to_text', methods=['POST'])
|
||||||
|
def search_text_to_text():
|
||||||
|
"""文本搜索文本"""
|
||||||
|
return handle_search('text_to_text')
|
||||||
|
|
||||||
|
@app.route('/api/search/text_to_image', methods=['POST'])
|
||||||
|
def search_text_to_image():
|
||||||
|
"""文本搜索图片"""
|
||||||
|
return handle_search('text_to_image')
|
||||||
|
|
||||||
|
@app.route('/api/search/image_to_text', methods=['POST'])
|
||||||
|
def search_image_to_text():
|
||||||
|
"""图片搜索文本"""
|
||||||
|
return handle_search('image_to_text')
|
||||||
|
|
||||||
|
@app.route('/api/search/image_to_image', methods=['POST'])
|
||||||
|
def search_image_to_image():
|
||||||
|
"""图片搜索图片"""
|
||||||
|
return handle_search('image_to_image')
|
||||||
|
|
||||||
|
@app.route('/api/search', methods=['POST'])
|
||||||
|
def search():
|
||||||
|
"""通用搜索接口(兼容旧版本)"""
|
||||||
|
mode = request.form.get('mode') or request.json.get('mode', 'text_to_text')
|
||||||
|
return handle_search(mode)
|
||||||
|
|
||||||
|
def handle_search(mode):
|
||||||
|
"""处理搜索请求的通用函数"""
|
||||||
|
global retrieval_system
|
||||||
|
|
||||||
|
if not retrieval_system:
|
||||||
|
return jsonify({
|
||||||
|
'success': False,
|
||||||
|
'message': '系统未初始化,请先点击初始化按钮'
|
||||||
|
}), 400
|
||||||
|
|
||||||
|
try:
|
||||||
|
top_k = int(request.form.get('top_k', 5))
|
||||||
|
|
||||||
|
if mode in ['text_to_text', 'text_to_image']:
|
||||||
|
# 文本查询
|
||||||
|
query = request.form.get('query') or request.json.get('query', '')
|
||||||
|
if not query.strip():
|
||||||
|
return jsonify({
|
||||||
|
'success': False,
|
||||||
|
'message': '请输入查询文本'
|
||||||
|
}), 400
|
||||||
|
|
||||||
|
logger.info(f"执行{mode}搜索: {query}")
|
||||||
|
|
||||||
|
# 执行搜索
|
||||||
|
if mode == 'text_to_text':
|
||||||
|
results = retrieval_system.search_text_to_text(query, top_k=top_k)
|
||||||
|
else: # text_to_image
|
||||||
|
results = retrieval_system.search_text_to_image(query, top_k=top_k)
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
'success': True,
|
||||||
|
'mode': mode,
|
||||||
|
'query': query,
|
||||||
|
'results': results,
|
||||||
|
'result_count': len(results)
|
||||||
|
})
|
||||||
|
|
||||||
|
elif mode in ['image_to_text', 'image_to_image']:
|
||||||
|
# 图片查询
|
||||||
|
if 'image' not in request.files:
|
||||||
|
return jsonify({
|
||||||
|
'success': False,
|
||||||
|
'message': '请上传查询图片'
|
||||||
|
}), 400
|
||||||
|
|
||||||
|
file = request.files['image']
|
||||||
|
if file.filename == '' or not allowed_file(file.filename):
|
||||||
|
return jsonify({
|
||||||
|
'success': False,
|
||||||
|
'message': '请上传有效的图片文件'
|
||||||
|
}), 400
|
||||||
|
|
||||||
|
# 保存上传的图片
|
||||||
|
filename = secure_filename(file.filename)
|
||||||
|
timestamp = str(int(time.time()))
|
||||||
|
filename = f"query_{timestamp}_{filename}"
|
||||||
|
filepath = os.path.join(UPLOAD_FOLDER, filename)
|
||||||
|
file.save(filepath)
|
||||||
|
|
||||||
|
logger.info(f"执行{mode}搜索,图片: {filename}")
|
||||||
|
|
||||||
|
# 执行搜索
|
||||||
|
if mode == 'image_to_text':
|
||||||
|
results = retrieval_system.search_image_to_text(filepath, top_k=top_k)
|
||||||
|
else: # image_to_image
|
||||||
|
results = retrieval_system.search_image_to_image(filepath, top_k=top_k)
|
||||||
|
|
||||||
|
# 转换查询图片为base64
|
||||||
|
query_image_b64 = image_to_base64(filepath)
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
'success': True,
|
||||||
|
'mode': mode,
|
||||||
|
'query_image': query_image_b64,
|
||||||
|
'results': results,
|
||||||
|
'result_count': len(results)
|
||||||
|
})
|
||||||
|
|
||||||
|
else:
|
||||||
|
return jsonify({
|
||||||
|
'success': False,
|
||||||
|
'message': f'不支持的搜索模式: {mode}'
|
||||||
|
}), 400
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"搜索失败: {str(e)}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
'success': False,
|
||||||
|
'message': error_msg
|
||||||
|
}), 500
|
||||||
|
|
||||||
|
@app.route('/api/upload/images', methods=['POST'])
|
||||||
|
def upload_images():
|
||||||
|
"""批量上传图片"""
|
||||||
|
try:
|
||||||
|
uploaded_files = []
|
||||||
|
|
||||||
|
if 'images' not in request.files:
|
||||||
|
return jsonify({'success': False, 'message': '没有选择文件'}), 400
|
||||||
|
|
||||||
|
files = request.files.getlist('images')
|
||||||
|
|
||||||
|
for file in files:
|
||||||
|
if file and file.filename != '' and allowed_file(file.filename):
|
||||||
|
filename = secure_filename(file.filename)
|
||||||
|
timestamp = str(int(time.time()))
|
||||||
|
filename = f"{timestamp}_{filename}"
|
||||||
|
filepath = os.path.join(SAMPLE_IMAGES_FOLDER, filename)
|
||||||
|
file.save(filepath)
|
||||||
|
uploaded_files.append(filename)
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
'success': True,
|
||||||
|
'message': f'成功上传 {len(uploaded_files)} 个图片文件',
|
||||||
|
'uploaded_count': len(uploaded_files),
|
||||||
|
'files': uploaded_files
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return jsonify({
|
||||||
|
'success': False,
|
||||||
|
'message': f'上传失败: {str(e)}'
|
||||||
|
}), 500
|
||||||
|
|
||||||
|
@app.route('/api/upload/texts', methods=['POST'])
|
||||||
|
def upload_texts():
|
||||||
|
"""批量上传文本数据"""
|
||||||
|
try:
|
||||||
|
data = request.get_json()
|
||||||
|
|
||||||
|
if not data or 'texts' not in data:
|
||||||
|
return jsonify({'success': False, 'message': '没有提供文本数据'}), 400
|
||||||
|
|
||||||
|
texts = data['texts']
|
||||||
|
if not isinstance(texts, list):
|
||||||
|
return jsonify({'success': False, 'message': '文本数据格式错误'}), 400
|
||||||
|
|
||||||
|
# 保存文本数据到文件
|
||||||
|
timestamp = str(int(time.time()))
|
||||||
|
filename = f"texts_{timestamp}.json"
|
||||||
|
filepath = os.path.join(TEXT_DATA_FOLDER, filename)
|
||||||
|
|
||||||
|
with open(filepath, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(texts, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
'success': True,
|
||||||
|
'message': f'成功上传 {len(texts)} 条文本',
|
||||||
|
'uploaded_count': len(texts)
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return jsonify({
|
||||||
|
'success': False,
|
||||||
|
'message': f'上传失败: {str(e)}'
|
||||||
|
}), 500
|
||||||
|
|
||||||
|
@app.route('/api/upload/file', methods=['POST'])
|
||||||
|
def upload_single_file():
|
||||||
|
"""上传单个文件"""
|
||||||
|
if 'file' not in request.files:
|
||||||
|
return jsonify({'success': False, 'message': '没有选择文件'}), 400
|
||||||
|
|
||||||
|
file = request.files['file']
|
||||||
|
if file.filename == '':
|
||||||
|
return jsonify({'success': False, 'message': '没有选择文件'}), 400
|
||||||
|
|
||||||
|
if file and allowed_file(file.filename):
|
||||||
|
filename = secure_filename(file.filename)
|
||||||
|
timestamp = str(int(time.time()))
|
||||||
|
filename = f"{timestamp}_{filename}"
|
||||||
|
filepath = os.path.join(SAMPLE_IMAGES_FOLDER, filename)
|
||||||
|
file.save(filepath)
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
'success': True,
|
||||||
|
'message': '文件上传成功',
|
||||||
|
'filename': filename
|
||||||
|
})
|
||||||
|
|
||||||
|
return jsonify({'success': False, 'message': '不支持的文件类型'}), 400
|
||||||
|
|
||||||
|
@app.route('/api/data/list', methods=['GET'])
|
||||||
|
def list_data():
|
||||||
|
"""列出已上传的数据"""
|
||||||
|
try:
|
||||||
|
# 列出图片文件
|
||||||
|
images = []
|
||||||
|
if os.path.exists(SAMPLE_IMAGES_FOLDER):
|
||||||
|
for filename in os.listdir(SAMPLE_IMAGES_FOLDER):
|
||||||
|
if allowed_file(filename):
|
||||||
|
filepath = os.path.join(SAMPLE_IMAGES_FOLDER, filename)
|
||||||
|
stat = os.stat(filepath)
|
||||||
|
images.append({
|
||||||
|
'filename': filename,
|
||||||
|
'size': stat.st_size,
|
||||||
|
'modified': stat.st_mtime
|
||||||
|
})
|
||||||
|
|
||||||
|
# 列出文本文件
|
||||||
|
texts = []
|
||||||
|
if os.path.exists(TEXT_DATA_FOLDER):
|
||||||
|
for filename in os.listdir(TEXT_DATA_FOLDER):
|
||||||
|
if filename.endswith('.json'):
|
||||||
|
filepath = os.path.join(TEXT_DATA_FOLDER, filename)
|
||||||
|
stat = os.stat(filepath)
|
||||||
|
texts.append({
|
||||||
|
'filename': filename,
|
||||||
|
'size': stat.st_size,
|
||||||
|
'modified': stat.st_mtime
|
||||||
|
})
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
'success': True,
|
||||||
|
'data': {
|
||||||
|
'images': images,
|
||||||
|
'texts': texts
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return jsonify({
|
||||||
|
'success': False,
|
||||||
|
'message': f'获取数据列表失败: {str(e)}'
|
||||||
|
}), 500
|
||||||
|
|
||||||
|
@app.route('/api/gpu_status')
|
||||||
|
def gpu_status():
|
||||||
|
"""获取GPU状态"""
|
||||||
|
try:
|
||||||
|
from smart_gpu_launcher import get_gpu_memory_info
|
||||||
|
gpu_info = get_gpu_memory_info()
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
'success': True,
|
||||||
|
'gpu_info': gpu_info
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return jsonify({
|
||||||
|
'success': False,
|
||||||
|
'message': f"获取GPU状态失败: {str(e)}"
|
||||||
|
}), 500
|
||||||
|
|
||||||
|
@app.route('/api/build_index', methods=['POST'])
|
||||||
|
def build_index():
|
||||||
|
"""构建检索索引"""
|
||||||
|
global retrieval_system
|
||||||
|
|
||||||
|
if not retrieval_system:
|
||||||
|
return jsonify({
|
||||||
|
'success': False,
|
||||||
|
'message': '系统未初始化'
|
||||||
|
}), 400
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 获取所有图片和文本文件
|
||||||
|
image_files = []
|
||||||
|
text_data = []
|
||||||
|
|
||||||
|
# 扫描图片文件
|
||||||
|
for ext in ALLOWED_EXTENSIONS:
|
||||||
|
pattern = os.path.join(SAMPLE_IMAGES_FOLDER, f"*.{ext}")
|
||||||
|
image_files.extend(glob.glob(pattern))
|
||||||
|
|
||||||
|
# 读取文本文件(支持.json和.txt格式)
|
||||||
|
text_files = glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.json"))
|
||||||
|
text_files.extend(glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.txt")))
|
||||||
|
|
||||||
|
for text_file in text_files:
|
||||||
|
try:
|
||||||
|
if text_file.endswith('.json'):
|
||||||
|
# 读取JSON格式的文本数据
|
||||||
|
with open(text_file, 'r', encoding='utf-8') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
if isinstance(data, list):
|
||||||
|
text_data.extend([str(item).strip() for item in data if str(item).strip()])
|
||||||
|
else:
|
||||||
|
text_data.append(str(data).strip())
|
||||||
|
else:
|
||||||
|
# 读取TXT格式的文本数据
|
||||||
|
with open(text_file, 'r', encoding='utf-8') as f:
|
||||||
|
lines = [line.strip() for line in f.readlines() if line.strip()]
|
||||||
|
text_data.extend(lines)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"读取文本文件失败 {text_file}: {e}")
|
||||||
|
|
||||||
|
# 检查是否有数据可以构建索引
|
||||||
|
if not image_files and not text_data:
|
||||||
|
return jsonify({
|
||||||
|
'success': False,
|
||||||
|
'message': '没有找到可用的图片或文本数据,请先上传数据'
|
||||||
|
}), 400
|
||||||
|
|
||||||
|
# 构建索引
|
||||||
|
if image_files:
|
||||||
|
logger.info(f"构建图片索引,共 {len(image_files)} 张图片")
|
||||||
|
retrieval_system.build_image_index_parallel(image_files)
|
||||||
|
|
||||||
|
if text_data:
|
||||||
|
logger.info(f"构建文本索引,共 {len(text_data)} 条文本")
|
||||||
|
retrieval_system.build_text_index_parallel(text_data)
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
'success': True,
|
||||||
|
'message': f'索引构建完成!图片: {len(image_files)} 张,文本: {len(text_data)} 条',
|
||||||
|
'image_count': len(image_files),
|
||||||
|
'text_count': len(text_data)
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"构建索引失败: {str(e)}")
|
||||||
|
return jsonify({
|
||||||
|
'success': False,
|
||||||
|
'message': f'构建索引失败: {str(e)}'
|
||||||
|
}), 500
|
||||||
|
|
||||||
|
@app.route('/api/data/stats', methods=['GET'])
|
||||||
|
def get_data_stats():
|
||||||
|
"""获取数据统计信息"""
|
||||||
|
try:
|
||||||
|
# 统计图片文件
|
||||||
|
image_count = 0
|
||||||
|
for ext in ALLOWED_EXTENSIONS:
|
||||||
|
pattern = os.path.join(SAMPLE_IMAGES_FOLDER, f"*.{ext}")
|
||||||
|
image_count += len(glob.glob(pattern))
|
||||||
|
|
||||||
|
# 统计文本数据
|
||||||
|
text_count = 0
|
||||||
|
text_files = glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.txt"))
|
||||||
|
for text_file in text_files:
|
||||||
|
try:
|
||||||
|
with open(text_file, 'r', encoding='utf-8') as f:
|
||||||
|
lines = [line.strip() for line in f.readlines() if line.strip()]
|
||||||
|
text_count += len(lines)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
'success': True,
|
||||||
|
'image_count': image_count,
|
||||||
|
'text_count': text_count
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取数据统计失败: {str(e)}")
|
||||||
|
return jsonify({
|
||||||
|
'success': False,
|
||||||
|
'message': f'获取统计失败: {str(e)}'
|
||||||
|
}), 500
|
||||||
|
|
||||||
|
@app.route('/api/data/clear', methods=['POST'])
|
||||||
|
def clear_data():
|
||||||
|
"""清空所有数据"""
|
||||||
|
try:
|
||||||
|
# 清空图片文件
|
||||||
|
for ext in ALLOWED_EXTENSIONS:
|
||||||
|
pattern = os.path.join(SAMPLE_IMAGES_FOLDER, f"*.{ext}")
|
||||||
|
for file_path in glob.glob(pattern):
|
||||||
|
try:
|
||||||
|
os.remove(file_path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"删除图片文件失败 {file_path}: {e}")
|
||||||
|
|
||||||
|
# 清空文本文件
|
||||||
|
text_files = glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.txt"))
|
||||||
|
for text_file in text_files:
|
||||||
|
try:
|
||||||
|
os.remove(text_file)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"删除文本文件失败 {text_file}: {e}")
|
||||||
|
|
||||||
|
# 重置索引
|
||||||
|
global retrieval_system
|
||||||
|
if retrieval_system:
|
||||||
|
retrieval_system.text_index = None
|
||||||
|
retrieval_system.image_index = None
|
||||||
|
retrieval_system.text_data = []
|
||||||
|
retrieval_system.image_data = []
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
'success': True,
|
||||||
|
'message': '数据已清空'
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"清空数据失败: {str(e)}")
|
||||||
|
return jsonify({
|
||||||
|
'success': False,
|
||||||
|
'message': f'清空数据失败: {str(e)}'
|
||||||
|
}), 500
|
||||||
|
|
||||||
|
@app.route('/uploads/<filename>')
|
||||||
|
def uploaded_file(filename):
|
||||||
|
"""提供上传文件的访问"""
|
||||||
|
return send_file(os.path.join(SAMPLE_IMAGES_FOLDER, filename))
|
||||||
|
|
||||||
|
def print_startup_info():
|
||||||
|
"""打印启动信息"""
|
||||||
|
print("🚀 启动多GPU多模态检索Web应用")
|
||||||
|
print("=" * 60)
|
||||||
|
print("访问地址: http://localhost:5000")
|
||||||
|
print("支持功能:")
|
||||||
|
print(" 📝 文搜文 - 文本查找相似文本")
|
||||||
|
print(" 🖼️ 文搜图 - 文本查找相关图片")
|
||||||
|
print(" 📝 图搜文 - 图片查找相关文本")
|
||||||
|
print(" 🖼️ 图搜图 - 图片查找相似图片")
|
||||||
|
print(" 📤 批量上传 - 图片和文本数据管理")
|
||||||
|
print("GPU配置:")
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
gpu_count = torch.cuda.device_count()
|
||||||
|
print(f" 🖥️ 检测到 {gpu_count} 个GPU")
|
||||||
|
for i in range(gpu_count):
|
||||||
|
name = torch.cuda.get_device_name(i)
|
||||||
|
props = torch.cuda.get_device_properties(i)
|
||||||
|
memory_gb = props.total_memory / 1024**3
|
||||||
|
print(f" GPU {i}: {name} ({memory_gb:.1f}GB)")
|
||||||
|
else:
|
||||||
|
print(" ❌ CUDA不可用")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ❌ GPU检查失败: {e}")
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
def auto_initialize():
|
||||||
|
"""启动时自动初始化系统"""
|
||||||
|
global retrieval_system
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info("🚀 启动时自动初始化多GPU检索系统...")
|
||||||
|
|
||||||
|
# 导入多GPU检索系统
|
||||||
|
from multimodal_retrieval_multigpu import MultiGPUMultimodalRetrieval
|
||||||
|
|
||||||
|
# 初始化系统
|
||||||
|
retrieval_system = MultiGPUMultimodalRetrieval()
|
||||||
|
|
||||||
|
if retrieval_system.model is None:
|
||||||
|
raise Exception("模型加载失败")
|
||||||
|
|
||||||
|
logger.info("✅ 系统自动初始化成功")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ 系统自动初始化失败: {str(e)}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
return False
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
print_startup_info()
|
||||||
|
|
||||||
|
# 启动时自动初始化
|
||||||
|
auto_initialize()
|
||||||
|
|
||||||
|
# 启动Flask应用
|
||||||
|
app.run(
|
||||||
|
host='0.0.0.0',
|
||||||
|
port=5000,
|
||||||
|
debug=False,
|
||||||
|
threaded=True
|
||||||
|
)
|
||||||