add multi card run

This commit is contained in:
eust-w 2025-09-22 19:47:14 +08:00
parent cadbab7541
commit 73ce51c611
5 changed files with 324 additions and 105 deletions

View File

@ -9,6 +9,7 @@ import torch
import numpy as np
from PIL import Image
from ops_mm_embedding_v1 import OpsMMEmbeddingV1
from concurrent.futures import ThreadPoolExecutor
from typing import List, Union, Tuple, Dict, Any, Optional
import os
import json
@ -33,7 +34,11 @@ class MultimodalRetrievalLocal:
use_all_gpus: bool = True,
gpu_ids: List[int] = None,
min_memory_gb: int = 12,
index_path: str = "local_faiss_index"):
index_path: str = "local_faiss_index",
shard_model_across_gpus: bool = False,
load_in_4bit: bool = False,
load_in_8bit: bool = False,
torch_dtype: Optional[torch.dtype] = torch.bfloat16):
"""
初始化多模态检索系统
@ -46,6 +51,10 @@ class MultimodalRetrievalLocal:
"""
self.model_path = model_path
self.index_path = index_path
self.shard_model_across_gpus = shard_model_across_gpus
self.load_in_4bit = load_in_4bit
self.load_in_8bit = load_in_8bit
self.torch_dtype = torch_dtype
# 检查模型路径
if not os.path.exists(model_path):
@ -116,14 +125,53 @@ class MultimodalRetrievalLocal:
"""加载多模态嵌入模型 OpsMMEmbeddingV1"""
logger.info(f"加载本地多模态嵌入模型: {self.model_path}")
try:
device_str = "cuda" if self.use_gpu else "cpu"
self.model = OpsMMEmbeddingV1(
self.model_path,
device=device_str,
attn_implementation=None,
)
self.models: List[OpsMMEmbeddingV1] = []
if self.use_gpu and len(self.gpu_ids) > 1 and self.shard_model_across_gpus:
# Tensor-parallel sharding using device_map='auto' across visible GPUs
logger.info(f"启用模型跨GPU切片(shard),使用设备: {self.gpu_ids}")
# Rely on CUDA_VISIBLE_DEVICES to constrain which GPUs are visible, or use accelerate's auto mapping
ref_model = OpsMMEmbeddingV1(
self.model_path,
device="cuda",
attn_implementation=None,
device_map="auto",
load_in_4bit=self.load_in_4bit,
load_in_8bit=self.load_in_8bit,
torch_dtype=self.torch_dtype,
)
# For sharded model, we keep a single logical model reference
self.models.append(ref_model)
else:
if self.use_gpu and len(self.gpu_ids) > 1:
logger.info(f"检测到多GPU可用: {self.gpu_ids},为每张卡加载一个模型副本(数据并行)")
for gid in self.gpu_ids:
device_str = f"cuda:{gid}"
self.models.append(
OpsMMEmbeddingV1(
self.model_path,
device=device_str,
attn_implementation=None,
load_in_4bit=self.load_in_4bit,
load_in_8bit=self.load_in_8bit,
torch_dtype=self.torch_dtype,
)
)
ref_model = self.models[0]
else:
device_str = "cuda" if self.use_gpu else "cpu"
logger.info(f"使用单设备: {device_str}")
ref_model = OpsMMEmbeddingV1(
self.model_path,
device=device_str,
attn_implementation=None,
load_in_4bit=self.load_in_4bit,
load_in_8bit=self.load_in_8bit,
torch_dtype=self.torch_dtype,
)
self.models.append(ref_model)
# 获取向量维度
self.vector_dim = int(getattr(self.model.base_model.config, "hidden_size"))
self.vector_dim = int(getattr(ref_model.base_model.config, "hidden_size"))
logger.info(f"向量维度: {self.vector_dim}")
logger.info("嵌入模型加载成功")
except Exception as e:
@ -158,33 +206,103 @@ class MultimodalRetrievalLocal:
logger.error(f"元数据加载失败: {str(e)}")
def encode_text(self, text: Union[str, List[str]]) -> np.ndarray:
"""编码文本为向量(使用 OpsMMEmbeddingV1"""
"""编码文本为向量(使用 OpsMMEmbeddingV1多GPU并行"""
if isinstance(text, str):
text = [text]
with torch.inference_mode():
emb = self.model.get_text_embeddings(texts=text)
text_embeddings = emb.detach().float().cpu().numpy()
# emb 已经做过 L2 归一化,这里保持一致
return text_embeddings[0] if len(text) == 1 else text_embeddings
if len(text) == 0:
return np.zeros((0, self.vector_dim))
# 单条或单模型直接处理
if len(self.models) == 1 or len(text) == 1:
with torch.inference_mode():
emb = self.models[0].get_text_embeddings(texts=text)
arr = emb.detach().float().cpu().numpy()
return arr[0] if len(text) == 1 else arr
# 多GPU并行按设备平均切分
shard_count = len(self.models)
shards: List[List[str]] = []
for i in range(shard_count):
shards.append([])
for idx, t in enumerate(text):
shards[idx % shard_count].append(t)
results: List[np.ndarray] = [None] * shard_count # type: ignore
def run_shard(model: OpsMMEmbeddingV1, shard_texts: List[str], shard_idx: int):
if len(shard_texts) == 0:
results[shard_idx] = np.zeros((0, self.vector_dim))
return
with torch.inference_mode():
emb = model.get_text_embeddings(texts=shard_texts)
results[shard_idx] = emb.detach().float().cpu().numpy()
with ThreadPoolExecutor(max_workers=shard_count) as ex:
for i, (model, shard_texts) in enumerate(zip(self.models, shards)):
ex.submit(run_shard, model, shard_texts, i)
# 还原原始顺序(按 round-robin 拼接)
merged: List[np.ndarray] = []
max_len = max(len(s) for s in shards)
for k in range(max_len):
for i in range(shard_count):
if k < len(shards[i]) and results[i].shape[0] > 0:
merged.append(results[i][k:k+1])
if len(merged) == 0:
return np.zeros((0, self.vector_dim))
return np.vstack(merged)
def encode_image(self, image: Union[Image.Image, List[Image.Image]]) -> np.ndarray:
"""编码图像为向量(使用 OpsMMEmbeddingV1"""
"""编码图像为向量(使用 OpsMMEmbeddingV1多GPU并行"""
try:
# 规范为列表
images: List[Image.Image]
if isinstance(image, Image.Image):
images = [image]
images: List[Image.Image] = [image]
else:
images = image
if not images:
logger.error("encode_image: 图像列表为空")
return np.zeros((0, self.vector_dim))
# 强制为 RGB
rgb_images = [img.convert('RGB') if img.mode != 'RGB' else img for img in images]
with torch.inference_mode():
emb = self.model.get_image_embeddings(images=rgb_images)
image_embeddings = emb.detach().float().cpu().numpy()
return image_embeddings
# 单张或单模型直接处理
if len(self.models) == 1 or len(rgb_images) == 1:
with torch.inference_mode():
emb = self.models[0].get_image_embeddings(images=rgb_images)
return emb.detach().float().cpu().numpy()
# 多GPU并行按设备平均切分
shard_count = len(self.models)
shards: List[List[Image.Image]] = [[] for _ in range(shard_count)]
for idx, img in enumerate(rgb_images):
shards[idx % shard_count].append(img)
results: List[np.ndarray] = [None] * shard_count # type: ignore
def run_shard(model: OpsMMEmbeddingV1, shard_imgs: List[Image.Image], shard_idx: int):
if len(shard_imgs) == 0:
results[shard_idx] = np.zeros((0, self.vector_dim))
return
with torch.inference_mode():
emb = model.get_image_embeddings(images=shard_imgs)
results[shard_idx] = emb.detach().float().cpu().numpy()
with ThreadPoolExecutor(max_workers=shard_count) as ex:
for i, (model, shard_imgs) in enumerate(zip(self.models, shards)):
ex.submit(run_shard, model, shard_imgs, i)
# 还原原始顺序(按 round-robin 拼接)
merged: List[np.ndarray] = []
max_len = max(len(s) for s in shards)
for k in range(max_len):
for i in range(shard_count):
if k < len(shards[i]) and results[i].shape[0] > 0:
merged.append(results[i][k:k+1])
if len(merged) == 0:
return np.zeros((0, self.vector_dim))
return np.vstack(merged)
except Exception as e:
logger.error(f"encode_image: 异常: {str(e)}")
return np.zeros((0, self.vector_dim))

View File

@ -18,17 +18,36 @@ class OpsMMEmbeddingV1(nn.Module):
device: str = "cuda",
max_length: Optional[int] = None,
attn_implementation: Optional[str] = None,
device_map: Optional[str] = None,
load_in_4bit: bool = False,
load_in_8bit: bool = False,
torch_dtype: Optional[torch.dtype] = torch.bfloat16,
):
super().__init__()
self.device = device
self.max_length = max_length
self.default_instruction = "You are a helpful assistant."
self.base_model = AutoModelForImageTextToText.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
from transformers import AutoModelForImageTextToText
load_kwargs = dict(
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
attn_implementation=attn_implementation,
).to(self.device)
)
# Quantization options (requires bitsandbytes for 4/8-bit)
if load_in_4bit:
load_kwargs["load_in_4bit"] = True
if load_in_8bit:
load_kwargs["load_in_8bit"] = True
if device_map is not None:
load_kwargs["device_map"] = device_map
self.base_model = AutoModelForImageTextToText.from_pretrained(
model_name,
**load_kwargs,
)
# Only move to a single device when not using tensor-parallel sharding
if device_map is None:
self.base_model = self.base_model.to(self.device)
self.processor = AutoProcessor.from_pretrained(model_name, min_pixels=256 * 28 * 28, max_pixels=1280 * 28 * 28)
self.processor.tokenizer.padding_side = "left"

11
run_server.sh Normal file
View File

@ -0,0 +1,11 @@
#!/usr/bin/env bash
set -euo pipefail
# Select GPUs 0 and 1 for tensor-parallel sharding
export CUDA_VISIBLE_DEVICES=0,1
# Unbuffered stdout for real-time logs
export PYTHONUNBUFFERED=1
# Start the local web app
exec python3 web_app_local.py "$@"

View File

@ -824,87 +824,66 @@
updateDataStats();
}
// 批量上传图片
async function uploadBatchImages(files) {
try {
const progressDiv = document.getElementById('imageUploadProgress');
const progressBar = progressDiv.querySelector('.progress-bar');
const progressText = document.getElementById('imageProgressText');
// 批量上传图片调用批量API以利用多卡并行
async function uploadBatchImages(files) {
try {
const progressDiv = document.getElementById('imageUploadProgress');
const progressBar = progressDiv.querySelector('.progress-bar');
const progressText = document.getElementById('imageProgressText');
progressDiv.style.display = 'block';
progressText.textContent = `0/${files.length}`;
progressBar.style.width = '0%';
progressDiv.style.display = 'block';
progressText.textContent = `0/${files.length}`;
progressBar.style.width = '10%';
showAlert('info', `正在上传${files.length}张图片...`);
showAlert('info', `正在批量上传 ${files.length} 张图片...`);
let successCount = 0;
for (let i = 0; i < files.length; i++) {
const formData = new FormData();
formData.append('image', files[i]);
const response = await fetch('/api/add_image', {
method: 'POST',
body: formData
});
const data = await response.json();
if (data.success) {
successCount++;
} else {
console.error(`图片 ${files[i].name} 上传失败: ${data.error}`);
}
// 更新进度
const progress = Math.round(((i + 1) / files.length) * 100);
progressBar.style.width = `${progress}%`;
progressText.textContent = `${i + 1}/${files.length}`;
}
showAlert('success', `成功上传 ${successCount}/${files.length} 张图片`);
// 自动保存索引
await autoSaveIndex();
updateDataStats();
} catch (error) {
showAlert('danger', `图片上传失败: ${error.message}`);
} finally {
setTimeout(() => {
document.getElementById('imageUploadProgress').style.display = 'none';
}, 2000);
}
// 旧代码已删除
// 旧代码已删除
const formData = new FormData();
for (let i = 0; i < files.length; i++) {
formData.append('images', files[i]);
}
// 批量上传文本
async function uploadBatchTexts(texts) {
try {
showAlert('info', `正在上传${texts.length}条文本...`);
for (let i = 0; i < texts.length; i++) {
const response = await fetch('/api/add_text', {
method: 'POST',
headers: {'Content-Type': 'application/json'},
body: JSON.stringify({text: texts[i]})
});
const data = await response.json();
if (!data.success) {
throw new Error(`第${i+1}条文本上传失败: ${data.error}`);
}
}
showAlert('success', `成功上传${texts.length}条文本`);
// 自动保存索引
await autoSaveIndex();
updateDataStats();
} catch (error) {
showAlert('danger', `文本上传失败: ${error.message}`);
}
// 已替换为新的API调用
// 旧代码已删除
// 已删除
const response = await fetch('/api/add_images_batch', {
method: 'POST',
body: formData
});
const data = await response.json();
if (!data.success) {
throw new Error(data.error || '批量上传失败');
}
progressBar.style.width = '100%';
progressText.textContent = `${files.length}/${files.length}`;
showAlert('success', data.message || `成功上传 ${files.length} 张图片`);
await autoSaveIndex();
updateDataStats();
} catch (error) {
showAlert('danger', `图片上传失败: ${error.message}`);
} finally {
setTimeout(() => {
document.getElementById('imageUploadProgress').style.display = 'none';
}, 1500);
}
}
// 批量上传文本调用批量API以利用多卡并行
async function uploadBatchTexts(texts) {
try {
showAlert('info', `正在批量上传 ${texts.length} 条文本...`);
const response = await fetch('/api/add_texts_batch', {
method: 'POST',
headers: {'Content-Type': 'application/json'},
body: JSON.stringify({texts})
});
const data = await response.json();
if (!data.success) {
throw new Error(data.error || '批量上传失败');
}
showAlert('success', data.message || `成功上传 ${texts.length} 条文本`);
await autoSaveIndex();
updateDataStats();
} catch (error) {
showAlert('danger', `文本上传失败: ${error.message}`);
}
}
// 自动保存索引函数
async function autoSaveIndex() {

View File

@ -94,7 +94,8 @@ def init_retrieval_system():
retrieval_system = MultimodalRetrievalLocal(
model_path=model_path,
use_all_gpus=True,
index_path=app.config['INDEX_PATH']
index_path=app.config['INDEX_PATH'],
shard_model_across_gpus=True
)
logger.info("多模态检索系统初始化完成")
@ -181,6 +182,97 @@ def add_text():
# 清理临时文件
file_handler.cleanup_all_temp_files()
@app.route('/api/add_images_batch', methods=['POST'])
def add_images_batch():
"""批量添加图像(多卡并行友好)"""
try:
# 读取多文件
files = request.files.getlist('images')
if not files:
return jsonify({"success": False, "error": "没有上传文件"}), 400
retrieval = init_retrieval_system()
images = []
metadatas = []
image_paths = []
for file in files:
if file and allowed_file(file.filename):
image_data = file.read()
filename = secure_filename(file.filename)
image_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
with open(image_path, 'wb') as f:
f.write(image_data)
try:
img = Image.open(BytesIO(image_data))
if img.mode != 'RGB':
img = img.convert('RGB')
except Exception as e:
return jsonify({"success": False, "error": f"图像格式不支持: {filename}: {str(e)}"}), 400
images.append(img)
metadatas.append({
"filename": filename,
"timestamp": time.time(),
"source": "web_upload_batch",
"size": len(image_data),
"local_path": image_path
})
image_paths.append(image_path)
else:
return jsonify({"success": False, "error": f"不支持的文件类型: {file.filename}"}), 400
image_ids = retrieval.add_images(images, metadatas, image_paths)
retrieval.save_index()
stats = retrieval.get_stats()
return jsonify({
"success": True,
"message": f"批量添加成功: {len(image_ids)}/{len(images)}",
"image_ids": image_ids,
"debug": {
"server_time": datetime.now(timezone.utc).isoformat(),
"vector_dimension": stats.get("vector_dimension"),
"total_vectors": stats.get("total_vectors"),
}
})
except Exception as e:
logger.error(f"批量添加图像失败: {str(e)}")
return jsonify({"success": False, "error": str(e)}), 500
finally:
file_handler.cleanup_all_temp_files()
@app.route('/api/add_texts_batch', methods=['POST'])
def add_texts_batch():
"""批量添加文本(多卡并行友好)"""
try:
data = request.json
if not data or 'texts' not in data:
return jsonify({"success": False, "error": "缺少 texts 数组"}), 400
texts = [t for t in data.get('texts', []) if isinstance(t, str) and t.strip()]
if not texts:
return jsonify({"success": False, "error": "文本数组为空"}), 400
retrieval = init_retrieval_system()
metadatas = [{"timestamp": time.time(), "source": "web_upload_batch"} for _ in texts]
text_ids = retrieval.add_texts(texts, metadatas)
retrieval.save_index()
stats = retrieval.get_stats()
return jsonify({
"success": True,
"message": f"批量添加成功: {len(text_ids)}/{len(texts)}",
"text_ids": text_ids,
"debug": {
"server_time": datetime.now(timezone.utc).isoformat(),
"vector_dimension": stats.get("vector_dimension"),
"total_vectors": stats.get("total_vectors"),
}
})
except Exception as e:
logger.error(f"批量添加文本失败: {str(e)}")
return jsonify({"success": False, "error": str(e)}), 500
@app.route('/api/add_image', methods=['POST'])
def add_image():
"""添加图像"""