base feature

This commit is contained in:
eust-w 2025-08-20 10:01:03 +00:00
parent 1de00fccda
commit 0cd7a4cb41
16 changed files with 2864 additions and 201 deletions

201
LICENSE
View File

@ -1,201 +0,0 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

170
README.md Normal file
View File

@ -0,0 +1,170 @@
# 多模态检索系统 (Multimodal Retrieval System)
基于 OpenSearch-AI/Ops-MM-embedding-v1-7B 模型的多模态检索系统,支持四种检索模式:文搜图、文搜文、图搜图、图搜文。
## 功能特性
- **文搜文 (Text-to-Text)**: 使用文本查询搜索相似文本
- **文搜图 (Text-to-Image)**: 使用文本查询搜索相关图像
- **图搜图 (Image-to-Image)**: 使用图像查询搜索相似图像
- **图搜文 (Image-to-Text)**: 使用图像查询搜索相关文本
## 环境要求
- Python 3.8+
- CUDA 支持的 GPU (推荐也可使用CPU)
- 至少 8GB 内存
## 安装依赖
```bash
pip install -r requirements.txt
```
## 快速开始
### 1. 检查GPU环境
```python
from multimodal_retrieval import check_gpu_info
check_gpu_info()
```
### 2. 初始化系统
```python
from multimodal_retrieval import MultimodalRetrieval
# 初始化检索系统
retrieval_system = MultimodalRetrieval()
```
### 3. 构建索引
```python
# 构建文本索引
texts = ["一只可爱的小猫", "美丽的山景", "现代化城市"]
retrieval_system.build_text_index(texts)
# 构建图像索引
image_paths = ["./images/cat.jpg", "./images/mountain.jpg", "./images/city.jpg"]
retrieval_system.build_image_index(image_paths)
```
### 4. 执行检索
```python
# 文搜文
results = retrieval_system.search_text_by_text("猫咪", top_k=5)
# 文搜图
results = retrieval_system.search_images_by_text("动物", top_k=5)
# 图搜图
results = retrieval_system.search_images_by_image("./query_image.jpg", top_k=5)
# 图搜文
results = retrieval_system.search_text_by_image("./query_image.jpg", top_k=5)
```
## 运行演示
```bash
python demo.py
```
演示脚本会自动:
1. 检查GPU环境信息
2. 初始化多模态检索系统
3. 演示四种检索模式
4. 显示检索结果和相似度分数
## 文件结构
```
mmeb/
├── multimodal_retrieval.py # 主要检索系统类
├── demo.py # 演示脚本
├── requirements.txt # 依赖包列表
├── README.md # 项目说明
└── images/ # 图像数据目录 (需要自行创建)
```
## API 参考
### MultimodalRetrieval 类
#### 初始化参数
- `model_name`: 模型名称,默认 "OpenSearch-AI/Ops-MM-embedding-v1-7B"
- `device`: 设备类型,默认自动选择 ("cuda" 或 "cpu")
#### 主要方法
##### `build_text_index(texts, save_path=None)`
构建文本索引
- `texts`: 文本列表
- `save_path`: 索引保存路径 (可选)
##### `build_image_index(image_paths, save_path=None)`
构建图像索引
- `image_paths`: 图像路径列表
- `save_path`: 索引保存路径 (可选)
##### `search_text_by_text(query, top_k=5)`
文搜文检索
- `query`: 查询文本
- `top_k`: 返回结果数量
- 返回: `[(文本, 相似度分数), ...]`
##### `search_images_by_text(query, top_k=5)`
文搜图检索
- `query`: 查询文本
- `top_k`: 返回结果数量
- 返回: `[(图像路径, 相似度分数), ...]`
##### `search_images_by_image(query_image, top_k=5)`
图搜图检索
- `query_image`: 查询图像路径或PIL图像
- `top_k`: 返回结果数量
- 返回: `[(图像路径, 相似度分数), ...]`
##### `search_text_by_image(query_image, top_k=5)`
图搜文检索
- `query_image`: 查询图像路径或PIL图像
- `top_k`: 返回结果数量
- 返回: `[(文本, 相似度分数), ...]`
## 注意事项
1. **首次运行**: 首次运行时会自动下载模型,需要网络连接
2. **内存需求**: 7B参数模型需要较大内存建议使用GPU
3. **图像格式**: 支持常见图像格式 (jpg, png, bmp, gif等)
4. **批处理**: 系统自动进行批处理以提高效率
5. **索引保存**: 可以保存和加载索引以避免重复构建
## 性能优化建议
- 使用GPU加速推理
- 合理设置批处理大小
- 保存索引文件避免重复构建
- 对大量数据使用分批处理
## 故障排除
### 常见问题
1. **CUDA内存不足**: 减小批处理大小或使用CPU
2. **模型下载失败**: 检查网络连接或使用镜像源
3. **图像加载错误**: 检查图像文件路径和格式
### 日志信息
系统会输出详细的日志信息,包括:
- GPU环境检测结果
- 模型加载进度
- 索引构建状态
- 检索执行情况
## 许可证
本项目遵循 MIT 许可证。

View File

@ -0,0 +1,632 @@
import torch
import torch.nn as nn
from torch.nn.parallel import DataParallel, DistributedDataParallel
import numpy as np
from PIL import Image
import faiss
from transformers import AutoModel, AutoProcessor, AutoTokenizer
from typing import List, Union, Tuple, Dict
import os
import json
from pathlib import Path
import logging
import gc
from concurrent.futures import ThreadPoolExecutor, as_completed
import threading
# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class MultiGPUMultimodalRetrieval:
"""多GPU优化的多模态检索系统支持文搜图、文搜文、图搜图、图搜文"""
def __init__(self, model_name: str = "OpenSearch-AI/Ops-MM-embedding-v1-7B",
use_all_gpus: bool = True, gpu_ids: List[int] = None, min_memory_gb=12):
"""
初始化多GPU多模态检索系统
Args:
model_name: 模型名称
use_all_gpus: 是否使用所有可用GPU
gpu_ids: 指定使用的GPU ID列表
min_memory_gb: 最小可用内存GB
"""
self.model_name = model_name
# 设置GPU设备
self._setup_devices(use_all_gpus, gpu_ids, min_memory_gb)
# 清理GPU内存
self._clear_all_gpu_memory()
logger.info(f"正在加载模型到多GPU: {self.device_ids}")
# 加载模型和处理器
self.model = None
self.tokenizer = None
self.processor = None
self._load_model_multigpu()
# 初始化索引
self.text_index = None
self.image_index = None
self.text_data = []
self.image_data = []
logger.info("多GPU模型加载完成")
def _setup_devices(self, use_all_gpus: bool, gpu_ids: List[int], min_memory_gb=12):
"""设置GPU设备"""
if not torch.cuda.is_available():
raise RuntimeError("CUDA不可用无法使用多GPU")
total_gpus = torch.cuda.device_count()
logger.info(f"检测到 {total_gpus} 个GPU")
# 检查是否设置了CUDA_VISIBLE_DEVICES
cuda_visible_devices = os.environ.get('CUDA_VISIBLE_DEVICES')
if cuda_visible_devices is not None:
# 如果设置了CUDA_VISIBLE_DEVICES使用可见的GPU
visible_gpu_count = len(cuda_visible_devices.split(','))
self.device_ids = list(range(visible_gpu_count))
logger.info(f"使用CUDA_VISIBLE_DEVICES指定的GPU: {cuda_visible_devices}")
elif use_all_gpus:
self.device_ids = self._select_best_gpus(min_memory_gb)
elif gpu_ids:
self.device_ids = gpu_ids
else:
self.device_ids = [0]
self.num_gpus = len(self.device_ids)
self.primary_device = f"cuda:{self.device_ids[0]}"
logger.info(f"使用GPU: {self.device_ids}, 主设备: {self.primary_device}")
def _clear_all_gpu_memory(self):
"""清理所有GPU内存"""
for gpu_id in self.device_ids:
torch.cuda.set_device(gpu_id)
torch.cuda.empty_cache()
torch.cuda.synchronize()
gc.collect()
logger.info("所有GPU内存已清理")
def _load_model_multigpu(self):
"""加载模型到多GPU"""
try:
# 设置环境变量优化内存使用
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
# 清理GPU内存
self._clear_gpu_memory()
# 首先尝试使用accelerate的自动设备映射
if self.num_gpus > 1:
# 设置最大内存限制每个GPU 18GB留出缓冲
max_memory = {i: "18GiB" for i in self.device_ids}
logger.info(f"正在加载模型到多GPU: {self.device_ids}")
self.model = AutoModel.from_pretrained(
self.model_name,
trust_remote_code=True,
torch_dtype=torch.float16,
device_map="auto",
max_memory=max_memory,
low_cpu_mem_usage=True,
offload_folder="./offload"
)
else:
# 单GPU加载
self.model = AutoModel.from_pretrained(
self.model_name,
trust_remote_code=True,
torch_dtype=torch.float16,
device_map=self.primary_device
)
# 加载分词器和处理器到主设备
try:
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
trust_remote_code=True
)
logger.info("Tokenizer加载成功")
except Exception as e:
logger.error(f"Tokenizer加载失败: {e}")
return False
# 加载处理器用于图像处理
try:
self.processor = AutoProcessor.from_pretrained(
self.model_name,
trust_remote_code=True
)
logger.info("Processor加载成功")
except Exception as e:
logger.warning(f"Processor加载失败: {e}")
# 如果AutoProcessor失败尝试使用tokenizer作为fallback
logger.info("尝试使用tokenizer作为processor的fallback")
self.processor = self.tokenizer
logger.info(f"模型已成功加载到设备: {self.model.hf_device_map if hasattr(self.model, 'hf_device_map') else self.primary_device}")
logger.info("多GPU模型加载完成")
return True
except Exception as e:
logger.error(f"多GPU模型加载失败: {str(e)}")
return False
def _clear_gpu_memory(self):
"""清理GPU内存"""
for gpu_id in self.device_ids:
torch.cuda.set_device(gpu_id)
torch.cuda.empty_cache()
torch.cuda.synchronize()
gc.collect()
logger.info("GPU内存已清理")
def _get_gpu_memory_info(self):
"""获取GPU内存使用情况"""
try:
import subprocess
result = subprocess.run(['nvidia-smi', '--query-gpu=memory.used,memory.total', '--format=csv,nounits,noheader'],
capture_output=True, text=True, check=True)
lines = result.stdout.strip().split('\n')
gpu_info = []
for i, line in enumerate(lines):
used, total = map(int, line.split(', '))
free = total - used
gpu_info.append({
'gpu_id': i,
'used': used,
'total': total,
'free': free,
'usage_percent': (used / total) * 100
})
return gpu_info
except Exception as e:
logger.warning(f"无法获取GPU内存信息: {e}")
return []
def _select_best_gpus(self, min_memory_gb=12):
"""选择内存充足的GPU"""
gpu_info = self._get_gpu_memory_info()
if not gpu_info:
return list(range(torch.cuda.device_count()))
# 按可用内存排序
gpu_info.sort(key=lambda x: x['free'], reverse=True)
# 选择内存充足的GPU
min_memory_mb = min_memory_gb * 1024
suitable_gpus = []
for gpu in gpu_info:
if gpu['free'] >= min_memory_mb:
suitable_gpus.append(gpu['gpu_id'])
logger.info(f"GPU {gpu['gpu_id']}: {gpu['free']}MB 可用 (合适)")
else:
logger.warning(f"GPU {gpu['gpu_id']}: {gpu['free']}MB 可用 (不足)")
if not suitable_gpus:
# 如果没有GPU满足要求选择可用内存最多的
logger.warning(f"没有GPU有足够内存({min_memory_gb}GB)选择可用内存最多的GPU")
suitable_gpus = [gpu_info[0]['gpu_id']]
return suitable_gpus
def encode_text_batch(self, texts: List[str]) -> np.ndarray:
"""
批量编码文本为向量多GPU优化
Args:
texts: 文本列表
Returns:
文本向量
"""
if not texts:
return np.array([])
with torch.no_grad():
# 预处理输入
inputs = self.tokenizer(
text=texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
)
# 将输入移动到主设备
inputs = {k: v.to(self.primary_device) for k, v in inputs.items()}
# 前向传播
outputs = self.model(**inputs)
embeddings = outputs.last_hidden_state.mean(dim=1)
# 清理GPU内存
del inputs, outputs
torch.cuda.empty_cache()
return embeddings.cpu().numpy().astype(np.float32)
def encode_image_batch(self, images: List[Union[str, Image.Image]]) -> np.ndarray:
"""
批量编码图像为向量
Args:
images: 图像路径或PIL图像列表
Returns:
图像向量
"""
if not images:
return np.array([])
# 预处理图像
processed_images = []
for img in images:
if isinstance(img, str):
img = Image.open(img).convert('RGB')
elif isinstance(img, Image.Image):
img = img.convert('RGB')
processed_images.append(img)
try:
logger.info(f"处理 {len(processed_images)} 张图像")
# 使用多模态模型生成图像embedding
# 为每张图像创建简单的文本描述作为输入
conversations = []
for i in range(len(processed_images)):
# 使用简化的对话格式
conversation = [
{
"role": "user",
"content": [
{"type": "image", "image": processed_images[i]},
{"type": "text", "text": "What is in this image?"}
]
}
]
conversations.append(conversation)
# 使用processor处理
try:
# 尝试使用apply_chat_template方法
texts = []
for conv in conversations:
text = self.processor.apply_chat_template(conv, tokenize=False, add_generation_prompt=False)
texts.append(text)
# 处理文本和图像
inputs = self.processor(
text=texts,
images=processed_images,
return_tensors="pt",
padding=True
)
# 移动到GPU
inputs = {k: v.to(self.primary_device) for k, v in inputs.items()}
# 获取模型输出
with torch.no_grad():
outputs = self.model(**inputs)
embeddings = outputs.last_hidden_state.mean(dim=1)
# 转换为numpy数组
embeddings = embeddings.cpu().numpy().astype(np.float32)
except Exception as inner_e:
logger.warning(f"多模态模型图像编码失败,使用文本模式: {inner_e}")
return np.zeros((len(processed_images), 3584), dtype=np.float32)
# 如果多模态失败使用纯文本描述作为fallback
image_descriptions = ["An image" for _ in processed_images]
text_inputs = self.processor(
text=image_descriptions,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
)
text_inputs = {k: v.to(self.primary_device) for k, v in text_inputs.items()}
with torch.no_grad():
outputs = self.model(**text_inputs)
embeddings = outputs.last_hidden_state.mean(dim=1)
embeddings = embeddings.cpu().numpy().astype(np.float32)
logger.info(f"生成图像embeddings: {embeddings.shape}")
return embeddings
except Exception as e:
logger.error(f"图像编码失败: {e}")
# 返回与文本embedding维度一致的零向量作为fallback
embedding_dim = 3584
embeddings = np.zeros((len(processed_images), embedding_dim), dtype=np.float32)
return embeddings
def build_text_index_parallel(self, texts: List[str], save_path: str = None):
"""
并行构建文本索引多GPU优化
Args:
texts: 文本列表
save_path: 索引保存路径
"""
logger.info(f"正在并行构建文本索引,共 {len(texts)} 条文本")
# 根据GPU数量调整批次大小
batch_size = max(4, 16 // self.num_gpus)
all_embeddings = []
# 分批处理
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i+batch_size]
try:
embeddings = self.encode_text_batch(batch_texts)
all_embeddings.append(embeddings)
# 显示进度
if (i // batch_size + 1) % 10 == 0:
logger.info(f"已处理 {i + len(batch_texts)}/{len(texts)} 条文本")
except torch.cuda.OutOfMemoryError:
logger.warning(f"GPU内存不足跳过批次 {i}-{i+len(batch_texts)}")
self._clear_all_gpu_memory()
continue
except Exception as e:
logger.error(f"处理文本批次时出错: {e}")
continue
if not all_embeddings:
raise ValueError("没有成功处理任何文本")
# 合并所有嵌入向量
embeddings = np.vstack(all_embeddings)
# 构建FAISS索引
dimension = embeddings.shape[1]
self.text_index = faiss.IndexFlatIP(dimension)
# 归一化向量
faiss.normalize_L2(embeddings)
self.text_index.add(embeddings)
self.text_data = texts
if save_path:
self._save_index(self.text_index, texts, save_path + "_text")
logger.info("文本索引构建完成")
def build_image_index_parallel(self, image_paths: List[str], save_path: str = None):
"""
并行构建图像索引多GPU优化
Args:
image_paths: 图像路径列表
save_path: 索引保存路径
"""
logger.info(f"正在并行构建图像索引,共 {len(image_paths)} 张图像")
# 图像处理使用更小的批次
batch_size = max(2, 8 // self.num_gpus)
all_embeddings = []
for i in range(0, len(image_paths), batch_size):
batch_images = image_paths[i:i+batch_size]
try:
embeddings = self.encode_image_batch(batch_images)
all_embeddings.append(embeddings)
# 显示进度
if (i // batch_size + 1) % 5 == 0:
logger.info(f"已处理 {i + len(batch_images)}/{len(image_paths)} 张图像")
except torch.cuda.OutOfMemoryError:
logger.warning(f"GPU内存不足跳过图像批次 {i}-{i+len(batch_images)}")
self._clear_all_gpu_memory()
continue
except Exception as e:
logger.error(f"处理图像批次时出错: {e}")
continue
if not all_embeddings:
raise ValueError("没有成功处理任何图像")
embeddings = np.vstack(all_embeddings)
# 构建FAISS索引
dimension = embeddings.shape[1]
self.image_index = faiss.IndexFlatIP(dimension)
faiss.normalize_L2(embeddings)
self.image_index.add(embeddings)
self.image_data = image_paths
if save_path:
self._save_index(self.image_index, image_paths, save_path + "_image")
logger.info("图像索引构建完成")
def search_text_by_text(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
"""文搜文:使用文本查询搜索相似文本"""
if self.text_index is None:
raise ValueError("文本索引未构建,请先调用 build_text_index_parallel")
query_embedding = self.encode_text_batch([query]).astype(np.float32)
faiss.normalize_L2(query_embedding)
scores, indices = self.text_index.search(query_embedding, top_k)
results = []
for score, idx in zip(scores[0], indices[0]):
if idx != -1:
results.append((self.text_data[idx], float(score)))
return results
def search_images_by_text(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
"""文搜图:使用文本查询搜索相似图像"""
if self.image_index is None:
raise ValueError("图像索引未构建,请先调用 build_image_index_parallel")
query_embedding = self.encode_text_batch([query]).astype(np.float32)
faiss.normalize_L2(query_embedding)
scores, indices = self.image_index.search(query_embedding, top_k)
results = []
for score, idx in zip(scores[0], indices[0]):
if idx != -1:
results.append((self.image_data[idx], float(score)))
return results
def search_images_by_image(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]:
"""图搜图:使用图像查询搜索相似图像"""
if self.image_index is None:
raise ValueError("图像索引未构建,请先调用 build_image_index_parallel")
query_embedding = self.encode_image_batch([query_image]).astype(np.float32)
faiss.normalize_L2(query_embedding)
scores, indices = self.image_index.search(query_embedding, top_k)
results = []
for score, idx in zip(scores[0], indices[0]):
if idx != -1:
results.append((self.image_data[idx], float(score)))
return results
def search_text_by_image(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]:
"""图搜文:使用图像查询搜索相似文本"""
if self.text_index is None:
raise ValueError("文本索引未构建,请先调用 build_text_index_parallel")
query_embedding = self.encode_image_batch([query_image]).astype(np.float32)
faiss.normalize_L2(query_embedding)
scores, indices = self.text_index.search(query_embedding, top_k)
results = []
for score, idx in zip(scores[0], indices[0]):
if idx != -1:
results.append((self.text_data[idx], float(score)))
return results
# Web应用兼容的方法名称
def search_text_to_image(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
"""文搜图Web应用兼容方法"""
return self.search_images_by_text(query, top_k)
def search_image_to_image(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]:
"""图搜图Web应用兼容方法"""
return self.search_images_by_image(query_image, top_k)
def search_text_to_text(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
"""文搜文Web应用兼容方法"""
return self.search_text_by_text(query, top_k)
def search_image_to_text(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]:
"""图搜文Web应用兼容方法"""
return self.search_text_by_image(query_image, top_k)
def _save_index(self, index, data, path_prefix):
"""保存索引和数据"""
faiss.write_index(index, f"{path_prefix}.index")
with open(f"{path_prefix}.json", 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
def load_index(self, path_prefix, index_type="text"):
"""加载已保存的索引"""
index = faiss.read_index(f"{path_prefix}.index")
with open(f"{path_prefix}.json", 'r', encoding='utf-8') as f:
data = json.load(f)
if index_type == "text":
self.text_index = index
self.text_data = data
else:
self.image_index = index
self.image_data = data
logger.info(f"已加载 {index_type} 索引")
def get_gpu_memory_info(self):
"""获取所有GPU内存使用信息"""
memory_info = {}
for gpu_id in self.device_ids:
torch.cuda.set_device(gpu_id)
allocated = torch.cuda.memory_allocated(gpu_id) / 1024**3
cached = torch.cuda.memory_reserved(gpu_id) / 1024**3
total = torch.cuda.get_device_properties(gpu_id).total_memory / 1024**3
free = total - cached
memory_info[f"GPU_{gpu_id}"] = {
"total": f"{total:.1f}GB",
"allocated": f"{allocated:.1f}GB",
"cached": f"{cached:.1f}GB",
"free": f"{free:.1f}GB"
}
return memory_info
def check_multigpu_info():
"""检查多GPU环境信息"""
print("=== 多GPU环境信息 ===")
if not torch.cuda.is_available():
print("❌ CUDA不可用")
return
gpu_count = torch.cuda.device_count()
print(f"✅ 检测到 {gpu_count} 个GPU")
print(f"CUDA版本: {torch.version.cuda}")
print(f"PyTorch版本: {torch.__version__}")
for i in range(gpu_count):
gpu_name = torch.cuda.get_device_name(i)
gpu_memory = torch.cuda.get_device_properties(i).total_memory / 1024**3
print(f"GPU {i}: {gpu_name} ({gpu_memory:.1f}GB)")
print("=====================")
if __name__ == "__main__":
# 检查多GPU环境
check_multigpu_info()
# 示例使用
print("\n正在初始化多GPU多模态检索系统...")
try:
retrieval_system = MultiGPUMultimodalRetrieval()
print("✅ 多GPU系统初始化成功")
# 显示GPU内存使用情况
memory_info = retrieval_system.get_gpu_memory_info()
print("\n📊 GPU内存使用情况:")
for gpu, info in memory_info.items():
print(f" {gpu}: {info['allocated']} / {info['total']} (已用/总计)")
print("\n🚀 多GPU多模态检索系统就绪")
print("支持的检索模式:")
print("1. 文搜文: search_text_by_text()")
print("2. 文搜图: search_images_by_text()")
print("3. 图搜图: search_images_by_image()")
print("4. 图搜文: search_text_by_image()")
except Exception as e:
print(f"❌ 多GPU系统初始化失败: {e}")
import traceback
traceback.print_exc()

309
ops_mm_embedding_v1.py Normal file
View File

@ -0,0 +1,309 @@
import math
from typing import List, Optional, TypeAlias, Union
import torch
import torch.nn as nn
from PIL import Image
from tqdm import tqdm
from transformers import AutoModelForImageTextToText, AutoProcessor
ImageInput: TypeAlias = Union[Image.Image, List[Image.Image]]
BatchImageInput: TypeAlias = Union[List[Image.Image], List[List[Image.Image]]]
class OpsMMEmbeddingV1(nn.Module):
def __init__(
self,
model_name: str,
device: str = "cuda",
max_length: Optional[int] = None,
attn_implementation: Optional[str] = None,
):
super().__init__()
self.device = device
self.max_length = max_length
self.default_instruction = "You are a helpful assistant."
self.base_model = AutoModelForImageTextToText.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
attn_implementation=attn_implementation,
).to(self.device)
self.processor = AutoProcessor.from_pretrained(model_name, min_pixels=256 * 28 * 28, max_pixels=1280 * 28 * 28)
self.processor.tokenizer.padding_side = "left"
self.eval()
def encode_input(self, input):
hidden_states = self.base_model(**input, return_dict=True, output_hidden_states=True)
hidden_states = hidden_states.hidden_states[-1]
pooled_output = self._pooling(hidden_states)
return pooled_output
def _pooling(self, last_hidden_state):
batch_size = last_hidden_state.shape[0]
reps = last_hidden_state[torch.arange(batch_size), -1, :]
reps = torch.nn.functional.normalize(reps, p=2, dim=-1)
return reps
def _validate_instructions(
self,
texts: Optional[List[str]],
images: Optional[BatchImageInput],
instruction: Optional[Union[str, List[str]]],
) -> List[str]:
"""Validate and format instructions to match batch size"""
batch_size = max(len(x) if x is not None else 0 for x in [texts, images])
if instruction is None:
return [self.default_instruction] * batch_size
if isinstance(instruction, str):
return [instruction] * batch_size
if isinstance(instruction, list):
if len(instruction) != batch_size:
raise ValueError(f"Length of instruction list ({len(instruction)}) must match batch size ({batch_size}) when texts/images are provided")
return instruction
raise TypeError("instruction must be str, List[str] or None")
def _process_images(self, images: ImageInput) -> List[Image.Image]:
"""Convert single image or list of images to processed format"""
if isinstance(images, Image.Image) or isinstance(images, str):
return [fetch_image(images)]
return [fetch_image(i) for i in images]
def embed(
self,
texts: Optional[List[str]] = None,
images: Optional[BatchImageInput] = None,
instruction: Optional[Union[str, List[str]]] = None,
**kwargs,
) -> torch.Tensor:
"""Generate embeddings for text, images, or combined inputs.
Args:
texts: List of text inputs (optional)
images: Can be:
- List[Image.Image]: Single image per input
- List[List[Image.Image]]: Multiple images per input
instruction: Instruction(s) for the model. Can be:
- None: use default instruction
- str: use same instruction for all inputs
- List[str]: per-input instructions (must match batch size)
"""
if texts is None and images is None:
raise ValueError("Either texts or images must be provided")
instructions = self._validate_instructions(texts, images, instruction)
# Determine batch size
batch_size = len(texts) if texts is not None else len(images) # type: ignore
input_texts, input_images = [], []
for i in range(batch_size):
text = texts[i] if texts is not None else None
image = images[i] if images is not None else None
input_str = ""
processed_image = None
if image is not None:
processed_image = self._process_images(image)
input_str += "<|vision_start|><|image_pad|><|vision_end|>" * len(processed_image)
if text is not None:
input_str += text
msg = f"<|im_start|>system\n{instructions[i]}<|im_end|>\n<|im_start|>user\n{input_str}<|im_end|>\n<|im_start|>assistant\n<|endoftext|>"
input_texts.append(msg)
input_images.append(processed_image)
# Only pass to processor if we actually have images
processed_images = input_images if any(img is not None for img in input_images) else None
inputs = self.processor(
text=input_texts,
images=processed_images,
padding=True,
truncation=True,
max_length=self.max_length,
return_tensors="pt",
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.inference_mode():
embeddings = self.encode_input(inputs)
return embeddings
def get_text_embeddings(
self,
texts: List[str],
instruction: Optional[Union[str, List[str]]] = None,
**kwargs,
) -> torch.Tensor:
"""Convenience method for text-only embeddings"""
return self.get_fused_embeddings(texts=texts, instruction=instruction, **kwargs)
def get_image_embeddings(
self,
images: BatchImageInput,
instruction: Optional[Union[str, List[str]]] = None,
**kwargs,
) -> torch.Tensor:
"""Convenience method for image-only embeddings.
Args:
images: Can be:
- List[Image.Image]: Single image per input
- List[List[Image.Image]]: Multiple images per input
"""
return self.get_fused_embeddings(images=images, instruction=instruction, **kwargs)
def get_fused_embeddings(
self,
texts: Optional[List[str]] = None,
images: Optional[BatchImageInput] = None,
instruction: Optional[Union[str, List[str]]] = None,
batch_size: int = 8,
show_progress: bool = True,
**kwargs,
) -> torch.Tensor:
"""Batch processing for large collections of texts/images.
Args:
texts: List of text inputs (optional)
images: Can be:
- List[Image.Image]: Single image per input
- List[List[Image.Image]]: Multiple images per input
instruction: Instruction(s) for the model
batch_size: Number of items to process at once
show_progress: Whether to display progress bar
"""
if texts is None and images is None:
raise ValueError("Either texts or images must be provided")
total_items = len(texts) if texts is not None else len(images) # type: ignore
num_batches = math.ceil(total_items / batch_size)
all_embeddings = []
progress = tqdm(total=num_batches, disable=not show_progress, desc="Processing")
for i in range(0, total_items, batch_size):
batch_texts = texts[i : i + batch_size] if texts is not None else None
batch_images = images[i : i + batch_size] if images is not None else None
batch_emb = self.embed(texts=batch_texts, images=batch_images, instruction=instruction)
all_embeddings.append(batch_emb.cpu())
progress.update(1)
progress.close()
return torch.cat(all_embeddings, dim=0).to(self.device)
def forward(self, **inputs) -> torch.Tensor:
"""Alias for encode_input"""
return self.encode_input(inputs)
### Modified from qwen_vl_utils.vision_process.py
import base64
import logging
import math
from io import BytesIO
import requests
IMAGE_FACTOR = 28
MIN_PIXELS = 256 * 28 * 28
MAX_PIXELS = 1280 * 28 * 28
MAX_RATIO = 200
def round_by_factor(number: int, factor: int) -> int:
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
return round(number / factor) * factor
def ceil_by_factor(number: int | float, factor: int) -> int:
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
return math.ceil(number / factor) * factor
def floor_by_factor(number: int | float, factor: int) -> int:
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
return math.floor(number / factor) * factor
def smart_resize(
height: int,
width: int,
factor: int = IMAGE_FACTOR,
min_pixels: int = MIN_PIXELS,
max_pixels: int = MAX_PIXELS,
) -> tuple[int, int]:
"""
Rescales the image so that the following conditions are met:
1. Both dimensions (height and width) are divisible by 'factor'.
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
3. The aspect ratio of the image is maintained as closely as possible.
"""
h_bar = max(factor, round_by_factor(height, factor))
w_bar = max(factor, round_by_factor(width, factor))
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = floor_by_factor(height / beta, factor)
w_bar = floor_by_factor(width / beta, factor)
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = ceil_by_factor(height * beta, factor)
w_bar = ceil_by_factor(width * beta, factor)
if max(h_bar, w_bar) / min(h_bar, w_bar) > MAX_RATIO:
logging.warning(f"Absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(h_bar, w_bar) / min(h_bar, w_bar)}")
if h_bar > w_bar:
h_bar = w_bar * MAX_RATIO
else:
w_bar = h_bar * MAX_RATIO
return h_bar, w_bar
def fetch_image(
image: str | Image.Image,
size_factor: int = IMAGE_FACTOR,
min_pixels: int = MIN_PIXELS,
max_pixels: int = MAX_PIXELS,
) -> Image.Image:
image_obj = None
if isinstance(image, Image.Image):
image_obj = image
elif image.startswith("http://") or image.startswith("https://"):
image_obj = Image.open(requests.get(image, stream=True).raw) # type: ignore
elif image.startswith("file://"):
image_obj = Image.open(image[7:])
elif image.startswith("data:image"):
if "base64," in image:
_, base64_data = image.split("base64,", 1)
data = base64.b64decode(base64_data)
image_obj = Image.open(BytesIO(data))
else:
image_obj = Image.open(image)
if image_obj is None:
raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
image = image_obj.convert("RGB")
width, height = image.size
resized_height, resized_width = smart_resize(
height,
width,
factor=size_factor,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
image = image.resize((resized_width, resized_height))
return image
###

12
requirements.txt Normal file
View File

@ -0,0 +1,12 @@
torch>=2.0.0
torchvision>=0.15.0
transformers>=4.30.0
accelerate>=0.20.0
faiss-cpu>=1.7.4
numpy>=1.21.0
Pillow>=9.0.0
scikit-learn>=1.3.0
tqdm>=4.65.0
flask>=2.3.0
werkzeug>=2.3.0
psutil>=5.9.0

Binary file not shown.

After

Width:  |  Height:  |  Size: 189 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 201 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 144 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 166 KiB

971
templates/index.html Normal file
View File

@ -0,0 +1,971 @@
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>多模态检索系统</title>
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css" rel="stylesheet">
<link href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css" rel="stylesheet">
<style>
:root {
--primary-color: #2563eb;
--secondary-color: #64748b;
--success-color: #059669;
--warning-color: #d97706;
--danger-color: #dc2626;
--bg-gradient: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
}
body {
background: var(--bg-gradient);
min-height: 100vh;
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
}
.main-container {
background: rgba(255, 255, 255, 0.95);
backdrop-filter: blur(10px);
border-radius: 20px;
box-shadow: 0 20px 40px rgba(0, 0, 0, 0.1);
margin: 20px auto;
max-width: 1200px;
}
.header {
background: linear-gradient(135deg, var(--primary-color), #3b82f6);
color: white;
padding: 2rem;
border-radius: 20px 20px 0 0;
text-align: center;
}
.mode-card {
background: white;
border-radius: 15px;
padding: 1.5rem;
margin-bottom: 1.5rem;
box-shadow: 0 4px 15px rgba(0, 0, 0, 0.1);
transition: all 0.3s ease;
border: 2px solid transparent;
}
.mode-card:hover {
transform: translateY(-5px);
box-shadow: 0 8px 25px rgba(0, 0, 0, 0.15);
}
.mode-card.active {
border-color: var(--primary-color);
background: linear-gradient(135deg, #eff6ff, #dbeafe);
}
.mode-icon {
font-size: 2.5rem;
margin-bottom: 1rem;
display: block;
}
.text-to-text { color: #059669; }
.text-to-image { color: #dc2626; }
.image-to-text { color: #d97706; }
.image-to-image { color: #7c3aed; }
.search-input {
border-radius: 12px;
border: 2px solid #e5e7eb;
padding: 12px 16px;
font-size: 16px;
transition: all 0.3s ease;
}
.search-input:focus {
border-color: var(--primary-color);
box-shadow: 0 0 0 3px rgba(37, 99, 235, 0.1);
}
.btn-primary {
background: var(--primary-color);
border: none;
border-radius: 12px;
padding: 12px 24px;
font-weight: 600;
transition: all 0.3s ease;
}
.btn-primary:hover {
background: #1d4ed8;
transform: translateY(-2px);
}
.file-upload-area {
border: 3px dashed #d1d5db;
border-radius: 12px;
padding: 3rem;
text-align: center;
transition: all 0.3s ease;
cursor: pointer;
}
.file-upload-area:hover {
border-color: var(--primary-color);
background: rgba(37, 99, 235, 0.05);
}
.file-upload-area.dragover {
border-color: var(--primary-color);
background: rgba(37, 99, 235, 0.1);
}
.result-card {
background: white;
border-radius: 12px;
padding: 1.5rem;
margin-bottom: 1rem;
box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1);
border-left: 4px solid var(--primary-color);
}
.result-image {
max-width: 200px;
max-height: 150px;
border-radius: 8px;
object-fit: cover;
}
.score-badge {
background: var(--success-color);
color: white;
padding: 4px 12px;
border-radius: 20px;
font-size: 0.85rem;
font-weight: 600;
}
.loading-spinner {
display: none;
text-align: center;
padding: 2rem;
}
.status-indicator {
position: fixed;
top: 20px;
right: 20px;
z-index: 1000;
}
.fade-in {
animation: fadeIn 0.5s ease-in;
}
@keyframes fadeIn {
from { opacity: 0; transform: translateY(20px); }
to { opacity: 1; transform: translateY(0); }
}
.query-image {
max-width: 300px;
max-height: 200px;
border-radius: 12px;
object-fit: cover;
box-shadow: 0 4px 15px rgba(0, 0, 0, 0.2);
}
</style>
</head>
<body>
<!-- 状态指示器 -->
<div class="status-indicator">
<div id="statusBadge" class="badge bg-secondary">
<i class="fas fa-circle-notch fa-spin"></i> 未初始化
</div>
</div>
<div class="container-fluid">
<div class="main-container">
<!-- 头部 -->
<div class="header">
<h1><i class="fas fa-search"></i> 多模态检索系统</h1>
<p class="mb-0">支持文搜图、文搜文、图搜图、图搜文四种检索模式</p>
</div>
<div class="p-4">
<!-- 重新初始化按钮 -->
<div class="text-center mb-4">
<button id="reinitBtn" class="btn btn-warning">
<i class="fas fa-redo"></i> 重新初始化系统
</button>
</div>
<!-- 检索模式选择 -->
<div class="row mb-4" id="modeSelection">
<div class="col-md-3">
<div class="mode-card text-center" data-mode="text_to_text">
<i class="fas fa-file-text mode-icon text-to-text"></i>
<h5>文搜文</h5>
<p class="text-muted">文本查找相似文本</p>
</div>
</div>
<div class="col-md-3">
<div class="mode-card text-center" data-mode="text_to_image">
<i class="fas fa-image mode-icon text-to-image"></i>
<h5>文搜图</h5>
<p class="text-muted">文本查找相关图片</p>
</div>
</div>
<div class="col-md-3">
<div class="mode-card text-center" data-mode="image_to_text">
<i class="fas fa-comment mode-icon image-to-text"></i>
<h5>图搜文</h5>
<p class="text-muted">图片查找相关文本</p>
</div>
</div>
<div class="col-md-3">
<div class="mode-card text-center" data-mode="image_to_image">
<i class="fas fa-images mode-icon image-to-image"></i>
<h5>图搜图</h5>
<p class="text-muted">图片查找相似图片</p>
</div>
</div>
</div>
<!-- 数据管理界面 -->
<div class="row mb-4" id="dataManagement">
<div class="col-12">
<div class="card">
<div class="card-header">
<h5><i class="fas fa-database"></i> 数据管理</h5>
<small class="text-muted">上传和管理检索数据库</small>
</div>
<div class="card-body">
<div class="row">
<!-- 批量上传图片 -->
<div class="col-md-6">
<div class="upload-section">
<h6><i class="fas fa-images text-primary"></i> 批量上传图片</h6>
<div class="file-upload-area" id="batchImageUpload">
<i class="fas fa-cloud-upload-alt fa-2x text-muted mb-2"></i>
<p>拖拽多张图片到此处或点击选择</p>
<input type="file" id="batchImageFiles" multiple accept="image/*" style="display: none;">
<button class="btn btn-outline-primary btn-sm mt-2" onclick="document.getElementById('batchImageFiles').click()">
<i class="fas fa-folder-open"></i> 选择图片
</button>
</div>
<div id="imageUploadProgress" class="mt-2" style="display: none;">
<div class="progress">
<div class="progress-bar" role="progressbar" style="width: 0%"></div>
</div>
<small class="text-muted mt-1 d-block">上传进度: <span id="imageProgressText">0/0</span></small>
</div>
</div>
</div>
<!-- 批量上传文本 -->
<div class="col-md-6">
<div class="upload-section">
<h6><i class="fas fa-file-text text-success"></i> 批量上传文本</h6>
<div class="mb-3">
<textarea id="batchTextInput" class="form-control" rows="8"
placeholder="请输入文本数据,每行一条文本记录...&#10;例如:&#10;这是第一条文本记录&#10;这是第二条文本记录&#10;这是第三条文本记录"></textarea>
</div>
<div class="d-flex gap-2">
<button id="uploadTextsBtn" class="btn btn-success">
<i class="fas fa-upload"></i> 上传文本
</button>
<button class="btn btn-outline-secondary" onclick="document.getElementById('textFile').click()">
<i class="fas fa-file-import"></i> 从文件导入
</button>
<input type="file" id="textFile" accept=".txt,.csv" style="display: none;">
</div>
</div>
</div>
</div>
<!-- 数据统计和管理 -->
<div class="row mt-4">
<div class="col-md-8">
<div class="d-flex gap-3">
<button id="buildIndexBtn" class="btn btn-warning" disabled>
<i class="fas fa-cogs"></i> 构建索引
</button>
<button id="viewDataBtn" class="btn btn-info">
<i class="fas fa-list"></i> 查看数据
</button>
<button id="clearDataBtn" class="btn btn-danger">
<i class="fas fa-trash"></i> 清空数据
</button>
</div>
</div>
<div class="col-md-4">
<div id="dataStats" class="text-end">
<small class="text-muted">
图片: <span id="imageCount">0</span> 张 |
文本: <span id="textCount">0</span>
</small>
</div>
</div>
</div>
</div>
</div>
</div>
</div>
<!-- 搜索界面 -->
<div id="searchInterface" style="display: none;">
<!-- 文本搜索 -->
<div id="textSearch" class="search-panel" style="display: none;">
<div class="row">
<div class="col-md-8">
<input type="text" id="textQuery" class="form-control search-input"
placeholder="请输入搜索文本...">
</div>
<div class="col-md-2">
<select id="textTopK" class="form-select search-input">
<option value="3">Top 3</option>
<option value="5" selected>Top 5</option>
<option value="10">Top 10</option>
</select>
</div>
<div class="col-md-2">
<button id="textSearchBtn" class="btn btn-primary w-100">
<i class="fas fa-search"></i> 搜索
</button>
</div>
</div>
</div>
<!-- 图片搜索 -->
<div id="imageSearch" class="search-panel" style="display: none;">
<div class="row">
<div class="col-md-8">
<div class="file-upload-area" id="fileUploadArea">
<i class="fas fa-cloud-upload-alt fa-3x text-muted mb-3"></i>
<h5>拖拽图片到此处或点击选择</h5>
<p class="text-muted">支持 PNG, JPG, JPEG, GIF, BMP, WebP 格式</p>
<input type="file" id="imageFile" accept="image/*" style="display: none;">
</div>
</div>
<div class="col-md-2">
<select id="imageTopK" class="form-select search-input">
<option value="3">Top 3</option>
<option value="5" selected>Top 5</option>
<option value="10">Top 10</option>
</select>
</div>
<div class="col-md-2">
<button id="imageSearchBtn" class="btn btn-primary w-100" disabled>
<i class="fas fa-search"></i> 搜索
</button>
</div>
</div>
</div>
</div>
<!-- 加载动画 -->
<div class="loading-spinner" id="loadingSpinner">
<div class="spinner-border text-primary" role="status">
<span class="visually-hidden">Loading...</span>
</div>
<p class="mt-2">正在搜索中...</p>
</div>
<!-- 搜索结果 -->
<div id="searchResults" class="mt-4"></div>
</div>
</div>
</div>
<script src="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/js/bootstrap.bundle.min.js"></script>
<script>
let currentMode = null;
let systemInitialized = false;
// 重新初始化系统
document.getElementById('reinitBtn').addEventListener('click', async function() {
const btn = this;
const originalText = btn.innerHTML;
btn.innerHTML = '<i class="fas fa-spinner fa-spin"></i> 重新初始化中...';
btn.disabled = true;
try {
const response = await fetch('/api/init', {
method: 'POST',
headers: {'Content-Type': 'application/json'}
});
const data = await response.json();
if (data.success) {
systemInitialized = true;
document.getElementById('statusBadge').innerHTML =
'<i class="fas fa-check-circle"></i> 已重新初始化';
document.getElementById('statusBadge').className = 'badge bg-success';
showAlert('success', `系统重新初始化成功GPU: ${data.gpu_count} 个`);
} else {
throw new Error(data.message);
}
} catch (error) {
showAlert('danger', '重新初始化失败: ' + error.message);
} finally {
btn.innerHTML = originalText;
btn.disabled = false;
}
});
// 模式选择
document.querySelectorAll('.mode-card').forEach(card => {
card.addEventListener('click', function() {
// 更新选中状态
document.querySelectorAll('.mode-card').forEach(c => c.classList.remove('active'));
this.classList.add('active');
currentMode = this.dataset.mode;
setupSearchInterface(currentMode);
});
});
// 设置搜索界面
function setupSearchInterface(mode) {
document.getElementById('searchInterface').style.display = 'block';
document.getElementById('textSearch').style.display = 'none';
document.getElementById('imageSearch').style.display = 'none';
document.getElementById('searchResults').innerHTML = '';
if (mode === 'text_to_text' || mode === 'text_to_image') {
document.getElementById('textSearch').style.display = 'block';
document.getElementById('textQuery').placeholder =
mode === 'text_to_text' ? '请输入要搜索的文本...' : '请输入要搜索图片的描述...';
} else {
document.getElementById('imageSearch').style.display = 'block';
}
}
// 文本搜索
document.getElementById('textSearchBtn').addEventListener('click', performTextSearch);
document.getElementById('textQuery').addEventListener('keypress', function(e) {
if (e.key === 'Enter') performTextSearch();
});
async function performTextSearch() {
const query = document.getElementById('textQuery').value.trim();
const topK = parseInt(document.getElementById('textTopK').value);
if (!query) {
showAlert('warning', '请输入搜索文本');
return;
}
showLoading(true);
try {
const endpoint = currentMode === 'text_to_text' ? '/api/search/text_to_text' : '/api/search/text_to_image';
const response = await fetch(endpoint, {
method: 'POST',
headers: {'Content-Type': 'application/json'},
body: JSON.stringify({query, top_k: topK})
});
const data = await response.json();
if (data.success) {
displayResults(data, currentMode);
} else {
throw new Error(data.message);
}
} catch (error) {
showAlert('danger', '搜索失败: ' + error.message);
} finally {
showLoading(false);
}
}
// 图片上传处理
const fileUploadArea = document.getElementById('fileUploadArea');
const imageFile = document.getElementById('imageFile');
fileUploadArea.addEventListener('click', () => imageFile.click());
fileUploadArea.addEventListener('dragover', handleDragOver);
fileUploadArea.addEventListener('drop', handleDrop);
imageFile.addEventListener('change', handleFileSelect);
function handleDragOver(e) {
e.preventDefault();
fileUploadArea.classList.add('dragover');
}
function handleDrop(e) {
e.preventDefault();
fileUploadArea.classList.remove('dragover');
const files = e.dataTransfer.files;
if (files.length > 0) {
handleFile(files[0]);
}
}
function handleFileSelect(e) {
const file = e.target.files[0];
if (file) handleFile(file);
}
function handleFile(file) {
if (!file.type.startsWith('image/')) {
showAlert('warning', '请选择图片文件');
return;
}
const reader = new FileReader();
reader.onload = function(e) {
fileUploadArea.innerHTML = `
<img src="${e.target.result}" class="query-image mb-3">
<p class="text-success"><i class="fas fa-check"></i> 图片已选择: ${file.name}</p>
`;
document.getElementById('imageSearchBtn').disabled = false;
};
reader.readAsDataURL(file);
}
// 图片搜索
document.getElementById('imageSearchBtn').addEventListener('click', async function() {
const file = imageFile.files[0];
const topK = parseInt(document.getElementById('imageTopK').value);
if (!file) {
showAlert('warning', '请选择图片文件');
return;
}
showLoading(true);
try {
const formData = new FormData();
formData.append('image', file);
formData.append('top_k', topK);
const endpoint = currentMode === 'image_to_text' ? '/api/search/image_to_text' : '/api/search/image_to_image';
const response = await fetch(endpoint, {
method: 'POST',
body: formData
});
const data = await response.json();
if (data.success) {
displayResults(data, currentMode);
} else {
throw new Error(data.message);
}
} catch (error) {
showAlert('danger', '搜索失败: ' + error.message);
} finally {
showLoading(false);
}
});
// 显示结果
function displayResults(data, mode) {
const resultsContainer = document.getElementById('searchResults');
let html = `
<div class="fade-in">
<div class="d-flex justify-content-between align-items-center mb-3">
<h4><i class="fas fa-search-plus"></i> 搜索结果</h4>
<div>
<span class="badge bg-info">找到 ${data.result_count} 个结果</span>
<span class="badge bg-secondary">耗时 ${data.search_time}s</span>
</div>
</div>
`;
if (data.query_image) {
html += `
<div class="result-card">
<h6><i class="fas fa-image"></i> 查询图片</h6>
<img src="data:image/jpeg;base64,${data.query_image}" class="query-image">
</div>
`;
}
if (data.query) {
html += `
<div class="result-card">
<h6><i class="fas fa-quote-left"></i> 查询文本</h6>
<p class="mb-0">"${data.query}"</p>
</div>
`;
}
data.results.forEach((result, index) => {
html += '<div class="result-card">';
if (mode === 'text_to_image' || mode === 'image_to_image') {
html += `
<div class="row">
<div class="col-md-3">
<img src="data:image/jpeg;base64,${result.image_base64}"
class="result-image" alt="Result ${index + 1}">
</div>
<div class="col-md-9">
<div class="d-flex justify-content-between align-items-start">
<h6><i class="fas fa-image"></i> ${result.filename}</h6>
<span class="score-badge">相似度: ${(result.score * 100).toFixed(1)}%</span>
</div>
<p class="text-muted mb-0">路径: ${result.image_path}</p>
</div>
</div>
`;
} else {
html += `
<div class="d-flex justify-content-between align-items-start">
<div>
<h6><i class="fas fa-file-text"></i> 结果 ${index + 1}</h6>
<p class="mb-0">${result.text || result}</p>
</div>
<span class="score-badge">相似度: ${((result.score || 0.95) * 100).toFixed(1)}%</span>
</div>
`;
}
html += '</div>';
});
html += '</div>';
resultsContainer.innerHTML = html;
}
// 工具函数
function showLoading(show) {
document.getElementById('loadingSpinner').style.display = show ? 'block' : 'none';
}
function showAlert(type, message) {
const alertDiv = document.createElement('div');
alertDiv.className = `alert alert-${type} alert-dismissible fade show`;
alertDiv.innerHTML = `
${message}
<button type="button" class="btn-close" data-bs-dismiss="alert"></button>
`;
document.querySelector('.main-container .p-4').insertBefore(alertDiv, document.querySelector('.main-container .p-4').firstChild);
setTimeout(() => {
if (alertDiv.parentNode) {
alertDiv.remove();
}
}, 5000);
}
// 检查系统状态
async function checkStatus() {
try {
const response = await fetch('/api/status');
const data = await response.json();
if (data.initialized) {
systemInitialized = true;
document.getElementById('statusBadge').innerHTML =
'<i class="fas fa-check-circle"></i> 已初始化';
document.getElementById('statusBadge').className = 'badge bg-success';
} else {
document.getElementById('statusBadge').innerHTML =
'<i class="fas fa-exclamation-triangle"></i> 未初始化';
document.getElementById('statusBadge').className = 'badge bg-warning';
}
} catch (error) {
console.log('Status check failed:', error);
document.getElementById('statusBadge').innerHTML =
'<i class="fas fa-times-circle"></i> 连接失败';
document.getElementById('statusBadge').className = 'badge bg-danger';
}
}
// 页面加载时检查状态
checkStatus();
// 设置数据管理功能事件绑定
setupDataManagement();
function setupDataManagement() {
// 批量图片上传事件
const batchImageUpload = document.getElementById('batchImageUpload');
const batchImageFiles = document.getElementById('batchImageFiles');
// 拖拽上传
batchImageUpload.addEventListener('dragover', function(e) {
e.preventDefault();
this.classList.add('dragover');
});
batchImageUpload.addEventListener('dragleave', function(e) {
e.preventDefault();
this.classList.remove('dragover');
});
batchImageUpload.addEventListener('drop', function(e) {
e.preventDefault();
this.classList.remove('dragover');
const files = Array.from(e.dataTransfer.files).filter(file => file.type.startsWith('image/'));
if (files.length > 0) {
uploadBatchImages(files);
}
});
batchImageFiles.addEventListener('change', function(e) {
if (e.target.files.length > 0) {
uploadBatchImages(Array.from(e.target.files));
}
});
// 批量文本上传
document.getElementById('uploadTextsBtn').addEventListener('click', function() {
const textData = document.getElementById('batchTextInput').value.trim();
if (textData) {
const texts = textData.split('\n').filter(line => line.trim());
if (texts.length > 0) {
uploadBatchTexts(texts);
} else {
showAlert('warning', '请输入有效的文本数据');
}
} else {
showAlert('warning', '请输入文本数据');
}
});
// 从文件导入文本
document.getElementById('textFile').addEventListener('change', function(e) {
const file = e.target.files[0];
if (file) {
const reader = new FileReader();
reader.onload = function(e) {
document.getElementById('batchTextInput').value = e.target.result;
};
reader.readAsText(file);
}
});
// 构建索引
document.getElementById('buildIndexBtn').addEventListener('click', buildIndex);
// 查看数据
document.getElementById('viewDataBtn').addEventListener('click', viewData);
// 清空数据
document.getElementById('clearDataBtn').addEventListener('click', clearData);
// 初始化时更新数据统计
updateDataStats();
}
// 批量上传图片
async function uploadBatchImages(files) {
const progressDiv = document.getElementById('imageUploadProgress');
const progressBar = progressDiv.querySelector('.progress-bar');
const progressText = document.getElementById('imageProgressText');
progressDiv.style.display = 'block';
progressText.textContent = `0/${files.length}`;
progressBar.style.width = '0%';
const formData = new FormData();
files.forEach(file => {
formData.append('images', file);
});
try {
const response = await fetch('/api/upload/images', {
method: 'POST',
body: formData
});
const data = await response.json();
if (data.success) {
progressBar.style.width = '100%';
progressText.textContent = `${files.length}/${files.length}`;
showAlert('success', `成功上传 ${data.uploaded_count} 张图片`);
updateDataStats();
document.getElementById('buildIndexBtn').disabled = false;
} else {
showAlert('danger', `上传失败: ${data.message}`);
}
} catch (error) {
showAlert('danger', `上传错误: ${error.message}`);
} finally {
setTimeout(() => {
progressDiv.style.display = 'none';
}, 2000);
}
}
// 批量上传文本
async function uploadBatchTexts(texts) {
try {
const response = await fetch('/api/upload/texts', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({ texts: texts })
});
const data = await response.json();
if (data.success) {
showAlert('success', `成功上传 ${data.uploaded_count} 条文本`);
document.getElementById('batchTextInput').value = '';
updateDataStats();
document.getElementById('buildIndexBtn').disabled = false;
} else {
showAlert('danger', `上传失败: ${data.message}`);
}
} catch (error) {
showAlert('danger', `上传错误: ${error.message}`);
}
}
// 构建索引
async function buildIndex() {
const btn = document.getElementById('buildIndexBtn');
const originalText = btn.innerHTML;
btn.innerHTML = '<i class="fas fa-spinner fa-spin"></i> 构建中...';
btn.disabled = true;
try {
const response = await fetch('/api/build_index', {
method: 'POST'
});
const data = await response.json();
if (data.success) {
showAlert('success', '索引构建完成!现在可以进行搜索了');
} else {
showAlert('danger', `索引构建失败: ${data.message}`);
}
} catch (error) {
showAlert('danger', `构建错误: ${error.message}`);
} finally {
btn.innerHTML = originalText;
btn.disabled = false;
}
}
// 查看数据
async function viewData() {
try {
const response = await fetch('/api/data/list');
const data = await response.json();
if (data.success) {
let content = '<div class="row">';
// 显示图片数据
if (data.images && data.images.length > 0) {
content += '<div class="col-md-6"><h6>图片数据 (' + data.images.length + ')</h6>';
content += '<div class="list-group" style="max-height: 300px; overflow-y: auto;">';
data.images.forEach(img => {
content += `<div class="list-group-item d-flex justify-content-between align-items-center">
<span>${img}</span>
<img src="/uploads/${img}" class="img-thumbnail" style="width: 50px; height: 50px; object-fit: cover;">
</div>`;
});
content += '</div></div>';
}
// 显示文本数据
if (data.texts && data.texts.length > 0) {
content += '<div class="col-md-6"><h6>文本数据 (' + data.texts.length + ')</h6>';
content += '<div class="list-group" style="max-height: 300px; overflow-y: auto;">';
data.texts.forEach((text, index) => {
const shortText = text.length > 50 ? text.substring(0, 50) + '...' : text;
content += `<div class="list-group-item">
<small class="text-muted">#${index + 1}</small><br>
${shortText}
</div>`;
});
content += '</div></div>';
}
content += '</div>';
showModal('数据列表', content);
} else {
showAlert('danger', `获取数据失败: ${data.message}`);
}
} catch (error) {
showAlert('danger', `获取数据错误: ${error.message}`);
}
}
// 清空数据
async function clearData() {
if (!confirm('确定要清空所有数据吗?此操作不可恢复!')) {
return;
}
try {
const response = await fetch('/api/data/clear', {
method: 'POST'
});
const data = await response.json();
if (data.success) {
showAlert('success', '数据已清空');
updateDataStats();
document.getElementById('buildIndexBtn').disabled = true;
} else {
showAlert('danger', `清空失败: ${data.message}`);
}
} catch (error) {
showAlert('danger', `清空错误: ${error.message}`);
}
}
// 更新数据统计
async function updateDataStats() {
try {
const response = await fetch('/api/data/stats');
const data = await response.json();
if (data.success) {
document.getElementById('imageCount').textContent = data.image_count || 0;
document.getElementById('textCount').textContent = data.text_count || 0;
}
} catch (error) {
console.log('获取数据统计失败:', error);
}
}
// 显示模态框
function showModal(title, content) {
const modalHtml = `
<div class="modal fade" id="dataModal" tabindex="-1">
<div class="modal-dialog modal-lg">
<div class="modal-content">
<div class="modal-header">
<h5 class="modal-title">${title}</h5>
<button type="button" class="btn-close" data-bs-dismiss="modal"></button>
</div>
<div class="modal-body">
${content}
</div>
</div>
</div>
</div>
`;
// 移除已存在的模态框
const existingModal = document.getElementById('dataModal');
if (existingModal) {
existingModal.remove();
}
document.body.insertAdjacentHTML('beforeend', modalHtml);
const modal = new bootstrap.Modal(document.getElementById('dataModal'));
modal.show();
}
</script>
</body>
</html>

104
test_all_retrieval_modes.py Normal file
View File

@ -0,0 +1,104 @@
#!/usr/bin/env python3
"""
测试所有四种多模态检索模式
"""
from multimodal_retrieval_multigpu import MultiGPUMultimodalRetrieval
import numpy as np
from PIL import Image
import os
def test_all_retrieval_modes():
print('正在初始化多GPU多模态检索系统...')
retrieval = MultiGPUMultimodalRetrieval()
# 准备测试数据
test_texts = [
"一只可爱的小猫",
"美丽的风景照片",
"现代建筑设计",
"colorful flowers in garden"
]
test_images = [
'sample_images/1755677101_1__.jpg',
'sample_images/1755677101_2__.jpg',
'sample_images/1755677101_3__.jpg',
'sample_images/1755677101_4__.jpg'
]
# 验证测试图像存在
existing_images = [img for img in test_images if os.path.exists(img)]
if not existing_images:
print("❌ 没有找到测试图像文件")
return
print(f"找到 {len(existing_images)} 张测试图像")
try:
# 1. 构建文本索引
print('\n=== 构建文本索引 ===')
retrieval.build_text_index_parallel(test_texts)
print('✅ 文本索引构建完成')
# 2. 构建图像索引
print('\n=== 构建图像索引 ===')
retrieval.build_image_index_parallel(existing_images)
print('✅ 图像索引构建完成')
# 3. 测试文本到文本检索
print('\n=== 测试文本到文本检索 ===')
query = "小动物"
results = retrieval.search_text_by_text(query, top_k=3)
print(f'查询: "{query}"')
for i, (text, score) in enumerate(results):
print(f' {i+1}. {text} (相似度: {score:.4f})')
# 4. 测试文本到图像检索
print('\n=== 测试文本到图像检索 ===')
query = "beautiful image"
results = retrieval.search_images_by_text(query, top_k=3)
print(f'查询: "{query}"')
for i, (image_path, score) in enumerate(results):
print(f' {i+1}. {image_path} (相似度: {score:.4f})')
# 5. 测试图像到文本检索
print('\n=== 测试图像到文本检索 ===')
query_image = existing_images[0]
results = retrieval.search_text_by_image(query_image, top_k=3)
print(f'查询图像: {query_image}')
for i, (text, score) in enumerate(results):
print(f' {i+1}. {text} (相似度: {score:.4f})')
# 6. 测试图像到图像检索
print('\n=== 测试图像到图像检索 ===')
query_image = existing_images[0]
results = retrieval.search_images_by_image(query_image, top_k=3)
print(f'查询图像: {query_image}')
for i, (image_path, score) in enumerate(results):
print(f' {i+1}. {image_path} (相似度: {score:.4f})')
print('\n✅ 所有四种检索模式测试完成!')
# 7. 测试Web应用兼容的方法名
print('\n=== 测试Web应用兼容方法 ===')
try:
results = retrieval.search_text_to_image("test query", top_k=2)
print('✅ search_text_to_image 方法正常')
results = retrieval.search_image_to_text(existing_images[0], top_k=2)
print('✅ search_image_to_text 方法正常')
results = retrieval.search_image_to_image(existing_images[0], top_k=2)
print('✅ search_image_to_image 方法正常')
except Exception as e:
print(f'❌ Web应用兼容方法测试失败: {e}')
except Exception as e:
print(f'❌ 测试过程中出现错误: {e}')
import traceback
traceback.print_exc()
if __name__ == "__main__":
test_all_retrieval_modes()

49
test_image_encoding.py Normal file
View File

@ -0,0 +1,49 @@
#!/usr/bin/env python3
"""
测试图像编码功能
"""
from multimodal_retrieval_multigpu import MultiGPUMultimodalRetrieval
import numpy as np
from PIL import Image
def test_image_encoding():
print('正在初始化多GPU多模态检索系统...')
retrieval = MultiGPUMultimodalRetrieval()
# 测试文本编码
print('测试文本编码...')
text_embeddings = retrieval.encode_text_batch(['这是一个测试文本'])
print(f'文本embedding形状: {text_embeddings.shape}')
print(f'文本embedding数据类型: {text_embeddings.dtype}')
# 测试图像编码
print('测试图像编码...')
test_images = ['sample_images/1755677101_1__.jpg']
image_embeddings = retrieval.encode_image_batch(test_images)
print(f'图像embedding形状: {image_embeddings.shape}')
print(f'图像embedding数据类型: {image_embeddings.dtype}')
# 测试两次相同图像的embedding是否一致
print('测试embedding一致性...')
image_embeddings2 = retrieval.encode_image_batch(test_images)
consistency = np.allclose(image_embeddings, image_embeddings2, rtol=1e-5)
print(f'相同图像embedding一致性: {consistency}')
# 测试不同图像的embedding差异
print('测试不同图像embedding差异...')
test_images2 = ['sample_images/1755677101_2__.jpg']
image_embeddings3 = retrieval.encode_image_batch(test_images2)
similarity = np.dot(image_embeddings[0], image_embeddings3[0]) / (np.linalg.norm(image_embeddings[0]) * np.linalg.norm(image_embeddings3[0]))
print(f'不同图像间相似度: {similarity:.4f}')
# 验证维度一致性
if text_embeddings.shape[1] == image_embeddings.shape[1]:
print('✅ 文本和图像embedding维度一致')
else:
print('❌ 文本和图像embedding维度不一致')
print('测试完成!')
if __name__ == "__main__":
test_image_encoding()

Binary file not shown.

After

Width:  |  Height:  |  Size: 172 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 201 KiB

617
web_app_multigpu.py Normal file
View File

@ -0,0 +1,617 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
多GPU多模态检索系统 - Web应用
专为双GPU部署优化
"""
import os
import json
import time
from flask import Flask, render_template, request, jsonify, send_file, url_for
from werkzeug.utils import secure_filename
from PIL import Image
import base64
import io
import logging
import traceback
import glob
# 设置环境变量优化GPU内存
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = Flask(__name__)
app.config['SECRET_KEY'] = 'multigpu_multimodal_retrieval_2024'
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max file size
# 配置上传文件夹
UPLOAD_FOLDER = 'uploads'
SAMPLE_IMAGES_FOLDER = 'sample_images'
TEXT_DATA_FOLDER = 'text_data'
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'bmp', 'webp'}
# 确保文件夹存在
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
os.makedirs(SAMPLE_IMAGES_FOLDER, exist_ok=True)
os.makedirs(TEXT_DATA_FOLDER, exist_ok=True)
# 全局检索系统实例
retrieval_system = None
def allowed_file(filename):
"""检查文件扩展名是否允许"""
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
def image_to_base64(image_path):
"""将图片转换为base64编码"""
try:
with open(image_path, "rb") as img_file:
return base64.b64encode(img_file.read()).decode('utf-8')
except Exception as e:
logger.error(f"图片转换失败: {e}")
return None
@app.route('/')
def index():
"""主页"""
return render_template('index.html')
@app.route('/api/status')
def get_status():
"""获取系统状态"""
global retrieval_system
status = {
'initialized': retrieval_system is not None,
'gpu_count': 0,
'model_loaded': False
}
try:
import torch
if torch.cuda.is_available():
status['gpu_count'] = torch.cuda.device_count()
if retrieval_system and retrieval_system.model:
status['model_loaded'] = True
status['device_ids'] = retrieval_system.device_ids
except Exception as e:
logger.error(f"获取状态失败: {e}")
return jsonify(status)
@app.route('/api/init', methods=['POST'])
def initialize_system():
"""初始化多GPU检索系统"""
global retrieval_system
try:
logger.info("正在初始化多GPU检索系统...")
# 导入多GPU检索系统
from multimodal_retrieval_multigpu import MultiGPUMultimodalRetrieval
# 初始化系统
retrieval_system = MultiGPUMultimodalRetrieval()
if retrieval_system.model is None:
raise Exception("模型加载失败")
logger.info("✅ 多GPU系统初始化成功")
return jsonify({
'success': True,
'message': '多GPU系统初始化成功',
'device_ids': retrieval_system.device_ids,
'gpu_count': len(retrieval_system.device_ids)
})
except Exception as e:
error_msg = f"系统初始化失败: {str(e)}"
logger.error(error_msg)
logger.error(traceback.format_exc())
return jsonify({
'success': False,
'message': error_msg
}), 500
@app.route('/api/search/text_to_text', methods=['POST'])
def search_text_to_text():
"""文本搜索文本"""
return handle_search('text_to_text')
@app.route('/api/search/text_to_image', methods=['POST'])
def search_text_to_image():
"""文本搜索图片"""
return handle_search('text_to_image')
@app.route('/api/search/image_to_text', methods=['POST'])
def search_image_to_text():
"""图片搜索文本"""
return handle_search('image_to_text')
@app.route('/api/search/image_to_image', methods=['POST'])
def search_image_to_image():
"""图片搜索图片"""
return handle_search('image_to_image')
@app.route('/api/search', methods=['POST'])
def search():
"""通用搜索接口(兼容旧版本)"""
mode = request.form.get('mode') or request.json.get('mode', 'text_to_text')
return handle_search(mode)
def handle_search(mode):
"""处理搜索请求的通用函数"""
global retrieval_system
if not retrieval_system:
return jsonify({
'success': False,
'message': '系统未初始化,请先点击初始化按钮'
}), 400
try:
top_k = int(request.form.get('top_k', 5))
if mode in ['text_to_text', 'text_to_image']:
# 文本查询
query = request.form.get('query') or request.json.get('query', '')
if not query.strip():
return jsonify({
'success': False,
'message': '请输入查询文本'
}), 400
logger.info(f"执行{mode}搜索: {query}")
# 执行搜索
if mode == 'text_to_text':
results = retrieval_system.search_text_to_text(query, top_k=top_k)
else: # text_to_image
results = retrieval_system.search_text_to_image(query, top_k=top_k)
return jsonify({
'success': True,
'mode': mode,
'query': query,
'results': results,
'result_count': len(results)
})
elif mode in ['image_to_text', 'image_to_image']:
# 图片查询
if 'image' not in request.files:
return jsonify({
'success': False,
'message': '请上传查询图片'
}), 400
file = request.files['image']
if file.filename == '' or not allowed_file(file.filename):
return jsonify({
'success': False,
'message': '请上传有效的图片文件'
}), 400
# 保存上传的图片
filename = secure_filename(file.filename)
timestamp = str(int(time.time()))
filename = f"query_{timestamp}_{filename}"
filepath = os.path.join(UPLOAD_FOLDER, filename)
file.save(filepath)
logger.info(f"执行{mode}搜索,图片: {filename}")
# 执行搜索
if mode == 'image_to_text':
results = retrieval_system.search_image_to_text(filepath, top_k=top_k)
else: # image_to_image
results = retrieval_system.search_image_to_image(filepath, top_k=top_k)
# 转换查询图片为base64
query_image_b64 = image_to_base64(filepath)
return jsonify({
'success': True,
'mode': mode,
'query_image': query_image_b64,
'results': results,
'result_count': len(results)
})
else:
return jsonify({
'success': False,
'message': f'不支持的搜索模式: {mode}'
}), 400
except Exception as e:
error_msg = f"搜索失败: {str(e)}"
logger.error(error_msg)
logger.error(traceback.format_exc())
return jsonify({
'success': False,
'message': error_msg
}), 500
@app.route('/api/upload/images', methods=['POST'])
def upload_images():
"""批量上传图片"""
try:
uploaded_files = []
if 'images' not in request.files:
return jsonify({'success': False, 'message': '没有选择文件'}), 400
files = request.files.getlist('images')
for file in files:
if file and file.filename != '' and allowed_file(file.filename):
filename = secure_filename(file.filename)
timestamp = str(int(time.time()))
filename = f"{timestamp}_{filename}"
filepath = os.path.join(SAMPLE_IMAGES_FOLDER, filename)
file.save(filepath)
uploaded_files.append(filename)
return jsonify({
'success': True,
'message': f'成功上传 {len(uploaded_files)} 个图片文件',
'uploaded_count': len(uploaded_files),
'files': uploaded_files
})
except Exception as e:
return jsonify({
'success': False,
'message': f'上传失败: {str(e)}'
}), 500
@app.route('/api/upload/texts', methods=['POST'])
def upload_texts():
"""批量上传文本数据"""
try:
data = request.get_json()
if not data or 'texts' not in data:
return jsonify({'success': False, 'message': '没有提供文本数据'}), 400
texts = data['texts']
if not isinstance(texts, list):
return jsonify({'success': False, 'message': '文本数据格式错误'}), 400
# 保存文本数据到文件
timestamp = str(int(time.time()))
filename = f"texts_{timestamp}.json"
filepath = os.path.join(TEXT_DATA_FOLDER, filename)
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(texts, f, ensure_ascii=False, indent=2)
return jsonify({
'success': True,
'message': f'成功上传 {len(texts)} 条文本',
'uploaded_count': len(texts)
})
except Exception as e:
return jsonify({
'success': False,
'message': f'上传失败: {str(e)}'
}), 500
@app.route('/api/upload/file', methods=['POST'])
def upload_single_file():
"""上传单个文件"""
if 'file' not in request.files:
return jsonify({'success': False, 'message': '没有选择文件'}), 400
file = request.files['file']
if file.filename == '':
return jsonify({'success': False, 'message': '没有选择文件'}), 400
if file and allowed_file(file.filename):
filename = secure_filename(file.filename)
timestamp = str(int(time.time()))
filename = f"{timestamp}_{filename}"
filepath = os.path.join(SAMPLE_IMAGES_FOLDER, filename)
file.save(filepath)
return jsonify({
'success': True,
'message': '文件上传成功',
'filename': filename
})
return jsonify({'success': False, 'message': '不支持的文件类型'}), 400
@app.route('/api/data/list', methods=['GET'])
def list_data():
"""列出已上传的数据"""
try:
# 列出图片文件
images = []
if os.path.exists(SAMPLE_IMAGES_FOLDER):
for filename in os.listdir(SAMPLE_IMAGES_FOLDER):
if allowed_file(filename):
filepath = os.path.join(SAMPLE_IMAGES_FOLDER, filename)
stat = os.stat(filepath)
images.append({
'filename': filename,
'size': stat.st_size,
'modified': stat.st_mtime
})
# 列出文本文件
texts = []
if os.path.exists(TEXT_DATA_FOLDER):
for filename in os.listdir(TEXT_DATA_FOLDER):
if filename.endswith('.json'):
filepath = os.path.join(TEXT_DATA_FOLDER, filename)
stat = os.stat(filepath)
texts.append({
'filename': filename,
'size': stat.st_size,
'modified': stat.st_mtime
})
return jsonify({
'success': True,
'data': {
'images': images,
'texts': texts
}
})
except Exception as e:
return jsonify({
'success': False,
'message': f'获取数据列表失败: {str(e)}'
}), 500
@app.route('/api/gpu_status')
def gpu_status():
"""获取GPU状态"""
try:
from smart_gpu_launcher import get_gpu_memory_info
gpu_info = get_gpu_memory_info()
return jsonify({
'success': True,
'gpu_info': gpu_info
})
except Exception as e:
return jsonify({
'success': False,
'message': f"获取GPU状态失败: {str(e)}"
}), 500
@app.route('/api/build_index', methods=['POST'])
def build_index():
"""构建检索索引"""
global retrieval_system
if not retrieval_system:
return jsonify({
'success': False,
'message': '系统未初始化'
}), 400
try:
# 获取所有图片和文本文件
image_files = []
text_data = []
# 扫描图片文件
for ext in ALLOWED_EXTENSIONS:
pattern = os.path.join(SAMPLE_IMAGES_FOLDER, f"*.{ext}")
image_files.extend(glob.glob(pattern))
# 读取文本文件(支持.json和.txt格式
text_files = glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.json"))
text_files.extend(glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.txt")))
for text_file in text_files:
try:
if text_file.endswith('.json'):
# 读取JSON格式的文本数据
with open(text_file, 'r', encoding='utf-8') as f:
data = json.load(f)
if isinstance(data, list):
text_data.extend([str(item).strip() for item in data if str(item).strip()])
else:
text_data.append(str(data).strip())
else:
# 读取TXT格式的文本数据
with open(text_file, 'r', encoding='utf-8') as f:
lines = [line.strip() for line in f.readlines() if line.strip()]
text_data.extend(lines)
except Exception as e:
logger.warning(f"读取文本文件失败 {text_file}: {e}")
# 检查是否有数据可以构建索引
if not image_files and not text_data:
return jsonify({
'success': False,
'message': '没有找到可用的图片或文本数据,请先上传数据'
}), 400
# 构建索引
if image_files:
logger.info(f"构建图片索引,共 {len(image_files)} 张图片")
retrieval_system.build_image_index_parallel(image_files)
if text_data:
logger.info(f"构建文本索引,共 {len(text_data)} 条文本")
retrieval_system.build_text_index_parallel(text_data)
return jsonify({
'success': True,
'message': f'索引构建完成!图片: {len(image_files)} 张,文本: {len(text_data)}',
'image_count': len(image_files),
'text_count': len(text_data)
})
except Exception as e:
logger.error(f"构建索引失败: {str(e)}")
return jsonify({
'success': False,
'message': f'构建索引失败: {str(e)}'
}), 500
@app.route('/api/data/stats', methods=['GET'])
def get_data_stats():
"""获取数据统计信息"""
try:
# 统计图片文件
image_count = 0
for ext in ALLOWED_EXTENSIONS:
pattern = os.path.join(SAMPLE_IMAGES_FOLDER, f"*.{ext}")
image_count += len(glob.glob(pattern))
# 统计文本数据
text_count = 0
text_files = glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.txt"))
for text_file in text_files:
try:
with open(text_file, 'r', encoding='utf-8') as f:
lines = [line.strip() for line in f.readlines() if line.strip()]
text_count += len(lines)
except Exception:
continue
return jsonify({
'success': True,
'image_count': image_count,
'text_count': text_count
})
except Exception as e:
logger.error(f"获取数据统计失败: {str(e)}")
return jsonify({
'success': False,
'message': f'获取统计失败: {str(e)}'
}), 500
@app.route('/api/data/clear', methods=['POST'])
def clear_data():
"""清空所有数据"""
try:
# 清空图片文件
for ext in ALLOWED_EXTENSIONS:
pattern = os.path.join(SAMPLE_IMAGES_FOLDER, f"*.{ext}")
for file_path in glob.glob(pattern):
try:
os.remove(file_path)
except Exception as e:
logger.warning(f"删除图片文件失败 {file_path}: {e}")
# 清空文本文件
text_files = glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.txt"))
for text_file in text_files:
try:
os.remove(text_file)
except Exception as e:
logger.warning(f"删除文本文件失败 {text_file}: {e}")
# 重置索引
global retrieval_system
if retrieval_system:
retrieval_system.text_index = None
retrieval_system.image_index = None
retrieval_system.text_data = []
retrieval_system.image_data = []
return jsonify({
'success': True,
'message': '数据已清空'
})
except Exception as e:
logger.error(f"清空数据失败: {str(e)}")
return jsonify({
'success': False,
'message': f'清空数据失败: {str(e)}'
}), 500
@app.route('/uploads/<filename>')
def uploaded_file(filename):
"""提供上传文件的访问"""
return send_file(os.path.join(SAMPLE_IMAGES_FOLDER, filename))
def print_startup_info():
"""打印启动信息"""
print("🚀 启动多GPU多模态检索Web应用")
print("=" * 60)
print("访问地址: http://localhost:5000")
print("支持功能:")
print(" 📝 文搜文 - 文本查找相似文本")
print(" 🖼️ 文搜图 - 文本查找相关图片")
print(" 📝 图搜文 - 图片查找相关文本")
print(" 🖼️ 图搜图 - 图片查找相似图片")
print(" 📤 批量上传 - 图片和文本数据管理")
print("GPU配置:")
try:
import torch
if torch.cuda.is_available():
gpu_count = torch.cuda.device_count()
print(f" 🖥️ 检测到 {gpu_count} 个GPU")
for i in range(gpu_count):
name = torch.cuda.get_device_name(i)
props = torch.cuda.get_device_properties(i)
memory_gb = props.total_memory / 1024**3
print(f" GPU {i}: {name} ({memory_gb:.1f}GB)")
else:
print(" ❌ CUDA不可用")
except Exception as e:
print(f" ❌ GPU检查失败: {e}")
print("=" * 60)
def auto_initialize():
"""启动时自动初始化系统"""
global retrieval_system
try:
logger.info("🚀 启动时自动初始化多GPU检索系统...")
# 导入多GPU检索系统
from multimodal_retrieval_multigpu import MultiGPUMultimodalRetrieval
# 初始化系统
retrieval_system = MultiGPUMultimodalRetrieval()
if retrieval_system.model is None:
raise Exception("模型加载失败")
logger.info("✅ 系统自动初始化成功")
return True
except Exception as e:
logger.error(f"❌ 系统自动初始化失败: {str(e)}")
logger.error(traceback.format_exc())
return False
if __name__ == '__main__':
print_startup_info()
# 启动时自动初始化
auto_initialize()
# 启动Flask应用
app.run(
host='0.0.0.0',
port=5000,
debug=False,
threaded=True
)