add multi card run

This commit is contained in:
eust-w 2025-09-22 19:47:14 +08:00
parent cadbab7541
commit 73ce51c611
5 changed files with 324 additions and 105 deletions

View File

@ -9,6 +9,7 @@ import torch
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from ops_mm_embedding_v1 import OpsMMEmbeddingV1 from ops_mm_embedding_v1 import OpsMMEmbeddingV1
from concurrent.futures import ThreadPoolExecutor
from typing import List, Union, Tuple, Dict, Any, Optional from typing import List, Union, Tuple, Dict, Any, Optional
import os import os
import json import json
@ -33,7 +34,11 @@ class MultimodalRetrievalLocal:
use_all_gpus: bool = True, use_all_gpus: bool = True,
gpu_ids: List[int] = None, gpu_ids: List[int] = None,
min_memory_gb: int = 12, min_memory_gb: int = 12,
index_path: str = "local_faiss_index"): index_path: str = "local_faiss_index",
shard_model_across_gpus: bool = False,
load_in_4bit: bool = False,
load_in_8bit: bool = False,
torch_dtype: Optional[torch.dtype] = torch.bfloat16):
""" """
初始化多模态检索系统 初始化多模态检索系统
@ -46,6 +51,10 @@ class MultimodalRetrievalLocal:
""" """
self.model_path = model_path self.model_path = model_path
self.index_path = index_path self.index_path = index_path
self.shard_model_across_gpus = shard_model_across_gpus
self.load_in_4bit = load_in_4bit
self.load_in_8bit = load_in_8bit
self.torch_dtype = torch_dtype
# 检查模型路径 # 检查模型路径
if not os.path.exists(model_path): if not os.path.exists(model_path):
@ -116,14 +125,53 @@ class MultimodalRetrievalLocal:
"""加载多模态嵌入模型 OpsMMEmbeddingV1""" """加载多模态嵌入模型 OpsMMEmbeddingV1"""
logger.info(f"加载本地多模态嵌入模型: {self.model_path}") logger.info(f"加载本地多模态嵌入模型: {self.model_path}")
try: try:
device_str = "cuda" if self.use_gpu else "cpu" self.models: List[OpsMMEmbeddingV1] = []
self.model = OpsMMEmbeddingV1( if self.use_gpu and len(self.gpu_ids) > 1 and self.shard_model_across_gpus:
# Tensor-parallel sharding using device_map='auto' across visible GPUs
logger.info(f"启用模型跨GPU切片(shard),使用设备: {self.gpu_ids}")
# Rely on CUDA_VISIBLE_DEVICES to constrain which GPUs are visible, or use accelerate's auto mapping
ref_model = OpsMMEmbeddingV1(
self.model_path,
device="cuda",
attn_implementation=None,
device_map="auto",
load_in_4bit=self.load_in_4bit,
load_in_8bit=self.load_in_8bit,
torch_dtype=self.torch_dtype,
)
# For sharded model, we keep a single logical model reference
self.models.append(ref_model)
else:
if self.use_gpu and len(self.gpu_ids) > 1:
logger.info(f"检测到多GPU可用: {self.gpu_ids},为每张卡加载一个模型副本(数据并行)")
for gid in self.gpu_ids:
device_str = f"cuda:{gid}"
self.models.append(
OpsMMEmbeddingV1(
self.model_path, self.model_path,
device=device_str, device=device_str,
attn_implementation=None, attn_implementation=None,
load_in_4bit=self.load_in_4bit,
load_in_8bit=self.load_in_8bit,
torch_dtype=self.torch_dtype,
) )
)
ref_model = self.models[0]
else:
device_str = "cuda" if self.use_gpu else "cpu"
logger.info(f"使用单设备: {device_str}")
ref_model = OpsMMEmbeddingV1(
self.model_path,
device=device_str,
attn_implementation=None,
load_in_4bit=self.load_in_4bit,
load_in_8bit=self.load_in_8bit,
torch_dtype=self.torch_dtype,
)
self.models.append(ref_model)
# 获取向量维度 # 获取向量维度
self.vector_dim = int(getattr(self.model.base_model.config, "hidden_size")) self.vector_dim = int(getattr(ref_model.base_model.config, "hidden_size"))
logger.info(f"向量维度: {self.vector_dim}") logger.info(f"向量维度: {self.vector_dim}")
logger.info("嵌入模型加载成功") logger.info("嵌入模型加载成功")
except Exception as e: except Exception as e:
@ -158,33 +206,103 @@ class MultimodalRetrievalLocal:
logger.error(f"元数据加载失败: {str(e)}") logger.error(f"元数据加载失败: {str(e)}")
def encode_text(self, text: Union[str, List[str]]) -> np.ndarray: def encode_text(self, text: Union[str, List[str]]) -> np.ndarray:
"""编码文本为向量(使用 OpsMMEmbeddingV1""" """编码文本为向量(使用 OpsMMEmbeddingV1多GPU并行"""
if isinstance(text, str): if isinstance(text, str):
text = [text] text = [text]
if len(text) == 0:
return np.zeros((0, self.vector_dim))
# 单条或单模型直接处理
if len(self.models) == 1 or len(text) == 1:
with torch.inference_mode(): with torch.inference_mode():
emb = self.model.get_text_embeddings(texts=text) emb = self.models[0].get_text_embeddings(texts=text)
text_embeddings = emb.detach().float().cpu().numpy() arr = emb.detach().float().cpu().numpy()
# emb 已经做过 L2 归一化,这里保持一致 return arr[0] if len(text) == 1 else arr
return text_embeddings[0] if len(text) == 1 else text_embeddings
# 多GPU并行按设备平均切分
shard_count = len(self.models)
shards: List[List[str]] = []
for i in range(shard_count):
shards.append([])
for idx, t in enumerate(text):
shards[idx % shard_count].append(t)
results: List[np.ndarray] = [None] * shard_count # type: ignore
def run_shard(model: OpsMMEmbeddingV1, shard_texts: List[str], shard_idx: int):
if len(shard_texts) == 0:
results[shard_idx] = np.zeros((0, self.vector_dim))
return
with torch.inference_mode():
emb = model.get_text_embeddings(texts=shard_texts)
results[shard_idx] = emb.detach().float().cpu().numpy()
with ThreadPoolExecutor(max_workers=shard_count) as ex:
for i, (model, shard_texts) in enumerate(zip(self.models, shards)):
ex.submit(run_shard, model, shard_texts, i)
# 还原原始顺序(按 round-robin 拼接)
merged: List[np.ndarray] = []
max_len = max(len(s) for s in shards)
for k in range(max_len):
for i in range(shard_count):
if k < len(shards[i]) and results[i].shape[0] > 0:
merged.append(results[i][k:k+1])
if len(merged) == 0:
return np.zeros((0, self.vector_dim))
return np.vstack(merged)
def encode_image(self, image: Union[Image.Image, List[Image.Image]]) -> np.ndarray: def encode_image(self, image: Union[Image.Image, List[Image.Image]]) -> np.ndarray:
"""编码图像为向量(使用 OpsMMEmbeddingV1""" """编码图像为向量(使用 OpsMMEmbeddingV1多GPU并行"""
try: try:
# 规范为列表 # 规范为列表
images: List[Image.Image]
if isinstance(image, Image.Image): if isinstance(image, Image.Image):
images = [image] images: List[Image.Image] = [image]
else: else:
images = image images = image
if not images: if not images:
logger.error("encode_image: 图像列表为空") logger.error("encode_image: 图像列表为空")
return np.zeros((0, self.vector_dim)) return np.zeros((0, self.vector_dim))
# 强制为 RGB # 强制为 RGB
rgb_images = [img.convert('RGB') if img.mode != 'RGB' else img for img in images] rgb_images = [img.convert('RGB') if img.mode != 'RGB' else img for img in images]
# 单张或单模型直接处理
if len(self.models) == 1 or len(rgb_images) == 1:
with torch.inference_mode(): with torch.inference_mode():
emb = self.model.get_image_embeddings(images=rgb_images) emb = self.models[0].get_image_embeddings(images=rgb_images)
image_embeddings = emb.detach().float().cpu().numpy() return emb.detach().float().cpu().numpy()
return image_embeddings
# 多GPU并行按设备平均切分
shard_count = len(self.models)
shards: List[List[Image.Image]] = [[] for _ in range(shard_count)]
for idx, img in enumerate(rgb_images):
shards[idx % shard_count].append(img)
results: List[np.ndarray] = [None] * shard_count # type: ignore
def run_shard(model: OpsMMEmbeddingV1, shard_imgs: List[Image.Image], shard_idx: int):
if len(shard_imgs) == 0:
results[shard_idx] = np.zeros((0, self.vector_dim))
return
with torch.inference_mode():
emb = model.get_image_embeddings(images=shard_imgs)
results[shard_idx] = emb.detach().float().cpu().numpy()
with ThreadPoolExecutor(max_workers=shard_count) as ex:
for i, (model, shard_imgs) in enumerate(zip(self.models, shards)):
ex.submit(run_shard, model, shard_imgs, i)
# 还原原始顺序(按 round-robin 拼接)
merged: List[np.ndarray] = []
max_len = max(len(s) for s in shards)
for k in range(max_len):
for i in range(shard_count):
if k < len(shards[i]) and results[i].shape[0] > 0:
merged.append(results[i][k:k+1])
if len(merged) == 0:
return np.zeros((0, self.vector_dim))
return np.vstack(merged)
except Exception as e: except Exception as e:
logger.error(f"encode_image: 异常: {str(e)}") logger.error(f"encode_image: 异常: {str(e)}")
return np.zeros((0, self.vector_dim)) return np.zeros((0, self.vector_dim))

View File

@ -18,17 +18,36 @@ class OpsMMEmbeddingV1(nn.Module):
device: str = "cuda", device: str = "cuda",
max_length: Optional[int] = None, max_length: Optional[int] = None,
attn_implementation: Optional[str] = None, attn_implementation: Optional[str] = None,
device_map: Optional[str] = None,
load_in_4bit: bool = False,
load_in_8bit: bool = False,
torch_dtype: Optional[torch.dtype] = torch.bfloat16,
): ):
super().__init__() super().__init__()
self.device = device self.device = device
self.max_length = max_length self.max_length = max_length
self.default_instruction = "You are a helpful assistant." self.default_instruction = "You are a helpful assistant."
self.base_model = AutoModelForImageTextToText.from_pretrained( from transformers import AutoModelForImageTextToText
model_name, load_kwargs = dict(
torch_dtype=torch.bfloat16, torch_dtype=torch_dtype,
low_cpu_mem_usage=True, low_cpu_mem_usage=True,
attn_implementation=attn_implementation, attn_implementation=attn_implementation,
).to(self.device) )
# Quantization options (requires bitsandbytes for 4/8-bit)
if load_in_4bit:
load_kwargs["load_in_4bit"] = True
if load_in_8bit:
load_kwargs["load_in_8bit"] = True
if device_map is not None:
load_kwargs["device_map"] = device_map
self.base_model = AutoModelForImageTextToText.from_pretrained(
model_name,
**load_kwargs,
)
# Only move to a single device when not using tensor-parallel sharding
if device_map is None:
self.base_model = self.base_model.to(self.device)
self.processor = AutoProcessor.from_pretrained(model_name, min_pixels=256 * 28 * 28, max_pixels=1280 * 28 * 28) self.processor = AutoProcessor.from_pretrained(model_name, min_pixels=256 * 28 * 28, max_pixels=1280 * 28 * 28)
self.processor.tokenizer.padding_side = "left" self.processor.tokenizer.padding_side = "left"

11
run_server.sh Normal file
View File

@ -0,0 +1,11 @@
#!/usr/bin/env bash
set -euo pipefail
# Select GPUs 0 and 1 for tensor-parallel sharding
export CUDA_VISIBLE_DEVICES=0,1
# Unbuffered stdout for real-time logs
export PYTHONUNBUFFERED=1
# Start the local web app
exec python3 web_app_local.py "$@"

View File

@ -824,7 +824,7 @@
updateDataStats(); updateDataStats();
} }
// 批量上传图片 // 批量上传图片调用批量API以利用多卡并行
async function uploadBatchImages(files) { async function uploadBatchImages(files) {
try { try {
const progressDiv = document.getElementById('imageUploadProgress'); const progressDiv = document.getElementById('imageUploadProgress');
@ -833,36 +833,26 @@
progressDiv.style.display = 'block'; progressDiv.style.display = 'block';
progressText.textContent = `0/${files.length}`; progressText.textContent = `0/${files.length}`;
progressBar.style.width = '0%'; progressBar.style.width = '10%';
showAlert('info', `正在上传${files.length}张图片...`); showAlert('info', `正在批量上传 ${files.length} 张图片...`);
let successCount = 0;
for (let i = 0; i < files.length; i++) {
const formData = new FormData(); const formData = new FormData();
formData.append('image', files[i]); for (let i = 0; i < files.length; i++) {
formData.append('images', files[i]);
}
const response = await fetch('/api/add_image', { const response = await fetch('/api/add_images_batch', {
method: 'POST', method: 'POST',
body: formData body: formData
}); });
const data = await response.json(); const data = await response.json();
if (data.success) { if (!data.success) {
successCount++; throw new Error(data.error || '批量上传失败');
} else {
console.error(`图片 ${files[i].name} 上传失败: ${data.error}`);
} }
progressBar.style.width = '100%';
// 更新进度 progressText.textContent = `${files.length}/${files.length}`;
const progress = Math.round(((i + 1) / files.length) * 100); showAlert('success', data.message || `成功上传 ${files.length} 张图片`);
progressBar.style.width = `${progress}%`;
progressText.textContent = `${i + 1}/${files.length}`;
}
showAlert('success', `成功上传 ${successCount}/${files.length} 张图片`);
// 自动保存索引
await autoSaveIndex(); await autoSaveIndex();
updateDataStats(); updateDataStats();
} catch (error) { } catch (error) {
@ -870,40 +860,29 @@
} finally { } finally {
setTimeout(() => { setTimeout(() => {
document.getElementById('imageUploadProgress').style.display = 'none'; document.getElementById('imageUploadProgress').style.display = 'none';
}, 2000); }, 1500);
} }
// 旧代码已删除
// 旧代码已删除
} }
// 批量上传文本 // 批量上传文本调用批量API以利用多卡并行
async function uploadBatchTexts(texts) { async function uploadBatchTexts(texts) {
try { try {
showAlert('info', `正在上传${texts.length}条文本...`); showAlert('info', `正在批量上传 ${texts.length} 条文本...`);
const response = await fetch('/api/add_texts_batch', {
for (let i = 0; i < texts.length; i++) {
const response = await fetch('/api/add_text', {
method: 'POST', method: 'POST',
headers: {'Content-Type': 'application/json'}, headers: {'Content-Type': 'application/json'},
body: JSON.stringify({text: texts[i]}) body: JSON.stringify({texts})
}); });
const data = await response.json(); const data = await response.json();
if (!data.success) { if (!data.success) {
throw new Error(`第${i+1}条文本上传失败: ${data.error}`); throw new Error(data.error || '批量上传失败');
} }
} showAlert('success', data.message || `成功上传 ${texts.length} 条文本`);
showAlert('success', `成功上传${texts.length}条文本`);
// 自动保存索引
await autoSaveIndex(); await autoSaveIndex();
updateDataStats(); updateDataStats();
} catch (error) { } catch (error) {
showAlert('danger', `文本上传失败: ${error.message}`); showAlert('danger', `文本上传失败: ${error.message}`);
} }
// 已替换为新的API调用
// 旧代码已删除
// 已删除
} }
// 自动保存索引函数 // 自动保存索引函数

View File

@ -94,7 +94,8 @@ def init_retrieval_system():
retrieval_system = MultimodalRetrievalLocal( retrieval_system = MultimodalRetrievalLocal(
model_path=model_path, model_path=model_path,
use_all_gpus=True, use_all_gpus=True,
index_path=app.config['INDEX_PATH'] index_path=app.config['INDEX_PATH'],
shard_model_across_gpus=True
) )
logger.info("多模态检索系统初始化完成") logger.info("多模态检索系统初始化完成")
@ -181,6 +182,97 @@ def add_text():
# 清理临时文件 # 清理临时文件
file_handler.cleanup_all_temp_files() file_handler.cleanup_all_temp_files()
@app.route('/api/add_images_batch', methods=['POST'])
def add_images_batch():
"""批量添加图像(多卡并行友好)"""
try:
# 读取多文件
files = request.files.getlist('images')
if not files:
return jsonify({"success": False, "error": "没有上传文件"}), 400
retrieval = init_retrieval_system()
images = []
metadatas = []
image_paths = []
for file in files:
if file and allowed_file(file.filename):
image_data = file.read()
filename = secure_filename(file.filename)
image_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
with open(image_path, 'wb') as f:
f.write(image_data)
try:
img = Image.open(BytesIO(image_data))
if img.mode != 'RGB':
img = img.convert('RGB')
except Exception as e:
return jsonify({"success": False, "error": f"图像格式不支持: {filename}: {str(e)}"}), 400
images.append(img)
metadatas.append({
"filename": filename,
"timestamp": time.time(),
"source": "web_upload_batch",
"size": len(image_data),
"local_path": image_path
})
image_paths.append(image_path)
else:
return jsonify({"success": False, "error": f"不支持的文件类型: {file.filename}"}), 400
image_ids = retrieval.add_images(images, metadatas, image_paths)
retrieval.save_index()
stats = retrieval.get_stats()
return jsonify({
"success": True,
"message": f"批量添加成功: {len(image_ids)}/{len(images)}",
"image_ids": image_ids,
"debug": {
"server_time": datetime.now(timezone.utc).isoformat(),
"vector_dimension": stats.get("vector_dimension"),
"total_vectors": stats.get("total_vectors"),
}
})
except Exception as e:
logger.error(f"批量添加图像失败: {str(e)}")
return jsonify({"success": False, "error": str(e)}), 500
finally:
file_handler.cleanup_all_temp_files()
@app.route('/api/add_texts_batch', methods=['POST'])
def add_texts_batch():
"""批量添加文本(多卡并行友好)"""
try:
data = request.json
if not data or 'texts' not in data:
return jsonify({"success": False, "error": "缺少 texts 数组"}), 400
texts = [t for t in data.get('texts', []) if isinstance(t, str) and t.strip()]
if not texts:
return jsonify({"success": False, "error": "文本数组为空"}), 400
retrieval = init_retrieval_system()
metadatas = [{"timestamp": time.time(), "source": "web_upload_batch"} for _ in texts]
text_ids = retrieval.add_texts(texts, metadatas)
retrieval.save_index()
stats = retrieval.get_stats()
return jsonify({
"success": True,
"message": f"批量添加成功: {len(text_ids)}/{len(texts)}",
"text_ids": text_ids,
"debug": {
"server_time": datetime.now(timezone.utc).isoformat(),
"vector_dimension": stats.get("vector_dimension"),
"total_vectors": stats.get("total_vectors"),
}
})
except Exception as e:
logger.error(f"批量添加文本失败: {str(e)}")
return jsonify({"success": False, "error": str(e)}), 500
@app.route('/api/add_image', methods=['POST']) @app.route('/api/add_image', methods=['POST'])
def add_image(): def add_image():
"""添加图像""" """添加图像"""