"""GPU stress test module — gpu-burn or PyTorch GEMM with telemetry.""" import glob import os import shutil import subprocess import threading import time from datetime import datetime from rich.console import Console from rich.table import Table from rich.live import Live from rich.text import Text from modules.gpu_specs import resolve_tools_dir class StressTest: def __init__(self, config: dict): self.config = config self.console = Console() self.stress_cfg = config.get("stress", {}) self.tools_dir = resolve_tools_dir(config) def _find_gpu_burn(self) -> str: p = shutil.which("gpu_burn") if p: return p local = os.path.join(self.tools_dir, "gpu-burn", "gpu_burn") if os.path.isfile(local) and os.access(local, shutil.os.X_OK): return local matches = glob.glob(os.path.join(self.tools_dir, "gpu-burn", "**", "gpu_burn"), recursive=True) for m in matches: if os.access(m, shutil.os.X_OK): return m return "" def run(self) -> dict: cfg = self.stress_cfg duration_sec = cfg.get("duration_sec", 60) use_doubles = cfg.get("use_doubles", False) use_tensor_cores = cfg.get("use_tensor_cores", True) memory_pct = cfg.get("memory_pct", 90) target_gpus = cfg.get("gpus", "all") gpu_burn = self._find_gpu_burn() if cfg.get("use_gpu_burn", False) else "" if gpu_burn: # Try gpu-burn first result = self._run_gpu_burn(gpu_burn, duration_sec, use_doubles, use_tensor_cores, target_gpus) # If gpu-burn fails (e.g. OOM), auto-fallback to PyTorch if not result.get("passed") and result.get("elapsed_sec", 0) < duration_sec * 0.5: self.console.print("\n[yellow]gpu-burn exited early (possible OOM), switching to PyTorch stress test[/yellow]") self.console.print("[dim]PyTorch mode dynamically adapts to available memory[/dim]\n") return self._run_pytorch_stress(duration_sec, memory_pct) return result self.console.print("[yellow]Using PyTorch stress test[/yellow]") return self._run_pytorch_stress(duration_sec, memory_pct) def _run_gpu_burn(self, gpu_burn: str, duration: int, doubles: bool, tensor_cores: bool, target_gpus: str) -> dict: self.console.print(f"[cyan]GPU Stress Test via gpu-burn ({duration}s)[/cyan]") cmd = [gpu_burn] if doubles: cmd.append("-d") if tensor_cores: cmd.append("-tc") if target_gpus != "all": cmd.extend(["-i", str(target_gpus)]) cmd.append(str(duration)) t0 = time.time() xid_before = self._collect_xid_events() interval = int(self.stress_cfg.get("telemetry_interval_sec", 1)) telemetry = [] stop_sampling = threading.Event() sampler = threading.Thread( target=self._sample_telemetry, args=(telemetry, stop_sampling, interval), daemon=True, ) sampler.start() try: r = subprocess.run(cmd, capture_output=True, text=True, timeout=duration + 120) elapsed = round(time.time() - t0, 1) stop_sampling.set() sampler.join(timeout=interval + 1) output = r.stdout + r.stderr xid_events = self._new_xid_events(xid_before, self._collect_xid_events()) telemetry_summary = self._evaluate_telemetry(telemetry, [], xid_events) passed = r.returncode == 0 and telemetry_summary.get("passed", False) gpu_results = [] for line in output.split("\n"): line = line.strip() if "GPU" in line and ("PASS" in line.upper() or "FAIL" in line.upper()): gpu_results.append(line) return { "source": "gpu-burn", "passed": passed, "duration_sec": duration, "elapsed_sec": elapsed, "gpu_results": gpu_results, "telemetry": telemetry_summary, "raw_output_tail": output[-500:] if output else "", "timestamp": datetime.now().isoformat(), } except subprocess.TimeoutExpired: stop_sampling.set() return { "source": "gpu-burn", "passed": False, "duration_sec": duration, "error": "timeout", "telemetry": self._evaluate_telemetry( telemetry, [], self._new_xid_events(xid_before, self._collect_xid_events()) ), "timestamp": datetime.now().isoformat(), } except Exception as e: stop_sampling.set() return { "source": "gpu-burn", "passed": False, "error": str(e), "telemetry": self._evaluate_telemetry( telemetry, [], self._new_xid_events(xid_before, self._collect_xid_events()) ), "timestamp": datetime.now().isoformat(), } finally: stop_sampling.set() def _run_pytorch_stress(self, duration: int, memory_pct: int = 90) -> dict: try: import torch if not torch.cuda.is_available(): return {"error": "pytorch_not_available"} except ImportError: return {"error": "pytorch_not_available"} gpu_count = torch.cuda.device_count() self.console.print(f"[cyan]PyTorch Stress Test ({duration}s, {gpu_count} GPUs, target {memory_pct}% memory)[/cyan]") dtype_name = self.stress_cfg.get("dtype", "bf16") matrix_size = int(self.stress_cfg.get("matrix_size", 8192)) interval = int(self.stress_cfg.get("telemetry_interval_sec", 1)) dtype_map = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32} dtype = dtype_map.get(dtype_name, torch.bfloat16) gpu_status = {} telemetry = [] stop_sampling = threading.Event() t0 = time.time() xid_before = self._collect_xid_events() try: sampler = threading.Thread( target=self._sample_telemetry, args=(telemetry, stop_sampling, interval), daemon=True, ) sampler.start() tensors = {} ballast = {} pass_tflops = [] for i in range(gpu_count): with torch.cuda.device(i): free_mem, total_mem = torch.cuda.mem_get_info(i) side = matrix_size elem = torch.tensor([], dtype=dtype).element_size() compute_bytes = side * side * elem * 3 target_mem = min(int(total_mem * memory_pct / 100), int(free_mem * 0.90)) ballast_bytes = max(0, target_mem - compute_bytes) if ballast_bytes: ballast_elems = ballast_bytes // 2 ballast[i] = torch.empty(ballast_elems, device=f"cuda:{i}", dtype=torch.float16) actual_mem_mb = (compute_bytes + ballast_bytes) / 1024 / 1024 total_mem_mb = total_mem / 1024 / 1024 free_mem_mb = free_mem / 1024 / 1024 self.console.print( f" [dim]GPU {i}: total {total_mem_mb:.0f}MB, free {free_mem_mb:.0f}MB, " f"alloc {actual_mem_mb:.0f}MB ({actual_mem_mb/total_mem_mb*100:.0f}%) - " f"{dtype_name} matrix {side}x{side}[/dim]" ) tensors[i] = ( torch.randn(side, side, device=f"cuda:{i}", dtype=dtype), torch.randn(side, side, device=f"cuda:{i}", dtype=dtype), torch.empty(side, side, device=f"cuda:{i}", dtype=dtype), ) self.console.print(f"\n[cyan]Starting stress test for {duration} seconds...[/cyan]") elapsed_check = 0 while time.time() - t0 < duration: loop_start = time.perf_counter() # Dispatch matmul on all GPUs in parallel — do NOT synchronize between # GPUs, otherwise the 8 GPUs run serially and overshoot the duration. for i in range(gpu_count): with torch.cuda.device(i): a, b, out = tensors[i] torch.matmul(a, b, out=out) # Single sync per pass — waits for all 8 streams concurrently for i in range(gpu_count): with torch.cuda.device(i): torch.cuda.synchronize() loop_elapsed = time.perf_counter() - loop_start current_elapsed = time.time() - t0 if loop_elapsed > 0: flops = gpu_count * 2 * (matrix_size ** 3) pass_tflops.append({ "elapsed_sec": current_elapsed, "tflops": flops / loop_elapsed / 1e12, }) # Show progress every 10 seconds if int(current_elapsed) != int(elapsed_check) and int(current_elapsed) % 10 == 0: self.console.print(f" [dim]Running {int(current_elapsed)}s / {duration}s[/dim]") elapsed_check = current_elapsed for i in range(gpu_count): gpu_status[i] = "PASS" except RuntimeError as e: error_msg = str(e) self.console.print(f"\n[red]Stress test error: {error_msg}[/red]") for i in range(gpu_count): if i not in gpu_status: gpu_status[i] = "FAIL" return { "source": "pytorch", "passed": False, "duration_sec": duration, "error": error_msg, "gpu_status": gpu_status, "telemetry": self._evaluate_telemetry( telemetry, pass_tflops if "pass_tflops" in locals() else [], self._new_xid_events(xid_before, self._collect_xid_events()), ), } finally: stop_sampling.set() tensors.clear() ballast.clear() torch.cuda.empty_cache() elapsed = round(time.time() - t0, 1) xid_events = self._new_xid_events(xid_before, self._collect_xid_events()) telemetry_summary = self._evaluate_telemetry(telemetry, pass_tflops, xid_events) passed = all(v == "PASS" for v in gpu_status.values()) and telemetry_summary.get("passed", False) return { "source": "pytorch", "passed": passed, "duration_sec": duration, "elapsed_sec": elapsed, "gpu_status": gpu_status, "telemetry": telemetry_summary, "timestamp": datetime.now().isoformat(), } def _sample_telemetry(self, telemetry: list, stop_event: threading.Event, interval: int): query = "index,temperature.gpu,power.draw,clocks_throttle_reasons.active" while not stop_event.is_set(): try: r = subprocess.run( ["nvidia-smi", f"--query-gpu={query}", "--format=csv,noheader,nounits"], capture_output=True, text=True, timeout=10, ) if r.returncode == 0: sample = {"time": time.time(), "gpus": []} for line in r.stdout.splitlines(): parts = [p.strip() for p in line.split(",")] if len(parts) >= 4: sample["gpus"].append({ "index": int(parts[0]), "temp_c": float(parts[1]), "power_w": float(parts[2]), "throttle": parts[3], }) telemetry.append(sample) except Exception: pass stop_event.wait(interval) def _collect_xid_events(self) -> list[str]: try: r = subprocess.run( ["dmesg", "--color=never"], capture_output=True, text=True, timeout=10, ) if r.returncode != 0: return [] return [ line.strip() for line in r.stdout.splitlines() if any(token in line.upper() for token in ("XID", "NVRM: XID")) ] except Exception: return [] @staticmethod def _new_xid_events(before: list[str], after: list[str]) -> list[str]: seen = set(before) return [line for line in after if line not in seen] def _evaluate_telemetry(self, telemetry: list, pass_tflops: list, xid_events: list[str] | None = None) -> dict: cfg = self.stress_cfg max_temp = float(cfg.get("max_temp_c", 80)) max_delta = float(cfg.get("max_temp_delta_c", 5)) min_power = float(cfg.get("min_power_watts", 630)) max_jitter = float(cfg.get("max_tflops_jitter_pct", 5)) require_jitter = bool(cfg.get("require_tflops_jitter", True)) duration = float(cfg.get("duration_sec", 60)) requested_warmup = float(cfg.get("warmup_sec", 60)) warmup_sec = min(requested_warmup, max(0.0, duration * 0.2)) min_steady_samples = int(cfg.get("min_steady_samples", 10)) temps = {} powers = {} throttle_bad = [] xid_events = xid_events or [] steady_telemetry = [ sample for sample in telemetry if sample.get("time", 0) - telemetry[0].get("time", 0) >= warmup_sec ] if telemetry else [] evaluation_samples = steady_telemetry if len(steady_telemetry) >= min_steady_samples else telemetry for sample in evaluation_samples: for g in sample.get("gpus", []): idx = g["index"] temps.setdefault(idx, []).append(g["temp_c"]) powers.setdefault(idx, []).append(g["power_w"]) try: bitmask = int(str(g["throttle"]), 16) except ValueError: bitmask = 0 real_throttle = bitmask & ~0x1 if real_throttle: throttle_bad.append({ "gpu": idx, "throttle": g["throttle"], "real_throttle": f"0x{real_throttle:x}", }) max_temps = {idx: max(vals) for idx, vals in temps.items() if vals} avg_powers = {idx: sum(vals) / len(vals) for idx, vals in powers.items() if vals} temp_delta = (max(max_temps.values()) - min(max_temps.values())) if len(max_temps) >= 2 else 0 jitter = 0 steady_tflops = [] for item in pass_tflops: if isinstance(item, dict): if float(item.get("elapsed_sec", 0)) >= warmup_sec: steady_tflops.append(float(item.get("tflops", 0))) else: steady_tflops.append(float(item)) if len(steady_tflops) < 2 and pass_tflops: steady_tflops = [ float(item.get("tflops", 0)) if isinstance(item, dict) else float(item) for item in pass_tflops ] if steady_tflops: mean = sum(steady_tflops) / len(steady_tflops) jitter = max(abs(v - mean) / mean * 100 for v in steady_tflops) if mean else 0 failures = [] temp_failures = {idx: v for idx, v in max_temps.items() if v > max_temp} power_failures = {idx: v for idx, v in avg_powers.items() if v < min_power} if not evaluation_samples: failures.append("no telemetry samples available for evaluation") if temp_failures: failures.append( "max temperature above threshold: " + ", ".join(f"GPU {idx} {val:.1f}C" for idx, val in sorted(temp_failures.items())) ) if temp_delta > max_delta: failures.append(f"GPU temperature delta {temp_delta:.1f}C exceeds {max_delta:.1f}C") if power_failures: failures.append( "average steady-state power below threshold: " + ", ".join(f"GPU {idx} {val:.1f}W" for idx, val in sorted(power_failures.items())) ) if throttle_bad: failures.append( f"non-idle throttle reasons observed in {len(throttle_bad)} samples " f"(first: GPU {throttle_bad[0]['gpu']} {throttle_bad[0]['real_throttle']})" ) if xid_events: failures.append(f"{len(xid_events)} new XID/NVRM XID events observed") if require_jitter and len(steady_tflops) < 2: failures.append( f"insufficient steady TFLOPS samples for jitter evaluation: {len(steady_tflops)} < 2" ) if jitter > max_jitter: failures.append(f"TFLOPS jitter {jitter:.2f}% exceeds {max_jitter:.2f}%") passed = ( bool(evaluation_samples) and all(v <= max_temp for v in max_temps.values()) and temp_delta <= max_delta and all(v >= min_power for v in avg_powers.values()) and not throttle_bad and not xid_events and (not require_jitter or len(steady_tflops) >= 2) and jitter <= max_jitter ) return { "passed": passed, "samples": len(telemetry), "steady_samples": len(evaluation_samples), "warmup_sec": round(warmup_sec, 1), "max_temp_c": {k: round(v, 1) for k, v in max_temps.items()}, "avg_power_w": {k: round(v, 1) for k, v in avg_powers.items()}, "temp_delta_c": round(temp_delta, 1), "throttle_events": throttle_bad[:20], "throttle_event_count": len(throttle_bad), "xid_events": xid_events[-20:], "tflops_jitter_pct": round(jitter, 2), "steady_tflops_samples": len(steady_tflops), "failures": failures, "thresholds": { "max_temp_c": max_temp, "max_temp_delta_c": max_delta, "min_power_w": min_power, "max_tflops_jitter_pct": max_jitter, "require_tflops_jitter": require_jitter, "warmup_sec": requested_warmup, "min_steady_samples": min_steady_samples, }, } @staticmethod def print_results(results: dict, console: Console = None): c = console or Console() if "error" in results: c.print(f"[bold red]Error: {results['error']}[/bold red]") return passed = results.get("passed", False) source = results.get("source", "unknown") duration = results.get("duration_sec", "?") elapsed = results.get("elapsed_sec", "?") verdict = "[bold green]✓ Stress Test PASSED[/bold green]" if passed else "[bold red]✗ Stress Test FAILED[/bold red]" c.print(f"\n{verdict} [dim](via {source})[/dim]") c.print(f" Target duration: {duration}s | Actual: {elapsed}s") gpu_results = results.get("gpu_results", []) if gpu_results: c.print("\n Per-GPU results:") for line in gpu_results: if "FAIL" in line.upper(): c.print(f" [red]{line}[/red]") else: c.print(f" [green]{line}[/green]") gpu_status = results.get("gpu_status", {}) if gpu_status: c.print("\n Per-GPU status:") for gid, status in sorted(gpu_status.items()): color = "green" if status == "PASS" else "red" c.print(f" GPU {gid}: [{color}]{status}[/{color}]") telemetry = results.get("telemetry") or {} if telemetry: c.print("\n Telemetry:") c.print(f" Samples: {telemetry.get('samples', 0)} total, {telemetry.get('steady_samples', 0)} evaluated after {telemetry.get('warmup_sec', 0)}s warmup") c.print(f" Avg steady power: {telemetry.get('avg_power_w', {})}") c.print(f" Max steady temp: {telemetry.get('max_temp_c', {})}") c.print(f" Temp delta: {telemetry.get('temp_delta_c', 'N/A')} C") c.print(f" TFLOPS jitter: {telemetry.get('tflops_jitter_pct', 'N/A')}%") c.print(f" Throttle events: {telemetry.get('throttle_event_count', len(telemetry.get('throttle_events', [])))}") c.print(f" XID events: {len(telemetry.get('xid_events', []))}") failures = telemetry.get("failures", []) if failures: c.print(" [red]Failure reasons:[/red]") for reason in failures: c.print(f" [red]- {reason}[/red]") if results.get("error"): c.print(f" [red]Error: {results['error']}[/red]")