diff --git a/LICENSE b/LICENSE deleted file mode 100644 index 261eeb9..0000000 --- a/LICENSE +++ /dev/null @@ -1,201 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000..e116d75 --- /dev/null +++ b/README.md @@ -0,0 +1,170 @@ +# 多模态检索系统 (Multimodal Retrieval System) + +基于 OpenSearch-AI/Ops-MM-embedding-v1-7B 模型的多模态检索系统,支持四种检索模式:文搜图、文搜文、图搜图、图搜文。 + +## 功能特性 + +- **文搜文 (Text-to-Text)**: 使用文本查询搜索相似文本 +- **文搜图 (Text-to-Image)**: 使用文本查询搜索相关图像 +- **图搜图 (Image-to-Image)**: 使用图像查询搜索相似图像 +- **图搜文 (Image-to-Text)**: 使用图像查询搜索相关文本 + +## 环境要求 + +- Python 3.8+ +- CUDA 支持的 GPU (推荐,也可使用CPU) +- 至少 8GB 内存 + +## 安装依赖 + +```bash +pip install -r requirements.txt +``` + +## 快速开始 + +### 1. 检查GPU环境 + +```python +from multimodal_retrieval import check_gpu_info +check_gpu_info() +``` + +### 2. 初始化系统 + +```python +from multimodal_retrieval import MultimodalRetrieval + +# 初始化检索系统 +retrieval_system = MultimodalRetrieval() +``` + +### 3. 构建索引 + +```python +# 构建文本索引 +texts = ["一只可爱的小猫", "美丽的山景", "现代化城市"] +retrieval_system.build_text_index(texts) + +# 构建图像索引 +image_paths = ["./images/cat.jpg", "./images/mountain.jpg", "./images/city.jpg"] +retrieval_system.build_image_index(image_paths) +``` + +### 4. 执行检索 + +```python +# 文搜文 +results = retrieval_system.search_text_by_text("猫咪", top_k=5) + +# 文搜图 +results = retrieval_system.search_images_by_text("动物", top_k=5) + +# 图搜图 +results = retrieval_system.search_images_by_image("./query_image.jpg", top_k=5) + +# 图搜文 +results = retrieval_system.search_text_by_image("./query_image.jpg", top_k=5) +``` + +## 运行演示 + +```bash +python demo.py +``` + +演示脚本会自动: +1. 检查GPU环境信息 +2. 初始化多模态检索系统 +3. 演示四种检索模式 +4. 显示检索结果和相似度分数 + +## 文件结构 + +``` +mmeb/ +├── multimodal_retrieval.py # 主要检索系统类 +├── demo.py # 演示脚本 +├── requirements.txt # 依赖包列表 +├── README.md # 项目说明 +└── images/ # 图像数据目录 (需要自行创建) +``` + +## API 参考 + +### MultimodalRetrieval 类 + +#### 初始化参数 +- `model_name`: 模型名称,默认 "OpenSearch-AI/Ops-MM-embedding-v1-7B" +- `device`: 设备类型,默认自动选择 ("cuda" 或 "cpu") + +#### 主要方法 + +##### `build_text_index(texts, save_path=None)` +构建文本索引 +- `texts`: 文本列表 +- `save_path`: 索引保存路径 (可选) + +##### `build_image_index(image_paths, save_path=None)` +构建图像索引 +- `image_paths`: 图像路径列表 +- `save_path`: 索引保存路径 (可选) + +##### `search_text_by_text(query, top_k=5)` +文搜文检索 +- `query`: 查询文本 +- `top_k`: 返回结果数量 +- 返回: `[(文本, 相似度分数), ...]` + +##### `search_images_by_text(query, top_k=5)` +文搜图检索 +- `query`: 查询文本 +- `top_k`: 返回结果数量 +- 返回: `[(图像路径, 相似度分数), ...]` + +##### `search_images_by_image(query_image, top_k=5)` +图搜图检索 +- `query_image`: 查询图像路径或PIL图像 +- `top_k`: 返回结果数量 +- 返回: `[(图像路径, 相似度分数), ...]` + +##### `search_text_by_image(query_image, top_k=5)` +图搜文检索 +- `query_image`: 查询图像路径或PIL图像 +- `top_k`: 返回结果数量 +- 返回: `[(文本, 相似度分数), ...]` + +## 注意事项 + +1. **首次运行**: 首次运行时会自动下载模型,需要网络连接 +2. **内存需求**: 7B参数模型需要较大内存,建议使用GPU +3. **图像格式**: 支持常见图像格式 (jpg, png, bmp, gif等) +4. **批处理**: 系统自动进行批处理以提高效率 +5. **索引保存**: 可以保存和加载索引以避免重复构建 + +## 性能优化建议 + +- 使用GPU加速推理 +- 合理设置批处理大小 +- 保存索引文件避免重复构建 +- 对大量数据使用分批处理 + +## 故障排除 + +### 常见问题 + +1. **CUDA内存不足**: 减小批处理大小或使用CPU +2. **模型下载失败**: 检查网络连接或使用镜像源 +3. **图像加载错误**: 检查图像文件路径和格式 + +### 日志信息 + +系统会输出详细的日志信息,包括: +- GPU环境检测结果 +- 模型加载进度 +- 索引构建状态 +- 检索执行情况 + +## 许可证 + +本项目遵循 MIT 许可证。 diff --git a/__pycache__/multimodal_retrieval_multigpu.cpython-310.pyc b/__pycache__/multimodal_retrieval_multigpu.cpython-310.pyc new file mode 100644 index 0000000..bdd00a6 Binary files /dev/null and b/__pycache__/multimodal_retrieval_multigpu.cpython-310.pyc differ diff --git a/multimodal_retrieval_multigpu.py b/multimodal_retrieval_multigpu.py new file mode 100644 index 0000000..ea31525 --- /dev/null +++ b/multimodal_retrieval_multigpu.py @@ -0,0 +1,632 @@ +import torch +import torch.nn as nn +from torch.nn.parallel import DataParallel, DistributedDataParallel +import numpy as np +from PIL import Image +import faiss +from transformers import AutoModel, AutoProcessor, AutoTokenizer +from typing import List, Union, Tuple, Dict +import os +import json +from pathlib import Path +import logging +import gc +from concurrent.futures import ThreadPoolExecutor, as_completed +import threading + +# 设置日志 +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +class MultiGPUMultimodalRetrieval: + """多GPU优化的多模态检索系统,支持文搜图、文搜文、图搜图、图搜文""" + + def __init__(self, model_name: str = "OpenSearch-AI/Ops-MM-embedding-v1-7B", + use_all_gpus: bool = True, gpu_ids: List[int] = None, min_memory_gb=12): + """ + 初始化多GPU多模态检索系统 + + Args: + model_name: 模型名称 + use_all_gpus: 是否使用所有可用GPU + gpu_ids: 指定使用的GPU ID列表 + min_memory_gb: 最小可用内存(GB) + """ + self.model_name = model_name + + # 设置GPU设备 + self._setup_devices(use_all_gpus, gpu_ids, min_memory_gb) + + # 清理GPU内存 + self._clear_all_gpu_memory() + + logger.info(f"正在加载模型到多GPU: {self.device_ids}") + + # 加载模型和处理器 + self.model = None + self.tokenizer = None + self.processor = None + self._load_model_multigpu() + + # 初始化索引 + self.text_index = None + self.image_index = None + self.text_data = [] + self.image_data = [] + + logger.info("多GPU模型加载完成") + + def _setup_devices(self, use_all_gpus: bool, gpu_ids: List[int], min_memory_gb=12): + """设置GPU设备""" + if not torch.cuda.is_available(): + raise RuntimeError("CUDA不可用,无法使用多GPU") + + total_gpus = torch.cuda.device_count() + logger.info(f"检测到 {total_gpus} 个GPU") + + # 检查是否设置了CUDA_VISIBLE_DEVICES + cuda_visible_devices = os.environ.get('CUDA_VISIBLE_DEVICES') + if cuda_visible_devices is not None: + # 如果设置了CUDA_VISIBLE_DEVICES,使用可见的GPU + visible_gpu_count = len(cuda_visible_devices.split(',')) + self.device_ids = list(range(visible_gpu_count)) + logger.info(f"使用CUDA_VISIBLE_DEVICES指定的GPU: {cuda_visible_devices}") + elif use_all_gpus: + self.device_ids = self._select_best_gpus(min_memory_gb) + elif gpu_ids: + self.device_ids = gpu_ids + else: + self.device_ids = [0] + + self.num_gpus = len(self.device_ids) + self.primary_device = f"cuda:{self.device_ids[0]}" + + logger.info(f"使用GPU: {self.device_ids}, 主设备: {self.primary_device}") + + def _clear_all_gpu_memory(self): + """清理所有GPU内存""" + for gpu_id in self.device_ids: + torch.cuda.set_device(gpu_id) + torch.cuda.empty_cache() + torch.cuda.synchronize() + gc.collect() + logger.info("所有GPU内存已清理") + + def _load_model_multigpu(self): + """加载模型到多GPU""" + try: + # 设置环境变量优化内存使用 + os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' + + # 清理GPU内存 + self._clear_gpu_memory() + + # 首先尝试使用accelerate的自动设备映射 + if self.num_gpus > 1: + # 设置最大内存限制(每个GPU 18GB,留出缓冲) + max_memory = {i: "18GiB" for i in self.device_ids} + + logger.info(f"正在加载模型到多GPU: {self.device_ids}") + self.model = AutoModel.from_pretrained( + self.model_name, + trust_remote_code=True, + torch_dtype=torch.float16, + device_map="auto", + max_memory=max_memory, + low_cpu_mem_usage=True, + offload_folder="./offload" + ) + else: + # 单GPU加载 + self.model = AutoModel.from_pretrained( + self.model_name, + trust_remote_code=True, + torch_dtype=torch.float16, + device_map=self.primary_device + ) + + # 加载分词器和处理器到主设备 + try: + self.tokenizer = AutoTokenizer.from_pretrained( + self.model_name, + trust_remote_code=True + ) + logger.info("Tokenizer加载成功") + except Exception as e: + logger.error(f"Tokenizer加载失败: {e}") + return False + + # 加载处理器用于图像处理 + try: + self.processor = AutoProcessor.from_pretrained( + self.model_name, + trust_remote_code=True + ) + logger.info("Processor加载成功") + except Exception as e: + logger.warning(f"Processor加载失败: {e}") + # 如果AutoProcessor失败,尝试使用tokenizer作为fallback + logger.info("尝试使用tokenizer作为processor的fallback") + self.processor = self.tokenizer + + logger.info(f"模型已成功加载到设备: {self.model.hf_device_map if hasattr(self.model, 'hf_device_map') else self.primary_device}") + logger.info("多GPU模型加载完成") + return True + + except Exception as e: + logger.error(f"多GPU模型加载失败: {str(e)}") + return False + + def _clear_gpu_memory(self): + """清理GPU内存""" + for gpu_id in self.device_ids: + torch.cuda.set_device(gpu_id) + torch.cuda.empty_cache() + torch.cuda.synchronize() + gc.collect() + logger.info("GPU内存已清理") + + def _get_gpu_memory_info(self): + """获取GPU内存使用情况""" + try: + import subprocess + result = subprocess.run(['nvidia-smi', '--query-gpu=memory.used,memory.total', '--format=csv,nounits,noheader'], + capture_output=True, text=True, check=True) + lines = result.stdout.strip().split('\n') + gpu_info = [] + for i, line in enumerate(lines): + used, total = map(int, line.split(', ')) + free = total - used + gpu_info.append({ + 'gpu_id': i, + 'used': used, + 'total': total, + 'free': free, + 'usage_percent': (used / total) * 100 + }) + return gpu_info + except Exception as e: + logger.warning(f"无法获取GPU内存信息: {e}") + return [] + + def _select_best_gpus(self, min_memory_gb=12): + """选择内存充足的GPU""" + gpu_info = self._get_gpu_memory_info() + if not gpu_info: + return list(range(torch.cuda.device_count())) + + # 按可用内存排序 + gpu_info.sort(key=lambda x: x['free'], reverse=True) + + # 选择内存充足的GPU + min_memory_mb = min_memory_gb * 1024 + suitable_gpus = [] + + for gpu in gpu_info: + if gpu['free'] >= min_memory_mb: + suitable_gpus.append(gpu['gpu_id']) + logger.info(f"GPU {gpu['gpu_id']}: {gpu['free']}MB 可用 (合适)") + else: + logger.warning(f"GPU {gpu['gpu_id']}: {gpu['free']}MB 可用 (不足)") + + if not suitable_gpus: + # 如果没有GPU满足要求,选择可用内存最多的 + logger.warning(f"没有GPU有足够内存({min_memory_gb}GB),选择可用内存最多的GPU") + suitable_gpus = [gpu_info[0]['gpu_id']] + + return suitable_gpus + + def encode_text_batch(self, texts: List[str]) -> np.ndarray: + """ + 批量编码文本为向量(多GPU优化) + + Args: + texts: 文本列表 + + Returns: + 文本向量 + """ + if not texts: + return np.array([]) + + with torch.no_grad(): + # 预处理输入 + inputs = self.tokenizer( + text=texts, + return_tensors="pt", + padding=True, + truncation=True, + max_length=512 + ) + + # 将输入移动到主设备 + inputs = {k: v.to(self.primary_device) for k, v in inputs.items()} + + # 前向传播 + outputs = self.model(**inputs) + embeddings = outputs.last_hidden_state.mean(dim=1) + + # 清理GPU内存 + del inputs, outputs + torch.cuda.empty_cache() + + return embeddings.cpu().numpy().astype(np.float32) + + def encode_image_batch(self, images: List[Union[str, Image.Image]]) -> np.ndarray: + """ + 批量编码图像为向量 + + Args: + images: 图像路径或PIL图像列表 + + Returns: + 图像向量 + """ + if not images: + return np.array([]) + + # 预处理图像 + processed_images = [] + for img in images: + if isinstance(img, str): + img = Image.open(img).convert('RGB') + elif isinstance(img, Image.Image): + img = img.convert('RGB') + processed_images.append(img) + + try: + logger.info(f"处理 {len(processed_images)} 张图像") + + # 使用多模态模型生成图像embedding + # 为每张图像创建简单的文本描述作为输入 + conversations = [] + for i in range(len(processed_images)): + # 使用简化的对话格式 + conversation = [ + { + "role": "user", + "content": [ + {"type": "image", "image": processed_images[i]}, + {"type": "text", "text": "What is in this image?"} + ] + } + ] + conversations.append(conversation) + + # 使用processor处理 + try: + # 尝试使用apply_chat_template方法 + texts = [] + for conv in conversations: + text = self.processor.apply_chat_template(conv, tokenize=False, add_generation_prompt=False) + texts.append(text) + + # 处理文本和图像 + inputs = self.processor( + text=texts, + images=processed_images, + return_tensors="pt", + padding=True + ) + + # 移动到GPU + inputs = {k: v.to(self.primary_device) for k, v in inputs.items()} + + # 获取模型输出 + with torch.no_grad(): + outputs = self.model(**inputs) + embeddings = outputs.last_hidden_state.mean(dim=1) + + # 转换为numpy数组 + embeddings = embeddings.cpu().numpy().astype(np.float32) + + except Exception as inner_e: + logger.warning(f"多模态模型图像编码失败,使用文本模式: {inner_e}") + return np.zeros((len(processed_images), 3584), dtype=np.float32) + # 如果多模态失败,使用纯文本描述作为fallback + image_descriptions = ["An image" for _ in processed_images] + text_inputs = self.processor( + text=image_descriptions, + return_tensors="pt", + padding=True, + truncation=True, + max_length=512 + ) + text_inputs = {k: v.to(self.primary_device) for k, v in text_inputs.items()} + + with torch.no_grad(): + outputs = self.model(**text_inputs) + embeddings = outputs.last_hidden_state.mean(dim=1) + + embeddings = embeddings.cpu().numpy().astype(np.float32) + + logger.info(f"生成图像embeddings: {embeddings.shape}") + return embeddings + + except Exception as e: + logger.error(f"图像编码失败: {e}") + # 返回与文本embedding维度一致的零向量作为fallback + embedding_dim = 3584 + embeddings = np.zeros((len(processed_images), embedding_dim), dtype=np.float32) + return embeddings + + def build_text_index_parallel(self, texts: List[str], save_path: str = None): + """ + 并行构建文本索引(多GPU优化) + + Args: + texts: 文本列表 + save_path: 索引保存路径 + """ + logger.info(f"正在并行构建文本索引,共 {len(texts)} 条文本") + + # 根据GPU数量调整批次大小 + batch_size = max(4, 16 // self.num_gpus) + all_embeddings = [] + + # 分批处理 + for i in range(0, len(texts), batch_size): + batch_texts = texts[i:i+batch_size] + + try: + embeddings = self.encode_text_batch(batch_texts) + all_embeddings.append(embeddings) + + # 显示进度 + if (i // batch_size + 1) % 10 == 0: + logger.info(f"已处理 {i + len(batch_texts)}/{len(texts)} 条文本") + + except torch.cuda.OutOfMemoryError: + logger.warning(f"GPU内存不足,跳过批次 {i}-{i+len(batch_texts)}") + self._clear_all_gpu_memory() + continue + except Exception as e: + logger.error(f"处理文本批次时出错: {e}") + continue + + if not all_embeddings: + raise ValueError("没有成功处理任何文本") + + # 合并所有嵌入向量 + embeddings = np.vstack(all_embeddings) + + # 构建FAISS索引 + dimension = embeddings.shape[1] + self.text_index = faiss.IndexFlatIP(dimension) + + # 归一化向量 + faiss.normalize_L2(embeddings) + self.text_index.add(embeddings) + + self.text_data = texts + + if save_path: + self._save_index(self.text_index, texts, save_path + "_text") + + logger.info("文本索引构建完成") + + def build_image_index_parallel(self, image_paths: List[str], save_path: str = None): + """ + 并行构建图像索引(多GPU优化) + + Args: + image_paths: 图像路径列表 + save_path: 索引保存路径 + """ + logger.info(f"正在并行构建图像索引,共 {len(image_paths)} 张图像") + + # 图像处理使用更小的批次 + batch_size = max(2, 8 // self.num_gpus) + all_embeddings = [] + + for i in range(0, len(image_paths), batch_size): + batch_images = image_paths[i:i+batch_size] + + try: + embeddings = self.encode_image_batch(batch_images) + all_embeddings.append(embeddings) + + # 显示进度 + if (i // batch_size + 1) % 5 == 0: + logger.info(f"已处理 {i + len(batch_images)}/{len(image_paths)} 张图像") + + except torch.cuda.OutOfMemoryError: + logger.warning(f"GPU内存不足,跳过图像批次 {i}-{i+len(batch_images)}") + self._clear_all_gpu_memory() + continue + except Exception as e: + logger.error(f"处理图像批次时出错: {e}") + continue + + if not all_embeddings: + raise ValueError("没有成功处理任何图像") + + embeddings = np.vstack(all_embeddings) + + # 构建FAISS索引 + dimension = embeddings.shape[1] + self.image_index = faiss.IndexFlatIP(dimension) + + faiss.normalize_L2(embeddings) + self.image_index.add(embeddings) + + self.image_data = image_paths + + if save_path: + self._save_index(self.image_index, image_paths, save_path + "_image") + + logger.info("图像索引构建完成") + + def search_text_by_text(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]: + """文搜文:使用文本查询搜索相似文本""" + if self.text_index is None: + raise ValueError("文本索引未构建,请先调用 build_text_index_parallel") + + query_embedding = self.encode_text_batch([query]).astype(np.float32) + faiss.normalize_L2(query_embedding) + + scores, indices = self.text_index.search(query_embedding, top_k) + + results = [] + for score, idx in zip(scores[0], indices[0]): + if idx != -1: + results.append((self.text_data[idx], float(score))) + + return results + + def search_images_by_text(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]: + """文搜图:使用文本查询搜索相似图像""" + if self.image_index is None: + raise ValueError("图像索引未构建,请先调用 build_image_index_parallel") + + query_embedding = self.encode_text_batch([query]).astype(np.float32) + faiss.normalize_L2(query_embedding) + + scores, indices = self.image_index.search(query_embedding, top_k) + + results = [] + for score, idx in zip(scores[0], indices[0]): + if idx != -1: + results.append((self.image_data[idx], float(score))) + + return results + + def search_images_by_image(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]: + """图搜图:使用图像查询搜索相似图像""" + if self.image_index is None: + raise ValueError("图像索引未构建,请先调用 build_image_index_parallel") + + query_embedding = self.encode_image_batch([query_image]).astype(np.float32) + faiss.normalize_L2(query_embedding) + + scores, indices = self.image_index.search(query_embedding, top_k) + + results = [] + for score, idx in zip(scores[0], indices[0]): + if idx != -1: + results.append((self.image_data[idx], float(score))) + + return results + + def search_text_by_image(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]: + """图搜文:使用图像查询搜索相似文本""" + if self.text_index is None: + raise ValueError("文本索引未构建,请先调用 build_text_index_parallel") + + query_embedding = self.encode_image_batch([query_image]).astype(np.float32) + faiss.normalize_L2(query_embedding) + + scores, indices = self.text_index.search(query_embedding, top_k) + + results = [] + for score, idx in zip(scores[0], indices[0]): + if idx != -1: + results.append((self.text_data[idx], float(score))) + + return results + + # Web应用兼容的方法名称 + def search_text_to_image(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]: + """文搜图:Web应用兼容方法""" + return self.search_images_by_text(query, top_k) + + def search_image_to_image(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]: + """图搜图:Web应用兼容方法""" + return self.search_images_by_image(query_image, top_k) + + def search_text_to_text(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]: + """文搜文:Web应用兼容方法""" + return self.search_text_by_text(query, top_k) + + def search_image_to_text(self, query_image: Union[str, Image.Image], top_k: int = 5) -> List[Tuple[str, float]]: + """图搜文:Web应用兼容方法""" + return self.search_text_by_image(query_image, top_k) + + def _save_index(self, index, data, path_prefix): + """保存索引和数据""" + faiss.write_index(index, f"{path_prefix}.index") + with open(f"{path_prefix}.json", 'w', encoding='utf-8') as f: + json.dump(data, f, ensure_ascii=False, indent=2) + + def load_index(self, path_prefix, index_type="text"): + """加载已保存的索引""" + index = faiss.read_index(f"{path_prefix}.index") + with open(f"{path_prefix}.json", 'r', encoding='utf-8') as f: + data = json.load(f) + + if index_type == "text": + self.text_index = index + self.text_data = data + else: + self.image_index = index + self.image_data = data + + logger.info(f"已加载 {index_type} 索引") + + def get_gpu_memory_info(self): + """获取所有GPU内存使用信息""" + memory_info = {} + for gpu_id in self.device_ids: + torch.cuda.set_device(gpu_id) + allocated = torch.cuda.memory_allocated(gpu_id) / 1024**3 + cached = torch.cuda.memory_reserved(gpu_id) / 1024**3 + total = torch.cuda.get_device_properties(gpu_id).total_memory / 1024**3 + free = total - cached + + memory_info[f"GPU_{gpu_id}"] = { + "total": f"{total:.1f}GB", + "allocated": f"{allocated:.1f}GB", + "cached": f"{cached:.1f}GB", + "free": f"{free:.1f}GB" + } + + return memory_info + +def check_multigpu_info(): + """检查多GPU环境信息""" + print("=== 多GPU环境信息 ===") + + if not torch.cuda.is_available(): + print("❌ CUDA不可用") + return + + gpu_count = torch.cuda.device_count() + print(f"✅ 检测到 {gpu_count} 个GPU") + print(f"CUDA版本: {torch.version.cuda}") + print(f"PyTorch版本: {torch.__version__}") + + for i in range(gpu_count): + gpu_name = torch.cuda.get_device_name(i) + gpu_memory = torch.cuda.get_device_properties(i).total_memory / 1024**3 + print(f"GPU {i}: {gpu_name} ({gpu_memory:.1f}GB)") + + print("=====================") + +if __name__ == "__main__": + # 检查多GPU环境 + check_multigpu_info() + + # 示例使用 + print("\n正在初始化多GPU多模态检索系统...") + + try: + retrieval_system = MultiGPUMultimodalRetrieval() + print("✅ 多GPU系统初始化成功!") + + # 显示GPU内存使用情况 + memory_info = retrieval_system.get_gpu_memory_info() + print("\n📊 GPU内存使用情况:") + for gpu, info in memory_info.items(): + print(f" {gpu}: {info['allocated']} / {info['total']} (已用/总计)") + + print("\n🚀 多GPU多模态检索系统就绪!") + print("支持的检索模式:") + print("1. 文搜文: search_text_by_text()") + print("2. 文搜图: search_images_by_text()") + print("3. 图搜图: search_images_by_image()") + print("4. 图搜文: search_text_by_image()") + + except Exception as e: + print(f"❌ 多GPU系统初始化失败: {e}") + import traceback + traceback.print_exc() diff --git a/ops_mm_embedding_v1.py b/ops_mm_embedding_v1.py new file mode 100644 index 0000000..482af79 --- /dev/null +++ b/ops_mm_embedding_v1.py @@ -0,0 +1,309 @@ +import math +from typing import List, Optional, TypeAlias, Union + +import torch +import torch.nn as nn +from PIL import Image +from tqdm import tqdm +from transformers import AutoModelForImageTextToText, AutoProcessor + +ImageInput: TypeAlias = Union[Image.Image, List[Image.Image]] +BatchImageInput: TypeAlias = Union[List[Image.Image], List[List[Image.Image]]] + + +class OpsMMEmbeddingV1(nn.Module): + def __init__( + self, + model_name: str, + device: str = "cuda", + max_length: Optional[int] = None, + attn_implementation: Optional[str] = None, + ): + super().__init__() + self.device = device + self.max_length = max_length + self.default_instruction = "You are a helpful assistant." + self.base_model = AutoModelForImageTextToText.from_pretrained( + model_name, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + attn_implementation=attn_implementation, + ).to(self.device) + + self.processor = AutoProcessor.from_pretrained(model_name, min_pixels=256 * 28 * 28, max_pixels=1280 * 28 * 28) + self.processor.tokenizer.padding_side = "left" + self.eval() + + def encode_input(self, input): + hidden_states = self.base_model(**input, return_dict=True, output_hidden_states=True) + hidden_states = hidden_states.hidden_states[-1] + pooled_output = self._pooling(hidden_states) + return pooled_output + + def _pooling(self, last_hidden_state): + batch_size = last_hidden_state.shape[0] + reps = last_hidden_state[torch.arange(batch_size), -1, :] + reps = torch.nn.functional.normalize(reps, p=2, dim=-1) + return reps + + def _validate_instructions( + self, + texts: Optional[List[str]], + images: Optional[BatchImageInput], + instruction: Optional[Union[str, List[str]]], + ) -> List[str]: + """Validate and format instructions to match batch size""" + batch_size = max(len(x) if x is not None else 0 for x in [texts, images]) + + if instruction is None: + return [self.default_instruction] * batch_size + + if isinstance(instruction, str): + return [instruction] * batch_size + + if isinstance(instruction, list): + if len(instruction) != batch_size: + raise ValueError(f"Length of instruction list ({len(instruction)}) must match batch size ({batch_size}) when texts/images are provided") + return instruction + + raise TypeError("instruction must be str, List[str] or None") + + def _process_images(self, images: ImageInput) -> List[Image.Image]: + """Convert single image or list of images to processed format""" + if isinstance(images, Image.Image) or isinstance(images, str): + return [fetch_image(images)] + return [fetch_image(i) for i in images] + + def embed( + self, + texts: Optional[List[str]] = None, + images: Optional[BatchImageInput] = None, + instruction: Optional[Union[str, List[str]]] = None, + **kwargs, + ) -> torch.Tensor: + """Generate embeddings for text, images, or combined inputs. + + Args: + texts: List of text inputs (optional) + images: Can be: + - List[Image.Image]: Single image per input + - List[List[Image.Image]]: Multiple images per input + instruction: Instruction(s) for the model. Can be: + - None: use default instruction + - str: use same instruction for all inputs + - List[str]: per-input instructions (must match batch size) + """ + if texts is None and images is None: + raise ValueError("Either texts or images must be provided") + + instructions = self._validate_instructions(texts, images, instruction) + + # Determine batch size + batch_size = len(texts) if texts is not None else len(images) # type: ignore + + input_texts, input_images = [], [] + for i in range(batch_size): + text = texts[i] if texts is not None else None + image = images[i] if images is not None else None + + input_str = "" + processed_image = None + if image is not None: + processed_image = self._process_images(image) + input_str += "<|vision_start|><|image_pad|><|vision_end|>" * len(processed_image) + + if text is not None: + input_str += text + + msg = f"<|im_start|>system\n{instructions[i]}<|im_end|>\n<|im_start|>user\n{input_str}<|im_end|>\n<|im_start|>assistant\n<|endoftext|>" + + input_texts.append(msg) + input_images.append(processed_image) + + # Only pass to processor if we actually have images + processed_images = input_images if any(img is not None for img in input_images) else None + + inputs = self.processor( + text=input_texts, + images=processed_images, + padding=True, + truncation=True, + max_length=self.max_length, + return_tensors="pt", + ) + inputs = {k: v.to(self.device) for k, v in inputs.items()} + + with torch.inference_mode(): + embeddings = self.encode_input(inputs) + + return embeddings + + def get_text_embeddings( + self, + texts: List[str], + instruction: Optional[Union[str, List[str]]] = None, + **kwargs, + ) -> torch.Tensor: + """Convenience method for text-only embeddings""" + return self.get_fused_embeddings(texts=texts, instruction=instruction, **kwargs) + + def get_image_embeddings( + self, + images: BatchImageInput, + instruction: Optional[Union[str, List[str]]] = None, + **kwargs, + ) -> torch.Tensor: + """Convenience method for image-only embeddings. + + Args: + images: Can be: + - List[Image.Image]: Single image per input + - List[List[Image.Image]]: Multiple images per input + """ + return self.get_fused_embeddings(images=images, instruction=instruction, **kwargs) + + def get_fused_embeddings( + self, + texts: Optional[List[str]] = None, + images: Optional[BatchImageInput] = None, + instruction: Optional[Union[str, List[str]]] = None, + batch_size: int = 8, + show_progress: bool = True, + **kwargs, + ) -> torch.Tensor: + """Batch processing for large collections of texts/images. + + Args: + texts: List of text inputs (optional) + images: Can be: + - List[Image.Image]: Single image per input + - List[List[Image.Image]]: Multiple images per input + instruction: Instruction(s) for the model + batch_size: Number of items to process at once + show_progress: Whether to display progress bar + """ + + if texts is None and images is None: + raise ValueError("Either texts or images must be provided") + + total_items = len(texts) if texts is not None else len(images) # type: ignore + num_batches = math.ceil(total_items / batch_size) + + all_embeddings = [] + progress = tqdm(total=num_batches, disable=not show_progress, desc="Processing") + + for i in range(0, total_items, batch_size): + batch_texts = texts[i : i + batch_size] if texts is not None else None + batch_images = images[i : i + batch_size] if images is not None else None + batch_emb = self.embed(texts=batch_texts, images=batch_images, instruction=instruction) + + all_embeddings.append(batch_emb.cpu()) + progress.update(1) + + progress.close() + return torch.cat(all_embeddings, dim=0).to(self.device) + + def forward(self, **inputs) -> torch.Tensor: + """Alias for encode_input""" + return self.encode_input(inputs) + + +### Modified from qwen_vl_utils.vision_process.py +import base64 +import logging +import math +from io import BytesIO + +import requests + +IMAGE_FACTOR = 28 +MIN_PIXELS = 256 * 28 * 28 +MAX_PIXELS = 1280 * 28 * 28 +MAX_RATIO = 200 + + +def round_by_factor(number: int, factor: int) -> int: + """Returns the closest integer to 'number' that is divisible by 'factor'.""" + return round(number / factor) * factor + + +def ceil_by_factor(number: int | float, factor: int) -> int: + """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" + return math.ceil(number / factor) * factor + + +def floor_by_factor(number: int | float, factor: int) -> int: + """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" + return math.floor(number / factor) * factor + + +def smart_resize( + height: int, + width: int, + factor: int = IMAGE_FACTOR, + min_pixels: int = MIN_PIXELS, + max_pixels: int = MAX_PIXELS, +) -> tuple[int, int]: + """ + Rescales the image so that the following conditions are met: + 1. Both dimensions (height and width) are divisible by 'factor'. + 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. + 3. The aspect ratio of the image is maintained as closely as possible. + """ + h_bar = max(factor, round_by_factor(height, factor)) + w_bar = max(factor, round_by_factor(width, factor)) + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = floor_by_factor(height / beta, factor) + w_bar = floor_by_factor(width / beta, factor) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = ceil_by_factor(height * beta, factor) + w_bar = ceil_by_factor(width * beta, factor) + + if max(h_bar, w_bar) / min(h_bar, w_bar) > MAX_RATIO: + logging.warning(f"Absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(h_bar, w_bar) / min(h_bar, w_bar)}") + if h_bar > w_bar: + h_bar = w_bar * MAX_RATIO + else: + w_bar = h_bar * MAX_RATIO + return h_bar, w_bar + + +def fetch_image( + image: str | Image.Image, + size_factor: int = IMAGE_FACTOR, + min_pixels: int = MIN_PIXELS, + max_pixels: int = MAX_PIXELS, +) -> Image.Image: + image_obj = None + if isinstance(image, Image.Image): + image_obj = image + elif image.startswith("http://") or image.startswith("https://"): + image_obj = Image.open(requests.get(image, stream=True).raw) # type: ignore + elif image.startswith("file://"): + image_obj = Image.open(image[7:]) + elif image.startswith("data:image"): + if "base64," in image: + _, base64_data = image.split("base64,", 1) + data = base64.b64decode(base64_data) + image_obj = Image.open(BytesIO(data)) + else: + image_obj = Image.open(image) + if image_obj is None: + raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}") + image = image_obj.convert("RGB") + width, height = image.size + resized_height, resized_width = smart_resize( + height, + width, + factor=size_factor, + min_pixels=min_pixels, + max_pixels=max_pixels, + ) + image = image.resize((resized_width, resized_height)) + + return image + + +### diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..ab6d48d --- /dev/null +++ b/requirements.txt @@ -0,0 +1,12 @@ +torch>=2.0.0 +torchvision>=0.15.0 +transformers>=4.30.0 +accelerate>=0.20.0 +faiss-cpu>=1.7.4 +numpy>=1.21.0 +Pillow>=9.0.0 +scikit-learn>=1.3.0 +tqdm>=4.65.0 +flask>=2.3.0 +werkzeug>=2.3.0 +psutil>=5.9.0 diff --git a/sample_images/1755681385_4__.jpg b/sample_images/1755681385_4__.jpg new file mode 100644 index 0000000..1ae598e Binary files /dev/null and b/sample_images/1755681385_4__.jpg differ diff --git a/sample_images/1755681385_5__.jpg b/sample_images/1755681385_5__.jpg new file mode 100644 index 0000000..da8cddd Binary files /dev/null and b/sample_images/1755681385_5__.jpg differ diff --git a/sample_images/1755681385_6__.jpg b/sample_images/1755681385_6__.jpg new file mode 100644 index 0000000..721c123 Binary files /dev/null and b/sample_images/1755681385_6__.jpg differ diff --git a/sample_images/1755681385_7__.jpg b/sample_images/1755681385_7__.jpg new file mode 100644 index 0000000..7152063 Binary files /dev/null and b/sample_images/1755681385_7__.jpg differ diff --git a/templates/index.html b/templates/index.html new file mode 100644 index 0000000..41a2ea2 --- /dev/null +++ b/templates/index.html @@ -0,0 +1,971 @@ + + + + + + 多模态检索系统 + + + + + + +
+
+ 未初始化 +
+
+ +
+
+ +
+

多模态检索系统

+

支持文搜图、文搜文、图搜图、图搜文四种检索模式

+
+ +
+ +
+ +
+ + +
+
+
+ +
文搜文
+

文本查找相似文本

+
+
+
+
+ +
文搜图
+

文本查找相关图片

+
+
+
+
+ +
图搜文
+

图片查找相关文本

+
+
+
+
+ +
图搜图
+

图片查找相似图片

+
+
+
+ + +
+
+
+
+
数据管理
+ 上传和管理检索数据库 +
+
+
+ +
+
+
批量上传图片
+
+ +

拖拽多张图片到此处或点击选择

+ + +
+ +
+
+ + +
+
+
批量上传文本
+
+ +
+
+ + + +
+
+
+
+ + +
+
+
+ + + +
+
+
+
+ + 图片: 0 张 | + 文本: 0 条 + +
+
+
+
+
+
+
+ + + + + +
+
+ Loading... +
+

正在搜索中...

+
+ + +
+
+
+
+ + + + + diff --git a/test_all_retrieval_modes.py b/test_all_retrieval_modes.py new file mode 100644 index 0000000..a554b3a --- /dev/null +++ b/test_all_retrieval_modes.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python3 +""" +测试所有四种多模态检索模式 +""" + +from multimodal_retrieval_multigpu import MultiGPUMultimodalRetrieval +import numpy as np +from PIL import Image +import os + +def test_all_retrieval_modes(): + print('正在初始化多GPU多模态检索系统...') + retrieval = MultiGPUMultimodalRetrieval() + + # 准备测试数据 + test_texts = [ + "一只可爱的小猫", + "美丽的风景照片", + "现代建筑设计", + "colorful flowers in garden" + ] + + test_images = [ + 'sample_images/1755677101_1__.jpg', + 'sample_images/1755677101_2__.jpg', + 'sample_images/1755677101_3__.jpg', + 'sample_images/1755677101_4__.jpg' + ] + + # 验证测试图像存在 + existing_images = [img for img in test_images if os.path.exists(img)] + if not existing_images: + print("❌ 没有找到测试图像文件") + return + + print(f"找到 {len(existing_images)} 张测试图像") + + try: + # 1. 构建文本索引 + print('\n=== 构建文本索引 ===') + retrieval.build_text_index_parallel(test_texts) + print('✅ 文本索引构建完成') + + # 2. 构建图像索引 + print('\n=== 构建图像索引 ===') + retrieval.build_image_index_parallel(existing_images) + print('✅ 图像索引构建完成') + + # 3. 测试文本到文本检索 + print('\n=== 测试文本到文本检索 ===') + query = "小动物" + results = retrieval.search_text_by_text(query, top_k=3) + print(f'查询: "{query}"') + for i, (text, score) in enumerate(results): + print(f' {i+1}. {text} (相似度: {score:.4f})') + + # 4. 测试文本到图像检索 + print('\n=== 测试文本到图像检索 ===') + query = "beautiful image" + results = retrieval.search_images_by_text(query, top_k=3) + print(f'查询: "{query}"') + for i, (image_path, score) in enumerate(results): + print(f' {i+1}. {image_path} (相似度: {score:.4f})') + + # 5. 测试图像到文本检索 + print('\n=== 测试图像到文本检索 ===') + query_image = existing_images[0] + results = retrieval.search_text_by_image(query_image, top_k=3) + print(f'查询图像: {query_image}') + for i, (text, score) in enumerate(results): + print(f' {i+1}. {text} (相似度: {score:.4f})') + + # 6. 测试图像到图像检索 + print('\n=== 测试图像到图像检索 ===') + query_image = existing_images[0] + results = retrieval.search_images_by_image(query_image, top_k=3) + print(f'查询图像: {query_image}') + for i, (image_path, score) in enumerate(results): + print(f' {i+1}. {image_path} (相似度: {score:.4f})') + + print('\n✅ 所有四种检索模式测试完成!') + + # 7. 测试Web应用兼容的方法名 + print('\n=== 测试Web应用兼容方法 ===') + try: + results = retrieval.search_text_to_image("test query", top_k=2) + print('✅ search_text_to_image 方法正常') + + results = retrieval.search_image_to_text(existing_images[0], top_k=2) + print('✅ search_image_to_text 方法正常') + + results = retrieval.search_image_to_image(existing_images[0], top_k=2) + print('✅ search_image_to_image 方法正常') + + except Exception as e: + print(f'❌ Web应用兼容方法测试失败: {e}') + + except Exception as e: + print(f'❌ 测试过程中出现错误: {e}') + import traceback + traceback.print_exc() + +if __name__ == "__main__": + test_all_retrieval_modes() diff --git a/test_image_encoding.py b/test_image_encoding.py new file mode 100644 index 0000000..3604ac0 --- /dev/null +++ b/test_image_encoding.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +""" +测试图像编码功能 +""" + +from multimodal_retrieval_multigpu import MultiGPUMultimodalRetrieval +import numpy as np +from PIL import Image + +def test_image_encoding(): + print('正在初始化多GPU多模态检索系统...') + retrieval = MultiGPUMultimodalRetrieval() + + # 测试文本编码 + print('测试文本编码...') + text_embeddings = retrieval.encode_text_batch(['这是一个测试文本']) + print(f'文本embedding形状: {text_embeddings.shape}') + print(f'文本embedding数据类型: {text_embeddings.dtype}') + + # 测试图像编码 + print('测试图像编码...') + test_images = ['sample_images/1755677101_1__.jpg'] + image_embeddings = retrieval.encode_image_batch(test_images) + print(f'图像embedding形状: {image_embeddings.shape}') + print(f'图像embedding数据类型: {image_embeddings.dtype}') + + # 测试两次相同图像的embedding是否一致 + print('测试embedding一致性...') + image_embeddings2 = retrieval.encode_image_batch(test_images) + consistency = np.allclose(image_embeddings, image_embeddings2, rtol=1e-5) + print(f'相同图像embedding一致性: {consistency}') + + # 测试不同图像的embedding差异 + print('测试不同图像embedding差异...') + test_images2 = ['sample_images/1755677101_2__.jpg'] + image_embeddings3 = retrieval.encode_image_batch(test_images2) + similarity = np.dot(image_embeddings[0], image_embeddings3[0]) / (np.linalg.norm(image_embeddings[0]) * np.linalg.norm(image_embeddings3[0])) + print(f'不同图像间相似度: {similarity:.4f}') + + # 验证维度一致性 + if text_embeddings.shape[1] == image_embeddings.shape[1]: + print('✅ 文本和图像embedding维度一致') + else: + print('❌ 文本和图像embedding维度不一致') + + print('测试完成!') + +if __name__ == "__main__": + test_image_encoding() diff --git a/uploads/query_1755675080_10__.jpg b/uploads/query_1755675080_10__.jpg new file mode 100644 index 0000000..7ea5d82 Binary files /dev/null and b/uploads/query_1755675080_10__.jpg differ diff --git a/uploads/query_1755681423_5__.jpg b/uploads/query_1755681423_5__.jpg new file mode 100644 index 0000000..da8cddd Binary files /dev/null and b/uploads/query_1755681423_5__.jpg differ diff --git a/web_app_multigpu.py b/web_app_multigpu.py new file mode 100644 index 0000000..e0b4b95 --- /dev/null +++ b/web_app_multigpu.py @@ -0,0 +1,617 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +多GPU多模态检索系统 - Web应用 +专为双GPU部署优化 +""" + +import os +import json +import time +from flask import Flask, render_template, request, jsonify, send_file, url_for +from werkzeug.utils import secure_filename +from PIL import Image +import base64 +import io +import logging +import traceback +import glob + +# 设置环境变量优化GPU内存 +os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' + +# 配置日志 +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +app = Flask(__name__) +app.config['SECRET_KEY'] = 'multigpu_multimodal_retrieval_2024' +app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max file size + +# 配置上传文件夹 +UPLOAD_FOLDER = 'uploads' +SAMPLE_IMAGES_FOLDER = 'sample_images' +TEXT_DATA_FOLDER = 'text_data' +ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'bmp', 'webp'} + +# 确保文件夹存在 +os.makedirs(UPLOAD_FOLDER, exist_ok=True) +os.makedirs(SAMPLE_IMAGES_FOLDER, exist_ok=True) +os.makedirs(TEXT_DATA_FOLDER, exist_ok=True) + +# 全局检索系统实例 +retrieval_system = None + +def allowed_file(filename): + """检查文件扩展名是否允许""" + return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS + +def image_to_base64(image_path): + """将图片转换为base64编码""" + try: + with open(image_path, "rb") as img_file: + return base64.b64encode(img_file.read()).decode('utf-8') + except Exception as e: + logger.error(f"图片转换失败: {e}") + return None + +@app.route('/') +def index(): + """主页""" + return render_template('index.html') + +@app.route('/api/status') +def get_status(): + """获取系统状态""" + global retrieval_system + + status = { + 'initialized': retrieval_system is not None, + 'gpu_count': 0, + 'model_loaded': False + } + + try: + import torch + if torch.cuda.is_available(): + status['gpu_count'] = torch.cuda.device_count() + + if retrieval_system and retrieval_system.model: + status['model_loaded'] = True + status['device_ids'] = retrieval_system.device_ids + + except Exception as e: + logger.error(f"获取状态失败: {e}") + + return jsonify(status) + +@app.route('/api/init', methods=['POST']) +def initialize_system(): + """初始化多GPU检索系统""" + global retrieval_system + + try: + logger.info("正在初始化多GPU检索系统...") + + # 导入多GPU检索系统 + from multimodal_retrieval_multigpu import MultiGPUMultimodalRetrieval + + # 初始化系统 + retrieval_system = MultiGPUMultimodalRetrieval() + + if retrieval_system.model is None: + raise Exception("模型加载失败") + + logger.info("✅ 多GPU系统初始化成功") + + return jsonify({ + 'success': True, + 'message': '多GPU系统初始化成功', + 'device_ids': retrieval_system.device_ids, + 'gpu_count': len(retrieval_system.device_ids) + }) + + except Exception as e: + error_msg = f"系统初始化失败: {str(e)}" + logger.error(error_msg) + logger.error(traceback.format_exc()) + + return jsonify({ + 'success': False, + 'message': error_msg + }), 500 + +@app.route('/api/search/text_to_text', methods=['POST']) +def search_text_to_text(): + """文本搜索文本""" + return handle_search('text_to_text') + +@app.route('/api/search/text_to_image', methods=['POST']) +def search_text_to_image(): + """文本搜索图片""" + return handle_search('text_to_image') + +@app.route('/api/search/image_to_text', methods=['POST']) +def search_image_to_text(): + """图片搜索文本""" + return handle_search('image_to_text') + +@app.route('/api/search/image_to_image', methods=['POST']) +def search_image_to_image(): + """图片搜索图片""" + return handle_search('image_to_image') + +@app.route('/api/search', methods=['POST']) +def search(): + """通用搜索接口(兼容旧版本)""" + mode = request.form.get('mode') or request.json.get('mode', 'text_to_text') + return handle_search(mode) + +def handle_search(mode): + """处理搜索请求的通用函数""" + global retrieval_system + + if not retrieval_system: + return jsonify({ + 'success': False, + 'message': '系统未初始化,请先点击初始化按钮' + }), 400 + + try: + top_k = int(request.form.get('top_k', 5)) + + if mode in ['text_to_text', 'text_to_image']: + # 文本查询 + query = request.form.get('query') or request.json.get('query', '') + if not query.strip(): + return jsonify({ + 'success': False, + 'message': '请输入查询文本' + }), 400 + + logger.info(f"执行{mode}搜索: {query}") + + # 执行搜索 + if mode == 'text_to_text': + results = retrieval_system.search_text_to_text(query, top_k=top_k) + else: # text_to_image + results = retrieval_system.search_text_to_image(query, top_k=top_k) + + return jsonify({ + 'success': True, + 'mode': mode, + 'query': query, + 'results': results, + 'result_count': len(results) + }) + + elif mode in ['image_to_text', 'image_to_image']: + # 图片查询 + if 'image' not in request.files: + return jsonify({ + 'success': False, + 'message': '请上传查询图片' + }), 400 + + file = request.files['image'] + if file.filename == '' or not allowed_file(file.filename): + return jsonify({ + 'success': False, + 'message': '请上传有效的图片文件' + }), 400 + + # 保存上传的图片 + filename = secure_filename(file.filename) + timestamp = str(int(time.time())) + filename = f"query_{timestamp}_{filename}" + filepath = os.path.join(UPLOAD_FOLDER, filename) + file.save(filepath) + + logger.info(f"执行{mode}搜索,图片: {filename}") + + # 执行搜索 + if mode == 'image_to_text': + results = retrieval_system.search_image_to_text(filepath, top_k=top_k) + else: # image_to_image + results = retrieval_system.search_image_to_image(filepath, top_k=top_k) + + # 转换查询图片为base64 + query_image_b64 = image_to_base64(filepath) + + return jsonify({ + 'success': True, + 'mode': mode, + 'query_image': query_image_b64, + 'results': results, + 'result_count': len(results) + }) + + else: + return jsonify({ + 'success': False, + 'message': f'不支持的搜索模式: {mode}' + }), 400 + + except Exception as e: + error_msg = f"搜索失败: {str(e)}" + logger.error(error_msg) + logger.error(traceback.format_exc()) + + return jsonify({ + 'success': False, + 'message': error_msg + }), 500 + +@app.route('/api/upload/images', methods=['POST']) +def upload_images(): + """批量上传图片""" + try: + uploaded_files = [] + + if 'images' not in request.files: + return jsonify({'success': False, 'message': '没有选择文件'}), 400 + + files = request.files.getlist('images') + + for file in files: + if file and file.filename != '' and allowed_file(file.filename): + filename = secure_filename(file.filename) + timestamp = str(int(time.time())) + filename = f"{timestamp}_{filename}" + filepath = os.path.join(SAMPLE_IMAGES_FOLDER, filename) + file.save(filepath) + uploaded_files.append(filename) + + return jsonify({ + 'success': True, + 'message': f'成功上传 {len(uploaded_files)} 个图片文件', + 'uploaded_count': len(uploaded_files), + 'files': uploaded_files + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'message': f'上传失败: {str(e)}' + }), 500 + +@app.route('/api/upload/texts', methods=['POST']) +def upload_texts(): + """批量上传文本数据""" + try: + data = request.get_json() + + if not data or 'texts' not in data: + return jsonify({'success': False, 'message': '没有提供文本数据'}), 400 + + texts = data['texts'] + if not isinstance(texts, list): + return jsonify({'success': False, 'message': '文本数据格式错误'}), 400 + + # 保存文本数据到文件 + timestamp = str(int(time.time())) + filename = f"texts_{timestamp}.json" + filepath = os.path.join(TEXT_DATA_FOLDER, filename) + + with open(filepath, 'w', encoding='utf-8') as f: + json.dump(texts, f, ensure_ascii=False, indent=2) + + return jsonify({ + 'success': True, + 'message': f'成功上传 {len(texts)} 条文本', + 'uploaded_count': len(texts) + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'message': f'上传失败: {str(e)}' + }), 500 + +@app.route('/api/upload/file', methods=['POST']) +def upload_single_file(): + """上传单个文件""" + if 'file' not in request.files: + return jsonify({'success': False, 'message': '没有选择文件'}), 400 + + file = request.files['file'] + if file.filename == '': + return jsonify({'success': False, 'message': '没有选择文件'}), 400 + + if file and allowed_file(file.filename): + filename = secure_filename(file.filename) + timestamp = str(int(time.time())) + filename = f"{timestamp}_{filename}" + filepath = os.path.join(SAMPLE_IMAGES_FOLDER, filename) + file.save(filepath) + + return jsonify({ + 'success': True, + 'message': '文件上传成功', + 'filename': filename + }) + + return jsonify({'success': False, 'message': '不支持的文件类型'}), 400 + +@app.route('/api/data/list', methods=['GET']) +def list_data(): + """列出已上传的数据""" + try: + # 列出图片文件 + images = [] + if os.path.exists(SAMPLE_IMAGES_FOLDER): + for filename in os.listdir(SAMPLE_IMAGES_FOLDER): + if allowed_file(filename): + filepath = os.path.join(SAMPLE_IMAGES_FOLDER, filename) + stat = os.stat(filepath) + images.append({ + 'filename': filename, + 'size': stat.st_size, + 'modified': stat.st_mtime + }) + + # 列出文本文件 + texts = [] + if os.path.exists(TEXT_DATA_FOLDER): + for filename in os.listdir(TEXT_DATA_FOLDER): + if filename.endswith('.json'): + filepath = os.path.join(TEXT_DATA_FOLDER, filename) + stat = os.stat(filepath) + texts.append({ + 'filename': filename, + 'size': stat.st_size, + 'modified': stat.st_mtime + }) + + return jsonify({ + 'success': True, + 'data': { + 'images': images, + 'texts': texts + } + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'message': f'获取数据列表失败: {str(e)}' + }), 500 + +@app.route('/api/gpu_status') +def gpu_status(): + """获取GPU状态""" + try: + from smart_gpu_launcher import get_gpu_memory_info + gpu_info = get_gpu_memory_info() + + return jsonify({ + 'success': True, + 'gpu_info': gpu_info + }) + + except Exception as e: + return jsonify({ + 'success': False, + 'message': f"获取GPU状态失败: {str(e)}" + }), 500 + +@app.route('/api/build_index', methods=['POST']) +def build_index(): + """构建检索索引""" + global retrieval_system + + if not retrieval_system: + return jsonify({ + 'success': False, + 'message': '系统未初始化' + }), 400 + + try: + # 获取所有图片和文本文件 + image_files = [] + text_data = [] + + # 扫描图片文件 + for ext in ALLOWED_EXTENSIONS: + pattern = os.path.join(SAMPLE_IMAGES_FOLDER, f"*.{ext}") + image_files.extend(glob.glob(pattern)) + + # 读取文本文件(支持.json和.txt格式) + text_files = glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.json")) + text_files.extend(glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.txt"))) + + for text_file in text_files: + try: + if text_file.endswith('.json'): + # 读取JSON格式的文本数据 + with open(text_file, 'r', encoding='utf-8') as f: + data = json.load(f) + if isinstance(data, list): + text_data.extend([str(item).strip() for item in data if str(item).strip()]) + else: + text_data.append(str(data).strip()) + else: + # 读取TXT格式的文本数据 + with open(text_file, 'r', encoding='utf-8') as f: + lines = [line.strip() for line in f.readlines() if line.strip()] + text_data.extend(lines) + except Exception as e: + logger.warning(f"读取文本文件失败 {text_file}: {e}") + + # 检查是否有数据可以构建索引 + if not image_files and not text_data: + return jsonify({ + 'success': False, + 'message': '没有找到可用的图片或文本数据,请先上传数据' + }), 400 + + # 构建索引 + if image_files: + logger.info(f"构建图片索引,共 {len(image_files)} 张图片") + retrieval_system.build_image_index_parallel(image_files) + + if text_data: + logger.info(f"构建文本索引,共 {len(text_data)} 条文本") + retrieval_system.build_text_index_parallel(text_data) + + return jsonify({ + 'success': True, + 'message': f'索引构建完成!图片: {len(image_files)} 张,文本: {len(text_data)} 条', + 'image_count': len(image_files), + 'text_count': len(text_data) + }) + + except Exception as e: + logger.error(f"构建索引失败: {str(e)}") + return jsonify({ + 'success': False, + 'message': f'构建索引失败: {str(e)}' + }), 500 + +@app.route('/api/data/stats', methods=['GET']) +def get_data_stats(): + """获取数据统计信息""" + try: + # 统计图片文件 + image_count = 0 + for ext in ALLOWED_EXTENSIONS: + pattern = os.path.join(SAMPLE_IMAGES_FOLDER, f"*.{ext}") + image_count += len(glob.glob(pattern)) + + # 统计文本数据 + text_count = 0 + text_files = glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.txt")) + for text_file in text_files: + try: + with open(text_file, 'r', encoding='utf-8') as f: + lines = [line.strip() for line in f.readlines() if line.strip()] + text_count += len(lines) + except Exception: + continue + + return jsonify({ + 'success': True, + 'image_count': image_count, + 'text_count': text_count + }) + + except Exception as e: + logger.error(f"获取数据统计失败: {str(e)}") + return jsonify({ + 'success': False, + 'message': f'获取统计失败: {str(e)}' + }), 500 + +@app.route('/api/data/clear', methods=['POST']) +def clear_data(): + """清空所有数据""" + try: + # 清空图片文件 + for ext in ALLOWED_EXTENSIONS: + pattern = os.path.join(SAMPLE_IMAGES_FOLDER, f"*.{ext}") + for file_path in glob.glob(pattern): + try: + os.remove(file_path) + except Exception as e: + logger.warning(f"删除图片文件失败 {file_path}: {e}") + + # 清空文本文件 + text_files = glob.glob(os.path.join(TEXT_DATA_FOLDER, "*.txt")) + for text_file in text_files: + try: + os.remove(text_file) + except Exception as e: + logger.warning(f"删除文本文件失败 {text_file}: {e}") + + # 重置索引 + global retrieval_system + if retrieval_system: + retrieval_system.text_index = None + retrieval_system.image_index = None + retrieval_system.text_data = [] + retrieval_system.image_data = [] + + return jsonify({ + 'success': True, + 'message': '数据已清空' + }) + + except Exception as e: + logger.error(f"清空数据失败: {str(e)}") + return jsonify({ + 'success': False, + 'message': f'清空数据失败: {str(e)}' + }), 500 + +@app.route('/uploads/') +def uploaded_file(filename): + """提供上传文件的访问""" + return send_file(os.path.join(SAMPLE_IMAGES_FOLDER, filename)) + +def print_startup_info(): + """打印启动信息""" + print("🚀 启动多GPU多模态检索Web应用") + print("=" * 60) + print("访问地址: http://localhost:5000") + print("支持功能:") + print(" 📝 文搜文 - 文本查找相似文本") + print(" 🖼️ 文搜图 - 文本查找相关图片") + print(" 📝 图搜文 - 图片查找相关文本") + print(" 🖼️ 图搜图 - 图片查找相似图片") + print(" 📤 批量上传 - 图片和文本数据管理") + print("GPU配置:") + + try: + import torch + if torch.cuda.is_available(): + gpu_count = torch.cuda.device_count() + print(f" 🖥️ 检测到 {gpu_count} 个GPU") + for i in range(gpu_count): + name = torch.cuda.get_device_name(i) + props = torch.cuda.get_device_properties(i) + memory_gb = props.total_memory / 1024**3 + print(f" GPU {i}: {name} ({memory_gb:.1f}GB)") + else: + print(" ❌ CUDA不可用") + except Exception as e: + print(f" ❌ GPU检查失败: {e}") + + print("=" * 60) + +def auto_initialize(): + """启动时自动初始化系统""" + global retrieval_system + + try: + logger.info("🚀 启动时自动初始化多GPU检索系统...") + + # 导入多GPU检索系统 + from multimodal_retrieval_multigpu import MultiGPUMultimodalRetrieval + + # 初始化系统 + retrieval_system = MultiGPUMultimodalRetrieval() + + if retrieval_system.model is None: + raise Exception("模型加载失败") + + logger.info("✅ 系统自动初始化成功") + return True + + except Exception as e: + logger.error(f"❌ 系统自动初始化失败: {str(e)}") + logger.error(traceback.format_exc()) + return False + +if __name__ == '__main__': + print_startup_info() + + # 启动时自动初始化 + auto_initialize() + + # 启动Flask应用 + app.run( + host='0.0.0.0', + port=5000, + debug=False, + threaded=True + )