✨ add multi card run
This commit is contained in:
parent
cadbab7541
commit
73ce51c611
@ -9,6 +9,7 @@ import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from ops_mm_embedding_v1 import OpsMMEmbeddingV1
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import List, Union, Tuple, Dict, Any, Optional
|
||||
import os
|
||||
import json
|
||||
@ -33,7 +34,11 @@ class MultimodalRetrievalLocal:
|
||||
use_all_gpus: bool = True,
|
||||
gpu_ids: List[int] = None,
|
||||
min_memory_gb: int = 12,
|
||||
index_path: str = "local_faiss_index"):
|
||||
index_path: str = "local_faiss_index",
|
||||
shard_model_across_gpus: bool = False,
|
||||
load_in_4bit: bool = False,
|
||||
load_in_8bit: bool = False,
|
||||
torch_dtype: Optional[torch.dtype] = torch.bfloat16):
|
||||
"""
|
||||
初始化多模态检索系统
|
||||
|
||||
@ -46,6 +51,10 @@ class MultimodalRetrievalLocal:
|
||||
"""
|
||||
self.model_path = model_path
|
||||
self.index_path = index_path
|
||||
self.shard_model_across_gpus = shard_model_across_gpus
|
||||
self.load_in_4bit = load_in_4bit
|
||||
self.load_in_8bit = load_in_8bit
|
||||
self.torch_dtype = torch_dtype
|
||||
|
||||
# 检查模型路径
|
||||
if not os.path.exists(model_path):
|
||||
@ -116,14 +125,53 @@ class MultimodalRetrievalLocal:
|
||||
"""加载多模态嵌入模型 OpsMMEmbeddingV1"""
|
||||
logger.info(f"加载本地多模态嵌入模型: {self.model_path}")
|
||||
try:
|
||||
device_str = "cuda" if self.use_gpu else "cpu"
|
||||
self.model = OpsMMEmbeddingV1(
|
||||
self.models: List[OpsMMEmbeddingV1] = []
|
||||
if self.use_gpu and len(self.gpu_ids) > 1 and self.shard_model_across_gpus:
|
||||
# Tensor-parallel sharding using device_map='auto' across visible GPUs
|
||||
logger.info(f"启用模型跨GPU切片(shard),使用设备: {self.gpu_ids}")
|
||||
# Rely on CUDA_VISIBLE_DEVICES to constrain which GPUs are visible, or use accelerate's auto mapping
|
||||
ref_model = OpsMMEmbeddingV1(
|
||||
self.model_path,
|
||||
device="cuda",
|
||||
attn_implementation=None,
|
||||
device_map="auto",
|
||||
load_in_4bit=self.load_in_4bit,
|
||||
load_in_8bit=self.load_in_8bit,
|
||||
torch_dtype=self.torch_dtype,
|
||||
)
|
||||
# For sharded model, we keep a single logical model reference
|
||||
self.models.append(ref_model)
|
||||
else:
|
||||
if self.use_gpu and len(self.gpu_ids) > 1:
|
||||
logger.info(f"检测到多GPU,可用: {self.gpu_ids},为每张卡加载一个模型副本(数据并行)")
|
||||
for gid in self.gpu_ids:
|
||||
device_str = f"cuda:{gid}"
|
||||
self.models.append(
|
||||
OpsMMEmbeddingV1(
|
||||
self.model_path,
|
||||
device=device_str,
|
||||
attn_implementation=None,
|
||||
load_in_4bit=self.load_in_4bit,
|
||||
load_in_8bit=self.load_in_8bit,
|
||||
torch_dtype=self.torch_dtype,
|
||||
)
|
||||
)
|
||||
ref_model = self.models[0]
|
||||
else:
|
||||
device_str = "cuda" if self.use_gpu else "cpu"
|
||||
logger.info(f"使用单设备: {device_str}")
|
||||
ref_model = OpsMMEmbeddingV1(
|
||||
self.model_path,
|
||||
device=device_str,
|
||||
attn_implementation=None,
|
||||
load_in_4bit=self.load_in_4bit,
|
||||
load_in_8bit=self.load_in_8bit,
|
||||
torch_dtype=self.torch_dtype,
|
||||
)
|
||||
self.models.append(ref_model)
|
||||
|
||||
# 获取向量维度
|
||||
self.vector_dim = int(getattr(self.model.base_model.config, "hidden_size"))
|
||||
self.vector_dim = int(getattr(ref_model.base_model.config, "hidden_size"))
|
||||
logger.info(f"向量维度: {self.vector_dim}")
|
||||
logger.info("嵌入模型加载成功")
|
||||
except Exception as e:
|
||||
@ -158,33 +206,103 @@ class MultimodalRetrievalLocal:
|
||||
logger.error(f"元数据加载失败: {str(e)}")
|
||||
|
||||
def encode_text(self, text: Union[str, List[str]]) -> np.ndarray:
|
||||
"""编码文本为向量(使用 OpsMMEmbeddingV1)"""
|
||||
"""编码文本为向量(使用 OpsMMEmbeddingV1,多GPU并行)"""
|
||||
if isinstance(text, str):
|
||||
text = [text]
|
||||
if len(text) == 0:
|
||||
return np.zeros((0, self.vector_dim))
|
||||
|
||||
# 单条或单模型直接处理
|
||||
if len(self.models) == 1 or len(text) == 1:
|
||||
with torch.inference_mode():
|
||||
emb = self.model.get_text_embeddings(texts=text)
|
||||
text_embeddings = emb.detach().float().cpu().numpy()
|
||||
# emb 已经做过 L2 归一化,这里保持一致
|
||||
return text_embeddings[0] if len(text) == 1 else text_embeddings
|
||||
emb = self.models[0].get_text_embeddings(texts=text)
|
||||
arr = emb.detach().float().cpu().numpy()
|
||||
return arr[0] if len(text) == 1 else arr
|
||||
|
||||
# 多GPU并行:按设备平均切分
|
||||
shard_count = len(self.models)
|
||||
shards: List[List[str]] = []
|
||||
for i in range(shard_count):
|
||||
shards.append([])
|
||||
for idx, t in enumerate(text):
|
||||
shards[idx % shard_count].append(t)
|
||||
|
||||
results: List[np.ndarray] = [None] * shard_count # type: ignore
|
||||
|
||||
def run_shard(model: OpsMMEmbeddingV1, shard_texts: List[str], shard_idx: int):
|
||||
if len(shard_texts) == 0:
|
||||
results[shard_idx] = np.zeros((0, self.vector_dim))
|
||||
return
|
||||
with torch.inference_mode():
|
||||
emb = model.get_text_embeddings(texts=shard_texts)
|
||||
results[shard_idx] = emb.detach().float().cpu().numpy()
|
||||
|
||||
with ThreadPoolExecutor(max_workers=shard_count) as ex:
|
||||
for i, (model, shard_texts) in enumerate(zip(self.models, shards)):
|
||||
ex.submit(run_shard, model, shard_texts, i)
|
||||
|
||||
# 还原原始顺序(按 round-robin 拼接)
|
||||
merged: List[np.ndarray] = []
|
||||
max_len = max(len(s) for s in shards)
|
||||
for k in range(max_len):
|
||||
for i in range(shard_count):
|
||||
if k < len(shards[i]) and results[i].shape[0] > 0:
|
||||
merged.append(results[i][k:k+1])
|
||||
if len(merged) == 0:
|
||||
return np.zeros((0, self.vector_dim))
|
||||
return np.vstack(merged)
|
||||
|
||||
def encode_image(self, image: Union[Image.Image, List[Image.Image]]) -> np.ndarray:
|
||||
"""编码图像为向量(使用 OpsMMEmbeddingV1)"""
|
||||
"""编码图像为向量(使用 OpsMMEmbeddingV1,多GPU并行)"""
|
||||
try:
|
||||
# 规范为列表
|
||||
images: List[Image.Image]
|
||||
if isinstance(image, Image.Image):
|
||||
images = [image]
|
||||
images: List[Image.Image] = [image]
|
||||
else:
|
||||
images = image
|
||||
if not images:
|
||||
logger.error("encode_image: 图像列表为空")
|
||||
return np.zeros((0, self.vector_dim))
|
||||
|
||||
# 强制为 RGB
|
||||
rgb_images = [img.convert('RGB') if img.mode != 'RGB' else img for img in images]
|
||||
|
||||
# 单张或单模型直接处理
|
||||
if len(self.models) == 1 or len(rgb_images) == 1:
|
||||
with torch.inference_mode():
|
||||
emb = self.model.get_image_embeddings(images=rgb_images)
|
||||
image_embeddings = emb.detach().float().cpu().numpy()
|
||||
return image_embeddings
|
||||
emb = self.models[0].get_image_embeddings(images=rgb_images)
|
||||
return emb.detach().float().cpu().numpy()
|
||||
|
||||
# 多GPU并行:按设备平均切分
|
||||
shard_count = len(self.models)
|
||||
shards: List[List[Image.Image]] = [[] for _ in range(shard_count)]
|
||||
for idx, img in enumerate(rgb_images):
|
||||
shards[idx % shard_count].append(img)
|
||||
|
||||
results: List[np.ndarray] = [None] * shard_count # type: ignore
|
||||
|
||||
def run_shard(model: OpsMMEmbeddingV1, shard_imgs: List[Image.Image], shard_idx: int):
|
||||
if len(shard_imgs) == 0:
|
||||
results[shard_idx] = np.zeros((0, self.vector_dim))
|
||||
return
|
||||
with torch.inference_mode():
|
||||
emb = model.get_image_embeddings(images=shard_imgs)
|
||||
results[shard_idx] = emb.detach().float().cpu().numpy()
|
||||
|
||||
with ThreadPoolExecutor(max_workers=shard_count) as ex:
|
||||
for i, (model, shard_imgs) in enumerate(zip(self.models, shards)):
|
||||
ex.submit(run_shard, model, shard_imgs, i)
|
||||
|
||||
# 还原原始顺序(按 round-robin 拼接)
|
||||
merged: List[np.ndarray] = []
|
||||
max_len = max(len(s) for s in shards)
|
||||
for k in range(max_len):
|
||||
for i in range(shard_count):
|
||||
if k < len(shards[i]) and results[i].shape[0] > 0:
|
||||
merged.append(results[i][k:k+1])
|
||||
if len(merged) == 0:
|
||||
return np.zeros((0, self.vector_dim))
|
||||
return np.vstack(merged)
|
||||
except Exception as e:
|
||||
logger.error(f"encode_image: 异常: {str(e)}")
|
||||
return np.zeros((0, self.vector_dim))
|
||||
|
||||
@ -18,17 +18,36 @@ class OpsMMEmbeddingV1(nn.Module):
|
||||
device: str = "cuda",
|
||||
max_length: Optional[int] = None,
|
||||
attn_implementation: Optional[str] = None,
|
||||
device_map: Optional[str] = None,
|
||||
load_in_4bit: bool = False,
|
||||
load_in_8bit: bool = False,
|
||||
torch_dtype: Optional[torch.dtype] = torch.bfloat16,
|
||||
):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.max_length = max_length
|
||||
self.default_instruction = "You are a helpful assistant."
|
||||
self.base_model = AutoModelForImageTextToText.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
from transformers import AutoModelForImageTextToText
|
||||
load_kwargs = dict(
|
||||
torch_dtype=torch_dtype,
|
||||
low_cpu_mem_usage=True,
|
||||
attn_implementation=attn_implementation,
|
||||
).to(self.device)
|
||||
)
|
||||
# Quantization options (requires bitsandbytes for 4/8-bit)
|
||||
if load_in_4bit:
|
||||
load_kwargs["load_in_4bit"] = True
|
||||
if load_in_8bit:
|
||||
load_kwargs["load_in_8bit"] = True
|
||||
if device_map is not None:
|
||||
load_kwargs["device_map"] = device_map
|
||||
|
||||
self.base_model = AutoModelForImageTextToText.from_pretrained(
|
||||
model_name,
|
||||
**load_kwargs,
|
||||
)
|
||||
# Only move to a single device when not using tensor-parallel sharding
|
||||
if device_map is None:
|
||||
self.base_model = self.base_model.to(self.device)
|
||||
|
||||
self.processor = AutoProcessor.from_pretrained(model_name, min_pixels=256 * 28 * 28, max_pixels=1280 * 28 * 28)
|
||||
self.processor.tokenizer.padding_side = "left"
|
||||
|
||||
11
run_server.sh
Normal file
11
run_server.sh
Normal file
@ -0,0 +1,11 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
# Select GPUs 0 and 1 for tensor-parallel sharding
|
||||
export CUDA_VISIBLE_DEVICES=0,1
|
||||
|
||||
# Unbuffered stdout for real-time logs
|
||||
export PYTHONUNBUFFERED=1
|
||||
|
||||
# Start the local web app
|
||||
exec python3 web_app_local.py "$@"
|
||||
@ -824,7 +824,7 @@
|
||||
updateDataStats();
|
||||
}
|
||||
|
||||
// 批量上传图片
|
||||
// 批量上传图片(调用批量API以利用多卡并行)
|
||||
async function uploadBatchImages(files) {
|
||||
try {
|
||||
const progressDiv = document.getElementById('imageUploadProgress');
|
||||
@ -833,36 +833,26 @@
|
||||
|
||||
progressDiv.style.display = 'block';
|
||||
progressText.textContent = `0/${files.length}`;
|
||||
progressBar.style.width = '0%';
|
||||
progressBar.style.width = '10%';
|
||||
|
||||
showAlert('info', `正在上传${files.length}张图片...`);
|
||||
showAlert('info', `正在批量上传 ${files.length} 张图片...`);
|
||||
|
||||
let successCount = 0;
|
||||
|
||||
for (let i = 0; i < files.length; i++) {
|
||||
const formData = new FormData();
|
||||
formData.append('image', files[i]);
|
||||
for (let i = 0; i < files.length; i++) {
|
||||
formData.append('images', files[i]);
|
||||
}
|
||||
|
||||
const response = await fetch('/api/add_image', {
|
||||
const response = await fetch('/api/add_images_batch', {
|
||||
method: 'POST',
|
||||
body: formData
|
||||
});
|
||||
|
||||
const data = await response.json();
|
||||
if (data.success) {
|
||||
successCount++;
|
||||
} else {
|
||||
console.error(`图片 ${files[i].name} 上传失败: ${data.error}`);
|
||||
if (!data.success) {
|
||||
throw new Error(data.error || '批量上传失败');
|
||||
}
|
||||
|
||||
// 更新进度
|
||||
const progress = Math.round(((i + 1) / files.length) * 100);
|
||||
progressBar.style.width = `${progress}%`;
|
||||
progressText.textContent = `${i + 1}/${files.length}`;
|
||||
}
|
||||
|
||||
showAlert('success', `成功上传 ${successCount}/${files.length} 张图片`);
|
||||
// 自动保存索引
|
||||
progressBar.style.width = '100%';
|
||||
progressText.textContent = `${files.length}/${files.length}`;
|
||||
showAlert('success', data.message || `成功上传 ${files.length} 张图片`);
|
||||
await autoSaveIndex();
|
||||
updateDataStats();
|
||||
} catch (error) {
|
||||
@ -870,40 +860,29 @@
|
||||
} finally {
|
||||
setTimeout(() => {
|
||||
document.getElementById('imageUploadProgress').style.display = 'none';
|
||||
}, 2000);
|
||||
}, 1500);
|
||||
}
|
||||
// 旧代码已删除
|
||||
// 旧代码已删除
|
||||
}
|
||||
|
||||
// 批量上传文本
|
||||
// 批量上传文本(调用批量API以利用多卡并行)
|
||||
async function uploadBatchTexts(texts) {
|
||||
try {
|
||||
showAlert('info', `正在上传${texts.length}条文本...`);
|
||||
|
||||
for (let i = 0; i < texts.length; i++) {
|
||||
const response = await fetch('/api/add_text', {
|
||||
showAlert('info', `正在批量上传 ${texts.length} 条文本...`);
|
||||
const response = await fetch('/api/add_texts_batch', {
|
||||
method: 'POST',
|
||||
headers: {'Content-Type': 'application/json'},
|
||||
body: JSON.stringify({text: texts[i]})
|
||||
body: JSON.stringify({texts})
|
||||
});
|
||||
|
||||
const data = await response.json();
|
||||
if (!data.success) {
|
||||
throw new Error(`第${i+1}条文本上传失败: ${data.error}`);
|
||||
throw new Error(data.error || '批量上传失败');
|
||||
}
|
||||
}
|
||||
|
||||
showAlert('success', `成功上传${texts.length}条文本`);
|
||||
// 自动保存索引
|
||||
showAlert('success', data.message || `成功上传 ${texts.length} 条文本`);
|
||||
await autoSaveIndex();
|
||||
updateDataStats();
|
||||
} catch (error) {
|
||||
showAlert('danger', `文本上传失败: ${error.message}`);
|
||||
}
|
||||
// 已替换为新的API调用
|
||||
// 旧代码已删除
|
||||
// 已删除
|
||||
}
|
||||
|
||||
// 自动保存索引函数
|
||||
|
||||
@ -94,7 +94,8 @@ def init_retrieval_system():
|
||||
retrieval_system = MultimodalRetrievalLocal(
|
||||
model_path=model_path,
|
||||
use_all_gpus=True,
|
||||
index_path=app.config['INDEX_PATH']
|
||||
index_path=app.config['INDEX_PATH'],
|
||||
shard_model_across_gpus=True
|
||||
)
|
||||
|
||||
logger.info("多模态检索系统初始化完成")
|
||||
@ -181,6 +182,97 @@ def add_text():
|
||||
# 清理临时文件
|
||||
file_handler.cleanup_all_temp_files()
|
||||
|
||||
@app.route('/api/add_images_batch', methods=['POST'])
|
||||
def add_images_batch():
|
||||
"""批量添加图像(多卡并行友好)"""
|
||||
try:
|
||||
# 读取多文件
|
||||
files = request.files.getlist('images')
|
||||
if not files:
|
||||
return jsonify({"success": False, "error": "没有上传文件"}), 400
|
||||
|
||||
retrieval = init_retrieval_system()
|
||||
|
||||
images = []
|
||||
metadatas = []
|
||||
image_paths = []
|
||||
|
||||
for file in files:
|
||||
if file and allowed_file(file.filename):
|
||||
image_data = file.read()
|
||||
filename = secure_filename(file.filename)
|
||||
image_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
|
||||
with open(image_path, 'wb') as f:
|
||||
f.write(image_data)
|
||||
try:
|
||||
img = Image.open(BytesIO(image_data))
|
||||
if img.mode != 'RGB':
|
||||
img = img.convert('RGB')
|
||||
except Exception as e:
|
||||
return jsonify({"success": False, "error": f"图像格式不支持: {filename}: {str(e)}"}), 400
|
||||
images.append(img)
|
||||
metadatas.append({
|
||||
"filename": filename,
|
||||
"timestamp": time.time(),
|
||||
"source": "web_upload_batch",
|
||||
"size": len(image_data),
|
||||
"local_path": image_path
|
||||
})
|
||||
image_paths.append(image_path)
|
||||
else:
|
||||
return jsonify({"success": False, "error": f"不支持的文件类型: {file.filename}"}), 400
|
||||
|
||||
image_ids = retrieval.add_images(images, metadatas, image_paths)
|
||||
retrieval.save_index()
|
||||
|
||||
stats = retrieval.get_stats()
|
||||
return jsonify({
|
||||
"success": True,
|
||||
"message": f"批量添加成功: {len(image_ids)}/{len(images)}",
|
||||
"image_ids": image_ids,
|
||||
"debug": {
|
||||
"server_time": datetime.now(timezone.utc).isoformat(),
|
||||
"vector_dimension": stats.get("vector_dimension"),
|
||||
"total_vectors": stats.get("total_vectors"),
|
||||
}
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"批量添加图像失败: {str(e)}")
|
||||
return jsonify({"success": False, "error": str(e)}), 500
|
||||
finally:
|
||||
file_handler.cleanup_all_temp_files()
|
||||
|
||||
@app.route('/api/add_texts_batch', methods=['POST'])
|
||||
def add_texts_batch():
|
||||
"""批量添加文本(多卡并行友好)"""
|
||||
try:
|
||||
data = request.json
|
||||
if not data or 'texts' not in data:
|
||||
return jsonify({"success": False, "error": "缺少 texts 数组"}), 400
|
||||
texts = [t for t in data.get('texts', []) if isinstance(t, str) and t.strip()]
|
||||
if not texts:
|
||||
return jsonify({"success": False, "error": "文本数组为空"}), 400
|
||||
|
||||
retrieval = init_retrieval_system()
|
||||
metadatas = [{"timestamp": time.time(), "source": "web_upload_batch"} for _ in texts]
|
||||
text_ids = retrieval.add_texts(texts, metadatas)
|
||||
retrieval.save_index()
|
||||
|
||||
stats = retrieval.get_stats()
|
||||
return jsonify({
|
||||
"success": True,
|
||||
"message": f"批量添加成功: {len(text_ids)}/{len(texts)}",
|
||||
"text_ids": text_ids,
|
||||
"debug": {
|
||||
"server_time": datetime.now(timezone.utc).isoformat(),
|
||||
"vector_dimension": stats.get("vector_dimension"),
|
||||
"total_vectors": stats.get("total_vectors"),
|
||||
}
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"批量添加文本失败: {str(e)}")
|
||||
return jsonify({"success": False, "error": str(e)}), 500
|
||||
|
||||
@app.route('/api/add_image', methods=['POST'])
|
||||
def add_image():
|
||||
"""添加图像"""
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user