test_gpu_scripts/modules/training_sim.py

527 lines
22 KiB
Python

"""Training simulation module - LLM training workload with PyTorch."""
import json
import os
import sys
import tempfile
import time
import subprocess
import shutil
import math
from datetime import datetime
from typing import Optional
from rich.console import Console
from rich.table import Table
from rich.progress import Progress, SpinnerColumn, BarColumn, TextColumn, TimeElapsedColumn
TORCH_AVAILABLE = False
try:
import torch
if torch.cuda.is_available():
TORCH_AVAILABLE = True
except ImportError:
pass
class TrainingSim:
def __init__(self, config: dict):
self.config = config
self.console = Console()
self.train_cfg = config.get("training", {})
def run(self) -> dict:
if not TORCH_AVAILABLE:
self.console.print("[yellow]PyTorch not available - skipping training simulation[/yellow]")
return {"error": "pytorch_not_available"}
gpu_count = torch.cuda.device_count()
model_name = self.train_cfg.get("model", "gpt2")
batch_size = self.train_cfg.get("batch_size", 8)
seq_length = self.train_cfg.get("seq_length", 2048)
num_steps = self.train_cfg.get("num_steps", 50)
warmup_steps = int(self.train_cfg.get("warmup_steps", 5))
dtype_str = self.train_cfg.get("dtype", "bf16")
dtype_map = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}
dtype = dtype_map.get(dtype_str, torch.bfloat16)
self.console.print(f"[cyan]Training Simulation[/cyan]")
self.console.print(f" Model: {model_name} | Batch: {batch_size} | Seq: {seq_length} | "
f"DType: {dtype_str} | Steps: {num_steps} | Warmup: {warmup_steps} | GPUs: {gpu_count}")
if self.train_cfg.get("mode", "ddp") == "ddp" and gpu_count > 1:
ddp_result = self._run_synthetic_ddp(gpu_count, batch_size, seq_length, num_steps, dtype_str)
if ddp_result.get("passed") or not self.train_cfg.get("allow_fallback", False):
return ddp_result
self.console.print("[yellow]DDP synthetic training failed, falling back to single-process synthetic path[/yellow]")
try:
from transformers import AutoModelForCausalLM, AutoTokenizer
except ImportError:
self.console.print("[yellow]transformers not installed - using synthetic model[/yellow]")
return self._run_synthetic(gpu_count, batch_size, seq_length, num_steps, dtype)
try:
self.console.print(f" Loading {model_name}...")
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=dtype,
device_map="auto" if gpu_count > 1 else None,
)
total_params = sum(p.numel() for p in model.parameters())
self.console.print(f" Parameters: {total_params / 1e6:.1f}M")
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
input_ids = torch.randint(0, tokenizer.vocab_size, (batch_size, seq_length))
attention_mask = torch.ones_like(input_ids)
model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
if dtype in (torch.float16, torch.bfloat16):
scaler = torch.amp.GradScaler("cuda", enabled=(dtype == torch.float16))
step_times = []
mem_usage = []
with Progress(
SpinnerColumn(), TextColumn("[progress.description]{task.description}"),
BarColumn(), TextColumn("{task.completed}/{task.total}"),
TimeElapsedColumn(), console=self.console,
) as progress:
total_steps = num_steps + warmup_steps
task = progress.add_task("Training steps...", total=total_steps)
for step in range(total_steps):
torch.cuda.synchronize()
t0 = time.perf_counter()
input_ids = input_ids.to(model.device)
attention_mask = attention_mask.to(model.device)
if dtype in (torch.float16, torch.bfloat16):
with torch.amp.autocast("cuda", dtype=dtype):
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
loss = outputs.loss
else:
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
loss = outputs.loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
torch.cuda.synchronize()
elapsed = time.perf_counter() - t0
step_times.append(elapsed)
if torch.cuda.is_available():
mem_used = torch.cuda.max_memory_allocated() / 1024**3
mem_usage.append(mem_used)
torch.cuda.reset_peak_memory_stats()
progress.advance(task)
measured_steps = step_times[warmup_steps:] if len(step_times) > warmup_steps else step_times
avg_step_time = sum(measured_steps) / len(measured_steps)
throughput = batch_size * seq_length / avg_step_time
jitter = self._jitter_pct(measured_steps)
peak_mem = round(max(mem_usage) if mem_usage else 0, 2)
final_loss = float(loss.item()) if hasattr(loss, "item") else float("nan")
passed = self._acceptance_pass(throughput, jitter, peak_mem, final_loss)
if self.train_cfg.get("require_distributed", True):
passed = False
return {
"model": model_name,
"total_params_m": round(total_params / 1e6, 1),
"gpu_count": gpu_count,
"dtype": dtype_str,
"batch_size": batch_size,
"seq_length": seq_length,
"num_steps": num_steps,
"warmup_steps": warmup_steps,
"total_steps": total_steps,
"avg_step_time_ms": round(avg_step_time * 1000, 1),
"throughput_tokens_per_sec": round(throughput, 0),
"throughput_samples_per_sec": round(batch_size / avg_step_time, 2),
"peak_memory_gb": peak_mem,
"final_loss": round(final_loss, 4),
"step_jitter_pct": round(jitter, 2),
"distributed_mode": "device_map",
"loss_finite": math.isfinite(final_loss),
"passed": passed,
"acceptance_gap": "8-GPU DDP was not used" if self.train_cfg.get("require_distributed", True) else "",
"timestamp": datetime.now().isoformat(),
}
except Exception as e:
self.console.print(f"[yellow]Model loading failed: {e}[/yellow]")
return self._run_synthetic(gpu_count, batch_size, seq_length, num_steps, dtype)
def _run_synthetic_ddp(self, gpu_count: int, batch_size: int, seq_length: int,
num_steps: int, dtype_str: str) -> dict:
"""Run the 1.5B synthetic Transformer with one process per GPU."""
torchrun = os.path.join(os.path.dirname(sys.executable), "torchrun")
if not os.path.isfile(torchrun):
torchrun = shutil.which("torchrun") or ""
if not torchrun:
return {
"model": "synthetic_transformer_1.5b",
"gpu_count": gpu_count,
"distributed_mode": "ddp",
"passed": False,
"error": "torchrun not found",
"timestamp": datetime.now().isoformat(),
}
script = r'''
import json
import math
import os
import time
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def main():
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
torch.cuda.set_device(local_rank)
dist.init_process_group("nccl")
global_batch = int(os.environ["TRAIN_BATCH_SIZE"])
local_batch = max(1, global_batch // world_size)
seq_length = int(os.environ["TRAIN_SEQ_LENGTH"])
num_steps = int(os.environ["TRAIN_NUM_STEPS"])
warmup_steps = int(os.environ.get("TRAIN_WARMUP_STEPS", "5"))
total_steps = num_steps + warmup_steps
dtype_name = os.environ.get("TRAIN_DTYPE", "bf16")
dtype = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32}.get(dtype_name, torch.bfloat16)
hidden_size = 4096
num_layers = 6
num_heads = 32
vocab_size = 32000
class SyntheticTransformer(torch.nn.Module):
def __init__(self):
super().__init__()
self.embed = torch.nn.Embedding(vocab_size, hidden_size)
self.layers = torch.nn.ModuleList([
torch.nn.TransformerEncoderLayer(
d_model=hidden_size,
nhead=num_heads,
dim_feedforward=hidden_size * 4,
batch_first=True,
dtype=dtype,
) for _ in range(num_layers)
])
self.head = torch.nn.Linear(hidden_size, vocab_size, dtype=dtype)
def forward(self, x):
h = self.embed(x).to(dtype)
for layer in self.layers:
h = layer(h)
return self.head(h)
model = SyntheticTransformer().cuda()
total_params = sum(p.numel() for p in model.parameters())
model = DDP(model, device_ids=[local_rank], output_device=local_rank)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
input_ids = torch.randint(0, vocab_size, (local_batch, seq_length), device="cuda")
step_times = []
last_loss = torch.tensor(float("nan"), device="cuda")
torch.cuda.reset_peak_memory_stats(local_rank)
for _ in range(total_steps):
torch.cuda.synchronize()
t0 = time.perf_counter()
with torch.amp.autocast("cuda", dtype=dtype, enabled=dtype in (torch.float16, torch.bfloat16)):
logits = model(input_ids)
loss = torch.nn.functional.cross_entropy(logits.reshape(-1, vocab_size), input_ids.reshape(-1))
loss.backward()
optimizer.step()
optimizer.zero_grad(set_to_none=True)
torch.cuda.synchronize()
step_times.append(time.perf_counter() - t0)
last_loss = loss.detach()
peak_mem = torch.tensor(torch.cuda.max_memory_allocated(local_rank) / 1024**3, device="cuda")
dist.all_reduce(peak_mem, op=dist.ReduceOp.MAX)
finite = torch.tensor(1 if math.isfinite(float(last_loss.item())) else 0, device="cuda")
dist.all_reduce(finite, op=dist.ReduceOp.MIN)
if dist.get_rank() == 0:
measured_steps = step_times[warmup_steps:] if len(step_times) > warmup_steps else step_times
avg_step = sum(measured_steps) / len(measured_steps)
mean = avg_step
jitter = max(abs(v - mean) / mean * 100 for v in measured_steps) if mean else 0.0
throughput = global_batch * seq_length / avg_step if avg_step else 0.0
print("TRAINING_DDP_JSON=" + json.dumps({
"model": "synthetic_transformer_1.5b",
"total_params_m": round(total_params / 1e6, 1),
"num_layers": num_layers,
"hidden_size": hidden_size,
"gpu_count": world_size,
"dtype": dtype_name,
"batch_size": global_batch,
"local_batch_size": local_batch,
"seq_length": seq_length,
"num_steps": num_steps,
"warmup_steps": warmup_steps,
"total_steps": total_steps,
"avg_step_time_ms": round(avg_step * 1000, 1),
"throughput_tokens_per_sec": round(throughput, 0),
"throughput_samples_per_sec": round(global_batch / avg_step, 2) if avg_step else 0,
"peak_memory_gb": round(float(peak_mem.item()), 2),
"final_loss": round(float(last_loss.item()), 4),
"step_jitter_pct": round(jitter, 2),
"distributed_mode": "ddp",
"loss_finite": bool(int(finite.item())),
}), flush=True)
dist.destroy_process_group()
if __name__ == "__main__":
main()
'''
tmp = tempfile.NamedTemporaryFile("w", suffix="_training_ddp.py", delete=False)
tmp.write(script)
tmp.close()
env = {
**os.environ,
"TRAIN_BATCH_SIZE": str(batch_size),
"TRAIN_SEQ_LENGTH": str(seq_length),
"TRAIN_NUM_STEPS": str(num_steps),
"TRAIN_WARMUP_STEPS": str(int(self.train_cfg.get("warmup_steps", 5))),
"TRAIN_DTYPE": dtype_str,
"NCCL_DEBUG": os.environ.get("NCCL_DEBUG", "WARN"),
}
cmd = [torchrun, f"--nproc_per_node={gpu_count}", tmp.name]
self.console.print(f" Running synthetic 1.5B DDP via torchrun ({gpu_count} processes)...")
try:
timeout = int(self.train_cfg.get("timeout_sec", max(600, num_steps * 180)))
r = subprocess.run(cmd, capture_output=True, text=True, timeout=timeout, env=env)
except subprocess.TimeoutExpired:
os.unlink(tmp.name)
return {
"model": "synthetic_transformer_1.5b",
"gpu_count": gpu_count,
"distributed_mode": "ddp",
"passed": False,
"error": "training_ddp_timeout",
"timestamp": datetime.now().isoformat(),
}
finally:
if os.path.exists(tmp.name):
try:
os.unlink(tmp.name)
except OSError:
pass
marker = "TRAINING_DDP_JSON="
payload = None
for line in (r.stdout + "\n" + r.stderr).splitlines():
if marker in line:
payload = line.split(marker, 1)[1].strip()
if r.returncode != 0 or not payload:
return {
"model": "synthetic_transformer_1.5b",
"gpu_count": gpu_count,
"distributed_mode": "ddp",
"passed": False,
"error": (r.stderr or r.stdout or "training_ddp_failed")[-1000:],
"timestamp": datetime.now().isoformat(),
}
result = json.loads(payload)
loss_value = float(result.get("final_loss", "nan"))
passed = self._acceptance_pass(
float(result.get("throughput_tokens_per_sec", 0)),
float(result.get("step_jitter_pct", 999)),
float(result.get("peak_memory_gb", 999)),
loss_value,
) and bool(result.get("loss_finite", False)) and result.get("gpu_count") == gpu_count
result.update({
"passed": passed,
"timestamp": datetime.now().isoformat(),
})
return result
def _run_synthetic(self, gpu_count, batch_size, seq_length, num_steps, dtype) -> dict:
self.console.print(" Running synthetic training benchmark...")
hidden_size = 4096
num_layers = 6
num_heads = 32
vocab_size = 32000
class SyntheticTransformer(torch.nn.Module):
def __init__(self):
super().__init__()
self.embed = torch.nn.Embedding(vocab_size, hidden_size)
self.layers = torch.nn.ModuleList([
torch.nn.TransformerEncoderLayer(
d_model=hidden_size, nhead=num_heads,
dim_feedforward=hidden_size * 4,
batch_first=True,
dtype=dtype,
) for _ in range(num_layers)
])
self.head = torch.nn.Linear(hidden_size, vocab_size, dtype=dtype)
def forward(self, x):
h = self.embed(x).to(dtype)
for layer in self.layers:
h = layer(h)
return self.head(h)
model = SyntheticTransformer()
total_params = sum(p.numel() for p in model.parameters())
self.console.print(f" Synthetic params: {total_params / 1e6:.1f}M")
distributed_mode = "single_gpu"
if gpu_count > 1:
model = torch.nn.DataParallel(model).cuda()
distributed_mode = "data_parallel"
else:
model = model.cuda()
model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
input_ids = torch.randint(0, vocab_size, (batch_size, seq_length)).cuda()
step_times = []
mem_usage = []
warmup_steps = int(self.train_cfg.get("warmup_steps", 5))
total_steps = num_steps + warmup_steps
with Progress(
SpinnerColumn(), TextColumn("[progress.description]{task.description}"),
BarColumn(), TextColumn("{task.completed}/{task.total}"),
TimeElapsedColumn(), console=self.console,
) as progress:
task = progress.add_task("Synthetic training...", total=total_steps)
for step in range(total_steps):
torch.cuda.synchronize()
t0 = time.perf_counter()
logits = model(input_ids)
loss = torch.nn.functional.cross_entropy(
logits.view(-1, vocab_size), input_ids.view(-1)
)
loss.backward()
optimizer.step()
optimizer.zero_grad()
torch.cuda.synchronize()
elapsed = time.perf_counter() - t0
step_times.append(elapsed)
mem_used = max(torch.cuda.max_memory_allocated(i) for i in range(gpu_count)) / 1024**3
mem_usage.append(mem_used)
for i in range(gpu_count):
torch.cuda.reset_peak_memory_stats(i)
progress.advance(task)
measured_steps = step_times[warmup_steps:] if len(step_times) > warmup_steps else step_times
avg_step_time = sum(measured_steps) / len(measured_steps)
throughput = batch_size * seq_length / avg_step_time
jitter = self._jitter_pct(measured_steps)
peak_mem = round(max(mem_usage) if mem_usage else 0, 2)
final_loss = float(loss.item())
passed = self._acceptance_pass(throughput, jitter, peak_mem, final_loss)
if self.train_cfg.get("require_distributed", True):
passed = False
return {
"model": "synthetic_transformer",
"total_params_m": round(total_params / 1e6, 1),
"num_layers": num_layers,
"hidden_size": hidden_size,
"gpu_count": gpu_count,
"dtype": str(dtype).replace("torch.", ""),
"batch_size": batch_size,
"seq_length": seq_length,
"num_steps": num_steps,
"warmup_steps": warmup_steps,
"total_steps": total_steps,
"avg_step_time_ms": round(avg_step_time * 1000, 1),
"throughput_tokens_per_sec": round(throughput, 0),
"throughput_samples_per_sec": round(batch_size / avg_step_time, 2),
"peak_memory_gb": peak_mem,
"final_loss": round(final_loss, 4),
"step_jitter_pct": round(jitter, 2),
"distributed_mode": distributed_mode,
"loss_finite": math.isfinite(final_loss),
"passed": passed,
"acceptance_gap": "8-GPU DDP was not used" if self.train_cfg.get("require_distributed", True) else "",
"timestamp": datetime.now().isoformat(),
}
@staticmethod
def _jitter_pct(step_times: list[float]) -> float:
if not step_times:
return 0.0
mean = sum(step_times) / len(step_times)
return max(abs(v - mean) / mean * 100 for v in step_times) if mean else 0.0
def _acceptance_pass(self, throughput: float, jitter: float, peak_mem: float, loss_value: float) -> bool:
return (
throughput >= float(self.train_cfg.get("min_tokens_per_sec", 45000))
and jitter <= float(self.train_cfg.get("max_step_jitter_pct", 3))
and peak_mem <= float(self.train_cfg.get("max_peak_memory_gb", 70))
and math.isfinite(loss_value)
)
@staticmethod
def print_results(results: dict, console: Console = None):
c = console or Console()
if "error" in results:
c.print(f"[bold red]Error: {results['error']}[/bold red]")
return
c.print(f"\n[bold cyan]Training Simulation Results[/bold cyan]")
table = Table(box=None, padding=(0, 1))
table.add_column("Metric", style="bold")
table.add_column("Value")
metrics = [
("Model", results.get("model", "N/A")),
("Parameters", f"{results.get('total_params_m', 'N/A')}M"),
("GPU Count", str(results.get("gpu_count", "N/A"))),
("DType", results.get("dtype", "N/A")),
("Batch Size", str(results.get("batch_size", "N/A"))),
("Seq Length", str(results.get("seq_length", "N/A"))),
("Steps", str(results.get("num_steps", "N/A"))),
("Warmup Steps", str(results.get("warmup_steps", "N/A"))),
("Avg Step Time", f"{results.get('avg_step_time_ms', 'N/A')} ms"),
("Throughput", f"{results.get('throughput_tokens_per_sec', 'N/A')} tokens/s"),
("Samples/sec", f"{results.get('throughput_samples_per_sec', 'N/A')}"),
("Peak Memory", f"{results.get('peak_memory_gb', 'N/A')} GB"),
("Final Loss", str(results.get("final_loss", "N/A"))),
("Step Jitter", f"{results.get('step_jitter_pct', 'N/A')}%"),
("Distributed Mode", results.get("distributed_mode", "N/A")),
("Verdict", "PASS" if results.get("passed") else "FAIL"),
]
for label, val in metrics:
table.add_row(label, str(val))
c.print(table)