44 lines
1.4 KiB
Python
44 lines
1.4 KiB
Python
import os
|
|
|
|
from functools import wraps
|
|
import time
|
|
import numpy as np
|
|
import torch
|
|
|
|
# def measure_time(func):
|
|
# @wraps(func)
|
|
# def wrapper(*args, **kwargs):
|
|
# begin_time = time.time()
|
|
# result = func(* args, **kwargs)
|
|
# elapsed_ms = 1000 * (time.time() - begin_time)
|
|
# logger.info("\033[1;31m" + f"{func.__name__}: {elapsed_ms:.2f} ms" + "\033[0m")
|
|
# return result
|
|
# return wrapper
|
|
|
|
# from functools import wraps
|
|
# import time
|
|
|
|
def measure_time(logger):
|
|
"""返回一个装饰器,使用指定的 logger 记录耗时"""
|
|
def decorator(func):
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
begin_time = time.time()
|
|
result = func(*args, **kwargs)
|
|
elapsed_ms = 1000 * (time.time() - begin_time)
|
|
logger.info("\033[1;31m" + f"{func.__name__}: {elapsed_ms:.2f} ms" + "\033[0m")
|
|
return result
|
|
return wrapper
|
|
return decorator
|
|
|
|
def show_data_summary(data):
|
|
"""打印数据集的维度和数据类型"""
|
|
for k, v in data.items():
|
|
if isinstance(v, np.ndarray):
|
|
print(f"{k}: {v.shape} {v.dtype} {type(v)} {v.min():.4f}~{v.max():.4f}")
|
|
elif isinstance(v, torch.Tensor):
|
|
print(f"{k}: {v.shape} {v.dtype} {type(v)} {v.min():.4f}~{v.max():.4f}, {v.device}")
|
|
else:
|
|
print(f"{k}: {v} {type(v)}")
|
|
|