314 lines
12 KiB
Python
314 lines
12 KiB
Python
import time
|
|
import json
|
|
import os
|
|
import time
|
|
import argparse
|
|
import sys
|
|
import signal
|
|
import random
|
|
from multiprocessing import Process
|
|
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
import yaml
|
|
|
|
from data.vla_dataset import VLADataset
|
|
from data.filelock import FileLock
|
|
|
|
# Producer does not need GPU
|
|
tf.config.set_visible_devices([], "GPU")
|
|
|
|
# Read the config
|
|
with open("configs/base.yaml", "r") as file:
|
|
config = yaml.safe_load(file)
|
|
# Load some constants from the config
|
|
BUF_PATH = config["dataset"]["buf_path"]
|
|
BUF_NUM_CHUNKS = config["dataset"]["buf_num_chunks"]
|
|
if BUF_NUM_CHUNKS < 1:
|
|
raise ValueError("Config `buf_num_chunks` must be at least 1.")
|
|
BUF_CHUNK_SIZE = config["dataset"]["buf_chunk_size"]
|
|
if BUF_CHUNK_SIZE < 1:
|
|
raise ValueError("Config `buf_chunk_size` must be at least 1.")
|
|
|
|
|
|
def get_dirty_item(chunk_dir):
|
|
"""
|
|
Get indexes of dirty items in a chunk.
|
|
"""
|
|
dirty_bit = read_dirty_bit(chunk_dir)
|
|
return np.where(dirty_bit)[0].tolist()
|
|
|
|
|
|
def get_clean_item(chunk_dir):
|
|
"""
|
|
Get indexes of clean items in a chunk.
|
|
"""
|
|
dirty_bit = read_dirty_bit(chunk_dir)
|
|
return np.where(1 - dirty_bit)[0].tolist()
|
|
|
|
|
|
def save_dirty_bit(chunk_dir, dirty_bit):
|
|
"""
|
|
Save the dirty bit to the chunk directory.
|
|
"""
|
|
time_stmp = time.time()
|
|
while time.time() - time_stmp < 10.0:
|
|
try:
|
|
file_path = os.path.join(chunk_dir, "dirty_bit")
|
|
lock = FileLock(file_path)
|
|
lock.acquire_write_lock()
|
|
with open(file_path, "wb") as file:
|
|
file.write(dirty_bit.tobytes())
|
|
lock.release_lock()
|
|
return
|
|
except KeyboardInterrupt:
|
|
lock.release_lock()
|
|
raise KeyboardInterrupt
|
|
except BaseException:
|
|
lock.release_lock()
|
|
continue
|
|
# raise RuntimeError("Failed to save dirty bit.")
|
|
print("Failed to save dirty bit.")
|
|
|
|
|
|
def read_dirty_bit(chunk_dir):
|
|
"""
|
|
Read the dirty bit from the chunk directory.
|
|
"""
|
|
# If error occurs, retry
|
|
time_stmp = time.time()
|
|
while time.time() - time_stmp < 10.0:
|
|
try:
|
|
file_path = os.path.join(chunk_dir, "dirty_bit")
|
|
lock = FileLock(file_path)
|
|
lock.acquire_read_lock()
|
|
with open(file_path, "rb") as file:
|
|
dirty_bit = np.frombuffer(file.read(), dtype=np.uint8).copy()
|
|
lock.release_lock()
|
|
assert len(dirty_bit) == BUF_CHUNK_SIZE
|
|
return dirty_bit
|
|
except KeyboardInterrupt:
|
|
lock.release_lock()
|
|
raise KeyboardInterrupt
|
|
except BaseException:
|
|
lock.release_lock()
|
|
continue
|
|
# If failed to read the dirty bit, return all ones for robustness
|
|
return np.ones(BUF_CHUNK_SIZE, dtype=np.uint8)
|
|
|
|
|
|
def save_sample(step_dict, chunk_dir, chunk_item_idx):
|
|
"""
|
|
Save a sample to the chunk directory.
|
|
"""
|
|
# Save the json content
|
|
time_stmp = time.time()
|
|
while time.time() - time_stmp < 10.0:
|
|
try:
|
|
locks = []
|
|
json_content = step_dict["json_content"]
|
|
file_path = os.path.join(chunk_dir, f"json_content_{chunk_item_idx}.json")
|
|
lock = FileLock(file_path)
|
|
locks.append(lock)
|
|
lock.acquire_write_lock()
|
|
with open(file_path, "w") as file:
|
|
json.dump(json_content, file, indent=4)
|
|
lock.release_lock()
|
|
# Save all other tensors in a npz
|
|
file_path = os.path.join(chunk_dir, f"sample_{chunk_item_idx}.npz")
|
|
lock = FileLock(file_path)
|
|
locks.append(lock)
|
|
lock.acquire_write_lock()
|
|
with open(file_path, "wb") as file:
|
|
np.savez(
|
|
file,
|
|
step_id=step_dict["step_id"].numpy(),
|
|
state_chunk=step_dict["state_chunk"].numpy(),
|
|
state_chunk_time_mask=step_dict["state_chunk_time_mask"].numpy(),
|
|
action_chunk=step_dict["action_chunk"].numpy(),
|
|
action_chunk_time_mask=step_dict["action_chunk_time_mask"].numpy(),
|
|
state_vec_mask=step_dict["state_vec_mask"].numpy(),
|
|
past_frames_0=step_dict["past_frames_0"].numpy(),
|
|
past_frames_0_time_mask=step_dict["past_frames_0_time_mask"].numpy(),
|
|
past_frames_1=step_dict["past_frames_1"].numpy(),
|
|
past_frames_1_time_mask=step_dict["past_frames_1_time_mask"].numpy(),
|
|
past_frames_2=step_dict["past_frames_2"].numpy(),
|
|
past_frames_2_time_mask=step_dict["past_frames_2_time_mask"].numpy(),
|
|
past_frames_3=step_dict["past_frames_3"].numpy(),
|
|
past_frames_3_time_mask=step_dict["past_frames_3_time_mask"].numpy(),
|
|
state_std=step_dict["state_std"].numpy(),
|
|
state_mean=step_dict["state_mean"].numpy(),
|
|
state_norm=step_dict["state_norm"].numpy(),
|
|
)
|
|
lock.release_lock()
|
|
return
|
|
except KeyboardInterrupt:
|
|
for lock in locks:
|
|
lock.release_lock()
|
|
raise KeyboardInterrupt
|
|
except BaseException:
|
|
for lock in locks:
|
|
lock.release_lock()
|
|
continue
|
|
# raise RuntimeError("Failed to save sample.")
|
|
print("Failed to save sample.")
|
|
|
|
|
|
def run_producer(seed, num_workers, worker_id, fill_up, clean_dirty, dataset_type):
|
|
"""
|
|
Run the producer.
|
|
The producer will first fill up the buffer with samples.
|
|
Then it will keep replacing dirty samples
|
|
(i.e., samples that have been read by the consumer)
|
|
with new samples.
|
|
"""
|
|
vla_dataset = VLADataset(seed=seed, dataset_type=dataset_type)
|
|
chunk_start_idx = worker_id * BUF_NUM_CHUNKS // num_workers
|
|
chunk_end_idx = (worker_id + 1) * BUF_NUM_CHUNKS // num_workers
|
|
if fill_up:
|
|
print(f"Worker {worker_id}: Start filling up the buffer...")
|
|
elif clean_dirty:
|
|
# Only refresh the dirty bits
|
|
print(f"Worker {worker_id}: Start refreshing the dirty bits...")
|
|
for chunk_idx in range(chunk_start_idx, chunk_end_idx):
|
|
chunk_dir = os.path.join(BUF_PATH, f"chunk_{chunk_idx}")
|
|
dirty_bit = np.zeros(BUF_CHUNK_SIZE, dtype=np.uint8)
|
|
save_dirty_bit(chunk_dir, dirty_bit)
|
|
print(f"Worker {worker_id}: Refreshed the dirty bits.")
|
|
|
|
fill_chunk_idx = chunk_start_idx
|
|
fill_chunk_item_idx = 0
|
|
dirty_chunk_idx = chunk_start_idx
|
|
dirty_chunk_item_idxs = []
|
|
time_stmp = time.time()
|
|
for episode_steps in vla_dataset:
|
|
for step in episode_steps:
|
|
if fill_up and fill_chunk_idx < chunk_end_idx:
|
|
# Fill up the buffer
|
|
chunk_dir = os.path.join(BUF_PATH, f"chunk_{fill_chunk_idx}")
|
|
if fill_chunk_item_idx == 0:
|
|
# Create a new chunk
|
|
os.makedirs(chunk_dir, exist_ok=True)
|
|
# Write the dirty bit of size BUF_CHUNK_SIZE
|
|
dirty_bit = np.zeros(BUF_CHUNK_SIZE, dtype=np.uint8)
|
|
save_dirty_bit(chunk_dir, dirty_bit)
|
|
|
|
# Save the sample
|
|
save_sample(step, chunk_dir, fill_chunk_item_idx)
|
|
|
|
# print(f"Filled up chunk {fill_chunk_item_idx+1}/{BUF_CHUNK_SIZE} {fill_chunk_idx+1}/{BUF_NUM_CHUNKS}")
|
|
local_fill_chunk_idx = fill_chunk_idx - chunk_start_idx
|
|
local_num_chunks = chunk_end_idx - chunk_start_idx
|
|
if (local_fill_chunk_idx % 10 == 0
|
|
or local_fill_chunk_idx == local_num_chunks - 1) and fill_chunk_item_idx == 0:
|
|
print(f"Worker {worker_id}: Filled up chunk {local_fill_chunk_idx+1}/{local_num_chunks}")
|
|
fill_chunk_item_idx += 1
|
|
if fill_chunk_item_idx == BUF_CHUNK_SIZE:
|
|
fill_chunk_idx += 1
|
|
fill_chunk_item_idx = 0
|
|
if fill_chunk_idx == BUF_NUM_CHUNKS:
|
|
print(f"Worker {worker_id}: Buffer filled up. Start replacing dirty samples...")
|
|
|
|
else:
|
|
# Search for the dirty chunk to replace
|
|
while len(dirty_chunk_item_idxs) == 0:
|
|
dirty_chunk_dir = os.path.join(BUF_PATH, f"chunk_{dirty_chunk_idx}")
|
|
dirty_chunk_item_idxs = get_dirty_item(dirty_chunk_dir)
|
|
# Print the dirty ratio
|
|
if time.time() - time_stmp > 2.0:
|
|
dirty_ratio = len(dirty_chunk_item_idxs) / BUF_CHUNK_SIZE
|
|
print(f"Worker {worker_id}: Dirty Ratio for Chunk {dirty_chunk_idx}: {dirty_ratio:.2f}")
|
|
time_stmp = time.time()
|
|
|
|
if len(dirty_chunk_item_idxs) > 0:
|
|
# Lock the chunk
|
|
dirty_bit = np.ones(BUF_CHUNK_SIZE, dtype=np.uint8)
|
|
save_dirty_bit(dirty_chunk_dir, dirty_bit)
|
|
|
|
# Iterate over the chunks
|
|
dirty_chunk_idx += 1
|
|
if dirty_chunk_idx == chunk_end_idx:
|
|
dirty_chunk_idx = chunk_start_idx
|
|
|
|
# Replace the dirty item
|
|
dirty_item_idx = dirty_chunk_item_idxs.pop()
|
|
chunk_dir = os.path.join(BUF_PATH, f"chunk_{dirty_chunk_idx}")
|
|
# Save the sample
|
|
save_sample(step, chunk_dir, dirty_item_idx)
|
|
|
|
# If we have replaced all dirty items in the chunk
|
|
if len(dirty_chunk_item_idxs) == 0:
|
|
# Unlock the chunk
|
|
dirty_bit = np.zeros(BUF_CHUNK_SIZE, dtype=np.uint8)
|
|
save_dirty_bit(dirty_chunk_dir, dirty_bit)
|
|
print(f"Worker {worker_id}: Replaced dirty chunk {dirty_chunk_idx}.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Args: n_workers, fill_up
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--n_workers",
|
|
type=int,
|
|
default=2,
|
|
help="Number of parallel workers. It should be less than or equal to the number of chunks.",
|
|
)
|
|
parser.add_argument(
|
|
"--fill_up",
|
|
action="store_true",
|
|
help="Whether to fill up the buffer before replacing dirty samples.",
|
|
)
|
|
parser.add_argument(
|
|
"--clean_dirty",
|
|
action="store_true",
|
|
help=
|
|
"Whether to clean the dirty bits before replacing dirty samples. This option is ignored when `fill_up` is set.",
|
|
)
|
|
parser.add_argument(
|
|
"--seed",
|
|
type=int,
|
|
default=None,
|
|
help="Random seed. If not set, the seed will be randomly generated.",
|
|
)
|
|
parser.add_argument(
|
|
"--dataset_type",
|
|
type=str,
|
|
default="pretrain",
|
|
help="Whether to load the pretrain dataset or finetune dataset.",
|
|
)
|
|
|
|
# Run the producer
|
|
args = parser.parse_args()
|
|
if args.seed is not None:
|
|
print(f"Base seed: {args.seed}")
|
|
random.seed(args.seed)
|
|
|
|
processes = []
|
|
process_seeds = [random.randint(0, 2**32) for _ in range(args.n_workers)]
|
|
print(f"Process seeds: {process_seeds}")
|
|
|
|
def signal_handler(sig, frame):
|
|
print("Ctrl+C received. Terminating child processes...")
|
|
for p in processes:
|
|
p.terminate()
|
|
sys.exit(0)
|
|
|
|
signal.signal(signal.SIGINT, signal_handler)
|
|
for worker_id in range(args.n_workers):
|
|
p = Process(
|
|
target=run_producer,
|
|
args=(
|
|
process_seeds[worker_id],
|
|
args.n_workers,
|
|
worker_id,
|
|
args.fill_up,
|
|
args.clean_dirty,
|
|
args.dataset_type,
|
|
),
|
|
)
|
|
p.start()
|
|
processes.append(p)
|
|
|
|
for p in processes:
|
|
p.join()
|