update
This commit is contained in:
parent
78702c7f47
commit
7a85849f59
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
300
client.py
Normal file
300
client.py
Normal file
@ -0,0 +1,300 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
RDT 推理服务器测试客户端
|
||||||
|
使用模拟数据测试 get_actions 接口
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import logging
|
||||||
|
import argparse
|
||||||
|
import time
|
||||||
|
from cloud_helper import Client
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s - %(levelname)s - %(message)s"
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def create_mock_observation(
|
||||||
|
state_dim=6,
|
||||||
|
img_history_size=2,
|
||||||
|
img_height=480,
|
||||||
|
img_width=640,
|
||||||
|
num_cameras=3
|
||||||
|
):
|
||||||
|
"""创建模拟的观测数据
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_dim: 状态向量维度(关节数量)
|
||||||
|
img_history_size: 图像历史长度
|
||||||
|
img_height: 图像高度
|
||||||
|
img_width: 图像宽度
|
||||||
|
num_cameras: 相机数量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
observation: 包含状态和图像的观测字典
|
||||||
|
"""
|
||||||
|
observation = {}
|
||||||
|
|
||||||
|
# 1. 创建模拟的机器人状态(关节角度等)
|
||||||
|
# 范围在 [-180, 180] 度之间
|
||||||
|
state = np.random.uniform(-180, 180, size=(state_dim,)).astype(np.float32)
|
||||||
|
observation["state"] = state
|
||||||
|
|
||||||
|
# 2. 创建模拟的相机图像
|
||||||
|
# 注意:msgpack_numpy 会自动处理 numpy 数组的序列化
|
||||||
|
camera_names = ["cam_high", "cam_left_wrist", "cam_right_wrist"]
|
||||||
|
|
||||||
|
for i, cam_name in enumerate(camera_names[:num_cameras]):
|
||||||
|
# 创建彩色渐变图像作为模拟数据
|
||||||
|
images = []
|
||||||
|
for t in range(img_history_size):
|
||||||
|
# 为每个时间步创建不同颜色的图像
|
||||||
|
img = np.zeros((img_height, img_width, 3), dtype=np.uint8)
|
||||||
|
|
||||||
|
# 创建彩色渐变效果
|
||||||
|
color_shift = (t * 50 + i * 100) % 255
|
||||||
|
img[:, :, 0] = np.linspace(color_shift, 255, img_width, dtype=np.uint8) # R
|
||||||
|
img[:, :, 1] = np.linspace(0, 255 - color_shift, img_height, dtype=np.uint8)[:, None] # G
|
||||||
|
img[:, :, 2] = 128 # B
|
||||||
|
|
||||||
|
images.append(img)
|
||||||
|
|
||||||
|
# 堆叠为 (IMG_HISTORY_SIZE, H, W, 3) 格式
|
||||||
|
observation[f"images.{cam_name}"] = np.stack(images, axis=0)
|
||||||
|
|
||||||
|
return observation
|
||||||
|
|
||||||
|
|
||||||
|
def create_test_batch(
|
||||||
|
observation,
|
||||||
|
instruction="pick up the object and place it in the box",
|
||||||
|
use_instruction_index=False
|
||||||
|
):
|
||||||
|
"""创建完整的测试批次数据
|
||||||
|
|
||||||
|
Args:
|
||||||
|
observation: 观测数据字典
|
||||||
|
instruction: 指令字符串或索引
|
||||||
|
use_instruction_index: 是否使用指令索引而非字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
batch: 完整的请求数据
|
||||||
|
"""
|
||||||
|
batch = {
|
||||||
|
"observation": observation,
|
||||||
|
"instruction": 0 if use_instruction_index else instruction
|
||||||
|
}
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
def test_single_request(client, args):
|
||||||
|
"""测试单次请求"""
|
||||||
|
logger.info("=" * 60)
|
||||||
|
logger.info("开始单次请求测试")
|
||||||
|
logger.info("=" * 60)
|
||||||
|
|
||||||
|
# 创建模拟数据
|
||||||
|
observation = create_mock_observation(
|
||||||
|
state_dim=args.state_dim,
|
||||||
|
img_history_size=args.img_history_size,
|
||||||
|
img_height=args.img_height,
|
||||||
|
img_width=args.img_width,
|
||||||
|
num_cameras=args.num_cameras
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"模拟观测数据:")
|
||||||
|
logger.info(f" - state shape: {observation['state'].shape}")
|
||||||
|
for key in observation.keys():
|
||||||
|
if key.startswith("images."):
|
||||||
|
logger.info(f" - {key} shape: {observation[key].shape}")
|
||||||
|
|
||||||
|
# 创建请求批次
|
||||||
|
batch = create_test_batch(
|
||||||
|
observation,
|
||||||
|
instruction=args.instruction,
|
||||||
|
use_instruction_index=args.use_index
|
||||||
|
)
|
||||||
|
|
||||||
|
# 发送请求
|
||||||
|
logger.info(f"发送指令: {batch['instruction']}")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
action = client.call_endpoint("get_actions", batch)
|
||||||
|
elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
|
logger.info(f"✓ 请求成功! 耗时: {elapsed_time*1000:.2f} ms")
|
||||||
|
logger.info(f" - action shape: {action.shape}")
|
||||||
|
logger.info(f" - action dtype: {action.dtype}")
|
||||||
|
logger.info(f" - action range: [{action.min():.3f}, {action.max():.3f}]")
|
||||||
|
logger.info(f" - action preview (前3个时间步的前3个维度):")
|
||||||
|
preview_steps = min(3, action.shape[0])
|
||||||
|
preview_dims = min(3, action.shape[1])
|
||||||
|
for t in range(preview_steps):
|
||||||
|
logger.info(f" t={t}: {action[t, :preview_dims]}")
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"✗ 请求失败: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiple_requests(client, args):
|
||||||
|
"""测试多次连续请求(性能测试)"""
|
||||||
|
logger.info("=" * 60)
|
||||||
|
logger.info(f"开始连续请求测试 (共 {args.num_requests} 次)")
|
||||||
|
logger.info("=" * 60)
|
||||||
|
|
||||||
|
# 预先创建观测数据
|
||||||
|
observation = create_mock_observation(
|
||||||
|
state_dim=args.state_dim,
|
||||||
|
img_history_size=args.img_history_size,
|
||||||
|
img_height=args.img_height,
|
||||||
|
img_width=args.img_width,
|
||||||
|
num_cameras=args.num_cameras
|
||||||
|
)
|
||||||
|
|
||||||
|
batch = create_test_batch(
|
||||||
|
observation,
|
||||||
|
instruction=args.instruction,
|
||||||
|
use_instruction_index=args.use_index
|
||||||
|
)
|
||||||
|
|
||||||
|
success_count = 0
|
||||||
|
total_time = 0
|
||||||
|
latencies = []
|
||||||
|
|
||||||
|
for i in range(args.num_requests):
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
action = client.call_endpoint("get_actions", batch)
|
||||||
|
elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
|
success_count += 1
|
||||||
|
total_time += elapsed_time
|
||||||
|
latencies.append(elapsed_time)
|
||||||
|
|
||||||
|
if (i + 1) % 10 == 0:
|
||||||
|
logger.info(f"已完成 {i + 1}/{args.num_requests} 次请求")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"第 {i+1} 次请求失败: {e}")
|
||||||
|
|
||||||
|
# 统计结果
|
||||||
|
logger.info("=" * 60)
|
||||||
|
logger.info("性能统计:")
|
||||||
|
logger.info(f" - 总请求数: {args.num_requests}")
|
||||||
|
logger.info(f" - 成功数: {success_count}")
|
||||||
|
logger.info(f" - 失败数: {args.num_requests - success_count}")
|
||||||
|
logger.info(f" - 成功率: {success_count/args.num_requests*100:.1f}%")
|
||||||
|
|
||||||
|
if latencies:
|
||||||
|
latencies = np.array(latencies)
|
||||||
|
logger.info(f" - 平均延迟: {np.mean(latencies)*1000:.2f} ms")
|
||||||
|
logger.info(f" - 中位数延迟: {np.median(latencies)*1000:.2f} ms")
|
||||||
|
logger.info(f" - 最小延迟: {np.min(latencies)*1000:.2f} ms")
|
||||||
|
logger.info(f" - 最大延迟: {np.max(latencies)*1000:.2f} ms")
|
||||||
|
logger.info(f" - 吞吐量: {success_count/total_time:.2f} requests/s")
|
||||||
|
|
||||||
|
|
||||||
|
def test_different_instructions(client, args):
|
||||||
|
"""测试不同的指令"""
|
||||||
|
logger.info("=" * 60)
|
||||||
|
logger.info("测试不同指令")
|
||||||
|
logger.info("=" * 60)
|
||||||
|
|
||||||
|
instructions = [
|
||||||
|
"pick up the red cube",
|
||||||
|
"place the object on the table",
|
||||||
|
"move to the left",
|
||||||
|
"grasp the bottle",
|
||||||
|
"open the drawer"
|
||||||
|
]
|
||||||
|
|
||||||
|
observation = create_mock_observation(
|
||||||
|
state_dim=args.state_dim,
|
||||||
|
img_history_size=args.img_history_size,
|
||||||
|
img_height=args.img_height,
|
||||||
|
img_width=args.img_width,
|
||||||
|
num_cameras=args.num_cameras
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, instruction in enumerate(instructions):
|
||||||
|
logger.info(f"\n测试指令 {i+1}/{len(instructions)}: '{instruction}'")
|
||||||
|
batch = create_test_batch(observation, instruction=instruction)
|
||||||
|
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
action = client.call_endpoint("get_actions", batch)
|
||||||
|
elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
|
logger.info(f" ✓ 成功 | 耗时: {elapsed_time*1000:.2f} ms | action shape: {action.shape}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f" ✗ 失败: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="RDT 推理服务器测试客户端")
|
||||||
|
|
||||||
|
# 连接参数
|
||||||
|
parser.add_argument("--host", type=str, default="localhost", help="服务器地址")
|
||||||
|
parser.add_argument("--port", type=int, default=8001, help="服务器端口")
|
||||||
|
|
||||||
|
# 测试模式
|
||||||
|
parser.add_argument("--mode", type=str, default="single",
|
||||||
|
choices=["single", "multiple", "instructions"],
|
||||||
|
help="测试模式: single(单次), multiple(多次), instructions(不同指令)")
|
||||||
|
parser.add_argument("--num-requests", type=int, default=50,
|
||||||
|
help="多次测试的请求数量")
|
||||||
|
|
||||||
|
# 数据参数
|
||||||
|
parser.add_argument("--state-dim", type=int, default=6, help="状态向量维度")
|
||||||
|
parser.add_argument("--img-history-size", type=int, default=2, help="图像历史长度")
|
||||||
|
parser.add_argument("--img-height", type=int, default=480, help="图像高度")
|
||||||
|
parser.add_argument("--img-width", type=int, default=640, help="图像宽度")
|
||||||
|
parser.add_argument("--num-cameras", type=int, default=3, help="相机数量 (与服务器配置一致)")
|
||||||
|
|
||||||
|
# 指令参数
|
||||||
|
parser.add_argument("--instruction", type=str,
|
||||||
|
default="pick up the object and place it in the box",
|
||||||
|
help="测试指令")
|
||||||
|
parser.add_argument("--use-index", action="store_true",
|
||||||
|
help="使用指令索引而非字符串")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# 连接服务器
|
||||||
|
logger.info(f"正在连接到 {args.host}:{args.port} ...")
|
||||||
|
try:
|
||||||
|
client = Client(host=args.host, port=args.port)
|
||||||
|
logger.info("✓ 连接成功!")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"✗ 连接失败: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 根据模式运行测试
|
||||||
|
try:
|
||||||
|
if args.mode == "single":
|
||||||
|
test_single_request(client, args)
|
||||||
|
elif args.mode == "multiple":
|
||||||
|
test_multiple_requests(client, args)
|
||||||
|
elif args.mode == "instructions":
|
||||||
|
test_different_instructions(client, args)
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("\n测试被用户中断")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"测试过程中发生错误: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
||||||
162
cloud_helper.py
Normal file
162
cloud_helper.py
Normal file
@ -0,0 +1,162 @@
|
|||||||
|
import zmq
|
||||||
|
import msgpack
|
||||||
|
import msgpack_numpy as m
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
|
||||||
|
from typing import Any, Callable
|
||||||
|
import zstandard as zstd
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
compresser = zstd.ZstdCompressor(level=12)
|
||||||
|
decompresser = zstd.ZstdDecompressor()
|
||||||
|
|
||||||
|
|
||||||
|
def _pack(data: Any) -> bytes:
|
||||||
|
return compresser.compress(msgpack.packb(data, default=m.encode, use_bin_type=True))
|
||||||
|
|
||||||
|
|
||||||
|
def _unpack(data: bytes) -> Any:
|
||||||
|
return msgpack.unpackb(
|
||||||
|
decompresser.decompress(data), object_hook=m.decode, raw=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Server:
|
||||||
|
def __init__(self, host: str = "*", port: int = 5555):
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
|
||||||
|
self.context = zmq.Context()
|
||||||
|
self.socket = self.context.socket(zmq.REP)
|
||||||
|
self.socket.bind(f"tcp://{self.host}:{self.port}")
|
||||||
|
logger.info(f"Server started at tcp://{self.host}:{self.port}")
|
||||||
|
|
||||||
|
self.endpoints: dict[str, Callable[[Any], Any]] = {}
|
||||||
|
|
||||||
|
def register_endpoint(self, command: str, func: Callable[[Any], Any]):
|
||||||
|
self.endpoints[command] = func
|
||||||
|
logger.info(f"Registered endpoint: {command} -> {func}")
|
||||||
|
|
||||||
|
def return_error(self, message: str) -> None:
|
||||||
|
self.socket.send(_pack({"status": "error", "data": message}))
|
||||||
|
|
||||||
|
def return_ok(self, data: Any) -> None:
|
||||||
|
self.socket.send(_pack({"status": "ok", "data": data}))
|
||||||
|
|
||||||
|
def handle_once(self) -> None:
|
||||||
|
message = self.socket.recv()
|
||||||
|
message = _unpack(message)
|
||||||
|
|
||||||
|
cmd = message.get("command")
|
||||||
|
data = message.get("data")
|
||||||
|
|
||||||
|
logger.info("Received Command: %s", cmd)
|
||||||
|
|
||||||
|
handler = self.endpoints.get(cmd)
|
||||||
|
|
||||||
|
if handler is not None:
|
||||||
|
try:
|
||||||
|
if data is None:
|
||||||
|
response = handler()
|
||||||
|
else:
|
||||||
|
response = handler(data)
|
||||||
|
self.return_ok(response)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error handling command {cmd}: {e}")
|
||||||
|
self.return_error(str(e))
|
||||||
|
else:
|
||||||
|
logger.warning(f"Unknown command: {cmd}")
|
||||||
|
self.return_error(f"Unknown command: {cmd}")
|
||||||
|
|
||||||
|
def loop_forever(self):
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
self.handle_once()
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("Server shutting down...")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
self.socket.close()
|
||||||
|
self.context.term()
|
||||||
|
|
||||||
|
|
||||||
|
class Client:
|
||||||
|
def __init__(self, host: str = "localhost", port: int = 5555):
|
||||||
|
self.context = zmq.Context()
|
||||||
|
self.socket = self.context.socket(zmq.REQ)
|
||||||
|
self.socket.connect(f"tcp://{host}:{port}")
|
||||||
|
logger.info(f"Client connected to tcp://{host}:{port}")
|
||||||
|
|
||||||
|
def call_endpoint(self, command: str, data=None):
|
||||||
|
self.socket.send(_pack({"command": command, "data": data}))
|
||||||
|
message = self.socket.recv()
|
||||||
|
message = _unpack(message)
|
||||||
|
|
||||||
|
if message.get("status") == "ok":
|
||||||
|
return message.get("data")
|
||||||
|
else:
|
||||||
|
logger.error(f"Error from server: {message.get('data')}")
|
||||||
|
raise Exception(f"Error from server: {message.get('data')}")
|
||||||
|
|
||||||
|
|
||||||
|
def freq_control(freq: int = 25):
|
||||||
|
def decorator(func):
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
start_time = time.time()
|
||||||
|
result = func(*args, **kwargs)
|
||||||
|
end_time = time.time()
|
||||||
|
elapsed_time = end_time - start_time
|
||||||
|
# logger.info(f"'{func.__name__}' tooks {elapsed_time * 1000:.2f} ms")
|
||||||
|
time.sleep(max(0, (1.0 / freq) - elapsed_time))
|
||||||
|
return result
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import sys
|
||||||
|
from time import sleep
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (len(sys.argv) == 2) and ((mode := sys.argv[1]) in ("server", "client")), (
|
||||||
|
"Usage: python service.py [server|client]"
|
||||||
|
)
|
||||||
|
|
||||||
|
## Protocol:
|
||||||
|
# Request: { "command": str, "data": Any }
|
||||||
|
# Response: { "status": "ok" | "error", "data": Any if status=="ok" else str (ErrorMsg) }
|
||||||
|
|
||||||
|
if mode == "server":
|
||||||
|
server = Server()
|
||||||
|
server.register_endpoint("ping", lambda: "pong")
|
||||||
|
server.register_endpoint("echo", lambda x: x)
|
||||||
|
server.register_endpoint("add", lambda data: data["a"] + data["b"])
|
||||||
|
server.loop_forever()
|
||||||
|
|
||||||
|
elif mode == "client":
|
||||||
|
client = Client()
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
response = client.call_endpoint("ping")
|
||||||
|
print(f"Response from server: {response}")
|
||||||
|
response = client.call_endpoint("echo", "Hello, World!")
|
||||||
|
print(f"Response from server: {response}")
|
||||||
|
response = client.call_endpoint("add", {"a": 5, "b": 10})
|
||||||
|
print(f"Response from server: {response}")
|
||||||
|
|
||||||
|
sleep(0.2)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
|
break
|
||||||
43
lerobot2rdt/Dockerfile
Normal file
43
lerobot2rdt/Dockerfile
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
|
||||||
|
FROM registry.d-robotics.cc/public/cuda:11.8.0-cudnn8-devel-ubuntu22.04
|
||||||
|
# ccr-29eug8s3-pub.cnc.bj.baidubce.com/public/cuda:11.8.0-cudnn8-devel-ubuntu22.04
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
|
ENV PYTHONUNBUFFERED=1
|
||||||
|
ENV TZ=Asia/Shanghai
|
||||||
|
|
||||||
|
RUN sed -i 's/archive.ubuntu.com/mirrors.tuna.tsinghua.edu.cn/g' /etc/apt/sources.list && \
|
||||||
|
sed -i 's/security.ubuntu.com/mirrors.tuna.tsinghua.edu.cn/g' /etc/apt/sources.list
|
||||||
|
|
||||||
|
RUN apt-get update --allow-unauthenticated && apt-get install -y \
|
||||||
|
software-properties-common \
|
||||||
|
&& add-apt-repository ppa:deadsnakes/ppa \
|
||||||
|
&& apt-get update \
|
||||||
|
&& apt-get install -y \
|
||||||
|
python3.10 \
|
||||||
|
python3.10-dev \
|
||||||
|
python3-pip \
|
||||||
|
python3.10-distutils \
|
||||||
|
libgl1-mesa-glx \
|
||||||
|
libglib2.0-0 \
|
||||||
|
wget \
|
||||||
|
libsm6 \
|
||||||
|
libxext6 \
|
||||||
|
ffmpeg \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 1
|
||||||
|
|
||||||
|
COPY . /app/
|
||||||
|
|
||||||
|
RUN python3 -m pip install --upgrade pip -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||||
|
|
||||||
|
# RUN pip install torch==2.1.0 torchvision==0.16.0 --index-url https://download.pytorch.org/whl/cu121
|
||||||
|
RUN pip install torch==2.1.0 torchvision==0.16.0 -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||||
|
RUN pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||||
|
RUN pip install packaging==24.0
|
||||||
|
|
||||||
|
# RUN mkdir -p /app/dataset/input /app/dataset/output
|
||||||
|
|
||||||
|
ENTRYPOINT ["bash", "convert.sh"]
|
||||||
60
lerobot2rdt/convert.sh
Normal file
60
lerobot2rdt/convert.sh
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
|
||||||
|
BEGIN_TIME=$(date +%s)
|
||||||
|
|
||||||
|
CONFIG_FILE="input/config.json"
|
||||||
|
echo "CONFIG_FILE_PATH: $CONFIG_FILE"
|
||||||
|
|
||||||
|
# Read values directly from the config.json using python - no more nested key error by using a helper script
|
||||||
|
TASK_ID=$(python3 read_json.py "$CONFIG_FILE" "task_id")
|
||||||
|
DATA_DIR=$(python3 read_json.py "$CONFIG_FILE" "data_dir")
|
||||||
|
OUTPUT_DIR=$(python3 read_json.py "$CONFIG_FILE" "output_dir")
|
||||||
|
EPISODE_NUM=$(python3 read_json.py "$CONFIG_FILE" "episode_num")
|
||||||
|
GPU=$(python3 read_json.py "$CONFIG_FILE" "gpu")
|
||||||
|
T5_PATH="/weights/t5-v1_1-xxl"
|
||||||
|
NO_LANGUAGE=$(python3 read_json.py "$CONFIG_FILE" "no_language")
|
||||||
|
|
||||||
|
# For the camera keys, extract them in a way that avoids the error about 'images_info.key.*' not found
|
||||||
|
CAM_HIGH_KEY=$(python3 -c "import json; print(json.load(open('$CONFIG_FILE'))['images_info']['key'].get('cam_high', ''))")
|
||||||
|
CAM_RIGHT_WRIST_KEY=$(python3 -c "import json; print(json.load(open('$CONFIG_FILE'))['images_info']['key'].get('cam_right_wrist', ''))")
|
||||||
|
|
||||||
|
# create output path
|
||||||
|
if [ ! -d "$OUTPUT_DIR/$TASK_ID" ]; then
|
||||||
|
mkdir -p "$OUTPUT_DIR/$TASK_ID"
|
||||||
|
echo "Created output directory: $OUTPUT_DIR/$TASK_ID"
|
||||||
|
else
|
||||||
|
echo "Output directory already exists: $OUTPUT_DIR/$TASK_ID"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ "$NO_LANGUAGE" = "true" ]; then
|
||||||
|
python3 lerobot2rdt.py \
|
||||||
|
--data_dir $DATA_DIR \
|
||||||
|
--output_dir $OUTPUT_DIR/$TASK_ID \
|
||||||
|
--episode_num $EPISODE_NUM \
|
||||||
|
--gpu $GPU \
|
||||||
|
--t5_path $T5_PATH \
|
||||||
|
--cam_high_key $CAM_HIGH_KEY \
|
||||||
|
--cam_right_wrist_key $CAM_RIGHT_WRIST_KEY \
|
||||||
|
--no_language
|
||||||
|
status=$?
|
||||||
|
else
|
||||||
|
python3 lerobot2rdt.py \
|
||||||
|
--data_dir $DATA_DIR \
|
||||||
|
--output_dir $OUTPUT_DIR/$TASK_ID \
|
||||||
|
--episode_num $EPISODE_NUM \
|
||||||
|
--gpu $GPU \
|
||||||
|
--t5_path $T5_PATH \
|
||||||
|
--cam_high_key $CAM_HIGH_KEY \
|
||||||
|
--cam_right_wrist_key $CAM_RIGHT_WRIST_KEY
|
||||||
|
status=$?
|
||||||
|
fi
|
||||||
|
|
||||||
|
END_TIME=$(date +%s)
|
||||||
|
echo "END_TIME: $END_TIME"
|
||||||
|
echo "TOTAL_TIME: $((END_TIME - BEGIN_TIME))"
|
||||||
|
|
||||||
|
if [ $status -eq 0 ]; then
|
||||||
|
python3 generate_output.py $CONFIG_FILE $((END_TIME - BEGIN_TIME))
|
||||||
|
else
|
||||||
|
echo "lerobot2rdt.py exited with status $status, skipping generate_output.py"
|
||||||
|
fi
|
||||||
|
|
||||||
26
lerobot2rdt/generate_output.py
Normal file
26
lerobot2rdt/generate_output.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
def generate_output(input_config, time):
|
||||||
|
with open(input_config, "r") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
output_dir_with_taskid = os.path.join(data["output_dir"], str(data["task_id"]))
|
||||||
|
# Ensure the output directory exists before writing the output file
|
||||||
|
os.makedirs(output_dir_with_taskid, exist_ok=True)
|
||||||
|
output_data = {
|
||||||
|
"task_id": data["task_id"],
|
||||||
|
"convert_time": time,
|
||||||
|
"data_dir": data["data_dir"],
|
||||||
|
"output_dir": output_dir_with_taskid,
|
||||||
|
"episode_num": data["episode_num"],
|
||||||
|
"no_language": data["no_language"],
|
||||||
|
}
|
||||||
|
output_json_path = os.path.join(output_dir_with_taskid, "output.json")
|
||||||
|
with open(output_json_path, "w") as f:
|
||||||
|
json.dump(output_data, f)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
input_config = sys.argv[1]
|
||||||
|
time = int(sys.argv[2])
|
||||||
|
generate_output(input_config, time)
|
||||||
368
lerobot2rdt/lerobot2rdt.py
Normal file
368
lerobot2rdt/lerobot2rdt.py
Normal file
@ -0,0 +1,368 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
LeRobot到RDT数据转换脚本
|
||||||
|
|
||||||
|
LeRobot机器人结构:
|
||||||
|
- 5个关节 (shoulder_pan, shoulder_lift, elbow_flex, wrist_flex, wrist_roll)
|
||||||
|
- 1个夹爪 (gripper)
|
||||||
|
- 总计:6个自由度 (6DOF)
|
||||||
|
|
||||||
|
维度映射(匹配RDT训练代码):
|
||||||
|
- left_arm_dim = 0 (单臂机器人,左臂不存在)
|
||||||
|
- right_arm_dim = 6 (5关节 + 1夹爪,映射到RDT的right_arm部分)
|
||||||
|
- 状态向量:6维 [joint1, joint2, joint3, joint4, joint5, gripper]
|
||||||
|
- RDT索引映射:right_arm_joint_0_pos到right_arm_joint_5_pos (索引0-5)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import h5py
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
import argparse
|
||||||
|
import yaml
|
||||||
|
import json
|
||||||
|
import subprocess
|
||||||
|
from pathlib import Path
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
|
||||||
|
current_dir = os.path.dirname(__file__)
|
||||||
|
sys.path.append(os.path.join(current_dir, ".."))
|
||||||
|
from models.multimodal_encoder.t5_encoder import T5Embedder
|
||||||
|
|
||||||
|
def extract_frames_from_video(video_path, output_dir, episode_idx):
|
||||||
|
if not os.path.exists(video_path):
|
||||||
|
print(f" No video file: {video_path}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
temp_dir = os.path.join(output_dir, f"temp_frames_{episode_idx}")
|
||||||
|
if not os.path.exists(temp_dir):
|
||||||
|
os.makedirs(temp_dir)
|
||||||
|
|
||||||
|
output_pattern = os.path.join(temp_dir, "frame_%04d.jpg")
|
||||||
|
|
||||||
|
try:
|
||||||
|
cmd = [
|
||||||
|
'ffmpeg', '-i', video_path,
|
||||||
|
'-vf', 'fps=30',
|
||||||
|
'-q:v', '2',
|
||||||
|
output_pattern,
|
||||||
|
'-y'
|
||||||
|
]
|
||||||
|
|
||||||
|
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||||
|
|
||||||
|
if result.returncode != 0:
|
||||||
|
print(f" Failed to extract frames with ffmpeg: {result.stderr}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
frames = []
|
||||||
|
frame_files = sorted([f for f in os.listdir(temp_dir) if f.endswith('.jpg')])
|
||||||
|
|
||||||
|
for frame_file in frame_files:
|
||||||
|
frame_path = os.path.join(temp_dir, frame_file)
|
||||||
|
frame = cv2.imread(frame_path)
|
||||||
|
if frame is not None:
|
||||||
|
frame_resized = cv2.resize(frame, (640, 480))
|
||||||
|
frames.append(frame_resized)
|
||||||
|
|
||||||
|
print(f" Successfully extracted {len(frames)} frames")
|
||||||
|
|
||||||
|
for frame_file in frame_files:
|
||||||
|
os.remove(os.path.join(temp_dir, frame_file))
|
||||||
|
os.rmdir(temp_dir)
|
||||||
|
|
||||||
|
return frames
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Error extracting frames: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def load_lerobot_episode(data_dir, episode_idx, output_dir, cam_high_key="high", cam_right_wrist_key="arm"):
|
||||||
|
"""加载LeRobot的单个episode数据
|
||||||
|
|
||||||
|
LeRobot数据结构:
|
||||||
|
- action: 6维 [shoulder_pan, shoulder_lift, elbow_flex, wrist_flex, wrist_roll, gripper]
|
||||||
|
- observation.state: 6维 [shoulder_pan, shoulder_lift, elbow_flex, wrist_flex, wrist_roll, gripper]
|
||||||
|
- 图像: 高位相机 + 手臂相机
|
||||||
|
"""
|
||||||
|
parquet_path = os.path.join(data_dir, "data/chunk-000", f"episode_{episode_idx:06d}.parquet")
|
||||||
|
if not os.path.exists(parquet_path):
|
||||||
|
print(f"Episode {episode_idx} parquet file does not exist: {parquet_path}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
df = pd.read_parquet(parquet_path)
|
||||||
|
|
||||||
|
actions = []
|
||||||
|
qpos = []
|
||||||
|
|
||||||
|
for i in range(len(df)):
|
||||||
|
action = df['action'].iloc[i]
|
||||||
|
state = df['observation.state'].iloc[i]
|
||||||
|
|
||||||
|
if isinstance(action, np.ndarray):
|
||||||
|
actions.append(action.astype(np.float32))
|
||||||
|
else:
|
||||||
|
actions.append(np.array(action, dtype=np.float32))
|
||||||
|
|
||||||
|
if isinstance(state, np.ndarray):
|
||||||
|
qpos.append(state.astype(np.float32))
|
||||||
|
else:
|
||||||
|
qpos.append(np.array(state, dtype=np.float32))
|
||||||
|
|
||||||
|
high_cam_path = os.path.join(data_dir, f"videos/chunk-000/observation.images.{cam_high_key}", f"episode_{episode_idx:06d}.mp4")
|
||||||
|
arm_cam_path = os.path.join(data_dir, f"videos/chunk-000/observation.images.{cam_right_wrist_key}", f"episode_{episode_idx:06d}.mp4")
|
||||||
|
|
||||||
|
print(f" Extracting high camera frames...")
|
||||||
|
high_images = extract_frames_from_video(high_cam_path, output_dir, episode_idx)
|
||||||
|
|
||||||
|
print(f" Extracting arm camera frames...")
|
||||||
|
arm_images = extract_frames_from_video(arm_cam_path, output_dir, episode_idx)
|
||||||
|
|
||||||
|
target_frames = len(df)
|
||||||
|
if len(high_images) > target_frames:
|
||||||
|
high_images = high_images[:target_frames]
|
||||||
|
if len(arm_images) > target_frames:
|
||||||
|
arm_images = arm_images[:target_frames]
|
||||||
|
|
||||||
|
while len(high_images) < target_frames and high_images:
|
||||||
|
high_images.append(high_images[-1])
|
||||||
|
while len(arm_images) < target_frames and arm_images:
|
||||||
|
arm_images.append(arm_images[-1])
|
||||||
|
|
||||||
|
return {
|
||||||
|
'actions': np.array(actions),
|
||||||
|
'qpos': np.array(qpos),
|
||||||
|
'high_images': high_images,
|
||||||
|
'arm_images': arm_images,
|
||||||
|
'episode_length': len(df)
|
||||||
|
}
|
||||||
|
|
||||||
|
def images_encoding(imgs):
|
||||||
|
if not imgs:
|
||||||
|
return [], 0
|
||||||
|
|
||||||
|
encode_data = []
|
||||||
|
padded_data = []
|
||||||
|
max_len = 0
|
||||||
|
|
||||||
|
for i in range(len(imgs)):
|
||||||
|
success, encoded_image = cv2.imencode(".jpg", imgs[i])
|
||||||
|
if success:
|
||||||
|
jpeg_data = encoded_image.tobytes()
|
||||||
|
encode_data.append(jpeg_data)
|
||||||
|
max_len = max(max_len, len(jpeg_data))
|
||||||
|
else:
|
||||||
|
print(f" Image encoding failed: {i}")
|
||||||
|
empty_data = b""
|
||||||
|
encode_data.append(empty_data)
|
||||||
|
|
||||||
|
for i in range(len(imgs)):
|
||||||
|
padded_data.append(encode_data[i].ljust(max_len, b"\0"))
|
||||||
|
|
||||||
|
return encode_data, max_len
|
||||||
|
|
||||||
|
def load_task_instructions(data_dir):
|
||||||
|
tasks_file = os.path.join(data_dir, "meta/tasks.jsonl")
|
||||||
|
if not os.path.exists(tasks_file):
|
||||||
|
print(f"Warning: tasks file not found: {tasks_file}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
instructions = []
|
||||||
|
with open(tasks_file, 'r') as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
task_data = json.loads(line.strip())
|
||||||
|
instructions.append(task_data["task"])
|
||||||
|
|
||||||
|
print(f" 加载了 {len(instructions)} 个任务指令")
|
||||||
|
return instructions
|
||||||
|
|
||||||
|
def encode_language_instruction(instruction_text, t5_embedder, device):
|
||||||
|
try:
|
||||||
|
text_embeds, attn_mask = t5_embedder.get_text_embeddings([instruction_text])
|
||||||
|
|
||||||
|
valid_embeds = text_embeds[0][attn_mask[0]].float()
|
||||||
|
return valid_embeds.cpu().numpy()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Language encoding failed: {e}")
|
||||||
|
return np.zeros((1, 4096))
|
||||||
|
|
||||||
|
def convert_lerobot_to_rdt(data_dir, output_dir, episode_num, gpu=0, no_language=False, t5_path=None, cam_high_key="high", cam_right_wrist_key="arm"):
|
||||||
|
if not os.path.exists(output_dir):
|
||||||
|
os.makedirs(output_dir)
|
||||||
|
|
||||||
|
print(f"Start converting LeRobot data to RDT format...")
|
||||||
|
print(f"Data source: {data_dir}")
|
||||||
|
print(f"Output directory: {output_dir}")
|
||||||
|
print(f"Processing episode number: {episode_num}")
|
||||||
|
print(f"GPU device: {gpu}")
|
||||||
|
|
||||||
|
scene_name = os.path.basename(data_dir)
|
||||||
|
|
||||||
|
instructions = None
|
||||||
|
if not no_language:
|
||||||
|
instructions = load_task_instructions(data_dir)
|
||||||
|
|
||||||
|
t5_embedder = None
|
||||||
|
if not no_language and instructions:
|
||||||
|
try:
|
||||||
|
print(f" Initializing T5 encoder...")
|
||||||
|
t5_embedder = T5Embedder(
|
||||||
|
from_pretrained=t5_path,
|
||||||
|
device=f"cuda:{gpu}" if torch.cuda.is_available() else "cpu",
|
||||||
|
model_max_length=1024,
|
||||||
|
use_offload_folder=None,
|
||||||
|
)
|
||||||
|
print(f" T5 encoder initialized successfully")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" T5 encoder initialization failed: {e}")
|
||||||
|
print(f" Will skip language processing")
|
||||||
|
no_language = True
|
||||||
|
|
||||||
|
for i in range(episode_num):
|
||||||
|
print(f"Processing episode {i}...")
|
||||||
|
|
||||||
|
episode_data = load_lerobot_episode(data_dir, i, output_dir, cam_high_key=cam_high_key, cam_right_wrist_key=cam_right_wrist_key)
|
||||||
|
if episode_data is None:
|
||||||
|
print(f"Skipping episode {i}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
episode_output_dir = os.path.join(output_dir, f"episode_{i}")
|
||||||
|
if not os.path.exists(episode_output_dir):
|
||||||
|
os.makedirs(episode_output_dir)
|
||||||
|
|
||||||
|
hdf5_path = os.path.join(episode_output_dir, f"episode_{i}.hdf5")
|
||||||
|
|
||||||
|
with h5py.File(hdf5_path, "w") as f:
|
||||||
|
f.create_dataset("action", data=episode_data['actions'])
|
||||||
|
|
||||||
|
obs = f.create_group("observations")
|
||||||
|
obs.create_dataset("qpos", data=episode_data['qpos'])
|
||||||
|
|
||||||
|
image = obs.create_group("images")
|
||||||
|
|
||||||
|
if episode_data['high_images']:
|
||||||
|
print(f" Encoding high camera images...")
|
||||||
|
high_enc, len_high = images_encoding(episode_data['high_images'])
|
||||||
|
if high_enc and len_high > 0:
|
||||||
|
image.create_dataset("cam_high", data=high_enc, dtype=f"S{len_high}")
|
||||||
|
print(f" Saved high camera images: {len(episode_data['high_images'])} frames")
|
||||||
|
else:
|
||||||
|
print(f" Warning: High camera images encoding failed")
|
||||||
|
|
||||||
|
if episode_data['arm_images']:
|
||||||
|
print(f" Encoding arm camera images...")
|
||||||
|
arm_enc, len_arm = images_encoding(episode_data['arm_images'])
|
||||||
|
if arm_enc and len_arm > 0:
|
||||||
|
image.create_dataset("cam_right_wrist", data=arm_enc, dtype=f"S{len_arm}")
|
||||||
|
print(f" Saved arm camera images: {len(episode_data['arm_images'])} frames")
|
||||||
|
else:
|
||||||
|
print(f" Warning: Arm camera images encoding failed")
|
||||||
|
|
||||||
|
# 添加机器人维度信息(LeRobot: 5个关节 + 1个夹爪)
|
||||||
|
# 根据process_data.py的逻辑,每个时间步都需要记录维度信息
|
||||||
|
# LeRobot是单臂机器人,只有右臂:5个关节 + 1个夹爪 = 6维
|
||||||
|
# 左臂:0维(单臂机器人)
|
||||||
|
|
||||||
|
# 为每个时间步记录维度信息
|
||||||
|
left_arm_dim = [0] * len(episode_data['actions']) # 左臂0维(单臂机器人)
|
||||||
|
right_arm_dim = [6] * len(episode_data['actions']) # 右臂6维(5关节+1夹爪)
|
||||||
|
|
||||||
|
obs.create_dataset("left_arm_dim", data=np.array(left_arm_dim))
|
||||||
|
obs.create_dataset("right_arm_dim", data=np.array(right_arm_dim))
|
||||||
|
|
||||||
|
print(f" Episode {i} converted successfully: {hdf5_path}")
|
||||||
|
print(f" Data length: {episode_data['episode_length']}")
|
||||||
|
print(f" Action shape: {episode_data['actions'].shape}")
|
||||||
|
print(f" Qpos shape: {episode_data['qpos'].shape}")
|
||||||
|
print(f" High camera frames: {len(episode_data['high_images'])}")
|
||||||
|
print(f" Arm camera frames: {len(episode_data['arm_images'])}")
|
||||||
|
|
||||||
|
if not no_language and t5_embedder and instructions:
|
||||||
|
print(f" Processing language instructions...")
|
||||||
|
try:
|
||||||
|
instruction = instructions[0]
|
||||||
|
|
||||||
|
language_features = encode_language_instruction(instruction, t5_embedder, f"cuda:{gpu}")
|
||||||
|
|
||||||
|
instructions_dir = os.path.join(episode_output_dir, "instructions")
|
||||||
|
if not os.path.exists(instructions_dir):
|
||||||
|
os.makedirs(instructions_dir)
|
||||||
|
|
||||||
|
lang_embed_path = os.path.join(instructions_dir, "lang_embed_0.pt")
|
||||||
|
torch.save(torch.from_numpy(language_features), lang_embed_path)
|
||||||
|
|
||||||
|
print(f" Language instruction encoded successfully: {instruction}")
|
||||||
|
print(f" Language features saved to: {lang_embed_path}")
|
||||||
|
print(f" Language features shape: {language_features.shape}, data type: {language_features.dtype}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Language instruction processing failed: {e}")
|
||||||
|
|
||||||
|
print(f"\nConversion completed! Processed {episode_num} episodes")
|
||||||
|
print(f"Output directory: {output_dir}")
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Convert LeRobot data to RDT format")
|
||||||
|
parser.add_argument("--data_dir", type=str, required=True,
|
||||||
|
help="LeRobot data directory path")
|
||||||
|
parser.add_argument("--output_dir", type=str, required=True,
|
||||||
|
help="Output directory path")
|
||||||
|
parser.add_argument("--episode_num", type=int, default=10,
|
||||||
|
help="Number of episodes to process")
|
||||||
|
parser.add_argument("--gpu", type=int, default=0,
|
||||||
|
help="GPU device ID")
|
||||||
|
parser.add_argument("--no_language", action="store_true",
|
||||||
|
help="Skip language processing")
|
||||||
|
parser.add_argument("--cam_high_key", type=str, default="cam_high",
|
||||||
|
help="High camera key")
|
||||||
|
parser.add_argument("--cam_right_wrist_key", type=str, default="cam_right_wrist",
|
||||||
|
help="Right wrist camera key")
|
||||||
|
parser.add_argument("--cam_left_wrist_key", type=str, default="cam_left_wrist",
|
||||||
|
help="Left wrist camera key")
|
||||||
|
parser.add_argument("--t5_path", type=str, required=True,
|
||||||
|
help="T5 model path")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if not os.path.exists(args.data_dir):
|
||||||
|
print(f"Error: Data directory does not exist: {args.data_dir}")
|
||||||
|
return
|
||||||
|
|
||||||
|
meta_file = os.path.join(args.data_dir, "meta/info.json")
|
||||||
|
if not os.path.exists(meta_file):
|
||||||
|
print(f"Error: Meta information file not found: {meta_file}")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
subprocess.run(['ffmpeg', '-version'], capture_output=True, check=True)
|
||||||
|
print("ffmpeg is available, will use ffmpeg to extract video frames")
|
||||||
|
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||||
|
print("Warning: ffmpeg is not available, image data may not be extracted correctly")
|
||||||
|
print("Please install ffmpeg: conda install -c conda-forge ffmpeg=6.1")
|
||||||
|
return
|
||||||
|
|
||||||
|
with open(meta_file, 'r') as f:
|
||||||
|
meta_info = yaml.safe_load(f)
|
||||||
|
|
||||||
|
total_episodes = meta_info.get('total_episodes', 10)
|
||||||
|
if args.episode_num > total_episodes:
|
||||||
|
print(f"Warning: Requested episode number ({args.episode_num}) exceeds available number ({total_episodes})")
|
||||||
|
args.episode_num = total_episodes
|
||||||
|
|
||||||
|
convert_lerobot_to_rdt(
|
||||||
|
args.data_dir,
|
||||||
|
args.output_dir,
|
||||||
|
args.episode_num,
|
||||||
|
args.gpu,
|
||||||
|
args.no_language,
|
||||||
|
args.t5_path,
|
||||||
|
args.cam_high_key,
|
||||||
|
args.cam_right_wrist_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
20
lerobot2rdt/read_json.py
Normal file
20
lerobot2rdt/read_json.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
import sys
|
||||||
|
import json
|
||||||
|
|
||||||
|
def read_json_value(file_path, key):
|
||||||
|
with open(file_path, "r") as file:
|
||||||
|
data = json.load(file)
|
||||||
|
value = data.get(key)
|
||||||
|
if value is not None:
|
||||||
|
print(value)
|
||||||
|
else:
|
||||||
|
print(f"Key '{key}' not found in {file_path}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
if len(sys.argv) != 3:
|
||||||
|
print("Usage: python read_json.py <file_path> <key>")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
file_path = sys.argv[1]
|
||||||
|
key = sys.argv[2]
|
||||||
|
read_json_value(file_path, key)
|
||||||
24
lerobot2rdt/requirements.txt
Normal file
24
lerobot2rdt/requirements.txt
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
numpy<2.0
|
||||||
|
packaging==24.0
|
||||||
|
deepspeed==0.14.2
|
||||||
|
accelerate==0.30.1
|
||||||
|
diffusers==0.27.2
|
||||||
|
timm==1.0.3
|
||||||
|
transformers==4.41.0
|
||||||
|
sentencepiece==0.2.0
|
||||||
|
h5py==3.11.0
|
||||||
|
opencv-python==4.9.0.80
|
||||||
|
imgaug==0.4.0
|
||||||
|
pytz==2022.1
|
||||||
|
huggingface_hub==0.23.0
|
||||||
|
pandas==2.3.3
|
||||||
|
|
||||||
|
# requirements_data.txt
|
||||||
|
# tfds-nightly==4.9.4.dev202402070044
|
||||||
|
gsutil==5.27
|
||||||
|
tensorflow==2.15.0.post1
|
||||||
|
pillow==10.2.0
|
||||||
|
pyyaml==6.0.1
|
||||||
|
tensorflow-graphics==2021.12.3
|
||||||
|
imageio==2.34.0
|
||||||
|
imageio-ffmpeg==0.4.9
|
||||||
|
Before Width: | Height: | Size: 726 KiB After Width: | Height: | Size: 726 KiB |
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user