150 lines
5.4 KiB
Python
150 lines
5.4 KiB
Python
import json
|
|
import random
|
|
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
import tensorflow_datasets as tfds
|
|
import yaml
|
|
|
|
from data.episode_transform import (
|
|
process_episode,
|
|
flatten_episode,
|
|
flatten_episode_agilex,
|
|
bgr_to_rgb,
|
|
)
|
|
from data.utils import dataset_to_path
|
|
from data.preprocess_scripts import *
|
|
|
|
# Producer does not need GPU
|
|
tf.config.set_visible_devices([], "GPU")
|
|
|
|
OPENX_EMBOD_DIR = "data/datasets/openx_embod"
|
|
|
|
DATASET_NAMES_NOOPENX = [
|
|
"aloha_mobile",
|
|
"aloha_static",
|
|
"roboset",
|
|
"agilex",
|
|
"rh20t",
|
|
"calvin",
|
|
"bridgev2",
|
|
]
|
|
|
|
# Read the config
|
|
with open("configs/base.yaml", "r") as file:
|
|
config = yaml.safe_load(file)
|
|
# Load some constants from the config
|
|
EPSD_LEN_THRESH_LOW = config["dataset"]["epsd_len_thresh_low"]
|
|
EPSD_LEN_THRESH_HIGH = config["dataset"]["epsd_len_thresh_high"]
|
|
# Read the image keys of each dataset
|
|
with open("configs/dataset_img_keys.json", "r") as file:
|
|
IMAGE_KEYS = json.load(file)
|
|
|
|
|
|
class VLADataset:
|
|
"""
|
|
This class is used to sample episodes from the embododiment dataset.
|
|
"""
|
|
|
|
def __init__(self, seed, dataset_type, repeat=True):
|
|
"""
|
|
seed: the random seed
|
|
dataset_type: 'pretrain' or 'finetune', which dataset to load
|
|
repeat: whether to repeat to infinite length
|
|
"""
|
|
dataset_names_cfg = ("configs/pretrain_datasets.json"
|
|
if dataset_type == "pretrain" else "configs/finetune_datasets.json")
|
|
with open(dataset_names_cfg, "r") as file:
|
|
DATASET_NAMES = json.load(file)
|
|
self.dataset_names = DATASET_NAMES
|
|
sample_weights_cfg = ("configs/pretrain_sample_weights.json"
|
|
if dataset_type == "pretrain" else "configs/finetune_sample_weights.json")
|
|
# Load the sample weights
|
|
with open(sample_weights_cfg, "r") as file:
|
|
SAMPLE_WEIGHTS = json.load(file)
|
|
self.openx_dir = OPENX_EMBOD_DIR
|
|
self.epsd_len_thresh_low = EPSD_LEN_THRESH_LOW
|
|
self.epsd_len_thresh_high = EPSD_LEN_THRESH_HIGH
|
|
self.repeat = repeat
|
|
|
|
# Set the random seed
|
|
tf.random.set_seed(seed)
|
|
np.random.seed(seed)
|
|
|
|
# Weights of the each dataset in the collection to sample from
|
|
sample_weights = []
|
|
|
|
self.name2dataset = {}
|
|
for dataset_name in self.dataset_names:
|
|
if dataset_name in DATASET_NAMES_NOOPENX:
|
|
dataset = globals()[dataset_name].load_dataset(seed)
|
|
else:
|
|
dataset_path = dataset_to_path(dataset_name, self.openx_dir)
|
|
dataset = tfds.builder_from_directory(builder_dir=dataset_path)
|
|
dataset = dataset.as_dataset(split="all", shuffle_files=True)
|
|
|
|
# You can add filter for other datasets
|
|
if dataset_name == "kuka":
|
|
dataset = dataset.filter(lambda x: x["success"])
|
|
elif dataset_name == "bc_z":
|
|
dataset = dataset.filter(lambda x: tf.math.greater(
|
|
next(iter(x["steps"]))["observation"]["episode_success"],
|
|
0.5,
|
|
))
|
|
elif (dataset_name == "ucsd_pick_and_place_dataset_converted_externally_to_rlds"):
|
|
dataset = dataset.filter(lambda x: x["episode_metadata"]["success"])
|
|
elif (dataset_name == "utokyo_xarm_bimanual_converted_externally_to_rlds"):
|
|
# Only preserve the meaningful episodes
|
|
dataset = dataset.filter(lambda x: tf.math.equal(
|
|
next(iter(x["steps"]))["language_instruction"],
|
|
tf.constant("Unfold a wrinkled towel."),
|
|
))
|
|
|
|
# Note: use cache() will cause the unexpected crash
|
|
# dataset = dataset.map().cache().shuffle().repeat()
|
|
dataset = dataset.map(lambda x: process_episode(
|
|
x,
|
|
dataset_name,
|
|
IMAGE_KEYS[dataset_name]["image_keys"],
|
|
IMAGE_KEYS[dataset_name]["image_mask"],
|
|
))
|
|
|
|
# Change BGR to RGB if needed
|
|
if dataset_name == "fmb":
|
|
dataset = dataset.map(bgr_to_rgb)
|
|
|
|
if self.repeat:
|
|
dataset = dataset.repeat()
|
|
self.name2dataset[dataset_name] = iter(dataset)
|
|
sample_weights.append(SAMPLE_WEIGHTS[dataset_name])
|
|
# Normalize the sample weights
|
|
sample_weights = np.array(sample_weights)
|
|
self.sample_weights = sample_weights / np.sum(sample_weights)
|
|
|
|
def __iter__(self):
|
|
"""
|
|
Sample batches of episodes for an epoch.
|
|
"""
|
|
while True:
|
|
dataset_name = np.random.choice(self.dataset_names, p=self.sample_weights)
|
|
episode = next(self.name2dataset[dataset_name])
|
|
if dataset_name == "agilex":
|
|
episode_steps = flatten_episode_agilex(episode)
|
|
else:
|
|
episode_steps = flatten_episode(episode)
|
|
# Filter too short
|
|
if len(episode_steps) < self.epsd_len_thresh_low:
|
|
continue
|
|
# Randomly sample too long
|
|
if len(episode_steps) > self.epsd_len_thresh_high:
|
|
episode_steps = random.sample(episode_steps, self.epsd_len_thresh_high)
|
|
|
|
yield episode_steps
|
|
|
|
|
|
if __name__ == "__main__":
|
|
dataset = VLADataset(0, "finetune")
|
|
for episode in dataset:
|
|
print(episode[0])
|
|
break
|