"""Training simulation module - LLM training workload with PyTorch.""" import json import os import sys import tempfile import time import subprocess import shutil import math from datetime import datetime from typing import Optional from rich.console import Console from rich.table import Table from rich.progress import Progress, SpinnerColumn, BarColumn, TextColumn, TimeElapsedColumn TORCH_AVAILABLE = False try: import torch if torch.cuda.is_available(): TORCH_AVAILABLE = True except ImportError: pass class TrainingSim: def __init__(self, config: dict): self.config = config self.console = Console() self.train_cfg = config.get("training", {}) def run(self) -> dict: if not TORCH_AVAILABLE: self.console.print("[yellow]PyTorch not available - skipping training simulation[/yellow]") return {"error": "pytorch_not_available"} gpu_count = torch.cuda.device_count() model_name = self.train_cfg.get("model", "gpt2") batch_size = self.train_cfg.get("batch_size", 8) seq_length = self.train_cfg.get("seq_length", 2048) num_steps = self.train_cfg.get("num_steps", 50) warmup_steps = int(self.train_cfg.get("warmup_steps", 5)) dtype_str = self.train_cfg.get("dtype", "bf16") dtype_map = { "fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16, } dtype = dtype_map.get(dtype_str, torch.bfloat16) self.console.print(f"[cyan]Training Simulation[/cyan]") self.console.print(f" Model: {model_name} | Batch: {batch_size} | Seq: {seq_length} | " f"DType: {dtype_str} | Steps: {num_steps} | Warmup: {warmup_steps} | GPUs: {gpu_count}") if self.train_cfg.get("mode", "ddp") == "ddp" and gpu_count > 1: ddp_result = self._run_synthetic_ddp(gpu_count, batch_size, seq_length, num_steps, dtype_str) if ddp_result.get("passed") or not self.train_cfg.get("allow_fallback", False): return ddp_result self.console.print("[yellow]DDP synthetic training failed, falling back to single-process synthetic path[/yellow]") try: from transformers import AutoModelForCausalLM, AutoTokenizer except ImportError: self.console.print("[yellow]transformers not installed - using synthetic model[/yellow]") return self._run_synthetic(gpu_count, batch_size, seq_length, num_steps, dtype) try: self.console.print(f" Loading {model_name}...") model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=dtype, device_map="auto" if gpu_count > 1 else None, ) total_params = sum(p.numel() for p in model.parameters()) self.console.print(f" Parameters: {total_params / 1e6:.1f}M") tokenizer = AutoTokenizer.from_pretrained(model_name) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token input_ids = torch.randint(0, tokenizer.vocab_size, (batch_size, seq_length)) attention_mask = torch.ones_like(input_ids) model.train() optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5) if dtype in (torch.float16, torch.bfloat16): scaler = torch.amp.GradScaler("cuda", enabled=(dtype == torch.float16)) step_times = [] mem_usage = [] with Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), BarColumn(), TextColumn("{task.completed}/{task.total}"), TimeElapsedColumn(), console=self.console, ) as progress: total_steps = num_steps + warmup_steps task = progress.add_task("Training steps...", total=total_steps) for step in range(total_steps): torch.cuda.synchronize() t0 = time.perf_counter() input_ids = input_ids.to(model.device) attention_mask = attention_mask.to(model.device) if dtype in (torch.float16, torch.bfloat16): with torch.amp.autocast("cuda", dtype=dtype): outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids) loss = outputs.loss else: outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids) loss = outputs.loss loss.backward() optimizer.step() optimizer.zero_grad() torch.cuda.synchronize() elapsed = time.perf_counter() - t0 step_times.append(elapsed) if torch.cuda.is_available(): mem_used = torch.cuda.max_memory_allocated() / 1024**3 mem_usage.append(mem_used) torch.cuda.reset_peak_memory_stats() progress.advance(task) measured_steps = step_times[warmup_steps:] if len(step_times) > warmup_steps else step_times avg_step_time = sum(measured_steps) / len(measured_steps) throughput = batch_size * seq_length / avg_step_time jitter = self._jitter_pct(measured_steps) peak_mem = round(max(mem_usage) if mem_usage else 0, 2) final_loss = float(loss.item()) if hasattr(loss, "item") else float("nan") passed = self._acceptance_pass(throughput, jitter, peak_mem, final_loss) if self.train_cfg.get("require_distributed", True): passed = False return { "model": model_name, "total_params_m": round(total_params / 1e6, 1), "gpu_count": gpu_count, "dtype": dtype_str, "batch_size": batch_size, "seq_length": seq_length, "num_steps": num_steps, "warmup_steps": warmup_steps, "total_steps": total_steps, "avg_step_time_ms": round(avg_step_time * 1000, 1), "throughput_tokens_per_sec": round(throughput, 0), "throughput_samples_per_sec": round(batch_size / avg_step_time, 2), "peak_memory_gb": peak_mem, "final_loss": round(final_loss, 4), "step_jitter_pct": round(jitter, 2), "distributed_mode": "device_map", "loss_finite": math.isfinite(final_loss), "passed": passed, "acceptance_gap": "8-GPU DDP was not used" if self.train_cfg.get("require_distributed", True) else "", "timestamp": datetime.now().isoformat(), } except Exception as e: self.console.print(f"[yellow]Model loading failed: {e}[/yellow]") return self._run_synthetic(gpu_count, batch_size, seq_length, num_steps, dtype) def _run_synthetic_ddp(self, gpu_count: int, batch_size: int, seq_length: int, num_steps: int, dtype_str: str) -> dict: """Run the 1.5B synthetic Transformer with one process per GPU.""" torchrun = os.path.join(os.path.dirname(sys.executable), "torchrun") if not os.path.isfile(torchrun): torchrun = shutil.which("torchrun") or "" if not torchrun: return { "model": "synthetic_transformer_1.5b", "gpu_count": gpu_count, "distributed_mode": "ddp", "passed": False, "error": "torchrun not found", "timestamp": datetime.now().isoformat(), } script = r''' import json import math import os import time import torch import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP def main(): local_rank = int(os.environ["LOCAL_RANK"]) world_size = int(os.environ["WORLD_SIZE"]) torch.cuda.set_device(local_rank) dist.init_process_group("nccl") global_batch = int(os.environ["TRAIN_BATCH_SIZE"]) local_batch = max(1, global_batch // world_size) seq_length = int(os.environ["TRAIN_SEQ_LENGTH"]) num_steps = int(os.environ["TRAIN_NUM_STEPS"]) warmup_steps = int(os.environ.get("TRAIN_WARMUP_STEPS", "5")) total_steps = num_steps + warmup_steps dtype_name = os.environ.get("TRAIN_DTYPE", "bf16") dtype = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32}.get(dtype_name, torch.bfloat16) hidden_size = 4096 num_layers = 6 num_heads = 32 vocab_size = 32000 class SyntheticTransformer(torch.nn.Module): def __init__(self): super().__init__() self.embed = torch.nn.Embedding(vocab_size, hidden_size) self.layers = torch.nn.ModuleList([ torch.nn.TransformerEncoderLayer( d_model=hidden_size, nhead=num_heads, dim_feedforward=hidden_size * 4, batch_first=True, dtype=dtype, ) for _ in range(num_layers) ]) self.head = torch.nn.Linear(hidden_size, vocab_size, dtype=dtype) def forward(self, x): h = self.embed(x).to(dtype) for layer in self.layers: h = layer(h) return self.head(h) model = SyntheticTransformer().cuda() total_params = sum(p.numel() for p in model.parameters()) model = DDP(model, device_ids=[local_rank], output_device=local_rank) optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) input_ids = torch.randint(0, vocab_size, (local_batch, seq_length), device="cuda") step_times = [] last_loss = torch.tensor(float("nan"), device="cuda") torch.cuda.reset_peak_memory_stats(local_rank) for _ in range(total_steps): torch.cuda.synchronize() t0 = time.perf_counter() with torch.amp.autocast("cuda", dtype=dtype, enabled=dtype in (torch.float16, torch.bfloat16)): logits = model(input_ids) loss = torch.nn.functional.cross_entropy(logits.reshape(-1, vocab_size), input_ids.reshape(-1)) loss.backward() optimizer.step() optimizer.zero_grad(set_to_none=True) torch.cuda.synchronize() step_times.append(time.perf_counter() - t0) last_loss = loss.detach() peak_mem = torch.tensor(torch.cuda.max_memory_allocated(local_rank) / 1024**3, device="cuda") dist.all_reduce(peak_mem, op=dist.ReduceOp.MAX) finite = torch.tensor(1 if math.isfinite(float(last_loss.item())) else 0, device="cuda") dist.all_reduce(finite, op=dist.ReduceOp.MIN) if dist.get_rank() == 0: measured_steps = step_times[warmup_steps:] if len(step_times) > warmup_steps else step_times avg_step = sum(measured_steps) / len(measured_steps) mean = avg_step jitter = max(abs(v - mean) / mean * 100 for v in measured_steps) if mean else 0.0 throughput = global_batch * seq_length / avg_step if avg_step else 0.0 print("TRAINING_DDP_JSON=" + json.dumps({ "model": "synthetic_transformer_1.5b", "total_params_m": round(total_params / 1e6, 1), "num_layers": num_layers, "hidden_size": hidden_size, "gpu_count": world_size, "dtype": dtype_name, "batch_size": global_batch, "local_batch_size": local_batch, "seq_length": seq_length, "num_steps": num_steps, "warmup_steps": warmup_steps, "total_steps": total_steps, "avg_step_time_ms": round(avg_step * 1000, 1), "throughput_tokens_per_sec": round(throughput, 0), "throughput_samples_per_sec": round(global_batch / avg_step, 2) if avg_step else 0, "peak_memory_gb": round(float(peak_mem.item()), 2), "final_loss": round(float(last_loss.item()), 4), "step_jitter_pct": round(jitter, 2), "distributed_mode": "ddp", "loss_finite": bool(int(finite.item())), }), flush=True) dist.destroy_process_group() if __name__ == "__main__": main() ''' tmp = tempfile.NamedTemporaryFile("w", suffix="_training_ddp.py", delete=False) tmp.write(script) tmp.close() env = { **os.environ, "TRAIN_BATCH_SIZE": str(batch_size), "TRAIN_SEQ_LENGTH": str(seq_length), "TRAIN_NUM_STEPS": str(num_steps), "TRAIN_WARMUP_STEPS": str(int(self.train_cfg.get("warmup_steps", 5))), "TRAIN_DTYPE": dtype_str, "NCCL_DEBUG": os.environ.get("NCCL_DEBUG", "WARN"), } cmd = [torchrun, f"--nproc_per_node={gpu_count}", tmp.name] self.console.print(f" Running synthetic 1.5B DDP via torchrun ({gpu_count} processes)...") try: timeout = int(self.train_cfg.get("timeout_sec", max(600, num_steps * 180))) r = subprocess.run(cmd, capture_output=True, text=True, timeout=timeout, env=env) except subprocess.TimeoutExpired: os.unlink(tmp.name) return { "model": "synthetic_transformer_1.5b", "gpu_count": gpu_count, "distributed_mode": "ddp", "passed": False, "error": "training_ddp_timeout", "timestamp": datetime.now().isoformat(), } finally: if os.path.exists(tmp.name): try: os.unlink(tmp.name) except OSError: pass marker = "TRAINING_DDP_JSON=" payload = None for line in (r.stdout + "\n" + r.stderr).splitlines(): if marker in line: payload = line.split(marker, 1)[1].strip() if r.returncode != 0 or not payload: return { "model": "synthetic_transformer_1.5b", "gpu_count": gpu_count, "distributed_mode": "ddp", "passed": False, "error": (r.stderr or r.stdout or "training_ddp_failed")[-1000:], "timestamp": datetime.now().isoformat(), } result = json.loads(payload) loss_value = float(result.get("final_loss", "nan")) passed = self._acceptance_pass( float(result.get("throughput_tokens_per_sec", 0)), float(result.get("step_jitter_pct", 999)), float(result.get("peak_memory_gb", 999)), loss_value, ) and bool(result.get("loss_finite", False)) and result.get("gpu_count") == gpu_count result.update({ "passed": passed, "timestamp": datetime.now().isoformat(), }) return result def _run_synthetic(self, gpu_count, batch_size, seq_length, num_steps, dtype) -> dict: self.console.print(" Running synthetic training benchmark...") hidden_size = 4096 num_layers = 6 num_heads = 32 vocab_size = 32000 class SyntheticTransformer(torch.nn.Module): def __init__(self): super().__init__() self.embed = torch.nn.Embedding(vocab_size, hidden_size) self.layers = torch.nn.ModuleList([ torch.nn.TransformerEncoderLayer( d_model=hidden_size, nhead=num_heads, dim_feedforward=hidden_size * 4, batch_first=True, dtype=dtype, ) for _ in range(num_layers) ]) self.head = torch.nn.Linear(hidden_size, vocab_size, dtype=dtype) def forward(self, x): h = self.embed(x).to(dtype) for layer in self.layers: h = layer(h) return self.head(h) model = SyntheticTransformer() total_params = sum(p.numel() for p in model.parameters()) self.console.print(f" Synthetic params: {total_params / 1e6:.1f}M") distributed_mode = "single_gpu" if gpu_count > 1: model = torch.nn.DataParallel(model).cuda() distributed_mode = "data_parallel" else: model = model.cuda() model.train() optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) input_ids = torch.randint(0, vocab_size, (batch_size, seq_length)).cuda() step_times = [] mem_usage = [] warmup_steps = int(self.train_cfg.get("warmup_steps", 5)) total_steps = num_steps + warmup_steps with Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), BarColumn(), TextColumn("{task.completed}/{task.total}"), TimeElapsedColumn(), console=self.console, ) as progress: task = progress.add_task("Synthetic training...", total=total_steps) for step in range(total_steps): torch.cuda.synchronize() t0 = time.perf_counter() logits = model(input_ids) loss = torch.nn.functional.cross_entropy( logits.view(-1, vocab_size), input_ids.view(-1) ) loss.backward() optimizer.step() optimizer.zero_grad() torch.cuda.synchronize() elapsed = time.perf_counter() - t0 step_times.append(elapsed) mem_used = max(torch.cuda.max_memory_allocated(i) for i in range(gpu_count)) / 1024**3 mem_usage.append(mem_used) for i in range(gpu_count): torch.cuda.reset_peak_memory_stats(i) progress.advance(task) measured_steps = step_times[warmup_steps:] if len(step_times) > warmup_steps else step_times avg_step_time = sum(measured_steps) / len(measured_steps) throughput = batch_size * seq_length / avg_step_time jitter = self._jitter_pct(measured_steps) peak_mem = round(max(mem_usage) if mem_usage else 0, 2) final_loss = float(loss.item()) passed = self._acceptance_pass(throughput, jitter, peak_mem, final_loss) if self.train_cfg.get("require_distributed", True): passed = False return { "model": "synthetic_transformer", "total_params_m": round(total_params / 1e6, 1), "num_layers": num_layers, "hidden_size": hidden_size, "gpu_count": gpu_count, "dtype": str(dtype).replace("torch.", ""), "batch_size": batch_size, "seq_length": seq_length, "num_steps": num_steps, "warmup_steps": warmup_steps, "total_steps": total_steps, "avg_step_time_ms": round(avg_step_time * 1000, 1), "throughput_tokens_per_sec": round(throughput, 0), "throughput_samples_per_sec": round(batch_size / avg_step_time, 2), "peak_memory_gb": peak_mem, "final_loss": round(final_loss, 4), "step_jitter_pct": round(jitter, 2), "distributed_mode": distributed_mode, "loss_finite": math.isfinite(final_loss), "passed": passed, "acceptance_gap": "8-GPU DDP was not used" if self.train_cfg.get("require_distributed", True) else "", "timestamp": datetime.now().isoformat(), } @staticmethod def _jitter_pct(step_times: list[float]) -> float: if not step_times: return 0.0 mean = sum(step_times) / len(step_times) return max(abs(v - mean) / mean * 100 for v in step_times) if mean else 0.0 def _acceptance_pass(self, throughput: float, jitter: float, peak_mem: float, loss_value: float) -> bool: return ( throughput >= float(self.train_cfg.get("min_tokens_per_sec", 45000)) and jitter <= float(self.train_cfg.get("max_step_jitter_pct", 3)) and peak_mem <= float(self.train_cfg.get("max_peak_memory_gb", 70)) and math.isfinite(loss_value) ) @staticmethod def print_results(results: dict, console: Console = None): c = console or Console() if "error" in results: c.print(f"[bold red]Error: {results['error']}[/bold red]") return c.print(f"\n[bold cyan]Training Simulation Results[/bold cyan]") table = Table(box=None, padding=(0, 1)) table.add_column("Metric", style="bold") table.add_column("Value") metrics = [ ("Model", results.get("model", "N/A")), ("Parameters", f"{results.get('total_params_m', 'N/A')}M"), ("GPU Count", str(results.get("gpu_count", "N/A"))), ("DType", results.get("dtype", "N/A")), ("Batch Size", str(results.get("batch_size", "N/A"))), ("Seq Length", str(results.get("seq_length", "N/A"))), ("Steps", str(results.get("num_steps", "N/A"))), ("Warmup Steps", str(results.get("warmup_steps", "N/A"))), ("Avg Step Time", f"{results.get('avg_step_time_ms', 'N/A')} ms"), ("Throughput", f"{results.get('throughput_tokens_per_sec', 'N/A')} tokens/s"), ("Samples/sec", f"{results.get('throughput_samples_per_sec', 'N/A')}"), ("Peak Memory", f"{results.get('peak_memory_gb', 'N/A')} GB"), ("Final Loss", str(results.get("final_loss", "N/A"))), ("Step Jitter", f"{results.get('step_jitter_pct', 'N/A')}%"), ("Distributed Mode", results.get("distributed_mode", "N/A")), ("Verdict", "PASS" if results.get("passed") else "FAIL"), ] for label, val in metrics: table.add_row(label, str(val)) c.print(table)