base feature
This commit is contained in:
parent
1de00fccda
commit
0cd7a4cb41
201
LICENSE
201
LICENSE
@ -1,201 +0,0 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
170
README.md
Normal file
170
README.md
Normal file
@ -0,0 +1,170 @@
|
||||
# 多模态检索系统 (Multimodal Retrieval System)
|
||||
|
||||
基于 OpenSearch-AI/Ops-MM-embedding-v1-7B 模型的多模态检索系统,支持四种检索模式:文搜图、文搜文、图搜图、图搜文。
|
||||
|
||||
## 功能特性
|
||||
|
||||
- **文搜文 (Text-to-Text)**: 使用文本查询搜索相似文本
|
||||
- **文搜图 (Text-to-Image)**: 使用文本查询搜索相关图像
|
||||
- **图搜图 (Image-to-Image)**: 使用图像查询搜索相似图像
|
||||
- **图搜文 (Image-to-Text)**: 使用图像查询搜索相关文本
|
||||
|
||||
## 环境要求
|
||||
|
||||
- Python 3.8+
|
||||
- CUDA 支持的 GPU (推荐,也可使用CPU)
|
||||
- 至少 8GB 内存
|
||||
|
||||
## 安装依赖
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 1. 检查GPU环境
|
||||
|
||||
```python
|
||||
from multimodal_retrieval import check_gpu_info
|
||||
check_gpu_info()
|
||||
```
|
||||
|
||||
### 2. 初始化系统
|
||||
|
||||
```python
|
||||
from multimodal_retrieval import MultimodalRetrieval
|
||||
|
||||
# 初始化检索系统
|
||||
retrieval_system = MultimodalRetrieval()
|
||||
```
|
||||
|
||||
### 3. 构建索引
|
||||
|
||||
```python
|
||||
# 构建文本索引
|
||||
texts = ["一只可爱的小猫", "美丽的山景", "现代化城市"]
|
||||
retrieval_system.build_text_index(texts)
|
||||
|
||||
# 构建图像索引
|
||||
image_paths = ["./images/cat.jpg", "./images/mountain.jpg", "./images/city.jpg"]
|
||||
retrieval_system.build_image_index(image_paths)
|
||||
```
|
||||
|
||||
### 4. 执行检索
|
||||
|
||||
```python
|
||||
# 文搜文
|
||||
results = retrieval_system.search_text_by_text("猫咪", top_k=5)
|
||||
|
||||
# 文搜图
|
||||
results = retrieval_system.search_images_by_text("动物", top_k=5)
|
||||
|
||||
# 图搜图
|
||||
results = retrieval_system.search_images_by_image("./query_image.jpg", top_k=5)
|
||||
|
||||
# 图搜文
|
||||
results = retrieval_system.search_text_by_image("./query_image.jpg", top_k=5)
|
||||
```
|
||||
|
||||
## 运行演示
|
||||
|
||||
```bash
|
||||
python demo.py
|
||||
```
|
||||
|
||||
演示脚本会自动:
|
||||
1. 检查GPU环境信息
|
||||
2. 初始化多模态检索系统
|
||||
3. 演示四种检索模式
|
||||
4. 显示检索结果和相似度分数
|
||||
|
||||
## 文件结构
|
||||
|
||||
```
|
||||
mmeb/
|
||||
├── multimodal_retrieval.py # 主要检索系统类
|
||||
├── demo.py # 演示脚本
|
||||
├── requirements.txt # 依赖包列表
|
||||
├── README.md # 项目说明
|
||||
└── images/ # 图像数据目录 (需要自行创建)
|
||||
```
|
||||
|
||||
## API 参考
|
||||
|
||||
### MultimodalRetrieval 类
|
||||
|
||||
#### 初始化参数
|
||||
- `model_name`: 模型名称,默认 "OpenSearch-AI/Ops-MM-embedding-v1-7B"
|
||||
- `device`: 设备类型,默认自动选择 ("cuda" 或 "cpu")
|
||||
|
||||
#### 主要方法
|
||||
|
||||
##### `build_text_index(texts, save_path=None)`
|
||||
构建文本索引
|
||||
- `texts`: 文本列表
|
||||
- `save_path`: 索引保存路径 (可选)
|
||||
|
||||
##### `build_image_index(image_paths, save_path=None)`
|
||||
构建图像索引
|
||||
- `image_paths`: 图像路径列表
|
||||
- `save_path`: 索引保存路径 (可选)
|
||||
|
||||
##### `search_text_by_text(query, top_k=5)`
|
||||
文搜文检索
|
||||
- `query`: 查询文本
|
||||
- `top_k`: 返回结果数量
|
||||
- 返回: `[(文本, 相似度分数), ...]`
|
||||
|
||||
##### `search_images_by_text(query, top_k=5)`
|
||||
文搜图检索
|
||||
- `query`: 查询文本
|
||||
- `top_k`: 返回结果数量
|
||||
- 返回: `[(图像路径, 相似度分数), ...]`
|
||||
|
||||
##### `search_images_by_image(query_image, top_k=5)`
|
||||
图搜图检索
|
||||
- `query_image`: 查询图像路径或PIL图像
|
||||
- `top_k`: 返回结果数量
|
||||
- 返回: `[(图像路径, 相似度分数), ...]`
|
||||
|
||||
##### `search_text_by_image(query_image, top_k=5)`
|
||||
图搜文检索
|
||||
- `query_image`: 查询图像路径或PIL图像
|
||||
- `top_k`: 返回结果数量
|
||||
- 返回: `[(文本, 相似度分数), ...]`
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. **首次运行**: 首次运行时会自动下载模型,需要网络连接
|
||||
2. **内存需求**: 7B参数模型需要较大内存,建议使用GPU
|
||||
3. **图像格式**: 支持常见图像格式 (jpg, png, bmp, gif等)
|
||||
4. **批处理**: 系统自动进行批处理以提高效率
|
||||
5. **索引保存**: 可以保存和加载索引以避免重复构建
|
||||
|
||||
## 性能优化建议
|
||||
|
||||
- 使用GPU加速推理
|
||||
- 合理设置批处理大小
|
||||
- 保存索引文件避免重复构建
|
||||
- 对大量数据使用分批处理
|
||||
|
||||
## 故障排除
|
||||
|
||||
### 常见问题
|
||||
|
||||
1. **CUDA内存不足**: 减小批处理大小或使用CPU
|
||||
2. **模型下载失败**: 检查网络连接或使用镜像源
|
||||
3. **图像加载错误**: 检查图像文件路径和格式
|
||||
|
||||
### 日志信息
|
||||
|
||||
系统会输出详细的日志信息,包括:
|
||||
- GPU环境检测结果
|
||||
- 模型加载进度
|
||||
- 索引构建状态
|
||||
- 检索执行情况
|
||||
|
||||
## 许可证
|
||||
|
||||
本项目遵循 MIT 许可证。
|
||||
BIN
__pycache__/multimodal_retrieval_multigpu.cpython-310.pyc
Normal file
BIN
__pycache__/multimodal_retrieval_multigpu.cpython-310.pyc
Normal file
Binary file not shown.
632
multimodal_retrieval_multigpu.py
Normal file
632
multimodal_retrieval_multigpu.py
Normal file
@ -0,0 +1,632 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import faiss
|
||||
from transformers import AutoModel, AutoProcessor, AutoTokenizer
|
||||
from typing import List, Union, Tuple, Dict
|
||||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
import logging
|
||||
import gc
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import threading
|
||||
|
||||
# 设置日志
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class MultiGPUMultimodalRetrieval:
|
||||
"""多GPU优化的多模态检索系统,支持文搜图、文搜文、图搜图、图搜文"""
|
||||
|
||||
def __init__(self, model_name: str = "OpenSearch-AI/Ops-MM-embedding-v1-7B",
|
||||
use_all_gpus: bool = True, gpu_ids: List[int] = None, min_memory_gb=12):
|
||||
"""
|
||||
初始化多GPU多模态检索系统
|
||||
|
||||
Args:
|
||||
model_name: 模型名称
|
||||
use_all_gpus: 是否使用所有可用GPU
|
||||
gpu_ids: 指定使用的GPU ID列表
|
||||
min_memory_gb: 最小可用内存(GB)
|
||||
"""
|
||||
self.model_name = model_name
|
||||
|
||||
# 设置GPU设备
|
||||
self._setup_devices(use_all_gpus, gpu_ids, min_memory_gb)
|
||||
|
||||
# 清理GPU内存
|
||||
self._clear_all_gpu_memory()
|
||||
|
||||
logger.info(f"正在加载模型到多GPU: {self.device_ids}")
|
||||
|
||||
# 加载模型和处理器
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
self.processor = None
|
||||
self._load_model_multigpu()
|
||||
|
||||
# 初始化索引
|
||||
self.text_index = None
|
||||
self.image_index = None
|
||||
self.text_data = []
|
||||
self.image_data = []
|
||||
|
||||
logger.info("多GPU模型加载完成")
|
||||
|
||||
def _setup_devices(self, use_all_gpus: bool, gpu_ids: List[int], min_memory_gb=12):
|
||||
"""设置GPU设备"""
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError("CUDA不可用,无法使用多GPU")
|
||||
|
||||
total_gpus = torch.cuda.device_count()
|
||||
logger.info(f"检测到 {total_gpus} 个GPU")
|
||||
|
||||
# 检查是否设置了CUDA_VISIBLE_DEVICES
|
||||
cuda_visible_devices = os.environ.get('CUDA_VISIBLE_DEVICES')
|
||||
if cuda_visible_devices is not None:
|
||||
# 如果设置了CUDA_VISIBLE_DEVICES,使用可见的GPU
|
||||
visible_gpu_count = len(cuda_visible_devices.split(','))
|
||||
self.device_ids = list(range(visible_gpu_count))
|
||||
logger.info(f"使用CUDA_VISIBLE_DEVICES指定的GPU: {cuda_visible_devices}")
|
||||
elif use_all_gpus:
|
||||
self.device_ids = self._select_best_gpus(min_memory_gb)
|
||||
elif gpu_ids:
|
||||
self.device_ids = gpu_ids
|
||||
else:
|
||||
self.device_ids = [0]
|
||||
|
||||
self.num_gpus = len(self.device_ids)
|
||||
self.primary_device = f"cuda:{self.device_ids[0]}"
|
||||
|
||||
logger.info(f"使用GPU: {self.device_ids}, 主设备: {self.primary_device}")
|
||||
|
||||
def _clear_all_gpu_memory(self):
|
||||
"""清理所有GPU内存"""
|
||||
for gpu_id in self.device_ids:
|
||||
torch.cuda.set_device(gpu_id)
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
gc.collect()
|
||||
logger.info("所有GPU内存已清理")
|
||||
|
||||
def _load_model_multigpu(self):
|
||||
"""加载模型到多GPU"""
|
||||
try:
|
||||
# 设置环境变量优化内存使用
|
||||
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
|
||||
|
||||
# 清理GPU内存
|
||||
self._clear_gpu_memory()
|
||||
|
||||
# 首先尝试使用accelerate的自动设备映射
|
||||
if self.num_gpus > 1:
|
||||
# 设置最大内存限制(每个GPU 18GB,留出缓冲)
|
||||
max_memory = {i: "18GiB" for i in self.device_ids}
|
||||
|
||||
logger.info(f"正在加载模型到多GPU: {self.device_ids}")
|
||||
self.model = AutoModel.from_pretrained(
|
||||
self.model_name,
|
||||
trust_remote_code=True,
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
max_memory=max_memory,
|
||||
low_cpu_mem_usage=True,
|
||||
offload_folder="./offload"
|
||||
)
|
||||
else:
|
||||
# 单GPU加载
|
||||
self.model = AutoModel.from_pretrained(
|
||||
self.model_name,
|
||||
trust_remote_code=True,
|
||||
torch_dtype=torch.float16,
|
||||
device_map=self.primary_device
|
||||
)
|
||||
|
||||
# 加载分词器和处理器到主设备
|
||||
try:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.model_name,
|
||||
trust_remote_code=True
|
||||
)
|
||||
logger.info("Tokenizer加载成功")
|
||||
except Exception as e:
|
||||
logger.error(f"Tokenizer加载失败: {e}")
|
||||
return False
|
||||
|
||||
# 加载处理器用于图像处理
|
||||
try:
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
self.model_name,
|
||||
trust_remote_code=True
|
||||
)
|
||||
logger.info("Processor加载成功")
|
||||
except Exception as e:
|
||||
logger.warning(f"Processor加载失败: {e}")
|
||||
# 如果AutoProcessor失败,尝试使用tokenizer作为fallback
|
||||
logger.info("尝试使用tokenizer作为processor的fallback")
|
||||
self.processor = self.tokenizer
|
||||
|
||||
logger.info(f"模型已成功加载到设备: {self.model.hf_device_map if hasattr(self.model, 'hf_device_map') else self.primary_device}")
|
||||
logger.info("多GPU模型加载完成")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"多GPU模型加载失败: {str(e)}")
|
||||
return False
|
||||
|
||||
def _clear_gpu_memory(self):
|
||||
"""清理GPU内存"""
|
||||
for gpu_id in self.device_ids:
|
||||
torch.cuda.set_device(gpu_id)
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
gc.collect()
|
||||
logger.info("GPU内存已清理")
|
||||
|
||||
def _get_gpu_memory_info(self):
|
||||
"""获取GPU内存使用情况"""
|
||||
try:
|
||||
import subprocess
|
||||
result = subprocess.run(['nvidia-smi', '--query-gpu=memory.used,memory.total', '--format=csv,nounits,noheader'],
|
||||
capture_output=True, text=True, check=True)
|
||||
lines = result.stdout.strip().split('\n')
|
||||
gpu_info = []
|
||||
for i, line in enumerate(lines):
|
||||
used, total = map(int, line.split(', '))
|
||||
free = total - used
|
||||
gpu_info.append({
|
||||
'gpu_id': i,
|
||||
'used': used,
|
||||
'total': total,
|
||||
'free': free,
|
||||
'usage_percent': (used / total) * 100
|
||||
})
|
||||
return gpu_info
|
||||
except Exception as e:
|
||||
logger.warning(f"无法获取GPU内存信息: {e}")
|
||||
return []
|
||||
|
||||
def _select_best_gpus(self, min_memory_gb=12):
|
||||
"""选择内存充足的GPU"""
|
||||
gpu_info = self._get_gpu_memory_info()
|
||||
if not gpu_info:
|
||||
return list(range(torch.cuda.device_count()))
|
||||
|
||||
# 按可用内存排序
|
||||
gpu_info.sort(key=lambda x: x['free'], reverse=True)
|
||||
|
||||
# 选择内存充足的GPU
|
||||
min_memory_mb = min_memory_gb * 1024
|
||||
suitable_gpus = []
|
||||
|
||||
for gpu in gpu_info:
|
||||
if gpu['free'] >= min_memory_mb:
|
||||
suitable_gpus.append(gpu['gpu_id'])
|
||||
logger.info(f"GPU {gpu['gpu_id']}: {gpu['free']}MB 可用 (合适)")
|
||||
else:
|
||||
logger.warning(f"GPU {gpu['gpu_id']}: {gpu['free']}MB 可用 (不足)")
|
||||
|
||||
if not suitable_gpus:
|
||||
# 如果没有GPU满足要求,选择可用内存最多的
|
||||
logger.warning(f"没有GPU有足够内存({min_memory_gb}GB),选择可用内存最多的GPU")
|
||||
suitable_gpus = [gpu_info[0]['gpu_id']]
|
||||
|
||||
return suitable_gpus
|
||||
|
||||
def encode_text_batch(self, texts: List[str]) -> np.ndarray:
|
||||
"""
|
||||
批量编码文本为向量(多GPU优化)
|
||||
|
||||
Args:
|
||||
texts: 文本列表
|
||||
|
||||
Returns:
|
||||
文本向量
|
||||
"""
|
||||
if not texts:
|
||||
return np.array([])
|
||||
|
||||
with torch.no_grad():
|
||||
# 预处理输入
|
||||
inputs = self.tokenizer(
|
||||
text=texts,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=512
|
||||
)
|
||||
|
||||
# 将输入移动到主设备
|
||||
inputs = {k: v.to(self.primary_device) for k, v in inputs.items()}
|
||||
|
||||
# 前向传播
|
||||
outputs = self.model(**inputs)
|
||||
embeddings = outputs.last_hidden_state.mean(dim=1)
|
||||
|
||||
# 清理GPU内存
|
||||
del inputs, outputs
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return embeddings.cpu().numpy().astype(np.float32)
|
||||
|
||||
def encode_image_batch(self, images: List[Union[str, Image.Image]]) -> np.ndarray:
|
||||
"""
|
||||
批量编码图像为向量
|
||||
|
||||
Args:
|
||||
images: 图像路径或PIL图像列表
|
||||
|
||||
Returns:
|
||||
图像向量
|
||||
"""
|
||||
if not images:
|
||||
return np.array([])
|
||||
|
||||
# 预处理图像
|
||||
processed_images = []
|
||||
for img in images:
|
||||
if isinstance(img, str):
|
||||
img = Image.open(img).convert('RGB')
|
||||
elif isinstance(img, Image.Image):
|
||||
img = img.convert('RGB')
|
||||
processed_images.append(img)
|
||||
|
||||
try:
|
||||
logger.info(f"处理 {len(processed_images)} 张图像")
|
||||
|
||||
# 使用多模态模型生成图像embedding
|
||||
# 为每张图像创建简单的文本描述作为输入
|
||||
conversations = []
|
||||
for i in range(len(processed_images)):
|
||||
# 使用简化的对话格式
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "image": processed_images[i]},
|
||||
{"type": "text", "text": "What is in this image?"}
|
||||
]
|
||||
}
|
||||
]
|
||||
conversations.append(conversation)
|
||||
|
||||
# 使用processor处理
|
||||
try:
|
||||
# 尝试使用apply_chat_template方法
|
||||
texts = []
|
||||
for conv in conversations:
|
||||
text = self.processor.apply_chat_template(conv, tokenize=False, add_generation_prompt=False)
|
||||
texts.append(text)
|
||||
|
||||
# 处理文本和图像
|
||||
inputs = self.processor(
|
||||
text=texts,
|
||||
images=processed_images,
|
||||
return_tensors="pt",
|
||||
padding=True
|
||||
)
|
||||
|
||||
# 移动到GPU
|
||||
inputs = {k: v.to(self.primary_device) for k, v in inputs.items()}
|
||||
|
||||
# 获取模型输出
|
||||
with torch.no_grad():
|
||||
outputs = self.model(**inputs)
|
||||
embeddings = outputs.last_hidden_state.mean(dim=1)
|
||||
|
||||
# 转换为numpy数组
|
||||
embeddings = embeddings.cpu().numpy().astype(np.float32)
|
||||
|
||||
except Exception as inner_e:
|
||||
logger.warning(f"多模态模型图像编码失败,使用文本模式: {inner_e}")
|
||||
return np.zeros((len(processed_images), 3584), dtype=np.float32)
|
||||
# 如果多模态失败,使用纯文本描述作为fallback
|
||||
image_descriptions = ["An image" for _ in processed_images]
|
||||
text_inputs = self.processor(
|
||||
text=image_descriptions,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=512
|
||||
)
|
||||
text_inputs = {k: v.to(self.primary_device) for k, v in text_inputs.items()}
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = self.model(**text_inputs)
|
||||
embeddings = outputs.last_hidden_state.mean(dim=1)
|
||||
|
||||
embeddings = embeddings.cpu().numpy().astype(np.float32)
|
||||
|
||||
logger.info(f"生成图像embeddings: {embeddings.shape}")
|
||||
return embeddings
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"图像编码失败: {e}")
|
||||
# 返回与文本embedding维度一致的零向量作为fallback
|
||||
embedding_dim = 3584
|
||||
embeddings = np.zeros((len(processed_images), embedding_dim), dtype=np.float32)
|
||||
return embeddings
|
||||
|
||||
def build_text_index_parallel(self, texts: List[str], save_path: str = None):
|
||||
"""
|
||||
并行构建文本索引(多GPU优化)
|
||||
|
||||
Args:
|
||||
texts: 文本列表
|
||||
save_path: 索引保存路径
|
||||
"""
|
||||
logger.info(f"正在并行构建文本索引,共 {len(texts)} 条文本")
|
||||
|
||||
# 根据GPU数量调整批次大小
|
||||
batch_size = max(4, 16 // self.num_gpus)
|
||||
all_embeddings = []
|
||||
|
||||
# 分批处理
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch_texts = texts[i:i+batch_size]
|
||||
|
||||
try:
|
||||
embeddings = self.encode_text_batch(batch_texts)
|
||||
all_embeddings.append(embeddings)
|
||||
|
||||
# 显示进度
|
||||
if (i // batch_size + 1) % 10 == 0:
|
||||
logger.info(f"已处理 {i + len(batch_texts)}/{len(texts)} 条文本")
|
||||
|
||||
except torch.cuda.OutOfMemoryError:
|
||||
logger.warning(f"GPU内存不足,跳过批次 {i}-{i+len(batch_texts)}")
|
||||
self._clear_all_gpu_memory()
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"处理文本批次时出错: {e}")
|
||||
continue
|
||||
|
||||
if not all_embeddings:
|
||||
raise ValueError("没有成功处理任何文本")
|
||||
|
||||
# 合并所有嵌入向量
|
||||
embeddings = np.vstack(all_embeddings)
|
||||
|
||||
# 构建FAISS索引
|
||||
dimension = embeddings.shape[1]
|
||||
self.text_index = faiss.IndexFlatIP(dimension)
|
||||
|
||||
# 归一化向量
|
||||
faiss.normalize_L2(embeddings)
|
||||
self.text_index.add(embeddings)
|
||||
|
||||
self.text_data = texts
|
||||
|
||||
if save_path:
|
||||
self._save_index(self.text_index, texts, save_path + "_text")
|
||||
|
||||
logger.info("文本索引构建完成")
|
||||
|
||||
def build_image_index_parallel(self, image_paths: List[str], save_path: str = None):
|
||||
"""
|
||||
并行构建图像索引(多GPU优化)
|
||||
|
||||
Args:
|
||||
image_paths: 图像路径列表
|
||||
save_path: 索引保存路径
|
||||
"""
|
||||
logger.info(f"正在并行构建图像索引,共 {len(image_paths)} 张图像")
|
||||
|
||||
# 图像处理使用更小的批次
|
||||
batch_size = max(2, 8 // self.num_gpus)
|
||||
all_embeddings = []
|
||||
|
||||
for i in range(0, len(image_paths), batch_size):
|
||||
batch_images = image_paths[i:i+batch_size]
|
||||
|
||||
try:
|
||||
embeddings = self.encode_image_batch(batch_images)
|
||||
all_embeddings.append(embeddings)
|
||||
|
||||
# 显示进度
|
||||
if (i // batch_size + 1) % 5 == 0:
|
||||
logger.info(f"已处理 {i + len(batch_images)}/{len(image_paths)} 张图像")
|
||||
|
||||
except torch.cuda.OutOfMemoryError:
|
||||
logger.warning(f"GPU内存不足,跳过图像批次 {i}-{i+len(batch_images)}")
|
||||
self._clear_all_gpu_memory()
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"处理图像批次时出错: {e}")
|
||||
continue
|
||||
|
||||
if not all_embeddings:
|
||||
raise ValueError("没有成功处理任何图像")
|
||||
|
||||
embeddings = np.vstack(all_embeddings)
|
||||
|
||||
# 构建FAISS索引
|
||||
dimension = embeddings.shape[1]
|
||||
self.image_index = faiss.IndexFlatIP(dimension)
|
||||
|
||||
faiss.normalize_L2(embeddings)
|
||||
self.image_index.add(embeddings)
|
||||
|
||||
self.image_data = image_paths
|
||||
|
||||
if save_path:
|
||||
self._save_index(self.image_index, image_paths, save_path + "_image")
|
||||
|
||||
logger.info("图像索引构建完成")
|
||||
|
||||
def search_text_by_text(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
|
||||
"""文搜文:使用文本查询搜索相似文本"""
|
||||
if self.text_index is None:
|
||||
raise ValueError("文本索引未构建,请先调用 build_text_index_parallel")
|
||||
|
||||
query_embedding = self.encode_text_batch([query]).astype(np.float32)
|
||||
faiss.normalize_L2(query_embedding)
|
||||
|
||||
scores, indices = self.text_index.search(query_embedding, top_k)
|
||||
|
||||
results = []
|
||||
for score, idx in zip(scores[0], indices[0]):
|
||||
if idx != -1:
|
||||
results.append((self.text_data[idx], float(score)))
|
||||
|
||||
return results
|
||||
|
||||
def search_images_by_text(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
|
||||
"""文搜图:使用文本查询搜索相似图像"""
|
||||
if self.image_index is None:
|
||||
raise ValueError("图像索引未构建,请先调用 build_image_index_parallel")
|
||||
|
||||
query_embedding = self.encode_text_batch([query]).astype(np.float32)
|
||||
faiss.normalize_L2(query_embedding)
|
||||
|
||||
scores, indices = self.image_index.search(query_embedding, top_k)
|
||||
|
||||
results = []
|
||||
for score, idx in zip(scores[0], indices[0]):
|
||||
if idx != -1:
|
||||
results.append((self.image_data[idx], float(score)))
|
||||
|
||||
return results
|
||||
|
||||
def search_images_by_image(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]:
|
||||
"""图搜图:使用图像查询搜索相似图像"""
|
||||
if self.image_index is None:
|
||||
raise ValueError("图像索引未构建,请先调用 build_image_index_parallel")
|
||||
|
||||
query_embedding = self.encode_image_batch([query_image]).astype(np.float32)
|
||||
faiss.normalize_L2(query_embedding)
|
||||
|
||||
scores, indices = self.image_index.search(query_embedding, top_k)
|
||||
|
||||
results = []
|
||||
for score, idx in zip(scores[0], indices[0]):
|
||||
if idx != -1:
|
||||
results.append((self.image_data[idx], float(score)))
|
||||
|
||||
return results
|
||||
|
||||
def search_text_by_image(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]:
|
||||
"""图搜文:使用图像查询搜索相似文本"""
|
||||
if self.text_index is None:
|
||||
raise ValueError("文本索引未构建,请先调用 build_text_index_parallel")
|
||||
|
||||
query_embedding = self.encode_image_batch([query_image]).astype(np.float32)
|
||||
faiss.normalize_L2(query_embedding)
|
||||
|
||||
scores, indices = self.text_index.search(query_embedding, top_k)
|
||||
|
||||
results = []
|
||||
for score, idx in zip(scores[0], indices[0]):
|
||||
if idx != -1:
|
||||
results.append((self.text_data[idx], float(score)))
|
||||
|
||||
return results
|
||||
|
||||
# Web应用兼容的方法名称
|
||||
def search_text_to_image(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
|
||||
"""文搜图:Web应用兼容方法"""
|
||||
return self.search_images_by_text(query, top_k)
|
||||
|
||||
def search_image_to_image(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]:
|
||||
"""图搜图:Web应用兼容方法"""
|
||||
return self.search_images_by_image(query_image, top_k)
|
||||
|
||||
def search_text_to_text(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
|
||||
"""文搜文:Web应用兼容方法"""
|
||||
return self.search_text_by_text(query, top_k)
|
||||
|
||||
def search_image_to_text(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]:
|
||||
"""图搜文:Web应用兼容方法"""
|
||||
return self.search_text_by_image(query_image, top_k)
|
||||
|
||||
def _save_index(self, index, data, path_prefix):
|
||||
"""保存索引和数据"""
|
||||
faiss.write_index(index, f"{path_prefix}.index")
|
||||
with open(f"{path_prefix}.json", 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
def load_index(self, path_prefix, index_type="text"):
|
||||
"""加载已保存的索引"""
|
||||
index = faiss.read_index(f"{path_prefix}.index")
|
||||
with open(f"{path_prefix}.json", 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
if index_type == "text":
|
||||
self.text_index = index
|
||||
self.text_data = data
|
||||
else:
|
||||
self.image_index = index
|
||||
self.image_data = data
|
||||
|
||||
logger.info(f"已加载 {index_type} 索引")
|
||||
|
||||
def get_gpu_memory_info(self):
|
||||
"""获取所有GPU内存使用信息"""
|
||||
memory_info = {}
|
||||
for gpu_id in self.device_ids:
|
||||
torch.cuda.set_device(gpu_id)
|
||||
allocated = torch.cuda.memory_allocated(gpu_id) / 1024**3
|
||||
cached = torch.cuda.memory_reserved(gpu_id) / 1024**3
|
||||
total = torch.cuda.get_device_properties(gpu_id).total_memory / 1024**3
|
||||
free = total - cached
|
||||
|
||||
memory_info[f"GPU_{gpu_id}"] = {
|
||||
"total": f"{total:.1f}GB",
|
||||
"allocated": f"{allocated:.1f}GB",
|
||||
"cached": f"{cached:.1f}GB",
|
||||
"free": f"{free:.1f}GB"
|
||||
}
|
||||
|
||||
return memory_info
|
||||
|
||||
def check_multigpu_info():
|
||||
"""检查多GPU环境信息"""
|
||||
print("=== 多GPU环境信息 ===")
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
print("❌ CUDA不可用")
|
||||
return
|
||||
|
||||
gpu_count = torch.cuda.device_count()
|
||||
print(f"✅ 检测到 {gpu_count} 个GPU")
|
||||
print(f"CUDA版本: {torch.version.cuda}")
|
||||
print(f"PyTorch版本: {torch.__version__}")
|
||||
|
||||
for i in range(gpu_count):
|
||||
gpu_name = torch.cuda.get_device_name(i)
|
||||
gpu_memory = torch.cuda.get_device_properties(i).total_memory / 1024**3
|
||||
print(f"GPU {i}: {gpu_name} ({gpu_memory:.1f}GB)")
|
||||
|
||||
print("=====================")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 检查多GPU环境
|
||||
check_multigpu_info()
|
||||
|
||||
# 示例使用
|
||||
print("\n正在初始化多GPU多模态检索系统...")
|
||||
|
||||
try:
|
||||
retrieval_system = MultiGPUMultimodalRetrieval()
|
||||
print("✅ 多GPU系统初始化成功!")
|
||||
|
||||
# 显示GPU内存使用情况
|
||||
memory_info = retrieval_system.get_gpu_memory_info()
|
||||
print("\n📊 GPU内存使用情况:")
|
||||
for gpu, info in memory_info.items():
|
||||
print(f" {gpu}: {info['allocated']} / {info['total']} (已用/总计)")
|
||||
|
||||
print("\n🚀 多GPU多模态检索系统就绪!")
|
||||
print("支持的检索模式:")
|
||||
print("1. 文搜文: search_text_by_text()")
|
||||
print("2. 文搜图: search_images_by_text()")
|
||||
print("3. 图搜图: search_images_by_image()")
|
||||
print("4. 图搜文: search_text_by_image()")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 多GPU系统初始化失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
309
ops_mm_embedding_v1.py
Normal file
309
ops_mm_embedding_v1.py
Normal file
@ -0,0 +1,309 @@
|
||||
import math
|
||||
from typing import List, Optional, TypeAlias, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoModelForImageTextToText, AutoProcessor
|
||||
|
||||
ImageInput: TypeAlias = Union[Image.Image, List[Image.Image]]
|
||||
BatchImageInput: TypeAlias = Union[List[Image.Image], List[List[Image.Image]]]
|
||||
|
||||
|
||||
class OpsMMEmbeddingV1(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
device: str = "cuda",
|
||||
max_length: Optional[int] = None,
|
||||
attn_implementation: Optional[str] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.max_length = max_length
|
||||
self.default_instruction = "You are a helpful assistant."
|
||||
self.base_model = AutoModelForImageTextToText.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
low_cpu_mem_usage=True,
|
||||
attn_implementation=attn_implementation,
|
||||
).to(self.device)
|
||||
|
||||
self.processor = AutoProcessor.from_pretrained(model_name, min_pixels=256 * 28 * 28, max_pixels=1280 * 28 * 28)
|
||||
self.processor.tokenizer.padding_side = "left"
|
||||
self.eval()
|
||||
|
||||
def encode_input(self, input):
|
||||
hidden_states = self.base_model(**input, return_dict=True, output_hidden_states=True)
|
||||
hidden_states = hidden_states.hidden_states[-1]
|
||||
pooled_output = self._pooling(hidden_states)
|
||||
return pooled_output
|
||||
|
||||
def _pooling(self, last_hidden_state):
|
||||
batch_size = last_hidden_state.shape[0]
|
||||
reps = last_hidden_state[torch.arange(batch_size), -1, :]
|
||||
reps = torch.nn.functional.normalize(reps, p=2, dim=-1)
|
||||
return reps
|
||||
|
||||
def _validate_instructions(
|
||||
self,
|
||||
texts: Optional[List[str]],
|
||||
images: Optional[BatchImageInput],
|
||||
instruction: Optional[Union[str, List[str]]],
|
||||
) -> List[str]:
|
||||
"""Validate and format instructions to match batch size"""
|
||||
batch_size = max(len(x) if x is not None else 0 for x in [texts, images])
|
||||
|
||||
if instruction is None:
|
||||
return [self.default_instruction] * batch_size
|
||||
|
||||
if isinstance(instruction, str):
|
||||
return [instruction] * batch_size
|
||||
|
||||
if isinstance(instruction, list):
|
||||
if len(instruction) != batch_size:
|
||||
raise ValueError(f"Length of instruction list ({len(instruction)}) must match batch size ({batch_size}) when texts/images are provided")
|
||||
return instruction
|
||||
|
||||
raise TypeError("instruction must be str, List[str] or None")
|
||||
|
||||
def _process_images(self, images: ImageInput) -> List[Image.Image]:
|
||||
"""Convert single image or list of images to processed format"""
|
||||
if isinstance(images, Image.Image) or isinstance(images, str):
|
||||
return [fetch_image(images)]
|
||||
return [fetch_image(i) for i in images]
|
||||
|
||||
def embed(
|
||||
self,
|
||||
texts: Optional[List[str]] = None,
|
||||
images: Optional[BatchImageInput] = None,
|
||||
instruction: Optional[Union[str, List[str]]] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Generate embeddings for text, images, or combined inputs.
|
||||
|
||||
Args:
|
||||
texts: List of text inputs (optional)
|
||||
images: Can be:
|
||||
- List[Image.Image]: Single image per input
|
||||
- List[List[Image.Image]]: Multiple images per input
|
||||
instruction: Instruction(s) for the model. Can be:
|
||||
- None: use default instruction
|
||||
- str: use same instruction for all inputs
|
||||
- List[str]: per-input instructions (must match batch size)
|
||||
"""
|
||||
if texts is None and images is None:
|
||||
raise ValueError("Either texts or images must be provided")
|
||||
|
||||
instructions = self._validate_instructions(texts, images, instruction)
|
||||
|
||||
# Determine batch size
|
||||
batch_size = len(texts) if texts is not None else len(images) # type: ignore
|
||||
|
||||
input_texts, input_images = [], []
|
||||
for i in range(batch_size):
|
||||
text = texts[i] if texts is not None else None
|
||||
image = images[i] if images is not None else None
|
||||
|
||||
input_str = ""
|
||||
processed_image = None
|
||||
if image is not None:
|
||||
processed_image = self._process_images(image)
|
||||
input_str += "<|vision_start|><|image_pad|><|vision_end|>" * len(processed_image)
|
||||
|
||||
if text is not None:
|
||||
input_str += text
|
||||
|
||||
msg = f"<|im_start|>system\n{instructions[i]}<|im_end|>\n<|im_start|>user\n{input_str}<|im_end|>\n<|im_start|>assistant\n<|endoftext|>"
|
||||
|
||||
input_texts.append(msg)
|
||||
input_images.append(processed_image)
|
||||
|
||||
# Only pass to processor if we actually have images
|
||||
processed_images = input_images if any(img is not None for img in input_images) else None
|
||||
|
||||
inputs = self.processor(
|
||||
text=input_texts,
|
||||
images=processed_images,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||
|
||||
with torch.inference_mode():
|
||||
embeddings = self.encode_input(inputs)
|
||||
|
||||
return embeddings
|
||||
|
||||
def get_text_embeddings(
|
||||
self,
|
||||
texts: List[str],
|
||||
instruction: Optional[Union[str, List[str]]] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Convenience method for text-only embeddings"""
|
||||
return self.get_fused_embeddings(texts=texts, instruction=instruction, **kwargs)
|
||||
|
||||
def get_image_embeddings(
|
||||
self,
|
||||
images: BatchImageInput,
|
||||
instruction: Optional[Union[str, List[str]]] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Convenience method for image-only embeddings.
|
||||
|
||||
Args:
|
||||
images: Can be:
|
||||
- List[Image.Image]: Single image per input
|
||||
- List[List[Image.Image]]: Multiple images per input
|
||||
"""
|
||||
return self.get_fused_embeddings(images=images, instruction=instruction, **kwargs)
|
||||
|
||||
def get_fused_embeddings(
|
||||
self,
|
||||
texts: Optional[List[str]] = None,
|
||||
images: Optional[BatchImageInput] = None,
|
||||
instruction: Optional[Union[str, List[str]]] = None,
|
||||
batch_size: int = 8,
|
||||
show_progress: bool = True,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""Batch processing for large collections of texts/images.
|
||||
|
||||
Args:
|
||||
texts: List of text inputs (optional)
|
||||
images: Can be:
|
||||
- List[Image.Image]: Single image per input
|
||||
- List[List[Image.Image]]: Multiple images per input
|
||||
instruction: Instruction(s) for the model
|
||||
batch_size: Number of items to process at once
|
||||
show_progress: Whether to display progress bar
|
||||
"""
|
||||
|
||||
if texts is None and images is None:
|
||||
raise ValueError("Either texts or images must be provided")
|
||||
|
||||
total_items = len(texts) if texts is not None else len(images) # type: ignore
|
||||
num_batches = math.ceil(total_items / batch_size)
|
||||
|
||||
all_embeddings = []
|
||||
progress = tqdm(total=num_batches, disable=not show_progress, desc="Processing")
|
||||
|
||||
for i in range(0, total_items, batch_size):
|
||||
batch_texts = texts[i : i + batch_size] if texts is not None else None
|
||||
batch_images = images[i : i + batch_size] if images is not None else None
|
||||
batch_emb = self.embed(texts=batch_texts, images=batch_images, instruction=instruction)
|
||||
|
||||
all_embeddings.append(batch_emb.cpu())
|
||||
progress.update(1)
|
||||
|
||||
progress.close()
|
||||
return torch.cat(all_embeddings, dim=0).to(self.device)
|
||||
|
||||
def forward(self, **inputs) -> torch.Tensor:
|
||||
"""Alias for encode_input"""
|
||||
return self.encode_input(inputs)
|
||||
|
||||
|
||||
### Modified from qwen_vl_utils.vision_process.py
|
||||
import base64
|
||||
import logging
|
||||
import math
|
||||
from io import BytesIO
|
||||
|
||||
import requests
|
||||
|
||||
IMAGE_FACTOR = 28
|
||||
MIN_PIXELS = 256 * 28 * 28
|
||||
MAX_PIXELS = 1280 * 28 * 28
|
||||
MAX_RATIO = 200
|
||||
|
||||
|
||||
def round_by_factor(number: int, factor: int) -> int:
|
||||
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
|
||||
return round(number / factor) * factor
|
||||
|
||||
|
||||
def ceil_by_factor(number: int | float, factor: int) -> int:
|
||||
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
|
||||
return math.ceil(number / factor) * factor
|
||||
|
||||
|
||||
def floor_by_factor(number: int | float, factor: int) -> int:
|
||||
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
|
||||
return math.floor(number / factor) * factor
|
||||
|
||||
|
||||
def smart_resize(
|
||||
height: int,
|
||||
width: int,
|
||||
factor: int = IMAGE_FACTOR,
|
||||
min_pixels: int = MIN_PIXELS,
|
||||
max_pixels: int = MAX_PIXELS,
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Rescales the image so that the following conditions are met:
|
||||
1. Both dimensions (height and width) are divisible by 'factor'.
|
||||
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
||||
3. The aspect ratio of the image is maintained as closely as possible.
|
||||
"""
|
||||
h_bar = max(factor, round_by_factor(height, factor))
|
||||
w_bar = max(factor, round_by_factor(width, factor))
|
||||
if h_bar * w_bar > max_pixels:
|
||||
beta = math.sqrt((height * width) / max_pixels)
|
||||
h_bar = floor_by_factor(height / beta, factor)
|
||||
w_bar = floor_by_factor(width / beta, factor)
|
||||
elif h_bar * w_bar < min_pixels:
|
||||
beta = math.sqrt(min_pixels / (height * width))
|
||||
h_bar = ceil_by_factor(height * beta, factor)
|
||||
w_bar = ceil_by_factor(width * beta, factor)
|
||||
|
||||
if max(h_bar, w_bar) / min(h_bar, w_bar) > MAX_RATIO:
|
||||
logging.warning(f"Absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(h_bar, w_bar) / min(h_bar, w_bar)}")
|
||||
if h_bar > w_bar:
|
||||
h_bar = w_bar * MAX_RATIO
|
||||
else:
|
||||
w_bar = h_bar * MAX_RATIO
|
||||
return h_bar, w_bar
|
||||
|
||||
|
||||
def fetch_image(
|
||||
image: str | Image.Image,
|
||||
size_factor: int = IMAGE_FACTOR,
|
||||
min_pixels: int = MIN_PIXELS,
|
||||
max_pixels: int = MAX_PIXELS,
|
||||
) -> Image.Image:
|
||||
image_obj = None
|
||||
if isinstance(image, Image.Image):
|
||||
image_obj = image
|
||||
elif image.startswith("http://") or image.startswith("https://"):
|
||||
image_obj = Image.open(requests.get(image, stream=True).raw) # type: ignore
|
||||
elif image.startswith("file://"):
|
||||
image_obj = Image.open(image[7:])
|
||||
elif image.startswith("data:image"):
|
||||
if "base64," in image:
|
||||
_, base64_data = image.split("base64,", 1)
|
||||
data = base64.b64decode(base64_data)
|
||||
image_obj = Image.open(BytesIO(data))
|
||||
else:
|
||||
image_obj = Image.open(image)
|
||||
if image_obj is None:
|
||||
raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
|
||||
image = image_obj.convert("RGB")
|
||||
width, height = image.size
|
||||
resized_height, resized_width = smart_resize(
|
||||
height,
|
||||
width,
|
||||
factor=size_factor,
|
||||
min_pixels=min_pixels,
|
||||
max_pixels=max_pixels,
|
||||
)
|
||||
image = image.resize((resized_width, resized_height))
|
||||
|
||||
return image
|
||||
|
||||
|
||||
###
|
||||
12
requirements.txt
Normal file
12
requirements.txt
Normal file
@ -0,0 +1,12 @@
|
||||
torch>=2.0.0
|
||||
torchvision>=0.15.0
|
||||
transformers>=4.30.0
|
||||
accelerate>=0.20.0
|
||||
faiss-cpu>=1.7.4
|
||||
numpy>=1.21.0
|
||||
Pillow>=9.0.0
|
||||
scikit-learn>=1.3.0
|
||||
tqdm>=4.65.0
|
||||
flask>=2.3.0
|
||||
werkzeug>=2.3.0
|
||||
psutil>=5.9.0
|
||||
BIN
sample_images/1755681385_4__.jpg
Normal file
BIN
sample_images/1755681385_4__.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 189 KiB |
BIN
sample_images/1755681385_5__.jpg
Normal file
BIN
sample_images/1755681385_5__.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 201 KiB |
BIN
sample_images/1755681385_6__.jpg
Normal file
BIN
sample_images/1755681385_6__.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 144 KiB |
BIN
sample_images/1755681385_7__.jpg
Normal file
BIN
sample_images/1755681385_7__.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 166 KiB |
971
templates/index.html
Normal file
971
templates/index.html
Normal file
@ -0,0 +1,971 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>多模态检索系统</title>
|
||||
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css" rel="stylesheet">
|
||||
<link href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css" rel="stylesheet">
|
||||
<style>
|
||||
:root {
|
||||
--primary-color: #2563eb;
|
||||
--secondary-color: #64748b;
|
||||
--success-color: #059669;
|
||||
--warning-color: #d97706;
|
||||
--danger-color: #dc2626;
|
||||
--bg-gradient: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
}
|
||||
|
||||
body {
|
||||
background: var(--bg-gradient);
|
||||
min-height: 100vh;
|
||||
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
|
||||
}
|
||||
|
||||
.main-container {
|
||||
background: rgba(255, 255, 255, 0.95);
|
||||
backdrop-filter: blur(10px);
|
||||
border-radius: 20px;
|
||||
box-shadow: 0 20px 40px rgba(0, 0, 0, 0.1);
|
||||
margin: 20px auto;
|
||||
max-width: 1200px;
|
||||
}
|
||||
|
||||
.header {
|
||||
background: linear-gradient(135deg, var(--primary-color), #3b82f6);
|
||||
color: white;
|
||||
padding: 2rem;
|
||||
border-radius: 20px 20px 0 0;
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.mode-card {
|
||||
background: white;
|
||||
border-radius: 15px;
|
||||
padding: 1.5rem;
|
||||
margin-bottom: 1.5rem;
|
||||
box-shadow: 0 4px 15px rgba(0, 0, 0, 0.1);
|
||||
transition: all 0.3s ease;
|
||||
border: 2px solid transparent;
|
||||
}
|
||||
|
||||
.mode-card:hover {
|
||||
transform: translateY(-5px);
|
||||
box-shadow: 0 8px 25px rgba(0, 0, 0, 0.15);
|
||||
}
|
||||
|
||||
.mode-card.active {
|
||||
border-color: var(--primary-color);
|
||||
background: linear-gradient(135deg, #eff6ff, #dbeafe);
|
||||
}
|
||||
|
||||
.mode-icon {
|
||||
font-size: 2.5rem;
|
||||
margin-bottom: 1rem;
|
||||
display: block;
|
||||
}
|
||||
|
||||
.text-to-text { color: #059669; }
|
||||
.text-to-image { color: #dc2626; }
|
||||
.image-to-text { color: #d97706; }
|
||||
.image-to-image { color: #7c3aed; }
|
||||
|
||||
.search-input {
|
||||
border-radius: 12px;
|
||||
border: 2px solid #e5e7eb;
|
||||
padding: 12px 16px;
|
||||
font-size: 16px;
|
||||
transition: all 0.3s ease;
|
||||
}
|
||||
|
||||
.search-input:focus {
|
||||
border-color: var(--primary-color);
|
||||
box-shadow: 0 0 0 3px rgba(37, 99, 235, 0.1);
|
||||
}
|
||||
|
||||
.btn-primary {
|
||||
background: var(--primary-color);
|
||||
border: none;
|
||||
border-radius: 12px;
|
||||
padding: 12px 24px;
|
||||
font-weight: 600;
|
||||
transition: all 0.3s ease;
|
||||
}
|
||||
|
||||
.btn-primary:hover {
|
||||
background: #1d4ed8;
|
||||
transform: translateY(-2px);
|
||||
}
|
||||
|
||||
.file-upload-area {
|
||||
border: 3px dashed #d1d5db;
|
||||
border-radius: 12px;
|
||||
padding: 3rem;
|
||||
text-align: center;
|
||||
transition: all 0.3s ease;
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.file-upload-area:hover {
|
||||
border-color: var(--primary-color);
|
||||
background: rgba(37, 99, 235, 0.05);
|
||||
}
|
||||
|
||||
.file-upload-area.dragover {
|
||||
border-color: var(--primary-color);
|
||||
background: rgba(37, 99, 235, 0.1);
|
||||
}
|
||||
|
||||
.result-card {
|
||||
background: white;
|
||||
border-radius: 12px;
|
||||
padding: 1.5rem;
|
||||
margin-bottom: 1rem;
|
||||
box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1);
|
||||
border-left: 4px solid var(--primary-color);
|
||||
}
|
||||
|
||||
.result-image {
|
||||
max-width: 200px;
|
||||
max-height: 150px;
|
||||
border-radius: 8px;
|
||||
object-fit: cover;
|
||||
}
|
||||
|
||||
.score-badge {
|
||||
background: var(--success-color);
|
||||
color: white;
|
||||
padding: 4px 12px;
|
||||
border-radius: 20px;
|
||||
font-size: 0.85rem;
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
.loading-spinner {
|
||||
display: none;
|
||||
text-align: center;
|
||||
padding: 2rem;
|
||||
}
|
||||
|
||||
.status-indicator {
|
||||
position: fixed;
|
||||
top: 20px;
|
||||
right: 20px;
|
||||
z-index: 1000;
|
||||
}
|
||||
|
||||
.fade-in {
|
||||
animation: fadeIn 0.5s ease-in;
|
||||
}
|
||||
|
||||
@keyframes fadeIn {
|
||||
from { opacity: 0; transform: translateY(20px); }
|
||||
to { opacity: 1; transform: translateY(0); }
|
||||
}
|
||||
|
||||
.query-image {
|
||||
max-width: 300px;
|
||||
max-height: 200px;
|
||||
border-radius: 12px;
|
||||
object-fit: cover;
|
||||
box-shadow: 0 4px 15px rgba(0, 0, 0, 0.2);
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<!-- 状态指示器 -->
|
||||
<div class="status-indicator">
|
||||
<div id="statusBadge" class="badge bg-secondary">
|
||||
<i class="fas fa-circle-notch fa-spin"></i> 未初始化
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="container-fluid">
|
||||
<div class="main-container">
|
||||
<!-- 头部 -->
|
||||
<div class="header">
|
||||
<h1><i class="fas fa-search"></i> 多模态检索系统</h1>
|
||||
<p class="mb-0">支持文搜图、文搜文、图搜图、图搜文四种检索模式</p>
|
||||
</div>
|
||||
|
||||
<div class="p-4">
|
||||
<!-- 重新初始化按钮 -->
|
||||
<div class="text-center mb-4">
|
||||
<button id="reinitBtn" class="btn btn-warning">
|
||||
<i class="fas fa-redo"></i> 重新初始化系统
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- 检索模式选择 -->
|
||||
<div class="row mb-4" id="modeSelection">
|
||||
<div class="col-md-3">
|
||||
<div class="mode-card text-center" data-mode="text_to_text">
|
||||
<i class="fas fa-file-text mode-icon text-to-text"></i>
|
||||
<h5>文搜文</h5>
|
||||
<p class="text-muted">文本查找相似文本</p>
|
||||
</div>
|
||||
</div>
|
||||
<div class="col-md-3">
|
||||
<div class="mode-card text-center" data-mode="text_to_image">
|
||||
<i class="fas fa-image mode-icon text-to-image"></i>
|
||||
<h5>文搜图</h5>
|
||||
<p class="text-muted">文本查找相关图片</p>
|
||||
</div>
|
||||
</div>
|
||||
<div class="col-md-3">
|
||||
<div class="mode-card text-center" data-mode="image_to_text">
|
||||
<i class="fas fa-comment mode-icon image-to-text"></i>
|
||||
<h5>图搜文</h5>
|
||||
<p class="text-muted">图片查找相关文本</p>
|
||||
</div>
|
||||
</div>
|
||||
<div class="col-md-3">
|
||||
<div class="mode-card text-center" data-mode="image_to_image">
|
||||
<i class="fas fa-images mode-icon image-to-image"></i>
|
||||
<h5>图搜图</h5>
|
||||
<p class="text-muted">图片查找相似图片</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 数据管理界面 -->
|
||||
<div class="row mb-4" id="dataManagement">
|
||||
<div class="col-12">
|
||||
<div class="card">
|
||||
<div class="card-header">
|
||||
<h5><i class="fas fa-database"></i> 数据管理</h5>
|
||||
<small class="text-muted">上传和管理检索数据库</small>
|
||||
</div>
|
||||
<div class="card-body">
|
||||
<div class="row">
|
||||
<!-- 批量上传图片 -->
|
||||
<div class="col-md-6">
|
||||
<div class="upload-section">
|
||||
<h6><i class="fas fa-images text-primary"></i> 批量上传图片</h6>
|
||||
<div class="file-upload-area" id="batchImageUpload">
|
||||
<i class="fas fa-cloud-upload-alt fa-2x text-muted mb-2"></i>
|
||||
<p>拖拽多张图片到此处或点击选择</p>
|
||||
<input type="file" id="batchImageFiles" multiple accept="image/*" style="display: none;">
|
||||
<button class="btn btn-outline-primary btn-sm mt-2" onclick="document.getElementById('batchImageFiles').click()">
|
||||
<i class="fas fa-folder-open"></i> 选择图片
|
||||
</button>
|
||||
</div>
|
||||
<div id="imageUploadProgress" class="mt-2" style="display: none;">
|
||||
<div class="progress">
|
||||
<div class="progress-bar" role="progressbar" style="width: 0%"></div>
|
||||
</div>
|
||||
<small class="text-muted mt-1 d-block">上传进度: <span id="imageProgressText">0/0</span></small>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 批量上传文本 -->
|
||||
<div class="col-md-6">
|
||||
<div class="upload-section">
|
||||
<h6><i class="fas fa-file-text text-success"></i> 批量上传文本</h6>
|
||||
<div class="mb-3">
|
||||
<textarea id="batchTextInput" class="form-control" rows="8"
|
||||
placeholder="请输入文本数据,每行一条文本记录... 例如: 这是第一条文本记录 这是第二条文本记录 这是第三条文本记录"></textarea>
|
||||
</div>
|
||||
<div class="d-flex gap-2">
|
||||
<button id="uploadTextsBtn" class="btn btn-success">
|
||||
<i class="fas fa-upload"></i> 上传文本
|
||||
</button>
|
||||
<button class="btn btn-outline-secondary" onclick="document.getElementById('textFile').click()">
|
||||
<i class="fas fa-file-import"></i> 从文件导入
|
||||
</button>
|
||||
<input type="file" id="textFile" accept=".txt,.csv" style="display: none;">
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 数据统计和管理 -->
|
||||
<div class="row mt-4">
|
||||
<div class="col-md-8">
|
||||
<div class="d-flex gap-3">
|
||||
<button id="buildIndexBtn" class="btn btn-warning" disabled>
|
||||
<i class="fas fa-cogs"></i> 构建索引
|
||||
</button>
|
||||
<button id="viewDataBtn" class="btn btn-info">
|
||||
<i class="fas fa-list"></i> 查看数据
|
||||
</button>
|
||||
<button id="clearDataBtn" class="btn btn-danger">
|
||||
<i class="fas fa-trash"></i> 清空数据
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
<div class="col-md-4">
|
||||
<div id="dataStats" class="text-end">
|
||||
<small class="text-muted">
|
||||
图片: <span id="imageCount">0</span> 张 |
|
||||
文本: <span id="textCount">0</span> 条
|
||||
</small>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 搜索界面 -->
|
||||
<div id="searchInterface" style="display: none;">
|
||||
<!-- 文本搜索 -->
|
||||
<div id="textSearch" class="search-panel" style="display: none;">
|
||||
<div class="row">
|
||||
<div class="col-md-8">
|
||||
<input type="text" id="textQuery" class="form-control search-input"
|
||||
placeholder="请输入搜索文本...">
|
||||
</div>
|
||||
<div class="col-md-2">
|
||||
<select id="textTopK" class="form-select search-input">
|
||||
<option value="3">Top 3</option>
|
||||
<option value="5" selected>Top 5</option>
|
||||
<option value="10">Top 10</option>
|
||||
</select>
|
||||
</div>
|
||||
<div class="col-md-2">
|
||||
<button id="textSearchBtn" class="btn btn-primary w-100">
|
||||
<i class="fas fa-search"></i> 搜索
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 图片搜索 -->
|
||||
<div id="imageSearch" class="search-panel" style="display: none;">
|
||||
<div class="row">
|
||||
<div class="col-md-8">
|
||||
<div class="file-upload-area" id="fileUploadArea">
|
||||
<i class="fas fa-cloud-upload-alt fa-3x text-muted mb-3"></i>
|
||||
<h5>拖拽图片到此处或点击选择</h5>
|
||||
<p class="text-muted">支持 PNG, JPG, JPEG, GIF, BMP, WebP 格式</p>
|
||||
<input type="file" id="imageFile" accept="image/*" style="display: none;">
|
||||
</div>
|
||||
</div>
|
||||
<div class="col-md-2">
|
||||
<select id="imageTopK" class="form-select search-input">
|
||||
<option value="3">Top 3</option>
|
||||
<option value="5" selected>Top 5</option>
|
||||
<option value="10">Top 10</option>
|
||||
</select>
|
||||
</div>
|
||||
<div class="col-md-2">
|
||||
<button id="imageSearchBtn" class="btn btn-primary w-100" disabled>
|
||||
<i class="fas fa-search"></i> 搜索
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 加载动画 -->
|
||||
<div class="loading-spinner" id="loadingSpinner">
|
||||
<div class="spinner-border text-primary" role="status">
|
||||
<span class="visually-hidden">Loading...</span>
|
||||
</div>
|
||||
<p class="mt-2">正在搜索中...</p>
|
||||
</div>
|
||||
|
||||
<!-- 搜索结果 -->
|
||||
<div id="searchResults" class="mt-4"></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script src="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/js/bootstrap.bundle.min.js"></script>
|
||||
<script>
|
||||
let currentMode = null;
|
||||
let systemInitialized = false;
|
||||
|
||||
// 重新初始化系统
|
||||
document.getElementById('reinitBtn').addEventListener('click', async function() {
|
||||
const btn = this;
|
||||
const originalText = btn.innerHTML;
|
||||
|
||||
btn.innerHTML = '<i class="fas fa-spinner fa-spin"></i> 重新初始化中...';
|
||||
btn.disabled = true;
|
||||
|
||||
try {
|
||||
const response = await fetch('/api/init', {
|
||||
method: 'POST',
|
||||
headers: {'Content-Type': 'application/json'}
|
||||
});
|
||||
|
||||
const data = await response.json();
|
||||
|
||||
if (data.success) {
|
||||
systemInitialized = true;
|
||||
document.getElementById('statusBadge').innerHTML =
|
||||
'<i class="fas fa-check-circle"></i> 已重新初始化';
|
||||
document.getElementById('statusBadge').className = 'badge bg-success';
|
||||
|
||||
showAlert('success', `系统重新初始化成功!GPU: ${data.gpu_count} 个`);
|
||||
} else {
|
||||
throw new Error(data.message);
|
||||
}
|
||||
} catch (error) {
|
||||
showAlert('danger', '重新初始化失败: ' + error.message);
|
||||
} finally {
|
||||
btn.innerHTML = originalText;
|
||||
btn.disabled = false;
|
||||
}
|
||||
});
|
||||
|
||||
// 模式选择
|
||||
document.querySelectorAll('.mode-card').forEach(card => {
|
||||
card.addEventListener('click', function() {
|
||||
|
||||
// 更新选中状态
|
||||
document.querySelectorAll('.mode-card').forEach(c => c.classList.remove('active'));
|
||||
this.classList.add('active');
|
||||
|
||||
currentMode = this.dataset.mode;
|
||||
setupSearchInterface(currentMode);
|
||||
});
|
||||
});
|
||||
|
||||
// 设置搜索界面
|
||||
function setupSearchInterface(mode) {
|
||||
document.getElementById('searchInterface').style.display = 'block';
|
||||
document.getElementById('textSearch').style.display = 'none';
|
||||
document.getElementById('imageSearch').style.display = 'none';
|
||||
document.getElementById('searchResults').innerHTML = '';
|
||||
|
||||
if (mode === 'text_to_text' || mode === 'text_to_image') {
|
||||
document.getElementById('textSearch').style.display = 'block';
|
||||
document.getElementById('textQuery').placeholder =
|
||||
mode === 'text_to_text' ? '请输入要搜索的文本...' : '请输入要搜索图片的描述...';
|
||||
} else {
|
||||
document.getElementById('imageSearch').style.display = 'block';
|
||||
}
|
||||
}
|
||||
|
||||
// 文本搜索
|
||||
document.getElementById('textSearchBtn').addEventListener('click', performTextSearch);
|
||||
document.getElementById('textQuery').addEventListener('keypress', function(e) {
|
||||
if (e.key === 'Enter') performTextSearch();
|
||||
});
|
||||
|
||||
async function performTextSearch() {
|
||||
const query = document.getElementById('textQuery').value.trim();
|
||||
const topK = parseInt(document.getElementById('textTopK').value);
|
||||
|
||||
if (!query) {
|
||||
showAlert('warning', '请输入搜索文本');
|
||||
return;
|
||||
}
|
||||
|
||||
showLoading(true);
|
||||
|
||||
try {
|
||||
const endpoint = currentMode === 'text_to_text' ? '/api/search/text_to_text' : '/api/search/text_to_image';
|
||||
const response = await fetch(endpoint, {
|
||||
method: 'POST',
|
||||
headers: {'Content-Type': 'application/json'},
|
||||
body: JSON.stringify({query, top_k: topK})
|
||||
});
|
||||
|
||||
const data = await response.json();
|
||||
|
||||
if (data.success) {
|
||||
displayResults(data, currentMode);
|
||||
} else {
|
||||
throw new Error(data.message);
|
||||
}
|
||||
} catch (error) {
|
||||
showAlert('danger', '搜索失败: ' + error.message);
|
||||
} finally {
|
||||
showLoading(false);
|
||||
}
|
||||
}
|
||||
|
||||
// 图片上传处理
|
||||
const fileUploadArea = document.getElementById('fileUploadArea');
|
||||
const imageFile = document.getElementById('imageFile');
|
||||
|
||||
fileUploadArea.addEventListener('click', () => imageFile.click());
|
||||
fileUploadArea.addEventListener('dragover', handleDragOver);
|
||||
fileUploadArea.addEventListener('drop', handleDrop);
|
||||
imageFile.addEventListener('change', handleFileSelect);
|
||||
|
||||
function handleDragOver(e) {
|
||||
e.preventDefault();
|
||||
fileUploadArea.classList.add('dragover');
|
||||
}
|
||||
|
||||
function handleDrop(e) {
|
||||
e.preventDefault();
|
||||
fileUploadArea.classList.remove('dragover');
|
||||
const files = e.dataTransfer.files;
|
||||
if (files.length > 0) {
|
||||
handleFile(files[0]);
|
||||
}
|
||||
}
|
||||
|
||||
function handleFileSelect(e) {
|
||||
const file = e.target.files[0];
|
||||
if (file) handleFile(file);
|
||||
}
|
||||
|
||||
function handleFile(file) {
|
||||
if (!file.type.startsWith('image/')) {
|
||||
showAlert('warning', '请选择图片文件');
|
||||
return;
|
||||
}
|
||||
|
||||
const reader = new FileReader();
|
||||
reader.onload = function(e) {
|
||||
fileUploadArea.innerHTML = `
|
||||
<img src="${e.target.result}" class="query-image mb-3">
|
||||
<p class="text-success"><i class="fas fa-check"></i> 图片已选择: ${file.name}</p>
|
||||
`;
|
||||
document.getElementById('imageSearchBtn').disabled = false;
|
||||
};
|
||||
reader.readAsDataURL(file);
|
||||
}
|
||||
|
||||
// 图片搜索
|
||||
document.getElementById('imageSearchBtn').addEventListener('click', async function() {
|
||||
const file = imageFile.files[0];
|
||||
const topK = parseInt(document.getElementById('imageTopK').value);
|
||||
|
||||
if (!file) {
|
||||
showAlert('warning', '请选择图片文件');
|
||||
return;
|
||||
}
|
||||
|
||||
showLoading(true);
|
||||
|
||||
try {
|
||||
const formData = new FormData();
|
||||
formData.append('image', file);
|
||||
formData.append('top_k', topK);
|
||||
|
||||
const endpoint = currentMode === 'image_to_text' ? '/api/search/image_to_text' : '/api/search/image_to_image';
|
||||
const response = await fetch(endpoint, {
|
||||
method: 'POST',
|
||||
body: formData
|
||||
});
|
||||
|
||||
const data = await response.json();
|
||||
|
||||
if (data.success) {
|
||||
displayResults(data, currentMode);
|
||||
} else {
|
||||
throw new Error(data.message);
|
||||
}
|
||||
} catch (error) {
|
||||
showAlert('danger', '搜索失败: ' + error.message);
|
||||
} finally {
|
||||
showLoading(false);
|
||||
}
|
||||
});
|
||||
|
||||
// 显示结果
|
||||
function displayResults(data, mode) {
|
||||
const resultsContainer = document.getElementById('searchResults');
|
||||
|
||||
let html = `
|
||||
<div class="fade-in">
|
||||
<div class="d-flex justify-content-between align-items-center mb-3">
|
||||
<h4><i class="fas fa-search-plus"></i> 搜索结果</h4>
|
||||
<div>
|
||||
<span class="badge bg-info">找到 ${data.result_count} 个结果</span>
|
||||
<span class="badge bg-secondary">耗时 ${data.search_time}s</span>
|
||||
</div>
|
||||
</div>
|
||||
`;
|
||||
|
||||
if (data.query_image) {
|
||||
html += `
|
||||
<div class="result-card">
|
||||
<h6><i class="fas fa-image"></i> 查询图片</h6>
|
||||
<img src="data:image/jpeg;base64,${data.query_image}" class="query-image">
|
||||
</div>
|
||||
`;
|
||||
}
|
||||
|
||||
if (data.query) {
|
||||
html += `
|
||||
<div class="result-card">
|
||||
<h6><i class="fas fa-quote-left"></i> 查询文本</h6>
|
||||
<p class="mb-0">"${data.query}"</p>
|
||||
</div>
|
||||
`;
|
||||
}
|
||||
|
||||
data.results.forEach((result, index) => {
|
||||
html += '<div class="result-card">';
|
||||
|
||||
if (mode === 'text_to_image' || mode === 'image_to_image') {
|
||||
html += `
|
||||
<div class="row">
|
||||
<div class="col-md-3">
|
||||
<img src="data:image/jpeg;base64,${result.image_base64}"
|
||||
class="result-image" alt="Result ${index + 1}">
|
||||
</div>
|
||||
<div class="col-md-9">
|
||||
<div class="d-flex justify-content-between align-items-start">
|
||||
<h6><i class="fas fa-image"></i> ${result.filename}</h6>
|
||||
<span class="score-badge">相似度: ${(result.score * 100).toFixed(1)}%</span>
|
||||
</div>
|
||||
<p class="text-muted mb-0">路径: ${result.image_path}</p>
|
||||
</div>
|
||||
</div>
|
||||
`;
|
||||
} else {
|
||||
html += `
|
||||
<div class="d-flex justify-content-between align-items-start">
|
||||
<div>
|
||||
<h6><i class="fas fa-file-text"></i> 结果 ${index + 1}</h6>
|
||||
<p class="mb-0">${result.text || result}</p>
|
||||
</div>
|
||||
<span class="score-badge">相似度: ${((result.score || 0.95) * 100).toFixed(1)}%</span>
|
||||
</div>
|
||||
`;
|
||||
}
|
||||
|
||||
html += '</div>';
|
||||
});
|
||||
|
||||
html += '</div>';
|
||||
resultsContainer.innerHTML = html;
|
||||
}
|
||||
|
||||
// 工具函数
|
||||
function showLoading(show) {
|
||||
document.getElementById('loadingSpinner').style.display = show ? 'block' : 'none';
|
||||
}
|
||||
|
||||
function showAlert(type, message) {
|
||||
const alertDiv = document.createElement('div');
|
||||
alertDiv.className = `alert alert-${type} alert-dismissible fade show`;
|
||||
alertDiv.innerHTML = `
|
||||
${message}
|
||||
<button type="button" class="btn-close" data-bs-dismiss="alert"></button>
|
||||
`;
|
||||
|
||||
document.querySelector('.main-container .p-4').insertBefore(alertDiv, document.querySelector('.main-container .p-4').firstChild);
|
||||
|
||||
setTimeout(() => {
|
||||
if (alertDiv.parentNode) {
|
||||
alertDiv.remove();
|
||||
}
|
||||
}, 5000);
|
||||
}
|
||||
|
||||
// 检查系统状态
|
||||
async function checkStatus() {
|
||||
try {
|
||||
const response = await fetch('/api/status');
|
||||
const data = await response.json();
|
||||
|
||||
if (data.initialized) {
|
||||
systemInitialized = true;
|
||||
document.getElementById('statusBadge').innerHTML =
|
||||
'<i class="fas fa-check-circle"></i> 已初始化';
|
||||
document.getElementById('statusBadge').className = 'badge bg-success';
|
||||
} else {
|
||||
document.getElementById('statusBadge').innerHTML =
|
||||
'<i class="fas fa-exclamation-triangle"></i> 未初始化';
|
||||
document.getElementById('statusBadge').className = 'badge bg-warning';
|
||||
}
|
||||
} catch (error) {
|
||||
console.log('Status check failed:', error);
|
||||
document.getElementById('statusBadge').innerHTML =
|
||||
'<i class="fas fa-times-circle"></i> 连接失败';
|
||||
document.getElementById('statusBadge').className = 'badge bg-danger';
|
||||
}
|
||||
}
|
||||
|
||||
// 页面加载时检查状态
|
||||
checkStatus();
|
||||
|
||||
// 设置数据管理功能事件绑定
|
||||
setupDataManagement();
|
||||
|
||||
function setupDataManagement() {
|
||||
// 批量图片上传事件
|
||||
const batchImageUpload = document.getElementById('batchImageUpload');
|
||||
const batchImageFiles = document.getElementById('batchImageFiles');
|
||||
|
||||
// 拖拽上传
|
||||
batchImageUpload.addEventListener('dragover', function(e) {
|
||||
e.preventDefault();
|
||||
this.classList.add('dragover');
|
||||
});
|
||||
|
||||
batchImageUpload.addEventListener('dragleave', function(e) {
|
||||
e.preventDefault();
|
||||
this.classList.remove('dragover');
|
||||
});
|
||||
|
||||
batchImageUpload.addEventListener('drop', function(e) {
|
||||
e.preventDefault();
|
||||
this.classList.remove('dragover');
|
||||
const files = Array.from(e.dataTransfer.files).filter(file => file.type.startsWith('image/'));
|
||||
if (files.length > 0) {
|
||||
uploadBatchImages(files);
|
||||
}
|
||||
});
|
||||
|
||||
batchImageFiles.addEventListener('change', function(e) {
|
||||
if (e.target.files.length > 0) {
|
||||
uploadBatchImages(Array.from(e.target.files));
|
||||
}
|
||||
});
|
||||
|
||||
// 批量文本上传
|
||||
document.getElementById('uploadTextsBtn').addEventListener('click', function() {
|
||||
const textData = document.getElementById('batchTextInput').value.trim();
|
||||
if (textData) {
|
||||
const texts = textData.split('\n').filter(line => line.trim());
|
||||
if (texts.length > 0) {
|
||||
uploadBatchTexts(texts);
|
||||
} else {
|
||||
showAlert('warning', '请输入有效的文本数据');
|
||||
}
|
||||
} else {
|
||||
showAlert('warning', '请输入文本数据');
|
||||
}
|
||||
});
|
||||
|
||||
// 从文件导入文本
|
||||
document.getElementById('textFile').addEventListener('change', function(e) {
|
||||
const file = e.target.files[0];
|
||||
if (file) {
|
||||
const reader = new FileReader();
|
||||
reader.onload = function(e) {
|
||||
document.getElementById('batchTextInput').value = e.target.result;
|
||||
};
|
||||
reader.readAsText(file);
|
||||
}
|
||||
});
|
||||
|
||||
// 构建索引
|
||||
document.getElementById('buildIndexBtn').addEventListener('click', buildIndex);
|
||||
|
||||
// 查看数据
|
||||
document.getElementById('viewDataBtn').addEventListener('click', viewData);
|
||||
|
||||
// 清空数据
|
||||
document.getElementById('clearDataBtn').addEventListener('click', clearData);
|
||||
|
||||
// 初始化时更新数据统计
|
||||
updateDataStats();
|
||||
}
|
||||
|
||||
// 批量上传图片
|
||||
async function uploadBatchImages(files) {
|
||||
const progressDiv = document.getElementById('imageUploadProgress');
|
||||
const progressBar = progressDiv.querySelector('.progress-bar');
|
||||
const progressText = document.getElementById('imageProgressText');
|
||||
|
||||
progressDiv.style.display = 'block';
|
||||
progressText.textContent = `0/${files.length}`;
|
||||
progressBar.style.width = '0%';
|
||||
|
||||
const formData = new FormData();
|
||||
files.forEach(file => {
|
||||
formData.append('images', file);
|
||||
});
|
||||
|
||||
try {
|
||||
const response = await fetch('/api/upload/images', {
|
||||
method: 'POST',
|
||||
body: formData
|
||||
});
|
||||
|
||||
const data = await response.json();
|
||||
|
||||
if (data.success) {
|
||||
progressBar.style.width = '100%';
|
||||
progressText.textContent = `${files.length}/${files.length}`;
|
||||
showAlert('success', `成功上传 ${data.uploaded_count} 张图片`);
|
||||
updateDataStats();
|
||||
document.getElementById('buildIndexBtn').disabled = false;
|
||||
} else {
|
||||
showAlert('danger', `上传失败: ${data.message}`);
|
||||
}
|
||||
} catch (error) {
|
||||
showAlert('danger', `上传错误: ${error.message}`);
|
||||
} finally {
|
||||
setTimeout(() => {
|
||||
progressDiv.style.display = 'none';
|
||||
}, 2000);
|
||||
}
|
||||
}
|
||||
|
||||
// 批量上传文本
|
||||
async function uploadBatchTexts(texts) {
|
||||
try {
|
||||
const response = await fetch('/api/upload/texts', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({ texts: texts })
|
||||
});
|
||||
|
||||
const data = await response.json();
|
||||
|
||||
if (data.success) {
|
||||
showAlert('success', `成功上传 ${data.uploaded_count} 条文本`);
|
||||
document.getElementById('batchTextInput').value = '';
|
||||
updateDataStats();
|
||||
document.getElementById('buildIndexBtn').disabled = false;
|
||||
} else {
|
||||
showAlert('danger', `上传失败: ${data.message}`);
|
||||
}
|
||||
} catch (error) {
|
||||
showAlert('danger', `上传错误: ${error.message}`);
|
||||
}
|
||||
}
|
||||
|
||||
// 构建索引
|
||||
async function buildIndex() {
|
||||
const btn = document.getElementById('buildIndexBtn');
|
||||
const originalText = btn.innerHTML;
|
||||
btn.innerHTML = '<i class="fas fa-spinner fa-spin"></i> 构建中...';
|
||||
btn.disabled = true;
|
||||
|
||||
try {
|
||||
const response = await fetch('/api/build_index', {
|
||||
method: 'POST'
|
||||
});
|
||||
|
||||
const data = await response.json();
|
||||
|
||||
if (data.success) {
|
||||
showAlert('success', '索引构建完成!现在可以进行搜索了');
|
||||
} else {
|
||||
showAlert('danger', `索引构建失败: ${data.message}`);
|
||||
}
|
||||
} catch (error) {
|
||||
showAlert('danger', `构建错误: ${error.message}`);
|
||||
} finally {
|
||||
btn.innerHTML = originalText;
|
||||
btn.disabled = false;
|
||||
}
|
||||
}
|
||||
|
||||
// 查看数据
|
||||
async function viewData() {
|
||||
try {
|
||||
const response = await fetch('/api/data/list');
|
||||
const data = await response.json();
|
||||
|
||||
if (data.success) {
|
||||
let content = '<div class="row">';
|
||||
|
||||
// 显示图片数据
|
||||
if (data.images && data.images.length > 0) {
|
||||
content += '<div class="col-md-6"><h6>图片数据 (' + data.images.length + ')</h6>';
|
||||
content += '<div class="list-group" style="max-height: 300px; overflow-y: auto;">';
|
||||
data.images.forEach(img => {
|
||||
content += `<div class="list-group-item d-flex justify-content-between align-items-center">
|
||||
<span>${img}</span>
|
||||
<img src="/uploads/${img}" class="img-thumbnail" style="width: 50px; height: 50px; object-fit: cover;">
|
||||
</div>`;
|
||||
});
|
||||
content += '</div></div>';
|
||||
}
|
||||
|
||||
// 显示文本数据
|
||||
if (data.texts && data.texts.length > 0) {
|
||||
content += '<div class="col-md-6"><h6>文本数据 (' + data.texts.length + ')</h6>';
|
||||
content += '<div class="list-group" style="max-height: 300px; overflow-y: auto;">';
|
||||
data.texts.forEach((text, index) => {
|
||||
const shortText = text.length > 50 ? text.substring(0, 50) + '...' : text;
|
||||
content += `<div class="list-group-item">
|
||||
<small class="text-muted">#${index + 1}</small><br>
|
||||
${shortText}
|
||||
</div>`;
|
||||
});
|
||||
content += '</div></div>';
|
||||
}
|
||||
|
||||
content += '</div>';
|
||||
|
||||
showModal('数据列表', content);
|
||||
} else {
|
||||
showAlert('danger', `获取数据失败: ${data.message}`);
|
||||
}
|
||||
} catch (error) {
|
||||
showAlert('danger', `获取数据错误: ${error.message}`);
|
||||
}
|
||||
}
|
||||
|
||||
// 清空数据
|
||||
async function clearData() {
|
||||
if (!confirm('确定要清空所有数据吗?此操作不可恢复!')) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await fetch('/api/data/clear', {
|
||||
method: 'POST'
|
||||
});
|
||||
|
||||
const data = await response.json();
|
||||
|
||||
if (data.success) {
|
||||
showAlert('success', '数据已清空');
|
||||
updateDataStats();
|
||||
document.getElementById('buildIndexBtn').disabled = true;
|
||||
} else {
|
||||
showAlert('danger', `清空失败: ${data.message}`);
|
||||
}
|
||||
} catch (error) {
|
||||
showAlert('danger', `清空错误: ${error.message}`);
|
||||
}
|
||||
}
|
||||
|
||||
// 更新数据统计
|
||||
async function updateDataStats() {
|
||||
try {
|
||||
const response = await fetch('/api/data/stats');
|
||||
const data = await response.json();
|
||||
|
||||
if (data.success) {
|
||||
document.getElementById('imageCount').textContent = data.image_count || 0;
|
||||
document.getElementById('textCount').textContent = data.text_count || 0;
|
||||
}
|
||||
} catch (error) {
|
||||
console.log('获取数据统计失败:', error);
|
||||
}
|
||||
}
|
||||
|
||||
// 显示模态框
|
||||
function showModal(title, content) {
|
||||
const modalHtml = `
|
||||
<div class="modal fade" id="dataModal" tabindex="-1">
|
||||
<div class="modal-dialog modal-lg">
|
||||
<div class="modal-content">
|
||||
<div class="modal-header">
|
||||
<h5 class="modal-title">${title}</h5>
|
||||
<button type="button" class="btn-close" data-bs-dismiss="modal"></button>
|
||||
</div>
|
||||
<div class="modal-body">
|
||||
${content}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
`;
|
||||
|
||||
// 移除已存在的模态框
|
||||
const existingModal = document.getElementById('dataModal');
|
||||
if (existingModal) {
|
||||
existingModal.remove();
|
||||
}
|
||||
|
||||
document.body.insertAdjacentHTML('beforeend', modalHtml);
|
||||
const modal = new bootstrap.Modal(document.getElementById('dataModal'));
|
||||
modal.show();
|
||||
}
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
104
test_all_retrieval_modes.py
Normal file
104
test_all_retrieval_modes.py
Normal file
@ -0,0 +1,104 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
测试所有四种多模态检索模式
|
||||
"""
|
||||
|
||||
from multimodal_retrieval_multigpu import MultiGPUMultimodalRetrieval
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import os
|
||||
|
||||
def test_all_retrieval_modes():
|
||||
print('正在初始化多GPU多模态检索系统...')
|
||||
retrieval = MultiGPUMultimodalRetrieval()
|
||||
|
||||
# 准备测试数据
|
||||
test_texts = [
|
||||
"一只可爱的小猫",
|
||||
"美丽的风景照片",
|
||||
"现代建筑设计",
|
||||
"colorful flowers in garden"
|
||||
]
|
||||
|
||||
test_images = [
|
||||
'sample_images/1755677101_1__.jpg',
|
||||
'sample_images/1755677101_2__.jpg',
|
||||
'sample_images/1755677101_3__.jpg',
|
||||
'sample_images/1755677101_4__.jpg'
|
||||
]
|
||||
|
||||
# 验证测试图像存在
|
||||
existing_images = [img for img in test_images if os.path.exists(img)]
|
||||
if not existing_images:
|
||||
print("❌ 没有找到测试图像文件")
|
||||
return
|
||||
|
||||
print(f"找到 {len(existing_images)} 张测试图像")
|
||||
|
||||
try:
|
||||
# 1. 构建文本索引
|
||||
print('\n=== 构建文本索引 ===')
|
||||
retrieval.build_text_index_parallel(test_texts)
|
||||
print('✅ 文本索引构建完成')
|
||||
|
||||
# 2. 构建图像索引
|
||||
print('\n=== 构建图像索引 ===')
|
||||
retrieval.build_image_index_parallel(existing_images)
|
||||
print('✅ 图像索引构建完成')
|
||||
|
||||
# 3. 测试文本到文本检索
|
||||
print('\n=== 测试文本到文本检索 ===')
|
||||
query = "小动物"
|
||||
results = retrieval.search_text_by_text(query, top_k=3)
|
||||
print(f'查询: "{query}"')
|
||||
for i, (text, score) in enumerate(results):
|
||||
print(f' {i+1}. {text} (相似度: {score:.4f})')
|
||||
|
||||
# 4. 测试文本到图像检索
|
||||
print('\n=== 测试文本到图像检索 ===')
|
||||
query = "beautiful image"
|
||||
results = retrieval.search_images_by_text(query, top_k=3)
|
||||
print(f'查询: "{query}"')
|
||||
for i, (image_path, score) in enumerate(results):
|
||||
print(f' {i+1}. {image_path} (相似度: {score:.4f})')
|
||||
|
||||
# 5. 测试图像到文本检索
|
||||
print('\n=== 测试图像到文本检索 ===')
|
||||
query_image = existing_images[0]
|
||||
results = retrieval.search_text_by_image(query_image, top_k=3)
|
||||
print(f'查询图像: {query_image}')
|
||||
for i, (text, score) in enumerate(results):
|
||||
print(f' {i+1}. {text} (相似度: {score:.4f})')
|
||||
|
||||
# 6. 测试图像到图像检索
|
||||
print('\n=== 测试图像到图像检索 ===')
|
||||
query_image = existing_images[0]
|
||||
results = retrieval.search_images_by_image(query_image, top_k=3)
|
||||
print(f'查询图像: {query_image}')
|
||||
for i, (image_path, score) in enumerate(results):
|
||||
print(f' {i+1}. {image_path} (相似度: {score:.4f})')
|
||||
|
||||
print('\n✅ 所有四种检索模式测试完成!')
|
||||
|
||||
# 7. 测试Web应用兼容的方法名
|
||||
print('\n=== 测试Web应用兼容方法 ===')
|
||||
try:
|
||||
results = retrieval.search_text_to_image("test query", top_k=2)
|
||||
print('✅ search_text_to_image 方法正常')
|
||||
|
||||
results = retrieval.search_image_to_text(existing_images[0], top_k=2)
|
||||
print('✅ search_image_to_text 方法正常')
|
||||
|
||||
results = retrieval.search_image_to_image(existing_images[0], top_k=2)
|
||||
print('✅ search_image_to_image 方法正常')
|
||||
|
||||
except Exception as e:
|
||||
print(f'❌ Web应用兼容方法测试失败: {e}')
|
||||
|
||||
except Exception as e:
|
||||
print(f'❌ 测试过程中出现错误: {e}')
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_all_retrieval_modes()
|
||||
49
test_image_encoding.py
Normal file
49
test_image_encoding.py
Normal file
@ -0,0 +1,49 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
测试图像编码功能
|
||||
"""
|
||||
|
||||
from multimodal_retrieval_multigpu import MultiGPUMultimodalRetrieval
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
def test_image_encoding():
|
||||
print('正在初始化多GPU多模态检索系统...')
|
||||
retrieval = MultiGPUMultimodalRetrieval()
|
||||
|
||||
# 测试文本编码
|
||||
print('测试文本编码...')
|
||||
text_embeddings = retrieval.encode_text_batch(['这是一个测试文本'])
|
||||
print(f'文本embedding形状: {text_embeddings.shape}')
|
||||
print(f'文本embedding数据类型: {text_embeddings.dtype}')
|
||||
|
||||
# 测试图像编码
|
||||
print('测试图像编码...')
|
||||
test_images = ['sample_images/1755677101_1__.jpg']
|
||||
image_embeddings = retrieval.encode_image_batch(test_images)
|
||||
print(f'图像embedding形状: {image_embeddings.shape}')
|
||||
print(f'图像embedding数据类型: {image_embeddings.dtype}')
|
||||
|
||||
# 测试两次相同图像的embedding是否一致
|
||||
print('测试embedding一致性...')
|
||||
image_embeddings2 = retrieval.encode_image_batch(test_images)
|
||||
consistency = np.allclose(image_embeddings, image_embeddings2, rtol=1e-5)
|
||||
print(f'相同图像embedding一致性: {consistency}')
|
||||
|
||||
# 测试不同图像的embedding差异
|
||||
print('测试不同图像embedding差异...')
|
||||
test_images2 = ['sample_images/1755677101_2__.jpg']
|
||||
image_embeddings3 = retrieval.encode_image_batch(test_images2)
|
||||
similarity = np.dot(image_embeddings[0], image_embeddings3[0]) / (np.linalg.norm(image_embeddings[0]) * np.linalg.norm(image_embeddings3[0]))
|
||||
print(f'不同图像间相似度: {similarity:.4f}')
|
||||
|
||||
# 验证维度一致性
|
||||
if text_embeddings.shape[1] == image_embeddings.shape[1]:
|
||||
print('✅ 文本和图像embedding维度一致')
|
||||
else:
|
||||
print('❌ 文本和图像embedding维度不一致')
|
||||
|
||||
print('测试完成!')
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_image_encoding()
|
||||
BIN
uploads/query_1755675080_10__.jpg
Normal file
BIN
uploads/query_1755675080_10__.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 172 KiB |
BIN
uploads/query_1755681423_5__.jpg
Normal file
BIN
uploads/query_1755681423_5__.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 201 KiB |
617
web_app_multigpu.py
Normal file
617
web_app_multigpu.py
Normal file
@ -0,0 +1,617 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
多GPU多模态检索系统 - Web应用
|
||||
专为双GPU部署优化
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
from flask import Flask, render_template, request, jsonify, send_file, url_for
|
||||
from werkzeug.utils import secure_filename
|
||||
from PIL import Image
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
import traceback
|
||||
import glob
|
||||
|
||||
# 设置环境变量优化GPU内存
|
||||
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
app = Flask(__name__)
|
||||
app.config['SECRET_KEY'] = 'multigpu_multimodal_retrieval_2024'
|
||||
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max file size
|
||||
|
||||
# 配置上传文件夹
|
||||
UPLOAD_FOLDER = 'uploads'
|
||||
SAMPLE_IMAGES_FOLDER = 'sample_images'
|
||||
TEXT_DATA_FOLDER = 'text_data'
|
||||
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'bmp', 'webp'}
|
||||
|
||||
# 确保文件夹存在
|
||||
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
|
||||
os.makedirs(SAMPLE_IMAGES_FOLDER, exist_ok=True)
|
||||
os.makedirs(TEXT_DATA_FOLDER, exist_ok=True)
|
||||
|
||||
# 全局检索系统实例
|
||||
retrieval_system = None
|
||||
|
||||
def allowed_file(filename):
|
||||
"""检查文件扩展名是否允许"""
|
||||
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
|
||||
|
||||
def image_to_base64(image_path):
|
||||
"""将图片转换为base64编码"""
|
||||
try:
|
||||
with open(image_path, "rb") as img_file:
|
||||
return base64.b64encode(img_file.read()).decode('utf-8')
|
||||
except Exception as e:
|
||||
logger.error(f"图片转换失败: {e}")
|
||||
return None
|
||||
|
||||
@app.route('/')
|
||||
def index():
|
||||
"""主页"""
|
||||
return render_template('index.html')
|
||||
|
||||
@app.route('/api/status')
|
||||
def get_status():
|
||||
"""获取系统状态"""
|
||||
global retrieval_system
|
||||
|
||||
status = {
|
||||
'initialized': retrieval_system is not None,
|
||||
'gpu_count': 0,
|
||||
'model_loaded': False
|
||||
}
|
||||
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
status['gpu_count'] = torch.cuda.device_count()
|
||||
|
||||
if retrieval_system and retrieval_system.model:
|
||||
status['model_loaded'] = True
|
||||
status['device_ids'] = retrieval_system.device_ids
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取状态失败: {e}")
|
||||
|
||||
return jsonify(status)
|
||||
|
||||
@app.route('/api/init', methods=['POST'])
|
||||
def initialize_system():
|
||||
"""初始化多GPU检索系统"""
|
||||
global retrieval_system
|
||||
|
||||
try:
|
||||
logger.info("正在初始化多GPU检索系统...")
|
||||
|
||||
# 导入多GPU检索系统
|
||||
from multimodal_retrieval_multigpu import MultiGPUMultimodalRetrieval
|
||||
|
||||
# 初始化系统
|
||||
retrieval_system = MultiGPUMultimodalRetrieval()
|
||||
|
||||
if retrieval_system.model is None:
|
||||
raise Exception("模型加载失败")
|
||||
|
||||
logger.info("✅ 多GPU系统初始化成功")
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'message': '多GPU系统初始化成功',
|
||||
'device_ids': retrieval_system.device_ids,
|
||||
'gpu_count': len(retrieval_system.device_ids)
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"系统初始化失败: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'message': error_msg
|
||||
}), 500
|
||||
|
||||
@app.route('/api/search/text_to_text', methods=['POST'])
|
||||
def search_text_to_text():
|
||||
"""文本搜索文本"""
|
||||
return handle_search('text_to_text')
|
||||
|
||||
@app.route('/api/search/text_to_image', methods=['POST'])
|
||||
def search_text_to_image():
|
||||
"""文本搜索图片"""
|
||||
return handle_search('text_to_image')
|
||||
|
||||
@app.route('/api/search/image_to_text', methods=['POST'])
|
||||
def search_image_to_text():
|
||||
"""图片搜索文本"""
|
||||
return handle_search('image_to_text')
|
||||
|
||||
@app.route('/api/search/image_to_image', methods=['POST'])
|
||||
def search_image_to_image():
|
||||
"""图片搜索图片"""
|
||||
return handle_search('image_to_image')
|
||||
|
||||
@app.route('/api/search', methods=['POST'])
|
||||
def search():
|
||||
"""通用搜索接口(兼容旧版本)"""
|
||||
mode = request.form.get('mode') or request.json.get('mode', 'text_to_text')
|
||||
return handle_search(mode)
|
||||
|
||||
def handle_search(mode):
|
||||
"""处理搜索请求的通用函数"""
|
||||
global retrieval_system
|
||||
|
||||
if not retrieval_system:
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'message': '系统未初始化,请先点击初始化按钮'
|
||||
}), 400
|
||||
|
||||
try:
|
||||
top_k = int(request.form.get('top_k', 5))
|
||||
|
||||
if mode in ['text_to_text', 'text_to_image']:
|
||||
# 文本查询
|
||||
query = request.form.get('query') or request.json.get('query', '')
|
||||
if not query.strip():
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'message': '请输入查询文本'
|
||||
}), 400
|
||||
|
||||
logger.info(f"执行{mode}搜索: {query}")
|
||||
|
||||
# 执行搜索
|
||||
if mode == 'text_to_text':
|
||||
results = retrieval_system.search_text_to_text(query, top_k=top_k)
|
||||
else: # text_to_image
|
||||
results = retrieval_system.search_text_to_image(query, top_k=top_k)
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'mode': mode,
|
||||
'query': query,
|
||||
'results': results,
|
||||
'result_count': len(results)
|
||||
})
|
||||
|
||||
elif mode in ['image_to_text', 'image_to_image']:
|
||||
# 图片查询
|
||||
if 'image' not in request.files:
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'message': '请上传查询图片'
|
||||
}), 400
|
||||
|
||||
file = request.files['image']
|
||||
if file.filename == '' or not allowed_file(file.filename):
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'message': '请上传有效的图片文件'
|
||||
}), 400
|
||||
|
||||
# 保存上传的图片
|
||||
filename = secure_filename(file.filename)
|
||||
timestamp = str(int(time.time()))
|
||||
filename = f"query_{timestamp}_{filename}"
|
||||
filepath = os.path.join(UPLOAD_FOLDER, filename)
|
||||
file.save(filepath)
|
||||
|
||||
logger.info(f"执行{mode}搜索,图片: {filename}")
|
||||
|
||||
# 执行搜索
|
||||
if mode == 'image_to_text':
|
||||
results = retrieval_system.search_image_to_text(filepath, top_k=top_k)
|
||||
else: # image_to_image
|
||||
results = retrieval_system.search_image_to_image(filepath, top_k=top_k)
|
||||
|
||||
# 转换查询图片为base64
|
||||
query_image_b64 = image_to_base64(filepath)
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'mode': mode,
|
||||
'query_image': query_image_b64,
|
||||
'results': results,
|
||||
'result_count': len(results)
|
||||
})
|
||||
|
||||
else:
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'message': f'不支持的搜索模式: {mode}'
|
||||
}), 400
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"搜索失败: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'message': error_msg
|
||||
}), 500
|
||||
|
||||
@app.route('/api/upload/images', methods=['POST'])
|
||||
def upload_images():
|
||||
"""批量上传图片"""
|
||||
try:
|
||||
uploaded_files = []
|
||||
|
||||
if 'images' not in request.files:
|
||||
return jsonify({'success': False, 'message': '没有选择文件'}), 400
|
||||
|
||||
files = request.files.getlist('images')
|
||||
|
||||
for file in files:
|
||||
if file and file.filename != '' and allowed_file(file.filename):
|
||||
filename = secure_filename(file.filename)
|
||||
timestamp = str(int(time.time()))
|
||||
filename = f"{timestamp}_{filename}"
|
||||
filepath = os.path.join(SAMPLE_IMAGES_FOLDER, filename)
|
||||
file.save(filepath)
|
||||
uploaded_files.append(filename)
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'message': f'成功上传 {len(uploaded_files)} 个图片文件',
|
||||
'uploaded_count': len(uploaded_files),
|
||||
'files': uploaded_files
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'message': f'上传失败: {str(e)}'
|
||||
}), 500
|
||||
|
||||
@app.route('/api/upload/texts', methods=['POST'])
|
||||
def upload_texts():
|
||||
"""批量上传文本数据"""
|
||||
try:
|
||||
data = request.get_json()
|
||||
|
||||
if not data or 'texts' not in data:
|
||||
return jsonify({'success': False, 'message': '没有提供文本数据'}), 400
|
||||
|
||||
texts = data['texts']
|
||||
if not isinstance(texts, list):
|
||||
return jsonify({'success': False, 'message': '文本数据格式错误'}), 400
|
||||
|
||||
# 保存文本数据到文件
|
||||
timestamp = str(int(time.time()))
|
||||
filename = f"texts_{timestamp}.json"
|
||||
filepath = os.path.join(TEXT_DATA_FOLDER, filename)
|
||||
|
||||
with open(filepath, 'w', encoding='utf-8') as f:
|
||||
json.dump(texts, f, ensure_ascii=False, indent=2)
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'message': f'成功上传 {len(texts)} 条文本',
|
||||
'uploaded_count': len(texts)
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'message': f'上传失败: {str(e)}'
|
||||
}), 500
|
||||
|
||||
@app.route('/api/upload/file', methods=['POST'])
|
||||
def upload_single_file():
|
||||
"""上传单个文件"""
|
||||
if 'file' not in request.files:
|
||||
return jsonify({'success': False, 'message': '没有选择文件'}), 400
|
||||
|
||||
file = request.files['file']
|
||||
if file.filename == '':
|
||||
return jsonify({'success': False, 'message': '没有选择文件'}), 400
|
||||
|
||||
if file and allowed_file(file.filename):
|
||||
filename = secure_filename(file.filename)
|
||||
timestamp = str(int(time.time()))
|
||||
filename = f"{timestamp}_{filename}"
|
||||
filepath = os.path.join(SAMPLE_IMAGES_FOLDER, filename)
|
||||
file.save(filepath)
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'message': '文件上传成功',
|
||||
'filename': filename
|
||||
})
|
||||
|
||||
return jsonify({'success': False, 'message': '不支持的文件类型'}), 400
|
||||
|
||||
@app.route('/api/data/list', methods=['GET'])
|
||||
def list_data():
|
||||
"""列出已上传的数据"""
|
||||
try:
|
||||
# 列出图片文件
|
||||
images = []
|
||||
if os.path.exists(SAMPLE_IMAGES_FOLDER):
|
||||
for filename in os.listdir(SAMPLE_IMAGES_FOLDER):
|
||||
if allowed_file(filename):
|
||||
filepath = os.path.join(SAMPLE_IMAGES_FOLDER, filename)
|
||||
stat = os.stat(filepath)
|
||||
images.append({
|
||||
'filename': filename,
|
||||
'size': stat.st_size,
|
||||
'modified': stat.st_mtime
|
||||
})
|
||||
|
||||
# 列出文本文件
|
||||
texts = []
|
||||
if os.path.exists(TEXT_DATA_FOLDER):
|
||||
for filename in os.listdir(TEXT_DATA_FOLDER):
|
||||
if filename.endswith('.json'):
|
||||
filepath = os.path.join(TEXT_DATA_FOLDER, filename)
|
||||
stat = os.stat(filepath)
|
||||
texts.append({
|
||||
'filename': filename,
|
||||
'size': stat.st_size,
|
||||
'modified': stat.st_mtime
|
||||
})
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'data': {
|
||||
'images': images,
|
||||
'texts': texts
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'message': f'获取数据列表失败: {str(e)}'
|
||||
}), 500
|
||||
|
||||
@app.route('/api/gpu_status')
|
||||
def gpu_status():
|
||||
"""获取GPU状态"""
|
||||
try:
|
||||
from smart_gpu_launcher import get_gpu_memory_info
|
||||
gpu_info = get_gpu_memory_info()
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'gpu_info': gpu_info
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'message': f"获取GPU状态失败: {str(e)}"
|
||||
}), 500
|
||||
|
||||
@app.route('/api/build_index', methods=['POST'])
|
||||
def build_index():
|
||||
"""构建检索索引"""
|
||||
global retrieval_system
|
||||
|
||||
if not retrieval_system:
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'message': '系统未初始化'
|
||||
}), 400
|
||||
|
||||
try:
|
||||
# 获取所有图片和文本文件
|
||||
image_files = []
|
||||
text_data = []
|
||||
|
||||
# 扫描图片文件
|
||||
for ext in ALLOWED_EXTENSIONS:
|
||||
pattern = os.path.join(SAMPLE_IMAGES_FOLDER, f"*.{ext}")
|
||||
image_files.extend(glob.glob(pattern))
|
||||
|
||||
# 读取文本文件(支持.json和.txt格式)
|
||||
text_files = glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.json"))
|
||||
text_files.extend(glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.txt")))
|
||||
|
||||
for text_file in text_files:
|
||||
try:
|
||||
if text_file.endswith('.json'):
|
||||
# 读取JSON格式的文本数据
|
||||
with open(text_file, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
if isinstance(data, list):
|
||||
text_data.extend([str(item).strip() for item in data if str(item).strip()])
|
||||
else:
|
||||
text_data.append(str(data).strip())
|
||||
else:
|
||||
# 读取TXT格式的文本数据
|
||||
with open(text_file, 'r', encoding='utf-8') as f:
|
||||
lines = [line.strip() for line in f.readlines() if line.strip()]
|
||||
text_data.extend(lines)
|
||||
except Exception as e:
|
||||
logger.warning(f"读取文本文件失败 {text_file}: {e}")
|
||||
|
||||
# 检查是否有数据可以构建索引
|
||||
if not image_files and not text_data:
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'message': '没有找到可用的图片或文本数据,请先上传数据'
|
||||
}), 400
|
||||
|
||||
# 构建索引
|
||||
if image_files:
|
||||
logger.info(f"构建图片索引,共 {len(image_files)} 张图片")
|
||||
retrieval_system.build_image_index_parallel(image_files)
|
||||
|
||||
if text_data:
|
||||
logger.info(f"构建文本索引,共 {len(text_data)} 条文本")
|
||||
retrieval_system.build_text_index_parallel(text_data)
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'message': f'索引构建完成!图片: {len(image_files)} 张,文本: {len(text_data)} 条',
|
||||
'image_count': len(image_files),
|
||||
'text_count': len(text_data)
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"构建索引失败: {str(e)}")
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'message': f'构建索引失败: {str(e)}'
|
||||
}), 500
|
||||
|
||||
@app.route('/api/data/stats', methods=['GET'])
|
||||
def get_data_stats():
|
||||
"""获取数据统计信息"""
|
||||
try:
|
||||
# 统计图片文件
|
||||
image_count = 0
|
||||
for ext in ALLOWED_EXTENSIONS:
|
||||
pattern = os.path.join(SAMPLE_IMAGES_FOLDER, f"*.{ext}")
|
||||
image_count += len(glob.glob(pattern))
|
||||
|
||||
# 统计文本数据
|
||||
text_count = 0
|
||||
text_files = glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.txt"))
|
||||
for text_file in text_files:
|
||||
try:
|
||||
with open(text_file, 'r', encoding='utf-8') as f:
|
||||
lines = [line.strip() for line in f.readlines() if line.strip()]
|
||||
text_count += len(lines)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'image_count': image_count,
|
||||
'text_count': text_count
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取数据统计失败: {str(e)}")
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'message': f'获取统计失败: {str(e)}'
|
||||
}), 500
|
||||
|
||||
@app.route('/api/data/clear', methods=['POST'])
|
||||
def clear_data():
|
||||
"""清空所有数据"""
|
||||
try:
|
||||
# 清空图片文件
|
||||
for ext in ALLOWED_EXTENSIONS:
|
||||
pattern = os.path.join(SAMPLE_IMAGES_FOLDER, f"*.{ext}")
|
||||
for file_path in glob.glob(pattern):
|
||||
try:
|
||||
os.remove(file_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"删除图片文件失败 {file_path}: {e}")
|
||||
|
||||
# 清空文本文件
|
||||
text_files = glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.txt"))
|
||||
for text_file in text_files:
|
||||
try:
|
||||
os.remove(text_file)
|
||||
except Exception as e:
|
||||
logger.warning(f"删除文本文件失败 {text_file}: {e}")
|
||||
|
||||
# 重置索引
|
||||
global retrieval_system
|
||||
if retrieval_system:
|
||||
retrieval_system.text_index = None
|
||||
retrieval_system.image_index = None
|
||||
retrieval_system.text_data = []
|
||||
retrieval_system.image_data = []
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'message': '数据已清空'
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"清空数据失败: {str(e)}")
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'message': f'清空数据失败: {str(e)}'
|
||||
}), 500
|
||||
|
||||
@app.route('/uploads/<filename>')
|
||||
def uploaded_file(filename):
|
||||
"""提供上传文件的访问"""
|
||||
return send_file(os.path.join(SAMPLE_IMAGES_FOLDER, filename))
|
||||
|
||||
def print_startup_info():
|
||||
"""打印启动信息"""
|
||||
print("🚀 启动多GPU多模态检索Web应用")
|
||||
print("=" * 60)
|
||||
print("访问地址: http://localhost:5000")
|
||||
print("支持功能:")
|
||||
print(" 📝 文搜文 - 文本查找相似文本")
|
||||
print(" 🖼️ 文搜图 - 文本查找相关图片")
|
||||
print(" 📝 图搜文 - 图片查找相关文本")
|
||||
print(" 🖼️ 图搜图 - 图片查找相似图片")
|
||||
print(" 📤 批量上传 - 图片和文本数据管理")
|
||||
print("GPU配置:")
|
||||
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
gpu_count = torch.cuda.device_count()
|
||||
print(f" 🖥️ 检测到 {gpu_count} 个GPU")
|
||||
for i in range(gpu_count):
|
||||
name = torch.cuda.get_device_name(i)
|
||||
props = torch.cuda.get_device_properties(i)
|
||||
memory_gb = props.total_memory / 1024**3
|
||||
print(f" GPU {i}: {name} ({memory_gb:.1f}GB)")
|
||||
else:
|
||||
print(" ❌ CUDA不可用")
|
||||
except Exception as e:
|
||||
print(f" ❌ GPU检查失败: {e}")
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
def auto_initialize():
|
||||
"""启动时自动初始化系统"""
|
||||
global retrieval_system
|
||||
|
||||
try:
|
||||
logger.info("🚀 启动时自动初始化多GPU检索系统...")
|
||||
|
||||
# 导入多GPU检索系统
|
||||
from multimodal_retrieval_multigpu import MultiGPUMultimodalRetrieval
|
||||
|
||||
# 初始化系统
|
||||
retrieval_system = MultiGPUMultimodalRetrieval()
|
||||
|
||||
if retrieval_system.model is None:
|
||||
raise Exception("模型加载失败")
|
||||
|
||||
logger.info("✅ 系统自动初始化成功")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 系统自动初始化失败: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
if __name__ == '__main__':
|
||||
print_startup_info()
|
||||
|
||||
# 启动时自动初始化
|
||||
auto_initialize()
|
||||
|
||||
# 启动Flask应用
|
||||
app.run(
|
||||
host='0.0.0.0',
|
||||
port=5000,
|
||||
debug=False,
|
||||
threaded=True
|
||||
)
|
||||
Loading…
x
Reference in New Issue
Block a user