base feature
This commit is contained in:
parent
1de00fccda
commit
0cd7a4cb41
201
LICENSE
201
LICENSE
@ -1,201 +0,0 @@
|
|||||||
Apache License
|
|
||||||
Version 2.0, January 2004
|
|
||||||
http://www.apache.org/licenses/
|
|
||||||
|
|
||||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
|
||||||
|
|
||||||
1. Definitions.
|
|
||||||
|
|
||||||
"License" shall mean the terms and conditions for use, reproduction,
|
|
||||||
and distribution as defined by Sections 1 through 9 of this document.
|
|
||||||
|
|
||||||
"Licensor" shall mean the copyright owner or entity authorized by
|
|
||||||
the copyright owner that is granting the License.
|
|
||||||
|
|
||||||
"Legal Entity" shall mean the union of the acting entity and all
|
|
||||||
other entities that control, are controlled by, or are under common
|
|
||||||
control with that entity. For the purposes of this definition,
|
|
||||||
"control" means (i) the power, direct or indirect, to cause the
|
|
||||||
direction or management of such entity, whether by contract or
|
|
||||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
|
||||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
|
||||||
|
|
||||||
"You" (or "Your") shall mean an individual or Legal Entity
|
|
||||||
exercising permissions granted by this License.
|
|
||||||
|
|
||||||
"Source" form shall mean the preferred form for making modifications,
|
|
||||||
including but not limited to software source code, documentation
|
|
||||||
source, and configuration files.
|
|
||||||
|
|
||||||
"Object" form shall mean any form resulting from mechanical
|
|
||||||
transformation or translation of a Source form, including but
|
|
||||||
not limited to compiled object code, generated documentation,
|
|
||||||
and conversions to other media types.
|
|
||||||
|
|
||||||
"Work" shall mean the work of authorship, whether in Source or
|
|
||||||
Object form, made available under the License, as indicated by a
|
|
||||||
copyright notice that is included in or attached to the work
|
|
||||||
(an example is provided in the Appendix below).
|
|
||||||
|
|
||||||
"Derivative Works" shall mean any work, whether in Source or Object
|
|
||||||
form, that is based on (or derived from) the Work and for which the
|
|
||||||
editorial revisions, annotations, elaborations, or other modifications
|
|
||||||
represent, as a whole, an original work of authorship. For the purposes
|
|
||||||
of this License, Derivative Works shall not include works that remain
|
|
||||||
separable from, or merely link (or bind by name) to the interfaces of,
|
|
||||||
the Work and Derivative Works thereof.
|
|
||||||
|
|
||||||
"Contribution" shall mean any work of authorship, including
|
|
||||||
the original version of the Work and any modifications or additions
|
|
||||||
to that Work or Derivative Works thereof, that is intentionally
|
|
||||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
|
||||||
or by an individual or Legal Entity authorized to submit on behalf of
|
|
||||||
the copyright owner. For the purposes of this definition, "submitted"
|
|
||||||
means any form of electronic, verbal, or written communication sent
|
|
||||||
to the Licensor or its representatives, including but not limited to
|
|
||||||
communication on electronic mailing lists, source code control systems,
|
|
||||||
and issue tracking systems that are managed by, or on behalf of, the
|
|
||||||
Licensor for the purpose of discussing and improving the Work, but
|
|
||||||
excluding communication that is conspicuously marked or otherwise
|
|
||||||
designated in writing by the copyright owner as "Not a Contribution."
|
|
||||||
|
|
||||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
|
||||||
on behalf of whom a Contribution has been received by Licensor and
|
|
||||||
subsequently incorporated within the Work.
|
|
||||||
|
|
||||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
copyright license to reproduce, prepare Derivative Works of,
|
|
||||||
publicly display, publicly perform, sublicense, and distribute the
|
|
||||||
Work and such Derivative Works in Source or Object form.
|
|
||||||
|
|
||||||
3. Grant of Patent License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
(except as stated in this section) patent license to make, have made,
|
|
||||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
|
||||||
where such license applies only to those patent claims licensable
|
|
||||||
by such Contributor that are necessarily infringed by their
|
|
||||||
Contribution(s) alone or by combination of their Contribution(s)
|
|
||||||
with the Work to which such Contribution(s) was submitted. If You
|
|
||||||
institute patent litigation against any entity (including a
|
|
||||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
|
||||||
or a Contribution incorporated within the Work constitutes direct
|
|
||||||
or contributory patent infringement, then any patent licenses
|
|
||||||
granted to You under this License for that Work shall terminate
|
|
||||||
as of the date such litigation is filed.
|
|
||||||
|
|
||||||
4. Redistribution. You may reproduce and distribute copies of the
|
|
||||||
Work or Derivative Works thereof in any medium, with or without
|
|
||||||
modifications, and in Source or Object form, provided that You
|
|
||||||
meet the following conditions:
|
|
||||||
|
|
||||||
(a) You must give any other recipients of the Work or
|
|
||||||
Derivative Works a copy of this License; and
|
|
||||||
|
|
||||||
(b) You must cause any modified files to carry prominent notices
|
|
||||||
stating that You changed the files; and
|
|
||||||
|
|
||||||
(c) You must retain, in the Source form of any Derivative Works
|
|
||||||
that You distribute, all copyright, patent, trademark, and
|
|
||||||
attribution notices from the Source form of the Work,
|
|
||||||
excluding those notices that do not pertain to any part of
|
|
||||||
the Derivative Works; and
|
|
||||||
|
|
||||||
(d) If the Work includes a "NOTICE" text file as part of its
|
|
||||||
distribution, then any Derivative Works that You distribute must
|
|
||||||
include a readable copy of the attribution notices contained
|
|
||||||
within such NOTICE file, excluding those notices that do not
|
|
||||||
pertain to any part of the Derivative Works, in at least one
|
|
||||||
of the following places: within a NOTICE text file distributed
|
|
||||||
as part of the Derivative Works; within the Source form or
|
|
||||||
documentation, if provided along with the Derivative Works; or,
|
|
||||||
within a display generated by the Derivative Works, if and
|
|
||||||
wherever such third-party notices normally appear. The contents
|
|
||||||
of the NOTICE file are for informational purposes only and
|
|
||||||
do not modify the License. You may add Your own attribution
|
|
||||||
notices within Derivative Works that You distribute, alongside
|
|
||||||
or as an addendum to the NOTICE text from the Work, provided
|
|
||||||
that such additional attribution notices cannot be construed
|
|
||||||
as modifying the License.
|
|
||||||
|
|
||||||
You may add Your own copyright statement to Your modifications and
|
|
||||||
may provide additional or different license terms and conditions
|
|
||||||
for use, reproduction, or distribution of Your modifications, or
|
|
||||||
for any such Derivative Works as a whole, provided Your use,
|
|
||||||
reproduction, and distribution of the Work otherwise complies with
|
|
||||||
the conditions stated in this License.
|
|
||||||
|
|
||||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
|
||||||
any Contribution intentionally submitted for inclusion in the Work
|
|
||||||
by You to the Licensor shall be under the terms and conditions of
|
|
||||||
this License, without any additional terms or conditions.
|
|
||||||
Notwithstanding the above, nothing herein shall supersede or modify
|
|
||||||
the terms of any separate license agreement you may have executed
|
|
||||||
with Licensor regarding such Contributions.
|
|
||||||
|
|
||||||
6. Trademarks. This License does not grant permission to use the trade
|
|
||||||
names, trademarks, service marks, or product names of the Licensor,
|
|
||||||
except as required for reasonable and customary use in describing the
|
|
||||||
origin of the Work and reproducing the content of the NOTICE file.
|
|
||||||
|
|
||||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
|
||||||
agreed to in writing, Licensor provides the Work (and each
|
|
||||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
|
||||||
implied, including, without limitation, any warranties or conditions
|
|
||||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
|
||||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
|
||||||
appropriateness of using or redistributing the Work and assume any
|
|
||||||
risks associated with Your exercise of permissions under this License.
|
|
||||||
|
|
||||||
8. Limitation of Liability. In no event and under no legal theory,
|
|
||||||
whether in tort (including negligence), contract, or otherwise,
|
|
||||||
unless required by applicable law (such as deliberate and grossly
|
|
||||||
negligent acts) or agreed to in writing, shall any Contributor be
|
|
||||||
liable to You for damages, including any direct, indirect, special,
|
|
||||||
incidental, or consequential damages of any character arising as a
|
|
||||||
result of this License or out of the use or inability to use the
|
|
||||||
Work (including but not limited to damages for loss of goodwill,
|
|
||||||
work stoppage, computer failure or malfunction, or any and all
|
|
||||||
other commercial damages or losses), even if such Contributor
|
|
||||||
has been advised of the possibility of such damages.
|
|
||||||
|
|
||||||
9. Accepting Warranty or Additional Liability. While redistributing
|
|
||||||
the Work or Derivative Works thereof, You may choose to offer,
|
|
||||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
|
||||||
or other liability obligations and/or rights consistent with this
|
|
||||||
License. However, in accepting such obligations, You may act only
|
|
||||||
on Your own behalf and on Your sole responsibility, not on behalf
|
|
||||||
of any other Contributor, and only if You agree to indemnify,
|
|
||||||
defend, and hold each Contributor harmless for any liability
|
|
||||||
incurred by, or claims asserted against, such Contributor by reason
|
|
||||||
of your accepting any such warranty or additional liability.
|
|
||||||
|
|
||||||
END OF TERMS AND CONDITIONS
|
|
||||||
|
|
||||||
APPENDIX: How to apply the Apache License to your work.
|
|
||||||
|
|
||||||
To apply the Apache License to your work, attach the following
|
|
||||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
|
||||||
replaced with your own identifying information. (Don't include
|
|
||||||
the brackets!) The text should be enclosed in the appropriate
|
|
||||||
comment syntax for the file format. We also recommend that a
|
|
||||||
file or class name and description of purpose be included on the
|
|
||||||
same "printed page" as the copyright notice for easier
|
|
||||||
identification within third-party archives.
|
|
||||||
|
|
||||||
Copyright [yyyy] [name of copyright owner]
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
170
README.md
Normal file
170
README.md
Normal file
@ -0,0 +1,170 @@
|
|||||||
|
# 多模态检索系统 (Multimodal Retrieval System)
|
||||||
|
|
||||||
|
基于 OpenSearch-AI/Ops-MM-embedding-v1-7B 模型的多模态检索系统,支持四种检索模式:文搜图、文搜文、图搜图、图搜文。
|
||||||
|
|
||||||
|
## 功能特性
|
||||||
|
|
||||||
|
- **文搜文 (Text-to-Text)**: 使用文本查询搜索相似文本
|
||||||
|
- **文搜图 (Text-to-Image)**: 使用文本查询搜索相关图像
|
||||||
|
- **图搜图 (Image-to-Image)**: 使用图像查询搜索相似图像
|
||||||
|
- **图搜文 (Image-to-Text)**: 使用图像查询搜索相关文本
|
||||||
|
|
||||||
|
## 环境要求
|
||||||
|
|
||||||
|
- Python 3.8+
|
||||||
|
- CUDA 支持的 GPU (推荐,也可使用CPU)
|
||||||
|
- 至少 8GB 内存
|
||||||
|
|
||||||
|
## 安装依赖
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
## 快速开始
|
||||||
|
|
||||||
|
### 1. 检查GPU环境
|
||||||
|
|
||||||
|
```python
|
||||||
|
from multimodal_retrieval import check_gpu_info
|
||||||
|
check_gpu_info()
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 初始化系统
|
||||||
|
|
||||||
|
```python
|
||||||
|
from multimodal_retrieval import MultimodalRetrieval
|
||||||
|
|
||||||
|
# 初始化检索系统
|
||||||
|
retrieval_system = MultimodalRetrieval()
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. 构建索引
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 构建文本索引
|
||||||
|
texts = ["一只可爱的小猫", "美丽的山景", "现代化城市"]
|
||||||
|
retrieval_system.build_text_index(texts)
|
||||||
|
|
||||||
|
# 构建图像索引
|
||||||
|
image_paths = ["./images/cat.jpg", "./images/mountain.jpg", "./images/city.jpg"]
|
||||||
|
retrieval_system.build_image_index(image_paths)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. 执行检索
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 文搜文
|
||||||
|
results = retrieval_system.search_text_by_text("猫咪", top_k=5)
|
||||||
|
|
||||||
|
# 文搜图
|
||||||
|
results = retrieval_system.search_images_by_text("动物", top_k=5)
|
||||||
|
|
||||||
|
# 图搜图
|
||||||
|
results = retrieval_system.search_images_by_image("./query_image.jpg", top_k=5)
|
||||||
|
|
||||||
|
# 图搜文
|
||||||
|
results = retrieval_system.search_text_by_image("./query_image.jpg", top_k=5)
|
||||||
|
```
|
||||||
|
|
||||||
|
## 运行演示
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python demo.py
|
||||||
|
```
|
||||||
|
|
||||||
|
演示脚本会自动:
|
||||||
|
1. 检查GPU环境信息
|
||||||
|
2. 初始化多模态检索系统
|
||||||
|
3. 演示四种检索模式
|
||||||
|
4. 显示检索结果和相似度分数
|
||||||
|
|
||||||
|
## 文件结构
|
||||||
|
|
||||||
|
```
|
||||||
|
mmeb/
|
||||||
|
├── multimodal_retrieval.py # 主要检索系统类
|
||||||
|
├── demo.py # 演示脚本
|
||||||
|
├── requirements.txt # 依赖包列表
|
||||||
|
├── README.md # 项目说明
|
||||||
|
└── images/ # 图像数据目录 (需要自行创建)
|
||||||
|
```
|
||||||
|
|
||||||
|
## API 参考
|
||||||
|
|
||||||
|
### MultimodalRetrieval 类
|
||||||
|
|
||||||
|
#### 初始化参数
|
||||||
|
- `model_name`: 模型名称,默认 "OpenSearch-AI/Ops-MM-embedding-v1-7B"
|
||||||
|
- `device`: 设备类型,默认自动选择 ("cuda" 或 "cpu")
|
||||||
|
|
||||||
|
#### 主要方法
|
||||||
|
|
||||||
|
##### `build_text_index(texts, save_path=None)`
|
||||||
|
构建文本索引
|
||||||
|
- `texts`: 文本列表
|
||||||
|
- `save_path`: 索引保存路径 (可选)
|
||||||
|
|
||||||
|
##### `build_image_index(image_paths, save_path=None)`
|
||||||
|
构建图像索引
|
||||||
|
- `image_paths`: 图像路径列表
|
||||||
|
- `save_path`: 索引保存路径 (可选)
|
||||||
|
|
||||||
|
##### `search_text_by_text(query, top_k=5)`
|
||||||
|
文搜文检索
|
||||||
|
- `query`: 查询文本
|
||||||
|
- `top_k`: 返回结果数量
|
||||||
|
- 返回: `[(文本, 相似度分数), ...]`
|
||||||
|
|
||||||
|
##### `search_images_by_text(query, top_k=5)`
|
||||||
|
文搜图检索
|
||||||
|
- `query`: 查询文本
|
||||||
|
- `top_k`: 返回结果数量
|
||||||
|
- 返回: `[(图像路径, 相似度分数), ...]`
|
||||||
|
|
||||||
|
##### `search_images_by_image(query_image, top_k=5)`
|
||||||
|
图搜图检索
|
||||||
|
- `query_image`: 查询图像路径或PIL图像
|
||||||
|
- `top_k`: 返回结果数量
|
||||||
|
- 返回: `[(图像路径, 相似度分数), ...]`
|
||||||
|
|
||||||
|
##### `search_text_by_image(query_image, top_k=5)`
|
||||||
|
图搜文检索
|
||||||
|
- `query_image`: 查询图像路径或PIL图像
|
||||||
|
- `top_k`: 返回结果数量
|
||||||
|
- 返回: `[(文本, 相似度分数), ...]`
|
||||||
|
|
||||||
|
## 注意事项
|
||||||
|
|
||||||
|
1. **首次运行**: 首次运行时会自动下载模型,需要网络连接
|
||||||
|
2. **内存需求**: 7B参数模型需要较大内存,建议使用GPU
|
||||||
|
3. **图像格式**: 支持常见图像格式 (jpg, png, bmp, gif等)
|
||||||
|
4. **批处理**: 系统自动进行批处理以提高效率
|
||||||
|
5. **索引保存**: 可以保存和加载索引以避免重复构建
|
||||||
|
|
||||||
|
## 性能优化建议
|
||||||
|
|
||||||
|
- 使用GPU加速推理
|
||||||
|
- 合理设置批处理大小
|
||||||
|
- 保存索引文件避免重复构建
|
||||||
|
- 对大量数据使用分批处理
|
||||||
|
|
||||||
|
## 故障排除
|
||||||
|
|
||||||
|
### 常见问题
|
||||||
|
|
||||||
|
1. **CUDA内存不足**: 减小批处理大小或使用CPU
|
||||||
|
2. **模型下载失败**: 检查网络连接或使用镜像源
|
||||||
|
3. **图像加载错误**: 检查图像文件路径和格式
|
||||||
|
|
||||||
|
### 日志信息
|
||||||
|
|
||||||
|
系统会输出详细的日志信息,包括:
|
||||||
|
- GPU环境检测结果
|
||||||
|
- 模型加载进度
|
||||||
|
- 索引构建状态
|
||||||
|
- 检索执行情况
|
||||||
|
|
||||||
|
## 许可证
|
||||||
|
|
||||||
|
本项目遵循 MIT 许可证。
|
||||||
BIN
__pycache__/multimodal_retrieval_multigpu.cpython-310.pyc
Normal file
BIN
__pycache__/multimodal_retrieval_multigpu.cpython-310.pyc
Normal file
Binary file not shown.
632
multimodal_retrieval_multigpu.py
Normal file
632
multimodal_retrieval_multigpu.py
Normal file
@ -0,0 +1,632 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
import faiss
|
||||||
|
from transformers import AutoModel, AutoProcessor, AutoTokenizer
|
||||||
|
from typing import List, Union, Tuple, Dict
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
import logging
|
||||||
|
import gc
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
import threading
|
||||||
|
|
||||||
|
# 设置日志
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class MultiGPUMultimodalRetrieval:
|
||||||
|
"""多GPU优化的多模态检索系统,支持文搜图、文搜文、图搜图、图搜文"""
|
||||||
|
|
||||||
|
def __init__(self, model_name: str = "OpenSearch-AI/Ops-MM-embedding-v1-7B",
|
||||||
|
use_all_gpus: bool = True, gpu_ids: List[int] = None, min_memory_gb=12):
|
||||||
|
"""
|
||||||
|
初始化多GPU多模态检索系统
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: 模型名称
|
||||||
|
use_all_gpus: 是否使用所有可用GPU
|
||||||
|
gpu_ids: 指定使用的GPU ID列表
|
||||||
|
min_memory_gb: 最小可用内存(GB)
|
||||||
|
"""
|
||||||
|
self.model_name = model_name
|
||||||
|
|
||||||
|
# 设置GPU设备
|
||||||
|
self._setup_devices(use_all_gpus, gpu_ids, min_memory_gb)
|
||||||
|
|
||||||
|
# 清理GPU内存
|
||||||
|
self._clear_all_gpu_memory()
|
||||||
|
|
||||||
|
logger.info(f"正在加载模型到多GPU: {self.device_ids}")
|
||||||
|
|
||||||
|
# 加载模型和处理器
|
||||||
|
self.model = None
|
||||||
|
self.tokenizer = None
|
||||||
|
self.processor = None
|
||||||
|
self._load_model_multigpu()
|
||||||
|
|
||||||
|
# 初始化索引
|
||||||
|
self.text_index = None
|
||||||
|
self.image_index = None
|
||||||
|
self.text_data = []
|
||||||
|
self.image_data = []
|
||||||
|
|
||||||
|
logger.info("多GPU模型加载完成")
|
||||||
|
|
||||||
|
def _setup_devices(self, use_all_gpus: bool, gpu_ids: List[int], min_memory_gb=12):
|
||||||
|
"""设置GPU设备"""
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
raise RuntimeError("CUDA不可用,无法使用多GPU")
|
||||||
|
|
||||||
|
total_gpus = torch.cuda.device_count()
|
||||||
|
logger.info(f"检测到 {total_gpus} 个GPU")
|
||||||
|
|
||||||
|
# 检查是否设置了CUDA_VISIBLE_DEVICES
|
||||||
|
cuda_visible_devices = os.environ.get('CUDA_VISIBLE_DEVICES')
|
||||||
|
if cuda_visible_devices is not None:
|
||||||
|
# 如果设置了CUDA_VISIBLE_DEVICES,使用可见的GPU
|
||||||
|
visible_gpu_count = len(cuda_visible_devices.split(','))
|
||||||
|
self.device_ids = list(range(visible_gpu_count))
|
||||||
|
logger.info(f"使用CUDA_VISIBLE_DEVICES指定的GPU: {cuda_visible_devices}")
|
||||||
|
elif use_all_gpus:
|
||||||
|
self.device_ids = self._select_best_gpus(min_memory_gb)
|
||||||
|
elif gpu_ids:
|
||||||
|
self.device_ids = gpu_ids
|
||||||
|
else:
|
||||||
|
self.device_ids = [0]
|
||||||
|
|
||||||
|
self.num_gpus = len(self.device_ids)
|
||||||
|
self.primary_device = f"cuda:{self.device_ids[0]}"
|
||||||
|
|
||||||
|
logger.info(f"使用GPU: {self.device_ids}, 主设备: {self.primary_device}")
|
||||||
|
|
||||||
|
def _clear_all_gpu_memory(self):
|
||||||
|
"""清理所有GPU内存"""
|
||||||
|
for gpu_id in self.device_ids:
|
||||||
|
torch.cuda.set_device(gpu_id)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
gc.collect()
|
||||||
|
logger.info("所有GPU内存已清理")
|
||||||
|
|
||||||
|
def _load_model_multigpu(self):
|
||||||
|
"""加载模型到多GPU"""
|
||||||
|
try:
|
||||||
|
# 设置环境变量优化内存使用
|
||||||
|
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
|
||||||
|
|
||||||
|
# 清理GPU内存
|
||||||
|
self._clear_gpu_memory()
|
||||||
|
|
||||||
|
# 首先尝试使用accelerate的自动设备映射
|
||||||
|
if self.num_gpus > 1:
|
||||||
|
# 设置最大内存限制(每个GPU 18GB,留出缓冲)
|
||||||
|
max_memory = {i: "18GiB" for i in self.device_ids}
|
||||||
|
|
||||||
|
logger.info(f"正在加载模型到多GPU: {self.device_ids}")
|
||||||
|
self.model = AutoModel.from_pretrained(
|
||||||
|
self.model_name,
|
||||||
|
trust_remote_code=True,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
device_map="auto",
|
||||||
|
max_memory=max_memory,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
offload_folder="./offload"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 单GPU加载
|
||||||
|
self.model = AutoModel.from_pretrained(
|
||||||
|
self.model_name,
|
||||||
|
trust_remote_code=True,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
device_map=self.primary_device
|
||||||
|
)
|
||||||
|
|
||||||
|
# 加载分词器和处理器到主设备
|
||||||
|
try:
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
self.model_name,
|
||||||
|
trust_remote_code=True
|
||||||
|
)
|
||||||
|
logger.info("Tokenizer加载成功")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Tokenizer加载失败: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 加载处理器用于图像处理
|
||||||
|
try:
|
||||||
|
self.processor = AutoProcessor.from_pretrained(
|
||||||
|
self.model_name,
|
||||||
|
trust_remote_code=True
|
||||||
|
)
|
||||||
|
logger.info("Processor加载成功")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Processor加载失败: {e}")
|
||||||
|
# 如果AutoProcessor失败,尝试使用tokenizer作为fallback
|
||||||
|
logger.info("尝试使用tokenizer作为processor的fallback")
|
||||||
|
self.processor = self.tokenizer
|
||||||
|
|
||||||
|
logger.info(f"模型已成功加载到设备: {self.model.hf_device_map if hasattr(self.model, 'hf_device_map') else self.primary_device}")
|
||||||
|
logger.info("多GPU模型加载完成")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"多GPU模型加载失败: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _clear_gpu_memory(self):
|
||||||
|
"""清理GPU内存"""
|
||||||
|
for gpu_id in self.device_ids:
|
||||||
|
torch.cuda.set_device(gpu_id)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
gc.collect()
|
||||||
|
logger.info("GPU内存已清理")
|
||||||
|
|
||||||
|
def _get_gpu_memory_info(self):
|
||||||
|
"""获取GPU内存使用情况"""
|
||||||
|
try:
|
||||||
|
import subprocess
|
||||||
|
result = subprocess.run(['nvidia-smi', '--query-gpu=memory.used,memory.total', '--format=csv,nounits,noheader'],
|
||||||
|
capture_output=True, text=True, check=True)
|
||||||
|
lines = result.stdout.strip().split('\n')
|
||||||
|
gpu_info = []
|
||||||
|
for i, line in enumerate(lines):
|
||||||
|
used, total = map(int, line.split(', '))
|
||||||
|
free = total - used
|
||||||
|
gpu_info.append({
|
||||||
|
'gpu_id': i,
|
||||||
|
'used': used,
|
||||||
|
'total': total,
|
||||||
|
'free': free,
|
||||||
|
'usage_percent': (used / total) * 100
|
||||||
|
})
|
||||||
|
return gpu_info
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"无法获取GPU内存信息: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _select_best_gpus(self, min_memory_gb=12):
|
||||||
|
"""选择内存充足的GPU"""
|
||||||
|
gpu_info = self._get_gpu_memory_info()
|
||||||
|
if not gpu_info:
|
||||||
|
return list(range(torch.cuda.device_count()))
|
||||||
|
|
||||||
|
# 按可用内存排序
|
||||||
|
gpu_info.sort(key=lambda x: x['free'], reverse=True)
|
||||||
|
|
||||||
|
# 选择内存充足的GPU
|
||||||
|
min_memory_mb = min_memory_gb * 1024
|
||||||
|
suitable_gpus = []
|
||||||
|
|
||||||
|
for gpu in gpu_info:
|
||||||
|
if gpu['free'] >= min_memory_mb:
|
||||||
|
suitable_gpus.append(gpu['gpu_id'])
|
||||||
|
logger.info(f"GPU {gpu['gpu_id']}: {gpu['free']}MB 可用 (合适)")
|
||||||
|
else:
|
||||||
|
logger.warning(f"GPU {gpu['gpu_id']}: {gpu['free']}MB 可用 (不足)")
|
||||||
|
|
||||||
|
if not suitable_gpus:
|
||||||
|
# 如果没有GPU满足要求,选择可用内存最多的
|
||||||
|
logger.warning(f"没有GPU有足够内存({min_memory_gb}GB),选择可用内存最多的GPU")
|
||||||
|
suitable_gpus = [gpu_info[0]['gpu_id']]
|
||||||
|
|
||||||
|
return suitable_gpus
|
||||||
|
|
||||||
|
def encode_text_batch(self, texts: List[str]) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
批量编码文本为向量(多GPU优化)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: 文本列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
文本向量
|
||||||
|
"""
|
||||||
|
if not texts:
|
||||||
|
return np.array([])
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
# 预处理输入
|
||||||
|
inputs = self.tokenizer(
|
||||||
|
text=texts,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
max_length=512
|
||||||
|
)
|
||||||
|
|
||||||
|
# 将输入移动到主设备
|
||||||
|
inputs = {k: v.to(self.primary_device) for k, v in inputs.items()}
|
||||||
|
|
||||||
|
# 前向传播
|
||||||
|
outputs = self.model(**inputs)
|
||||||
|
embeddings = outputs.last_hidden_state.mean(dim=1)
|
||||||
|
|
||||||
|
# 清理GPU内存
|
||||||
|
del inputs, outputs
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
return embeddings.cpu().numpy().astype(np.float32)
|
||||||
|
|
||||||
|
def encode_image_batch(self, images: List[Union[str, Image.Image]]) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
批量编码图像为向量
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images: 图像路径或PIL图像列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
图像向量
|
||||||
|
"""
|
||||||
|
if not images:
|
||||||
|
return np.array([])
|
||||||
|
|
||||||
|
# 预处理图像
|
||||||
|
processed_images = []
|
||||||
|
for img in images:
|
||||||
|
if isinstance(img, str):
|
||||||
|
img = Image.open(img).convert('RGB')
|
||||||
|
elif isinstance(img, Image.Image):
|
||||||
|
img = img.convert('RGB')
|
||||||
|
processed_images.append(img)
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info(f"处理 {len(processed_images)} 张图像")
|
||||||
|
|
||||||
|
# 使用多模态模型生成图像embedding
|
||||||
|
# 为每张图像创建简单的文本描述作为输入
|
||||||
|
conversations = []
|
||||||
|
for i in range(len(processed_images)):
|
||||||
|
# 使用简化的对话格式
|
||||||
|
conversation = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image", "image": processed_images[i]},
|
||||||
|
{"type": "text", "text": "What is in this image?"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
conversations.append(conversation)
|
||||||
|
|
||||||
|
# 使用processor处理
|
||||||
|
try:
|
||||||
|
# 尝试使用apply_chat_template方法
|
||||||
|
texts = []
|
||||||
|
for conv in conversations:
|
||||||
|
text = self.processor.apply_chat_template(conv, tokenize=False, add_generation_prompt=False)
|
||||||
|
texts.append(text)
|
||||||
|
|
||||||
|
# 处理文本和图像
|
||||||
|
inputs = self.processor(
|
||||||
|
text=texts,
|
||||||
|
images=processed_images,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# 移动到GPU
|
||||||
|
inputs = {k: v.to(self.primary_device) for k, v in inputs.items()}
|
||||||
|
|
||||||
|
# 获取模型输出
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = self.model(**inputs)
|
||||||
|
embeddings = outputs.last_hidden_state.mean(dim=1)
|
||||||
|
|
||||||
|
# 转换为numpy数组
|
||||||
|
embeddings = embeddings.cpu().numpy().astype(np.float32)
|
||||||
|
|
||||||
|
except Exception as inner_e:
|
||||||
|
logger.warning(f"多模态模型图像编码失败,使用文本模式: {inner_e}")
|
||||||
|
return np.zeros((len(processed_images), 3584), dtype=np.float32)
|
||||||
|
# 如果多模态失败,使用纯文本描述作为fallback
|
||||||
|
image_descriptions = ["An image" for _ in processed_images]
|
||||||
|
text_inputs = self.processor(
|
||||||
|
text=image_descriptions,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
max_length=512
|
||||||
|
)
|
||||||
|
text_inputs = {k: v.to(self.primary_device) for k, v in text_inputs.items()}
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = self.model(**text_inputs)
|
||||||
|
embeddings = outputs.last_hidden_state.mean(dim=1)
|
||||||
|
|
||||||
|
embeddings = embeddings.cpu().numpy().astype(np.float32)
|
||||||
|
|
||||||
|
logger.info(f"生成图像embeddings: {embeddings.shape}")
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"图像编码失败: {e}")
|
||||||
|
# 返回与文本embedding维度一致的零向量作为fallback
|
||||||
|
embedding_dim = 3584
|
||||||
|
embeddings = np.zeros((len(processed_images), embedding_dim), dtype=np.float32)
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
def build_text_index_parallel(self, texts: List[str], save_path: str = None):
|
||||||
|
"""
|
||||||
|
并行构建文本索引(多GPU优化)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: 文本列表
|
||||||
|
save_path: 索引保存路径
|
||||||
|
"""
|
||||||
|
logger.info(f"正在并行构建文本索引,共 {len(texts)} 条文本")
|
||||||
|
|
||||||
|
# 根据GPU数量调整批次大小
|
||||||
|
batch_size = max(4, 16 // self.num_gpus)
|
||||||
|
all_embeddings = []
|
||||||
|
|
||||||
|
# 分批处理
|
||||||
|
for i in range(0, len(texts), batch_size):
|
||||||
|
batch_texts = texts[i:i+batch_size]
|
||||||
|
|
||||||
|
try:
|
||||||
|
embeddings = self.encode_text_batch(batch_texts)
|
||||||
|
all_embeddings.append(embeddings)
|
||||||
|
|
||||||
|
# 显示进度
|
||||||
|
if (i // batch_size + 1) % 10 == 0:
|
||||||
|
logger.info(f"已处理 {i + len(batch_texts)}/{len(texts)} 条文本")
|
||||||
|
|
||||||
|
except torch.cuda.OutOfMemoryError:
|
||||||
|
logger.warning(f"GPU内存不足,跳过批次 {i}-{i+len(batch_texts)}")
|
||||||
|
self._clear_all_gpu_memory()
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"处理文本批次时出错: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not all_embeddings:
|
||||||
|
raise ValueError("没有成功处理任何文本")
|
||||||
|
|
||||||
|
# 合并所有嵌入向量
|
||||||
|
embeddings = np.vstack(all_embeddings)
|
||||||
|
|
||||||
|
# 构建FAISS索引
|
||||||
|
dimension = embeddings.shape[1]
|
||||||
|
self.text_index = faiss.IndexFlatIP(dimension)
|
||||||
|
|
||||||
|
# 归一化向量
|
||||||
|
faiss.normalize_L2(embeddings)
|
||||||
|
self.text_index.add(embeddings)
|
||||||
|
|
||||||
|
self.text_data = texts
|
||||||
|
|
||||||
|
if save_path:
|
||||||
|
self._save_index(self.text_index, texts, save_path + "_text")
|
||||||
|
|
||||||
|
logger.info("文本索引构建完成")
|
||||||
|
|
||||||
|
def build_image_index_parallel(self, image_paths: List[str], save_path: str = None):
|
||||||
|
"""
|
||||||
|
并行构建图像索引(多GPU优化)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_paths: 图像路径列表
|
||||||
|
save_path: 索引保存路径
|
||||||
|
"""
|
||||||
|
logger.info(f"正在并行构建图像索引,共 {len(image_paths)} 张图像")
|
||||||
|
|
||||||
|
# 图像处理使用更小的批次
|
||||||
|
batch_size = max(2, 8 // self.num_gpus)
|
||||||
|
all_embeddings = []
|
||||||
|
|
||||||
|
for i in range(0, len(image_paths), batch_size):
|
||||||
|
batch_images = image_paths[i:i+batch_size]
|
||||||
|
|
||||||
|
try:
|
||||||
|
embeddings = self.encode_image_batch(batch_images)
|
||||||
|
all_embeddings.append(embeddings)
|
||||||
|
|
||||||
|
# 显示进度
|
||||||
|
if (i // batch_size + 1) % 5 == 0:
|
||||||
|
logger.info(f"已处理 {i + len(batch_images)}/{len(image_paths)} 张图像")
|
||||||
|
|
||||||
|
except torch.cuda.OutOfMemoryError:
|
||||||
|
logger.warning(f"GPU内存不足,跳过图像批次 {i}-{i+len(batch_images)}")
|
||||||
|
self._clear_all_gpu_memory()
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"处理图像批次时出错: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not all_embeddings:
|
||||||
|
raise ValueError("没有成功处理任何图像")
|
||||||
|
|
||||||
|
embeddings = np.vstack(all_embeddings)
|
||||||
|
|
||||||
|
# 构建FAISS索引
|
||||||
|
dimension = embeddings.shape[1]
|
||||||
|
self.image_index = faiss.IndexFlatIP(dimension)
|
||||||
|
|
||||||
|
faiss.normalize_L2(embeddings)
|
||||||
|
self.image_index.add(embeddings)
|
||||||
|
|
||||||
|
self.image_data = image_paths
|
||||||
|
|
||||||
|
if save_path:
|
||||||
|
self._save_index(self.image_index, image_paths, save_path + "_image")
|
||||||
|
|
||||||
|
logger.info("图像索引构建完成")
|
||||||
|
|
||||||
|
def search_text_by_text(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
|
||||||
|
"""文搜文:使用文本查询搜索相似文本"""
|
||||||
|
if self.text_index is None:
|
||||||
|
raise ValueError("文本索引未构建,请先调用 build_text_index_parallel")
|
||||||
|
|
||||||
|
query_embedding = self.encode_text_batch([query]).astype(np.float32)
|
||||||
|
faiss.normalize_L2(query_embedding)
|
||||||
|
|
||||||
|
scores, indices = self.text_index.search(query_embedding, top_k)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for score, idx in zip(scores[0], indices[0]):
|
||||||
|
if idx != -1:
|
||||||
|
results.append((self.text_data[idx], float(score)))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def search_images_by_text(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
|
||||||
|
"""文搜图:使用文本查询搜索相似图像"""
|
||||||
|
if self.image_index is None:
|
||||||
|
raise ValueError("图像索引未构建,请先调用 build_image_index_parallel")
|
||||||
|
|
||||||
|
query_embedding = self.encode_text_batch([query]).astype(np.float32)
|
||||||
|
faiss.normalize_L2(query_embedding)
|
||||||
|
|
||||||
|
scores, indices = self.image_index.search(query_embedding, top_k)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for score, idx in zip(scores[0], indices[0]):
|
||||||
|
if idx != -1:
|
||||||
|
results.append((self.image_data[idx], float(score)))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def search_images_by_image(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]:
|
||||||
|
"""图搜图:使用图像查询搜索相似图像"""
|
||||||
|
if self.image_index is None:
|
||||||
|
raise ValueError("图像索引未构建,请先调用 build_image_index_parallel")
|
||||||
|
|
||||||
|
query_embedding = self.encode_image_batch([query_image]).astype(np.float32)
|
||||||
|
faiss.normalize_L2(query_embedding)
|
||||||
|
|
||||||
|
scores, indices = self.image_index.search(query_embedding, top_k)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for score, idx in zip(scores[0], indices[0]):
|
||||||
|
if idx != -1:
|
||||||
|
results.append((self.image_data[idx], float(score)))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def search_text_by_image(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]:
|
||||||
|
"""图搜文:使用图像查询搜索相似文本"""
|
||||||
|
if self.text_index is None:
|
||||||
|
raise ValueError("文本索引未构建,请先调用 build_text_index_parallel")
|
||||||
|
|
||||||
|
query_embedding = self.encode_image_batch([query_image]).astype(np.float32)
|
||||||
|
faiss.normalize_L2(query_embedding)
|
||||||
|
|
||||||
|
scores, indices = self.text_index.search(query_embedding, top_k)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for score, idx in zip(scores[0], indices[0]):
|
||||||
|
if idx != -1:
|
||||||
|
results.append((self.text_data[idx], float(score)))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
# Web应用兼容的方法名称
|
||||||
|
def search_text_to_image(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
|
||||||
|
"""文搜图:Web应用兼容方法"""
|
||||||
|
return self.search_images_by_text(query, top_k)
|
||||||
|
|
||||||
|
def search_image_to_image(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]:
|
||||||
|
"""图搜图:Web应用兼容方法"""
|
||||||
|
return self.search_images_by_image(query_image, top_k)
|
||||||
|
|
||||||
|
def search_text_to_text(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
|
||||||
|
"""文搜文:Web应用兼容方法"""
|
||||||
|
return self.search_text_by_text(query, top_k)
|
||||||
|
|
||||||
|
def search_image_to_text(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]:
|
||||||
|
"""图搜文:Web应用兼容方法"""
|
||||||
|
return self.search_text_by_image(query_image, top_k)
|
||||||
|
|
||||||
|
def _save_index(self, index, data, path_prefix):
|
||||||
|
"""保存索引和数据"""
|
||||||
|
faiss.write_index(index, f"{path_prefix}.index")
|
||||||
|
with open(f"{path_prefix}.json", 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
def load_index(self, path_prefix, index_type="text"):
|
||||||
|
"""加载已保存的索引"""
|
||||||
|
index = faiss.read_index(f"{path_prefix}.index")
|
||||||
|
with open(f"{path_prefix}.json", 'r', encoding='utf-8') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
if index_type == "text":
|
||||||
|
self.text_index = index
|
||||||
|
self.text_data = data
|
||||||
|
else:
|
||||||
|
self.image_index = index
|
||||||
|
self.image_data = data
|
||||||
|
|
||||||
|
logger.info(f"已加载 {index_type} 索引")
|
||||||
|
|
||||||
|
def get_gpu_memory_info(self):
|
||||||
|
"""获取所有GPU内存使用信息"""
|
||||||
|
memory_info = {}
|
||||||
|
for gpu_id in self.device_ids:
|
||||||
|
torch.cuda.set_device(gpu_id)
|
||||||
|
allocated = torch.cuda.memory_allocated(gpu_id) / 1024**3
|
||||||
|
cached = torch.cuda.memory_reserved(gpu_id) / 1024**3
|
||||||
|
total = torch.cuda.get_device_properties(gpu_id).total_memory / 1024**3
|
||||||
|
free = total - cached
|
||||||
|
|
||||||
|
memory_info[f"GPU_{gpu_id}"] = {
|
||||||
|
"total": f"{total:.1f}GB",
|
||||||
|
"allocated": f"{allocated:.1f}GB",
|
||||||
|
"cached": f"{cached:.1f}GB",
|
||||||
|
"free": f"{free:.1f}GB"
|
||||||
|
}
|
||||||
|
|
||||||
|
return memory_info
|
||||||
|
|
||||||
|
def check_multigpu_info():
|
||||||
|
"""检查多GPU环境信息"""
|
||||||
|
print("=== 多GPU环境信息 ===")
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
print("❌ CUDA不可用")
|
||||||
|
return
|
||||||
|
|
||||||
|
gpu_count = torch.cuda.device_count()
|
||||||
|
print(f"✅ 检测到 {gpu_count} 个GPU")
|
||||||
|
print(f"CUDA版本: {torch.version.cuda}")
|
||||||
|
print(f"PyTorch版本: {torch.__version__}")
|
||||||
|
|
||||||
|
for i in range(gpu_count):
|
||||||
|
gpu_name = torch.cuda.get_device_name(i)
|
||||||
|
gpu_memory = torch.cuda.get_device_properties(i).total_memory / 1024**3
|
||||||
|
print(f"GPU {i}: {gpu_name} ({gpu_memory:.1f}GB)")
|
||||||
|
|
||||||
|
print("=====================")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 检查多GPU环境
|
||||||
|
check_multigpu_info()
|
||||||
|
|
||||||
|
# 示例使用
|
||||||
|
print("\n正在初始化多GPU多模态检索系统...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
retrieval_system = MultiGPUMultimodalRetrieval()
|
||||||
|
print("✅ 多GPU系统初始化成功!")
|
||||||
|
|
||||||
|
# 显示GPU内存使用情况
|
||||||
|
memory_info = retrieval_system.get_gpu_memory_info()
|
||||||
|
print("\n📊 GPU内存使用情况:")
|
||||||
|
for gpu, info in memory_info.items():
|
||||||
|
print(f" {gpu}: {info['allocated']} / {info['total']} (已用/总计)")
|
||||||
|
|
||||||
|
print("\n🚀 多GPU多模态检索系统就绪!")
|
||||||
|
print("支持的检索模式:")
|
||||||
|
print("1. 文搜文: search_text_by_text()")
|
||||||
|
print("2. 文搜图: search_images_by_text()")
|
||||||
|
print("3. 图搜图: search_images_by_image()")
|
||||||
|
print("4. 图搜文: search_text_by_image()")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ 多GPU系统初始化失败: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
309
ops_mm_embedding_v1.py
Normal file
309
ops_mm_embedding_v1.py
Normal file
@ -0,0 +1,309 @@
|
|||||||
|
import math
|
||||||
|
from typing import List, Optional, TypeAlias, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers import AutoModelForImageTextToText, AutoProcessor
|
||||||
|
|
||||||
|
ImageInput: TypeAlias = Union[Image.Image, List[Image.Image]]
|
||||||
|
BatchImageInput: TypeAlias = Union[List[Image.Image], List[List[Image.Image]]]
|
||||||
|
|
||||||
|
|
||||||
|
class OpsMMEmbeddingV1(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
device: str = "cuda",
|
||||||
|
max_length: Optional[int] = None,
|
||||||
|
attn_implementation: Optional[str] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.device = device
|
||||||
|
self.max_length = max_length
|
||||||
|
self.default_instruction = "You are a helpful assistant."
|
||||||
|
self.base_model = AutoModelForImageTextToText.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
attn_implementation=attn_implementation,
|
||||||
|
).to(self.device)
|
||||||
|
|
||||||
|
self.processor = AutoProcessor.from_pretrained(model_name, min_pixels=256 * 28 * 28, max_pixels=1280 * 28 * 28)
|
||||||
|
self.processor.tokenizer.padding_side = "left"
|
||||||
|
self.eval()
|
||||||
|
|
||||||
|
def encode_input(self, input):
|
||||||
|
hidden_states = self.base_model(**input, return_dict=True, output_hidden_states=True)
|
||||||
|
hidden_states = hidden_states.hidden_states[-1]
|
||||||
|
pooled_output = self._pooling(hidden_states)
|
||||||
|
return pooled_output
|
||||||
|
|
||||||
|
def _pooling(self, last_hidden_state):
|
||||||
|
batch_size = last_hidden_state.shape[0]
|
||||||
|
reps = last_hidden_state[torch.arange(batch_size), -1, :]
|
||||||
|
reps = torch.nn.functional.normalize(reps, p=2, dim=-1)
|
||||||
|
return reps
|
||||||
|
|
||||||
|
def _validate_instructions(
|
||||||
|
self,
|
||||||
|
texts: Optional[List[str]],
|
||||||
|
images: Optional[BatchImageInput],
|
||||||
|
instruction: Optional[Union[str, List[str]]],
|
||||||
|
) -> List[str]:
|
||||||
|
"""Validate and format instructions to match batch size"""
|
||||||
|
batch_size = max(len(x) if x is not None else 0 for x in [texts, images])
|
||||||
|
|
||||||
|
if instruction is None:
|
||||||
|
return [self.default_instruction] * batch_size
|
||||||
|
|
||||||
|
if isinstance(instruction, str):
|
||||||
|
return [instruction] * batch_size
|
||||||
|
|
||||||
|
if isinstance(instruction, list):
|
||||||
|
if len(instruction) != batch_size:
|
||||||
|
raise ValueError(f"Length of instruction list ({len(instruction)}) must match batch size ({batch_size}) when texts/images are provided")
|
||||||
|
return instruction
|
||||||
|
|
||||||
|
raise TypeError("instruction must be str, List[str] or None")
|
||||||
|
|
||||||
|
def _process_images(self, images: ImageInput) -> List[Image.Image]:
|
||||||
|
"""Convert single image or list of images to processed format"""
|
||||||
|
if isinstance(images, Image.Image) or isinstance(images, str):
|
||||||
|
return [fetch_image(images)]
|
||||||
|
return [fetch_image(i) for i in images]
|
||||||
|
|
||||||
|
def embed(
|
||||||
|
self,
|
||||||
|
texts: Optional[List[str]] = None,
|
||||||
|
images: Optional[BatchImageInput] = None,
|
||||||
|
instruction: Optional[Union[str, List[str]]] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Generate embeddings for text, images, or combined inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: List of text inputs (optional)
|
||||||
|
images: Can be:
|
||||||
|
- List[Image.Image]: Single image per input
|
||||||
|
- List[List[Image.Image]]: Multiple images per input
|
||||||
|
instruction: Instruction(s) for the model. Can be:
|
||||||
|
- None: use default instruction
|
||||||
|
- str: use same instruction for all inputs
|
||||||
|
- List[str]: per-input instructions (must match batch size)
|
||||||
|
"""
|
||||||
|
if texts is None and images is None:
|
||||||
|
raise ValueError("Either texts or images must be provided")
|
||||||
|
|
||||||
|
instructions = self._validate_instructions(texts, images, instruction)
|
||||||
|
|
||||||
|
# Determine batch size
|
||||||
|
batch_size = len(texts) if texts is not None else len(images) # type: ignore
|
||||||
|
|
||||||
|
input_texts, input_images = [], []
|
||||||
|
for i in range(batch_size):
|
||||||
|
text = texts[i] if texts is not None else None
|
||||||
|
image = images[i] if images is not None else None
|
||||||
|
|
||||||
|
input_str = ""
|
||||||
|
processed_image = None
|
||||||
|
if image is not None:
|
||||||
|
processed_image = self._process_images(image)
|
||||||
|
input_str += "<|vision_start|><|image_pad|><|vision_end|>" * len(processed_image)
|
||||||
|
|
||||||
|
if text is not None:
|
||||||
|
input_str += text
|
||||||
|
|
||||||
|
msg = f"<|im_start|>system\n{instructions[i]}<|im_end|>\n<|im_start|>user\n{input_str}<|im_end|>\n<|im_start|>assistant\n<|endoftext|>"
|
||||||
|
|
||||||
|
input_texts.append(msg)
|
||||||
|
input_images.append(processed_image)
|
||||||
|
|
||||||
|
# Only pass to processor if we actually have images
|
||||||
|
processed_images = input_images if any(img is not None for img in input_images) else None
|
||||||
|
|
||||||
|
inputs = self.processor(
|
||||||
|
text=input_texts,
|
||||||
|
images=processed_images,
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
max_length=self.max_length,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
embeddings = self.encode_input(inputs)
|
||||||
|
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
def get_text_embeddings(
|
||||||
|
self,
|
||||||
|
texts: List[str],
|
||||||
|
instruction: Optional[Union[str, List[str]]] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Convenience method for text-only embeddings"""
|
||||||
|
return self.get_fused_embeddings(texts=texts, instruction=instruction, **kwargs)
|
||||||
|
|
||||||
|
def get_image_embeddings(
|
||||||
|
self,
|
||||||
|
images: BatchImageInput,
|
||||||
|
instruction: Optional[Union[str, List[str]]] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Convenience method for image-only embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images: Can be:
|
||||||
|
- List[Image.Image]: Single image per input
|
||||||
|
- List[List[Image.Image]]: Multiple images per input
|
||||||
|
"""
|
||||||
|
return self.get_fused_embeddings(images=images, instruction=instruction, **kwargs)
|
||||||
|
|
||||||
|
def get_fused_embeddings(
|
||||||
|
self,
|
||||||
|
texts: Optional[List[str]] = None,
|
||||||
|
images: Optional[BatchImageInput] = None,
|
||||||
|
instruction: Optional[Union[str, List[str]]] = None,
|
||||||
|
batch_size: int = 8,
|
||||||
|
show_progress: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Batch processing for large collections of texts/images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: List of text inputs (optional)
|
||||||
|
images: Can be:
|
||||||
|
- List[Image.Image]: Single image per input
|
||||||
|
- List[List[Image.Image]]: Multiple images per input
|
||||||
|
instruction: Instruction(s) for the model
|
||||||
|
batch_size: Number of items to process at once
|
||||||
|
show_progress: Whether to display progress bar
|
||||||
|
"""
|
||||||
|
|
||||||
|
if texts is None and images is None:
|
||||||
|
raise ValueError("Either texts or images must be provided")
|
||||||
|
|
||||||
|
total_items = len(texts) if texts is not None else len(images) # type: ignore
|
||||||
|
num_batches = math.ceil(total_items / batch_size)
|
||||||
|
|
||||||
|
all_embeddings = []
|
||||||
|
progress = tqdm(total=num_batches, disable=not show_progress, desc="Processing")
|
||||||
|
|
||||||
|
for i in range(0, total_items, batch_size):
|
||||||
|
batch_texts = texts[i : i + batch_size] if texts is not None else None
|
||||||
|
batch_images = images[i : i + batch_size] if images is not None else None
|
||||||
|
batch_emb = self.embed(texts=batch_texts, images=batch_images, instruction=instruction)
|
||||||
|
|
||||||
|
all_embeddings.append(batch_emb.cpu())
|
||||||
|
progress.update(1)
|
||||||
|
|
||||||
|
progress.close()
|
||||||
|
return torch.cat(all_embeddings, dim=0).to(self.device)
|
||||||
|
|
||||||
|
def forward(self, **inputs) -> torch.Tensor:
|
||||||
|
"""Alias for encode_input"""
|
||||||
|
return self.encode_input(inputs)
|
||||||
|
|
||||||
|
|
||||||
|
### Modified from qwen_vl_utils.vision_process.py
|
||||||
|
import base64
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
IMAGE_FACTOR = 28
|
||||||
|
MIN_PIXELS = 256 * 28 * 28
|
||||||
|
MAX_PIXELS = 1280 * 28 * 28
|
||||||
|
MAX_RATIO = 200
|
||||||
|
|
||||||
|
|
||||||
|
def round_by_factor(number: int, factor: int) -> int:
|
||||||
|
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
|
||||||
|
return round(number / factor) * factor
|
||||||
|
|
||||||
|
|
||||||
|
def ceil_by_factor(number: int | float, factor: int) -> int:
|
||||||
|
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
|
||||||
|
return math.ceil(number / factor) * factor
|
||||||
|
|
||||||
|
|
||||||
|
def floor_by_factor(number: int | float, factor: int) -> int:
|
||||||
|
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
|
||||||
|
return math.floor(number / factor) * factor
|
||||||
|
|
||||||
|
|
||||||
|
def smart_resize(
|
||||||
|
height: int,
|
||||||
|
width: int,
|
||||||
|
factor: int = IMAGE_FACTOR,
|
||||||
|
min_pixels: int = MIN_PIXELS,
|
||||||
|
max_pixels: int = MAX_PIXELS,
|
||||||
|
) -> tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Rescales the image so that the following conditions are met:
|
||||||
|
1. Both dimensions (height and width) are divisible by 'factor'.
|
||||||
|
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
||||||
|
3. The aspect ratio of the image is maintained as closely as possible.
|
||||||
|
"""
|
||||||
|
h_bar = max(factor, round_by_factor(height, factor))
|
||||||
|
w_bar = max(factor, round_by_factor(width, factor))
|
||||||
|
if h_bar * w_bar > max_pixels:
|
||||||
|
beta = math.sqrt((height * width) / max_pixels)
|
||||||
|
h_bar = floor_by_factor(height / beta, factor)
|
||||||
|
w_bar = floor_by_factor(width / beta, factor)
|
||||||
|
elif h_bar * w_bar < min_pixels:
|
||||||
|
beta = math.sqrt(min_pixels / (height * width))
|
||||||
|
h_bar = ceil_by_factor(height * beta, factor)
|
||||||
|
w_bar = ceil_by_factor(width * beta, factor)
|
||||||
|
|
||||||
|
if max(h_bar, w_bar) / min(h_bar, w_bar) > MAX_RATIO:
|
||||||
|
logging.warning(f"Absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(h_bar, w_bar) / min(h_bar, w_bar)}")
|
||||||
|
if h_bar > w_bar:
|
||||||
|
h_bar = w_bar * MAX_RATIO
|
||||||
|
else:
|
||||||
|
w_bar = h_bar * MAX_RATIO
|
||||||
|
return h_bar, w_bar
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_image(
|
||||||
|
image: str | Image.Image,
|
||||||
|
size_factor: int = IMAGE_FACTOR,
|
||||||
|
min_pixels: int = MIN_PIXELS,
|
||||||
|
max_pixels: int = MAX_PIXELS,
|
||||||
|
) -> Image.Image:
|
||||||
|
image_obj = None
|
||||||
|
if isinstance(image, Image.Image):
|
||||||
|
image_obj = image
|
||||||
|
elif image.startswith("http://") or image.startswith("https://"):
|
||||||
|
image_obj = Image.open(requests.get(image, stream=True).raw) # type: ignore
|
||||||
|
elif image.startswith("file://"):
|
||||||
|
image_obj = Image.open(image[7:])
|
||||||
|
elif image.startswith("data:image"):
|
||||||
|
if "base64," in image:
|
||||||
|
_, base64_data = image.split("base64,", 1)
|
||||||
|
data = base64.b64decode(base64_data)
|
||||||
|
image_obj = Image.open(BytesIO(data))
|
||||||
|
else:
|
||||||
|
image_obj = Image.open(image)
|
||||||
|
if image_obj is None:
|
||||||
|
raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
|
||||||
|
image = image_obj.convert("RGB")
|
||||||
|
width, height = image.size
|
||||||
|
resized_height, resized_width = smart_resize(
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
factor=size_factor,
|
||||||
|
min_pixels=min_pixels,
|
||||||
|
max_pixels=max_pixels,
|
||||||
|
)
|
||||||
|
image = image.resize((resized_width, resized_height))
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
###
|
||||||
12
requirements.txt
Normal file
12
requirements.txt
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
torch>=2.0.0
|
||||||
|
torchvision>=0.15.0
|
||||||
|
transformers>=4.30.0
|
||||||
|
accelerate>=0.20.0
|
||||||
|
faiss-cpu>=1.7.4
|
||||||
|
numpy>=1.21.0
|
||||||
|
Pillow>=9.0.0
|
||||||
|
scikit-learn>=1.3.0
|
||||||
|
tqdm>=4.65.0
|
||||||
|
flask>=2.3.0
|
||||||
|
werkzeug>=2.3.0
|
||||||
|
psutil>=5.9.0
|
||||||
BIN
sample_images/1755681385_4__.jpg
Normal file
BIN
sample_images/1755681385_4__.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 189 KiB |
BIN
sample_images/1755681385_5__.jpg
Normal file
BIN
sample_images/1755681385_5__.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 201 KiB |
BIN
sample_images/1755681385_6__.jpg
Normal file
BIN
sample_images/1755681385_6__.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 144 KiB |
BIN
sample_images/1755681385_7__.jpg
Normal file
BIN
sample_images/1755681385_7__.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 166 KiB |
971
templates/index.html
Normal file
971
templates/index.html
Normal file
@ -0,0 +1,971 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="zh-CN">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>多模态检索系统</title>
|
||||||
|
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css" rel="stylesheet">
|
||||||
|
<link href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css" rel="stylesheet">
|
||||||
|
<style>
|
||||||
|
:root {
|
||||||
|
--primary-color: #2563eb;
|
||||||
|
--secondary-color: #64748b;
|
||||||
|
--success-color: #059669;
|
||||||
|
--warning-color: #d97706;
|
||||||
|
--danger-color: #dc2626;
|
||||||
|
--bg-gradient: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||||
|
}
|
||||||
|
|
||||||
|
body {
|
||||||
|
background: var(--bg-gradient);
|
||||||
|
min-height: 100vh;
|
||||||
|
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
|
||||||
|
}
|
||||||
|
|
||||||
|
.main-container {
|
||||||
|
background: rgba(255, 255, 255, 0.95);
|
||||||
|
backdrop-filter: blur(10px);
|
||||||
|
border-radius: 20px;
|
||||||
|
box-shadow: 0 20px 40px rgba(0, 0, 0, 0.1);
|
||||||
|
margin: 20px auto;
|
||||||
|
max-width: 1200px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.header {
|
||||||
|
background: linear-gradient(135deg, var(--primary-color), #3b82f6);
|
||||||
|
color: white;
|
||||||
|
padding: 2rem;
|
||||||
|
border-radius: 20px 20px 0 0;
|
||||||
|
text-align: center;
|
||||||
|
}
|
||||||
|
|
||||||
|
.mode-card {
|
||||||
|
background: white;
|
||||||
|
border-radius: 15px;
|
||||||
|
padding: 1.5rem;
|
||||||
|
margin-bottom: 1.5rem;
|
||||||
|
box-shadow: 0 4px 15px rgba(0, 0, 0, 0.1);
|
||||||
|
transition: all 0.3s ease;
|
||||||
|
border: 2px solid transparent;
|
||||||
|
}
|
||||||
|
|
||||||
|
.mode-card:hover {
|
||||||
|
transform: translateY(-5px);
|
||||||
|
box-shadow: 0 8px 25px rgba(0, 0, 0, 0.15);
|
||||||
|
}
|
||||||
|
|
||||||
|
.mode-card.active {
|
||||||
|
border-color: var(--primary-color);
|
||||||
|
background: linear-gradient(135deg, #eff6ff, #dbeafe);
|
||||||
|
}
|
||||||
|
|
||||||
|
.mode-icon {
|
||||||
|
font-size: 2.5rem;
|
||||||
|
margin-bottom: 1rem;
|
||||||
|
display: block;
|
||||||
|
}
|
||||||
|
|
||||||
|
.text-to-text { color: #059669; }
|
||||||
|
.text-to-image { color: #dc2626; }
|
||||||
|
.image-to-text { color: #d97706; }
|
||||||
|
.image-to-image { color: #7c3aed; }
|
||||||
|
|
||||||
|
.search-input {
|
||||||
|
border-radius: 12px;
|
||||||
|
border: 2px solid #e5e7eb;
|
||||||
|
padding: 12px 16px;
|
||||||
|
font-size: 16px;
|
||||||
|
transition: all 0.3s ease;
|
||||||
|
}
|
||||||
|
|
||||||
|
.search-input:focus {
|
||||||
|
border-color: var(--primary-color);
|
||||||
|
box-shadow: 0 0 0 3px rgba(37, 99, 235, 0.1);
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-primary {
|
||||||
|
background: var(--primary-color);
|
||||||
|
border: none;
|
||||||
|
border-radius: 12px;
|
||||||
|
padding: 12px 24px;
|
||||||
|
font-weight: 600;
|
||||||
|
transition: all 0.3s ease;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-primary:hover {
|
||||||
|
background: #1d4ed8;
|
||||||
|
transform: translateY(-2px);
|
||||||
|
}
|
||||||
|
|
||||||
|
.file-upload-area {
|
||||||
|
border: 3px dashed #d1d5db;
|
||||||
|
border-radius: 12px;
|
||||||
|
padding: 3rem;
|
||||||
|
text-align: center;
|
||||||
|
transition: all 0.3s ease;
|
||||||
|
cursor: pointer;
|
||||||
|
}
|
||||||
|
|
||||||
|
.file-upload-area:hover {
|
||||||
|
border-color: var(--primary-color);
|
||||||
|
background: rgba(37, 99, 235, 0.05);
|
||||||
|
}
|
||||||
|
|
||||||
|
.file-upload-area.dragover {
|
||||||
|
border-color: var(--primary-color);
|
||||||
|
background: rgba(37, 99, 235, 0.1);
|
||||||
|
}
|
||||||
|
|
||||||
|
.result-card {
|
||||||
|
background: white;
|
||||||
|
border-radius: 12px;
|
||||||
|
padding: 1.5rem;
|
||||||
|
margin-bottom: 1rem;
|
||||||
|
box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1);
|
||||||
|
border-left: 4px solid var(--primary-color);
|
||||||
|
}
|
||||||
|
|
||||||
|
.result-image {
|
||||||
|
max-width: 200px;
|
||||||
|
max-height: 150px;
|
||||||
|
border-radius: 8px;
|
||||||
|
object-fit: cover;
|
||||||
|
}
|
||||||
|
|
||||||
|
.score-badge {
|
||||||
|
background: var(--success-color);
|
||||||
|
color: white;
|
||||||
|
padding: 4px 12px;
|
||||||
|
border-radius: 20px;
|
||||||
|
font-size: 0.85rem;
|
||||||
|
font-weight: 600;
|
||||||
|
}
|
||||||
|
|
||||||
|
.loading-spinner {
|
||||||
|
display: none;
|
||||||
|
text-align: center;
|
||||||
|
padding: 2rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.status-indicator {
|
||||||
|
position: fixed;
|
||||||
|
top: 20px;
|
||||||
|
right: 20px;
|
||||||
|
z-index: 1000;
|
||||||
|
}
|
||||||
|
|
||||||
|
.fade-in {
|
||||||
|
animation: fadeIn 0.5s ease-in;
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes fadeIn {
|
||||||
|
from { opacity: 0; transform: translateY(20px); }
|
||||||
|
to { opacity: 1; transform: translateY(0); }
|
||||||
|
}
|
||||||
|
|
||||||
|
.query-image {
|
||||||
|
max-width: 300px;
|
||||||
|
max-height: 200px;
|
||||||
|
border-radius: 12px;
|
||||||
|
object-fit: cover;
|
||||||
|
box-shadow: 0 4px 15px rgba(0, 0, 0, 0.2);
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<!-- 状态指示器 -->
|
||||||
|
<div class="status-indicator">
|
||||||
|
<div id="statusBadge" class="badge bg-secondary">
|
||||||
|
<i class="fas fa-circle-notch fa-spin"></i> 未初始化
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="container-fluid">
|
||||||
|
<div class="main-container">
|
||||||
|
<!-- 头部 -->
|
||||||
|
<div class="header">
|
||||||
|
<h1><i class="fas fa-search"></i> 多模态检索系统</h1>
|
||||||
|
<p class="mb-0">支持文搜图、文搜文、图搜图、图搜文四种检索模式</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="p-4">
|
||||||
|
<!-- 重新初始化按钮 -->
|
||||||
|
<div class="text-center mb-4">
|
||||||
|
<button id="reinitBtn" class="btn btn-warning">
|
||||||
|
<i class="fas fa-redo"></i> 重新初始化系统
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- 检索模式选择 -->
|
||||||
|
<div class="row mb-4" id="modeSelection">
|
||||||
|
<div class="col-md-3">
|
||||||
|
<div class="mode-card text-center" data-mode="text_to_text">
|
||||||
|
<i class="fas fa-file-text mode-icon text-to-text"></i>
|
||||||
|
<h5>文搜文</h5>
|
||||||
|
<p class="text-muted">文本查找相似文本</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class="col-md-3">
|
||||||
|
<div class="mode-card text-center" data-mode="text_to_image">
|
||||||
|
<i class="fas fa-image mode-icon text-to-image"></i>
|
||||||
|
<h5>文搜图</h5>
|
||||||
|
<p class="text-muted">文本查找相关图片</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class="col-md-3">
|
||||||
|
<div class="mode-card text-center" data-mode="image_to_text">
|
||||||
|
<i class="fas fa-comment mode-icon image-to-text"></i>
|
||||||
|
<h5>图搜文</h5>
|
||||||
|
<p class="text-muted">图片查找相关文本</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class="col-md-3">
|
||||||
|
<div class="mode-card text-center" data-mode="image_to_image">
|
||||||
|
<i class="fas fa-images mode-icon image-to-image"></i>
|
||||||
|
<h5>图搜图</h5>
|
||||||
|
<p class="text-muted">图片查找相似图片</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- 数据管理界面 -->
|
||||||
|
<div class="row mb-4" id="dataManagement">
|
||||||
|
<div class="col-12">
|
||||||
|
<div class="card">
|
||||||
|
<div class="card-header">
|
||||||
|
<h5><i class="fas fa-database"></i> 数据管理</h5>
|
||||||
|
<small class="text-muted">上传和管理检索数据库</small>
|
||||||
|
</div>
|
||||||
|
<div class="card-body">
|
||||||
|
<div class="row">
|
||||||
|
<!-- 批量上传图片 -->
|
||||||
|
<div class="col-md-6">
|
||||||
|
<div class="upload-section">
|
||||||
|
<h6><i class="fas fa-images text-primary"></i> 批量上传图片</h6>
|
||||||
|
<div class="file-upload-area" id="batchImageUpload">
|
||||||
|
<i class="fas fa-cloud-upload-alt fa-2x text-muted mb-2"></i>
|
||||||
|
<p>拖拽多张图片到此处或点击选择</p>
|
||||||
|
<input type="file" id="batchImageFiles" multiple accept="image/*" style="display: none;">
|
||||||
|
<button class="btn btn-outline-primary btn-sm mt-2" onclick="document.getElementById('batchImageFiles').click()">
|
||||||
|
<i class="fas fa-folder-open"></i> 选择图片
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
<div id="imageUploadProgress" class="mt-2" style="display: none;">
|
||||||
|
<div class="progress">
|
||||||
|
<div class="progress-bar" role="progressbar" style="width: 0%"></div>
|
||||||
|
</div>
|
||||||
|
<small class="text-muted mt-1 d-block">上传进度: <span id="imageProgressText">0/0</span></small>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- 批量上传文本 -->
|
||||||
|
<div class="col-md-6">
|
||||||
|
<div class="upload-section">
|
||||||
|
<h6><i class="fas fa-file-text text-success"></i> 批量上传文本</h6>
|
||||||
|
<div class="mb-3">
|
||||||
|
<textarea id="batchTextInput" class="form-control" rows="8"
|
||||||
|
placeholder="请输入文本数据,每行一条文本记录... 例如: 这是第一条文本记录 这是第二条文本记录 这是第三条文本记录"></textarea>
|
||||||
|
</div>
|
||||||
|
<div class="d-flex gap-2">
|
||||||
|
<button id="uploadTextsBtn" class="btn btn-success">
|
||||||
|
<i class="fas fa-upload"></i> 上传文本
|
||||||
|
</button>
|
||||||
|
<button class="btn btn-outline-secondary" onclick="document.getElementById('textFile').click()">
|
||||||
|
<i class="fas fa-file-import"></i> 从文件导入
|
||||||
|
</button>
|
||||||
|
<input type="file" id="textFile" accept=".txt,.csv" style="display: none;">
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- 数据统计和管理 -->
|
||||||
|
<div class="row mt-4">
|
||||||
|
<div class="col-md-8">
|
||||||
|
<div class="d-flex gap-3">
|
||||||
|
<button id="buildIndexBtn" class="btn btn-warning" disabled>
|
||||||
|
<i class="fas fa-cogs"></i> 构建索引
|
||||||
|
</button>
|
||||||
|
<button id="viewDataBtn" class="btn btn-info">
|
||||||
|
<i class="fas fa-list"></i> 查看数据
|
||||||
|
</button>
|
||||||
|
<button id="clearDataBtn" class="btn btn-danger">
|
||||||
|
<i class="fas fa-trash"></i> 清空数据
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class="col-md-4">
|
||||||
|
<div id="dataStats" class="text-end">
|
||||||
|
<small class="text-muted">
|
||||||
|
图片: <span id="imageCount">0</span> 张 |
|
||||||
|
文本: <span id="textCount">0</span> 条
|
||||||
|
</small>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- 搜索界面 -->
|
||||||
|
<div id="searchInterface" style="display: none;">
|
||||||
|
<!-- 文本搜索 -->
|
||||||
|
<div id="textSearch" class="search-panel" style="display: none;">
|
||||||
|
<div class="row">
|
||||||
|
<div class="col-md-8">
|
||||||
|
<input type="text" id="textQuery" class="form-control search-input"
|
||||||
|
placeholder="请输入搜索文本...">
|
||||||
|
</div>
|
||||||
|
<div class="col-md-2">
|
||||||
|
<select id="textTopK" class="form-select search-input">
|
||||||
|
<option value="3">Top 3</option>
|
||||||
|
<option value="5" selected>Top 5</option>
|
||||||
|
<option value="10">Top 10</option>
|
||||||
|
</select>
|
||||||
|
</div>
|
||||||
|
<div class="col-md-2">
|
||||||
|
<button id="textSearchBtn" class="btn btn-primary w-100">
|
||||||
|
<i class="fas fa-search"></i> 搜索
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- 图片搜索 -->
|
||||||
|
<div id="imageSearch" class="search-panel" style="display: none;">
|
||||||
|
<div class="row">
|
||||||
|
<div class="col-md-8">
|
||||||
|
<div class="file-upload-area" id="fileUploadArea">
|
||||||
|
<i class="fas fa-cloud-upload-alt fa-3x text-muted mb-3"></i>
|
||||||
|
<h5>拖拽图片到此处或点击选择</h5>
|
||||||
|
<p class="text-muted">支持 PNG, JPG, JPEG, GIF, BMP, WebP 格式</p>
|
||||||
|
<input type="file" id="imageFile" accept="image/*" style="display: none;">
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class="col-md-2">
|
||||||
|
<select id="imageTopK" class="form-select search-input">
|
||||||
|
<option value="3">Top 3</option>
|
||||||
|
<option value="5" selected>Top 5</option>
|
||||||
|
<option value="10">Top 10</option>
|
||||||
|
</select>
|
||||||
|
</div>
|
||||||
|
<div class="col-md-2">
|
||||||
|
<button id="imageSearchBtn" class="btn btn-primary w-100" disabled>
|
||||||
|
<i class="fas fa-search"></i> 搜索
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- 加载动画 -->
|
||||||
|
<div class="loading-spinner" id="loadingSpinner">
|
||||||
|
<div class="spinner-border text-primary" role="status">
|
||||||
|
<span class="visually-hidden">Loading...</span>
|
||||||
|
</div>
|
||||||
|
<p class="mt-2">正在搜索中...</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- 搜索结果 -->
|
||||||
|
<div id="searchResults" class="mt-4"></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<script src="https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/js/bootstrap.bundle.min.js"></script>
|
||||||
|
<script>
|
||||||
|
let currentMode = null;
|
||||||
|
let systemInitialized = false;
|
||||||
|
|
||||||
|
// 重新初始化系统
|
||||||
|
document.getElementById('reinitBtn').addEventListener('click', async function() {
|
||||||
|
const btn = this;
|
||||||
|
const originalText = btn.innerHTML;
|
||||||
|
|
||||||
|
btn.innerHTML = '<i class="fas fa-spinner fa-spin"></i> 重新初始化中...';
|
||||||
|
btn.disabled = true;
|
||||||
|
|
||||||
|
try {
|
||||||
|
const response = await fetch('/api/init', {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {'Content-Type': 'application/json'}
|
||||||
|
});
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
|
||||||
|
if (data.success) {
|
||||||
|
systemInitialized = true;
|
||||||
|
document.getElementById('statusBadge').innerHTML =
|
||||||
|
'<i class="fas fa-check-circle"></i> 已重新初始化';
|
||||||
|
document.getElementById('statusBadge').className = 'badge bg-success';
|
||||||
|
|
||||||
|
showAlert('success', `系统重新初始化成功!GPU: ${data.gpu_count} 个`);
|
||||||
|
} else {
|
||||||
|
throw new Error(data.message);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
showAlert('danger', '重新初始化失败: ' + error.message);
|
||||||
|
} finally {
|
||||||
|
btn.innerHTML = originalText;
|
||||||
|
btn.disabled = false;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// 模式选择
|
||||||
|
document.querySelectorAll('.mode-card').forEach(card => {
|
||||||
|
card.addEventListener('click', function() {
|
||||||
|
|
||||||
|
// 更新选中状态
|
||||||
|
document.querySelectorAll('.mode-card').forEach(c => c.classList.remove('active'));
|
||||||
|
this.classList.add('active');
|
||||||
|
|
||||||
|
currentMode = this.dataset.mode;
|
||||||
|
setupSearchInterface(currentMode);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
// 设置搜索界面
|
||||||
|
function setupSearchInterface(mode) {
|
||||||
|
document.getElementById('searchInterface').style.display = 'block';
|
||||||
|
document.getElementById('textSearch').style.display = 'none';
|
||||||
|
document.getElementById('imageSearch').style.display = 'none';
|
||||||
|
document.getElementById('searchResults').innerHTML = '';
|
||||||
|
|
||||||
|
if (mode === 'text_to_text' || mode === 'text_to_image') {
|
||||||
|
document.getElementById('textSearch').style.display = 'block';
|
||||||
|
document.getElementById('textQuery').placeholder =
|
||||||
|
mode === 'text_to_text' ? '请输入要搜索的文本...' : '请输入要搜索图片的描述...';
|
||||||
|
} else {
|
||||||
|
document.getElementById('imageSearch').style.display = 'block';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 文本搜索
|
||||||
|
document.getElementById('textSearchBtn').addEventListener('click', performTextSearch);
|
||||||
|
document.getElementById('textQuery').addEventListener('keypress', function(e) {
|
||||||
|
if (e.key === 'Enter') performTextSearch();
|
||||||
|
});
|
||||||
|
|
||||||
|
async function performTextSearch() {
|
||||||
|
const query = document.getElementById('textQuery').value.trim();
|
||||||
|
const topK = parseInt(document.getElementById('textTopK').value);
|
||||||
|
|
||||||
|
if (!query) {
|
||||||
|
showAlert('warning', '请输入搜索文本');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
showLoading(true);
|
||||||
|
|
||||||
|
try {
|
||||||
|
const endpoint = currentMode === 'text_to_text' ? '/api/search/text_to_text' : '/api/search/text_to_image';
|
||||||
|
const response = await fetch(endpoint, {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {'Content-Type': 'application/json'},
|
||||||
|
body: JSON.stringify({query, top_k: topK})
|
||||||
|
});
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
|
||||||
|
if (data.success) {
|
||||||
|
displayResults(data, currentMode);
|
||||||
|
} else {
|
||||||
|
throw new Error(data.message);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
showAlert('danger', '搜索失败: ' + error.message);
|
||||||
|
} finally {
|
||||||
|
showLoading(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 图片上传处理
|
||||||
|
const fileUploadArea = document.getElementById('fileUploadArea');
|
||||||
|
const imageFile = document.getElementById('imageFile');
|
||||||
|
|
||||||
|
fileUploadArea.addEventListener('click', () => imageFile.click());
|
||||||
|
fileUploadArea.addEventListener('dragover', handleDragOver);
|
||||||
|
fileUploadArea.addEventListener('drop', handleDrop);
|
||||||
|
imageFile.addEventListener('change', handleFileSelect);
|
||||||
|
|
||||||
|
function handleDragOver(e) {
|
||||||
|
e.preventDefault();
|
||||||
|
fileUploadArea.classList.add('dragover');
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleDrop(e) {
|
||||||
|
e.preventDefault();
|
||||||
|
fileUploadArea.classList.remove('dragover');
|
||||||
|
const files = e.dataTransfer.files;
|
||||||
|
if (files.length > 0) {
|
||||||
|
handleFile(files[0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleFileSelect(e) {
|
||||||
|
const file = e.target.files[0];
|
||||||
|
if (file) handleFile(file);
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleFile(file) {
|
||||||
|
if (!file.type.startsWith('image/')) {
|
||||||
|
showAlert('warning', '请选择图片文件');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const reader = new FileReader();
|
||||||
|
reader.onload = function(e) {
|
||||||
|
fileUploadArea.innerHTML = `
|
||||||
|
<img src="${e.target.result}" class="query-image mb-3">
|
||||||
|
<p class="text-success"><i class="fas fa-check"></i> 图片已选择: ${file.name}</p>
|
||||||
|
`;
|
||||||
|
document.getElementById('imageSearchBtn').disabled = false;
|
||||||
|
};
|
||||||
|
reader.readAsDataURL(file);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 图片搜索
|
||||||
|
document.getElementById('imageSearchBtn').addEventListener('click', async function() {
|
||||||
|
const file = imageFile.files[0];
|
||||||
|
const topK = parseInt(document.getElementById('imageTopK').value);
|
||||||
|
|
||||||
|
if (!file) {
|
||||||
|
showAlert('warning', '请选择图片文件');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
showLoading(true);
|
||||||
|
|
||||||
|
try {
|
||||||
|
const formData = new FormData();
|
||||||
|
formData.append('image', file);
|
||||||
|
formData.append('top_k', topK);
|
||||||
|
|
||||||
|
const endpoint = currentMode === 'image_to_text' ? '/api/search/image_to_text' : '/api/search/image_to_image';
|
||||||
|
const response = await fetch(endpoint, {
|
||||||
|
method: 'POST',
|
||||||
|
body: formData
|
||||||
|
});
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
|
||||||
|
if (data.success) {
|
||||||
|
displayResults(data, currentMode);
|
||||||
|
} else {
|
||||||
|
throw new Error(data.message);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
showAlert('danger', '搜索失败: ' + error.message);
|
||||||
|
} finally {
|
||||||
|
showLoading(false);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// 显示结果
|
||||||
|
function displayResults(data, mode) {
|
||||||
|
const resultsContainer = document.getElementById('searchResults');
|
||||||
|
|
||||||
|
let html = `
|
||||||
|
<div class="fade-in">
|
||||||
|
<div class="d-flex justify-content-between align-items-center mb-3">
|
||||||
|
<h4><i class="fas fa-search-plus"></i> 搜索结果</h4>
|
||||||
|
<div>
|
||||||
|
<span class="badge bg-info">找到 ${data.result_count} 个结果</span>
|
||||||
|
<span class="badge bg-secondary">耗时 ${data.search_time}s</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
`;
|
||||||
|
|
||||||
|
if (data.query_image) {
|
||||||
|
html += `
|
||||||
|
<div class="result-card">
|
||||||
|
<h6><i class="fas fa-image"></i> 查询图片</h6>
|
||||||
|
<img src="data:image/jpeg;base64,${data.query_image}" class="query-image">
|
||||||
|
</div>
|
||||||
|
`;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (data.query) {
|
||||||
|
html += `
|
||||||
|
<div class="result-card">
|
||||||
|
<h6><i class="fas fa-quote-left"></i> 查询文本</h6>
|
||||||
|
<p class="mb-0">"${data.query}"</p>
|
||||||
|
</div>
|
||||||
|
`;
|
||||||
|
}
|
||||||
|
|
||||||
|
data.results.forEach((result, index) => {
|
||||||
|
html += '<div class="result-card">';
|
||||||
|
|
||||||
|
if (mode === 'text_to_image' || mode === 'image_to_image') {
|
||||||
|
html += `
|
||||||
|
<div class="row">
|
||||||
|
<div class="col-md-3">
|
||||||
|
<img src="data:image/jpeg;base64,${result.image_base64}"
|
||||||
|
class="result-image" alt="Result ${index + 1}">
|
||||||
|
</div>
|
||||||
|
<div class="col-md-9">
|
||||||
|
<div class="d-flex justify-content-between align-items-start">
|
||||||
|
<h6><i class="fas fa-image"></i> ${result.filename}</h6>
|
||||||
|
<span class="score-badge">相似度: ${(result.score * 100).toFixed(1)}%</span>
|
||||||
|
</div>
|
||||||
|
<p class="text-muted mb-0">路径: ${result.image_path}</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
`;
|
||||||
|
} else {
|
||||||
|
html += `
|
||||||
|
<div class="d-flex justify-content-between align-items-start">
|
||||||
|
<div>
|
||||||
|
<h6><i class="fas fa-file-text"></i> 结果 ${index + 1}</h6>
|
||||||
|
<p class="mb-0">${result.text || result}</p>
|
||||||
|
</div>
|
||||||
|
<span class="score-badge">相似度: ${((result.score || 0.95) * 100).toFixed(1)}%</span>
|
||||||
|
</div>
|
||||||
|
`;
|
||||||
|
}
|
||||||
|
|
||||||
|
html += '</div>';
|
||||||
|
});
|
||||||
|
|
||||||
|
html += '</div>';
|
||||||
|
resultsContainer.innerHTML = html;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 工具函数
|
||||||
|
function showLoading(show) {
|
||||||
|
document.getElementById('loadingSpinner').style.display = show ? 'block' : 'none';
|
||||||
|
}
|
||||||
|
|
||||||
|
function showAlert(type, message) {
|
||||||
|
const alertDiv = document.createElement('div');
|
||||||
|
alertDiv.className = `alert alert-${type} alert-dismissible fade show`;
|
||||||
|
alertDiv.innerHTML = `
|
||||||
|
${message}
|
||||||
|
<button type="button" class="btn-close" data-bs-dismiss="alert"></button>
|
||||||
|
`;
|
||||||
|
|
||||||
|
document.querySelector('.main-container .p-4').insertBefore(alertDiv, document.querySelector('.main-container .p-4').firstChild);
|
||||||
|
|
||||||
|
setTimeout(() => {
|
||||||
|
if (alertDiv.parentNode) {
|
||||||
|
alertDiv.remove();
|
||||||
|
}
|
||||||
|
}, 5000);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查系统状态
|
||||||
|
async function checkStatus() {
|
||||||
|
try {
|
||||||
|
const response = await fetch('/api/status');
|
||||||
|
const data = await response.json();
|
||||||
|
|
||||||
|
if (data.initialized) {
|
||||||
|
systemInitialized = true;
|
||||||
|
document.getElementById('statusBadge').innerHTML =
|
||||||
|
'<i class="fas fa-check-circle"></i> 已初始化';
|
||||||
|
document.getElementById('statusBadge').className = 'badge bg-success';
|
||||||
|
} else {
|
||||||
|
document.getElementById('statusBadge').innerHTML =
|
||||||
|
'<i class="fas fa-exclamation-triangle"></i> 未初始化';
|
||||||
|
document.getElementById('statusBadge').className = 'badge bg-warning';
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.log('Status check failed:', error);
|
||||||
|
document.getElementById('statusBadge').innerHTML =
|
||||||
|
'<i class="fas fa-times-circle"></i> 连接失败';
|
||||||
|
document.getElementById('statusBadge').className = 'badge bg-danger';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 页面加载时检查状态
|
||||||
|
checkStatus();
|
||||||
|
|
||||||
|
// 设置数据管理功能事件绑定
|
||||||
|
setupDataManagement();
|
||||||
|
|
||||||
|
function setupDataManagement() {
|
||||||
|
// 批量图片上传事件
|
||||||
|
const batchImageUpload = document.getElementById('batchImageUpload');
|
||||||
|
const batchImageFiles = document.getElementById('batchImageFiles');
|
||||||
|
|
||||||
|
// 拖拽上传
|
||||||
|
batchImageUpload.addEventListener('dragover', function(e) {
|
||||||
|
e.preventDefault();
|
||||||
|
this.classList.add('dragover');
|
||||||
|
});
|
||||||
|
|
||||||
|
batchImageUpload.addEventListener('dragleave', function(e) {
|
||||||
|
e.preventDefault();
|
||||||
|
this.classList.remove('dragover');
|
||||||
|
});
|
||||||
|
|
||||||
|
batchImageUpload.addEventListener('drop', function(e) {
|
||||||
|
e.preventDefault();
|
||||||
|
this.classList.remove('dragover');
|
||||||
|
const files = Array.from(e.dataTransfer.files).filter(file => file.type.startsWith('image/'));
|
||||||
|
if (files.length > 0) {
|
||||||
|
uploadBatchImages(files);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
batchImageFiles.addEventListener('change', function(e) {
|
||||||
|
if (e.target.files.length > 0) {
|
||||||
|
uploadBatchImages(Array.from(e.target.files));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// 批量文本上传
|
||||||
|
document.getElementById('uploadTextsBtn').addEventListener('click', function() {
|
||||||
|
const textData = document.getElementById('batchTextInput').value.trim();
|
||||||
|
if (textData) {
|
||||||
|
const texts = textData.split('\n').filter(line => line.trim());
|
||||||
|
if (texts.length > 0) {
|
||||||
|
uploadBatchTexts(texts);
|
||||||
|
} else {
|
||||||
|
showAlert('warning', '请输入有效的文本数据');
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
showAlert('warning', '请输入文本数据');
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// 从文件导入文本
|
||||||
|
document.getElementById('textFile').addEventListener('change', function(e) {
|
||||||
|
const file = e.target.files[0];
|
||||||
|
if (file) {
|
||||||
|
const reader = new FileReader();
|
||||||
|
reader.onload = function(e) {
|
||||||
|
document.getElementById('batchTextInput').value = e.target.result;
|
||||||
|
};
|
||||||
|
reader.readAsText(file);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// 构建索引
|
||||||
|
document.getElementById('buildIndexBtn').addEventListener('click', buildIndex);
|
||||||
|
|
||||||
|
// 查看数据
|
||||||
|
document.getElementById('viewDataBtn').addEventListener('click', viewData);
|
||||||
|
|
||||||
|
// 清空数据
|
||||||
|
document.getElementById('clearDataBtn').addEventListener('click', clearData);
|
||||||
|
|
||||||
|
// 初始化时更新数据统计
|
||||||
|
updateDataStats();
|
||||||
|
}
|
||||||
|
|
||||||
|
// 批量上传图片
|
||||||
|
async function uploadBatchImages(files) {
|
||||||
|
const progressDiv = document.getElementById('imageUploadProgress');
|
||||||
|
const progressBar = progressDiv.querySelector('.progress-bar');
|
||||||
|
const progressText = document.getElementById('imageProgressText');
|
||||||
|
|
||||||
|
progressDiv.style.display = 'block';
|
||||||
|
progressText.textContent = `0/${files.length}`;
|
||||||
|
progressBar.style.width = '0%';
|
||||||
|
|
||||||
|
const formData = new FormData();
|
||||||
|
files.forEach(file => {
|
||||||
|
formData.append('images', file);
|
||||||
|
});
|
||||||
|
|
||||||
|
try {
|
||||||
|
const response = await fetch('/api/upload/images', {
|
||||||
|
method: 'POST',
|
||||||
|
body: formData
|
||||||
|
});
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
|
||||||
|
if (data.success) {
|
||||||
|
progressBar.style.width = '100%';
|
||||||
|
progressText.textContent = `${files.length}/${files.length}`;
|
||||||
|
showAlert('success', `成功上传 ${data.uploaded_count} 张图片`);
|
||||||
|
updateDataStats();
|
||||||
|
document.getElementById('buildIndexBtn').disabled = false;
|
||||||
|
} else {
|
||||||
|
showAlert('danger', `上传失败: ${data.message}`);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
showAlert('danger', `上传错误: ${error.message}`);
|
||||||
|
} finally {
|
||||||
|
setTimeout(() => {
|
||||||
|
progressDiv.style.display = 'none';
|
||||||
|
}, 2000);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 批量上传文本
|
||||||
|
async function uploadBatchTexts(texts) {
|
||||||
|
try {
|
||||||
|
const response = await fetch('/api/upload/texts', {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
},
|
||||||
|
body: JSON.stringify({ texts: texts })
|
||||||
|
});
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
|
||||||
|
if (data.success) {
|
||||||
|
showAlert('success', `成功上传 ${data.uploaded_count} 条文本`);
|
||||||
|
document.getElementById('batchTextInput').value = '';
|
||||||
|
updateDataStats();
|
||||||
|
document.getElementById('buildIndexBtn').disabled = false;
|
||||||
|
} else {
|
||||||
|
showAlert('danger', `上传失败: ${data.message}`);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
showAlert('danger', `上传错误: ${error.message}`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 构建索引
|
||||||
|
async function buildIndex() {
|
||||||
|
const btn = document.getElementById('buildIndexBtn');
|
||||||
|
const originalText = btn.innerHTML;
|
||||||
|
btn.innerHTML = '<i class="fas fa-spinner fa-spin"></i> 构建中...';
|
||||||
|
btn.disabled = true;
|
||||||
|
|
||||||
|
try {
|
||||||
|
const response = await fetch('/api/build_index', {
|
||||||
|
method: 'POST'
|
||||||
|
});
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
|
||||||
|
if (data.success) {
|
||||||
|
showAlert('success', '索引构建完成!现在可以进行搜索了');
|
||||||
|
} else {
|
||||||
|
showAlert('danger', `索引构建失败: ${data.message}`);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
showAlert('danger', `构建错误: ${error.message}`);
|
||||||
|
} finally {
|
||||||
|
btn.innerHTML = originalText;
|
||||||
|
btn.disabled = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 查看数据
|
||||||
|
async function viewData() {
|
||||||
|
try {
|
||||||
|
const response = await fetch('/api/data/list');
|
||||||
|
const data = await response.json();
|
||||||
|
|
||||||
|
if (data.success) {
|
||||||
|
let content = '<div class="row">';
|
||||||
|
|
||||||
|
// 显示图片数据
|
||||||
|
if (data.images && data.images.length > 0) {
|
||||||
|
content += '<div class="col-md-6"><h6>图片数据 (' + data.images.length + ')</h6>';
|
||||||
|
content += '<div class="list-group" style="max-height: 300px; overflow-y: auto;">';
|
||||||
|
data.images.forEach(img => {
|
||||||
|
content += `<div class="list-group-item d-flex justify-content-between align-items-center">
|
||||||
|
<span>${img}</span>
|
||||||
|
<img src="/uploads/${img}" class="img-thumbnail" style="width: 50px; height: 50px; object-fit: cover;">
|
||||||
|
</div>`;
|
||||||
|
});
|
||||||
|
content += '</div></div>';
|
||||||
|
}
|
||||||
|
|
||||||
|
// 显示文本数据
|
||||||
|
if (data.texts && data.texts.length > 0) {
|
||||||
|
content += '<div class="col-md-6"><h6>文本数据 (' + data.texts.length + ')</h6>';
|
||||||
|
content += '<div class="list-group" style="max-height: 300px; overflow-y: auto;">';
|
||||||
|
data.texts.forEach((text, index) => {
|
||||||
|
const shortText = text.length > 50 ? text.substring(0, 50) + '...' : text;
|
||||||
|
content += `<div class="list-group-item">
|
||||||
|
<small class="text-muted">#${index + 1}</small><br>
|
||||||
|
${shortText}
|
||||||
|
</div>`;
|
||||||
|
});
|
||||||
|
content += '</div></div>';
|
||||||
|
}
|
||||||
|
|
||||||
|
content += '</div>';
|
||||||
|
|
||||||
|
showModal('数据列表', content);
|
||||||
|
} else {
|
||||||
|
showAlert('danger', `获取数据失败: ${data.message}`);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
showAlert('danger', `获取数据错误: ${error.message}`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 清空数据
|
||||||
|
async function clearData() {
|
||||||
|
if (!confirm('确定要清空所有数据吗?此操作不可恢复!')) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
const response = await fetch('/api/data/clear', {
|
||||||
|
method: 'POST'
|
||||||
|
});
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
|
||||||
|
if (data.success) {
|
||||||
|
showAlert('success', '数据已清空');
|
||||||
|
updateDataStats();
|
||||||
|
document.getElementById('buildIndexBtn').disabled = true;
|
||||||
|
} else {
|
||||||
|
showAlert('danger', `清空失败: ${data.message}`);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
showAlert('danger', `清空错误: ${error.message}`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新数据统计
|
||||||
|
async function updateDataStats() {
|
||||||
|
try {
|
||||||
|
const response = await fetch('/api/data/stats');
|
||||||
|
const data = await response.json();
|
||||||
|
|
||||||
|
if (data.success) {
|
||||||
|
document.getElementById('imageCount').textContent = data.image_count || 0;
|
||||||
|
document.getElementById('textCount').textContent = data.text_count || 0;
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.log('获取数据统计失败:', error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 显示模态框
|
||||||
|
function showModal(title, content) {
|
||||||
|
const modalHtml = `
|
||||||
|
<div class="modal fade" id="dataModal" tabindex="-1">
|
||||||
|
<div class="modal-dialog modal-lg">
|
||||||
|
<div class="modal-content">
|
||||||
|
<div class="modal-header">
|
||||||
|
<h5 class="modal-title">${title}</h5>
|
||||||
|
<button type="button" class="btn-close" data-bs-dismiss="modal"></button>
|
||||||
|
</div>
|
||||||
|
<div class="modal-body">
|
||||||
|
${content}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
`;
|
||||||
|
|
||||||
|
// 移除已存在的模态框
|
||||||
|
const existingModal = document.getElementById('dataModal');
|
||||||
|
if (existingModal) {
|
||||||
|
existingModal.remove();
|
||||||
|
}
|
||||||
|
|
||||||
|
document.body.insertAdjacentHTML('beforeend', modalHtml);
|
||||||
|
const modal = new bootstrap.Modal(document.getElementById('dataModal'));
|
||||||
|
modal.show();
|
||||||
|
}
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
104
test_all_retrieval_modes.py
Normal file
104
test_all_retrieval_modes.py
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
测试所有四种多模态检索模式
|
||||||
|
"""
|
||||||
|
|
||||||
|
from multimodal_retrieval_multigpu import MultiGPUMultimodalRetrieval
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
import os
|
||||||
|
|
||||||
|
def test_all_retrieval_modes():
|
||||||
|
print('正在初始化多GPU多模态检索系统...')
|
||||||
|
retrieval = MultiGPUMultimodalRetrieval()
|
||||||
|
|
||||||
|
# 准备测试数据
|
||||||
|
test_texts = [
|
||||||
|
"一只可爱的小猫",
|
||||||
|
"美丽的风景照片",
|
||||||
|
"现代建筑设计",
|
||||||
|
"colorful flowers in garden"
|
||||||
|
]
|
||||||
|
|
||||||
|
test_images = [
|
||||||
|
'sample_images/1755677101_1__.jpg',
|
||||||
|
'sample_images/1755677101_2__.jpg',
|
||||||
|
'sample_images/1755677101_3__.jpg',
|
||||||
|
'sample_images/1755677101_4__.jpg'
|
||||||
|
]
|
||||||
|
|
||||||
|
# 验证测试图像存在
|
||||||
|
existing_images = [img for img in test_images if os.path.exists(img)]
|
||||||
|
if not existing_images:
|
||||||
|
print("❌ 没有找到测试图像文件")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"找到 {len(existing_images)} 张测试图像")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 1. 构建文本索引
|
||||||
|
print('\n=== 构建文本索引 ===')
|
||||||
|
retrieval.build_text_index_parallel(test_texts)
|
||||||
|
print('✅ 文本索引构建完成')
|
||||||
|
|
||||||
|
# 2. 构建图像索引
|
||||||
|
print('\n=== 构建图像索引 ===')
|
||||||
|
retrieval.build_image_index_parallel(existing_images)
|
||||||
|
print('✅ 图像索引构建完成')
|
||||||
|
|
||||||
|
# 3. 测试文本到文本检索
|
||||||
|
print('\n=== 测试文本到文本检索 ===')
|
||||||
|
query = "小动物"
|
||||||
|
results = retrieval.search_text_by_text(query, top_k=3)
|
||||||
|
print(f'查询: "{query}"')
|
||||||
|
for i, (text, score) in enumerate(results):
|
||||||
|
print(f' {i+1}. {text} (相似度: {score:.4f})')
|
||||||
|
|
||||||
|
# 4. 测试文本到图像检索
|
||||||
|
print('\n=== 测试文本到图像检索 ===')
|
||||||
|
query = "beautiful image"
|
||||||
|
results = retrieval.search_images_by_text(query, top_k=3)
|
||||||
|
print(f'查询: "{query}"')
|
||||||
|
for i, (image_path, score) in enumerate(results):
|
||||||
|
print(f' {i+1}. {image_path} (相似度: {score:.4f})')
|
||||||
|
|
||||||
|
# 5. 测试图像到文本检索
|
||||||
|
print('\n=== 测试图像到文本检索 ===')
|
||||||
|
query_image = existing_images[0]
|
||||||
|
results = retrieval.search_text_by_image(query_image, top_k=3)
|
||||||
|
print(f'查询图像: {query_image}')
|
||||||
|
for i, (text, score) in enumerate(results):
|
||||||
|
print(f' {i+1}. {text} (相似度: {score:.4f})')
|
||||||
|
|
||||||
|
# 6. 测试图像到图像检索
|
||||||
|
print('\n=== 测试图像到图像检索 ===')
|
||||||
|
query_image = existing_images[0]
|
||||||
|
results = retrieval.search_images_by_image(query_image, top_k=3)
|
||||||
|
print(f'查询图像: {query_image}')
|
||||||
|
for i, (image_path, score) in enumerate(results):
|
||||||
|
print(f' {i+1}. {image_path} (相似度: {score:.4f})')
|
||||||
|
|
||||||
|
print('\n✅ 所有四种检索模式测试完成!')
|
||||||
|
|
||||||
|
# 7. 测试Web应用兼容的方法名
|
||||||
|
print('\n=== 测试Web应用兼容方法 ===')
|
||||||
|
try:
|
||||||
|
results = retrieval.search_text_to_image("test query", top_k=2)
|
||||||
|
print('✅ search_text_to_image 方法正常')
|
||||||
|
|
||||||
|
results = retrieval.search_image_to_text(existing_images[0], top_k=2)
|
||||||
|
print('✅ search_image_to_text 方法正常')
|
||||||
|
|
||||||
|
results = retrieval.search_image_to_image(existing_images[0], top_k=2)
|
||||||
|
print('✅ search_image_to_image 方法正常')
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f'❌ Web应用兼容方法测试失败: {e}')
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f'❌ 测试过程中出现错误: {e}')
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_all_retrieval_modes()
|
||||||
49
test_image_encoding.py
Normal file
49
test_image_encoding.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
测试图像编码功能
|
||||||
|
"""
|
||||||
|
|
||||||
|
from multimodal_retrieval_multigpu import MultiGPUMultimodalRetrieval
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
def test_image_encoding():
|
||||||
|
print('正在初始化多GPU多模态检索系统...')
|
||||||
|
retrieval = MultiGPUMultimodalRetrieval()
|
||||||
|
|
||||||
|
# 测试文本编码
|
||||||
|
print('测试文本编码...')
|
||||||
|
text_embeddings = retrieval.encode_text_batch(['这是一个测试文本'])
|
||||||
|
print(f'文本embedding形状: {text_embeddings.shape}')
|
||||||
|
print(f'文本embedding数据类型: {text_embeddings.dtype}')
|
||||||
|
|
||||||
|
# 测试图像编码
|
||||||
|
print('测试图像编码...')
|
||||||
|
test_images = ['sample_images/1755677101_1__.jpg']
|
||||||
|
image_embeddings = retrieval.encode_image_batch(test_images)
|
||||||
|
print(f'图像embedding形状: {image_embeddings.shape}')
|
||||||
|
print(f'图像embedding数据类型: {image_embeddings.dtype}')
|
||||||
|
|
||||||
|
# 测试两次相同图像的embedding是否一致
|
||||||
|
print('测试embedding一致性...')
|
||||||
|
image_embeddings2 = retrieval.encode_image_batch(test_images)
|
||||||
|
consistency = np.allclose(image_embeddings, image_embeddings2, rtol=1e-5)
|
||||||
|
print(f'相同图像embedding一致性: {consistency}')
|
||||||
|
|
||||||
|
# 测试不同图像的embedding差异
|
||||||
|
print('测试不同图像embedding差异...')
|
||||||
|
test_images2 = ['sample_images/1755677101_2__.jpg']
|
||||||
|
image_embeddings3 = retrieval.encode_image_batch(test_images2)
|
||||||
|
similarity = np.dot(image_embeddings[0], image_embeddings3[0]) / (np.linalg.norm(image_embeddings[0]) * np.linalg.norm(image_embeddings3[0]))
|
||||||
|
print(f'不同图像间相似度: {similarity:.4f}')
|
||||||
|
|
||||||
|
# 验证维度一致性
|
||||||
|
if text_embeddings.shape[1] == image_embeddings.shape[1]:
|
||||||
|
print('✅ 文本和图像embedding维度一致')
|
||||||
|
else:
|
||||||
|
print('❌ 文本和图像embedding维度不一致')
|
||||||
|
|
||||||
|
print('测试完成!')
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_image_encoding()
|
||||||
BIN
uploads/query_1755675080_10__.jpg
Normal file
BIN
uploads/query_1755675080_10__.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 172 KiB |
BIN
uploads/query_1755681423_5__.jpg
Normal file
BIN
uploads/query_1755681423_5__.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 201 KiB |
617
web_app_multigpu.py
Normal file
617
web_app_multigpu.py
Normal file
@ -0,0 +1,617 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
多GPU多模态检索系统 - Web应用
|
||||||
|
专为双GPU部署优化
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from flask import Flask, render_template, request, jsonify, send_file, url_for
|
||||||
|
from werkzeug.utils import secure_filename
|
||||||
|
from PIL import Image
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
import traceback
|
||||||
|
import glob
|
||||||
|
|
||||||
|
# 设置环境变量优化GPU内存
|
||||||
|
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
app = Flask(__name__)
|
||||||
|
app.config['SECRET_KEY'] = 'multigpu_multimodal_retrieval_2024'
|
||||||
|
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max file size
|
||||||
|
|
||||||
|
# 配置上传文件夹
|
||||||
|
UPLOAD_FOLDER = 'uploads'
|
||||||
|
SAMPLE_IMAGES_FOLDER = 'sample_images'
|
||||||
|
TEXT_DATA_FOLDER = 'text_data'
|
||||||
|
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'bmp', 'webp'}
|
||||||
|
|
||||||
|
# 确保文件夹存在
|
||||||
|
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
|
||||||
|
os.makedirs(SAMPLE_IMAGES_FOLDER, exist_ok=True)
|
||||||
|
os.makedirs(TEXT_DATA_FOLDER, exist_ok=True)
|
||||||
|
|
||||||
|
# 全局检索系统实例
|
||||||
|
retrieval_system = None
|
||||||
|
|
||||||
|
def allowed_file(filename):
|
||||||
|
"""检查文件扩展名是否允许"""
|
||||||
|
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
|
||||||
|
|
||||||
|
def image_to_base64(image_path):
|
||||||
|
"""将图片转换为base64编码"""
|
||||||
|
try:
|
||||||
|
with open(image_path, "rb") as img_file:
|
||||||
|
return base64.b64encode(img_file.read()).decode('utf-8')
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"图片转换失败: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
@app.route('/')
|
||||||
|
def index():
|
||||||
|
"""主页"""
|
||||||
|
return render_template('index.html')
|
||||||
|
|
||||||
|
@app.route('/api/status')
|
||||||
|
def get_status():
|
||||||
|
"""获取系统状态"""
|
||||||
|
global retrieval_system
|
||||||
|
|
||||||
|
status = {
|
||||||
|
'initialized': retrieval_system is not None,
|
||||||
|
'gpu_count': 0,
|
||||||
|
'model_loaded': False
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
status['gpu_count'] = torch.cuda.device_count()
|
||||||
|
|
||||||
|
if retrieval_system and retrieval_system.model:
|
||||||
|
status['model_loaded'] = True
|
||||||
|
status['device_ids'] = retrieval_system.device_ids
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取状态失败: {e}")
|
||||||
|
|
||||||
|
return jsonify(status)
|
||||||
|
|
||||||
|
@app.route('/api/init', methods=['POST'])
|
||||||
|
def initialize_system():
|
||||||
|
"""初始化多GPU检索系统"""
|
||||||
|
global retrieval_system
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info("正在初始化多GPU检索系统...")
|
||||||
|
|
||||||
|
# 导入多GPU检索系统
|
||||||
|
from multimodal_retrieval_multigpu import MultiGPUMultimodalRetrieval
|
||||||
|
|
||||||
|
# 初始化系统
|
||||||
|
retrieval_system = MultiGPUMultimodalRetrieval()
|
||||||
|
|
||||||
|
if retrieval_system.model is None:
|
||||||
|
raise Exception("模型加载失败")
|
||||||
|
|
||||||
|
logger.info("✅ 多GPU系统初始化成功")
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
'success': True,
|
||||||
|
'message': '多GPU系统初始化成功',
|
||||||
|
'device_ids': retrieval_system.device_ids,
|
||||||
|
'gpu_count': len(retrieval_system.device_ids)
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"系统初始化失败: {str(e)}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
'success': False,
|
||||||
|
'message': error_msg
|
||||||
|
}), 500
|
||||||
|
|
||||||
|
@app.route('/api/search/text_to_text', methods=['POST'])
|
||||||
|
def search_text_to_text():
|
||||||
|
"""文本搜索文本"""
|
||||||
|
return handle_search('text_to_text')
|
||||||
|
|
||||||
|
@app.route('/api/search/text_to_image', methods=['POST'])
|
||||||
|
def search_text_to_image():
|
||||||
|
"""文本搜索图片"""
|
||||||
|
return handle_search('text_to_image')
|
||||||
|
|
||||||
|
@app.route('/api/search/image_to_text', methods=['POST'])
|
||||||
|
def search_image_to_text():
|
||||||
|
"""图片搜索文本"""
|
||||||
|
return handle_search('image_to_text')
|
||||||
|
|
||||||
|
@app.route('/api/search/image_to_image', methods=['POST'])
|
||||||
|
def search_image_to_image():
|
||||||
|
"""图片搜索图片"""
|
||||||
|
return handle_search('image_to_image')
|
||||||
|
|
||||||
|
@app.route('/api/search', methods=['POST'])
|
||||||
|
def search():
|
||||||
|
"""通用搜索接口(兼容旧版本)"""
|
||||||
|
mode = request.form.get('mode') or request.json.get('mode', 'text_to_text')
|
||||||
|
return handle_search(mode)
|
||||||
|
|
||||||
|
def handle_search(mode):
|
||||||
|
"""处理搜索请求的通用函数"""
|
||||||
|
global retrieval_system
|
||||||
|
|
||||||
|
if not retrieval_system:
|
||||||
|
return jsonify({
|
||||||
|
'success': False,
|
||||||
|
'message': '系统未初始化,请先点击初始化按钮'
|
||||||
|
}), 400
|
||||||
|
|
||||||
|
try:
|
||||||
|
top_k = int(request.form.get('top_k', 5))
|
||||||
|
|
||||||
|
if mode in ['text_to_text', 'text_to_image']:
|
||||||
|
# 文本查询
|
||||||
|
query = request.form.get('query') or request.json.get('query', '')
|
||||||
|
if not query.strip():
|
||||||
|
return jsonify({
|
||||||
|
'success': False,
|
||||||
|
'message': '请输入查询文本'
|
||||||
|
}), 400
|
||||||
|
|
||||||
|
logger.info(f"执行{mode}搜索: {query}")
|
||||||
|
|
||||||
|
# 执行搜索
|
||||||
|
if mode == 'text_to_text':
|
||||||
|
results = retrieval_system.search_text_to_text(query, top_k=top_k)
|
||||||
|
else: # text_to_image
|
||||||
|
results = retrieval_system.search_text_to_image(query, top_k=top_k)
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
'success': True,
|
||||||
|
'mode': mode,
|
||||||
|
'query': query,
|
||||||
|
'results': results,
|
||||||
|
'result_count': len(results)
|
||||||
|
})
|
||||||
|
|
||||||
|
elif mode in ['image_to_text', 'image_to_image']:
|
||||||
|
# 图片查询
|
||||||
|
if 'image' not in request.files:
|
||||||
|
return jsonify({
|
||||||
|
'success': False,
|
||||||
|
'message': '请上传查询图片'
|
||||||
|
}), 400
|
||||||
|
|
||||||
|
file = request.files['image']
|
||||||
|
if file.filename == '' or not allowed_file(file.filename):
|
||||||
|
return jsonify({
|
||||||
|
'success': False,
|
||||||
|
'message': '请上传有效的图片文件'
|
||||||
|
}), 400
|
||||||
|
|
||||||
|
# 保存上传的图片
|
||||||
|
filename = secure_filename(file.filename)
|
||||||
|
timestamp = str(int(time.time()))
|
||||||
|
filename = f"query_{timestamp}_{filename}"
|
||||||
|
filepath = os.path.join(UPLOAD_FOLDER, filename)
|
||||||
|
file.save(filepath)
|
||||||
|
|
||||||
|
logger.info(f"执行{mode}搜索,图片: {filename}")
|
||||||
|
|
||||||
|
# 执行搜索
|
||||||
|
if mode == 'image_to_text':
|
||||||
|
results = retrieval_system.search_image_to_text(filepath, top_k=top_k)
|
||||||
|
else: # image_to_image
|
||||||
|
results = retrieval_system.search_image_to_image(filepath, top_k=top_k)
|
||||||
|
|
||||||
|
# 转换查询图片为base64
|
||||||
|
query_image_b64 = image_to_base64(filepath)
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
'success': True,
|
||||||
|
'mode': mode,
|
||||||
|
'query_image': query_image_b64,
|
||||||
|
'results': results,
|
||||||
|
'result_count': len(results)
|
||||||
|
})
|
||||||
|
|
||||||
|
else:
|
||||||
|
return jsonify({
|
||||||
|
'success': False,
|
||||||
|
'message': f'不支持的搜索模式: {mode}'
|
||||||
|
}), 400
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"搜索失败: {str(e)}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
'success': False,
|
||||||
|
'message': error_msg
|
||||||
|
}), 500
|
||||||
|
|
||||||
|
@app.route('/api/upload/images', methods=['POST'])
|
||||||
|
def upload_images():
|
||||||
|
"""批量上传图片"""
|
||||||
|
try:
|
||||||
|
uploaded_files = []
|
||||||
|
|
||||||
|
if 'images' not in request.files:
|
||||||
|
return jsonify({'success': False, 'message': '没有选择文件'}), 400
|
||||||
|
|
||||||
|
files = request.files.getlist('images')
|
||||||
|
|
||||||
|
for file in files:
|
||||||
|
if file and file.filename != '' and allowed_file(file.filename):
|
||||||
|
filename = secure_filename(file.filename)
|
||||||
|
timestamp = str(int(time.time()))
|
||||||
|
filename = f"{timestamp}_{filename}"
|
||||||
|
filepath = os.path.join(SAMPLE_IMAGES_FOLDER, filename)
|
||||||
|
file.save(filepath)
|
||||||
|
uploaded_files.append(filename)
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
'success': True,
|
||||||
|
'message': f'成功上传 {len(uploaded_files)} 个图片文件',
|
||||||
|
'uploaded_count': len(uploaded_files),
|
||||||
|
'files': uploaded_files
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return jsonify({
|
||||||
|
'success': False,
|
||||||
|
'message': f'上传失败: {str(e)}'
|
||||||
|
}), 500
|
||||||
|
|
||||||
|
@app.route('/api/upload/texts', methods=['POST'])
|
||||||
|
def upload_texts():
|
||||||
|
"""批量上传文本数据"""
|
||||||
|
try:
|
||||||
|
data = request.get_json()
|
||||||
|
|
||||||
|
if not data or 'texts' not in data:
|
||||||
|
return jsonify({'success': False, 'message': '没有提供文本数据'}), 400
|
||||||
|
|
||||||
|
texts = data['texts']
|
||||||
|
if not isinstance(texts, list):
|
||||||
|
return jsonify({'success': False, 'message': '文本数据格式错误'}), 400
|
||||||
|
|
||||||
|
# 保存文本数据到文件
|
||||||
|
timestamp = str(int(time.time()))
|
||||||
|
filename = f"texts_{timestamp}.json"
|
||||||
|
filepath = os.path.join(TEXT_DATA_FOLDER, filename)
|
||||||
|
|
||||||
|
with open(filepath, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(texts, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
'success': True,
|
||||||
|
'message': f'成功上传 {len(texts)} 条文本',
|
||||||
|
'uploaded_count': len(texts)
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return jsonify({
|
||||||
|
'success': False,
|
||||||
|
'message': f'上传失败: {str(e)}'
|
||||||
|
}), 500
|
||||||
|
|
||||||
|
@app.route('/api/upload/file', methods=['POST'])
|
||||||
|
def upload_single_file():
|
||||||
|
"""上传单个文件"""
|
||||||
|
if 'file' not in request.files:
|
||||||
|
return jsonify({'success': False, 'message': '没有选择文件'}), 400
|
||||||
|
|
||||||
|
file = request.files['file']
|
||||||
|
if file.filename == '':
|
||||||
|
return jsonify({'success': False, 'message': '没有选择文件'}), 400
|
||||||
|
|
||||||
|
if file and allowed_file(file.filename):
|
||||||
|
filename = secure_filename(file.filename)
|
||||||
|
timestamp = str(int(time.time()))
|
||||||
|
filename = f"{timestamp}_{filename}"
|
||||||
|
filepath = os.path.join(SAMPLE_IMAGES_FOLDER, filename)
|
||||||
|
file.save(filepath)
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
'success': True,
|
||||||
|
'message': '文件上传成功',
|
||||||
|
'filename': filename
|
||||||
|
})
|
||||||
|
|
||||||
|
return jsonify({'success': False, 'message': '不支持的文件类型'}), 400
|
||||||
|
|
||||||
|
@app.route('/api/data/list', methods=['GET'])
|
||||||
|
def list_data():
|
||||||
|
"""列出已上传的数据"""
|
||||||
|
try:
|
||||||
|
# 列出图片文件
|
||||||
|
images = []
|
||||||
|
if os.path.exists(SAMPLE_IMAGES_FOLDER):
|
||||||
|
for filename in os.listdir(SAMPLE_IMAGES_FOLDER):
|
||||||
|
if allowed_file(filename):
|
||||||
|
filepath = os.path.join(SAMPLE_IMAGES_FOLDER, filename)
|
||||||
|
stat = os.stat(filepath)
|
||||||
|
images.append({
|
||||||
|
'filename': filename,
|
||||||
|
'size': stat.st_size,
|
||||||
|
'modified': stat.st_mtime
|
||||||
|
})
|
||||||
|
|
||||||
|
# 列出文本文件
|
||||||
|
texts = []
|
||||||
|
if os.path.exists(TEXT_DATA_FOLDER):
|
||||||
|
for filename in os.listdir(TEXT_DATA_FOLDER):
|
||||||
|
if filename.endswith('.json'):
|
||||||
|
filepath = os.path.join(TEXT_DATA_FOLDER, filename)
|
||||||
|
stat = os.stat(filepath)
|
||||||
|
texts.append({
|
||||||
|
'filename': filename,
|
||||||
|
'size': stat.st_size,
|
||||||
|
'modified': stat.st_mtime
|
||||||
|
})
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
'success': True,
|
||||||
|
'data': {
|
||||||
|
'images': images,
|
||||||
|
'texts': texts
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return jsonify({
|
||||||
|
'success': False,
|
||||||
|
'message': f'获取数据列表失败: {str(e)}'
|
||||||
|
}), 500
|
||||||
|
|
||||||
|
@app.route('/api/gpu_status')
|
||||||
|
def gpu_status():
|
||||||
|
"""获取GPU状态"""
|
||||||
|
try:
|
||||||
|
from smart_gpu_launcher import get_gpu_memory_info
|
||||||
|
gpu_info = get_gpu_memory_info()
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
'success': True,
|
||||||
|
'gpu_info': gpu_info
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return jsonify({
|
||||||
|
'success': False,
|
||||||
|
'message': f"获取GPU状态失败: {str(e)}"
|
||||||
|
}), 500
|
||||||
|
|
||||||
|
@app.route('/api/build_index', methods=['POST'])
|
||||||
|
def build_index():
|
||||||
|
"""构建检索索引"""
|
||||||
|
global retrieval_system
|
||||||
|
|
||||||
|
if not retrieval_system:
|
||||||
|
return jsonify({
|
||||||
|
'success': False,
|
||||||
|
'message': '系统未初始化'
|
||||||
|
}), 400
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 获取所有图片和文本文件
|
||||||
|
image_files = []
|
||||||
|
text_data = []
|
||||||
|
|
||||||
|
# 扫描图片文件
|
||||||
|
for ext in ALLOWED_EXTENSIONS:
|
||||||
|
pattern = os.path.join(SAMPLE_IMAGES_FOLDER, f"*.{ext}")
|
||||||
|
image_files.extend(glob.glob(pattern))
|
||||||
|
|
||||||
|
# 读取文本文件(支持.json和.txt格式)
|
||||||
|
text_files = glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.json"))
|
||||||
|
text_files.extend(glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.txt")))
|
||||||
|
|
||||||
|
for text_file in text_files:
|
||||||
|
try:
|
||||||
|
if text_file.endswith('.json'):
|
||||||
|
# 读取JSON格式的文本数据
|
||||||
|
with open(text_file, 'r', encoding='utf-8') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
if isinstance(data, list):
|
||||||
|
text_data.extend([str(item).strip() for item in data if str(item).strip()])
|
||||||
|
else:
|
||||||
|
text_data.append(str(data).strip())
|
||||||
|
else:
|
||||||
|
# 读取TXT格式的文本数据
|
||||||
|
with open(text_file, 'r', encoding='utf-8') as f:
|
||||||
|
lines = [line.strip() for line in f.readlines() if line.strip()]
|
||||||
|
text_data.extend(lines)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"读取文本文件失败 {text_file}: {e}")
|
||||||
|
|
||||||
|
# 检查是否有数据可以构建索引
|
||||||
|
if not image_files and not text_data:
|
||||||
|
return jsonify({
|
||||||
|
'success': False,
|
||||||
|
'message': '没有找到可用的图片或文本数据,请先上传数据'
|
||||||
|
}), 400
|
||||||
|
|
||||||
|
# 构建索引
|
||||||
|
if image_files:
|
||||||
|
logger.info(f"构建图片索引,共 {len(image_files)} 张图片")
|
||||||
|
retrieval_system.build_image_index_parallel(image_files)
|
||||||
|
|
||||||
|
if text_data:
|
||||||
|
logger.info(f"构建文本索引,共 {len(text_data)} 条文本")
|
||||||
|
retrieval_system.build_text_index_parallel(text_data)
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
'success': True,
|
||||||
|
'message': f'索引构建完成!图片: {len(image_files)} 张,文本: {len(text_data)} 条',
|
||||||
|
'image_count': len(image_files),
|
||||||
|
'text_count': len(text_data)
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"构建索引失败: {str(e)}")
|
||||||
|
return jsonify({
|
||||||
|
'success': False,
|
||||||
|
'message': f'构建索引失败: {str(e)}'
|
||||||
|
}), 500
|
||||||
|
|
||||||
|
@app.route('/api/data/stats', methods=['GET'])
|
||||||
|
def get_data_stats():
|
||||||
|
"""获取数据统计信息"""
|
||||||
|
try:
|
||||||
|
# 统计图片文件
|
||||||
|
image_count = 0
|
||||||
|
for ext in ALLOWED_EXTENSIONS:
|
||||||
|
pattern = os.path.join(SAMPLE_IMAGES_FOLDER, f"*.{ext}")
|
||||||
|
image_count += len(glob.glob(pattern))
|
||||||
|
|
||||||
|
# 统计文本数据
|
||||||
|
text_count = 0
|
||||||
|
text_files = glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.txt"))
|
||||||
|
for text_file in text_files:
|
||||||
|
try:
|
||||||
|
with open(text_file, 'r', encoding='utf-8') as f:
|
||||||
|
lines = [line.strip() for line in f.readlines() if line.strip()]
|
||||||
|
text_count += len(lines)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
'success': True,
|
||||||
|
'image_count': image_count,
|
||||||
|
'text_count': text_count
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取数据统计失败: {str(e)}")
|
||||||
|
return jsonify({
|
||||||
|
'success': False,
|
||||||
|
'message': f'获取统计失败: {str(e)}'
|
||||||
|
}), 500
|
||||||
|
|
||||||
|
@app.route('/api/data/clear', methods=['POST'])
|
||||||
|
def clear_data():
|
||||||
|
"""清空所有数据"""
|
||||||
|
try:
|
||||||
|
# 清空图片文件
|
||||||
|
for ext in ALLOWED_EXTENSIONS:
|
||||||
|
pattern = os.path.join(SAMPLE_IMAGES_FOLDER, f"*.{ext}")
|
||||||
|
for file_path in glob.glob(pattern):
|
||||||
|
try:
|
||||||
|
os.remove(file_path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"删除图片文件失败 {file_path}: {e}")
|
||||||
|
|
||||||
|
# 清空文本文件
|
||||||
|
text_files = glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.txt"))
|
||||||
|
for text_file in text_files:
|
||||||
|
try:
|
||||||
|
os.remove(text_file)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"删除文本文件失败 {text_file}: {e}")
|
||||||
|
|
||||||
|
# 重置索引
|
||||||
|
global retrieval_system
|
||||||
|
if retrieval_system:
|
||||||
|
retrieval_system.text_index = None
|
||||||
|
retrieval_system.image_index = None
|
||||||
|
retrieval_system.text_data = []
|
||||||
|
retrieval_system.image_data = []
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
'success': True,
|
||||||
|
'message': '数据已清空'
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"清空数据失败: {str(e)}")
|
||||||
|
return jsonify({
|
||||||
|
'success': False,
|
||||||
|
'message': f'清空数据失败: {str(e)}'
|
||||||
|
}), 500
|
||||||
|
|
||||||
|
@app.route('/uploads/<filename>')
|
||||||
|
def uploaded_file(filename):
|
||||||
|
"""提供上传文件的访问"""
|
||||||
|
return send_file(os.path.join(SAMPLE_IMAGES_FOLDER, filename))
|
||||||
|
|
||||||
|
def print_startup_info():
|
||||||
|
"""打印启动信息"""
|
||||||
|
print("🚀 启动多GPU多模态检索Web应用")
|
||||||
|
print("=" * 60)
|
||||||
|
print("访问地址: http://localhost:5000")
|
||||||
|
print("支持功能:")
|
||||||
|
print(" 📝 文搜文 - 文本查找相似文本")
|
||||||
|
print(" 🖼️ 文搜图 - 文本查找相关图片")
|
||||||
|
print(" 📝 图搜文 - 图片查找相关文本")
|
||||||
|
print(" 🖼️ 图搜图 - 图片查找相似图片")
|
||||||
|
print(" 📤 批量上传 - 图片和文本数据管理")
|
||||||
|
print("GPU配置:")
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
gpu_count = torch.cuda.device_count()
|
||||||
|
print(f" 🖥️ 检测到 {gpu_count} 个GPU")
|
||||||
|
for i in range(gpu_count):
|
||||||
|
name = torch.cuda.get_device_name(i)
|
||||||
|
props = torch.cuda.get_device_properties(i)
|
||||||
|
memory_gb = props.total_memory / 1024**3
|
||||||
|
print(f" GPU {i}: {name} ({memory_gb:.1f}GB)")
|
||||||
|
else:
|
||||||
|
print(" ❌ CUDA不可用")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ❌ GPU检查失败: {e}")
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
def auto_initialize():
|
||||||
|
"""启动时自动初始化系统"""
|
||||||
|
global retrieval_system
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info("🚀 启动时自动初始化多GPU检索系统...")
|
||||||
|
|
||||||
|
# 导入多GPU检索系统
|
||||||
|
from multimodal_retrieval_multigpu import MultiGPUMultimodalRetrieval
|
||||||
|
|
||||||
|
# 初始化系统
|
||||||
|
retrieval_system = MultiGPUMultimodalRetrieval()
|
||||||
|
|
||||||
|
if retrieval_system.model is None:
|
||||||
|
raise Exception("模型加载失败")
|
||||||
|
|
||||||
|
logger.info("✅ 系统自动初始化成功")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ 系统自动初始化失败: {str(e)}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
return False
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
print_startup_info()
|
||||||
|
|
||||||
|
# 启动时自动初始化
|
||||||
|
auto_initialize()
|
||||||
|
|
||||||
|
# 启动Flask应用
|
||||||
|
app.run(
|
||||||
|
host='0.0.0.0',
|
||||||
|
port=5000,
|
||||||
|
debug=False,
|
||||||
|
threaded=True
|
||||||
|
)
|
||||||
Loading…
x
Reference in New Issue
Block a user