Compare commits

...

10 Commits

Author SHA1 Message Date
eust-w
2b480e5277 📝 del redundance 2025-10-28 14:51:45 +08:00
eust-w
7969226a6a add more topk choice 2025-09-23 11:22:37 +08:00
eust-w
b71b8c1b23 upload image one by one 2025-09-23 11:18:55 +08:00
eust-w
e8ed31d335 🐛 fix txt bug 2025-09-23 11:12:14 +08:00
eust-w
73ce51c611 add multi card run 2025-09-22 19:47:14 +08:00
eust-w
cadbab7541 add feat 2025-09-22 19:12:45 +08:00
eust-w
39e3fe76ea ♻️ format project 2025-09-22 18:57:10 +08:00
eust-w
202fad85ec 📝 tem push 2025-09-22 10:13:11 +00:00
eust-w
36021e817c add vdb and bos 2025-09-01 11:24:01 +00:00
eust-w
4c0bc822cb v2 2025-08-20 12:20:50 +00:00
24 changed files with 2013 additions and 1541 deletions

108
model_download_guide.md Normal file
View File

@ -0,0 +1,108 @@
# 多模态模型下载指南
## 下载 OpenSearch-AI/Ops-MM-embedding-v1-7B 模型
### 方法1使用 git-lfs
```bash
# 安装 git-lfs
apt-get install git-lfs
# 或
curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | bash
apt-get install git-lfs
# 初始化 git-lfs
git lfs install
# 克隆模型仓库
mkdir -p ~/models
git clone https://huggingface.co/OpenSearch-AI/Ops-MM-embedding-v1-7B ~/models/Ops-MM-embedding-v1-7B
```
### 方法2使用 huggingface-cli
```bash
# 安装 huggingface-hub
pip install huggingface-hub
# 下载模型
mkdir -p ~/models
huggingface-cli download OpenSearch-AI/Ops-MM-embedding-v1-7B --local-dir ~/models/Ops-MM-embedding-v1-7B
```
### 方法3手动下载关键文件
如果上述方法不可行,可以手动下载以下关键文件:
1. 访问 https://huggingface.co/OpenSearch-AI/Ops-MM-embedding-v1-7B/tree/main
2. 下载以下文件:
- `config.json`
- `pytorch_model.bin` (或分片文件 `pytorch_model-00001-of-00002.bin` 等)
- `tokenizer.json`
- `tokenizer_config.json`
- `special_tokens_map.json`
- `vocab.txt`
## 下载替代轻量级模型
如果主模型太大,可以下载这些较小的替代模型:
### CLIP 模型
```bash
mkdir -p ~/models/clip-ViT-B-32
huggingface-cli download openai/clip-vit-base-patch32 --local-dir ~/models/clip-ViT-B-32
```
### 多语言CLIP模型
```bash
mkdir -p ~/models/clip-multilingual
huggingface-cli download sentence-transformers/clip-ViT-B-32-multilingual-v1 --local-dir ~/models/clip-multilingual
```
## 传输模型文件
下载完成后,使用以下方法将模型传输到目标服务器:
### 使用 scp
```bash
# 从当前机器传输到目标服务器
scp -r ~/models/Ops-MM-embedding-v1-7B user@target-server:/root/models/
```
### 使用压缩文件
```bash
# 压缩
tar -czvf model.tar.gz ~/models/Ops-MM-embedding-v1-7B
# 传输压缩文件
scp model.tar.gz user@target-server:/root/
# 在目标服务器上解压
ssh user@target-server
mkdir -p /root/models
tar -xzvf /root/model.tar.gz -C /root/models
```
## 验证模型文件
模型下载完成后,目录结构应类似于:
```
/root/models/Ops-MM-embedding-v1-7B/
├── config.json
├── pytorch_model.bin (或分片文件)
├── tokenizer.json
├── tokenizer_config.json
├── special_tokens_map.json
└── vocab.txt
```
使用以下命令验证文件完整性:
```bash
ls -la /root/models/Ops-MM-embedding-v1-7B/
```

View File

@ -0,0 +1,605 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
使用本地模型的多模态检索系统
支持文搜文文搜图图搜文图搜图四种检索模式
"""
import torch
import numpy as np
from PIL import Image
from ops_mm_embedding_v1 import OpsMMEmbeddingV1
from concurrent.futures import ThreadPoolExecutor
from typing import List, Union, Tuple, Dict, Any, Optional
import os
import json
from pathlib import Path
import logging
import gc
import faiss
import time
# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 设置离线模式
os.environ['TRANSFORMERS_OFFLINE'] = '1'
class MultimodalRetrievalLocal:
"""使用本地模型的多模态检索系统"""
def __init__(self,
model_path: str = "/root/models/Ops-MM-embedding-v1-7B",
use_all_gpus: bool = True,
gpu_ids: List[int] = None,
min_memory_gb: int = 12,
index_path: str = "local_faiss_index",
shard_model_across_gpus: bool = False,
load_in_4bit: bool = False,
load_in_8bit: bool = False,
torch_dtype: Optional[torch.dtype] = torch.bfloat16):
"""
初始化多模态检索系统
Args:
model_path: 本地模型路径
use_all_gpus: 是否使用所有可用GPU
gpu_ids: 指定使用的GPU ID列表
min_memory_gb: 最小可用内存(GB)
index_path: FAISS索引文件路径
"""
self.model_path = model_path
self.index_path = index_path
self.shard_model_across_gpus = shard_model_across_gpus
self.load_in_4bit = load_in_4bit
self.load_in_8bit = load_in_8bit
self.torch_dtype = torch_dtype
# 检查模型路径
if not os.path.exists(model_path):
logger.error(f"模型路径不存在: {model_path}")
logger.info("请先下载模型到指定路径")
raise FileNotFoundError(f"模型路径不存在: {model_path}")
# 设置GPU设备
self._setup_devices(use_all_gpus, gpu_ids, min_memory_gb)
# 清理GPU内存
self._clear_all_gpu_memory()
# 加载嵌入模型
self._load_embedding_model()
# 初始化FAISS索引
self._init_index()
logger.info(f"多模态检索系统初始化完成,使用本地模型: {model_path}")
logger.info(f"向量存储路径: {index_path}")
def _setup_devices(self, use_all_gpus: bool, gpu_ids: List[int], min_memory_gb: int):
"""设置GPU设备"""
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.use_gpu = self.device.type == "cuda"
if self.use_gpu:
self.available_gpus = self._get_available_gpus(min_memory_gb)
if not self.available_gpus:
logger.warning(f"没有可用的GPU或GPU内存不足{min_memory_gb}GB将使用CPU")
self.device = torch.device("cpu")
self.use_gpu = False
else:
if gpu_ids:
self.gpu_ids = [gid for gid in gpu_ids if gid in self.available_gpus]
if not self.gpu_ids:
logger.warning(f"指定的GPU {gpu_ids}不可用或内存不足将使用可用的GPU: {self.available_gpus}")
self.gpu_ids = self.available_gpus
elif use_all_gpus:
self.gpu_ids = self.available_gpus
else:
self.gpu_ids = [self.available_gpus[0]]
logger.info(f"使用GPU: {self.gpu_ids}")
self.device = torch.device(f"cuda:{self.gpu_ids[0]}")
else:
logger.warning("没有可用的GPU将使用CPU")
self.gpu_ids = []
def _get_available_gpus(self, min_memory_gb: int) -> List[int]:
"""获取可用的GPU列表"""
available_gpus = []
for i in range(torch.cuda.device_count()):
total_mem = torch.cuda.get_device_properties(i).total_memory / (1024 ** 3) # GB
if total_mem >= min_memory_gb:
available_gpus.append(i)
return available_gpus
def _clear_all_gpu_memory(self):
"""清理GPU内存"""
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
def _load_embedding_model(self):
"""加载多模态嵌入模型 OpsMMEmbeddingV1"""
logger.info(f"加载本地多模态嵌入模型: {self.model_path}")
try:
self.models: List[OpsMMEmbeddingV1] = []
if self.use_gpu and len(self.gpu_ids) > 1 and self.shard_model_across_gpus:
# Tensor-parallel sharding using device_map='auto' across visible GPUs
logger.info(f"启用模型跨GPU切片(shard),使用设备: {self.gpu_ids}")
# Rely on CUDA_VISIBLE_DEVICES to constrain which GPUs are visible, or use accelerate's auto mapping
ref_model = OpsMMEmbeddingV1(
self.model_path,
device="cuda",
attn_implementation=None,
device_map="auto",
load_in_4bit=self.load_in_4bit,
load_in_8bit=self.load_in_8bit,
torch_dtype=self.torch_dtype,
)
# For sharded model, we keep a single logical model reference
self.models.append(ref_model)
else:
if self.use_gpu and len(self.gpu_ids) > 1:
logger.info(f"检测到多GPU可用: {self.gpu_ids},为每张卡加载一个模型副本(数据并行)")
for gid in self.gpu_ids:
device_str = f"cuda:{gid}"
self.models.append(
OpsMMEmbeddingV1(
self.model_path,
device=device_str,
attn_implementation=None,
load_in_4bit=self.load_in_4bit,
load_in_8bit=self.load_in_8bit,
torch_dtype=self.torch_dtype,
)
)
ref_model = self.models[0]
else:
device_str = "cuda" if self.use_gpu else "cpu"
logger.info(f"使用单设备: {device_str}")
ref_model = OpsMMEmbeddingV1(
self.model_path,
device=device_str,
attn_implementation=None,
load_in_4bit=self.load_in_4bit,
load_in_8bit=self.load_in_8bit,
torch_dtype=self.torch_dtype,
)
self.models.append(ref_model)
# 获取向量维度
self.vector_dim = int(getattr(ref_model.base_model.config, "hidden_size"))
logger.info(f"向量维度: {self.vector_dim}")
logger.info("嵌入模型加载成功")
except Exception as e:
logger.error(f"嵌入模型加载失败: {str(e)}")
raise RuntimeError(f"嵌入模型加载失败: {str(e)}")
def _init_index(self):
"""初始化FAISS索引"""
index_file = f"{self.index_path}.index"
if os.path.exists(index_file):
logger.info(f"加载现有索引: {index_file}")
try:
self.index = faiss.read_index(index_file)
logger.info(f"索引加载成功,包含{self.index.ntotal}个向量")
except Exception as e:
logger.error(f"索引加载失败: {str(e)}")
logger.info("创建新索引...")
self.index = faiss.IndexFlatL2(self.vector_dim)
else:
logger.info(f"创建新索引,维度: {self.vector_dim}")
self.index = faiss.IndexFlatL2(self.vector_dim)
# 加载元数据
self.metadata = {}
metadata_file = f"{self.index_path}_metadata.json"
if os.path.exists(metadata_file):
try:
with open(metadata_file, 'r', encoding='utf-8') as f:
self.metadata = json.load(f)
logger.info(f"元数据加载成功,包含{len(self.metadata)}条记录")
except Exception as e:
logger.error(f"元数据加载失败: {str(e)}")
def encode_text(self, text: Union[str, List[str]]) -> np.ndarray:
"""编码文本为向量(使用 OpsMMEmbeddingV1多GPU并行"""
if isinstance(text, str):
text = [text]
if len(text) == 0:
return np.zeros((0, self.vector_dim))
# 单条或单模型直接处理
if len(self.models) == 1 or len(text) == 1:
with torch.inference_mode():
emb = self.models[0].get_text_embeddings(texts=text)
arr = emb.detach().float().cpu().numpy()
return arr[0] if len(text) == 1 else arr
# 多GPU并行按设备平均切分
shard_count = len(self.models)
shards: List[List[str]] = []
for i in range(shard_count):
shards.append([])
for idx, t in enumerate(text):
shards[idx % shard_count].append(t)
results: List[np.ndarray] = [None] * shard_count # type: ignore
def run_shard(model: OpsMMEmbeddingV1, shard_texts: List[str], shard_idx: int):
if len(shard_texts) == 0:
results[shard_idx] = np.zeros((0, self.vector_dim))
return
with torch.inference_mode():
emb = model.get_text_embeddings(texts=shard_texts)
results[shard_idx] = emb.detach().float().cpu().numpy()
with ThreadPoolExecutor(max_workers=shard_count) as ex:
for i, (model, shard_texts) in enumerate(zip(self.models, shards)):
ex.submit(run_shard, model, shard_texts, i)
# 还原原始顺序(按 round-robin 拼接)
merged: List[np.ndarray] = []
max_len = max(len(s) for s in shards)
for k in range(max_len):
for i in range(shard_count):
if k < len(shards[i]) and results[i].shape[0] > 0:
merged.append(results[i][k:k+1])
if len(merged) == 0:
return np.zeros((0, self.vector_dim))
return np.vstack(merged)
def encode_image(self, image: Union[Image.Image, List[Image.Image]]) -> np.ndarray:
"""编码图像为向量(使用 OpsMMEmbeddingV1多GPU并行"""
try:
# 规范为列表
if isinstance(image, Image.Image):
images: List[Image.Image] = [image]
else:
images = image
if not images:
logger.error("encode_image: 图像列表为空")
return np.zeros((0, self.vector_dim))
# 强制为 RGB
rgb_images = [img.convert('RGB') if img.mode != 'RGB' else img for img in images]
# 单张或单模型直接处理
if len(self.models) == 1 or len(rgb_images) == 1:
with torch.inference_mode():
emb = self.models[0].get_image_embeddings(images=rgb_images)
return emb.detach().float().cpu().numpy()
# 多GPU并行按设备平均切分
shard_count = len(self.models)
shards: List[List[Image.Image]] = [[] for _ in range(shard_count)]
for idx, img in enumerate(rgb_images):
shards[idx % shard_count].append(img)
results: List[np.ndarray] = [None] * shard_count # type: ignore
def run_shard(model: OpsMMEmbeddingV1, shard_imgs: List[Image.Image], shard_idx: int):
if len(shard_imgs) == 0:
results[shard_idx] = np.zeros((0, self.vector_dim))
return
with torch.inference_mode():
emb = model.get_image_embeddings(images=shard_imgs)
results[shard_idx] = emb.detach().float().cpu().numpy()
with ThreadPoolExecutor(max_workers=shard_count) as ex:
for i, (model, shard_imgs) in enumerate(zip(self.models, shards)):
ex.submit(run_shard, model, shard_imgs, i)
# 还原原始顺序(按 round-robin 拼接)
merged: List[np.ndarray] = []
max_len = max(len(s) for s in shards)
for k in range(max_len):
for i in range(shard_count):
if k < len(shards[i]) and results[i].shape[0] > 0:
merged.append(results[i][k:k+1])
if len(merged) == 0:
return np.zeros((0, self.vector_dim))
return np.vstack(merged)
except Exception as e:
logger.error(f"encode_image: 异常: {str(e)}")
return np.zeros((0, self.vector_dim))
def add_texts(
self,
texts: List[str],
metadatas: Optional[List[Dict[str, Any]]] = None
) -> List[str]:
"""
添加文本到检索系统
Args:
texts: 文本列表
metadatas: 元数据列表每个元素是一个字典
Returns:
添加的文本ID列表
"""
if not texts:
return []
if metadatas is None:
metadatas = [{} for _ in range(len(texts))]
if len(texts) != len(metadatas):
raise ValueError("texts和metadatas长度必须相同")
# 编码文本
text_embeddings = self.encode_text(texts)
# 准备元数据
start_id = self.index.ntotal
ids = list(range(start_id, start_id + len(texts)))
# 添加到索引
self.index.add(np.array(text_embeddings).astype('float32'))
# 保存元数据
for i, id in enumerate(ids):
self.metadata[str(id)] = {
"text": texts[i],
"type": "text",
**metadatas[i]
}
logger.info(f"成功添加{len(ids)}条文本到检索系统")
return [str(id) for id in ids]
def add_images(
self,
images: List[Image.Image],
metadatas: Optional[List[Dict[str, Any]]] = None,
image_paths: Optional[List[str]] = None
) -> List[str]:
"""
添加图像到检索系统
Args:
images: PIL图像列表
metadatas: 元数据列表每个元素是一个字典
image_paths: 图像路径列表用于保存到元数据
Returns:
添加的图像ID列表
"""
try:
logger.info(f"add_images: 开始添加图像,数量: {len(images) if images else 0}")
# 检查图像列表
if not images or len(images) == 0:
logger.warning("add_images: 图像列表为空")
return []
# 准备元数据
if metadatas is None:
logger.info("add_images: 创建默认元数据")
metadatas = [{} for _ in range(len(images))]
# 检查长度一致性
if len(images) != len(metadatas):
logger.error(f"add_images: 长度不一致 - images: {len(images)}, metadatas: {len(metadatas)}")
raise ValueError("images和metadatas长度必须相同")
# 编码图像
logger.info("add_images: 编码图像")
image_embeddings = self.encode_image(images)
# 检查编码结果
if image_embeddings.shape[0] == 0:
logger.error("add_images: 图像编码失败,返回空数组")
return []
# 准备元数据
start_id = self.index.ntotal
ids = list(range(start_id, start_id + len(images)))
logger.info(f"add_images: 生成索引ID: {start_id} - {start_id + len(images) - 1}")
# 添加到索引
logger.info(f"add_images: 添加向量到FAISS索引形状: {image_embeddings.shape}")
self.index.add(np.array(image_embeddings).astype('float32'))
# 保存元数据
for i, id in enumerate(ids):
try:
metadata = {
"type": "image",
"width": images[i].width,
"height": images[i].height,
**metadatas[i]
}
if image_paths and i < len(image_paths):
metadata["path"] = image_paths[i]
self.metadata[str(id)] = metadata
logger.debug(f"add_images: 保存元数据成功 - ID: {id}")
except Exception as e:
logger.error(f"add_images: 保存元数据失败 - ID: {id}, 错误: {str(e)}")
logger.info(f"add_images: 成功添加{len(ids)}张图像到检索系统")
return [str(id) for id in ids]
except Exception as e:
logger.error(f"add_images: 添加图像异常: {str(e)}")
return []
def search_by_text(
self,
query: str,
k: int = 5,
filter_type: Optional[str] = None
) -> List[Dict[str, Any]]:
"""
文本搜索
Args:
query: 查询文本
k: 返回结果数量
filter_type: 过滤类型可选值: "text", "image", None(不过滤)
Returns:
搜索结果列表每个元素包含相似项和分数
"""
# 编码查询文本
query_embedding = self.encode_text(query)
# 执行搜索
return self._search(query_embedding, k, filter_type)
def search_by_image(
self,
image: Image.Image,
k: int = 5,
filter_type: Optional[str] = None
) -> List[Dict[str, Any]]:
"""
图像搜索
Args:
image: 查询图像
k: 返回结果数量
filter_type: 过滤类型可选值: "text", "image", None(不过滤)
Returns:
搜索结果列表每个元素包含相似项和分数
"""
# 编码查询图像
query_embedding = self.encode_image(image)
# 执行搜索
return self._search(query_embedding, k, filter_type)
def _search(
self,
query_embedding: np.ndarray,
k: int = 5,
filter_type: Optional[str] = None
) -> List[Dict[str, Any]]:
"""
执行搜索
Args:
query_embedding: 查询向量
k: 返回结果数量
filter_type: 过滤类型可选值: "text", "image", None(不过滤)
Returns:
搜索结果列表
"""
if self.index.ntotal == 0:
return []
# 确保查询向量是2D数组
if len(query_embedding.shape) == 1:
query_embedding = query_embedding.reshape(1, -1)
# 执行搜索,获取更多结果以便过滤
actual_k = k * 3 if filter_type else k
actual_k = min(actual_k, self.index.ntotal)
distances, indices = self.index.search(query_embedding.astype('float32'), actual_k)
# 处理结果
results = []
for i in range(len(indices[0])):
idx = indices[0][i]
if idx < 0: # FAISS可能返回-1表示无效索引
continue
vector_id = str(idx)
if vector_id in self.metadata:
item = self.metadata[vector_id]
# 如果指定了过滤类型,则只返回该类型的结果
if filter_type and item.get("type") != filter_type:
continue
# 添加距离和分数
result = item.copy()
result["distance"] = float(distances[0][i])
result["score"] = float(1.0 / (1.0 + distances[0][i]))
results.append(result)
# 如果已经收集了足够的结果,则停止
if len(results) >= k:
break
return results
def save_index(self):
"""保存索引和元数据"""
# 保存索引
index_file = f"{self.index_path}.index"
try:
faiss.write_index(self.index, index_file)
logger.info(f"索引保存成功: {index_file}")
except Exception as e:
logger.error(f"索引保存失败: {str(e)}")
# 保存元数据
metadata_file = f"{self.index_path}_metadata.json"
try:
with open(metadata_file, 'w', encoding='utf-8') as f:
json.dump(self.metadata, f, ensure_ascii=False, indent=2)
logger.info(f"元数据保存成功: {metadata_file}")
except Exception as e:
logger.error(f"元数据保存失败: {str(e)}")
def get_stats(self) -> Dict[str, Any]:
"""获取检索系统统计信息"""
text_count = sum(1 for v in self.metadata.values() if v.get("type") == "text")
image_count = sum(1 for v in self.metadata.values() if v.get("type") == "image")
return {
"total_vectors": self.index.ntotal,
"text_count": text_count,
"image_count": image_count,
"vector_dimension": self.vector_dim,
"index_path": self.index_path,
"model_path": self.model_path
}
def clear_index(self):
"""清空索引"""
logger.info(f"清空索引: {self.index_path}")
# 重新创建索引
self.index = faiss.IndexFlatL2(self.vector_dim)
# 清空元数据
self.metadata = {}
# 保存空索引
self.save_index()
logger.info(f"索引已清空: {self.index_path}")
return True
def list_items(self) -> List[Dict[str, Any]]:
"""列出所有索引项"""
items = []
for item_id, metadata in self.metadata.items():
item = metadata.copy()
item['id'] = item_id
items.append(item)
return items
def __del__(self):
"""析构函数,确保资源被正确释放并自动保存索引"""
try:
if hasattr(self, 'model'):
del self.model
self._clear_all_gpu_memory()
if hasattr(self, 'index') and self.index is not None:
logger.info("系统关闭前自动保存索引")
self.save_index()
except Exception as e:
logger.error(f"析构时保存索引失败: {str(e)}")

View File

@ -1,632 +0,0 @@
import torch
import torch.nn as nn
from torch.nn.parallel import DataParallel, DistributedDataParallel
import numpy as np
from PIL import Image
import faiss
from transformers import AutoModel, AutoProcessor, AutoTokenizer
from typing import List, Union, Tuple, Dict
import os
import json
from pathlib import Path
import logging
import gc
from concurrent.futures import ThreadPoolExecutor, as_completed
import threading
# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class MultiGPUMultimodalRetrieval:
"""多GPU优化的多模态检索系统支持文搜图、文搜文、图搜图、图搜文"""
def __init__(self, model_name: str = "OpenSearch-AI/Ops-MM-embedding-v1-7B",
use_all_gpus: bool = True, gpu_ids: List[int] = None, min_memory_gb=12):
"""
初始化多GPU多模态检索系统
Args:
model_name: 模型名称
use_all_gpus: 是否使用所有可用GPU
gpu_ids: 指定使用的GPU ID列表
min_memory_gb: 最小可用内存GB
"""
self.model_name = model_name
# 设置GPU设备
self._setup_devices(use_all_gpus, gpu_ids, min_memory_gb)
# 清理GPU内存
self._clear_all_gpu_memory()
logger.info(f"正在加载模型到多GPU: {self.device_ids}")
# 加载模型和处理器
self.model = None
self.tokenizer = None
self.processor = None
self._load_model_multigpu()
# 初始化索引
self.text_index = None
self.image_index = None
self.text_data = []
self.image_data = []
logger.info("多GPU模型加载完成")
def _setup_devices(self, use_all_gpus: bool, gpu_ids: List[int], min_memory_gb=12):
"""设置GPU设备"""
if not torch.cuda.is_available():
raise RuntimeError("CUDA不可用无法使用多GPU")
total_gpus = torch.cuda.device_count()
logger.info(f"检测到 {total_gpus} 个GPU")
# 检查是否设置了CUDA_VISIBLE_DEVICES
cuda_visible_devices = os.environ.get('CUDA_VISIBLE_DEVICES')
if cuda_visible_devices is not None:
# 如果设置了CUDA_VISIBLE_DEVICES使用可见的GPU
visible_gpu_count = len(cuda_visible_devices.split(','))
self.device_ids = list(range(visible_gpu_count))
logger.info(f"使用CUDA_VISIBLE_DEVICES指定的GPU: {cuda_visible_devices}")
elif use_all_gpus:
self.device_ids = self._select_best_gpus(min_memory_gb)
elif gpu_ids:
self.device_ids = gpu_ids
else:
self.device_ids = [0]
self.num_gpus = len(self.device_ids)
self.primary_device = f"cuda:{self.device_ids[0]}"
logger.info(f"使用GPU: {self.device_ids}, 主设备: {self.primary_device}")
def _clear_all_gpu_memory(self):
"""清理所有GPU内存"""
for gpu_id in self.device_ids:
torch.cuda.set_device(gpu_id)
torch.cuda.empty_cache()
torch.cuda.synchronize()
gc.collect()
logger.info("所有GPU内存已清理")
def _load_model_multigpu(self):
"""加载模型到多GPU"""
try:
# 设置环境变量优化内存使用
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
# 清理GPU内存
self._clear_gpu_memory()
# 首先尝试使用accelerate的自动设备映射
if self.num_gpus > 1:
# 设置最大内存限制每个GPU 18GB留出缓冲
max_memory = {i: "18GiB" for i in self.device_ids}
logger.info(f"正在加载模型到多GPU: {self.device_ids}")
self.model = AutoModel.from_pretrained(
self.model_name,
trust_remote_code=True,
torch_dtype=torch.float16,
device_map="auto",
max_memory=max_memory,
low_cpu_mem_usage=True,
offload_folder="./offload"
)
else:
# 单GPU加载
self.model = AutoModel.from_pretrained(
self.model_name,
trust_remote_code=True,
torch_dtype=torch.float16,
device_map=self.primary_device
)
# 加载分词器和处理器到主设备
try:
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
trust_remote_code=True
)
logger.info("Tokenizer加载成功")
except Exception as e:
logger.error(f"Tokenizer加载失败: {e}")
return False
# 加载处理器用于图像处理
try:
self.processor = AutoProcessor.from_pretrained(
self.model_name,
trust_remote_code=True
)
logger.info("Processor加载成功")
except Exception as e:
logger.warning(f"Processor加载失败: {e}")
# 如果AutoProcessor失败尝试使用tokenizer作为fallback
logger.info("尝试使用tokenizer作为processor的fallback")
self.processor = self.tokenizer
logger.info(f"模型已成功加载到设备: {self.model.hf_device_map if hasattr(self.model, 'hf_device_map') else self.primary_device}")
logger.info("多GPU模型加载完成")
return True
except Exception as e:
logger.error(f"多GPU模型加载失败: {str(e)}")
return False
def _clear_gpu_memory(self):
"""清理GPU内存"""
for gpu_id in self.device_ids:
torch.cuda.set_device(gpu_id)
torch.cuda.empty_cache()
torch.cuda.synchronize()
gc.collect()
logger.info("GPU内存已清理")
def _get_gpu_memory_info(self):
"""获取GPU内存使用情况"""
try:
import subprocess
result = subprocess.run(['nvidia-smi', '--query-gpu=memory.used,memory.total', '--format=csv,nounits,noheader'],
capture_output=True, text=True, check=True)
lines = result.stdout.strip().split('\n')
gpu_info = []
for i, line in enumerate(lines):
used, total = map(int, line.split(', '))
free = total - used
gpu_info.append({
'gpu_id': i,
'used': used,
'total': total,
'free': free,
'usage_percent': (used / total) * 100
})
return gpu_info
except Exception as e:
logger.warning(f"无法获取GPU内存信息: {e}")
return []
def _select_best_gpus(self, min_memory_gb=12):
"""选择内存充足的GPU"""
gpu_info = self._get_gpu_memory_info()
if not gpu_info:
return list(range(torch.cuda.device_count()))
# 按可用内存排序
gpu_info.sort(key=lambda x: x['free'], reverse=True)
# 选择内存充足的GPU
min_memory_mb = min_memory_gb * 1024
suitable_gpus = []
for gpu in gpu_info:
if gpu['free'] >= min_memory_mb:
suitable_gpus.append(gpu['gpu_id'])
logger.info(f"GPU {gpu['gpu_id']}: {gpu['free']}MB 可用 (合适)")
else:
logger.warning(f"GPU {gpu['gpu_id']}: {gpu['free']}MB 可用 (不足)")
if not suitable_gpus:
# 如果没有GPU满足要求选择可用内存最多的
logger.warning(f"没有GPU有足够内存({min_memory_gb}GB)选择可用内存最多的GPU")
suitable_gpus = [gpu_info[0]['gpu_id']]
return suitable_gpus
def encode_text_batch(self, texts: List[str]) -> np.ndarray:
"""
批量编码文本为向量多GPU优化
Args:
texts: 文本列表
Returns:
文本向量
"""
if not texts:
return np.array([])
with torch.no_grad():
# 预处理输入
inputs = self.tokenizer(
text=texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
)
# 将输入移动到主设备
inputs = {k: v.to(self.primary_device) for k, v in inputs.items()}
# 前向传播
outputs = self.model(**inputs)
embeddings = outputs.last_hidden_state.mean(dim=1)
# 清理GPU内存
del inputs, outputs
torch.cuda.empty_cache()
return embeddings.cpu().numpy().astype(np.float32)
def encode_image_batch(self, images: List[Union[str, Image.Image]]) -> np.ndarray:
"""
批量编码图像为向量
Args:
images: 图像路径或PIL图像列表
Returns:
图像向量
"""
if not images:
return np.array([])
# 预处理图像
processed_images = []
for img in images:
if isinstance(img, str):
img = Image.open(img).convert('RGB')
elif isinstance(img, Image.Image):
img = img.convert('RGB')
processed_images.append(img)
try:
logger.info(f"处理 {len(processed_images)} 张图像")
# 使用多模态模型生成图像embedding
# 为每张图像创建简单的文本描述作为输入
conversations = []
for i in range(len(processed_images)):
# 使用简化的对话格式
conversation = [
{
"role": "user",
"content": [
{"type": "image", "image": processed_images[i]},
{"type": "text", "text": "What is in this image?"}
]
}
]
conversations.append(conversation)
# 使用processor处理
try:
# 尝试使用apply_chat_template方法
texts = []
for conv in conversations:
text = self.processor.apply_chat_template(conv, tokenize=False, add_generation_prompt=False)
texts.append(text)
# 处理文本和图像
inputs = self.processor(
text=texts,
images=processed_images,
return_tensors="pt",
padding=True
)
# 移动到GPU
inputs = {k: v.to(self.primary_device) for k, v in inputs.items()}
# 获取模型输出
with torch.no_grad():
outputs = self.model(**inputs)
embeddings = outputs.last_hidden_state.mean(dim=1)
# 转换为numpy数组
embeddings = embeddings.cpu().numpy().astype(np.float32)
except Exception as inner_e:
logger.warning(f"多模态模型图像编码失败,使用文本模式: {inner_e}")
return np.zeros((len(processed_images), 3584), dtype=np.float32)
# 如果多模态失败使用纯文本描述作为fallback
image_descriptions = ["An image" for _ in processed_images]
text_inputs = self.processor(
text=image_descriptions,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
)
text_inputs = {k: v.to(self.primary_device) for k, v in text_inputs.items()}
with torch.no_grad():
outputs = self.model(**text_inputs)
embeddings = outputs.last_hidden_state.mean(dim=1)
embeddings = embeddings.cpu().numpy().astype(np.float32)
logger.info(f"生成图像embeddings: {embeddings.shape}")
return embeddings
except Exception as e:
logger.error(f"图像编码失败: {e}")
# 返回与文本embedding维度一致的零向量作为fallback
embedding_dim = 3584
embeddings = np.zeros((len(processed_images), embedding_dim), dtype=np.float32)
return embeddings
def build_text_index_parallel(self, texts: List[str], save_path: str = None):
"""
并行构建文本索引多GPU优化
Args:
texts: 文本列表
save_path: 索引保存路径
"""
logger.info(f"正在并行构建文本索引,共 {len(texts)} 条文本")
# 根据GPU数量调整批次大小
batch_size = max(4, 16 // self.num_gpus)
all_embeddings = []
# 分批处理
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i+batch_size]
try:
embeddings = self.encode_text_batch(batch_texts)
all_embeddings.append(embeddings)
# 显示进度
if (i // batch_size + 1) % 10 == 0:
logger.info(f"已处理 {i + len(batch_texts)}/{len(texts)} 条文本")
except torch.cuda.OutOfMemoryError:
logger.warning(f"GPU内存不足跳过批次 {i}-{i+len(batch_texts)}")
self._clear_all_gpu_memory()
continue
except Exception as e:
logger.error(f"处理文本批次时出错: {e}")
continue
if not all_embeddings:
raise ValueError("没有成功处理任何文本")
# 合并所有嵌入向量
embeddings = np.vstack(all_embeddings)
# 构建FAISS索引
dimension = embeddings.shape[1]
self.text_index = faiss.IndexFlatIP(dimension)
# 归一化向量
faiss.normalize_L2(embeddings)
self.text_index.add(embeddings)
self.text_data = texts
if save_path:
self._save_index(self.text_index, texts, save_path + "_text")
logger.info("文本索引构建完成")
def build_image_index_parallel(self, image_paths: List[str], save_path: str = None):
"""
并行构建图像索引多GPU优化
Args:
image_paths: 图像路径列表
save_path: 索引保存路径
"""
logger.info(f"正在并行构建图像索引,共 {len(image_paths)} 张图像")
# 图像处理使用更小的批次
batch_size = max(2, 8 // self.num_gpus)
all_embeddings = []
for i in range(0, len(image_paths), batch_size):
batch_images = image_paths[i:i+batch_size]
try:
embeddings = self.encode_image_batch(batch_images)
all_embeddings.append(embeddings)
# 显示进度
if (i // batch_size + 1) % 5 == 0:
logger.info(f"已处理 {i + len(batch_images)}/{len(image_paths)} 张图像")
except torch.cuda.OutOfMemoryError:
logger.warning(f"GPU内存不足跳过图像批次 {i}-{i+len(batch_images)}")
self._clear_all_gpu_memory()
continue
except Exception as e:
logger.error(f"处理图像批次时出错: {e}")
continue
if not all_embeddings:
raise ValueError("没有成功处理任何图像")
embeddings = np.vstack(all_embeddings)
# 构建FAISS索引
dimension = embeddings.shape[1]
self.image_index = faiss.IndexFlatIP(dimension)
faiss.normalize_L2(embeddings)
self.image_index.add(embeddings)
self.image_data = image_paths
if save_path:
self._save_index(self.image_index, image_paths, save_path + "_image")
logger.info("图像索引构建完成")
def search_text_by_text(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
"""文搜文:使用文本查询搜索相似文本"""
if self.text_index is None:
raise ValueError("文本索引未构建,请先调用 build_text_index_parallel")
query_embedding = self.encode_text_batch([query]).astype(np.float32)
faiss.normalize_L2(query_embedding)
scores, indices = self.text_index.search(query_embedding, top_k)
results = []
for score, idx in zip(scores[0], indices[0]):
if idx != -1:
results.append((self.text_data[idx], float(score)))
return results
def search_images_by_text(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
"""文搜图:使用文本查询搜索相似图像"""
if self.image_index is None:
raise ValueError("图像索引未构建,请先调用 build_image_index_parallel")
query_embedding = self.encode_text_batch([query]).astype(np.float32)
faiss.normalize_L2(query_embedding)
scores, indices = self.image_index.search(query_embedding, top_k)
results = []
for score, idx in zip(scores[0], indices[0]):
if idx != -1:
results.append((self.image_data[idx], float(score)))
return results
def search_images_by_image(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]:
"""图搜图:使用图像查询搜索相似图像"""
if self.image_index is None:
raise ValueError("图像索引未构建,请先调用 build_image_index_parallel")
query_embedding = self.encode_image_batch([query_image]).astype(np.float32)
faiss.normalize_L2(query_embedding)
scores, indices = self.image_index.search(query_embedding, top_k)
results = []
for score, idx in zip(scores[0], indices[0]):
if idx != -1:
results.append((self.image_data[idx], float(score)))
return results
def search_text_by_image(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]:
"""图搜文:使用图像查询搜索相似文本"""
if self.text_index is None:
raise ValueError("文本索引未构建,请先调用 build_text_index_parallel")
query_embedding = self.encode_image_batch([query_image]).astype(np.float32)
faiss.normalize_L2(query_embedding)
scores, indices = self.text_index.search(query_embedding, top_k)
results = []
for score, idx in zip(scores[0], indices[0]):
if idx != -1:
results.append((self.text_data[idx], float(score)))
return results
# Web应用兼容的方法名称
def search_text_to_image(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
"""文搜图Web应用兼容方法"""
return self.search_images_by_text(query, top_k)
def search_image_to_image(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]:
"""图搜图Web应用兼容方法"""
return self.search_images_by_image(query_image, top_k)
def search_text_to_text(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
"""文搜文Web应用兼容方法"""
return self.search_text_by_text(query, top_k)
def search_image_to_text(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]:
"""图搜文Web应用兼容方法"""
return self.search_text_by_image(query_image, top_k)
def _save_index(self, index, data, path_prefix):
"""保存索引和数据"""
faiss.write_index(index, f"{path_prefix}.index")
with open(f"{path_prefix}.json", 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
def load_index(self, path_prefix, index_type="text"):
"""加载已保存的索引"""
index = faiss.read_index(f"{path_prefix}.index")
with open(f"{path_prefix}.json", 'r', encoding='utf-8') as f:
data = json.load(f)
if index_type == "text":
self.text_index = index
self.text_data = data
else:
self.image_index = index
self.image_data = data
logger.info(f"已加载 {index_type} 索引")
def get_gpu_memory_info(self):
"""获取所有GPU内存使用信息"""
memory_info = {}
for gpu_id in self.device_ids:
torch.cuda.set_device(gpu_id)
allocated = torch.cuda.memory_allocated(gpu_id) / 1024**3
cached = torch.cuda.memory_reserved(gpu_id) / 1024**3
total = torch.cuda.get_device_properties(gpu_id).total_memory / 1024**3
free = total - cached
memory_info[f"GPU_{gpu_id}"] = {
"total": f"{total:.1f}GB",
"allocated": f"{allocated:.1f}GB",
"cached": f"{cached:.1f}GB",
"free": f"{free:.1f}GB"
}
return memory_info
def check_multigpu_info():
"""检查多GPU环境信息"""
print("=== 多GPU环境信息 ===")
if not torch.cuda.is_available():
print("❌ CUDA不可用")
return
gpu_count = torch.cuda.device_count()
print(f"✅ 检测到 {gpu_count} 个GPU")
print(f"CUDA版本: {torch.version.cuda}")
print(f"PyTorch版本: {torch.__version__}")
for i in range(gpu_count):
gpu_name = torch.cuda.get_device_name(i)
gpu_memory = torch.cuda.get_device_properties(i).total_memory / 1024**3
print(f"GPU {i}: {gpu_name} ({gpu_memory:.1f}GB)")
print("=====================")
if __name__ == "__main__":
# 检查多GPU环境
check_multigpu_info()
# 示例使用
print("\n正在初始化多GPU多模态检索系统...")
try:
retrieval_system = MultiGPUMultimodalRetrieval()
print("✅ 多GPU系统初始化成功")
# 显示GPU内存使用情况
memory_info = retrieval_system.get_gpu_memory_info()
print("\n📊 GPU内存使用情况:")
for gpu, info in memory_info.items():
print(f" {gpu}: {info['allocated']} / {info['total']} (已用/总计)")
print("\n🚀 多GPU多模态检索系统就绪")
print("支持的检索模式:")
print("1. 文搜文: search_text_by_text()")
print("2. 文搜图: search_images_by_text()")
print("3. 图搜图: search_images_by_image()")
print("4. 图搜文: search_text_by_image()")
except Exception as e:
print(f"❌ 多GPU系统初始化失败: {e}")
import traceback
traceback.print_exc()

View File

@ -18,19 +18,45 @@ class OpsMMEmbeddingV1(nn.Module):
device: str = "cuda", device: str = "cuda",
max_length: Optional[int] = None, max_length: Optional[int] = None,
attn_implementation: Optional[str] = None, attn_implementation: Optional[str] = None,
device_map: Optional[str] = None,
load_in_4bit: bool = False,
load_in_8bit: bool = False,
torch_dtype: Optional[torch.dtype] = torch.bfloat16,
processor_min_pixels: int = 128 * 28 * 28,
processor_max_pixels: int = 512 * 28 * 28,
): ):
super().__init__() super().__init__()
self.device = device self.device = device
self.max_length = max_length self.max_length = max_length
self.default_instruction = "You are a helpful assistant." self.default_instruction = "You are a helpful assistant."
self.base_model = AutoModelForImageTextToText.from_pretrained( from transformers import AutoModelForImageTextToText
model_name, load_kwargs = dict(
torch_dtype=torch.bfloat16, torch_dtype=torch_dtype,
low_cpu_mem_usage=True, low_cpu_mem_usage=True,
attn_implementation=attn_implementation, attn_implementation=attn_implementation,
).to(self.device) )
# Quantization options (requires bitsandbytes for 4/8-bit)
if load_in_4bit:
load_kwargs["load_in_4bit"] = True
if load_in_8bit:
load_kwargs["load_in_8bit"] = True
if device_map is not None:
load_kwargs["device_map"] = device_map
self.processor = AutoProcessor.from_pretrained(model_name, min_pixels=256 * 28 * 28, max_pixels=1280 * 28 * 28) self.base_model = AutoModelForImageTextToText.from_pretrained(
model_name,
**load_kwargs,
)
# Only move to a single device when not using tensor-parallel sharding
if device_map is None:
self.base_model = self.base_model.to(self.device)
# Use configurable pixel limits to control VRAM usage
self.processor = AutoProcessor.from_pretrained(
model_name,
min_pixels=processor_min_pixels,
max_pixels=processor_max_pixels,
)
self.processor.tokenizer.padding_side = "left" self.processor.tokenizer.padding_side = "left"
self.eval() self.eval()
@ -120,9 +146,11 @@ class OpsMMEmbeddingV1(nn.Module):
input_texts.append(msg) input_texts.append(msg)
input_images.append(processed_image) input_images.append(processed_image)
# Only pass to processor if we actually have images # Only pass images when present; some processors expect paired inputs and
processed_images = input_images if any(img is not None for img in input_images) else None # can raise unpack errors if we pass images=None with multi-modal processor.
has_images = any(img is not None for img in input_images)
if has_images:
processed_images = input_images
inputs = self.processor( inputs = self.processor(
text=input_texts, text=input_texts,
images=processed_images, images=processed_images,
@ -131,6 +159,14 @@ class OpsMMEmbeddingV1(nn.Module):
max_length=self.max_length, max_length=self.max_length,
return_tensors="pt", return_tensors="pt",
) )
else:
inputs = self.processor(
text=input_texts,
padding=True,
truncation=True,
max_length=self.max_length,
return_tensors="pt",
)
inputs = {k: v.to(self.device) for k, v in inputs.items()} inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.inference_mode(): with torch.inference_mode():

367
optimized_file_handler.py Normal file
View File

@ -0,0 +1,367 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
优化的文件处理器
支持自动清理内存处理和流式上传
"""
import os
import io
import tempfile
import logging
import uuid
from contextlib import contextmanager
from typing import Dict, List, Optional, Any, Union, BinaryIO
from pathlib import Path
from PIL import Image
import numpy as np
# Optional external managers (BOS/Mongo) are disabled in minimal setup
# Previously:
# from baidu_bos_manager import get_bos_manager
# from mongodb_manager import get_mongodb_manager
logger = logging.getLogger(__name__)
class OptimizedFileHandler:
"""优化的文件处理器"""
# 小文件阈值 (5MB)
SMALL_FILE_THRESHOLD = 5 * 1024 * 1024
# 支持的图像格式
SUPPORTED_IMAGE_FORMATS = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp'}
def __init__(self, local_storage_dir=None):
# In minimal setup, BOS and MongoDB are not used
self.bos_manager = None
self.mongodb_manager = None
self.temp_files = set() # 跟踪临时文件
self.local_storage_dir = local_storage_dir or tempfile.gettempdir()
# 确保本地存储目录存在
if self.local_storage_dir:
os.makedirs(self.local_storage_dir, exist_ok=True)
@contextmanager
def temp_file_context(self, content: bytes = None, suffix: str = None, delete_on_exit: bool = True):
"""临时文件上下文管理器,确保自动清理"""
temp_fd, temp_path = tempfile.mkstemp(suffix=suffix, dir=self.local_storage_dir)
self.temp_files.add(temp_path)
# 如果提供了内容,写入文件
if content is not None:
with os.fdopen(temp_fd, 'wb') as f:
f.write(content)
else:
os.close(temp_fd) # 关闭文件描述符
try:
yield temp_path
finally:
if delete_on_exit and os.path.exists(temp_path):
try:
os.unlink(temp_path)
self.temp_files.discard(temp_path)
logger.debug(f"🗑️ 临时文件已清理: {temp_path}")
except Exception as e:
logger.warning(f"⚠️ 临时文件清理失败: {temp_path}, {e}")
def cleanup_all_temp_files(self):
"""清理所有跟踪的临时文件"""
for temp_path in list(self.temp_files):
if os.path.exists(temp_path):
try:
os.unlink(temp_path)
logger.debug(f"🗑️ 清理临时文件: {temp_path}")
except Exception as e:
logger.warning(f"⚠️ 清理临时文件失败: {temp_path}, {e}")
self.temp_files.clear()
def get_file_size(self, file_obj) -> int:
"""获取文件大小"""
if hasattr(file_obj, 'content_length') and file_obj.content_length:
return file_obj.content_length
# 通过读取内容获取大小
current_pos = file_obj.tell()
file_obj.seek(0, 2) # 移动到文件末尾
size = file_obj.tell()
file_obj.seek(current_pos) # 恢复原位置
return size
def is_small_file(self, file_obj) -> bool:
"""判断是否为小文件"""
return self.get_file_size(file_obj) <= self.SMALL_FILE_THRESHOLD
def process_image_in_memory(self, file_obj, filename: str) -> Optional[Dict[str, Any]]:
"""在内存中处理小图像文件"""
try:
# 读取文件内容到内存
file_obj.seek(0)
file_content = file_obj.read()
file_obj.seek(0)
# 验证图像格式
try:
image = Image.open(io.BytesIO(file_content))
image.verify() # 验证图像完整性
except Exception as e:
logger.error(f"❌ 图像验证失败: {filename}, {e}")
return None
# 生成唯一ID
file_id = str(uuid.uuid4())
# 保存到本地存储
local_path = os.path.join(self.local_storage_dir, f"{file_id}_{filename}")
with open(local_path, 'wb') as f:
f.write(file_content)
# 存储元数据到MongoDB
metadata = {
"_id": file_id,
"filename": filename,
"file_type": "image",
"file_size": len(file_content),
"processing_method": "memory",
"local_path": local_path
}
# 如果有BOS管理器也上传到BOS
if self.bos_manager:
bos_key = f"images/memory_{file_id}_{filename}"
bos_result = self._upload_to_bos_from_memory(file_content, bos_key, filename)
if bos_result:
metadata["bos_key"] = bos_key
metadata["bos_url"] = bos_result["url"]
if self.mongodb_manager:
self.mongodb_manager.store_file_metadata(metadata=metadata)
logger.info(f"✅ 内存处理图像成功: {filename} ({len(file_content)} bytes)")
return {
"file_id": file_id,
"filename": filename,
"local_path": local_path,
"processing_method": "memory"
}
except Exception as e:
logger.error(f"❌ 内存处理图像失败: {filename}, {e}")
return None
def process_image_with_temp_file(self, file_obj, filename: str) -> Optional[Dict[str, Any]]:
"""使用临时文件处理大图像文件"""
try:
# 获取文件扩展名
ext = os.path.splitext(filename)[1].lower()
# 生成唯一ID
file_id = str(uuid.uuid4())
# 创建永久文件路径
permanent_path = os.path.join(self.local_storage_dir, f"{file_id}_{filename}")
with self.temp_file_context(suffix=ext) as temp_path:
# 保存到临时文件
file_obj.seek(0)
with open(temp_path, 'wb') as temp_file:
temp_file.write(file_obj.read())
# 验证图像
try:
with Image.open(temp_path) as image:
image.verify()
except Exception as e:
logger.error(f"❌ 图像验证失败: {filename}, {e}")
return None
# 复制到永久存储位置
with open(temp_path, 'rb') as src, open(permanent_path, 'wb') as dst:
dst.write(src.read())
# 获取文件信息
file_stat = os.stat(permanent_path)
# 存储元数据
metadata = {
"_id": file_id,
"filename": filename,
"file_type": "image",
"file_size": file_stat.st_size,
"processing_method": "temp_file",
"local_path": permanent_path
}
# 如果有BOS管理器也上传到BOS
if self.bos_manager:
bos_key = f"images/temp_{file_id}_{filename}"
bos_result = self.bos_manager.upload_file(temp_path, bos_key)
if bos_result:
metadata["bos_key"] = bos_key
metadata["bos_url"] = bos_result["url"]
# 存储元数据到MongoDB
if self.mongodb_manager:
self.mongodb_manager.store_file_metadata(metadata=metadata)
logger.info(f"✅ 临时文件处理图像成功: {filename} ({file_stat.st_size} bytes)")
return {
"file_id": file_id,
"filename": filename,
"local_path": permanent_path,
"processing_method": "temp_file"
}
except Exception as e:
logger.error(f"❌ 临时文件处理图像失败: {filename}, {e}")
return None
def process_image_smart(self, file_obj, filename: str) -> Optional[Dict[str, Any]]:
"""智能处理图像文件(自动选择内存或临时文件)"""
if self.is_small_file(file_obj):
logger.info(f"📦 小文件内存处理: {filename}")
return self.process_image_in_memory(file_obj, filename)
else:
logger.info(f"📁 大文件临时处理: {filename}")
return self.process_image_with_temp_file(file_obj, filename)
def process_text_in_memory(self, texts: List[str]) -> List[Dict[str, Any]]:
"""在内存中处理文本数据"""
processed_texts = []
for i, text in enumerate(texts):
try:
# 生成唯一ID和BOS键
file_id = str(uuid.uuid4())
bos_key = f"texts/memory_{file_id}.txt"
# 将文本转换为字节
text_bytes = text.encode('utf-8')
# 直接上传到BOS
bos_result = self._upload_to_bos_from_memory(
text_bytes, bos_key, f"text_{i}.txt"
)
if bos_result:
# 存储元数据到MongoDB
metadata = {
"_id": file_id,
"filename": f"text_{i}.txt",
"file_type": "text",
"file_size": len(text_bytes),
"processing_method": "memory",
"bos_key": bos_key,
"bos_url": bos_result["url"],
"text_content": text
}
self.mongodb_manager.store_file_metadata(metadata=metadata)
processed_texts.append({
"file_id": file_id,
"text_content": text,
"bos_key": bos_key,
"bos_result": bos_result
})
logger.info(f"✅ 内存处理文本成功: text_{i} ({len(text_bytes)} bytes)")
except Exception as e:
logger.error(f"❌ 内存处理文本失败 {i}: {e}")
return processed_texts
def download_from_bos_for_processing(self, bos_key: str, local_filename: str = None) -> Optional[str]:
"""从BOS下载文件用于模型处理"""
try:
# 生成临时文件路径
if local_filename:
ext = os.path.splitext(local_filename)[1]
else:
ext = os.path.splitext(bos_key)[1]
with self.temp_file_context(suffix=ext, delete_on_exit=False) as temp_path:
# 从BOS下载文件
if not self.bos_manager:
logger.warning("BOS manager is not available; skip download.")
return None
success = self.bos_manager.download_file(bos_key, temp_path)
if success:
logger.info(f"✅ 从BOS下载文件用于处理: {bos_key}")
return temp_path
else:
logger.error(f"❌ 从BOS下载文件失败: {bos_key}")
return None
except Exception as e:
logger.error(f"❌ 从BOS下载文件异常: {bos_key}, {e}")
return None
def _upload_to_bos_from_memory(self, content: bytes, bos_key: str, filename: str) -> Optional[Dict[str, Any]]:
"""从内存直接上传到BOS"""
try:
if not self.bos_manager:
return None
# 创建临时文件用于上传
with self.temp_file_context() as temp_path:
with open(temp_path, 'wb') as temp_file:
temp_file.write(content)
result = self.bos_manager.upload_file(temp_path, bos_key)
return result
except Exception as e:
logger.error(f"❌ 内存上传到BOS失败: {filename}, {e}")
return None
def get_temp_file_for_model(self, file_obj, filename: str) -> Optional[str]:
"""为模型处理获取临时文件路径(确保文件存在于本地)"""
try:
ext = os.path.splitext(filename)[1].lower()
# 生成唯一ID
file_id = str(uuid.uuid4())
# 创建临时文件(不自动删除,供模型使用)
temp_fd, temp_path = tempfile.mkstemp(suffix=ext, dir=self.local_storage_dir)
self.temp_files.add(temp_path)
try:
# 写入文件内容
file_obj.seek(0)
with os.fdopen(temp_fd, 'wb') as temp_file:
temp_file.write(file_obj.read())
logger.debug(f"📁 为模型创建临时文件: {temp_path}")
return temp_path
except Exception as e:
os.close(temp_fd)
raise e
except Exception as e:
logger.error(f"❌ 为模型创建临时文件失败: {filename}, {e}")
return None
def cleanup_temp_file(self, temp_path: str):
"""清理指定的临时文件"""
if temp_path and os.path.exists(temp_path):
try:
os.unlink(temp_path)
self.temp_files.discard(temp_path)
logger.debug(f"🗑️ 清理临时文件: {temp_path}")
except Exception as e:
logger.warning(f"⚠️ 清理临时文件失败: {temp_path}, {e}")
# 全局实例
file_handler = None
def get_file_handler() -> OptimizedFileHandler:
"""获取优化文件处理器实例"""
global file_handler
if file_handler is None:
file_handler = OptimizedFileHandler()
return file_handler

View File

@ -1,12 +1,11 @@
torch>=2.0.0 torch>=2.0.0
torchvision>=0.15.0
transformers>=4.30.0 transformers>=4.30.0
accelerate>=0.20.0 accelerate>=0.20.0
faiss-cpu>=1.7.4 faiss-cpu>=1.7.4
numpy>=1.21.0 numpy>=1.21.0
Pillow>=9.0.0 Pillow>=9.0.0
scikit-learn>=1.3.0
tqdm>=4.65.0 tqdm>=4.65.0
flask>=2.3.0 flask>=2.3.0
werkzeug>=2.3.0 werkzeug>=2.3.0
psutil>=5.9.0 requests>=2.31.0
safetensors>=0.4.0

14
run_server.sh Normal file
View File

@ -0,0 +1,14 @@
#!/usr/bin/env bash
set -euo pipefail
# Select GPUs 0 and 1 for tensor-parallel sharding
export CUDA_VISIBLE_DEVICES=0,1
# Unbuffered stdout for real-time logs
export PYTHONUNBUFFERED=1
# Help PyTorch allocator avoid fragmentation (see OOM hint)
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
# Start the local web app
exec python3 web_app_local.py "$@"

Binary file not shown.

Before

Width:  |  Height:  |  Size: 109 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 189 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 201 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 166 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 150 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 312 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 325 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.5 KiB

BIN
static/favicon.ico Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 785 KiB

View File

@ -3,7 +3,8 @@
<head> <head>
<meta charset="UTF-8"> <meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0"> <meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>多模态检索系统</title> <title>本地多模态检索系统 - FAISS</title>
<link rel="icon" href="/favicon.ico" />
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css" rel="stylesheet"> <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css" rel="stylesheet">
<link href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css" rel="stylesheet"> <link href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css" rel="stylesheet">
<style> <style>
@ -153,6 +154,25 @@
right: 20px; right: 20px;
z-index: 1000; z-index: 1000;
} }
.status-bar {
position: fixed;
top: 20px;
right: 120px;
background: rgba(255,255,255,0.9);
border-radius: 12px;
padding: 8px 12px;
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
font-size: 12px;
color: #334155;
display: flex;
gap: 12px;
align-items: center;
}
.status-item {
display: flex;
gap: 6px;
align-items: center;
}
.fade-in { .fade-in {
animation: fadeIn 0.5s ease-in; animation: fadeIn 0.5s ease-in;
@ -179,13 +199,28 @@
<i class="fas fa-circle-notch fa-spin"></i> 未初始化 <i class="fas fa-circle-notch fa-spin"></i> 未初始化
</div> </div>
</div> </div>
<!-- 右上角状态栏 -->
<div class="status-bar">
<div class="status-item" title="Vector Dimension">
<i class="fas fa-ruler-combined text-primary"></i>
<span id="statusVectorDim">-</span>
</div>
<div class="status-item" title="Total Vectors">
<i class="fas fa-database text-success"></i>
<span id="statusTotalVectors">-</span>
</div>
<div class="status-item" title="Server Time (UTC)">
<i class="fas fa-clock text-warning"></i>
<span id="statusServerTime">-</span>
</div>
</div>
<div class="container-fluid"> <div class="container-fluid">
<div class="main-container"> <div class="main-container">
<!-- 头部 --> <!-- 头部 -->
<div class="header"> <div class="header">
<h1><i class="fas fa-search"></i> 多模态检索系统</h1> <h1><i class="fas fa-search"></i> 本地多模态检索系统</h1>
<p class="mb-0">支持文搜图、文搜文、图搜图、图搜文四种检索模式</p> <p class="mb-0">基于本地模型和FAISS向量数据库支持文搜图、文搜文、图搜图、图搜文四种检索模式</p>
</div> </div>
<div class="p-4"> <div class="p-4">
@ -284,9 +319,7 @@
<div class="row mt-4"> <div class="row mt-4">
<div class="col-md-8"> <div class="col-md-8">
<div class="d-flex gap-3"> <div class="d-flex gap-3">
<button id="buildIndexBtn" class="btn btn-warning" disabled> <!-- 移除构建索引按钮,改为自动构建 -->
<i class="fas fa-cogs"></i> 构建索引
</button>
<button id="viewDataBtn" class="btn btn-info"> <button id="viewDataBtn" class="btn btn-info">
<i class="fas fa-list"></i> 查看数据 <i class="fas fa-list"></i> 查看数据
</button> </button>
@ -323,7 +356,11 @@
<option value="3">Top 3</option> <option value="3">Top 3</option>
<option value="5" selected>Top 5</option> <option value="5" selected>Top 5</option>
<option value="10">Top 10</option> <option value="10">Top 10</option>
<option value="50">Top 50</option>
<option value="100">Top 100</option>
<option value="custom">自定义</option>
</select> </select>
<input id="textTopKCustom" type="number" min="1" max="1000" placeholder="自定义K" class="form-control mt-2" style="display:none;" />
</div> </div>
<div class="col-md-2"> <div class="col-md-2">
<button id="textSearchBtn" class="btn btn-primary w-100"> <button id="textSearchBtn" class="btn btn-primary w-100">
@ -349,7 +386,11 @@
<option value="3">Top 3</option> <option value="3">Top 3</option>
<option value="5" selected>Top 5</option> <option value="5" selected>Top 5</option>
<option value="10">Top 10</option> <option value="10">Top 10</option>
<option value="50">Top 50</option>
<option value="100">Top 100</option>
<option value="custom">自定义</option>
</select> </select>
<input id="imageTopKCustom" type="number" min="1" max="1000" placeholder="自定义K" class="form-control mt-2" style="display:none;" />
</div> </div>
<div class="col-md-2"> <div class="col-md-2">
<button id="imageSearchBtn" class="btn btn-primary w-100" disabled> <button id="imageSearchBtn" class="btn btn-primary w-100" disabled>
@ -388,7 +429,7 @@
btn.disabled = true; btn.disabled = true;
try { try {
const response = await fetch('/api/init', { const response = await fetch('/api/system_info', {
method: 'POST', method: 'POST',
headers: {'Content-Type': 'application/json'} headers: {'Content-Type': 'application/json'}
}); });
@ -401,7 +442,7 @@
'<i class="fas fa-check-circle"></i> 已重新初始化'; '<i class="fas fa-check-circle"></i> 已重新初始化';
document.getElementById('statusBadge').className = 'badge bg-success'; document.getElementById('statusBadge').className = 'badge bg-success';
showAlert('success', `系统重新初始化成功GPU: ${data.gpu_count} 个`); showAlert('success', `系统重新初始化成功GPU信息: ${data.gpu_info.length} 个, 向量数量: ${data.retrieval_info.total_vectors || 0}`);
} else { } else {
throw new Error(data.message); throw new Error(data.message);
} }
@ -448,9 +489,34 @@
if (e.key === 'Enter') performTextSearch(); if (e.key === 'Enter') performTextSearch();
}); });
// TopK 自定义输入显隐
const textTopKSel = document.getElementById('textTopK');
const textTopKCustom = document.getElementById('textTopKCustom');
const imageTopKSel = document.getElementById('imageTopK');
const imageTopKCustom = document.getElementById('imageTopKCustom');
textTopKSel.addEventListener('change', () => {
textTopKCustom.style.display = textTopKSel.value === 'custom' ? 'block' : 'none';
});
imageTopKSel.addEventListener('change', () => {
imageTopKCustom.style.display = imageTopKSel.value === 'custom' ? 'block' : 'none';
});
function resolveTopK(selectEl, customEl) {
let v = selectEl.value;
if (v === 'custom') {
const n = parseInt(customEl.value, 10);
if (!Number.isFinite(n) || n <= 0) {
showAlert('warning', '请输入有效的自定义 Top K 数值');
throw new Error('Invalid custom TopK');
}
return n;
}
return parseInt(v, 10);
}
async function performTextSearch() { async function performTextSearch() {
const query = document.getElementById('textQuery').value.trim(); const query = document.getElementById('textQuery').value.trim();
const topK = parseInt(document.getElementById('textTopK').value); const topK = resolveTopK(textTopKSel, textTopKCustom);
if (!query) { if (!query) {
showAlert('warning', '请输入搜索文本'); showAlert('warning', '请输入搜索文本');
@ -460,11 +526,12 @@
showLoading(true); showLoading(true);
try { try {
const endpoint = currentMode === 'text_to_text' ? '/api/search/text_to_text' : '/api/search/text_to_image'; const endpoint = '/api/search_by_text';
const filter_type = currentMode === 'text_to_text' ? 'text' : 'image';
const response = await fetch(endpoint, { const response = await fetch(endpoint, {
method: 'POST', method: 'POST',
headers: {'Content-Type': 'application/json'}, headers: {'Content-Type': 'application/json'},
body: JSON.stringify({query, top_k: topK}) body: JSON.stringify({query, k: topK, filter_type: filter_type})
}); });
const data = await response.json(); const data = await response.json();
@ -529,7 +596,7 @@
// 图片搜索 // 图片搜索
document.getElementById('imageSearchBtn').addEventListener('click', async function() { document.getElementById('imageSearchBtn').addEventListener('click', async function() {
const file = imageFile.files[0]; const file = imageFile.files[0];
const topK = parseInt(document.getElementById('imageTopK').value); const topK = resolveTopK(imageTopKSel, imageTopKCustom);
if (!file) { if (!file) {
showAlert('warning', '请选择图片文件'); showAlert('warning', '请选择图片文件');
@ -539,11 +606,13 @@
showLoading(true); showLoading(true);
try { try {
const endpoint = '/api/search_by_image';
const filter_type = currentMode === 'image_to_text' ? 'text' : 'image';
const formData = new FormData(); const formData = new FormData();
formData.append('image', file); formData.append('image', file);
formData.append('top_k', topK); formData.append('k', topK);
formData.append('filter_type', filter_type);
const endpoint = currentMode === 'image_to_text' ? '/api/search/image_to_text' : '/api/search/image_to_image';
const response = await fetch(endpoint, { const response = await fetch(endpoint, {
method: 'POST', method: 'POST',
body: formData body: formData
@ -572,17 +641,18 @@
<div class="d-flex justify-content-between align-items-center mb-3"> <div class="d-flex justify-content-between align-items-center mb-3">
<h4><i class="fas fa-search-plus"></i> 搜索结果</h4> <h4><i class="fas fa-search-plus"></i> 搜索结果</h4>
<div> <div>
<span class="badge bg-info">找到 ${data.result_count} 个结果</span> <span class="badge bg-info">找到 ${data.results?.length || 0} 个结果</span>
<span class="badge bg-secondary">耗时 ${data.search_time}s</span> <span class="badge bg-secondary">耗时 ${data.search_time || data.time || '0.0'}s</span>
</div> </div>
</div> </div>
`; `;
if (data.query_image) { if (data.query_image) {
const imageUrl = data.query_image.startsWith('data:') ? data.query_image : `data:image/jpeg;base64,${data.query_image}`;
html += ` html += `
<div class="result-card"> <div class="result-card">
<h6><i class="fas fa-image"></i> 查询图片</h6> <h6><i class="fas fa-image"></i> 查询图片</h6>
<img src="data:image/jpeg;base64,${data.query_image}" class="query-image"> <img src="${imageUrl}" class="query-image">
</div> </div>
`; `;
} }
@ -600,29 +670,40 @@
html += '<div class="result-card">'; html += '<div class="result-card">';
if (mode === 'text_to_image' || mode === 'image_to_image') { if (mode === 'text_to_image' || mode === 'image_to_image') {
const imageUrl = result.image_base64 ? `data:image/jpeg;base64,${result.image_base64}` :
(result.image_url || `/temp/${result.filename || result.id}`);
const score = result.score || result.distance ?
(result.score ? (result.score * 100).toFixed(1) : (100 - result.distance * 100).toFixed(1)) : '95.0';
const title = result.title || result.filename || result.id || `结果 ${index + 1}`;
html += ` html += `
<div class="row"> <div class="row">
<div class="col-md-3"> <div class="col-md-3">
<img src="data:image/jpeg;base64,${result.image_base64}" <img src="${imageUrl}" class="result-image" alt="Result ${index + 1}">
class="result-image" alt="Result ${index + 1}">
</div> </div>
<div class="col-md-9"> <div class="col-md-9">
<div class="d-flex justify-content-between align-items-start"> <div class="d-flex justify-content-between align-items-start">
<h6><i class="fas fa-image"></i> ${result.filename}</h6> <h6><i class="fas fa-image"></i> ${title}</h6>
<span class="score-badge">相似度: ${(result.score * 100).toFixed(1)}%</span> <span class="score-badge">相似度: ${score}%</span>
</div> </div>
<p class="text-muted mb-0">路径: ${result.image_path}</p> <p class="text-muted mb-0">类型: 图片 | ID: ${result.id || index}</p>
</div> </div>
</div> </div>
`; `;
} else { } else {
const text = result.text || result.content || (typeof result === 'string' ? result : JSON.stringify(result));
const score = result.score || result.distance ?
(result.score ? (result.score * 100).toFixed(1) : (100 - result.distance * 100).toFixed(1)) : '95.0';
const title = result.title || `结果 ${index + 1}`;
html += ` html += `
<div class="d-flex justify-content-between align-items-start"> <div class="d-flex justify-content-between align-items-start">
<div> <div>
<h6><i class="fas fa-file-text"></i> 结果 ${index + 1}</h6> <h6><i class="fas fa-file-text"></i> ${title}</h6>
<p class="mb-0">${result.text || result}</p> <p class="mb-0">${text}</p>
<p class="text-muted small mb-0">类型: 文本 | ID: ${result.id || index}</p>
</div> </div>
<span class="score-badge">相似度: ${((result.score || 0.95) * 100).toFixed(1)}%</span> <span class="score-badge">相似度: ${score}%</span>
</div> </div>
`; `;
} }
@ -659,10 +740,10 @@
// 检查系统状态 // 检查系统状态
async function checkStatus() { async function checkStatus() {
try { try {
const response = await fetch('/api/status'); const response = await fetch('/api/system_info');
const data = await response.json(); const data = await response.json();
if (data.initialized) { if (data.success) {
systemInitialized = true; systemInitialized = true;
document.getElementById('statusBadge').innerHTML = document.getElementById('statusBadge').innerHTML =
'<i class="fas fa-check-circle"></i> 已初始化'; '<i class="fas fa-check-circle"></i> 已初始化';
@ -680,8 +761,28 @@
} }
} }
// 页面加载时检查状态 // 轮询 /api/stats 更新右上角状态栏
async function updateStatusBar() {
try {
const res = await fetch('/api/stats');
const data = await res.json();
if (data && data.success) {
const dim = data.debug?.vector_dimension ?? data.stats?.vector_dimension ?? '-';
const total = data.debug?.total_vectors ?? data.stats?.total_vectors ?? '-';
const time = data.debug?.server_time ?? '-';
document.getElementById('statusVectorDim').textContent = dim;
document.getElementById('statusTotalVectors').textContent = total;
document.getElementById('statusServerTime').textContent = time.replace('T',' ').replace('Z','');
}
} catch (e) {
// 忽略一次失败,等待下次轮询
}
}
// 页面加载时检查状态并开始轮询状态栏
checkStatus(); checkStatus();
updateStatusBar();
setInterval(updateStatusBar, 5000);
// 设置数据管理功能事件绑定 // 设置数据管理功能事件绑定
setupDataManagement(); setupDataManagement();
@ -744,8 +845,7 @@
} }
}); });
// 构建索引 // 移除构建索引按钮的事件监听器
document.getElementById('buildIndexBtn').addEventListener('click', buildIndex);
// 查看数据 // 查看数据
document.getElementById('viewDataBtn').addEventListener('click', viewData); document.getElementById('viewDataBtn').addEventListener('click', viewData);
@ -757,8 +857,9 @@
updateDataStats(); updateDataStats();
} }
// 批量上传图片 // 批量上传图片(串行:并发=1每次上传一张调用 /api/add_image
async function uploadBatchImages(files) { async function uploadBatchImages(files) {
try {
const progressDiv = document.getElementById('imageUploadProgress'); const progressDiv = document.getElementById('imageUploadProgress');
const progressBar = progressDiv.querySelector('.progress-bar'); const progressBar = progressDiv.querySelector('.progress-bar');
const progressText = document.getElementById('imageProgressText'); const progressText = document.getElementById('imageProgressText');
@ -767,120 +868,116 @@
progressText.textContent = `0/${files.length}`; progressText.textContent = `0/${files.length}`;
progressBar.style.width = '0%'; progressBar.style.width = '0%';
showAlert('info', `开始上传 ${files.length} 张图片(串行)...`);
let successCount = 0;
for (let i = 0; i < files.length; i++) {
const formData = new FormData(); const formData = new FormData();
files.forEach(file => { formData.append('image', files[i]);
formData.append('images', file);
});
try { try {
const response = await fetch('/api/upload/images', { const resp = await fetch('/api/add_image', { method: 'POST', body: formData });
method: 'POST', const data = await resp.json();
body: formData if (data && data.success) {
}); successCount++;
const data = await response.json();
if (data.success) {
progressBar.style.width = '100%';
progressText.textContent = `${files.length}/${files.length}`;
showAlert('success', `成功上传 ${data.uploaded_count} 张图片`);
updateDataStats();
document.getElementById('buildIndexBtn').disabled = false;
} else { } else {
showAlert('danger', `上传失败: ${data.message}`); // 不中断流程,记录失败
console.warn('Upload failed for file:', files[i].name, data?.error);
} }
} catch (e) {
console.warn('Request failed for file:', files[i].name, e);
}
const progress = Math.round(((i + 1) / files.length) * 100);
progressBar.style.width = `${progress}%`;
progressText.textContent = `${i + 1}/${files.length}`;
}
showAlert('success', `上传完成:成功 ${successCount}/${files.length} 张`);
await autoSaveIndex();
updateDataStats();
} catch (error) { } catch (error) {
showAlert('danger', `上传错误: ${error.message}`); showAlert('danger', `图片上传失败: ${error.message}`);
} finally { } finally {
setTimeout(() => { setTimeout(() => {
progressDiv.style.display = 'none'; document.getElementById('imageUploadProgress').style.display = 'none';
}, 2000); }, 1500);
}
} }
}
// 批量上传文本 // 批量上传文本调用批量API以利用多卡并行
async function uploadBatchTexts(texts) { async function uploadBatchTexts(texts) {
try { try {
const response = await fetch('/api/upload/texts', { showAlert('info', `正在批量上传 ${texts.length} 条文本...`);
const response = await fetch('/api/add_texts_batch', {
method: 'POST', method: 'POST',
headers: { headers: {'Content-Type': 'application/json'},
'Content-Type': 'application/json', body: JSON.stringify({texts})
},
body: JSON.stringify({ texts: texts })
}); });
const data = await response.json(); const data = await response.json();
if (!data.success) {
if (data.success) { throw new Error(data.error || '批量上传失败');
showAlert('success', `成功上传 ${data.uploaded_count} 条文本`); }
document.getElementById('batchTextInput').value = ''; showAlert('success', data.message || `成功上传 ${texts.length} 条文本`);
await autoSaveIndex();
updateDataStats(); updateDataStats();
document.getElementById('buildIndexBtn').disabled = false;
} else {
showAlert('danger', `上传失败: ${data.message}`);
}
} catch (error) { } catch (error) {
showAlert('danger', `上传错误: ${error.message}`); showAlert('danger', `文本上传失败: ${error.message}`);
}
} }
}
// 构建索引 // 自动保存索引函数
async function buildIndex() { async function autoSaveIndex() {
const btn = document.getElementById('buildIndexBtn');
const originalText = btn.innerHTML;
btn.innerHTML = '<i class="fas fa-spinner fa-spin"></i> 构建中...';
btn.disabled = true;
try { try {
const response = await fetch('/api/build_index', { const response = await fetch('/api/save_index', {
method: 'POST' method: 'POST'
}); });
const data = await response.json(); const data = await response.json();
if (data.success) { if (data.success) {
showAlert('success', '索引构建完成!现在可以进行搜索了'); console.log('索引自动保存成功');
} else { } else {
showAlert('danger', `索引构建失败: ${data.message}`); console.error(`索引自动保存失败: ${data.message}`);
} }
} catch (error) { } catch (error) {
showAlert('danger', `构建错误: ${error.message}`); console.error(`索引自动保存错误: ${error.message}`);
} finally {
btn.innerHTML = originalText;
btn.disabled = false;
} }
} }
// 查看数据 // 查看数据
async function viewData() { async function viewData() {
try { try {
const response = await fetch('/api/data/list'); const response = await fetch('/api/list_items');
const data = await response.json(); const data = await response.json();
if (data.success) { if (data.success) {
let content = '<div class="row">'; let content = '<div class="row">';
// 显示图片数据 // 显示图片数据
if (data.images && data.images.length > 0) { if (data.items && data.items.filter(item => item.type === 'image').length > 0) {
content += '<div class="col-md-6"><h6>图片数据 (' + data.images.length + ')</h6>'; const imageItems = data.items.filter(item => item.type === 'image');
content += '<div class="col-md-6"><h6>图片数据 (' + imageItems.length + ')</h6>';
content += '<div class="list-group" style="max-height: 300px; overflow-y: auto;">'; content += '<div class="list-group" style="max-height: 300px; overflow-y: auto;">';
data.images.forEach(img => { imageItems.forEach(item => {
content += `<div class="list-group-item d-flex justify-content-between align-items-center"> content += `<div class="list-group-item d-flex justify-content-between align-items-center">
<span>${img}</span> <span>${item.id}: ${item.metadata?.title || '无标题'}</span>
<img src="/uploads/${img}" class="img-thumbnail" style="width: 50px; height: 50px; object-fit: cover;"> <img src="/temp/${item.filename || item.id}" class="img-thumbnail" style="width: 50px; height: 50px; object-fit: cover;">
</div>`; </div>`;
}); });
content += '</div></div>'; content += '</div></div>';
} }
// 显示文本数据 // 显示文本数据
if (data.texts && data.texts.length > 0) { if (data.items && data.items.filter(item => item.type === 'text').length > 0) {
content += '<div class="col-md-6"><h6>文本数据 (' + data.texts.length + ')</h6>'; const textItems = data.items.filter(item => item.type === 'text');
content += '<div class="col-md-6"><h6>文本数据 (' + textItems.length + ')</h6>';
content += '<div class="list-group" style="max-height: 300px; overflow-y: auto;">'; content += '<div class="list-group" style="max-height: 300px; overflow-y: auto;">';
data.texts.forEach((text, index) => { textItems.forEach((item, index) => {
const text = item.content || item.text || '';
const shortText = text.length > 50 ? text.substring(0, 50) + '...' : text; const shortText = text.length > 50 ? text.substring(0, 50) + '...' : text;
content += `<div class="list-group-item"> content += `<div class="list-group-item">
<small class="text-muted">#${index + 1}</small><br> <small class="text-muted">#${item.id}</small><br>
${shortText} ${shortText}
</div>`; </div>`;
}); });
@ -905,7 +1002,7 @@
} }
try { try {
const response = await fetch('/api/data/clear', { const response = await fetch('/api/clear_index', {
method: 'POST' method: 'POST'
}); });
@ -914,7 +1011,7 @@
if (data.success) { if (data.success) {
showAlert('success', '数据已清空'); showAlert('success', '数据已清空');
updateDataStats(); updateDataStats();
document.getElementById('buildIndexBtn').disabled = true; // 移除构建索引按钮的引用
} else { } else {
showAlert('danger', `清空失败: ${data.message}`); showAlert('danger', `清空失败: ${data.message}`);
} }
@ -926,12 +1023,14 @@
// 更新数据统计 // 更新数据统计
async function updateDataStats() { async function updateDataStats() {
try { try {
const response = await fetch('/api/data/stats'); const response = await fetch('/api/system_info');
const data = await response.json(); const data = await response.json();
if (data.success) { if (data.success) {
document.getElementById('imageCount').textContent = data.image_count || 0; const retrieval_info = data.retrieval_info || {};
document.getElementById('textCount').textContent = data.text_count || 0; document.getElementById('imageCount').textContent = retrieval_info.image_count || 0;
document.getElementById('textCount').textContent = retrieval_info.text_count || 0;
// 移除构建索引按钮的引用
} }
} catch (error) { } catch (error) {
console.log('获取数据统计失败:', error); console.log('获取数据统计失败:', error);

View File

@ -1,104 +0,0 @@
#!/usr/bin/env python3
"""
测试所有四种多模态检索模式
"""
from multimodal_retrieval_multigpu import MultiGPUMultimodalRetrieval
import numpy as np
from PIL import Image
import os
def test_all_retrieval_modes():
print('正在初始化多GPU多模态检索系统...')
retrieval = MultiGPUMultimodalRetrieval()
# 准备测试数据
test_texts = [
"一只可爱的小猫",
"美丽的风景照片",
"现代建筑设计",
"colorful flowers in garden"
]
test_images = [
'sample_images/1755677101_1__.jpg',
'sample_images/1755677101_2__.jpg',
'sample_images/1755677101_3__.jpg',
'sample_images/1755677101_4__.jpg'
]
# 验证测试图像存在
existing_images = [img for img in test_images if os.path.exists(img)]
if not existing_images:
print("❌ 没有找到测试图像文件")
return
print(f"找到 {len(existing_images)} 张测试图像")
try:
# 1. 构建文本索引
print('\n=== 构建文本索引 ===')
retrieval.build_text_index_parallel(test_texts)
print('✅ 文本索引构建完成')
# 2. 构建图像索引
print('\n=== 构建图像索引 ===')
retrieval.build_image_index_parallel(existing_images)
print('✅ 图像索引构建完成')
# 3. 测试文本到文本检索
print('\n=== 测试文本到文本检索 ===')
query = "小动物"
results = retrieval.search_text_by_text(query, top_k=3)
print(f'查询: "{query}"')
for i, (text, score) in enumerate(results):
print(f' {i+1}. {text} (相似度: {score:.4f})')
# 4. 测试文本到图像检索
print('\n=== 测试文本到图像检索 ===')
query = "beautiful image"
results = retrieval.search_images_by_text(query, top_k=3)
print(f'查询: "{query}"')
for i, (image_path, score) in enumerate(results):
print(f' {i+1}. {image_path} (相似度: {score:.4f})')
# 5. 测试图像到文本检索
print('\n=== 测试图像到文本检索 ===')
query_image = existing_images[0]
results = retrieval.search_text_by_image(query_image, top_k=3)
print(f'查询图像: {query_image}')
for i, (text, score) in enumerate(results):
print(f' {i+1}. {text} (相似度: {score:.4f})')
# 6. 测试图像到图像检索
print('\n=== 测试图像到图像检索 ===')
query_image = existing_images[0]
results = retrieval.search_images_by_image(query_image, top_k=3)
print(f'查询图像: {query_image}')
for i, (image_path, score) in enumerate(results):
print(f' {i+1}. {image_path} (相似度: {score:.4f})')
print('\n✅ 所有四种检索模式测试完成!')
# 7. 测试Web应用兼容的方法名
print('\n=== 测试Web应用兼容方法 ===')
try:
results = retrieval.search_text_to_image("test query", top_k=2)
print('✅ search_text_to_image 方法正常')
results = retrieval.search_image_to_text(existing_images[0], top_k=2)
print('✅ search_image_to_text 方法正常')
results = retrieval.search_image_to_image(existing_images[0], top_k=2)
print('✅ search_image_to_image 方法正常')
except Exception as e:
print(f'❌ Web应用兼容方法测试失败: {e}')
except Exception as e:
print(f'❌ 测试过程中出现错误: {e}')
import traceback
traceback.print_exc()
if __name__ == "__main__":
test_all_retrieval_modes()

View File

@ -1,49 +0,0 @@
#!/usr/bin/env python3
"""
测试图像编码功能
"""
from multimodal_retrieval_multigpu import MultiGPUMultimodalRetrieval
import numpy as np
from PIL import Image
def test_image_encoding():
print('正在初始化多GPU多模态检索系统...')
retrieval = MultiGPUMultimodalRetrieval()
# 测试文本编码
print('测试文本编码...')
text_embeddings = retrieval.encode_text_batch(['这是一个测试文本'])
print(f'文本embedding形状: {text_embeddings.shape}')
print(f'文本embedding数据类型: {text_embeddings.dtype}')
# 测试图像编码
print('测试图像编码...')
test_images = ['sample_images/1755677101_1__.jpg']
image_embeddings = retrieval.encode_image_batch(test_images)
print(f'图像embedding形状: {image_embeddings.shape}')
print(f'图像embedding数据类型: {image_embeddings.dtype}')
# 测试两次相同图像的embedding是否一致
print('测试embedding一致性...')
image_embeddings2 = retrieval.encode_image_batch(test_images)
consistency = np.allclose(image_embeddings, image_embeddings2, rtol=1e-5)
print(f'相同图像embedding一致性: {consistency}')
# 测试不同图像的embedding差异
print('测试不同图像embedding差异...')
test_images2 = ['sample_images/1755677101_2__.jpg']
image_embeddings3 = retrieval.encode_image_batch(test_images2)
similarity = np.dot(image_embeddings[0], image_embeddings3[0]) / (np.linalg.norm(image_embeddings[0]) * np.linalg.norm(image_embeddings3[0]))
print(f'不同图像间相似度: {similarity:.4f}')
# 验证维度一致性
if text_embeddings.shape[1] == image_embeddings.shape[1]:
print('✅ 文本和图像embedding维度一致')
else:
print('❌ 文本和图像embedding维度不一致')
print('测试完成!')
if __name__ == "__main__":
test_image_encoding()

Binary file not shown.

Before

Width:  |  Height:  |  Size: 172 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 201 KiB

646
web_app_local.py Normal file
View File

@ -0,0 +1,646 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
本地多模态检索系统Web应用
集成本地模型和FAISS向量数据库
支持文搜文文搜图图搜文图搜图四种检索模式
"""
import os
import sys
import logging
import time
import json
import base64
from io import BytesIO
from pathlib import Path
from datetime import datetime, timezone
import numpy as np
from PIL import Image
from flask import Flask, request, jsonify, render_template, send_from_directory, send_file
from werkzeug.utils import secure_filename
import torch
# 设置离线模式
os.environ['TRANSFORMERS_OFFLINE'] = '1'
# 导入本地模块
from multimodal_retrieval_local import MultimodalRetrievalLocal
from optimized_file_handler import OptimizedFileHandler
# 设置日志
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# 创建Flask应用
app = Flask(__name__)
# 配置
app.config['UPLOAD_FOLDER'] = '/tmp/mmeb_uploads'
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB
app.config['MODEL_PATH'] = '/root/models/Ops-MM-embedding-v1-7B'
app.config['INDEX_PATH'] = '/root/mmeb/local_faiss_index'
app.config['ALLOWED_EXTENSIONS'] = {'txt', 'pdf', 'png', 'jpg', 'jpeg', 'gif'}
# 确保上传目录存在
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
# 创建临时文件夹
if not os.path.exists(app.config['UPLOAD_FOLDER']):
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
# 确保静态资源目录与 favicon 存在(方案二:静态文件)
app.static_folder = os.path.join(os.path.dirname(__file__), 'static')
os.makedirs(app.static_folder, exist_ok=True)
favicon_path = os.path.join(app.static_folder, 'favicon.ico')
if not os.path.exists(favicon_path):
# 写入一个 1x1 透明 PNG 转 ICO 的简易占位图标(使用 PNG 作为内容也可被多数浏览器识别)
# 这里直接写入一个极小的 PNG 文件,并命名为 .ico 以简化处理
transparent_png_base64 = (
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8Xw8AAqMBhHqg7T8AAAAASUVORK5CYII="
)
with open(favicon_path, 'wb') as f:
f.write(base64.b64decode(transparent_png_base64))
# 创建文件处理器
from optimized_file_handler import OptimizedFileHandler
file_handler = OptimizedFileHandler(local_storage_dir=app.config['UPLOAD_FOLDER'])
# 全局变量
retrieval_system = None
def allowed_file(filename):
"""检查文件扩展名是否允许"""
return '.' in filename and \
filename.rsplit('.', 1)[1].lower() in app.config['ALLOWED_EXTENSIONS']
def init_retrieval_system():
"""初始化检索系统"""
global retrieval_system
if retrieval_system is not None:
return retrieval_system
logger.info("初始化多模态检索系统...")
# 检查模型路径
model_path = app.config['MODEL_PATH']
if not os.path.exists(model_path):
logger.error(f"模型路径不存在: {model_path}")
raise FileNotFoundError(f"模型路径不存在: {model_path}")
# 初始化检索系统
retrieval_system = MultimodalRetrievalLocal(
model_path=model_path,
use_all_gpus=True,
index_path=app.config['INDEX_PATH'],
shard_model_across_gpus=True
)
logger.info("多模态检索系统初始化完成")
return retrieval_system
def get_image_base64(image_path):
"""将图像转换为base64编码"""
with open(image_path, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
return f"data:image/jpeg;base64,{encoded_string}"
@app.route('/')
def index():
"""首页"""
return render_template('local_index.html')
@app.route('/favicon.ico')
def favicon():
"""提供静态 favicon方案二"""
return send_from_directory(app.static_folder, 'favicon.ico')
@app.route('/api/stats', methods=['GET'])
def get_stats():
"""获取系统统计信息"""
try:
retrieval = init_retrieval_system()
stats = retrieval.get_stats()
debug_info = {
"server_time": datetime.now(timezone.utc).isoformat(),
"vector_dimension": stats.get("vector_dimension"),
"total_vectors": stats.get("total_vectors"),
"model_path": stats.get("model_path"),
"index_path": stats.get("index_path"),
}
return jsonify({"success": True, "stats": stats, "debug": debug_info})
except Exception as e:
logger.error(f"获取统计信息失败: {str(e)}")
return jsonify({"success": False, "error": str(e)}), 500
@app.route('/api/add_text', methods=['POST'])
def add_text():
"""添加文本"""
try:
data = request.json
text = data.get('text')
if not text:
return jsonify({"success": False, "error": "文本不能为空"}), 400
# 使用内存处理文本
with file_handler.temp_file_context(text.encode('utf-8'), suffix='.txt') as temp_file:
logger.info(f"处理文本: {temp_file}")
# 初始化检索系统
retrieval = init_retrieval_system()
# 添加文本
metadata = {
"timestamp": time.time(),
"source": "web_upload"
}
text_ids = retrieval.add_texts([text], [metadata])
# 保存索引
retrieval.save_index()
stats = retrieval.get_stats()
return jsonify({
"success": True,
"message": "文本添加成功",
"text_id": text_ids[0] if text_ids else None,
"debug": {
"server_time": datetime.now(timezone.utc).isoformat(),
"vector_dimension": stats.get("vector_dimension"),
"total_vectors": stats.get("total_vectors"),
}
})
except Exception as e:
logger.error(f"添加文本失败: {str(e)}")
return jsonify({"success": False, "error": str(e)}), 500
finally:
# 清理临时文件
file_handler.cleanup_all_temp_files()
@app.route('/api/add_images_batch', methods=['POST'])
def add_images_batch():
"""批量添加图像(多卡并行友好)"""
try:
# 读取多文件
files = request.files.getlist('images')
if not files:
return jsonify({"success": False, "error": "没有上传文件"}), 400
retrieval = init_retrieval_system()
images = []
metadatas = []
image_paths = []
for file in files:
if file and allowed_file(file.filename):
image_data = file.read()
filename = secure_filename(file.filename)
image_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
with open(image_path, 'wb') as f:
f.write(image_data)
try:
img = Image.open(BytesIO(image_data))
if img.mode != 'RGB':
img = img.convert('RGB')
except Exception as e:
return jsonify({"success": False, "error": f"图像格式不支持: {filename}: {str(e)}"}), 400
images.append(img)
metadatas.append({
"filename": filename,
"timestamp": time.time(),
"source": "web_upload_batch",
"size": len(image_data),
"local_path": image_path
})
image_paths.append(image_path)
else:
return jsonify({"success": False, "error": f"不支持的文件类型: {file.filename}"}), 400
image_ids = retrieval.add_images(images, metadatas, image_paths)
retrieval.save_index()
stats = retrieval.get_stats()
return jsonify({
"success": True,
"message": f"批量添加成功: {len(image_ids)}/{len(images)}",
"image_ids": image_ids,
"debug": {
"server_time": datetime.now(timezone.utc).isoformat(),
"vector_dimension": stats.get("vector_dimension"),
"total_vectors": stats.get("total_vectors"),
}
})
except Exception as e:
logger.error(f"批量添加图像失败: {str(e)}")
return jsonify({"success": False, "error": str(e)}), 500
finally:
file_handler.cleanup_all_temp_files()
@app.route('/api/add_texts_batch', methods=['POST'])
def add_texts_batch():
"""批量添加文本(多卡并行友好)"""
try:
data = request.json
if not data or 'texts' not in data:
return jsonify({"success": False, "error": "缺少 texts 数组"}), 400
texts = [t for t in data.get('texts', []) if isinstance(t, str) and t.strip()]
if not texts:
return jsonify({"success": False, "error": "文本数组为空"}), 400
retrieval = init_retrieval_system()
metadatas = [{"timestamp": time.time(), "source": "web_upload_batch"} for _ in texts]
text_ids = retrieval.add_texts(texts, metadatas)
retrieval.save_index()
stats = retrieval.get_stats()
return jsonify({
"success": True,
"message": f"批量添加成功: {len(text_ids)}/{len(texts)}",
"text_ids": text_ids,
"debug": {
"server_time": datetime.now(timezone.utc).isoformat(),
"vector_dimension": stats.get("vector_dimension"),
"total_vectors": stats.get("total_vectors"),
}
})
except Exception as e:
logger.error(f"批量添加文本失败: {str(e)}")
return jsonify({"success": False, "error": str(e)}), 500
@app.route('/api/add_image', methods=['POST'])
def add_image():
"""添加图像"""
try:
# 检查是否有文件
if 'image' not in request.files:
return jsonify({"success": False, "error": "没有上传文件"}), 400
file = request.files['image']
# 检查文件名
if file.filename == '':
return jsonify({"success": False, "error": "没有选择文件"}), 400
if file and allowed_file(file.filename):
# 读取图像数据
image_data = file.read()
file_size = len(image_data)
# 使用文件处理器处理图像
logger.info(f"处理图像: {file.filename} ({file_size} 字节)")
# 初始化检索系统
retrieval = init_retrieval_system()
# 创建临时文件
file_obj = BytesIO(image_data)
filename = secure_filename(file.filename)
# 保存到本地文件系统
image_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
with open(image_path, 'wb') as f:
f.write(image_data)
# 加载图像
try:
image = Image.open(BytesIO(image_data))
# 确保图像是RGB模式
if image.mode != 'RGB':
logger.info(f"将图像从 {image.mode} 转换为 RGB")
image = image.convert('RGB')
logger.info(f"成功加载图像: {filename}, 格式: {image.format}, 模式: {image.mode}, 大小: {image.size}")
except Exception as e:
logger.error(f"加载图像失败: {filename}, 错误: {str(e)}")
return jsonify({"success": False, "error": f"图像格式不支持: {str(e)}"}), 400
# 添加图像
metadata = {
"filename": filename,
"timestamp": time.time(),
"source": "web_upload",
"size": file_size,
"local_path": image_path
}
# 添加到检索系统
image_ids = retrieval.add_images([image], [metadata], [image_path])
# 保存索引
retrieval.save_index()
stats = retrieval.get_stats()
return jsonify({
"success": True,
"message": "图像添加成功",
"image_id": image_ids[0] if image_ids else None,
"debug": {
"server_time": datetime.now(timezone.utc).isoformat(),
"vector_dimension": stats.get("vector_dimension"),
"total_vectors": stats.get("total_vectors"),
"image_size": [image.width, image.height] if 'image' in locals() else None,
}
})
else:
return jsonify({"success": False, "error": "不支持的文件类型"}), 400
except Exception as e:
logger.error(f"添加图像失败: {str(e)}")
return jsonify({"success": False, "error": str(e)}), 500
finally:
# 清理临时文件
file_handler.cleanup_all_temp_files()
@app.route('/api/search_by_text', methods=['POST'])
def search_by_text():
"""文本搜索"""
try:
data = request.json
query = data.get('query')
k = int(data.get('k', 5))
filter_type = data.get('filter_type') # "text", "image" 或 null
if not query:
return jsonify({"success": False, "error": "查询文本不能为空"}), 400
# 初始化检索系统
retrieval = init_retrieval_system()
# 执行搜索
results = retrieval.search_by_text(query, k, filter_type)
# 处理结果
processed_results = []
for result in results:
item = {
"score": result.get("score", 0),
"type": result.get("type")
}
if result.get("type") == "text":
item["text"] = result.get("text", "")
elif result.get("type") == "image":
if "path" in result and os.path.exists(result["path"]):
item["image"] = get_image_base64(result["path"])
item["filename"] = os.path.basename(result["path"])
if "description" in result:
item["description"] = result["description"]
processed_results.append(item)
stats = retrieval.get_stats()
return jsonify({
"success": True,
"results": processed_results,
"query": query,
"filter_type": filter_type,
"debug": {
"server_time": datetime.now(timezone.utc).isoformat(),
"vector_dimension": stats.get("vector_dimension"),
"total_vectors": stats.get("total_vectors"),
}
})
except Exception as e:
logger.error(f"文本搜索失败: {str(e)}")
return jsonify({"success": False, "error": str(e)}), 500
@app.route('/api/search_by_image', methods=['POST'])
def search_by_image():
"""图像搜索"""
try:
# 检查是否有文件
if 'image' not in request.files:
return jsonify({"success": False, "error": "没有上传文件"}), 400
file = request.files['image']
k = int(request.form.get('k', 5))
filter_type = request.form.get('filter_type') # "text", "image" 或 null
# 检查文件名
if file.filename == '':
return jsonify({"success": False, "error": "没有选择文件"}), 400
if file and allowed_file(file.filename):
# 读取图像数据
image_data = file.read()
file_size = len(image_data)
# 根据文件大小选择处理方式
if file_size <= 5 * 1024 * 1024: # 5MB
# 小文件使用内存处理
logger.info(f"使用内存处理搜索图像: {file.filename} ({file_size} 字节)")
image = Image.open(BytesIO(image_data))
# 初始化检索系统
retrieval = init_retrieval_system()
# 执行搜索
results = retrieval.search_by_image(image, k, filter_type)
else:
# 大文件使用临时文件处理
with file_handler.temp_file_context(image_data, suffix=os.path.splitext(file.filename)[1]) as temp_file:
logger.info(f"使用临时文件处理搜索图像: {temp_file} ({file_size} 字节)")
# 初始化检索系统
retrieval = init_retrieval_system()
# 加载图像
image = Image.open(temp_file)
# 执行搜索
results = retrieval.search_by_image(image, k, filter_type)
# 处理结果
processed_results = []
for result in results:
item = {
"score": result.get("score", 0),
"type": result.get("type")
}
if result.get("type") == "text":
item["text"] = result.get("text", "")
elif result.get("type") == "image":
if "path" in result and os.path.exists(result["path"]):
item["image"] = get_image_base64(result["path"])
item["filename"] = os.path.basename(result["path"])
if "description" in result:
item["description"] = result["description"]
processed_results.append(item)
stats = retrieval.get_stats()
return jsonify({
"success": True,
"results": processed_results,
"filter_type": filter_type,
"debug": {
"server_time": datetime.now(timezone.utc).isoformat(),
"vector_dimension": stats.get("vector_dimension"),
"total_vectors": stats.get("total_vectors"),
}
})
else:
return jsonify({"success": False, "error": "不支持的文件类型"}), 400
except Exception as e:
logger.error(f"图像搜索失败: {str(e)}")
return jsonify({"success": False, "error": str(e)}), 500
finally:
# 清理临时文件
file_handler.cleanup_all_temp_files()
@app.route('/uploads/<filename>')
def uploaded_file(filename):
"""提供上传文件的访问"""
return send_from_directory(app.config['UPLOAD_FOLDER'], filename)
@app.route('/temp/<filename>')
def temp_file(filename):
"""提供临时文件的访问"""
return send_from_directory(app.config['UPLOAD_FOLDER'], filename)
@app.route('/api/save_index', methods=['POST'])
def save_index():
"""保存索引"""
try:
# 初始化检索系统
retrieval = init_retrieval_system()
# 保存索引
retrieval.save_index()
stats = retrieval.get_stats()
return jsonify({
"success": True,
"message": "索引保存成功",
"debug": {
"server_time": datetime.now(timezone.utc).isoformat(),
"vector_dimension": stats.get("vector_dimension"),
"total_vectors": stats.get("total_vectors"),
}
})
except Exception as e:
logger.error(f"保存索引失败: {str(e)}")
return jsonify({"success": False, "error": str(e)}), 500
@app.route('/api/clear_index', methods=['POST'])
def clear_index():
"""清空索引"""
try:
# 初始化检索系统
retrieval = init_retrieval_system()
# 清空索引
retrieval.clear_index()
stats = retrieval.get_stats()
return jsonify({
"success": True,
"message": "索引已清空",
"debug": {
"server_time": datetime.now(timezone.utc).isoformat(),
"vector_dimension": stats.get("vector_dimension"),
"total_vectors": stats.get("total_vectors"),
}
})
except Exception as e:
logger.error(f"清空索引失败: {str(e)}")
return jsonify({"success": False, "error": str(e)}), 500
@app.route('/api/list_items', methods=['GET'])
def list_items():
"""列出所有索引项"""
try:
# 初始化检索系统
retrieval = init_retrieval_system()
# 获取所有项
items = retrieval.list_items()
stats = retrieval.get_stats()
return jsonify({
"success": True,
"items": items,
"debug": {
"server_time": datetime.now(timezone.utc).isoformat(),
"vector_dimension": stats.get("vector_dimension"),
"total_vectors": stats.get("total_vectors"),
}
})
except Exception as e:
logger.error(f"列出索引项失败: {str(e)}")
return jsonify({"success": False, "error": str(e)}), 500
@app.route('/api/system_info', methods=['GET', 'POST'])
def system_info():
"""获取系统信息"""
try:
# GPU信息
gpu_info = []
if torch.cuda.is_available():
for i in range(torch.cuda.device_count()):
gpu_info.append({
"id": i,
"name": torch.cuda.get_device_name(i),
"memory_total": torch.cuda.get_device_properties(i).total_memory / (1024 ** 3),
"memory_allocated": torch.cuda.memory_allocated(i) / (1024 ** 3),
"memory_reserved": torch.cuda.memory_reserved(i) / (1024 ** 3)
})
# 检索系统信息
retrieval_info = {}
if retrieval_system is not None:
retrieval_info = retrieval_system.get_stats()
return jsonify({
"success": True,
"gpu_info": gpu_info,
"retrieval_info": retrieval_info,
"model_path": app.config['MODEL_PATH'],
"index_path": app.config['INDEX_PATH'],
"debug": {"server_time": datetime.now(timezone.utc).isoformat()}
})
except Exception as e:
logger.error(f"获取系统信息失败: {str(e)}")
return jsonify({"success": False, "error": str(e)}), 500
@app.route('/api/health', methods=['GET'])
def health():
"""健康检查模型可用性、索引状态、GPU 信息"""
try:
retrieval = init_retrieval_system()
stats = retrieval.get_stats()
health_info = {
"model_loaded": True,
"vector_dimension": stats.get("vector_dimension"),
"total_vectors": stats.get("total_vectors"),
"gpu_available": torch.cuda.is_available(),
"cuda_device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0,
"server_time": datetime.now(timezone.utc).isoformat(),
}
return jsonify({"success": True, "health": health_info})
except Exception as e:
logger.error(f"健康检查失败: {str(e)}")
return jsonify({"success": False, "error": str(e), "health": {"model_loaded": False, "server_time": datetime.now(timezone.utc).isoformat()}}), 500
if __name__ == '__main__':
try:
# 预初始化检索系统
init_retrieval_system()
# 启动Web应用
app.run(host='0.0.0.0', port=5000, debug=False)
except Exception as e:
logger.error(f"启动Web应用失败: {str(e)}")
sys.exit(1)

View File

@ -1,617 +0,0 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
多GPU多模态检索系统 - Web应用
专为双GPU部署优化
"""
import os
import json
import time
from flask import Flask, render_template, request, jsonify, send_file, url_for
from werkzeug.utils import secure_filename
from PIL import Image
import base64
import io
import logging
import traceback
import glob
# 设置环境变量优化GPU内存
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = Flask(__name__)
app.config['SECRET_KEY'] = 'multigpu_multimodal_retrieval_2024'
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max file size
# 配置上传文件夹
UPLOAD_FOLDER = 'uploads'
SAMPLE_IMAGES_FOLDER = 'sample_images'
TEXT_DATA_FOLDER = 'text_data'
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'bmp', 'webp'}
# 确保文件夹存在
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
os.makedirs(SAMPLE_IMAGES_FOLDER, exist_ok=True)
os.makedirs(TEXT_DATA_FOLDER, exist_ok=True)
# 全局检索系统实例
retrieval_system = None
def allowed_file(filename):
"""检查文件扩展名是否允许"""
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
def image_to_base64(image_path):
"""将图片转换为base64编码"""
try:
with open(image_path, "rb") as img_file:
return base64.b64encode(img_file.read()).decode('utf-8')
except Exception as e:
logger.error(f"图片转换失败: {e}")
return None
@app.route('/')
def index():
"""主页"""
return render_template('index.html')
@app.route('/api/status')
def get_status():
"""获取系统状态"""
global retrieval_system
status = {
'initialized': retrieval_system is not None,
'gpu_count': 0,
'model_loaded': False
}
try:
import torch
if torch.cuda.is_available():
status['gpu_count'] = torch.cuda.device_count()
if retrieval_system and retrieval_system.model:
status['model_loaded'] = True
status['device_ids'] = retrieval_system.device_ids
except Exception as e:
logger.error(f"获取状态失败: {e}")
return jsonify(status)
@app.route('/api/init', methods=['POST'])
def initialize_system():
"""初始化多GPU检索系统"""
global retrieval_system
try:
logger.info("正在初始化多GPU检索系统...")
# 导入多GPU检索系统
from multimodal_retrieval_multigpu import MultiGPUMultimodalRetrieval
# 初始化系统
retrieval_system = MultiGPUMultimodalRetrieval()
if retrieval_system.model is None:
raise Exception("模型加载失败")
logger.info("✅ 多GPU系统初始化成功")
return jsonify({
'success': True,
'message': '多GPU系统初始化成功',
'device_ids': retrieval_system.device_ids,
'gpu_count': len(retrieval_system.device_ids)
})
except Exception as e:
error_msg = f"系统初始化失败: {str(e)}"
logger.error(error_msg)
logger.error(traceback.format_exc())
return jsonify({
'success': False,
'message': error_msg
}), 500
@app.route('/api/search/text_to_text', methods=['POST'])
def search_text_to_text():
"""文本搜索文本"""
return handle_search('text_to_text')
@app.route('/api/search/text_to_image', methods=['POST'])
def search_text_to_image():
"""文本搜索图片"""
return handle_search('text_to_image')
@app.route('/api/search/image_to_text', methods=['POST'])
def search_image_to_text():
"""图片搜索文本"""
return handle_search('image_to_text')
@app.route('/api/search/image_to_image', methods=['POST'])
def search_image_to_image():
"""图片搜索图片"""
return handle_search('image_to_image')
@app.route('/api/search', methods=['POST'])
def search():
"""通用搜索接口(兼容旧版本)"""
mode = request.form.get('mode') or request.json.get('mode', 'text_to_text')
return handle_search(mode)
def handle_search(mode):
"""处理搜索请求的通用函数"""
global retrieval_system
if not retrieval_system:
return jsonify({
'success': False,
'message': '系统未初始化,请先点击初始化按钮'
}), 400
try:
top_k = int(request.form.get('top_k', 5))
if mode in ['text_to_text', 'text_to_image']:
# 文本查询
query = request.form.get('query') or request.json.get('query', '')
if not query.strip():
return jsonify({
'success': False,
'message': '请输入查询文本'
}), 400
logger.info(f"执行{mode}搜索: {query}")
# 执行搜索
if mode == 'text_to_text':
results = retrieval_system.search_text_to_text(query, top_k=top_k)
else: # text_to_image
results = retrieval_system.search_text_to_image(query, top_k=top_k)
return jsonify({
'success': True,
'mode': mode,
'query': query,
'results': results,
'result_count': len(results)
})
elif mode in ['image_to_text', 'image_to_image']:
# 图片查询
if 'image' not in request.files:
return jsonify({
'success': False,
'message': '请上传查询图片'
}), 400
file = request.files['image']
if file.filename == '' or not allowed_file(file.filename):
return jsonify({
'success': False,
'message': '请上传有效的图片文件'
}), 400
# 保存上传的图片
filename = secure_filename(file.filename)
timestamp = str(int(time.time()))
filename = f"query_{timestamp}_{filename}"
filepath = os.path.join(UPLOAD_FOLDER, filename)
file.save(filepath)
logger.info(f"执行{mode}搜索,图片: {filename}")
# 执行搜索
if mode == 'image_to_text':
results = retrieval_system.search_image_to_text(filepath, top_k=top_k)
else: # image_to_image
results = retrieval_system.search_image_to_image(filepath, top_k=top_k)
# 转换查询图片为base64
query_image_b64 = image_to_base64(filepath)
return jsonify({
'success': True,
'mode': mode,
'query_image': query_image_b64,
'results': results,
'result_count': len(results)
})
else:
return jsonify({
'success': False,
'message': f'不支持的搜索模式: {mode}'
}), 400
except Exception as e:
error_msg = f"搜索失败: {str(e)}"
logger.error(error_msg)
logger.error(traceback.format_exc())
return jsonify({
'success': False,
'message': error_msg
}), 500
@app.route('/api/upload/images', methods=['POST'])
def upload_images():
"""批量上传图片"""
try:
uploaded_files = []
if 'images' not in request.files:
return jsonify({'success': False, 'message': '没有选择文件'}), 400
files = request.files.getlist('images')
for file in files:
if file and file.filename != '' and allowed_file(file.filename):
filename = secure_filename(file.filename)
timestamp = str(int(time.time()))
filename = f"{timestamp}_{filename}"
filepath = os.path.join(SAMPLE_IMAGES_FOLDER, filename)
file.save(filepath)
uploaded_files.append(filename)
return jsonify({
'success': True,
'message': f'成功上传 {len(uploaded_files)} 个图片文件',
'uploaded_count': len(uploaded_files),
'files': uploaded_files
})
except Exception as e:
return jsonify({
'success': False,
'message': f'上传失败: {str(e)}'
}), 500
@app.route('/api/upload/texts', methods=['POST'])
def upload_texts():
"""批量上传文本数据"""
try:
data = request.get_json()
if not data or 'texts' not in data:
return jsonify({'success': False, 'message': '没有提供文本数据'}), 400
texts = data['texts']
if not isinstance(texts, list):
return jsonify({'success': False, 'message': '文本数据格式错误'}), 400
# 保存文本数据到文件
timestamp = str(int(time.time()))
filename = f"texts_{timestamp}.json"
filepath = os.path.join(TEXT_DATA_FOLDER, filename)
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(texts, f, ensure_ascii=False, indent=2)
return jsonify({
'success': True,
'message': f'成功上传 {len(texts)} 条文本',
'uploaded_count': len(texts)
})
except Exception as e:
return jsonify({
'success': False,
'message': f'上传失败: {str(e)}'
}), 500
@app.route('/api/upload/file', methods=['POST'])
def upload_single_file():
"""上传单个文件"""
if 'file' not in request.files:
return jsonify({'success': False, 'message': '没有选择文件'}), 400
file = request.files['file']
if file.filename == '':
return jsonify({'success': False, 'message': '没有选择文件'}), 400
if file and allowed_file(file.filename):
filename = secure_filename(file.filename)
timestamp = str(int(time.time()))
filename = f"{timestamp}_{filename}"
filepath = os.path.join(SAMPLE_IMAGES_FOLDER, filename)
file.save(filepath)
return jsonify({
'success': True,
'message': '文件上传成功',
'filename': filename
})
return jsonify({'success': False, 'message': '不支持的文件类型'}), 400
@app.route('/api/data/list', methods=['GET'])
def list_data():
"""列出已上传的数据"""
try:
# 列出图片文件
images = []
if os.path.exists(SAMPLE_IMAGES_FOLDER):
for filename in os.listdir(SAMPLE_IMAGES_FOLDER):
if allowed_file(filename):
filepath = os.path.join(SAMPLE_IMAGES_FOLDER, filename)
stat = os.stat(filepath)
images.append({
'filename': filename,
'size': stat.st_size,
'modified': stat.st_mtime
})
# 列出文本文件
texts = []
if os.path.exists(TEXT_DATA_FOLDER):
for filename in os.listdir(TEXT_DATA_FOLDER):
if filename.endswith('.json'):
filepath = os.path.join(TEXT_DATA_FOLDER, filename)
stat = os.stat(filepath)
texts.append({
'filename': filename,
'size': stat.st_size,
'modified': stat.st_mtime
})
return jsonify({
'success': True,
'data': {
'images': images,
'texts': texts
}
})
except Exception as e:
return jsonify({
'success': False,
'message': f'获取数据列表失败: {str(e)}'
}), 500
@app.route('/api/gpu_status')
def gpu_status():
"""获取GPU状态"""
try:
from smart_gpu_launcher import get_gpu_memory_info
gpu_info = get_gpu_memory_info()
return jsonify({
'success': True,
'gpu_info': gpu_info
})
except Exception as e:
return jsonify({
'success': False,
'message': f"获取GPU状态失败: {str(e)}"
}), 500
@app.route('/api/build_index', methods=['POST'])
def build_index():
"""构建检索索引"""
global retrieval_system
if not retrieval_system:
return jsonify({
'success': False,
'message': '系统未初始化'
}), 400
try:
# 获取所有图片和文本文件
image_files = []
text_data = []
# 扫描图片文件
for ext in ALLOWED_EXTENSIONS:
pattern = os.path.join(SAMPLE_IMAGES_FOLDER, f"*.{ext}")
image_files.extend(glob.glob(pattern))
# 读取文本文件(支持.json和.txt格式
text_files = glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.json"))
text_files.extend(glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.txt")))
for text_file in text_files:
try:
if text_file.endswith('.json'):
# 读取JSON格式的文本数据
with open(text_file, 'r', encoding='utf-8') as f:
data = json.load(f)
if isinstance(data, list):
text_data.extend([str(item).strip() for item in data if str(item).strip()])
else:
text_data.append(str(data).strip())
else:
# 读取TXT格式的文本数据
with open(text_file, 'r', encoding='utf-8') as f:
lines = [line.strip() for line in f.readlines() if line.strip()]
text_data.extend(lines)
except Exception as e:
logger.warning(f"读取文本文件失败 {text_file}: {e}")
# 检查是否有数据可以构建索引
if not image_files and not text_data:
return jsonify({
'success': False,
'message': '没有找到可用的图片或文本数据,请先上传数据'
}), 400
# 构建索引
if image_files:
logger.info(f"构建图片索引,共 {len(image_files)} 张图片")
retrieval_system.build_image_index_parallel(image_files)
if text_data:
logger.info(f"构建文本索引,共 {len(text_data)} 条文本")
retrieval_system.build_text_index_parallel(text_data)
return jsonify({
'success': True,
'message': f'索引构建完成!图片: {len(image_files)} 张,文本: {len(text_data)}',
'image_count': len(image_files),
'text_count': len(text_data)
})
except Exception as e:
logger.error(f"构建索引失败: {str(e)}")
return jsonify({
'success': False,
'message': f'构建索引失败: {str(e)}'
}), 500
@app.route('/api/data/stats', methods=['GET'])
def get_data_stats():
"""获取数据统计信息"""
try:
# 统计图片文件
image_count = 0
for ext in ALLOWED_EXTENSIONS:
pattern = os.path.join(SAMPLE_IMAGES_FOLDER, f"*.{ext}")
image_count += len(glob.glob(pattern))
# 统计文本数据
text_count = 0
text_files = glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.txt"))
for text_file in text_files:
try:
with open(text_file, 'r', encoding='utf-8') as f:
lines = [line.strip() for line in f.readlines() if line.strip()]
text_count += len(lines)
except Exception:
continue
return jsonify({
'success': True,
'image_count': image_count,
'text_count': text_count
})
except Exception as e:
logger.error(f"获取数据统计失败: {str(e)}")
return jsonify({
'success': False,
'message': f'获取统计失败: {str(e)}'
}), 500
@app.route('/api/data/clear', methods=['POST'])
def clear_data():
"""清空所有数据"""
try:
# 清空图片文件
for ext in ALLOWED_EXTENSIONS:
pattern = os.path.join(SAMPLE_IMAGES_FOLDER, f"*.{ext}")
for file_path in glob.glob(pattern):
try:
os.remove(file_path)
except Exception as e:
logger.warning(f"删除图片文件失败 {file_path}: {e}")
# 清空文本文件
text_files = glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.txt"))
for text_file in text_files:
try:
os.remove(text_file)
except Exception as e:
logger.warning(f"删除文本文件失败 {text_file}: {e}")
# 重置索引
global retrieval_system
if retrieval_system:
retrieval_system.text_index = None
retrieval_system.image_index = None
retrieval_system.text_data = []
retrieval_system.image_data = []
return jsonify({
'success': True,
'message': '数据已清空'
})
except Exception as e:
logger.error(f"清空数据失败: {str(e)}")
return jsonify({
'success': False,
'message': f'清空数据失败: {str(e)}'
}), 500
@app.route('/uploads/<filename>')
def uploaded_file(filename):
"""提供上传文件的访问"""
return send_file(os.path.join(SAMPLE_IMAGES_FOLDER, filename))
def print_startup_info():
"""打印启动信息"""
print("🚀 启动多GPU多模态检索Web应用")
print("=" * 60)
print("访问地址: http://localhost:5000")
print("支持功能:")
print(" 📝 文搜文 - 文本查找相似文本")
print(" 🖼️ 文搜图 - 文本查找相关图片")
print(" 📝 图搜文 - 图片查找相关文本")
print(" 🖼️ 图搜图 - 图片查找相似图片")
print(" 📤 批量上传 - 图片和文本数据管理")
print("GPU配置:")
try:
import torch
if torch.cuda.is_available():
gpu_count = torch.cuda.device_count()
print(f" 🖥️ 检测到 {gpu_count} 个GPU")
for i in range(gpu_count):
name = torch.cuda.get_device_name(i)
props = torch.cuda.get_device_properties(i)
memory_gb = props.total_memory / 1024**3
print(f" GPU {i}: {name} ({memory_gb:.1f}GB)")
else:
print(" ❌ CUDA不可用")
except Exception as e:
print(f" ❌ GPU检查失败: {e}")
print("=" * 60)
def auto_initialize():
"""启动时自动初始化系统"""
global retrieval_system
try:
logger.info("🚀 启动时自动初始化多GPU检索系统...")
# 导入多GPU检索系统
from multimodal_retrieval_multigpu import MultiGPUMultimodalRetrieval
# 初始化系统
retrieval_system = MultiGPUMultimodalRetrieval()
if retrieval_system.model is None:
raise Exception("模型加载失败")
logger.info("✅ 系统自动初始化成功")
return True
except Exception as e:
logger.error(f"❌ 系统自动初始化失败: {str(e)}")
logger.error(traceback.format_exc())
return False
if __name__ == '__main__':
print_startup_info()
# 启动时自动初始化
auto_initialize()
# 启动Flask应用
app.run(
host='0.0.0.0',
port=5000,
debug=False,
threaded=True
)