2026-03-03 16:24:02 +08:00

44 lines
1.4 KiB
Python

import os
from functools import wraps
import time
import numpy as np
import torch
# def measure_time(func):
# @wraps(func)
# def wrapper(*args, **kwargs):
# begin_time = time.time()
# result = func(* args, **kwargs)
# elapsed_ms = 1000 * (time.time() - begin_time)
# logger.info("\033[1;31m" + f"{func.__name__}: {elapsed_ms:.2f} ms" + "\033[0m")
# return result
# return wrapper
# from functools import wraps
# import time
def measure_time(logger):
"""返回一个装饰器,使用指定的 logger 记录耗时"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
begin_time = time.time()
result = func(*args, **kwargs)
elapsed_ms = 1000 * (time.time() - begin_time)
logger.info("\033[1;31m" + f"{func.__name__}: {elapsed_ms:.2f} ms" + "\033[0m")
return result
return wrapper
return decorator
def show_data_summary(data):
"""打印数据集的维度和数据类型"""
for k, v in data.items():
if isinstance(v, np.ndarray):
print(f"{k}: {v.shape} {v.dtype} {type(v)} {v.min():.4f}~{v.max():.4f}")
elif isinstance(v, torch.Tensor):
print(f"{k}: {v.shape} {v.dtype} {type(v)} {v.min():.4f}~{v.max():.4f}, {v.device}")
else:
print(f"{k}: {v} {type(v)}")