527 lines
22 KiB
Python
527 lines
22 KiB
Python
"""Training simulation module - LLM training workload with PyTorch."""
|
|
|
|
import json
|
|
import os
|
|
import sys
|
|
import tempfile
|
|
import time
|
|
import subprocess
|
|
import shutil
|
|
import math
|
|
from datetime import datetime
|
|
from typing import Optional
|
|
|
|
from rich.console import Console
|
|
from rich.table import Table
|
|
from rich.progress import Progress, SpinnerColumn, BarColumn, TextColumn, TimeElapsedColumn
|
|
|
|
TORCH_AVAILABLE = False
|
|
try:
|
|
import torch
|
|
if torch.cuda.is_available():
|
|
TORCH_AVAILABLE = True
|
|
except ImportError:
|
|
pass
|
|
|
|
|
|
class TrainingSim:
|
|
|
|
def __init__(self, config: dict):
|
|
self.config = config
|
|
self.console = Console()
|
|
self.train_cfg = config.get("training", {})
|
|
|
|
def run(self) -> dict:
|
|
if not TORCH_AVAILABLE:
|
|
self.console.print("[yellow]PyTorch not available - skipping training simulation[/yellow]")
|
|
return {"error": "pytorch_not_available"}
|
|
|
|
gpu_count = torch.cuda.device_count()
|
|
model_name = self.train_cfg.get("model", "gpt2")
|
|
batch_size = self.train_cfg.get("batch_size", 8)
|
|
seq_length = self.train_cfg.get("seq_length", 2048)
|
|
num_steps = self.train_cfg.get("num_steps", 50)
|
|
warmup_steps = int(self.train_cfg.get("warmup_steps", 5))
|
|
dtype_str = self.train_cfg.get("dtype", "bf16")
|
|
|
|
dtype_map = {
|
|
"fp32": torch.float32,
|
|
"fp16": torch.float16,
|
|
"bf16": torch.bfloat16,
|
|
}
|
|
dtype = dtype_map.get(dtype_str, torch.bfloat16)
|
|
|
|
self.console.print(f"[cyan]Training Simulation[/cyan]")
|
|
self.console.print(f" Model: {model_name} | Batch: {batch_size} | Seq: {seq_length} | "
|
|
f"DType: {dtype_str} | Steps: {num_steps} | Warmup: {warmup_steps} | GPUs: {gpu_count}")
|
|
|
|
if self.train_cfg.get("mode", "ddp") == "ddp" and gpu_count > 1:
|
|
ddp_result = self._run_synthetic_ddp(gpu_count, batch_size, seq_length, num_steps, dtype_str)
|
|
if ddp_result.get("passed") or not self.train_cfg.get("allow_fallback", False):
|
|
return ddp_result
|
|
self.console.print("[yellow]DDP synthetic training failed, falling back to single-process synthetic path[/yellow]")
|
|
|
|
try:
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
except ImportError:
|
|
self.console.print("[yellow]transformers not installed - using synthetic model[/yellow]")
|
|
return self._run_synthetic(gpu_count, batch_size, seq_length, num_steps, dtype)
|
|
|
|
try:
|
|
self.console.print(f" Loading {model_name}...")
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_name,
|
|
torch_dtype=dtype,
|
|
device_map="auto" if gpu_count > 1 else None,
|
|
)
|
|
|
|
total_params = sum(p.numel() for p in model.parameters())
|
|
self.console.print(f" Parameters: {total_params / 1e6:.1f}M")
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
if tokenizer.pad_token is None:
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
input_ids = torch.randint(0, tokenizer.vocab_size, (batch_size, seq_length))
|
|
attention_mask = torch.ones_like(input_ids)
|
|
|
|
model.train()
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
|
|
|
|
if dtype in (torch.float16, torch.bfloat16):
|
|
scaler = torch.amp.GradScaler("cuda", enabled=(dtype == torch.float16))
|
|
|
|
step_times = []
|
|
mem_usage = []
|
|
|
|
with Progress(
|
|
SpinnerColumn(), TextColumn("[progress.description]{task.description}"),
|
|
BarColumn(), TextColumn("{task.completed}/{task.total}"),
|
|
TimeElapsedColumn(), console=self.console,
|
|
) as progress:
|
|
total_steps = num_steps + warmup_steps
|
|
task = progress.add_task("Training steps...", total=total_steps)
|
|
|
|
for step in range(total_steps):
|
|
torch.cuda.synchronize()
|
|
t0 = time.perf_counter()
|
|
|
|
input_ids = input_ids.to(model.device)
|
|
attention_mask = attention_mask.to(model.device)
|
|
|
|
if dtype in (torch.float16, torch.bfloat16):
|
|
with torch.amp.autocast("cuda", dtype=dtype):
|
|
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
|
|
loss = outputs.loss
|
|
else:
|
|
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
|
|
loss = outputs.loss
|
|
|
|
loss.backward()
|
|
optimizer.step()
|
|
optimizer.zero_grad()
|
|
|
|
torch.cuda.synchronize()
|
|
elapsed = time.perf_counter() - t0
|
|
step_times.append(elapsed)
|
|
|
|
if torch.cuda.is_available():
|
|
mem_used = torch.cuda.max_memory_allocated() / 1024**3
|
|
mem_usage.append(mem_used)
|
|
torch.cuda.reset_peak_memory_stats()
|
|
|
|
progress.advance(task)
|
|
|
|
measured_steps = step_times[warmup_steps:] if len(step_times) > warmup_steps else step_times
|
|
avg_step_time = sum(measured_steps) / len(measured_steps)
|
|
throughput = batch_size * seq_length / avg_step_time
|
|
jitter = self._jitter_pct(measured_steps)
|
|
peak_mem = round(max(mem_usage) if mem_usage else 0, 2)
|
|
final_loss = float(loss.item()) if hasattr(loss, "item") else float("nan")
|
|
passed = self._acceptance_pass(throughput, jitter, peak_mem, final_loss)
|
|
if self.train_cfg.get("require_distributed", True):
|
|
passed = False
|
|
|
|
return {
|
|
"model": model_name,
|
|
"total_params_m": round(total_params / 1e6, 1),
|
|
"gpu_count": gpu_count,
|
|
"dtype": dtype_str,
|
|
"batch_size": batch_size,
|
|
"seq_length": seq_length,
|
|
"num_steps": num_steps,
|
|
"warmup_steps": warmup_steps,
|
|
"total_steps": total_steps,
|
|
"avg_step_time_ms": round(avg_step_time * 1000, 1),
|
|
"throughput_tokens_per_sec": round(throughput, 0),
|
|
"throughput_samples_per_sec": round(batch_size / avg_step_time, 2),
|
|
"peak_memory_gb": peak_mem,
|
|
"final_loss": round(final_loss, 4),
|
|
"step_jitter_pct": round(jitter, 2),
|
|
"distributed_mode": "device_map",
|
|
"loss_finite": math.isfinite(final_loss),
|
|
"passed": passed,
|
|
"acceptance_gap": "8-GPU DDP was not used" if self.train_cfg.get("require_distributed", True) else "",
|
|
"timestamp": datetime.now().isoformat(),
|
|
}
|
|
|
|
except Exception as e:
|
|
self.console.print(f"[yellow]Model loading failed: {e}[/yellow]")
|
|
return self._run_synthetic(gpu_count, batch_size, seq_length, num_steps, dtype)
|
|
|
|
def _run_synthetic_ddp(self, gpu_count: int, batch_size: int, seq_length: int,
|
|
num_steps: int, dtype_str: str) -> dict:
|
|
"""Run the 1.5B synthetic Transformer with one process per GPU."""
|
|
torchrun = os.path.join(os.path.dirname(sys.executable), "torchrun")
|
|
if not os.path.isfile(torchrun):
|
|
torchrun = shutil.which("torchrun") or ""
|
|
if not torchrun:
|
|
return {
|
|
"model": "synthetic_transformer_1.5b",
|
|
"gpu_count": gpu_count,
|
|
"distributed_mode": "ddp",
|
|
"passed": False,
|
|
"error": "torchrun not found",
|
|
"timestamp": datetime.now().isoformat(),
|
|
}
|
|
|
|
script = r'''
|
|
import json
|
|
import math
|
|
import os
|
|
import time
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
|
|
def main():
|
|
local_rank = int(os.environ["LOCAL_RANK"])
|
|
world_size = int(os.environ["WORLD_SIZE"])
|
|
torch.cuda.set_device(local_rank)
|
|
dist.init_process_group("nccl")
|
|
|
|
global_batch = int(os.environ["TRAIN_BATCH_SIZE"])
|
|
local_batch = max(1, global_batch // world_size)
|
|
seq_length = int(os.environ["TRAIN_SEQ_LENGTH"])
|
|
num_steps = int(os.environ["TRAIN_NUM_STEPS"])
|
|
warmup_steps = int(os.environ.get("TRAIN_WARMUP_STEPS", "5"))
|
|
total_steps = num_steps + warmup_steps
|
|
dtype_name = os.environ.get("TRAIN_DTYPE", "bf16")
|
|
dtype = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32}.get(dtype_name, torch.bfloat16)
|
|
|
|
hidden_size = 4096
|
|
num_layers = 6
|
|
num_heads = 32
|
|
vocab_size = 32000
|
|
|
|
class SyntheticTransformer(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.embed = torch.nn.Embedding(vocab_size, hidden_size)
|
|
self.layers = torch.nn.ModuleList([
|
|
torch.nn.TransformerEncoderLayer(
|
|
d_model=hidden_size,
|
|
nhead=num_heads,
|
|
dim_feedforward=hidden_size * 4,
|
|
batch_first=True,
|
|
dtype=dtype,
|
|
) for _ in range(num_layers)
|
|
])
|
|
self.head = torch.nn.Linear(hidden_size, vocab_size, dtype=dtype)
|
|
|
|
def forward(self, x):
|
|
h = self.embed(x).to(dtype)
|
|
for layer in self.layers:
|
|
h = layer(h)
|
|
return self.head(h)
|
|
|
|
model = SyntheticTransformer().cuda()
|
|
total_params = sum(p.numel() for p in model.parameters())
|
|
model = DDP(model, device_ids=[local_rank], output_device=local_rank)
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
|
|
input_ids = torch.randint(0, vocab_size, (local_batch, seq_length), device="cuda")
|
|
step_times = []
|
|
last_loss = torch.tensor(float("nan"), device="cuda")
|
|
torch.cuda.reset_peak_memory_stats(local_rank)
|
|
|
|
for _ in range(total_steps):
|
|
torch.cuda.synchronize()
|
|
t0 = time.perf_counter()
|
|
with torch.amp.autocast("cuda", dtype=dtype, enabled=dtype in (torch.float16, torch.bfloat16)):
|
|
logits = model(input_ids)
|
|
loss = torch.nn.functional.cross_entropy(logits.reshape(-1, vocab_size), input_ids.reshape(-1))
|
|
loss.backward()
|
|
optimizer.step()
|
|
optimizer.zero_grad(set_to_none=True)
|
|
torch.cuda.synchronize()
|
|
step_times.append(time.perf_counter() - t0)
|
|
last_loss = loss.detach()
|
|
|
|
peak_mem = torch.tensor(torch.cuda.max_memory_allocated(local_rank) / 1024**3, device="cuda")
|
|
dist.all_reduce(peak_mem, op=dist.ReduceOp.MAX)
|
|
finite = torch.tensor(1 if math.isfinite(float(last_loss.item())) else 0, device="cuda")
|
|
dist.all_reduce(finite, op=dist.ReduceOp.MIN)
|
|
|
|
if dist.get_rank() == 0:
|
|
measured_steps = step_times[warmup_steps:] if len(step_times) > warmup_steps else step_times
|
|
avg_step = sum(measured_steps) / len(measured_steps)
|
|
mean = avg_step
|
|
jitter = max(abs(v - mean) / mean * 100 for v in measured_steps) if mean else 0.0
|
|
throughput = global_batch * seq_length / avg_step if avg_step else 0.0
|
|
print("TRAINING_DDP_JSON=" + json.dumps({
|
|
"model": "synthetic_transformer_1.5b",
|
|
"total_params_m": round(total_params / 1e6, 1),
|
|
"num_layers": num_layers,
|
|
"hidden_size": hidden_size,
|
|
"gpu_count": world_size,
|
|
"dtype": dtype_name,
|
|
"batch_size": global_batch,
|
|
"local_batch_size": local_batch,
|
|
"seq_length": seq_length,
|
|
"num_steps": num_steps,
|
|
"warmup_steps": warmup_steps,
|
|
"total_steps": total_steps,
|
|
"avg_step_time_ms": round(avg_step * 1000, 1),
|
|
"throughput_tokens_per_sec": round(throughput, 0),
|
|
"throughput_samples_per_sec": round(global_batch / avg_step, 2) if avg_step else 0,
|
|
"peak_memory_gb": round(float(peak_mem.item()), 2),
|
|
"final_loss": round(float(last_loss.item()), 4),
|
|
"step_jitter_pct": round(jitter, 2),
|
|
"distributed_mode": "ddp",
|
|
"loss_finite": bool(int(finite.item())),
|
|
}), flush=True)
|
|
dist.destroy_process_group()
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
'''
|
|
tmp = tempfile.NamedTemporaryFile("w", suffix="_training_ddp.py", delete=False)
|
|
tmp.write(script)
|
|
tmp.close()
|
|
|
|
env = {
|
|
**os.environ,
|
|
"TRAIN_BATCH_SIZE": str(batch_size),
|
|
"TRAIN_SEQ_LENGTH": str(seq_length),
|
|
"TRAIN_NUM_STEPS": str(num_steps),
|
|
"TRAIN_WARMUP_STEPS": str(int(self.train_cfg.get("warmup_steps", 5))),
|
|
"TRAIN_DTYPE": dtype_str,
|
|
"NCCL_DEBUG": os.environ.get("NCCL_DEBUG", "WARN"),
|
|
}
|
|
cmd = [torchrun, f"--nproc_per_node={gpu_count}", tmp.name]
|
|
self.console.print(f" Running synthetic 1.5B DDP via torchrun ({gpu_count} processes)...")
|
|
try:
|
|
timeout = int(self.train_cfg.get("timeout_sec", max(600, num_steps * 180)))
|
|
r = subprocess.run(cmd, capture_output=True, text=True, timeout=timeout, env=env)
|
|
except subprocess.TimeoutExpired:
|
|
os.unlink(tmp.name)
|
|
return {
|
|
"model": "synthetic_transformer_1.5b",
|
|
"gpu_count": gpu_count,
|
|
"distributed_mode": "ddp",
|
|
"passed": False,
|
|
"error": "training_ddp_timeout",
|
|
"timestamp": datetime.now().isoformat(),
|
|
}
|
|
finally:
|
|
if os.path.exists(tmp.name):
|
|
try:
|
|
os.unlink(tmp.name)
|
|
except OSError:
|
|
pass
|
|
|
|
marker = "TRAINING_DDP_JSON="
|
|
payload = None
|
|
for line in (r.stdout + "\n" + r.stderr).splitlines():
|
|
if marker in line:
|
|
payload = line.split(marker, 1)[1].strip()
|
|
if r.returncode != 0 or not payload:
|
|
return {
|
|
"model": "synthetic_transformer_1.5b",
|
|
"gpu_count": gpu_count,
|
|
"distributed_mode": "ddp",
|
|
"passed": False,
|
|
"error": (r.stderr or r.stdout or "training_ddp_failed")[-1000:],
|
|
"timestamp": datetime.now().isoformat(),
|
|
}
|
|
|
|
result = json.loads(payload)
|
|
loss_value = float(result.get("final_loss", "nan"))
|
|
passed = self._acceptance_pass(
|
|
float(result.get("throughput_tokens_per_sec", 0)),
|
|
float(result.get("step_jitter_pct", 999)),
|
|
float(result.get("peak_memory_gb", 999)),
|
|
loss_value,
|
|
) and bool(result.get("loss_finite", False)) and result.get("gpu_count") == gpu_count
|
|
result.update({
|
|
"passed": passed,
|
|
"timestamp": datetime.now().isoformat(),
|
|
})
|
|
return result
|
|
|
|
def _run_synthetic(self, gpu_count, batch_size, seq_length, num_steps, dtype) -> dict:
|
|
self.console.print(" Running synthetic training benchmark...")
|
|
|
|
hidden_size = 4096
|
|
num_layers = 6
|
|
num_heads = 32
|
|
vocab_size = 32000
|
|
|
|
class SyntheticTransformer(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.embed = torch.nn.Embedding(vocab_size, hidden_size)
|
|
self.layers = torch.nn.ModuleList([
|
|
torch.nn.TransformerEncoderLayer(
|
|
d_model=hidden_size, nhead=num_heads,
|
|
dim_feedforward=hidden_size * 4,
|
|
batch_first=True,
|
|
dtype=dtype,
|
|
) for _ in range(num_layers)
|
|
])
|
|
self.head = torch.nn.Linear(hidden_size, vocab_size, dtype=dtype)
|
|
|
|
def forward(self, x):
|
|
h = self.embed(x).to(dtype)
|
|
for layer in self.layers:
|
|
h = layer(h)
|
|
return self.head(h)
|
|
|
|
model = SyntheticTransformer()
|
|
total_params = sum(p.numel() for p in model.parameters())
|
|
|
|
self.console.print(f" Synthetic params: {total_params / 1e6:.1f}M")
|
|
|
|
distributed_mode = "single_gpu"
|
|
if gpu_count > 1:
|
|
model = torch.nn.DataParallel(model).cuda()
|
|
distributed_mode = "data_parallel"
|
|
else:
|
|
model = model.cuda()
|
|
model.train()
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
|
|
|
|
input_ids = torch.randint(0, vocab_size, (batch_size, seq_length)).cuda()
|
|
|
|
step_times = []
|
|
mem_usage = []
|
|
|
|
warmup_steps = int(self.train_cfg.get("warmup_steps", 5))
|
|
total_steps = num_steps + warmup_steps
|
|
|
|
with Progress(
|
|
SpinnerColumn(), TextColumn("[progress.description]{task.description}"),
|
|
BarColumn(), TextColumn("{task.completed}/{task.total}"),
|
|
TimeElapsedColumn(), console=self.console,
|
|
) as progress:
|
|
task = progress.add_task("Synthetic training...", total=total_steps)
|
|
|
|
for step in range(total_steps):
|
|
torch.cuda.synchronize()
|
|
t0 = time.perf_counter()
|
|
|
|
logits = model(input_ids)
|
|
loss = torch.nn.functional.cross_entropy(
|
|
logits.view(-1, vocab_size), input_ids.view(-1)
|
|
)
|
|
loss.backward()
|
|
optimizer.step()
|
|
optimizer.zero_grad()
|
|
|
|
torch.cuda.synchronize()
|
|
elapsed = time.perf_counter() - t0
|
|
step_times.append(elapsed)
|
|
|
|
mem_used = max(torch.cuda.max_memory_allocated(i) for i in range(gpu_count)) / 1024**3
|
|
mem_usage.append(mem_used)
|
|
for i in range(gpu_count):
|
|
torch.cuda.reset_peak_memory_stats(i)
|
|
|
|
progress.advance(task)
|
|
|
|
measured_steps = step_times[warmup_steps:] if len(step_times) > warmup_steps else step_times
|
|
avg_step_time = sum(measured_steps) / len(measured_steps)
|
|
throughput = batch_size * seq_length / avg_step_time
|
|
jitter = self._jitter_pct(measured_steps)
|
|
peak_mem = round(max(mem_usage) if mem_usage else 0, 2)
|
|
final_loss = float(loss.item())
|
|
passed = self._acceptance_pass(throughput, jitter, peak_mem, final_loss)
|
|
if self.train_cfg.get("require_distributed", True):
|
|
passed = False
|
|
|
|
return {
|
|
"model": "synthetic_transformer",
|
|
"total_params_m": round(total_params / 1e6, 1),
|
|
"num_layers": num_layers,
|
|
"hidden_size": hidden_size,
|
|
"gpu_count": gpu_count,
|
|
"dtype": str(dtype).replace("torch.", ""),
|
|
"batch_size": batch_size,
|
|
"seq_length": seq_length,
|
|
"num_steps": num_steps,
|
|
"warmup_steps": warmup_steps,
|
|
"total_steps": total_steps,
|
|
"avg_step_time_ms": round(avg_step_time * 1000, 1),
|
|
"throughput_tokens_per_sec": round(throughput, 0),
|
|
"throughput_samples_per_sec": round(batch_size / avg_step_time, 2),
|
|
"peak_memory_gb": peak_mem,
|
|
"final_loss": round(final_loss, 4),
|
|
"step_jitter_pct": round(jitter, 2),
|
|
"distributed_mode": distributed_mode,
|
|
"loss_finite": math.isfinite(final_loss),
|
|
"passed": passed,
|
|
"acceptance_gap": "8-GPU DDP was not used" if self.train_cfg.get("require_distributed", True) else "",
|
|
"timestamp": datetime.now().isoformat(),
|
|
}
|
|
|
|
@staticmethod
|
|
def _jitter_pct(step_times: list[float]) -> float:
|
|
if not step_times:
|
|
return 0.0
|
|
mean = sum(step_times) / len(step_times)
|
|
return max(abs(v - mean) / mean * 100 for v in step_times) if mean else 0.0
|
|
|
|
def _acceptance_pass(self, throughput: float, jitter: float, peak_mem: float, loss_value: float) -> bool:
|
|
return (
|
|
throughput >= float(self.train_cfg.get("min_tokens_per_sec", 45000))
|
|
and jitter <= float(self.train_cfg.get("max_step_jitter_pct", 3))
|
|
and peak_mem <= float(self.train_cfg.get("max_peak_memory_gb", 70))
|
|
and math.isfinite(loss_value)
|
|
)
|
|
|
|
@staticmethod
|
|
def print_results(results: dict, console: Console = None):
|
|
c = console or Console()
|
|
if "error" in results:
|
|
c.print(f"[bold red]Error: {results['error']}[/bold red]")
|
|
return
|
|
|
|
c.print(f"\n[bold cyan]Training Simulation Results[/bold cyan]")
|
|
|
|
table = Table(box=None, padding=(0, 1))
|
|
table.add_column("Metric", style="bold")
|
|
table.add_column("Value")
|
|
|
|
metrics = [
|
|
("Model", results.get("model", "N/A")),
|
|
("Parameters", f"{results.get('total_params_m', 'N/A')}M"),
|
|
("GPU Count", str(results.get("gpu_count", "N/A"))),
|
|
("DType", results.get("dtype", "N/A")),
|
|
("Batch Size", str(results.get("batch_size", "N/A"))),
|
|
("Seq Length", str(results.get("seq_length", "N/A"))),
|
|
("Steps", str(results.get("num_steps", "N/A"))),
|
|
("Warmup Steps", str(results.get("warmup_steps", "N/A"))),
|
|
("Avg Step Time", f"{results.get('avg_step_time_ms', 'N/A')} ms"),
|
|
("Throughput", f"{results.get('throughput_tokens_per_sec', 'N/A')} tokens/s"),
|
|
("Samples/sec", f"{results.get('throughput_samples_per_sec', 'N/A')}"),
|
|
("Peak Memory", f"{results.get('peak_memory_gb', 'N/A')} GB"),
|
|
("Final Loss", str(results.get("final_loss", "N/A"))),
|
|
("Step Jitter", f"{results.get('step_jitter_pct', 'N/A')}%"),
|
|
("Distributed Mode", results.get("distributed_mode", "N/A")),
|
|
("Verdict", "PASS" if results.get("passed") else "FAIL"),
|
|
]
|
|
for label, val in metrics:
|
|
table.add_row(label, str(val))
|
|
|
|
c.print(table)
|