942 lines
32 KiB
Python
942 lines
32 KiB
Python
#!/home/lin/software/miniconda3/envs/aloha/bin/python
|
|
# -- coding: UTF-8
|
|
"""
|
|
#!/usr/bin/python3
|
|
"""
|
|
|
|
import argparse
|
|
import sys
|
|
import threading
|
|
import time
|
|
import yaml
|
|
from collections import deque
|
|
|
|
import numpy as np
|
|
import rospy
|
|
import torch
|
|
from cv_bridge import CvBridge
|
|
from geometry_msgs.msg import Twist
|
|
from nav_msgs.msg import Odometry
|
|
from PIL import Image as PImage
|
|
from sensor_msgs.msg import Image, JointState
|
|
from std_msgs.msg import Header
|
|
import cv2
|
|
|
|
from scripts.agilex_model import create_model
|
|
|
|
# sys.path.append("./")
|
|
|
|
CAMERA_NAMES = ["cam_high", "cam_right_wrist", "cam_left_wrist"]
|
|
|
|
observation_window = None
|
|
|
|
lang_embeddings = None
|
|
|
|
# debug
|
|
preload_images = None
|
|
|
|
|
|
# Initialize the model
|
|
def make_policy(args):
|
|
with open(args.config_path, "r") as fp:
|
|
config = yaml.safe_load(fp)
|
|
args.config = config
|
|
|
|
# pretrained_text_encoder_name_or_path = "google/t5-v1_1-xxl"
|
|
pretrained_vision_encoder_name_or_path = "google/siglip-so400m-patch14-384"
|
|
model = create_model(
|
|
args=args.config,
|
|
dtype=torch.bfloat16,
|
|
pretrained=args.pretrained_model_name_or_path,
|
|
# pretrained_text_encoder_name_or_path=pretrained_text_encoder_name_or_path,
|
|
pretrained_vision_encoder_name_or_path=pretrained_vision_encoder_name_or_path,
|
|
control_frequency=args.ctrl_freq,
|
|
)
|
|
|
|
return model
|
|
|
|
|
|
def set_seed(seed):
|
|
torch.manual_seed(seed)
|
|
np.random.seed(seed)
|
|
|
|
|
|
# Interpolate the actions to make the robot move smoothly
|
|
def interpolate_action(args, prev_action, cur_action):
|
|
steps = np.concatenate((np.array(args.arm_steps_length), np.array(args.arm_steps_length)), axis=0)
|
|
diff = np.abs(cur_action - prev_action)
|
|
step = np.ceil(diff / steps).astype(int)
|
|
step = np.max(step)
|
|
if step <= 1:
|
|
return cur_action[np.newaxis, :]
|
|
new_actions = np.linspace(prev_action, cur_action, step + 1)
|
|
return new_actions[1:]
|
|
|
|
|
|
def get_config(args):
|
|
config = {
|
|
"episode_len": args.max_publish_step,
|
|
"state_dim": 14,
|
|
"chunk_size": args.chunk_size,
|
|
"camera_names": CAMERA_NAMES,
|
|
}
|
|
return config
|
|
|
|
|
|
# Get the observation from the ROS topic
|
|
def get_ros_observation(args, ros_operator):
|
|
rate = rospy.Rate(args.publish_rate)
|
|
print_flag = True
|
|
|
|
while True and not rospy.is_shutdown():
|
|
result = ros_operator.get_frame()
|
|
if not result:
|
|
if print_flag:
|
|
print("syn fail when get_ros_observation")
|
|
print_flag = False
|
|
rate.sleep()
|
|
continue
|
|
print_flag = True
|
|
(
|
|
img_front,
|
|
img_left,
|
|
img_right,
|
|
img_front_depth,
|
|
img_left_depth,
|
|
img_right_depth,
|
|
puppet_arm_left,
|
|
puppet_arm_right,
|
|
robot_base,
|
|
) = result
|
|
# print(f"sync success when get_ros_observation")
|
|
return (img_front, img_left, img_right, puppet_arm_left, puppet_arm_right)
|
|
|
|
|
|
# Update the observation window buffer
|
|
def update_observation_window(args, config, ros_operator):
|
|
# JPEG transformation
|
|
# Align with training
|
|
def jpeg_mapping(img):
|
|
img = cv2.imencode(".jpg", img)[1].tobytes()
|
|
img = cv2.imdecode(np.frombuffer(img, np.uint8), cv2.IMREAD_COLOR)
|
|
return img
|
|
|
|
global observation_window
|
|
if observation_window is None:
|
|
observation_window = deque(maxlen=2)
|
|
|
|
# Append the first dummy image
|
|
observation_window.append({
|
|
"qpos": None,
|
|
"images": {
|
|
config["camera_names"][0]: None,
|
|
config["camera_names"][1]: None,
|
|
config["camera_names"][2]: None,
|
|
},
|
|
})
|
|
|
|
img_front, img_left, img_right, puppet_arm_left, puppet_arm_right = (get_ros_observation(args, ros_operator))
|
|
img_front = jpeg_mapping(img_front)
|
|
img_left = jpeg_mapping(img_left)
|
|
img_right = jpeg_mapping(img_right)
|
|
|
|
qpos = np.concatenate(
|
|
(np.array(puppet_arm_left.position), np.array(puppet_arm_right.position)),
|
|
axis=0,
|
|
)
|
|
qpos = torch.from_numpy(qpos).float().cuda()
|
|
observation_window.append({
|
|
"qpos": qpos,
|
|
"images": {
|
|
config["camera_names"][0]: img_front,
|
|
config["camera_names"][1]: img_right,
|
|
config["camera_names"][2]: img_left,
|
|
},
|
|
})
|
|
|
|
|
|
# RDT inference
|
|
def inference_fn(args, config, policy, t):
|
|
global observation_window
|
|
global lang_embeddings
|
|
|
|
# print(f"Start inference_thread_fn: t={t}")
|
|
while True and not rospy.is_shutdown():
|
|
time1 = time.time()
|
|
|
|
# fetch images in sequence [front, right, left]
|
|
image_arrs = [
|
|
observation_window[-2]["images"][config["camera_names"][0]],
|
|
observation_window[-2]["images"][config["camera_names"][1]],
|
|
observation_window[-2]["images"][config["camera_names"][2]],
|
|
observation_window[-1]["images"][config["camera_names"][0]],
|
|
observation_window[-1]["images"][config["camera_names"][1]],
|
|
observation_window[-1]["images"][config["camera_names"][2]],
|
|
]
|
|
|
|
# fetch debug images in sequence [front, right, left]
|
|
# image_arrs = [
|
|
# preload_images[config['camera_names'][0]][max(t - 1, 0)],
|
|
# preload_images[config['camera_names'][2]][max(t - 1, 0)],
|
|
# preload_images[config['camera_names'][1]][max(t - 1, 0)],
|
|
# preload_images[config['camera_names'][0]][t],
|
|
# preload_images[config['camera_names'][2]][t],
|
|
# preload_images[config['camera_names'][1]][t]
|
|
# ]
|
|
# # encode the images
|
|
# for i in range(len(image_arrs)):
|
|
# image_arrs[i] = cv2.imdecode(np.frombuffer(image_arrs[i], np.uint8), cv2.IMREAD_COLOR)
|
|
# proprio = torch.from_numpy(preload_images['qpos'][t]).float().cuda()
|
|
|
|
images = [PImage.fromarray(arr) if arr is not None else None for arr in image_arrs]
|
|
|
|
# for i, pos in enumerate(['f', 'r', 'l'] * 2):
|
|
# images[i].save(f'{t}-{i}-{pos}.png')
|
|
|
|
# get last qpos in shape [14, ]
|
|
proprio = observation_window[-1]["qpos"]
|
|
# unsqueeze to [1, 14]
|
|
proprio = proprio.unsqueeze(0)
|
|
|
|
# actions shaped as [1, 64, 14] in format [left, right]
|
|
actions = (policy.step(proprio=proprio, images=images, text_embeds=lang_embeddings).squeeze(0).cpu().numpy())
|
|
# print(f"inference_actions: {actions.squeeze()}")
|
|
|
|
# print(f"Model inference time: {time.time() - time1} s")
|
|
|
|
# print(f"Finish inference_thread_fn: t={t}")
|
|
return actions
|
|
|
|
|
|
# Main loop for the manipulation task
|
|
def model_inference(args, config, ros_operator):
|
|
global lang_embeddings
|
|
|
|
# Load rdt model
|
|
policy = make_policy(args)
|
|
|
|
lang_dict = torch.load(args.lang_embeddings_path)
|
|
print(f"Running with instruction: \"{lang_dict['instruction']}\" from \"{lang_dict['name']}\"")
|
|
lang_embeddings = lang_dict["embeddings"]
|
|
|
|
max_publish_step = config["episode_len"]
|
|
chunk_size = config["chunk_size"]
|
|
|
|
# Initialize position of the puppet arm
|
|
left0 = [
|
|
-0.00133514404296875,
|
|
0.00209808349609375,
|
|
0.01583099365234375,
|
|
-0.032616615295410156,
|
|
-0.00286102294921875,
|
|
0.00095367431640625,
|
|
3.557830810546875,
|
|
]
|
|
right0 = [
|
|
-0.00133514404296875,
|
|
0.00438690185546875,
|
|
0.034523963928222656,
|
|
-0.053597450256347656,
|
|
-0.00476837158203125,
|
|
-0.00209808349609375,
|
|
3.557830810546875,
|
|
]
|
|
left1 = [
|
|
-0.00133514404296875,
|
|
0.00209808349609375,
|
|
0.01583099365234375,
|
|
-0.032616615295410156,
|
|
-0.00286102294921875,
|
|
0.00095367431640625,
|
|
-0.3393220901489258,
|
|
]
|
|
right1 = [
|
|
-0.00133514404296875,
|
|
0.00247955322265625,
|
|
0.01583099365234375,
|
|
-0.032616615295410156,
|
|
-0.00286102294921875,
|
|
0.00095367431640625,
|
|
-0.3397035598754883,
|
|
]
|
|
ros_operator.puppet_arm_publish_continuous(left0, right0)
|
|
input("Press enter to continue")
|
|
ros_operator.puppet_arm_publish_continuous(left1, right1)
|
|
# Initialize the previous action to be the initial robot state
|
|
pre_action = np.zeros(config["state_dim"])
|
|
pre_action[:14] = np.array([
|
|
-0.00133514404296875,
|
|
0.00209808349609375,
|
|
0.01583099365234375,
|
|
-0.032616615295410156,
|
|
-0.00286102294921875,
|
|
0.00095367431640625,
|
|
-0.3393220901489258,
|
|
] + [
|
|
-0.00133514404296875,
|
|
0.00247955322265625,
|
|
0.01583099365234375,
|
|
-0.032616615295410156,
|
|
-0.00286102294921875,
|
|
0.00095367431640625,
|
|
-0.3397035598754883,
|
|
])
|
|
action = None
|
|
# Inference loop
|
|
with torch.inference_mode():
|
|
while True and not rospy.is_shutdown():
|
|
# The current time step
|
|
t = 0
|
|
rate = rospy.Rate(args.publish_rate)
|
|
|
|
action_buffer = np.zeros([chunk_size, config["state_dim"]])
|
|
|
|
while t < max_publish_step and not rospy.is_shutdown():
|
|
# Update observation window
|
|
update_observation_window(args, config, ros_operator)
|
|
|
|
# When coming to the end of the action chunk
|
|
if t % chunk_size == 0:
|
|
# Start inference
|
|
action_buffer = inference_fn(args, config, policy, t).copy()
|
|
|
|
raw_action = action_buffer[t % chunk_size]
|
|
action = raw_action
|
|
# Interpolate the original action sequence
|
|
if args.use_actions_interpolation:
|
|
# print(f"Time {t}, pre {pre_action}, act {action}")
|
|
interp_actions = interpolate_action(args, pre_action, action)
|
|
else:
|
|
interp_actions = action[np.newaxis, :]
|
|
# Execute the interpolated actions one by one
|
|
for act in interp_actions:
|
|
left_action = act[:7]
|
|
right_action = act[7:14]
|
|
|
|
if not args.disable_puppet_arm:
|
|
ros_operator.puppet_arm_publish(left_action,
|
|
right_action) # puppet_arm_publish_continuous_thread
|
|
|
|
if args.use_robot_base:
|
|
vel_action = act[14:16]
|
|
ros_operator.robot_base_publish(vel_action)
|
|
rate.sleep()
|
|
# print(f"doing action: {act}")
|
|
t += 1
|
|
|
|
print("Published Step", t)
|
|
pre_action = action.copy()
|
|
|
|
|
|
# ROS operator class
|
|
class RosOperator:
|
|
|
|
def __init__(self, args):
|
|
self.robot_base_deque = None
|
|
self.puppet_arm_right_deque = None
|
|
self.puppet_arm_left_deque = None
|
|
self.img_front_deque = None
|
|
self.img_right_deque = None
|
|
self.img_left_deque = None
|
|
self.img_front_depth_deque = None
|
|
self.img_right_depth_deque = None
|
|
self.img_left_depth_deque = None
|
|
self.bridge = None
|
|
self.puppet_arm_left_publisher = None
|
|
self.puppet_arm_right_publisher = None
|
|
self.robot_base_publisher = None
|
|
self.puppet_arm_publish_thread = None
|
|
self.puppet_arm_publish_lock = None
|
|
self.args = args
|
|
self.init()
|
|
self.init_ros()
|
|
|
|
def init(self):
|
|
self.bridge = CvBridge()
|
|
self.img_left_deque = deque()
|
|
self.img_right_deque = deque()
|
|
self.img_front_deque = deque()
|
|
self.img_left_depth_deque = deque()
|
|
self.img_right_depth_deque = deque()
|
|
self.img_front_depth_deque = deque()
|
|
self.puppet_arm_left_deque = deque()
|
|
self.puppet_arm_right_deque = deque()
|
|
self.robot_base_deque = deque()
|
|
self.puppet_arm_publish_lock = threading.Lock()
|
|
self.puppet_arm_publish_lock.acquire()
|
|
|
|
def puppet_arm_publish(self, left, right):
|
|
joint_state_msg = JointState()
|
|
joint_state_msg.header = Header()
|
|
joint_state_msg.header.stamp = rospy.Time.now() # Set timestep
|
|
joint_state_msg.name = [
|
|
"joint0",
|
|
"joint1",
|
|
"joint2",
|
|
"joint3",
|
|
"joint4",
|
|
"joint5",
|
|
"joint6",
|
|
] # 设置关节名称
|
|
joint_state_msg.position = left
|
|
self.puppet_arm_left_publisher.publish(joint_state_msg)
|
|
joint_state_msg.position = right
|
|
self.puppet_arm_right_publisher.publish(joint_state_msg)
|
|
|
|
def robot_base_publish(self, vel):
|
|
vel_msg = Twist()
|
|
vel_msg.linear.x = vel[0]
|
|
vel_msg.linear.y = 0
|
|
vel_msg.linear.z = 0
|
|
vel_msg.angular.x = 0
|
|
vel_msg.angular.y = 0
|
|
vel_msg.angular.z = vel[1]
|
|
self.robot_base_publisher.publish(vel_msg)
|
|
|
|
def puppet_arm_publish_continuous(self, left, right):
|
|
rate = rospy.Rate(self.args.publish_rate)
|
|
left_arm = None
|
|
right_arm = None
|
|
while True and not rospy.is_shutdown():
|
|
if len(self.puppet_arm_left_deque) != 0:
|
|
left_arm = list(self.puppet_arm_left_deque[-1].position)
|
|
if len(self.puppet_arm_right_deque) != 0:
|
|
right_arm = list(self.puppet_arm_right_deque[-1].position)
|
|
if left_arm is None or right_arm is None:
|
|
rate.sleep()
|
|
continue
|
|
else:
|
|
break
|
|
left_symbol = [1 if left[i] - left_arm[i] > 0 else -1 for i in range(len(left))]
|
|
right_symbol = [1 if right[i] - right_arm[i] > 0 else -1 for i in range(len(right))]
|
|
flag = True
|
|
step = 0
|
|
while flag and not rospy.is_shutdown():
|
|
if self.puppet_arm_publish_lock.acquire(False):
|
|
return
|
|
left_diff = [abs(left[i] - left_arm[i]) for i in range(len(left))]
|
|
right_diff = [abs(right[i] - right_arm[i]) for i in range(len(right))]
|
|
flag = False
|
|
for i in range(len(left)):
|
|
if left_diff[i] < self.args.arm_steps_length[i]:
|
|
left_arm[i] = left[i]
|
|
else:
|
|
left_arm[i] += left_symbol[i] * self.args.arm_steps_length[i]
|
|
flag = True
|
|
for i in range(len(right)):
|
|
if right_diff[i] < self.args.arm_steps_length[i]:
|
|
right_arm[i] = right[i]
|
|
else:
|
|
right_arm[i] += right_symbol[i] * self.args.arm_steps_length[i]
|
|
flag = True
|
|
joint_state_msg = JointState()
|
|
joint_state_msg.header = Header()
|
|
joint_state_msg.header.stamp = rospy.Time.now() # Set the timestep
|
|
joint_state_msg.name = [
|
|
"joint0",
|
|
"joint1",
|
|
"joint2",
|
|
"joint3",
|
|
"joint4",
|
|
"joint5",
|
|
"joint6",
|
|
] # 设置关节名称
|
|
joint_state_msg.position = left_arm
|
|
self.puppet_arm_left_publisher.publish(joint_state_msg)
|
|
joint_state_msg.position = right_arm
|
|
self.puppet_arm_right_publisher.publish(joint_state_msg)
|
|
step += 1
|
|
print("puppet_arm_publish_continuous:", step)
|
|
rate.sleep()
|
|
|
|
def puppet_arm_publish_linear(self, left, right):
|
|
num_step = 100
|
|
rate = rospy.Rate(200)
|
|
|
|
left_arm = None
|
|
right_arm = None
|
|
|
|
while True and not rospy.is_shutdown():
|
|
if len(self.puppet_arm_left_deque) != 0:
|
|
left_arm = list(self.puppet_arm_left_deque[-1].position)
|
|
if len(self.puppet_arm_right_deque) != 0:
|
|
right_arm = list(self.puppet_arm_right_deque[-1].position)
|
|
if left_arm is None or right_arm is None:
|
|
rate.sleep()
|
|
continue
|
|
else:
|
|
break
|
|
|
|
traj_left_list = np.linspace(left_arm, left, num_step)
|
|
traj_right_list = np.linspace(right_arm, right, num_step)
|
|
|
|
for i in range(len(traj_left_list)):
|
|
traj_left = traj_left_list[i]
|
|
traj_right = traj_right_list[i]
|
|
traj_left[-1] = left[-1]
|
|
traj_right[-1] = right[-1]
|
|
joint_state_msg = JointState()
|
|
joint_state_msg.header = Header()
|
|
joint_state_msg.header.stamp = rospy.Time.now() # 设置时间戳
|
|
joint_state_msg.name = [
|
|
"joint0",
|
|
"joint1",
|
|
"joint2",
|
|
"joint3",
|
|
"joint4",
|
|
"joint5",
|
|
"joint6",
|
|
] # 设置关节名称
|
|
joint_state_msg.position = traj_left
|
|
self.puppet_arm_left_publisher.publish(joint_state_msg)
|
|
joint_state_msg.position = traj_right
|
|
self.puppet_arm_right_publisher.publish(joint_state_msg)
|
|
rate.sleep()
|
|
|
|
def puppet_arm_publish_continuous_thread(self, left, right):
|
|
if self.puppet_arm_publish_thread is not None:
|
|
self.puppet_arm_publish_lock.release()
|
|
self.puppet_arm_publish_thread.join()
|
|
self.puppet_arm_publish_lock.acquire(False)
|
|
self.puppet_arm_publish_thread = None
|
|
self.puppet_arm_publish_thread = threading.Thread(target=self.puppet_arm_publish_continuous, args=(left, right))
|
|
self.puppet_arm_publish_thread.start()
|
|
|
|
def get_frame(self):
|
|
if (len(self.img_left_deque) == 0 or len(self.img_right_deque) == 0 or len(self.img_front_deque) == 0 or
|
|
(self.args.use_depth_image and (len(self.img_left_depth_deque) == 0 or len(self.img_right_depth_deque) == 0
|
|
or len(self.img_front_depth_deque) == 0))):
|
|
return False
|
|
if self.args.use_depth_image:
|
|
frame_time = min([
|
|
self.img_left_deque[-1].header.stamp.to_sec(),
|
|
self.img_right_deque[-1].header.stamp.to_sec(),
|
|
self.img_front_deque[-1].header.stamp.to_sec(),
|
|
self.img_left_depth_deque[-1].header.stamp.to_sec(),
|
|
self.img_right_depth_deque[-1].header.stamp.to_sec(),
|
|
self.img_front_depth_deque[-1].header.stamp.to_sec(),
|
|
])
|
|
else:
|
|
frame_time = min([
|
|
self.img_left_deque[-1].header.stamp.to_sec(),
|
|
self.img_right_deque[-1].header.stamp.to_sec(),
|
|
self.img_front_deque[-1].header.stamp.to_sec(),
|
|
])
|
|
|
|
if (len(self.img_left_deque) == 0 or self.img_left_deque[-1].header.stamp.to_sec() < frame_time):
|
|
return False
|
|
if (len(self.img_right_deque) == 0 or self.img_right_deque[-1].header.stamp.to_sec() < frame_time):
|
|
return False
|
|
if (len(self.img_front_deque) == 0 or self.img_front_deque[-1].header.stamp.to_sec() < frame_time):
|
|
return False
|
|
if (len(self.puppet_arm_left_deque) == 0 or self.puppet_arm_left_deque[-1].header.stamp.to_sec() < frame_time):
|
|
return False
|
|
if (len(self.puppet_arm_right_deque) == 0
|
|
or self.puppet_arm_right_deque[-1].header.stamp.to_sec() < frame_time):
|
|
return False
|
|
if self.args.use_depth_image and (len(self.img_left_depth_deque) == 0
|
|
or self.img_left_depth_deque[-1].header.stamp.to_sec() < frame_time):
|
|
return False
|
|
if self.args.use_depth_image and (len(self.img_right_depth_deque) == 0
|
|
or self.img_right_depth_deque[-1].header.stamp.to_sec() < frame_time):
|
|
return False
|
|
if self.args.use_depth_image and (len(self.img_front_depth_deque) == 0
|
|
or self.img_front_depth_deque[-1].header.stamp.to_sec() < frame_time):
|
|
return False
|
|
if self.args.use_robot_base and (len(self.robot_base_deque) == 0
|
|
or self.robot_base_deque[-1].header.stamp.to_sec() < frame_time):
|
|
return False
|
|
|
|
while self.img_left_deque[0].header.stamp.to_sec() < frame_time:
|
|
self.img_left_deque.popleft()
|
|
img_left = self.bridge.imgmsg_to_cv2(self.img_left_deque.popleft(), "passthrough")
|
|
|
|
while self.img_right_deque[0].header.stamp.to_sec() < frame_time:
|
|
self.img_right_deque.popleft()
|
|
img_right = self.bridge.imgmsg_to_cv2(self.img_right_deque.popleft(), "passthrough")
|
|
|
|
while self.img_front_deque[0].header.stamp.to_sec() < frame_time:
|
|
self.img_front_deque.popleft()
|
|
img_front = self.bridge.imgmsg_to_cv2(self.img_front_deque.popleft(), "passthrough")
|
|
|
|
while self.puppet_arm_left_deque[0].header.stamp.to_sec() < frame_time:
|
|
self.puppet_arm_left_deque.popleft()
|
|
puppet_arm_left = self.puppet_arm_left_deque.popleft()
|
|
|
|
while self.puppet_arm_right_deque[0].header.stamp.to_sec() < frame_time:
|
|
self.puppet_arm_right_deque.popleft()
|
|
puppet_arm_right = self.puppet_arm_right_deque.popleft()
|
|
|
|
img_left_depth = None
|
|
if self.args.use_depth_image:
|
|
while self.img_left_depth_deque[0].header.stamp.to_sec() < frame_time:
|
|
self.img_left_depth_deque.popleft()
|
|
img_left_depth = self.bridge.imgmsg_to_cv2(self.img_left_depth_deque.popleft(), "passthrough")
|
|
|
|
img_right_depth = None
|
|
if self.args.use_depth_image:
|
|
while self.img_right_depth_deque[0].header.stamp.to_sec() < frame_time:
|
|
self.img_right_depth_deque.popleft()
|
|
img_right_depth = self.bridge.imgmsg_to_cv2(self.img_right_depth_deque.popleft(), "passthrough")
|
|
|
|
img_front_depth = None
|
|
if self.args.use_depth_image:
|
|
while self.img_front_depth_deque[0].header.stamp.to_sec() < frame_time:
|
|
self.img_front_depth_deque.popleft()
|
|
img_front_depth = self.bridge.imgmsg_to_cv2(self.img_front_depth_deque.popleft(), "passthrough")
|
|
|
|
robot_base = None
|
|
if self.args.use_robot_base:
|
|
while self.robot_base_deque[0].header.stamp.to_sec() < frame_time:
|
|
self.robot_base_deque.popleft()
|
|
robot_base = self.robot_base_deque.popleft()
|
|
|
|
return (
|
|
img_front,
|
|
img_left,
|
|
img_right,
|
|
img_front_depth,
|
|
img_left_depth,
|
|
img_right_depth,
|
|
puppet_arm_left,
|
|
puppet_arm_right,
|
|
robot_base,
|
|
)
|
|
|
|
def img_left_callback(self, msg):
|
|
if len(self.img_left_deque) >= 2000:
|
|
self.img_left_deque.popleft()
|
|
self.img_left_deque.append(msg)
|
|
|
|
def img_right_callback(self, msg):
|
|
if len(self.img_right_deque) >= 2000:
|
|
self.img_right_deque.popleft()
|
|
self.img_right_deque.append(msg)
|
|
|
|
def img_front_callback(self, msg):
|
|
if len(self.img_front_deque) >= 2000:
|
|
self.img_front_deque.popleft()
|
|
self.img_front_deque.append(msg)
|
|
|
|
def img_left_depth_callback(self, msg):
|
|
if len(self.img_left_depth_deque) >= 2000:
|
|
self.img_left_depth_deque.popleft()
|
|
self.img_left_depth_deque.append(msg)
|
|
|
|
def img_right_depth_callback(self, msg):
|
|
if len(self.img_right_depth_deque) >= 2000:
|
|
self.img_right_depth_deque.popleft()
|
|
self.img_right_depth_deque.append(msg)
|
|
|
|
def img_front_depth_callback(self, msg):
|
|
if len(self.img_front_depth_deque) >= 2000:
|
|
self.img_front_depth_deque.popleft()
|
|
self.img_front_depth_deque.append(msg)
|
|
|
|
def puppet_arm_left_callback(self, msg):
|
|
if len(self.puppet_arm_left_deque) >= 2000:
|
|
self.puppet_arm_left_deque.popleft()
|
|
self.puppet_arm_left_deque.append(msg)
|
|
|
|
def puppet_arm_right_callback(self, msg):
|
|
if len(self.puppet_arm_right_deque) >= 2000:
|
|
self.puppet_arm_right_deque.popleft()
|
|
self.puppet_arm_right_deque.append(msg)
|
|
|
|
def robot_base_callback(self, msg):
|
|
if len(self.robot_base_deque) >= 2000:
|
|
self.robot_base_deque.popleft()
|
|
self.robot_base_deque.append(msg)
|
|
|
|
def init_ros(self):
|
|
rospy.init_node("joint_state_publisher", anonymous=True)
|
|
rospy.Subscriber(
|
|
self.args.img_left_topic,
|
|
Image,
|
|
self.img_left_callback,
|
|
queue_size=1000,
|
|
tcp_nodelay=True,
|
|
)
|
|
rospy.Subscriber(
|
|
self.args.img_right_topic,
|
|
Image,
|
|
self.img_right_callback,
|
|
queue_size=1000,
|
|
tcp_nodelay=True,
|
|
)
|
|
rospy.Subscriber(
|
|
self.args.img_front_topic,
|
|
Image,
|
|
self.img_front_callback,
|
|
queue_size=1000,
|
|
tcp_nodelay=True,
|
|
)
|
|
if self.args.use_depth_image:
|
|
rospy.Subscriber(
|
|
self.args.img_left_depth_topic,
|
|
Image,
|
|
self.img_left_depth_callback,
|
|
queue_size=1000,
|
|
tcp_nodelay=True,
|
|
)
|
|
rospy.Subscriber(
|
|
self.args.img_right_depth_topic,
|
|
Image,
|
|
self.img_right_depth_callback,
|
|
queue_size=1000,
|
|
tcp_nodelay=True,
|
|
)
|
|
rospy.Subscriber(
|
|
self.args.img_front_depth_topic,
|
|
Image,
|
|
self.img_front_depth_callback,
|
|
queue_size=1000,
|
|
tcp_nodelay=True,
|
|
)
|
|
rospy.Subscriber(
|
|
self.args.puppet_arm_left_topic,
|
|
JointState,
|
|
self.puppet_arm_left_callback,
|
|
queue_size=1000,
|
|
tcp_nodelay=True,
|
|
)
|
|
rospy.Subscriber(
|
|
self.args.puppet_arm_right_topic,
|
|
JointState,
|
|
self.puppet_arm_right_callback,
|
|
queue_size=1000,
|
|
tcp_nodelay=True,
|
|
)
|
|
rospy.Subscriber(
|
|
self.args.robot_base_topic,
|
|
Odometry,
|
|
self.robot_base_callback,
|
|
queue_size=1000,
|
|
tcp_nodelay=True,
|
|
)
|
|
self.puppet_arm_left_publisher = rospy.Publisher(self.args.puppet_arm_left_cmd_topic, JointState, queue_size=10)
|
|
self.puppet_arm_right_publisher = rospy.Publisher(self.args.puppet_arm_right_cmd_topic,
|
|
JointState,
|
|
queue_size=10)
|
|
self.robot_base_publisher = rospy.Publisher(self.args.robot_base_cmd_topic, Twist, queue_size=10)
|
|
|
|
|
|
def get_arguments():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--max_publish_step",
|
|
action="store",
|
|
type=int,
|
|
help="Maximum number of action publishing steps",
|
|
default=10000,
|
|
required=False,
|
|
)
|
|
parser.add_argument(
|
|
"--seed",
|
|
action="store",
|
|
type=int,
|
|
help="Random seed",
|
|
default=None,
|
|
required=False,
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--img_front_topic",
|
|
action="store",
|
|
type=str,
|
|
help="img_front_topic",
|
|
default="/camera_f/color/image_raw",
|
|
required=False,
|
|
)
|
|
parser.add_argument(
|
|
"--img_left_topic",
|
|
action="store",
|
|
type=str,
|
|
help="img_left_topic",
|
|
default="/camera_l/color/image_raw",
|
|
required=False,
|
|
)
|
|
parser.add_argument(
|
|
"--img_right_topic",
|
|
action="store",
|
|
type=str,
|
|
help="img_right_topic",
|
|
default="/camera_r/color/image_raw",
|
|
required=False,
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--img_front_depth_topic",
|
|
action="store",
|
|
type=str,
|
|
help="img_front_depth_topic",
|
|
default="/camera_f/depth/image_raw",
|
|
required=False,
|
|
)
|
|
parser.add_argument(
|
|
"--img_left_depth_topic",
|
|
action="store",
|
|
type=str,
|
|
help="img_left_depth_topic",
|
|
default="/camera_l/depth/image_raw",
|
|
required=False,
|
|
)
|
|
parser.add_argument(
|
|
"--img_right_depth_topic",
|
|
action="store",
|
|
type=str,
|
|
help="img_right_depth_topic",
|
|
default="/camera_r/depth/image_raw",
|
|
required=False,
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--puppet_arm_left_cmd_topic",
|
|
action="store",
|
|
type=str,
|
|
help="puppet_arm_left_cmd_topic",
|
|
default="/master/joint_left",
|
|
required=False,
|
|
)
|
|
parser.add_argument(
|
|
"--puppet_arm_right_cmd_topic",
|
|
action="store",
|
|
type=str,
|
|
help="puppet_arm_right_cmd_topic",
|
|
default="/master/joint_right",
|
|
required=False,
|
|
)
|
|
parser.add_argument(
|
|
"--puppet_arm_left_topic",
|
|
action="store",
|
|
type=str,
|
|
help="puppet_arm_left_topic",
|
|
default="/puppet/joint_left",
|
|
required=False,
|
|
)
|
|
parser.add_argument(
|
|
"--puppet_arm_right_topic",
|
|
action="store",
|
|
type=str,
|
|
help="puppet_arm_right_topic",
|
|
default="/puppet/joint_right",
|
|
required=False,
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--robot_base_topic",
|
|
action="store",
|
|
type=str,
|
|
help="robot_base_topic",
|
|
default="/odom_raw",
|
|
required=False,
|
|
)
|
|
parser.add_argument(
|
|
"--robot_base_cmd_topic",
|
|
action="store",
|
|
type=str,
|
|
help="robot_base_topic",
|
|
default="/cmd_vel",
|
|
required=False,
|
|
)
|
|
parser.add_argument(
|
|
"--use_robot_base",
|
|
action="store_true",
|
|
help="Whether to use the robot base to move around",
|
|
default=False,
|
|
required=False,
|
|
)
|
|
parser.add_argument(
|
|
"--publish_rate",
|
|
action="store",
|
|
type=int,
|
|
help="The rate at which to publish the actions",
|
|
default=30,
|
|
required=False,
|
|
)
|
|
parser.add_argument(
|
|
"--ctrl_freq",
|
|
action="store",
|
|
type=int,
|
|
help="The control frequency of the robot",
|
|
default=25,
|
|
required=False,
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--chunk_size",
|
|
action="store",
|
|
type=int,
|
|
help="Action chunk size",
|
|
default=64,
|
|
required=False,
|
|
)
|
|
parser.add_argument(
|
|
"--arm_steps_length",
|
|
action="store",
|
|
type=float,
|
|
help="The maximum change allowed for each joint per timestep",
|
|
default=[0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.2],
|
|
required=False,
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--use_actions_interpolation",
|
|
action="store_true",
|
|
help="Whether to interpolate the actions if the difference is too large",
|
|
default=False,
|
|
required=False,
|
|
)
|
|
parser.add_argument(
|
|
"--use_depth_image",
|
|
action="store_true",
|
|
help="Whether to use depth images",
|
|
default=False,
|
|
required=False,
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--disable_puppet_arm",
|
|
action="store_true",
|
|
help="Whether to disable the puppet arm. This is useful for safely debugging",
|
|
default=False,
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--config_path",
|
|
type=str,
|
|
default="configs/base.yaml",
|
|
help="Path to the config file",
|
|
)
|
|
# parser.add_argument('--cfg_scale', type=float, default=2.0,
|
|
# help='the scaling factor used to modify the magnitude of the control features during denoising')
|
|
parser.add_argument(
|
|
"--pretrained_model_name_or_path",
|
|
type=str,
|
|
required=True,
|
|
help="Name or path to the pretrained model",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--lang_embeddings_path",
|
|
type=str,
|
|
required=True,
|
|
help="Path to the pre-encoded language instruction embeddings",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def main():
|
|
args = get_arguments()
|
|
ros_operator = RosOperator(args)
|
|
if args.seed is not None:
|
|
set_seed(args.seed)
|
|
config = get_config(args)
|
|
model_inference(args, config, ros_operator)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|