112 lines
5.1 KiB
Python
112 lines
5.1 KiB
Python
# import packages and module here
|
|
import sys, os
|
|
sys.path.append("/home/chao.wu/SmolVLA_RoboTwin2_BPU/LeRobot_SmolVLA_Server_Fast_JPEG")
|
|
|
|
from request_tools import upload_policy, send_inference_request
|
|
from tools import show_data_summary
|
|
|
|
import numpy as np
|
|
import cv2
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
from tqdm import tqdm
|
|
|
|
# current_file_path = os.path.abspath(__file__)
|
|
# parent_directory = os.path.dirname(current_file_path)
|
|
|
|
class SmolVLA_Client:
|
|
def __init__(self, server_url, server_device, chunk_slice=32, instruction=None):
|
|
self.url = server_url.rstrip('/')
|
|
# print(f"{model_path = }")
|
|
# print(f"{self.url}/update_policy")
|
|
# print(f"{server_device = }")
|
|
self.chunk_slice = chunk_slice
|
|
self.instruction = instruction
|
|
# upload_policy(model_path, f"{self.url}/update_policy", server_device)
|
|
|
|
def get_action(self, obs):
|
|
'''
|
|
task: Pick the bottle with ridges on bottom upright from the table <class 'str'>
|
|
observation.images.cam_high: torch.Size([1, 3, 480, 640]) torch.float32 <class 'torch.Tensor'> 0.0353~1.0000
|
|
observation.images.cam_left_wrist: torch.Size([1, 3, 480, 640]) torch.float32 <class 'torch.Tensor'> 0.0549~0.9922
|
|
observation.images.cam_right_wrist: torch.Size([1, 3, 480, 640]) torch.float32 <class 'torch.Tensor'> 0.1608~1.0000
|
|
observation.state: torch.Size([1, 14]) torch.float32 <class 'torch.Tensor'> -1.5298~2.3349
|
|
'''
|
|
if self.instruction is not None:
|
|
obs["instruction"] = self.instruction # str
|
|
obs["task"] = obs['instruction']
|
|
show_data_summary(obs)
|
|
|
|
return send_inference_request(obs, url=f"{self.url}/infer")['action_chunk'][0,:self.chunk_slice,:]
|
|
# if result['message'] == 'success':
|
|
# return result['action_chunk'][0,:self.chunk_slice,:]
|
|
# print(f"[{cnt}] [Cauchy] {result = }, Try again.")
|
|
# cnt += 1
|
|
|
|
def preprocess(self, img: np.ndarray) -> bytes:
|
|
# Resize using cv2 nearest neighbor, then JPEG encode
|
|
img_resized = cv2.resize(img, (640, 480), interpolation=cv2.INTER_NEAREST)
|
|
_, buf = cv2.imencode('.jpg', img_resized, [cv2.IMWRITE_JPEG_QUALITY, 95])
|
|
return buf
|
|
def get_model(usr_args):
|
|
print(f"[Cauchy] {usr_args = }")
|
|
instruction = usr_args["task_name"].replace("_", " ")
|
|
model = SmolVLA_Client(usr_args["server_url"], usr_args["server_device"], instruction=instruction)
|
|
return model
|
|
|
|
def eval(TASK_ENV, model, observation):
|
|
# print(f"{observation.keys() =}")
|
|
|
|
obs = {}
|
|
obs["instruction"] = TASK_ENV.get_instruction() # str
|
|
obs["task"] = obs['instruction']
|
|
obs['observation.images.cam_high'] = model.preprocess(observation["observation"]["head_camera"]["rgb"])
|
|
obs['observation.images.cam_left_wrist'] = model.preprocess(observation["observation"]["left_camera"]["rgb"])
|
|
obs['observation.images.cam_right_wrist'] = model.preprocess(observation["observation"]["right_camera"]["rgb"])
|
|
obs['observation.state'] = torch.from_numpy(observation["joint_action"]["vector"]).unsqueeze(0).float().numpy()
|
|
actions = model.get_action(obs)
|
|
|
|
# print(f"{actions.shape = }")
|
|
for action in actions: # Execute each step of the action
|
|
# for action in tqdm(actions, desc="action", ncols=100):
|
|
# print(f"{action.shape = }")
|
|
TASK_ENV.take_action(action)
|
|
# TASK_ENV.get_obs()
|
|
TASK_ENV.get_obs()
|
|
return
|
|
# observation["observation"]["head_camera"]["rgb"]: np.array, np.unit8, (0~255), (240, 320, 3)
|
|
# observation["observation"]["right_camera"]["rgb"]: np.array, np.unit8, (0~255), (240, 320, 3)
|
|
# observation["observation"]["left_camera"]["rgb"]: np.array, np.unit8, (0~255), (240, 320, 3)
|
|
# observation["joint_action"]["vector"]: np.array, np.float64, (0.0~1.0), (14, )
|
|
obs = encode_obs(observation) # Post-Process Observation
|
|
instruction = TASK_ENV.get_instruction()
|
|
input_rgb_arr, input_state = [
|
|
obs["observation"]["head_camera"]["rgb"],
|
|
obs["observation"]["right_camera"]["rgb"],
|
|
obs["observation"]["left_camera"]["rgb"],
|
|
], obs["agent_pos"] # TODO
|
|
|
|
if (model.observation_window
|
|
is None): # Force an update of the observation at the first frame to avoid an empty observation window
|
|
model.set_language_instruction(instruction)
|
|
model.update_observation_window(input_rgb_arr, input_state)
|
|
|
|
actions = model.get_action()[:model.rdt_step, :] # Get Action according to observation chunk
|
|
|
|
for action in actions: # Execute each step of the action
|
|
TASK_ENV.take_action(action)
|
|
# observation = TASK_ENV.get_obs()
|
|
# obs = encode_obs(observation)
|
|
# input_rgb_arr, input_state = [
|
|
# obs["observation"]["head_camera"]["rgb"],
|
|
# obs["observation"]["right_camera"]["rgb"],
|
|
# obs["observation"]["left_camera"]["rgb"],
|
|
# ], obs["agent_pos"] # TODO
|
|
# model.update_observation_window(input_rgb_arr, input_state) # Update Observation
|
|
|
|
|
|
def reset_model(model):
|
|
return
|
|
model.reset_obsrvationwindows()
|