399 lines
16 KiB
Python
399 lines
16 KiB
Python
import numpy as np
|
|
import tensorflow as tf
|
|
import yaml
|
|
|
|
from data.preprocess import generate_json_state
|
|
from configs.state_vec import STATE_VEC_IDX_MAPPING
|
|
|
|
# Read the config
|
|
with open("configs/base.yaml", "r") as file:
|
|
config = yaml.safe_load(file)
|
|
# Load some constants from the config
|
|
IMG_HISTORY_SIZE = config["common"]["img_history_size"]
|
|
if IMG_HISTORY_SIZE < 1:
|
|
raise ValueError("Config `img_history_size` must be at least 1.")
|
|
ACTION_CHUNK_SIZE = config["common"]["action_chunk_size"]
|
|
if ACTION_CHUNK_SIZE < 1:
|
|
raise ValueError("Config `action_chunk_size` must be at least 1.")
|
|
|
|
|
|
@tf.function
|
|
def process_episode(epsd: dict, dataset_name: str, image_keys: list, image_mask: list) -> dict:
|
|
"""
|
|
Process an episode to extract the frames and the json content.
|
|
"""
|
|
# Frames of each camera
|
|
# Ugly code due to tf's poor compatibility
|
|
frames_0 = tf.TensorArray(dtype=tf.uint8, size=0, dynamic_size=True)
|
|
frames_1 = tf.TensorArray(dtype=tf.uint8, size=0, dynamic_size=True)
|
|
frames_2 = tf.TensorArray(dtype=tf.uint8, size=0, dynamic_size=True)
|
|
frames_3 = tf.TensorArray(dtype=tf.uint8, size=0, dynamic_size=True)
|
|
# Traverse the episode to collect...
|
|
for step in iter(epsd["steps"]):
|
|
# Parse the image
|
|
frames_0 = frames_0.write(
|
|
frames_0.size(),
|
|
tf.cond(
|
|
tf.equal(image_mask[0], 1),
|
|
lambda: step["observation"][image_keys[0]],
|
|
lambda: tf.zeros([0, 0, 0], dtype=tf.uint8),
|
|
),
|
|
)
|
|
# Very ugly code due to tf's poor compatibility
|
|
frames_1 = frames_1.write(
|
|
frames_1.size(),
|
|
tf.cond(
|
|
tf.equal(image_mask[1], 1),
|
|
lambda: step["observation"][image_keys[1]],
|
|
lambda: tf.zeros([0, 0, 0], dtype=tf.uint8),
|
|
),
|
|
)
|
|
frames_2 = frames_2.write(
|
|
frames_2.size(),
|
|
tf.cond(
|
|
tf.equal(image_mask[2], 1),
|
|
lambda: step["observation"][image_keys[2]],
|
|
lambda: tf.zeros([0, 0, 0], dtype=tf.uint8),
|
|
),
|
|
)
|
|
frames_3 = frames_3.write(
|
|
frames_3.size(),
|
|
tf.cond(
|
|
tf.equal(image_mask[3], 1),
|
|
lambda: step["observation"][image_keys[3]],
|
|
lambda: tf.zeros([0, 0, 0], dtype=tf.uint8),
|
|
),
|
|
)
|
|
|
|
# Calculate the past_frames_0 for each step
|
|
# Each step has a window of previous frames with size IMG_HISTORY_SIZE
|
|
# Use the first state to pad the frames
|
|
# past_frames_0 will have shape (num_steps, IMG_HISTORY_SIZE, height, width, channels)
|
|
frames_0 = frames_0.stack()
|
|
first_frame = tf.expand_dims(frames_0[0], axis=0)
|
|
first_frame = tf.repeat(first_frame, IMG_HISTORY_SIZE - 1, axis=0)
|
|
padded_frames_0 = tf.concat([first_frame, frames_0], axis=0)
|
|
indices = tf.range(IMG_HISTORY_SIZE, tf.shape(frames_0)[0] + IMG_HISTORY_SIZE)
|
|
past_frames_0 = tf.map_fn(lambda i: padded_frames_0[i - IMG_HISTORY_SIZE:i], indices, dtype=tf.uint8)
|
|
frames_0_time_mask = tf.ones([tf.shape(frames_0)[0]], dtype=tf.bool)
|
|
padded_frames_0_time_mask = tf.pad(
|
|
frames_0_time_mask,
|
|
[[IMG_HISTORY_SIZE - 1, 0]],
|
|
"CONSTANT",
|
|
constant_values=False,
|
|
)
|
|
past_frames_0_time_mask = tf.map_fn(
|
|
lambda i: padded_frames_0_time_mask[i - IMG_HISTORY_SIZE:i],
|
|
indices,
|
|
dtype=tf.bool,
|
|
)
|
|
|
|
# For past_frames_1
|
|
frames_1 = frames_1.stack()
|
|
first_frame = tf.expand_dims(frames_1[0], axis=0)
|
|
first_frame = tf.repeat(first_frame, IMG_HISTORY_SIZE - 1, axis=0)
|
|
padded_frames_1 = tf.concat([first_frame, frames_1], axis=0)
|
|
indices = tf.range(IMG_HISTORY_SIZE, tf.shape(frames_1)[0] + IMG_HISTORY_SIZE)
|
|
past_frames_1 = tf.map_fn(lambda i: padded_frames_1[i - IMG_HISTORY_SIZE:i], indices, dtype=tf.uint8)
|
|
frames_1_time_mask = tf.ones([tf.shape(frames_1)[0]], dtype=tf.bool)
|
|
padded_frames_1_time_mask = tf.pad(
|
|
frames_1_time_mask,
|
|
[[IMG_HISTORY_SIZE - 1, 0]],
|
|
"CONSTANT",
|
|
constant_values=False,
|
|
)
|
|
past_frames_1_time_mask = tf.map_fn(
|
|
lambda i: padded_frames_1_time_mask[i - IMG_HISTORY_SIZE:i],
|
|
indices,
|
|
dtype=tf.bool,
|
|
)
|
|
|
|
# For past_frames_2
|
|
frames_2 = frames_2.stack()
|
|
first_frame = tf.expand_dims(frames_2[0], axis=0)
|
|
first_frame = tf.repeat(first_frame, IMG_HISTORY_SIZE - 1, axis=0)
|
|
padded_frames_2 = tf.concat([first_frame, frames_2], axis=0)
|
|
indices = tf.range(IMG_HISTORY_SIZE, tf.shape(frames_2)[0] + IMG_HISTORY_SIZE)
|
|
past_frames_2 = tf.map_fn(lambda i: padded_frames_2[i - IMG_HISTORY_SIZE:i], indices, dtype=tf.uint8)
|
|
frames_2_time_mask = tf.ones([tf.shape(frames_2)[0]], dtype=tf.bool)
|
|
padded_frames_2_time_mask = tf.pad(
|
|
frames_2_time_mask,
|
|
[[IMG_HISTORY_SIZE - 1, 0]],
|
|
"CONSTANT",
|
|
constant_values=False,
|
|
)
|
|
past_frames_2_time_mask = tf.map_fn(
|
|
lambda i: padded_frames_2_time_mask[i - IMG_HISTORY_SIZE:i],
|
|
indices,
|
|
dtype=tf.bool,
|
|
)
|
|
|
|
# For past_frames_3
|
|
frames_3 = frames_3.stack()
|
|
first_frame = tf.expand_dims(frames_3[0], axis=0)
|
|
first_frame = tf.repeat(first_frame, IMG_HISTORY_SIZE - 1, axis=0)
|
|
padded_frames_3 = tf.concat([first_frame, frames_3], axis=0)
|
|
indices = tf.range(IMG_HISTORY_SIZE, tf.shape(frames_3)[0] + IMG_HISTORY_SIZE)
|
|
past_frames_3 = tf.map_fn(lambda i: padded_frames_3[i - IMG_HISTORY_SIZE:i], indices, dtype=tf.uint8)
|
|
frames_3_time_mask = tf.ones([tf.shape(frames_3)[0]], dtype=tf.bool)
|
|
padded_frames_3_time_mask = tf.pad(
|
|
frames_3_time_mask,
|
|
[[IMG_HISTORY_SIZE - 1, 0]],
|
|
"CONSTANT",
|
|
constant_values=False,
|
|
)
|
|
past_frames_3_time_mask = tf.map_fn(
|
|
lambda i: padded_frames_3_time_mask[i - IMG_HISTORY_SIZE:i],
|
|
indices,
|
|
dtype=tf.bool,
|
|
)
|
|
|
|
# Creat the ids for each step
|
|
step_id = tf.range(0, tf.shape(frames_0)[0])
|
|
|
|
return {
|
|
"dataset_name": dataset_name,
|
|
"episode_dict": epsd,
|
|
"step_id": step_id,
|
|
"past_frames_0": past_frames_0,
|
|
"past_frames_0_time_mask": past_frames_0_time_mask,
|
|
"past_frames_1": past_frames_1,
|
|
"past_frames_1_time_mask": past_frames_1_time_mask,
|
|
"past_frames_2": past_frames_2,
|
|
"past_frames_2_time_mask": past_frames_2_time_mask,
|
|
"past_frames_3": past_frames_3,
|
|
"past_frames_3_time_mask": past_frames_3_time_mask,
|
|
}
|
|
|
|
|
|
@tf.function
|
|
def bgr_to_rgb(epsd: dict):
|
|
"""
|
|
Convert BGR images to RGB images.
|
|
"""
|
|
past_frames_0 = epsd["past_frames_0"]
|
|
past_frames_0 = tf.cond(
|
|
tf.equal(tf.shape(past_frames_0)[-1], 3),
|
|
lambda: tf.stack(
|
|
[past_frames_0[..., 2], past_frames_0[..., 1], past_frames_0[..., 0]],
|
|
axis=-1,
|
|
),
|
|
lambda: past_frames_0,
|
|
)
|
|
|
|
past_frames_1 = epsd["past_frames_1"]
|
|
past_frames_1 = tf.cond(
|
|
tf.equal(tf.shape(past_frames_1)[-1], 3),
|
|
lambda: tf.stack(
|
|
[past_frames_1[..., 2], past_frames_1[..., 1], past_frames_1[..., 0]],
|
|
axis=-1,
|
|
),
|
|
lambda: past_frames_1,
|
|
)
|
|
|
|
past_frames_2 = epsd["past_frames_2"]
|
|
past_frames_2 = tf.cond(
|
|
tf.equal(tf.shape(past_frames_2)[-1], 3),
|
|
lambda: tf.stack(
|
|
[past_frames_2[..., 2], past_frames_2[..., 1], past_frames_2[..., 0]],
|
|
axis=-1,
|
|
),
|
|
lambda: past_frames_2,
|
|
)
|
|
|
|
past_frames_3 = epsd["past_frames_3"]
|
|
past_frames_3 = tf.cond(
|
|
tf.equal(tf.shape(past_frames_3)[-1], 3),
|
|
lambda: tf.stack(
|
|
[past_frames_3[..., 2], past_frames_3[..., 1], past_frames_3[..., 0]],
|
|
axis=-1,
|
|
),
|
|
lambda: past_frames_3,
|
|
)
|
|
|
|
return {
|
|
"dataset_name": epsd["dataset_name"],
|
|
"episode_dict": epsd["episode_dict"],
|
|
"step_id": epsd["step_id"],
|
|
"past_frames_0": past_frames_0,
|
|
"past_frames_0_time_mask": epsd["past_frames_0_time_mask"],
|
|
"past_frames_1": past_frames_1,
|
|
"past_frames_1_time_mask": epsd["past_frames_1_time_mask"],
|
|
"past_frames_2": past_frames_2,
|
|
"past_frames_2_time_mask": epsd["past_frames_2_time_mask"],
|
|
"past_frames_3": past_frames_3,
|
|
"past_frames_3_time_mask": epsd["past_frames_3_time_mask"],
|
|
}
|
|
|
|
|
|
def flatten_episode(episode: dict) -> tf.data.Dataset:
|
|
"""
|
|
Flatten the episode to a list of steps.
|
|
"""
|
|
episode_dict = episode["episode_dict"]
|
|
dataset_name = episode["dataset_name"]
|
|
|
|
json_content, states, masks = generate_json_state(episode_dict, dataset_name)
|
|
|
|
# Calculate the past_states for each step
|
|
# Each step has a window of previous states with size ACTION_CHUNK_SIZE
|
|
# Use the first state to pad the states
|
|
# past_states will have shape (num_steps, ACTION_CHUNK_SIZE, state_dim)
|
|
first_state = tf.expand_dims(states[0], axis=0)
|
|
first_state = tf.repeat(first_state, ACTION_CHUNK_SIZE - 1, axis=0)
|
|
padded_states = tf.concat([first_state, states], axis=0)
|
|
indices = tf.range(ACTION_CHUNK_SIZE, tf.shape(states)[0] + ACTION_CHUNK_SIZE)
|
|
past_states = tf.map_fn(lambda i: padded_states[i - ACTION_CHUNK_SIZE:i], indices, dtype=tf.float32)
|
|
states_time_mask = tf.ones([tf.shape(states)[0]], dtype=tf.bool)
|
|
padded_states_time_mask = tf.pad(
|
|
states_time_mask,
|
|
[[ACTION_CHUNK_SIZE - 1, 0]],
|
|
"CONSTANT",
|
|
constant_values=False,
|
|
)
|
|
past_states_time_mask = tf.map_fn(
|
|
lambda i: padded_states_time_mask[i - ACTION_CHUNK_SIZE:i],
|
|
indices,
|
|
dtype=tf.bool,
|
|
)
|
|
|
|
# Calculate the future_states for each step
|
|
# Each step has a window of future states with size ACTION_CHUNK_SIZE
|
|
# Use the last state to pad the states
|
|
# future_states will have shape (num_steps, ACTION_CHUNK_SIZE, state_dim)
|
|
last_state = tf.expand_dims(states[-1], axis=0)
|
|
last_state = tf.repeat(last_state, ACTION_CHUNK_SIZE, axis=0)
|
|
padded_states = tf.concat([states, last_state], axis=0)
|
|
indices = tf.range(1, tf.shape(states)[0] + 1)
|
|
future_states = tf.map_fn(lambda i: padded_states[i:i + ACTION_CHUNK_SIZE], indices, dtype=tf.float32)
|
|
states_time_mask = tf.ones([tf.shape(states)[0]], dtype=tf.bool)
|
|
padded_states_time_mask = tf.pad(states_time_mask, [[0, ACTION_CHUNK_SIZE]], "CONSTANT", constant_values=False)
|
|
future_states_time_mask = tf.map_fn(
|
|
lambda i: padded_states_time_mask[i:i + ACTION_CHUNK_SIZE],
|
|
indices,
|
|
dtype=tf.bool,
|
|
)
|
|
|
|
# Calculate the mean and std for state
|
|
state_std = tf.math.reduce_std(states, axis=0, keepdims=True)
|
|
state_std = tf.repeat(state_std, tf.shape(states)[0], axis=0)
|
|
state_mean = tf.math.reduce_mean(states, axis=0, keepdims=True)
|
|
state_mean = tf.repeat(state_mean, tf.shape(states)[0], axis=0)
|
|
|
|
state_norm = tf.math.reduce_mean(tf.math.square(states), axis=0, keepdims=True)
|
|
state_norm = tf.math.sqrt(state_norm)
|
|
state_norm = tf.repeat(state_norm, tf.shape(states)[0], axis=0)
|
|
|
|
# Create a list of steps
|
|
step_data = []
|
|
for i in range(tf.shape(states)[0]):
|
|
step_data.append({
|
|
"step_id": episode["step_id"][i],
|
|
"json_content": json_content,
|
|
"state_chunk": past_states[i],
|
|
"state_chunk_time_mask": past_states_time_mask[i],
|
|
"action_chunk": future_states[i],
|
|
"action_chunk_time_mask": future_states_time_mask[i],
|
|
"state_vec_mask": masks[i],
|
|
"past_frames_0": episode["past_frames_0"][i],
|
|
"past_frames_0_time_mask": episode["past_frames_0_time_mask"][i],
|
|
"past_frames_1": episode["past_frames_1"][i],
|
|
"past_frames_1_time_mask": episode["past_frames_1_time_mask"][i],
|
|
"past_frames_2": episode["past_frames_2"][i],
|
|
"past_frames_2_time_mask": episode["past_frames_2_time_mask"][i],
|
|
"past_frames_3": episode["past_frames_3"][i],
|
|
"past_frames_3_time_mask": episode["past_frames_3_time_mask"][i],
|
|
"state_std": state_std[i],
|
|
"state_mean": state_mean[i],
|
|
"state_norm": state_norm[i],
|
|
})
|
|
|
|
return step_data
|
|
|
|
|
|
def flatten_episode_agilex(episode: dict) -> tf.data.Dataset:
|
|
"""
|
|
Flatten the episode to a list of steps.
|
|
"""
|
|
episode_dict = episode["episode_dict"]
|
|
dataset_name = episode["dataset_name"]
|
|
|
|
json_content, states, masks, acts = generate_json_state(episode_dict, dataset_name)
|
|
|
|
# Calculate the past_states for each step
|
|
# Each step has a window of previous states with size ACTION_CHUNK_SIZE
|
|
# Use the first state to pad the states
|
|
# past_states will have shape (num_steps, ACTION_CHUNK_SIZE, state_dim)
|
|
first_state = tf.expand_dims(states[0], axis=0)
|
|
first_state = tf.repeat(first_state, ACTION_CHUNK_SIZE - 1, axis=0)
|
|
padded_states = tf.concat([first_state, states], axis=0)
|
|
indices = tf.range(ACTION_CHUNK_SIZE, tf.shape(states)[0] + ACTION_CHUNK_SIZE)
|
|
past_states = tf.map_fn(lambda i: padded_states[i - ACTION_CHUNK_SIZE:i], indices, dtype=tf.float32)
|
|
states_time_mask = tf.ones([tf.shape(states)[0]], dtype=tf.bool)
|
|
padded_states_time_mask = tf.pad(
|
|
states_time_mask,
|
|
[[ACTION_CHUNK_SIZE - 1, 0]],
|
|
"CONSTANT",
|
|
constant_values=False,
|
|
)
|
|
past_states_time_mask = tf.map_fn(
|
|
lambda i: padded_states_time_mask[i - ACTION_CHUNK_SIZE:i],
|
|
indices,
|
|
dtype=tf.bool,
|
|
)
|
|
|
|
# NOTE bg the future states shall be actions
|
|
# Calculate the future_states for each step
|
|
# Each step has a window of future states with size ACTION_CHUNK_SIZE
|
|
# Use the last action to pad the states
|
|
# future_states will have shape (num_steps, ACTION_CHUNK_SIZE, state_dim)
|
|
last_act = tf.expand_dims(acts[-1], axis=0)
|
|
last_act = tf.repeat(last_act, ACTION_CHUNK_SIZE, axis=0)
|
|
padded_states = tf.concat([acts, last_act], axis=0)
|
|
# indices = tf.range(1, tf.shape(states)[0] + 1)
|
|
indices = tf.range(0, tf.shape(acts)[0]) # NOTE time 0 action = time 1 state
|
|
future_states = tf.map_fn(lambda i: padded_states[i:i + ACTION_CHUNK_SIZE], indices, dtype=tf.float32)
|
|
states_time_mask = tf.ones([tf.shape(acts)[0]], dtype=tf.bool)
|
|
padded_states_time_mask = tf.pad(states_time_mask, [[0, ACTION_CHUNK_SIZE]], "CONSTANT", constant_values=False)
|
|
future_states_time_mask = tf.map_fn(
|
|
lambda i: padded_states_time_mask[i:i + ACTION_CHUNK_SIZE],
|
|
indices,
|
|
dtype=tf.bool,
|
|
)
|
|
|
|
# Calculate the std and mean for state
|
|
state_std = tf.math.reduce_std(states, axis=0, keepdims=True)
|
|
state_std = tf.repeat(state_std, tf.shape(states)[0], axis=0)
|
|
state_mean = tf.math.reduce_mean(states, axis=0, keepdims=True)
|
|
state_mean = tf.repeat(state_mean, tf.shape(states)[0], axis=0)
|
|
|
|
state_norm = tf.math.reduce_mean(tf.math.square(acts), axis=0, keepdims=True)
|
|
state_norm = tf.math.sqrt(state_norm)
|
|
state_norm = tf.repeat(state_norm, tf.shape(states)[0], axis=0)
|
|
|
|
# Create a list of steps
|
|
step_data = []
|
|
for i in range(tf.shape(states)[0]):
|
|
step_data.append({
|
|
"step_id": episode["step_id"][i],
|
|
"json_content": json_content,
|
|
"state_chunk": past_states[i],
|
|
"state_chunk_time_mask": past_states_time_mask[i],
|
|
"action_chunk": future_states[i],
|
|
"action_chunk_time_mask": future_states_time_mask[i],
|
|
"state_vec_mask": masks[i],
|
|
"past_frames_0": episode["past_frames_0"][i],
|
|
"past_frames_0_time_mask": episode["past_frames_0_time_mask"][i],
|
|
"past_frames_1": episode["past_frames_1"][i],
|
|
"past_frames_1_time_mask": episode["past_frames_1_time_mask"][i],
|
|
"past_frames_2": episode["past_frames_2"][i],
|
|
"past_frames_2_time_mask": episode["past_frames_2_time_mask"][i],
|
|
"past_frames_3": episode["past_frames_3"][i],
|
|
"past_frames_3_time_mask": episode["past_frames_3_time_mask"][i],
|
|
"state_std": state_std[i],
|
|
"state_mean": state_mean[i],
|
|
"state_norm": state_norm[i],
|
|
})
|
|
|
|
return step_data
|