#!/usr/bin/env python3 """Compare FP8 GEMM paths used for H100/H200 acceptance debugging. Paths: A. torch._scaled_mm eager, default accumulation B. torch._scaled_mm eager, use_fast_accum=True C. CUDA Graph replay of torch._scaled_mm(out=..., use_fast_accum=True) D. Transformer Engine Linear under fp8_autocast, when installed """ from __future__ import annotations import argparse import json import statistics import sys import time from typing import Any, Callable import torch def tflops_from_ms(matrix_size: int, iterations: int, elapsed_ms: float) -> float: flops = 2.0 * matrix_size * matrix_size * matrix_size * iterations return flops / (elapsed_ms / 1000.0) / 1e12 def cuda_event_bench( name: str, matrix_size: int, iterations: int, warmup: int, func: Callable[[int], Any], ) -> dict[str, Any]: for i in range(warmup): func(i) torch.cuda.synchronize() start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) wall_start = time.perf_counter() start.record() for i in range(iterations): func(i) end.record() torch.cuda.synchronize() wall_elapsed = time.perf_counter() - wall_start elapsed_ms = start.elapsed_time(end) return { "name": name, "status": "ok", "matrix_size": matrix_size, "iterations": iterations, "warmup": warmup, "event_ms_total": round(elapsed_ms, 3), "event_us_per_iter": round(elapsed_ms * 1000.0 / iterations, 3), "wall_ms_total": round(wall_elapsed * 1000.0, 3), "tflops": round(tflops_from_ms(matrix_size, iterations, elapsed_ms), 1), } def make_fp8_inputs(matrix_size: int, pools: int, device: str) -> tuple[list[torch.Tensor], list[torch.Tensor]]: a = [ torch.randn(matrix_size, matrix_size, device=device, dtype=torch.float32).to(torch.float8_e4m3fn) for _ in range(pools) ] b = [ torch.randn(matrix_size, matrix_size, device=device, dtype=torch.float32).to(torch.float8_e4m3fn) for _ in range(pools) ] torch.cuda.synchronize() return a, b def bench_scaled_mm(args: argparse.Namespace) -> list[dict[str, Any]]: device = f"cuda:{args.gpu_index}" torch.cuda.set_device(args.gpu_index) scale_a = torch.tensor(1.0, device=device) scale_b = torch.tensor(1.0, device=device) pools_a, pools_b = make_fp8_inputs(args.matrix_size, args.pools, device) results: list[dict[str, Any]] = [] def eager_default(i: int) -> torch.Tensor: idx = i % args.pools return torch._scaled_mm( pools_a[idx], pools_b[idx].T, scale_a=scale_a, scale_b=scale_b, out_dtype=torch.bfloat16, ) def eager_fast(i: int) -> torch.Tensor: idx = i % args.pools return torch._scaled_mm( pools_a[idx], pools_b[idx].T, scale_a=scale_a, scale_b=scale_b, out_dtype=torch.bfloat16, use_fast_accum=True, ) results.append( cuda_event_bench( "A_eager_scaled_mm_default", args.matrix_size, args.iterations, args.warmup, eager_default, ) ) results.append( cuda_event_bench( "B_eager_scaled_mm_fast_accum", args.matrix_size, args.iterations, args.warmup, eager_fast, ) ) graph_out = torch.empty( (args.matrix_size, args.matrix_size), device=device, dtype=torch.bfloat16, ) static_a = pools_a[0] static_b_t = pools_b[0].T try: side_stream = torch.cuda.Stream() side_stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(side_stream): for _ in range(max(3, args.warmup // 2)): torch._scaled_mm( static_a, static_b_t, scale_a=scale_a, scale_b=scale_b, out_dtype=torch.bfloat16, use_fast_accum=True, out=graph_out, ) torch.cuda.current_stream().wait_stream(side_stream) torch.cuda.synchronize() graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph): torch._scaled_mm( static_a, static_b_t, scale_a=scale_a, scale_b=scale_b, out_dtype=torch.bfloat16, use_fast_accum=True, out=graph_out, ) def graph_replay(_: int) -> None: graph.replay() results.append( cuda_event_bench( "C_cuda_graph_scaled_mm_fast_accum", args.matrix_size, args.iterations, 3, graph_replay, ) ) except Exception as exc: # noqa: BLE001 results.append( { "name": "C_cuda_graph_scaled_mm_fast_accum", "status": "unavailable", "reason": f"{type(exc).__name__}: {exc}", } ) return results def bench_transformer_engine(args: argparse.Namespace) -> dict[str, Any]: try: import transformer_engine.pytorch as te # type: ignore[import-not-found] from transformer_engine.common.recipe import DelayedScaling, Format # type: ignore[import-not-found] except Exception as exc: # noqa: BLE001 return { "name": "D_transformer_engine_fp8_linear", "status": "unavailable", "reason": f"{type(exc).__name__}: {exc}", } device = f"cuda:{args.gpu_index}" x = torch.randn(args.matrix_size, args.matrix_size, device=device, dtype=torch.bfloat16) layer = te.Linear( args.matrix_size, args.matrix_size, bias=False, params_dtype=torch.bfloat16, device=device, ) recipe = DelayedScaling(fp8_format=Format.HYBRID) def run(_: int) -> torch.Tensor: with te.fp8_autocast(enabled=True, fp8_recipe=recipe): return layer(x) try: result = cuda_event_bench( "D_transformer_engine_fp8_linear", args.matrix_size, args.iterations, args.warmup, run, ) except Exception as exc: # noqa: BLE001 return { "name": "D_transformer_engine_fp8_linear", "status": "error", "reason": f"{type(exc).__name__}: {exc}", } result["note"] = "Transformer Engine Linear forward under fp8_autocast; includes TE module/cast overhead." return result def main() -> int: parser = argparse.ArgumentParser() parser.add_argument("--matrix-size", type=int, default=8192) parser.add_argument("--warmup", type=int, default=20) parser.add_argument("--iterations", type=int, default=100) parser.add_argument("--gpu-index", type=int, default=0) parser.add_argument("--pools", type=int, default=4) args = parser.parse_args() if not torch.cuda.is_available(): print(json.dumps({"error": "cuda unavailable"}, indent=2)) return 1 if not hasattr(torch, "_scaled_mm") or not hasattr(torch, "float8_e4m3fn"): print(json.dumps({"error": "torch FP8 _scaled_mm unavailable"}, indent=2)) return 1 torch.cuda.set_device(args.gpu_index) props = torch.cuda.get_device_properties(args.gpu_index) payload = { "source": "pytorch_fp8_path_bench", "torch": torch.__version__, "cuda": torch.version.cuda, "gpu_index": args.gpu_index, "gpu_name": props.name, "matrix_size": args.matrix_size, "warmup": args.warmup, "iterations": args.iterations, "results": [], } try: payload["results"].extend(bench_scaled_mm(args)) payload["results"].append(bench_transformer_engine(args)) except torch.cuda.OutOfMemoryError as exc: payload["error"] = f"CUDA OOM: {exc}" print(json.dumps(payload, indent=2)) return 1 ok_values = [r["tflops"] for r in payload["results"] if r.get("status") == "ok"] if ok_values: payload["summary"] = { "max_tflops": round(max(ok_values), 1), "min_tflops": round(min(ok_values), 1), "mean_tflops": round(statistics.mean(ok_values), 1), } print(json.dumps(payload, indent=2)) return 0 if __name__ == "__main__": sys.exit(main())