300 lines
10 KiB
Python
300 lines
10 KiB
Python
import json
|
|
|
|
import tensorflow as tf
|
|
import yaml
|
|
|
|
from data.preprocess_scripts import *
|
|
from configs.state_vec import STATE_VEC_IDX_MAPPING, STATE_VEC_LEN
|
|
from data.utils import capitalize_and_period
|
|
|
|
# The dataset without state
|
|
DATASET_NAMES_NO_STATE = [
|
|
"nyu_door_opening_surprising_effectiveness",
|
|
"usc_cloth_sim_converted_externally_to_rlds",
|
|
"cmu_franka_exploration_dataset_converted_externally_to_rlds",
|
|
"imperialcollege_sawyer_wrist_cam",
|
|
]
|
|
|
|
# Read the image keys of each dataset
|
|
with open("configs/dataset_img_keys.json", "r") as file:
|
|
IMAGE_KEYS = json.load(file)
|
|
# Read the config
|
|
with open("configs/base.yaml", "r") as file:
|
|
config = yaml.safe_load(file)
|
|
|
|
|
|
def assemble_state_vec(arm_concat: tf.Tensor, arm_format: str, base_concat=None, base_format=None) -> tf.Tensor:
|
|
"""
|
|
Assemble the state/action vector from the arm and base.
|
|
"""
|
|
state_vec = tf.zeros(STATE_VEC_LEN, dtype=tf.float32)
|
|
mask_vec = tf.zeros(STATE_VEC_LEN, dtype=tf.float32)
|
|
|
|
# Assemble the arm state
|
|
arm_concat = tf.cast(arm_concat, tf.float32)
|
|
arm_format = arm_format.split(",")
|
|
# Use the scatter_nd to avoid the duplicate indices
|
|
state_vec = tf.tensor_scatter_nd_update(state_vec, [[STATE_VEC_IDX_MAPPING[name]] for name in arm_format],
|
|
arm_concat)
|
|
mask_vec = tf.tensor_scatter_nd_update(
|
|
mask_vec,
|
|
[[STATE_VEC_IDX_MAPPING[name]] for name in arm_format],
|
|
tf.ones(len(arm_format), dtype=tf.float32),
|
|
)
|
|
|
|
# Assemble the base state if exists
|
|
if base_concat is not None:
|
|
base_concat = tf.cast(base_concat, tf.float32)
|
|
base_format = base_format.split(",")
|
|
state_vec = tf.tensor_scatter_nd_update(
|
|
state_vec,
|
|
[[STATE_VEC_IDX_MAPPING[name]] for name in base_format],
|
|
base_concat,
|
|
)
|
|
mask_vec = tf.tensor_scatter_nd_update(
|
|
mask_vec,
|
|
[[STATE_VEC_IDX_MAPPING[name]] for name in base_format],
|
|
tf.ones(len(base_format), dtype=tf.float32),
|
|
)
|
|
return state_vec, mask_vec
|
|
|
|
|
|
@tf.autograph.experimental.do_not_convert
|
|
def _generate_json_state_agilex(episode: dict, dataset_name: str):
|
|
"""
|
|
Generate the json dict and state for a given episode.
|
|
"""
|
|
# Load some constants from the config
|
|
IMG_HISTORY_SIZE = config["common"]["img_history_size"]
|
|
if IMG_HISTORY_SIZE < 1:
|
|
raise ValueError("Config `img_history_size` must be at least 1.")
|
|
ACTION_CHUNK_SIZE = config["common"]["action_chunk_size"]
|
|
if ACTION_CHUNK_SIZE < 1:
|
|
raise ValueError("Config `action_chunk_size` must be at least 1.")
|
|
|
|
# Initialize the episode_metadata
|
|
episode_metadata = {"dataset_name": dataset_name, "#steps": 0, "instruction": None}
|
|
|
|
# Check whether this episode has an 'END'
|
|
base_act = None
|
|
last_base_act = None
|
|
episode_states = []
|
|
episode_acts = []
|
|
episode_masks = []
|
|
has_base = None
|
|
for step_id, step in enumerate(iter(episode["steps"])):
|
|
# Parse the action
|
|
action = step["action"]
|
|
if has_base is None:
|
|
has_base = "base_concat" in action
|
|
if has_base:
|
|
base_act = action["base_concat"]
|
|
|
|
# Parse the state
|
|
state = step["observation"]
|
|
|
|
arm_format = state["format"].numpy().decode("utf-8")
|
|
base_format = None
|
|
if has_base:
|
|
act_format = action["format"].numpy().decode("utf-8")
|
|
base_formate_idx = act_format.find("base")
|
|
base_format = act_format[base_formate_idx:]
|
|
|
|
arm_state = state["arm_concat"]
|
|
base_state = None
|
|
if has_base:
|
|
if last_base_act is None:
|
|
base_state = base_act * 0
|
|
else:
|
|
base_state = last_base_act
|
|
last_base_act = base_act
|
|
|
|
# Assemble the state vector
|
|
state_vec, mask_vec = assemble_state_vec(arm_state, arm_format, base_state, base_format)
|
|
|
|
act_vec, mask_vec = assemble_state_vec(action["arm_concat"], arm_format, base_state, base_format)
|
|
|
|
episode_states.append(state_vec)
|
|
episode_masks.append(mask_vec)
|
|
episode_acts.append(act_vec)
|
|
|
|
# Parse the task instruction
|
|
instr = step["observation"]["natural_language_instruction"]
|
|
instr = instr.numpy().decode("utf-8")
|
|
instr = capitalize_and_period(instr)
|
|
|
|
# Write to the episode_metadata
|
|
if episode_metadata["instruction"] is None:
|
|
episode_metadata["instruction"] = instr
|
|
|
|
episode_metadata["#steps"] = step_id
|
|
|
|
episode_states = tf.stack(episode_states)
|
|
episode_masks = tf.stack(episode_masks)
|
|
episode_acts = tf.stack(episode_acts)
|
|
|
|
return episode_metadata, episode_states, episode_masks, episode_acts
|
|
|
|
|
|
@tf.autograph.experimental.do_not_convert
|
|
def _generate_json_state(episode: dict, dataset_name: str):
|
|
"""
|
|
Generate the json dict and state for a given episode.
|
|
"""
|
|
# Load some constants from the config
|
|
IMG_HISTORY_SIZE = config["common"]["img_history_size"]
|
|
if IMG_HISTORY_SIZE < 1:
|
|
raise ValueError("Config `img_history_size` must be at least 1.")
|
|
ACTION_CHUNK_SIZE = config["common"]["action_chunk_size"]
|
|
if ACTION_CHUNK_SIZE < 1:
|
|
raise ValueError("Config `action_chunk_size` must be at least 1.")
|
|
|
|
# Initialize the episode_metadata
|
|
episode_metadata = {"dataset_name": dataset_name, "#steps": 0, "instruction": None}
|
|
|
|
# Check whether this episode has an 'END'
|
|
base_act = None
|
|
last_base_act = None
|
|
episode_states = []
|
|
episode_masks = []
|
|
has_base = None
|
|
for step_id, step in enumerate(iter(episode["steps"])):
|
|
# Parse the action
|
|
action = step["action"]
|
|
if has_base is None:
|
|
has_base = "base_concat" in action
|
|
if has_base:
|
|
base_act = action["base_concat"]
|
|
|
|
# Parse the state
|
|
state = step["observation"]
|
|
|
|
arm_format = state["format"].numpy().decode("utf-8")
|
|
base_format = None
|
|
if has_base:
|
|
act_format = action["format"].numpy().decode("utf-8")
|
|
base_formate_idx = act_format.find("base")
|
|
base_format = act_format[base_formate_idx:]
|
|
|
|
arm_state = state["arm_concat"]
|
|
base_state = None
|
|
if has_base:
|
|
if last_base_act is None:
|
|
base_state = base_act * 0
|
|
else:
|
|
base_state = last_base_act
|
|
last_base_act = base_act
|
|
|
|
# Assemble the state vector
|
|
state_vec, mask_vec = assemble_state_vec(arm_state, arm_format, base_state, base_format)
|
|
|
|
episode_states.append(state_vec)
|
|
episode_masks.append(mask_vec)
|
|
|
|
# Parse the task instruction
|
|
instr = step["observation"]["natural_language_instruction"]
|
|
instr = instr.numpy().decode("utf-8")
|
|
instr = capitalize_and_period(instr)
|
|
|
|
# Write to the episode_metadata
|
|
if episode_metadata["instruction"] is None:
|
|
episode_metadata["instruction"] = instr
|
|
|
|
episode_metadata["#steps"] = step_id
|
|
episode_states = tf.stack(episode_states)
|
|
episode_masks = tf.stack(episode_masks)
|
|
|
|
return episode_metadata, episode_states, episode_masks
|
|
|
|
|
|
@tf.autograph.experimental.do_not_convert
|
|
def _generate_json_state_nostate_ds(episode: dict, dataset_name: str):
|
|
"""
|
|
Generate the json dict and state for an episode in the dataset without state.
|
|
If not state, we use the last action as current state.
|
|
"""
|
|
# Load some constants from the config
|
|
IMG_HISTORY_SIZE = config["common"]["img_history_size"]
|
|
if IMG_HISTORY_SIZE < 1:
|
|
raise ValueError("Config `img_history_size` must be at least 1.")
|
|
ACTION_CHUNK_SIZE = config["common"]["action_chunk_size"]
|
|
if ACTION_CHUNK_SIZE < 1:
|
|
raise ValueError("Config `action_chunk_size` must be at least 1.")
|
|
|
|
# Initialize the episode_metadata
|
|
episode_metadata = {"dataset_name": dataset_name, "#steps": 0, "instruction": None}
|
|
|
|
last_base_act = None
|
|
last_arm_act = None
|
|
episode_states = []
|
|
episode_masks = []
|
|
has_base = None
|
|
for step_id, step in enumerate(iter(episode["steps"])):
|
|
# Parse the action
|
|
action = step["action"]
|
|
if has_base is None:
|
|
has_base = "base_concat" in action
|
|
if has_base:
|
|
base_act = action["base_concat"]
|
|
if last_base_act is None:
|
|
last_base_act = base_act * 0 # Initialize
|
|
|
|
# Parse the arm action
|
|
arm_act = action["arm_concat"]
|
|
if last_arm_act is None:
|
|
last_arm_act = arm_act * 0 # Initialize
|
|
|
|
# Parse the act format
|
|
# Action format as the state format
|
|
act_format = action["format"].numpy().decode("utf-8")
|
|
|
|
# Assemble the state vector
|
|
if has_base:
|
|
last_act_concat = tf.concat([last_arm_act, last_base_act], axis=0)
|
|
else:
|
|
last_act_concat = last_arm_act
|
|
state_vec, mask_vec = assemble_state_vec(last_act_concat, act_format)
|
|
|
|
episode_states.append(state_vec)
|
|
episode_masks.append(mask_vec)
|
|
|
|
# Parse the task instruction
|
|
instr = step["observation"]["natural_language_instruction"]
|
|
instr = instr.numpy().decode("utf-8")
|
|
instr = capitalize_and_period(instr)
|
|
|
|
# Write to the episode_metadata
|
|
if episode_metadata["instruction"] is None:
|
|
episode_metadata["instruction"] = instr
|
|
|
|
# Update the last_arm_act and last_base_act
|
|
last_arm_act = arm_act
|
|
if has_base:
|
|
last_base_act = base_act
|
|
|
|
episode_metadata["#steps"] = step_id
|
|
episode_states = tf.stack(episode_states)
|
|
episode_masks = tf.stack(episode_masks)
|
|
|
|
return episode_metadata, episode_states, episode_masks
|
|
|
|
|
|
@tf.autograph.experimental.do_not_convert
|
|
def generate_json_state(episode: dict, dataset_name: str):
|
|
"""
|
|
Generate the json dict and state for an episode.
|
|
"""
|
|
if isinstance(dataset_name, tf.Tensor):
|
|
dataset_name = dataset_name.numpy().decode("utf-8")
|
|
|
|
# Process each step in the episode
|
|
episode["steps"] = episode["steps"].map(globals()[dataset_name].process_step, )
|
|
|
|
if dataset_name == "agilex":
|
|
return _generate_json_state_agilex(episode, dataset_name)
|
|
|
|
if dataset_name in DATASET_NAMES_NO_STATE:
|
|
return _generate_json_state_nostate_ds(episode, dataset_name)
|
|
|
|
return _generate_json_state(episode, dataset_name)
|