From 73ce51c611cefa6d3e4ded35aaea8d93ca8876f5 Mon Sep 17 00:00:00 2001 From: eust-w Date: Mon, 22 Sep 2025 19:47:14 +0800 Subject: [PATCH] :sparkles: add multi card run --- multimodal_retrieval_local.py | 160 +++++++++++++++++++++++++++++----- ops_mm_embedding_v1.py | 27 +++++- run_server.sh | 11 +++ templates/local_index.html | 137 ++++++++++++----------------- web_app_local.py | 94 +++++++++++++++++++- 5 files changed, 324 insertions(+), 105 deletions(-) create mode 100644 run_server.sh diff --git a/multimodal_retrieval_local.py b/multimodal_retrieval_local.py index f50d3f6..21c42b5 100644 --- a/multimodal_retrieval_local.py +++ b/multimodal_retrieval_local.py @@ -9,6 +9,7 @@ import torch import numpy as np from PIL import Image from ops_mm_embedding_v1 import OpsMMEmbeddingV1 +from concurrent.futures import ThreadPoolExecutor from typing import List, Union, Tuple, Dict, Any, Optional import os import json @@ -33,7 +34,11 @@ class MultimodalRetrievalLocal: use_all_gpus: bool = True, gpu_ids: List[int] = None, min_memory_gb: int = 12, - index_path: str = "local_faiss_index"): + index_path: str = "local_faiss_index", + shard_model_across_gpus: bool = False, + load_in_4bit: bool = False, + load_in_8bit: bool = False, + torch_dtype: Optional[torch.dtype] = torch.bfloat16): """ 初始化多模态检索系统 @@ -46,6 +51,10 @@ class MultimodalRetrievalLocal: """ self.model_path = model_path self.index_path = index_path + self.shard_model_across_gpus = shard_model_across_gpus + self.load_in_4bit = load_in_4bit + self.load_in_8bit = load_in_8bit + self.torch_dtype = torch_dtype # 检查模型路径 if not os.path.exists(model_path): @@ -116,14 +125,53 @@ class MultimodalRetrievalLocal: """加载多模态嵌入模型 OpsMMEmbeddingV1""" logger.info(f"加载本地多模态嵌入模型: {self.model_path}") try: - device_str = "cuda" if self.use_gpu else "cpu" - self.model = OpsMMEmbeddingV1( - self.model_path, - device=device_str, - attn_implementation=None, - ) + self.models: List[OpsMMEmbeddingV1] = [] + if self.use_gpu and len(self.gpu_ids) > 1 and self.shard_model_across_gpus: + # Tensor-parallel sharding using device_map='auto' across visible GPUs + logger.info(f"启用模型跨GPU切片(shard),使用设备: {self.gpu_ids}") + # Rely on CUDA_VISIBLE_DEVICES to constrain which GPUs are visible, or use accelerate's auto mapping + ref_model = OpsMMEmbeddingV1( + self.model_path, + device="cuda", + attn_implementation=None, + device_map="auto", + load_in_4bit=self.load_in_4bit, + load_in_8bit=self.load_in_8bit, + torch_dtype=self.torch_dtype, + ) + # For sharded model, we keep a single logical model reference + self.models.append(ref_model) + else: + if self.use_gpu and len(self.gpu_ids) > 1: + logger.info(f"检测到多GPU,可用: {self.gpu_ids},为每张卡加载一个模型副本(数据并行)") + for gid in self.gpu_ids: + device_str = f"cuda:{gid}" + self.models.append( + OpsMMEmbeddingV1( + self.model_path, + device=device_str, + attn_implementation=None, + load_in_4bit=self.load_in_4bit, + load_in_8bit=self.load_in_8bit, + torch_dtype=self.torch_dtype, + ) + ) + ref_model = self.models[0] + else: + device_str = "cuda" if self.use_gpu else "cpu" + logger.info(f"使用单设备: {device_str}") + ref_model = OpsMMEmbeddingV1( + self.model_path, + device=device_str, + attn_implementation=None, + load_in_4bit=self.load_in_4bit, + load_in_8bit=self.load_in_8bit, + torch_dtype=self.torch_dtype, + ) + self.models.append(ref_model) + # 获取向量维度 - self.vector_dim = int(getattr(self.model.base_model.config, "hidden_size")) + self.vector_dim = int(getattr(ref_model.base_model.config, "hidden_size")) logger.info(f"向量维度: {self.vector_dim}") logger.info("嵌入模型加载成功") except Exception as e: @@ -158,33 +206,103 @@ class MultimodalRetrievalLocal: logger.error(f"元数据加载失败: {str(e)}") def encode_text(self, text: Union[str, List[str]]) -> np.ndarray: - """编码文本为向量(使用 OpsMMEmbeddingV1)""" + """编码文本为向量(使用 OpsMMEmbeddingV1,多GPU并行)""" if isinstance(text, str): text = [text] - with torch.inference_mode(): - emb = self.model.get_text_embeddings(texts=text) - text_embeddings = emb.detach().float().cpu().numpy() - # emb 已经做过 L2 归一化,这里保持一致 - return text_embeddings[0] if len(text) == 1 else text_embeddings + if len(text) == 0: + return np.zeros((0, self.vector_dim)) + + # 单条或单模型直接处理 + if len(self.models) == 1 or len(text) == 1: + with torch.inference_mode(): + emb = self.models[0].get_text_embeddings(texts=text) + arr = emb.detach().float().cpu().numpy() + return arr[0] if len(text) == 1 else arr + + # 多GPU并行:按设备平均切分 + shard_count = len(self.models) + shards: List[List[str]] = [] + for i in range(shard_count): + shards.append([]) + for idx, t in enumerate(text): + shards[idx % shard_count].append(t) + + results: List[np.ndarray] = [None] * shard_count # type: ignore + + def run_shard(model: OpsMMEmbeddingV1, shard_texts: List[str], shard_idx: int): + if len(shard_texts) == 0: + results[shard_idx] = np.zeros((0, self.vector_dim)) + return + with torch.inference_mode(): + emb = model.get_text_embeddings(texts=shard_texts) + results[shard_idx] = emb.detach().float().cpu().numpy() + + with ThreadPoolExecutor(max_workers=shard_count) as ex: + for i, (model, shard_texts) in enumerate(zip(self.models, shards)): + ex.submit(run_shard, model, shard_texts, i) + + # 还原原始顺序(按 round-robin 拼接) + merged: List[np.ndarray] = [] + max_len = max(len(s) for s in shards) + for k in range(max_len): + for i in range(shard_count): + if k < len(shards[i]) and results[i].shape[0] > 0: + merged.append(results[i][k:k+1]) + if len(merged) == 0: + return np.zeros((0, self.vector_dim)) + return np.vstack(merged) def encode_image(self, image: Union[Image.Image, List[Image.Image]]) -> np.ndarray: - """编码图像为向量(使用 OpsMMEmbeddingV1)""" + """编码图像为向量(使用 OpsMMEmbeddingV1,多GPU并行)""" try: # 规范为列表 - images: List[Image.Image] if isinstance(image, Image.Image): - images = [image] + images: List[Image.Image] = [image] else: images = image if not images: logger.error("encode_image: 图像列表为空") return np.zeros((0, self.vector_dim)) + # 强制为 RGB rgb_images = [img.convert('RGB') if img.mode != 'RGB' else img for img in images] - with torch.inference_mode(): - emb = self.model.get_image_embeddings(images=rgb_images) - image_embeddings = emb.detach().float().cpu().numpy() - return image_embeddings + + # 单张或单模型直接处理 + if len(self.models) == 1 or len(rgb_images) == 1: + with torch.inference_mode(): + emb = self.models[0].get_image_embeddings(images=rgb_images) + return emb.detach().float().cpu().numpy() + + # 多GPU并行:按设备平均切分 + shard_count = len(self.models) + shards: List[List[Image.Image]] = [[] for _ in range(shard_count)] + for idx, img in enumerate(rgb_images): + shards[idx % shard_count].append(img) + + results: List[np.ndarray] = [None] * shard_count # type: ignore + + def run_shard(model: OpsMMEmbeddingV1, shard_imgs: List[Image.Image], shard_idx: int): + if len(shard_imgs) == 0: + results[shard_idx] = np.zeros((0, self.vector_dim)) + return + with torch.inference_mode(): + emb = model.get_image_embeddings(images=shard_imgs) + results[shard_idx] = emb.detach().float().cpu().numpy() + + with ThreadPoolExecutor(max_workers=shard_count) as ex: + for i, (model, shard_imgs) in enumerate(zip(self.models, shards)): + ex.submit(run_shard, model, shard_imgs, i) + + # 还原原始顺序(按 round-robin 拼接) + merged: List[np.ndarray] = [] + max_len = max(len(s) for s in shards) + for k in range(max_len): + for i in range(shard_count): + if k < len(shards[i]) and results[i].shape[0] > 0: + merged.append(results[i][k:k+1]) + if len(merged) == 0: + return np.zeros((0, self.vector_dim)) + return np.vstack(merged) except Exception as e: logger.error(f"encode_image: 异常: {str(e)}") return np.zeros((0, self.vector_dim)) diff --git a/ops_mm_embedding_v1.py b/ops_mm_embedding_v1.py index 482af79..c5fc930 100644 --- a/ops_mm_embedding_v1.py +++ b/ops_mm_embedding_v1.py @@ -18,17 +18,36 @@ class OpsMMEmbeddingV1(nn.Module): device: str = "cuda", max_length: Optional[int] = None, attn_implementation: Optional[str] = None, + device_map: Optional[str] = None, + load_in_4bit: bool = False, + load_in_8bit: bool = False, + torch_dtype: Optional[torch.dtype] = torch.bfloat16, ): super().__init__() self.device = device self.max_length = max_length self.default_instruction = "You are a helpful assistant." - self.base_model = AutoModelForImageTextToText.from_pretrained( - model_name, - torch_dtype=torch.bfloat16, + from transformers import AutoModelForImageTextToText + load_kwargs = dict( + torch_dtype=torch_dtype, low_cpu_mem_usage=True, attn_implementation=attn_implementation, - ).to(self.device) + ) + # Quantization options (requires bitsandbytes for 4/8-bit) + if load_in_4bit: + load_kwargs["load_in_4bit"] = True + if load_in_8bit: + load_kwargs["load_in_8bit"] = True + if device_map is not None: + load_kwargs["device_map"] = device_map + + self.base_model = AutoModelForImageTextToText.from_pretrained( + model_name, + **load_kwargs, + ) + # Only move to a single device when not using tensor-parallel sharding + if device_map is None: + self.base_model = self.base_model.to(self.device) self.processor = AutoProcessor.from_pretrained(model_name, min_pixels=256 * 28 * 28, max_pixels=1280 * 28 * 28) self.processor.tokenizer.padding_side = "left" diff --git a/run_server.sh b/run_server.sh new file mode 100644 index 0000000..aa7c2a9 --- /dev/null +++ b/run_server.sh @@ -0,0 +1,11 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Select GPUs 0 and 1 for tensor-parallel sharding +export CUDA_VISIBLE_DEVICES=0,1 + +# Unbuffered stdout for real-time logs +export PYTHONUNBUFFERED=1 + +# Start the local web app +exec python3 web_app_local.py "$@" diff --git a/templates/local_index.html b/templates/local_index.html index 1071f2f..0e65449 100644 --- a/templates/local_index.html +++ b/templates/local_index.html @@ -824,88 +824,67 @@ updateDataStats(); } - // 批量上传图片 - async function uploadBatchImages(files) { - try { - const progressDiv = document.getElementById('imageUploadProgress'); - const progressBar = progressDiv.querySelector('.progress-bar'); - const progressText = document.getElementById('imageProgressText'); - - progressDiv.style.display = 'block'; - progressText.textContent = `0/${files.length}`; - progressBar.style.width = '0%'; - - showAlert('info', `正在上传${files.length}张图片...`); - - let successCount = 0; - - for (let i = 0; i < files.length; i++) { - const formData = new FormData(); - formData.append('image', files[i]); - - const response = await fetch('/api/add_image', { - method: 'POST', - body: formData - }); - - const data = await response.json(); - if (data.success) { - successCount++; - } else { - console.error(`图片 ${files[i].name} 上传失败: ${data.error}`); - } - - // 更新进度 - const progress = Math.round(((i + 1) / files.length) * 100); - progressBar.style.width = `${progress}%`; - progressText.textContent = `${i + 1}/${files.length}`; - } - - showAlert('success', `成功上传 ${successCount}/${files.length} 张图片`); - // 自动保存索引 - await autoSaveIndex(); - updateDataStats(); - } catch (error) { - showAlert('danger', `图片上传失败: ${error.message}`); - } finally { - setTimeout(() => { - document.getElementById('imageUploadProgress').style.display = 'none'; - }, 2000); - } - // 旧代码已删除 - // 旧代码已删除 + // 批量上传图片(调用批量API以利用多卡并行) +async function uploadBatchImages(files) { + try { + const progressDiv = document.getElementById('imageUploadProgress'); + const progressBar = progressDiv.querySelector('.progress-bar'); + const progressText = document.getElementById('imageProgressText'); + + progressDiv.style.display = 'block'; + progressText.textContent = `0/${files.length}`; + progressBar.style.width = '10%'; + + showAlert('info', `正在批量上传 ${files.length} 张图片...`); + + const formData = new FormData(); + for (let i = 0; i < files.length; i++) { + formData.append('images', files[i]); } - // 批量上传文本 - async function uploadBatchTexts(texts) { - try { - showAlert('info', `正在上传${texts.length}条文本...`); - - for (let i = 0; i < texts.length; i++) { - const response = await fetch('/api/add_text', { - method: 'POST', - headers: {'Content-Type': 'application/json'}, - body: JSON.stringify({text: texts[i]}) - }); - - const data = await response.json(); - if (!data.success) { - throw new Error(`第${i+1}条文本上传失败: ${data.error}`); - } - } - - showAlert('success', `成功上传${texts.length}条文本`); - // 自动保存索引 - await autoSaveIndex(); - updateDataStats(); - } catch (error) { - showAlert('danger', `文本上传失败: ${error.message}`); - } - // 已替换为新的API调用 - // 旧代码已删除 - // 已删除 + const response = await fetch('/api/add_images_batch', { + method: 'POST', + body: formData + }); + const data = await response.json(); + if (!data.success) { + throw new Error(data.error || '批量上传失败'); } - + progressBar.style.width = '100%'; + progressText.textContent = `${files.length}/${files.length}`; + showAlert('success', data.message || `成功上传 ${files.length} 张图片`); + await autoSaveIndex(); + updateDataStats(); + } catch (error) { + showAlert('danger', `图片上传失败: ${error.message}`); + } finally { + setTimeout(() => { + document.getElementById('imageUploadProgress').style.display = 'none'; + }, 1500); + } +} + + // 批量上传文本(调用批量API以利用多卡并行) +async function uploadBatchTexts(texts) { + try { + showAlert('info', `正在批量上传 ${texts.length} 条文本...`); + const response = await fetch('/api/add_texts_batch', { + method: 'POST', + headers: {'Content-Type': 'application/json'}, + body: JSON.stringify({texts}) + }); + const data = await response.json(); + if (!data.success) { + throw new Error(data.error || '批量上传失败'); + } + showAlert('success', data.message || `成功上传 ${texts.length} 条文本`); + await autoSaveIndex(); + updateDataStats(); + } catch (error) { + showAlert('danger', `文本上传失败: ${error.message}`); + } +} + // 自动保存索引函数 async function autoSaveIndex() { try { diff --git a/web_app_local.py b/web_app_local.py index 83c5863..bd04d16 100644 --- a/web_app_local.py +++ b/web_app_local.py @@ -94,7 +94,8 @@ def init_retrieval_system(): retrieval_system = MultimodalRetrievalLocal( model_path=model_path, use_all_gpus=True, - index_path=app.config['INDEX_PATH'] + index_path=app.config['INDEX_PATH'], + shard_model_across_gpus=True ) logger.info("多模态检索系统初始化完成") @@ -181,6 +182,97 @@ def add_text(): # 清理临时文件 file_handler.cleanup_all_temp_files() +@app.route('/api/add_images_batch', methods=['POST']) +def add_images_batch(): + """批量添加图像(多卡并行友好)""" + try: + # 读取多文件 + files = request.files.getlist('images') + if not files: + return jsonify({"success": False, "error": "没有上传文件"}), 400 + + retrieval = init_retrieval_system() + + images = [] + metadatas = [] + image_paths = [] + + for file in files: + if file and allowed_file(file.filename): + image_data = file.read() + filename = secure_filename(file.filename) + image_path = os.path.join(app.config['UPLOAD_FOLDER'], filename) + with open(image_path, 'wb') as f: + f.write(image_data) + try: + img = Image.open(BytesIO(image_data)) + if img.mode != 'RGB': + img = img.convert('RGB') + except Exception as e: + return jsonify({"success": False, "error": f"图像格式不支持: {filename}: {str(e)}"}), 400 + images.append(img) + metadatas.append({ + "filename": filename, + "timestamp": time.time(), + "source": "web_upload_batch", + "size": len(image_data), + "local_path": image_path + }) + image_paths.append(image_path) + else: + return jsonify({"success": False, "error": f"不支持的文件类型: {file.filename}"}), 400 + + image_ids = retrieval.add_images(images, metadatas, image_paths) + retrieval.save_index() + + stats = retrieval.get_stats() + return jsonify({ + "success": True, + "message": f"批量添加成功: {len(image_ids)}/{len(images)}", + "image_ids": image_ids, + "debug": { + "server_time": datetime.now(timezone.utc).isoformat(), + "vector_dimension": stats.get("vector_dimension"), + "total_vectors": stats.get("total_vectors"), + } + }) + except Exception as e: + logger.error(f"批量添加图像失败: {str(e)}") + return jsonify({"success": False, "error": str(e)}), 500 + finally: + file_handler.cleanup_all_temp_files() + +@app.route('/api/add_texts_batch', methods=['POST']) +def add_texts_batch(): + """批量添加文本(多卡并行友好)""" + try: + data = request.json + if not data or 'texts' not in data: + return jsonify({"success": False, "error": "缺少 texts 数组"}), 400 + texts = [t for t in data.get('texts', []) if isinstance(t, str) and t.strip()] + if not texts: + return jsonify({"success": False, "error": "文本数组为空"}), 400 + + retrieval = init_retrieval_system() + metadatas = [{"timestamp": time.time(), "source": "web_upload_batch"} for _ in texts] + text_ids = retrieval.add_texts(texts, metadatas) + retrieval.save_index() + + stats = retrieval.get_stats() + return jsonify({ + "success": True, + "message": f"批量添加成功: {len(text_ids)}/{len(texts)}", + "text_ids": text_ids, + "debug": { + "server_time": datetime.now(timezone.utc).isoformat(), + "vector_dimension": stats.get("vector_dimension"), + "total_vectors": stats.get("total_vectors"), + } + }) + except Exception as e: + logger.error(f"批量添加文本失败: {str(e)}") + return jsonify({"success": False, "error": str(e)}), 500 + @app.route('/api/add_image', methods=['POST']) def add_image(): """添加图像"""