test_gpu_scripts/scripts/pytorch_fp8_path_bench.py

278 lines
8.4 KiB
Python
Executable File

#!/usr/bin/env python3
"""Compare FP8 GEMM paths used for H100/H200 acceptance debugging.
Paths:
A. torch._scaled_mm eager, default accumulation
B. torch._scaled_mm eager, use_fast_accum=True
C. CUDA Graph replay of torch._scaled_mm(out=..., use_fast_accum=True)
D. Transformer Engine Linear under fp8_autocast, when installed
"""
from __future__ import annotations
import argparse
import json
import statistics
import sys
import time
from typing import Any, Callable
import torch
def tflops_from_ms(matrix_size: int, iterations: int, elapsed_ms: float) -> float:
flops = 2.0 * matrix_size * matrix_size * matrix_size * iterations
return flops / (elapsed_ms / 1000.0) / 1e12
def cuda_event_bench(
name: str,
matrix_size: int,
iterations: int,
warmup: int,
func: Callable[[int], Any],
) -> dict[str, Any]:
for i in range(warmup):
func(i)
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
wall_start = time.perf_counter()
start.record()
for i in range(iterations):
func(i)
end.record()
torch.cuda.synchronize()
wall_elapsed = time.perf_counter() - wall_start
elapsed_ms = start.elapsed_time(end)
return {
"name": name,
"status": "ok",
"matrix_size": matrix_size,
"iterations": iterations,
"warmup": warmup,
"event_ms_total": round(elapsed_ms, 3),
"event_us_per_iter": round(elapsed_ms * 1000.0 / iterations, 3),
"wall_ms_total": round(wall_elapsed * 1000.0, 3),
"tflops": round(tflops_from_ms(matrix_size, iterations, elapsed_ms), 1),
}
def make_fp8_inputs(matrix_size: int, pools: int, device: str) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
a = [
torch.randn(matrix_size, matrix_size, device=device, dtype=torch.float32).to(torch.float8_e4m3fn)
for _ in range(pools)
]
b = [
torch.randn(matrix_size, matrix_size, device=device, dtype=torch.float32).to(torch.float8_e4m3fn)
for _ in range(pools)
]
torch.cuda.synchronize()
return a, b
def bench_scaled_mm(args: argparse.Namespace) -> list[dict[str, Any]]:
device = f"cuda:{args.gpu_index}"
torch.cuda.set_device(args.gpu_index)
scale_a = torch.tensor(1.0, device=device)
scale_b = torch.tensor(1.0, device=device)
pools_a, pools_b = make_fp8_inputs(args.matrix_size, args.pools, device)
results: list[dict[str, Any]] = []
def eager_default(i: int) -> torch.Tensor:
idx = i % args.pools
return torch._scaled_mm(
pools_a[idx],
pools_b[idx].T,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=torch.bfloat16,
)
def eager_fast(i: int) -> torch.Tensor:
idx = i % args.pools
return torch._scaled_mm(
pools_a[idx],
pools_b[idx].T,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=torch.bfloat16,
use_fast_accum=True,
)
results.append(
cuda_event_bench(
"A_eager_scaled_mm_default",
args.matrix_size,
args.iterations,
args.warmup,
eager_default,
)
)
results.append(
cuda_event_bench(
"B_eager_scaled_mm_fast_accum",
args.matrix_size,
args.iterations,
args.warmup,
eager_fast,
)
)
graph_out = torch.empty(
(args.matrix_size, args.matrix_size),
device=device,
dtype=torch.bfloat16,
)
static_a = pools_a[0]
static_b_t = pools_b[0].T
try:
side_stream = torch.cuda.Stream()
side_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(side_stream):
for _ in range(max(3, args.warmup // 2)):
torch._scaled_mm(
static_a,
static_b_t,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=torch.bfloat16,
use_fast_accum=True,
out=graph_out,
)
torch.cuda.current_stream().wait_stream(side_stream)
torch.cuda.synchronize()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
torch._scaled_mm(
static_a,
static_b_t,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=torch.bfloat16,
use_fast_accum=True,
out=graph_out,
)
def graph_replay(_: int) -> None:
graph.replay()
results.append(
cuda_event_bench(
"C_cuda_graph_scaled_mm_fast_accum",
args.matrix_size,
args.iterations,
3,
graph_replay,
)
)
except Exception as exc: # noqa: BLE001
results.append(
{
"name": "C_cuda_graph_scaled_mm_fast_accum",
"status": "unavailable",
"reason": f"{type(exc).__name__}: {exc}",
}
)
return results
def bench_transformer_engine(args: argparse.Namespace) -> dict[str, Any]:
try:
import transformer_engine.pytorch as te # type: ignore[import-not-found]
from transformer_engine.common.recipe import DelayedScaling, Format # type: ignore[import-not-found]
except Exception as exc: # noqa: BLE001
return {
"name": "D_transformer_engine_fp8_linear",
"status": "unavailable",
"reason": f"{type(exc).__name__}: {exc}",
}
device = f"cuda:{args.gpu_index}"
x = torch.randn(args.matrix_size, args.matrix_size, device=device, dtype=torch.bfloat16)
layer = te.Linear(
args.matrix_size,
args.matrix_size,
bias=False,
params_dtype=torch.bfloat16,
device=device,
)
recipe = DelayedScaling(fp8_format=Format.HYBRID)
def run(_: int) -> torch.Tensor:
with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
return layer(x)
try:
result = cuda_event_bench(
"D_transformer_engine_fp8_linear",
args.matrix_size,
args.iterations,
args.warmup,
run,
)
except Exception as exc: # noqa: BLE001
return {
"name": "D_transformer_engine_fp8_linear",
"status": "error",
"reason": f"{type(exc).__name__}: {exc}",
}
result["note"] = "Transformer Engine Linear forward under fp8_autocast; includes TE module/cast overhead."
return result
def main() -> int:
parser = argparse.ArgumentParser()
parser.add_argument("--matrix-size", type=int, default=8192)
parser.add_argument("--warmup", type=int, default=20)
parser.add_argument("--iterations", type=int, default=100)
parser.add_argument("--gpu-index", type=int, default=0)
parser.add_argument("--pools", type=int, default=4)
args = parser.parse_args()
if not torch.cuda.is_available():
print(json.dumps({"error": "cuda unavailable"}, indent=2))
return 1
if not hasattr(torch, "_scaled_mm") or not hasattr(torch, "float8_e4m3fn"):
print(json.dumps({"error": "torch FP8 _scaled_mm unavailable"}, indent=2))
return 1
torch.cuda.set_device(args.gpu_index)
props = torch.cuda.get_device_properties(args.gpu_index)
payload = {
"source": "pytorch_fp8_path_bench",
"torch": torch.__version__,
"cuda": torch.version.cuda,
"gpu_index": args.gpu_index,
"gpu_name": props.name,
"matrix_size": args.matrix_size,
"warmup": args.warmup,
"iterations": args.iterations,
"results": [],
}
try:
payload["results"].extend(bench_scaled_mm(args))
payload["results"].append(bench_transformer_engine(args))
except torch.cuda.OutOfMemoryError as exc:
payload["error"] = f"CUDA OOM: {exc}"
print(json.dumps(payload, indent=2))
return 1
ok_values = [r["tflops"] for r in payload["results"] if r.get("status") == "ok"]
if ok_values:
payload["summary"] = {
"max_tflops": round(max(ok_values), 1),
"min_tflops": round(min(ok_values), 1),
"mean_tflops": round(statistics.mean(ok_values), 1),
}
print(json.dumps(payload, indent=2))
return 0
if __name__ == "__main__":
sys.exit(main())