import os from functools import wraps import time import numpy as np import torch # def measure_time(func): # @wraps(func) # def wrapper(*args, **kwargs): # begin_time = time.time() # result = func(* args, **kwargs) # elapsed_ms = 1000 * (time.time() - begin_time) # logger.info("\033[1;31m" + f"{func.__name__}: {elapsed_ms:.2f} ms" + "\033[0m") # return result # return wrapper # from functools import wraps # import time def measure_time(logger): """返回一个装饰器,使用指定的 logger 记录耗时""" def decorator(func): @wraps(func) def wrapper(*args, **kwargs): begin_time = time.time() result = func(*args, **kwargs) elapsed_ms = 1000 * (time.time() - begin_time) logger.info("\033[1;31m" + f"{func.__name__}: {elapsed_ms:.2f} ms" + "\033[0m") return result return wrapper return decorator def show_data_summary(data): """打印数据集的维度和数据类型""" for k, v in data.items(): if isinstance(v, np.ndarray): print(f"{k}: {v.shape} {v.dtype} {type(v)} {v.min():.4f}~{v.max():.4f}") elif isinstance(v, torch.Tensor): print(f"{k}: {v.shape} {v.dtype} {type(v)} {v.min():.4f}~{v.max():.4f}, {v.device}") else: print(f"{k}: {v} {type(v)}")