257 lines
8.3 KiB
Python
257 lines
8.3 KiB
Python
"""
|
|
This file will compute the min, max, mean, and standard deviation of each datasets
|
|
in `pretrain_datasets.json` or `pretrain_datasets.json`.
|
|
"""
|
|
|
|
import json
|
|
import argparse
|
|
import os
|
|
|
|
# from multiprocessing import Pool, Manager
|
|
|
|
import tensorflow as tf
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
|
|
from data.vla_dataset import VLADataset
|
|
from data.hdf5_vla_dataset import HDF5VLADataset
|
|
from data.preprocess import generate_json_state
|
|
|
|
|
|
# Process each dataset to get the statistics
|
|
@tf.autograph.experimental.do_not_convert
|
|
def process_dataset(name_dataset_pair):
|
|
# print(f"PID {os.getpid()} processing {name_dataset_pair[0]}")
|
|
dataset_iter = name_dataset_pair[1]
|
|
|
|
MAX_EPISODES = 100000
|
|
EPS = 1e-8
|
|
# For debugging
|
|
# MAX_EPISODES = 10
|
|
episode_cnt = 0
|
|
state_sum = 0
|
|
state_sum_sq = 0
|
|
z_state_sum = 0
|
|
z_state_sum_sq = 0
|
|
state_cnt = 0
|
|
nz_state_cnt = None
|
|
state_max = None
|
|
state_min = None
|
|
for episode in dataset_iter:
|
|
episode_cnt += 1
|
|
if episode_cnt % 1000 == 0:
|
|
print(f"Processing episodes {episode_cnt}/{MAX_EPISODES}")
|
|
if episode_cnt > MAX_EPISODES:
|
|
break
|
|
episode_dict = episode["episode_dict"]
|
|
dataset_name = episode["dataset_name"]
|
|
|
|
res_tup = generate_json_state(episode_dict, dataset_name)
|
|
states = res_tup[1]
|
|
|
|
# Convert to numpy
|
|
states = states.numpy()
|
|
|
|
# Zero the values that are close to zero
|
|
z_states = states.copy()
|
|
z_states[np.abs(states) <= EPS] = 0
|
|
# Compute the non-zero count
|
|
if nz_state_cnt is None:
|
|
nz_state_cnt = np.zeros(states.shape[1])
|
|
nz_state_cnt += np.sum(np.abs(states) > EPS, axis=0)
|
|
|
|
# Update statistics
|
|
state_sum += np.sum(states, axis=0)
|
|
state_sum_sq += np.sum(states**2, axis=0)
|
|
z_state_sum += np.sum(z_states, axis=0)
|
|
z_state_sum_sq += np.sum(z_states**2, axis=0)
|
|
state_cnt += states.shape[0]
|
|
if state_max is None:
|
|
state_max = np.max(states, axis=0)
|
|
state_min = np.min(states, axis=0)
|
|
else:
|
|
state_max = np.maximum(state_max, np.max(states, axis=0))
|
|
state_min = np.minimum(state_min, np.min(states, axis=0))
|
|
|
|
# Add one to avoid division by zero
|
|
nz_state_cnt = np.maximum(nz_state_cnt, np.ones_like(nz_state_cnt))
|
|
|
|
result = {
|
|
"dataset_name":
|
|
name_dataset_pair[0],
|
|
"state_mean": (state_sum / state_cnt).tolist(),
|
|
"state_std":
|
|
np.sqrt(
|
|
np.maximum(
|
|
(z_state_sum_sq / nz_state_cnt) - (z_state_sum / state_cnt)**2 * (state_cnt / nz_state_cnt),
|
|
np.zeros_like(state_sum_sq),
|
|
)).tolist(),
|
|
"state_min":
|
|
state_min.tolist(),
|
|
"state_max":
|
|
state_max.tolist(),
|
|
}
|
|
|
|
return result
|
|
|
|
|
|
def process_hdf5_dataset(vla_dataset):
|
|
EPS = 1e-8
|
|
episode_cnt = 0
|
|
state_sum = 0
|
|
state_sum_sq = 0
|
|
z_state_sum = 0
|
|
z_state_sum_sq = 0
|
|
state_cnt = 0
|
|
nz_state_cnt = None
|
|
state_max = None
|
|
state_min = None
|
|
for i in tqdm(range(len(vla_dataset))):
|
|
episode = vla_dataset.get_item(i, state_only=True)
|
|
episode_cnt += 1
|
|
|
|
states = episode["state"]
|
|
|
|
# Zero the values that are close to zero
|
|
z_states = states.copy()
|
|
z_states[np.abs(states) <= EPS] = 0
|
|
# Compute the non-zero count
|
|
if nz_state_cnt is None:
|
|
nz_state_cnt = np.zeros(states.shape[1])
|
|
nz_state_cnt += np.sum(np.abs(states) > EPS, axis=0)
|
|
|
|
# Update statistics
|
|
state_sum += np.sum(states, axis=0)
|
|
state_sum_sq += np.sum(states**2, axis=0)
|
|
z_state_sum += np.sum(z_states, axis=0)
|
|
z_state_sum_sq += np.sum(z_states**2, axis=0)
|
|
state_cnt += states.shape[0]
|
|
if state_max is None:
|
|
state_max = np.max(states, axis=0)
|
|
state_min = np.min(states, axis=0)
|
|
else:
|
|
state_max = np.maximum(state_max, np.max(states, axis=0))
|
|
state_min = np.minimum(state_min, np.min(states, axis=0))
|
|
|
|
# Add one to avoid division by zero
|
|
nz_state_cnt = np.maximum(nz_state_cnt, np.ones_like(nz_state_cnt))
|
|
|
|
result = {
|
|
"dataset_name":
|
|
vla_dataset.get_dataset_name(),
|
|
"state_mean": (state_sum / state_cnt).tolist(),
|
|
"state_std":
|
|
np.sqrt(
|
|
np.maximum(
|
|
(z_state_sum_sq / nz_state_cnt) - (z_state_sum / state_cnt)**2 * (state_cnt / nz_state_cnt),
|
|
np.zeros_like(state_sum_sq),
|
|
)).tolist(),
|
|
"state_min":
|
|
state_min.tolist(),
|
|
"state_max":
|
|
state_max.tolist(),
|
|
}
|
|
|
|
return result
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
# Multiprocessing currently with bugs
|
|
# parser.add_argument('--n_workers', type=int, default=1,
|
|
# help="Number of parallel workers.")
|
|
parser.add_argument(
|
|
"--dataset_type",
|
|
type=str,
|
|
default="pretrain",
|
|
help="Whether to load the pretrain dataset or finetune dataset.",
|
|
)
|
|
parser.add_argument(
|
|
"--save_path",
|
|
type=str,
|
|
default="configs/dataset_stat.json",
|
|
help="JSON file path to save the dataset statistics.",
|
|
)
|
|
parser.add_argument(
|
|
"--skip_exist",
|
|
action="store_true",
|
|
help="Whether to skip the existing dataset statistics.",
|
|
)
|
|
parser.add_argument(
|
|
"--hdf5_dataset",
|
|
action="store_true",
|
|
help="Whether to load the dataset from the HDF5 files.",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
if args.hdf5_dataset:
|
|
vla_dataset = HDF5VLADataset()
|
|
dataset_name = vla_dataset.get_dataset_name()
|
|
|
|
try:
|
|
with open(args.save_path, "r") as f:
|
|
results = json.load(f)
|
|
except FileNotFoundError:
|
|
results = {}
|
|
if args.skip_exist and dataset_name in results:
|
|
print(f"Skipping existed {dataset_name} dataset statistics")
|
|
else:
|
|
print(f"Processing {dataset_name} dataset")
|
|
result = process_hdf5_dataset(vla_dataset)
|
|
results[result["dataset_name"]] = result
|
|
with open(args.save_path, "w") as f:
|
|
json.dump(results, f, indent=4)
|
|
print("All datasets have been processed.")
|
|
os._exit(0)
|
|
|
|
vla_dataset = VLADataset(seed=0, dataset_type=args.dataset_type, repeat=False)
|
|
name_dataset_pairs = vla_dataset.name2dataset.items()
|
|
# num_workers = args.n_workers
|
|
|
|
for name_dataset_pair in tqdm(name_dataset_pairs):
|
|
try:
|
|
with open(args.save_path, "r") as f:
|
|
results = json.load(f)
|
|
except FileNotFoundError:
|
|
results = {}
|
|
|
|
if args.skip_exist and name_dataset_pair[0] in results:
|
|
print(f"Skipping existed {name_dataset_pair[0]} dataset statistics")
|
|
continue
|
|
print(f"Processing {name_dataset_pair[0]} dataset")
|
|
|
|
result = process_dataset(name_dataset_pair)
|
|
|
|
results[result["dataset_name"]] = result
|
|
|
|
# Save the results in the json file after each dataset (for resume)
|
|
with open(args.save_path, "w") as f:
|
|
json.dump(results, f, indent=4)
|
|
|
|
print("All datasets have been processed.")
|
|
|
|
# with Manager() as manager:
|
|
# # Create shared dictionary and lock through the manager, accessible by all processes
|
|
# progress = manager.dict(processed=0, results={})
|
|
# progress_lock = manager.Lock()
|
|
|
|
# # Callback function to update progress
|
|
# def update_progress(result):
|
|
# with progress_lock:
|
|
# progress['processed'] += 1
|
|
# print(f"{result['dataset_name']} - {progress['processed']}/{len(name_dataset_pairs)} datasets have been processed")
|
|
# # Append the result to the shared dictionary
|
|
# progress['results'][result["dataset_name"]] = result
|
|
|
|
# with Pool(num_workers) as p:
|
|
# for name_dataset_pair in name_dataset_pairs:
|
|
# p.apply_async(process_dataset, args=(name_dataset_pair,), callback=update_progress)
|
|
|
|
# # Close the pool and wait for the work to finish
|
|
# p.close()
|
|
# p.join()
|
|
|
|
# # Save the results in the json file
|
|
# with open(args.save_path, 'w') as f:
|
|
# json.dump(progress['results'], f, indent=4)
|