refactor: remove hardcoding, fix AMP bug, unify English output
- Fix AMP autocast: bf16 now uses torch.amp.autocast (was skipped) - Fix NCCL threshold: unknown GPU gets 10 GB/s floor instead of 0 - Fix PCIe health check: use specs-driven pcie_gen, not hardcoded Gen4 - Remove hardcoded GPU lists: dynamic banner, CLI choices, version - Unknown GPU efficiency displays N/A instead of 0% - Unify all console output to English (stress_test, gpu_tester) - Use importlib.metadata for runtime version resolution - Remove dir="/tmp" from tempfile (use system default) 🤖 Generated with [Qoder][https://qoder.com]
This commit is contained in:
parent
f2158f6cd3
commit
fefef8e03b
@ -1,5 +1,5 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""GPU Training Server Test Suite (A100/A800/H100/H200/B200/B300) - Main CLI Entry Point."""
|
"""GPU Training Server Test Suite - Main CLI Entry Point."""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
@ -74,17 +74,19 @@ DEFAULT_CONFIG = {
|
|||||||
"tools": {"install_dir": "/opt/gpu-test-tools"},
|
"tools": {"install_dir": "/opt/gpu-test-tools"},
|
||||||
}
|
}
|
||||||
|
|
||||||
BANNER = r"""
|
def _build_banner() -> str:
|
||||||
[bold cyan]
|
gpu_list = " / ".join(g.upper() for g in get_supported_gpus())
|
||||||
╔══════════════════════════════════════════════════════╗
|
return (
|
||||||
║ ║
|
"[bold cyan]\n"
|
||||||
║ GPU Training Server Test Suite ║
|
"╔══════════════════════════════════════════════════════════╗\n"
|
||||||
║ Diagnostics & Benchmarking Tool ║
|
"║ ║\n"
|
||||||
║ Supports: A100 / A800 / H100 / H200 / B200 / B300 ║
|
"║ GPU Training Server Test Suite ║\n"
|
||||||
║ ║
|
"║ Diagnostics & Benchmarking Tool ║\n"
|
||||||
╚══════════════════════════════════════════════════════════╝
|
f"║ Supports: {gpu_list:<40s} ║\n"
|
||||||
[/bold cyan]
|
"║ ║\n"
|
||||||
"""
|
"╚══════════════════════════════════════════════════════════╝\n"
|
||||||
|
"[/bold cyan]"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_config() -> dict:
|
def load_config() -> dict:
|
||||||
@ -112,7 +114,7 @@ def interactive_menu(config: dict):
|
|||||||
"""Run interactive menu loop."""
|
"""Run interactive menu loop."""
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
||||||
console.print(BANNER)
|
console.print(_build_banner())
|
||||||
|
|
||||||
gpu_type = detect_gpu_type()
|
gpu_type = detect_gpu_type()
|
||||||
gpu_label = get_gpu_label(gpu_type)
|
gpu_label = get_gpu_label(gpu_type)
|
||||||
@ -310,7 +312,7 @@ def _run_full_suite(config: dict, console: Console) -> dict:
|
|||||||
|
|
||||||
# Summary
|
# Summary
|
||||||
console.print("\n" + "=" * 60)
|
console.print("\n" + "=" * 60)
|
||||||
# 只统计测试结果,排除 timestamp 等元数据
|
# Only count test results, exclude metadata like timestamp
|
||||||
test_results = {k: v for k, v in all_results.items() if k != "timestamp"}
|
test_results = {k: v for k, v in all_results.items() if k != "timestamp"}
|
||||||
passed = sum(1 for v in test_results.values() if not isinstance(v, dict) or "error" not in v)
|
passed = sum(1 for v in test_results.values() if not isinstance(v, dict) or "error" not in v)
|
||||||
total = len(test_results)
|
total = len(test_results)
|
||||||
@ -320,8 +322,9 @@ def _run_full_suite(config: dict, console: Console) -> dict:
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
gpu_list_str = " / ".join(g.upper() for g in get_supported_gpus())
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="GPU Training Server Test Suite (A100/A800/H100/H200/B200/B300)",
|
description=f"GPU Training Server Test Suite ({gpu_list_str})",
|
||||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
epilog="""
|
epilog="""
|
||||||
Examples:
|
Examples:
|
||||||
@ -349,7 +352,7 @@ Examples:
|
|||||||
parser.add_argument("--config", default=None, help="Path to config YAML file")
|
parser.add_argument("--config", default=None, help="Path to config YAML file")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--gpu-type",
|
"--gpu-type",
|
||||||
choices=["auto", "a100", "a800", "h100", "h200", "b200", "b300"],
|
choices=["auto"] + get_supported_gpus(),
|
||||||
default="auto",
|
default="auto",
|
||||||
help="Override GPU type detection",
|
help="Override GPU type detection",
|
||||||
)
|
)
|
||||||
|
|||||||
@ -151,13 +151,13 @@ class Benchmark:
|
|||||||
# (nvlink_bandwidth_gbps is bidirectional, so per-direction = /2)
|
# (nvlink_bandwidth_gbps is bidirectional, so per-direction = /2)
|
||||||
nvlink_bw = self.specs.get("nvlink_bandwidth_gbps", 0)
|
nvlink_bw = self.specs.get("nvlink_bandwidth_gbps", 0)
|
||||||
d2d_peak = nvlink_bw / 2 if nvlink_bw else 0
|
d2d_peak = nvlink_bw / 2 if nvlink_bw else 0
|
||||||
d2d_efficiency = (d2d_bw / d2d_peak) * 100 if (d2d_bw and d2d_peak) else 0
|
d2d_efficiency = round((d2d_bw / d2d_peak) * 100, 1) if (d2d_bw and d2d_peak) else None
|
||||||
|
|
||||||
# H2D/D2H goes through PCIe — estimate peak from PCIe gen
|
# H2D/D2H goes through PCIe — estimate peak from PCIe gen
|
||||||
pcie_gen = self.specs.get("pcie_gen", 4)
|
pcie_gen = self.specs.get("pcie_gen", 0)
|
||||||
pcie_peak = {3: 16, 4: 32, 5: 64, 6: 128}.get(pcie_gen, 32) # GB/s x16
|
pcie_peak = {3: 16, 4: 32, 5: 64, 6: 128}.get(pcie_gen, 32) if pcie_gen > 0 else 0 # GB/s x16
|
||||||
h2d_efficiency = (h2d_bw / pcie_peak) * 100 if (h2d_bw and pcie_peak) else 0
|
h2d_efficiency = round((h2d_bw / pcie_peak) * 100, 1) if (h2d_bw and pcie_peak) else None
|
||||||
d2h_efficiency = (d2h_bw / pcie_peak) * 100 if (d2h_bw and pcie_peak) else 0
|
d2h_efficiency = round((d2h_bw / pcie_peak) * 100, 1) if (d2h_bw and pcie_peak) else None
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"memory": {
|
"memory": {
|
||||||
@ -165,14 +165,14 @@ class Benchmark:
|
|||||||
"h2d_bandwidth_gbps": round(h2d_bw, 1),
|
"h2d_bandwidth_gbps": round(h2d_bw, 1),
|
||||||
"d2h_bandwidth_gbps": round(d2h_bw, 1),
|
"d2h_bandwidth_gbps": round(d2h_bw, 1),
|
||||||
"d2d_bandwidth_gbps": round(d2d_bw, 1),
|
"d2d_bandwidth_gbps": round(d2d_bw, 1),
|
||||||
"h2d_peak_gbps": pcie_peak,
|
"h2d_peak_gbps": pcie_peak if pcie_peak else None,
|
||||||
"d2h_peak_gbps": pcie_peak,
|
"d2h_peak_gbps": pcie_peak if pcie_peak else None,
|
||||||
"d2d_peak_gbps": round(d2d_peak, 1),
|
"d2d_peak_gbps": round(d2d_peak, 1) if d2d_peak else None,
|
||||||
"h2d_efficiency_pct": round(h2d_efficiency, 1),
|
"h2d_efficiency_pct": h2d_efficiency,
|
||||||
"d2h_efficiency_pct": round(d2h_efficiency, 1),
|
"d2h_efficiency_pct": d2h_efficiency,
|
||||||
"d2d_efficiency_pct": round(d2d_efficiency, 1),
|
"d2d_efficiency_pct": d2d_efficiency,
|
||||||
"peak_bandwidth_gbps": self.specs["memory_bandwidth_gbps"],
|
"peak_bandwidth_gbps": self.specs["memory_bandwidth_gbps"],
|
||||||
"efficiency_pct": round(d2d_efficiency, 1),
|
"efficiency_pct": d2d_efficiency,
|
||||||
"results_by_test": results_by_test,
|
"results_by_test": results_by_test,
|
||||||
"per_gpu": [],
|
"per_gpu": [],
|
||||||
}
|
}
|
||||||
@ -276,7 +276,7 @@ class Benchmark:
|
|||||||
|
|
||||||
best_d2d = max(v["d2d_gbps"] for v in bandwidth_by_size.values())
|
best_d2d = max(v["d2d_gbps"] for v in bandwidth_by_size.values())
|
||||||
peak_bw = self.specs["memory_bandwidth_gbps"]
|
peak_bw = self.specs["memory_bandwidth_gbps"]
|
||||||
efficiency = (best_d2d / peak_bw) * 100 if peak_bw else 0.0
|
efficiency = round((best_d2d / peak_bw) * 100, 1) if peak_bw else None
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"memory": {
|
"memory": {
|
||||||
@ -285,7 +285,7 @@ class Benchmark:
|
|||||||
"d2h_bandwidth_gbps": round(max(v["d2h_gbps"] for v in bandwidth_by_size.values()), 1),
|
"d2h_bandwidth_gbps": round(max(v["d2h_gbps"] for v in bandwidth_by_size.values()), 1),
|
||||||
"d2d_bandwidth_gbps": round(best_d2d, 1),
|
"d2d_bandwidth_gbps": round(best_d2d, 1),
|
||||||
"peak_bandwidth_gbps": self.specs["memory_bandwidth_gbps"],
|
"peak_bandwidth_gbps": self.specs["memory_bandwidth_gbps"],
|
||||||
"efficiency_pct": round(efficiency, 1),
|
"efficiency_pct": efficiency,
|
||||||
"test_sizes_mb": test_sizes_mb,
|
"test_sizes_mb": test_sizes_mb,
|
||||||
"bandwidth_by_size": bandwidth_by_size,
|
"bandwidth_by_size": bandwidth_by_size,
|
||||||
"per_gpu": [],
|
"per_gpu": [],
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
"""GPU specifications database for NVIDIA datacenter GPUs (A100/A800/H100/H200/B200/B300)."""
|
"""GPU specifications database for NVIDIA datacenter GPUs."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
"""Hardware health monitoring module for NVIDIA datacenter GPUs (A100/A800/H100/H200/B200/B300)."""
|
"""Hardware health monitoring module for NVIDIA datacenter GPUs."""
|
||||||
|
|
||||||
import subprocess
|
import subprocess
|
||||||
import shutil
|
import shutil
|
||||||
@ -115,7 +115,11 @@ class HealthCheck:
|
|||||||
|
|
||||||
pg = self._safe_int(pcie_gens[i] if i < len(pcie_gens) else 0)
|
pg = self._safe_int(pcie_gens[i] if i < len(pcie_gens) else 0)
|
||||||
pw = self._safe_int(pcie_widths[i] if i < len(pcie_widths) else 0)
|
pw = self._safe_int(pcie_widths[i] if i < len(pcie_widths) else 0)
|
||||||
pcie_ok = pg >= 4 and pw >= 8
|
expected_gen = self.specs.get("pcie_gen", 0)
|
||||||
|
if expected_gen > 0:
|
||||||
|
pcie_ok = pg >= expected_gen and pw >= 16
|
||||||
|
else:
|
||||||
|
pcie_ok = pw >= 8 # unknown GPU: just check width
|
||||||
if not pcie_ok:
|
if not pcie_ok:
|
||||||
overall_pass = False
|
overall_pass = False
|
||||||
checks["pcie_link"] = {"gen": pg, "width": pw, "status": "PASS" if pcie_ok else "WARN"}
|
checks["pcie_link"] = {"gen": pg, "width": pw, "status": "PASS" if pcie_ok else "WARN"}
|
||||||
|
|||||||
@ -79,9 +79,17 @@ class NCCLTest:
|
|||||||
if self.nccl_cfg.get("test_sendrecv", False):
|
if self.nccl_cfg.get("test_sendrecv", False):
|
||||||
tests.append(("sendrecv_perf", "SendRecv"))
|
tests.append(("sendrecv_perf", "SendRecv"))
|
||||||
|
|
||||||
default_min_bw = self.specs.get("nvlink_bandwidth_gbps", 900) * 0.4
|
nvlink_bw = self.specs.get("nvlink_bandwidth_gbps", 0)
|
||||||
|
if nvlink_bw > 0:
|
||||||
|
default_min_bw = nvlink_bw * 0.4
|
||||||
|
else:
|
||||||
|
# Conservative floor: any working NVLink should exceed 10 GB/s
|
||||||
|
default_min_bw = 10
|
||||||
min_bw = self.nccl_cfg.get("min_bandwidth_gbps") or round(default_min_bw)
|
min_bw = self.nccl_cfg.get("min_bandwidth_gbps") or round(default_min_bw)
|
||||||
|
|
||||||
|
if self.gpu_type == "unknown":
|
||||||
|
self.console.print("[yellow]Unknown GPU — using conservative bandwidth thresholds[/yellow]")
|
||||||
|
|
||||||
# Strategy: try nccl-tests binary directly (single-node, -g N),
|
# Strategy: try nccl-tests binary directly (single-node, -g N),
|
||||||
# then mpirun, then torchrun fallback
|
# then mpirun, then torchrun fallback
|
||||||
results = {}
|
results = {}
|
||||||
@ -317,7 +325,7 @@ except Exception as e:
|
|||||||
dist.destroy_process_group()
|
dist.destroy_process_group()
|
||||||
"""
|
"""
|
||||||
import tempfile
|
import tempfile
|
||||||
tmp = tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False, dir="/tmp")
|
tmp = tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False)
|
||||||
tmp.write(code)
|
tmp.write(code)
|
||||||
tmp.close()
|
tmp.close()
|
||||||
|
|
||||||
|
|||||||
@ -6,6 +6,12 @@ from datetime import datetime
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
try:
|
||||||
|
from importlib.metadata import version as _pkg_version
|
||||||
|
__version__ = _pkg_version("gpu-server-test-suite")
|
||||||
|
except Exception:
|
||||||
|
__version__ = "0.2.0"
|
||||||
|
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
|
|
||||||
@ -368,7 +374,7 @@ class ReportGenerator:
|
|||||||
|
|
||||||
# --- Footer ---
|
# --- Footer ---
|
||||||
lines.append("---")
|
lines.append("---")
|
||||||
lines.append(f"*Generated by GPU Test Suite v0.2.0*")
|
lines.append(f"*Generated by GPU Test Suite v{__version__}*")
|
||||||
|
|
||||||
content = "\n".join(lines)
|
content = "\n".join(lines)
|
||||||
with open(output, "w") as f:
|
with open(output, "w") as f:
|
||||||
|
|||||||
@ -49,13 +49,13 @@ class StressTest:
|
|||||||
gpu_burn = self._find_gpu_burn()
|
gpu_burn = self._find_gpu_burn()
|
||||||
|
|
||||||
if gpu_burn:
|
if gpu_burn:
|
||||||
# 尝试使用 gpu-burn
|
# Try gpu-burn first
|
||||||
result = self._run_gpu_burn(gpu_burn, duration_sec, use_doubles, use_tensor_cores, target_gpus)
|
result = self._run_gpu_burn(gpu_burn, duration_sec, use_doubles, use_tensor_cores, target_gpus)
|
||||||
|
|
||||||
# 如果 gpu-burn 失败(例如显存不足),自动 fallback 到 PyTorch
|
# If gpu-burn fails (e.g. OOM), auto-fallback to PyTorch
|
||||||
if not result.get("passed") and result.get("elapsed_sec", 0) < duration_sec * 0.5:
|
if not result.get("passed") and result.get("elapsed_sec", 0) < duration_sec * 0.5:
|
||||||
self.console.print("\n[yellow]gpu-burn 提前退出(可能显存不足),自动切换到 PyTorch 压力测试[/yellow]")
|
self.console.print("\n[yellow]gpu-burn exited early (possible OOM), switching to PyTorch stress test[/yellow]")
|
||||||
self.console.print("[dim]PyTorch 模式会根据实际可用显存动态调整,更稳定[/dim]\n")
|
self.console.print("[dim]PyTorch mode dynamically adapts to available memory[/dim]\n")
|
||||||
return self._run_pytorch_stress(duration_sec, memory_pct)
|
return self._run_pytorch_stress(duration_sec, memory_pct)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
@ -134,18 +134,16 @@ class StressTest:
|
|||||||
tensors = {}
|
tensors = {}
|
||||||
for i in range(gpu_count):
|
for i in range(gpu_count):
|
||||||
with torch.cuda.device(i):
|
with torch.cuda.device(i):
|
||||||
# 获取实际可用显存(考虑其他进程已占用的部分)
|
# Get actual free memory (accounting for other processes)
|
||||||
free_mem, total_mem = torch.cuda.mem_get_info(i)
|
free_mem, total_mem = torch.cuda.mem_get_info(i)
|
||||||
|
|
||||||
# 根据配置的 memory_pct 计算分配大小
|
# Calculate allocation from configured memory_pct
|
||||||
# 例如:memory_pct=90 表示使用总显存的 90%
|
|
||||||
target_mem = int(total_mem * memory_pct / 100)
|
target_mem = int(total_mem * memory_pct / 100)
|
||||||
|
|
||||||
# 但不能超过实际可用显存(留出 5% 安全余量)
|
# Cap at actual free memory with 5% safety margin
|
||||||
alloc_bytes = min(target_mem, int(free_mem * 0.95))
|
alloc_bytes = min(target_mem, int(free_mem * 0.95))
|
||||||
|
|
||||||
# matmul(A, A.T) 需要 2x 输入显存(输入 + 输出)
|
# matmul(A, A.T) needs 2x input memory (input + output)
|
||||||
# 所以分配 sqrt(alloc_bytes/4/2) 大小的方阵
|
|
||||||
side = int((alloc_bytes / 4 / 2) ** 0.5) # float32 = 4 bytes
|
side = int((alloc_bytes / 4 / 2) ** 0.5) # float32 = 4 bytes
|
||||||
|
|
||||||
actual_mem_mb = side * side * 4 / 1024 / 1024
|
actual_mem_mb = side * side * 4 / 1024 / 1024
|
||||||
@ -153,13 +151,13 @@ class StressTest:
|
|||||||
free_mem_mb = free_mem / 1024 / 1024
|
free_mem_mb = free_mem / 1024 / 1024
|
||||||
|
|
||||||
self.console.print(
|
self.console.print(
|
||||||
f" [dim]GPU {i}: 总显存 {total_mem_mb:.0f}MB, 可用 {free_mem_mb:.0f}MB, "
|
f" [dim]GPU {i}: total {total_mem_mb:.0f}MB, free {free_mem_mb:.0f}MB, "
|
||||||
f"分配 {actual_mem_mb:.0f}MB ({actual_mem_mb/total_mem_mb*100:.0f}%) - "
|
f"alloc {actual_mem_mb:.0f}MB ({actual_mem_mb/total_mem_mb*100:.0f}%) - "
|
||||||
f"矩阵 {side}x{side}[/dim]"
|
f"matrix {side}x{side}[/dim]"
|
||||||
)
|
)
|
||||||
tensors[i] = torch.randn(side, side, device=f"cuda:{i}", dtype=torch.float32)
|
tensors[i] = torch.randn(side, side, device=f"cuda:{i}", dtype=torch.float32)
|
||||||
|
|
||||||
self.console.print(f"\n[cyan]开始压力测试,持续 {duration} 秒...[/cyan]")
|
self.console.print(f"\n[cyan]Starting stress test for {duration} seconds...[/cyan]")
|
||||||
|
|
||||||
elapsed_check = 0
|
elapsed_check = 0
|
||||||
while time.time() - t0 < duration:
|
while time.time() - t0 < duration:
|
||||||
@ -169,10 +167,10 @@ class StressTest:
|
|||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
|
|
||||||
# 每 10 秒显示一次进度
|
# Show progress every 10 seconds
|
||||||
current_elapsed = time.time() - t0
|
current_elapsed = time.time() - t0
|
||||||
if int(current_elapsed) != int(elapsed_check) and int(current_elapsed) % 10 == 0:
|
if int(current_elapsed) != int(elapsed_check) and int(current_elapsed) % 10 == 0:
|
||||||
self.console.print(f" [dim]已运行 {int(current_elapsed)}s / {duration}s[/dim]")
|
self.console.print(f" [dim]Running {int(current_elapsed)}s / {duration}s[/dim]")
|
||||||
elapsed_check = current_elapsed
|
elapsed_check = current_elapsed
|
||||||
|
|
||||||
for i in range(gpu_count):
|
for i in range(gpu_count):
|
||||||
@ -180,7 +178,7 @@ class StressTest:
|
|||||||
|
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
error_msg = str(e)
|
error_msg = str(e)
|
||||||
self.console.print(f"\n[red]压力测试出错: {error_msg}[/red]")
|
self.console.print(f"\n[red]Stress test error: {error_msg}[/red]")
|
||||||
for i in range(gpu_count):
|
for i in range(gpu_count):
|
||||||
if i not in gpu_status:
|
if i not in gpu_status:
|
||||||
gpu_status[i] = "FAIL"
|
gpu_status[i] = "FAIL"
|
||||||
|
|||||||
@ -77,7 +77,7 @@ class TrainingSim:
|
|||||||
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
|
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
|
||||||
|
|
||||||
if dtype in (torch.float16, torch.bfloat16):
|
if dtype in (torch.float16, torch.bfloat16):
|
||||||
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == torch.float16))
|
scaler = torch.amp.GradScaler("cuda", enabled=(dtype == torch.float16))
|
||||||
|
|
||||||
step_times = []
|
step_times = []
|
||||||
mem_usage = []
|
mem_usage = []
|
||||||
@ -96,8 +96,8 @@ class TrainingSim:
|
|||||||
input_ids = input_ids.to(model.device)
|
input_ids = input_ids.to(model.device)
|
||||||
attention_mask = attention_mask.to(model.device)
|
attention_mask = attention_mask.to(model.device)
|
||||||
|
|
||||||
if dtype in (torch.float16, torch.bfloat16) and dtype != torch.bfloat16:
|
if dtype in (torch.float16, torch.bfloat16):
|
||||||
with torch.cuda.amp.autocast(dtype=dtype):
|
with torch.amp.autocast("cuda", dtype=dtype):
|
||||||
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
|
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
|
||||||
loss = outputs.loss
|
loss = outputs.loss
|
||||||
else:
|
else:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user