"""GPU stress test module — wraps gpu-burn for long-running stability tests.""" import glob import os import shutil import subprocess import time from datetime import datetime from rich.console import Console from rich.table import Table from rich.live import Live from rich.text import Text from modules.gpu_specs import resolve_tools_dir class StressTest: def __init__(self, config: dict): self.config = config self.console = Console() self.stress_cfg = config.get("stress", {}) self.tools_dir = resolve_tools_dir(config) def _find_gpu_burn(self) -> str: p = shutil.which("gpu_burn") if p: return p local = os.path.join(self.tools_dir, "gpu-burn", "gpu_burn") if os.path.isfile(local) and os.access(local, shutil.os.X_OK): return local matches = glob.glob(os.path.join(self.tools_dir, "gpu-burn", "**", "gpu_burn"), recursive=True) for m in matches: if os.access(m, shutil.os.X_OK): return m return "" def run(self) -> dict: cfg = self.stress_cfg duration_sec = cfg.get("duration_sec", 60) use_doubles = cfg.get("use_doubles", False) use_tensor_cores = cfg.get("use_tensor_cores", True) memory_pct = cfg.get("memory_pct", 90) target_gpus = cfg.get("gpus", "all") gpu_burn = self._find_gpu_burn() if gpu_burn: # Try gpu-burn first result = self._run_gpu_burn(gpu_burn, duration_sec, use_doubles, use_tensor_cores, target_gpus) # If gpu-burn fails (e.g. OOM), auto-fallback to PyTorch if not result.get("passed") and result.get("elapsed_sec", 0) < duration_sec * 0.5: self.console.print("\n[yellow]gpu-burn exited early (possible OOM), switching to PyTorch stress test[/yellow]") self.console.print("[dim]PyTorch mode dynamically adapts to available memory[/dim]\n") return self._run_pytorch_stress(duration_sec, memory_pct) return result self.console.print("[yellow]gpu_burn not found, using PyTorch stress test[/yellow]") return self._run_pytorch_stress(duration_sec, memory_pct) def _run_gpu_burn(self, gpu_burn: str, duration: int, doubles: bool, tensor_cores: bool, target_gpus: str) -> dict: self.console.print(f"[cyan]GPU Stress Test via gpu-burn ({duration}s)[/cyan]") cmd = [gpu_burn] if doubles: cmd.append("-d") if tensor_cores: cmd.append("-tc") if target_gpus != "all": cmd.extend(["-i", str(target_gpus)]) cmd.append(str(duration)) t0 = time.time() try: r = subprocess.run(cmd, capture_output=True, text=True, timeout=duration + 120) elapsed = round(time.time() - t0, 1) output = r.stdout + r.stderr passed = r.returncode == 0 gpu_results = [] for line in output.split("\n"): line = line.strip() if "GPU" in line and ("PASS" in line.upper() or "FAIL" in line.upper()): gpu_results.append(line) return { "source": "gpu-burn", "passed": passed, "duration_sec": duration, "elapsed_sec": elapsed, "gpu_results": gpu_results, "raw_output_tail": output[-500:] if output else "", "timestamp": datetime.now().isoformat(), } except subprocess.TimeoutExpired: return { "source": "gpu-burn", "passed": False, "duration_sec": duration, "error": "timeout", "timestamp": datetime.now().isoformat(), } except Exception as e: return { "source": "gpu-burn", "passed": False, "error": str(e), "timestamp": datetime.now().isoformat(), } def _run_pytorch_stress(self, duration: int, memory_pct: int = 90) -> dict: try: import torch if not torch.cuda.is_available(): return {"error": "pytorch_not_available"} except ImportError: return {"error": "pytorch_not_available"} gpu_count = torch.cuda.device_count() self.console.print(f"[cyan]PyTorch Stress Test ({duration}s, {gpu_count} GPUs, target {memory_pct}% memory)[/cyan]") gpu_status = {} t0 = time.time() try: tensors = {} for i in range(gpu_count): with torch.cuda.device(i): # Get actual free memory (accounting for other processes) free_mem, total_mem = torch.cuda.mem_get_info(i) # Calculate allocation from configured memory_pct target_mem = int(total_mem * memory_pct / 100) # Cap at actual free memory with 5% safety margin alloc_bytes = min(target_mem, int(free_mem * 0.95)) # matmul(A, A.T) needs 2x input memory (input + output) mem_side = int((alloc_bytes / 4 / 2) ** 0.5) # Cap compute matrix so a single matmul completes in ~2s on H100/H200 # (FP32 ≈ 67 TFLOPS → 2*4096³/67e12 ≈ 2s). Without this cap, a 141GB # HBM yields side ≈ 131K → single matmul ~68s × 8 GPUs serial → loop # overshoots a 60s duration request by 10×+. MAX_COMPUTE_SIDE = 4096 side = min(mem_side, MAX_COMPUTE_SIDE) actual_mem_mb = side * side * 4 / 1024 / 1024 total_mem_mb = total_mem / 1024 / 1024 free_mem_mb = free_mem / 1024 / 1024 self.console.print( f" [dim]GPU {i}: total {total_mem_mb:.0f}MB, free {free_mem_mb:.0f}MB, " f"alloc {actual_mem_mb:.0f}MB ({actual_mem_mb/total_mem_mb*100:.0f}%) - " f"matrix {side}x{side}[/dim]" ) tensors[i] = torch.randn(side, side, device=f"cuda:{i}", dtype=torch.float32) self.console.print(f"\n[cyan]Starting stress test for {duration} seconds...[/cyan]") elapsed_check = 0 while time.time() - t0 < duration: # Dispatch matmul on all GPUs in parallel — do NOT synchronize between # GPUs, otherwise the 8 GPUs run serially and overshoot the duration. for i in range(gpu_count): with torch.cuda.device(i): tensors[i] = torch.matmul(tensors[i], tensors[i].T) # Single sync per pass — waits for all 8 streams concurrently for i in range(gpu_count): with torch.cuda.device(i): torch.cuda.synchronize() # Show progress every 10 seconds current_elapsed = time.time() - t0 if int(current_elapsed) != int(elapsed_check) and int(current_elapsed) % 10 == 0: self.console.print(f" [dim]Running {int(current_elapsed)}s / {duration}s[/dim]") elapsed_check = current_elapsed for i in range(gpu_count): gpu_status[i] = "PASS" except RuntimeError as e: error_msg = str(e) self.console.print(f"\n[red]Stress test error: {error_msg}[/red]") for i in range(gpu_count): if i not in gpu_status: gpu_status[i] = "FAIL" return { "source": "pytorch", "passed": False, "duration_sec": duration, "error": error_msg, "gpu_status": gpu_status, } finally: tensors.clear() torch.cuda.empty_cache() elapsed = round(time.time() - t0, 1) return { "source": "pytorch", "passed": True, "duration_sec": duration, "elapsed_sec": elapsed, "gpu_status": gpu_status, "timestamp": datetime.now().isoformat(), } @staticmethod def print_results(results: dict, console: Console = None): c = console or Console() if "error" in results: c.print(f"[bold red]Error: {results['error']}[/bold red]") return passed = results.get("passed", False) source = results.get("source", "unknown") duration = results.get("duration_sec", "?") elapsed = results.get("elapsed_sec", "?") verdict = "[bold green]✓ Stress Test PASSED[/bold green]" if passed else "[bold red]✗ Stress Test FAILED[/bold red]" c.print(f"\n{verdict} [dim](via {source})[/dim]") c.print(f" Target duration: {duration}s | Actual: {elapsed}s") gpu_results = results.get("gpu_results", []) if gpu_results: c.print("\n Per-GPU results:") for line in gpu_results: if "FAIL" in line.upper(): c.print(f" [red]{line}[/red]") else: c.print(f" [green]{line}[/green]") gpu_status = results.get("gpu_status", {}) if gpu_status: c.print("\n Per-GPU status:") for gid, status in sorted(gpu_status.items()): color = "green" if status == "PASS" else "red" c.print(f" GPU {gid}: [{color}]{status}[/{color}]") if results.get("error"): c.print(f" [red]Error: {results['error']}[/red]")