Cauchy Update
This commit is contained in:
commit
3eb02c372a
0
.gitignore
vendored
Normal file
0
.gitignore
vendored
Normal file
6
README.md
Normal file
6
README.md
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
# SmolVLA Tools
|
||||||
|
|
||||||
|
[飞书文档 Link](https://horizonrobotics.feishu.cn/docx/QnODdFmqboTg5PxyFXWc5lnMnMb)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
168
RoboTwin_Policy/BPU_LeRobot_SmolVLA_Cloud4VLA_BPU4A.py
Normal file
168
RoboTwin_Policy/BPU_LeRobot_SmolVLA_Cloud4VLA_BPU4A.py
Normal file
@ -0,0 +1,168 @@
|
|||||||
|
# Copyright (c) 2025, Cauchy WuChao, D-Robotics.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import os
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "5"
|
||||||
|
# os.environ["HTTPS_PROXY"] = "http://192.168.16.68:18000"
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import tempfile
|
||||||
|
import tarfile
|
||||||
|
from tools import measure_time, show_data_summary
|
||||||
|
from binary_protocol import dict_to_binary, binary_to_dict
|
||||||
|
from request_tools import send_inference_request
|
||||||
|
|
||||||
|
from pyCauchyKesai import CauchyKesai
|
||||||
|
from flask import Flask, Response, request
|
||||||
|
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
# lerobot.constants
|
||||||
|
ACTION = "action"
|
||||||
|
OBS_STATE = "observation.state"
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level = logging.INFO,
|
||||||
|
format = '[%(name)s] [%(asctime)s.%(msecs)03d] [%(levelname)s] %(message)s',
|
||||||
|
datefmt='%H:%M:%S')
|
||||||
|
logger = logging.getLogger("SmolVLA_Server")
|
||||||
|
|
||||||
|
app = Flask(__name__)
|
||||||
|
|
||||||
|
global smolvla_policy_a, normalize_params, device, action_feature_dim, vl_server_url
|
||||||
|
smolvla_policy_a = None
|
||||||
|
normalize_params = None
|
||||||
|
device = None
|
||||||
|
action_feature_dim = None
|
||||||
|
vl_server_url = None
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/infer', methods=['POST'])
|
||||||
|
def infer():
|
||||||
|
global smolvla_policy_a, normalize_params, device, action_feature_dim, vl_server_url
|
||||||
|
|
||||||
|
def unnormalize_outputs(batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
|
batch[ACTION] = batch[ACTION] * normalize_params["unnormalize_outputs.std"] + normalize_params["unnormalize_outputs.mean"]
|
||||||
|
return batch
|
||||||
|
|
||||||
|
# global check
|
||||||
|
if smolvla_policy_a is None:
|
||||||
|
return Response(dict_to_binary({"status": "error", "message": "Service not ready"}),
|
||||||
|
status=503, mimetype='application/octet-stream')
|
||||||
|
# try:
|
||||||
|
if True:
|
||||||
|
# binary_to_dict
|
||||||
|
begin_time_part = time.time()
|
||||||
|
raw_data = request.data
|
||||||
|
logger.info(f"request.data time = {1000*(time.time() - begin_time_part):.2f} ms")
|
||||||
|
begin_time_part = time.time()
|
||||||
|
data = binary_to_dict(raw_data)
|
||||||
|
# show_data_summary(data)
|
||||||
|
obs = {}
|
||||||
|
for k in data.keys():
|
||||||
|
if isinstance(data[k], np.ndarray):
|
||||||
|
obs[k] = data[k]
|
||||||
|
else:
|
||||||
|
obs[k] = data[k]
|
||||||
|
logger.info(f"binary_to_dict time = {1000*(time.time() - begin_time_part):.2f} ms")
|
||||||
|
# show_data_summary(obs)
|
||||||
|
|
||||||
|
begin_time = time.time()
|
||||||
|
# 送入云端推理得到kv caches
|
||||||
|
begin_time_part = time.time()
|
||||||
|
diffusion = send_inference_request(obs, url=f"{vl_server_url}", timeout=10, max_retries=3)
|
||||||
|
logger.info(f"Cloud4VL inference+post time = {1000*(time.time() - begin_time_part):.2f} ms")
|
||||||
|
|
||||||
|
# show_data_summary(diffusion)
|
||||||
|
# 解析云端推理结果并开始BPU推理A(10次flow matching)
|
||||||
|
begin_time_part = time.time()
|
||||||
|
num_steps = diffusion['num_steps']
|
||||||
|
prefix_pad_masks = diffusion['prefix_pad_masks'].astype(np.bool)
|
||||||
|
x_t = diffusion['x_t'].astype(np.float32)
|
||||||
|
past_key_values = []
|
||||||
|
for i in range(16):
|
||||||
|
past_key_values.append(diffusion[f'k_{i}'].astype(np.float32))
|
||||||
|
past_key_values.append(diffusion[f'v_{i}'].astype(np.float32))
|
||||||
|
|
||||||
|
# print(f"{prefix_pad_masks.dtype = }")
|
||||||
|
# print(f"{x_t.dtype = }")
|
||||||
|
# # # print(f"{time_fm.dtype = }")
|
||||||
|
# for cnt, _ in enumerate(past_key_values):
|
||||||
|
# print(f"{cnt}, {_.dtype = }")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
dt = -1.0 / num_steps
|
||||||
|
time_fm = np.array([1.0]).astype(np.float32)
|
||||||
|
while time_fm >= -dt / 2:
|
||||||
|
v_t = smolvla_policy_a([prefix_pad_masks, x_t, time_fm, *past_key_values])[0]
|
||||||
|
# Euler step
|
||||||
|
x_t += dt * v_t
|
||||||
|
time_fm += dt
|
||||||
|
logger.info(f"BPU4A 10 flow matching time = {1000*(time.time() - begin_time_part):.2f} ms")
|
||||||
|
|
||||||
|
# Post Process
|
||||||
|
begin_time_part = time.time()
|
||||||
|
actions = torch.from_numpy(x_t)
|
||||||
|
actions = actions[:, :, :action_feature_dim]
|
||||||
|
actions = unnormalize_outputs({ACTION: actions})[ACTION]
|
||||||
|
action_chunk = actions.numpy()
|
||||||
|
logger.info(f"Post Process time = {1000*(time.time() - begin_time_part):.2f} ms")
|
||||||
|
|
||||||
|
logger.info("\033[1;31m" + f"Cloud4VL, BPU4A, e2e inference time = {1000*(time.time() - begin_time):.2f} ms" + "\033[0m")
|
||||||
|
|
||||||
|
# dict_to_binary
|
||||||
|
response_blob = dict_to_binary({"status": "ok", 'action_chunk': action_chunk})
|
||||||
|
|
||||||
|
return Response(response_blob, status=200, mimetype='application/octet-stream')
|
||||||
|
|
||||||
|
# except Exception as e:
|
||||||
|
# logger.error(e)
|
||||||
|
# return Response(dict_to_binary({"status": "error", "message": str(e)}),
|
||||||
|
# status=500, mimetype='application/octet-stream')
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
root_path = "/root/ssd/SmolVLA_BPU/LeRobot_SmolVLA_Quantization_Small/LeRobot_SmolVLA_Quantization/ls_export_adjust_bottle_randomized_500_fps15_1instruction/board_outputs_all"
|
||||||
|
parser.add_argument('--action-hbm-model', type=str, default=f"{root_path}/action_expert_featuremaps.hbm", help="")
|
||||||
|
parser.add_argument('--state-normalize', type=str, default=f"{root_path}/state_normalize_unnormalize.pt", help="")
|
||||||
|
parser.add_argument('--action_feature_dim', type=int, default=14, help="LeRobot So101: 6, 14")
|
||||||
|
parser.add_argument('--vl-server-url', type=str, default="http://10.112.20.37:50002/infer_vl", help="")
|
||||||
|
# parser.add_argument('--vl-server-url', type=str, default="http://120.48.157.2:50002/infer_vl", help="")
|
||||||
|
parser.add_argument('--device', type=str, default="cpu", help="")
|
||||||
|
parser.add_argument('--port', type=int, default=60002, help="")
|
||||||
|
opt = parser.parse_args()
|
||||||
|
logger.info(opt)
|
||||||
|
|
||||||
|
global smolvla_policy_a, normalize_params, device, action_feature_dim, vl_server_url
|
||||||
|
device = torch.device(opt.device)
|
||||||
|
normalize_params = torch.load(opt.state_normalize)
|
||||||
|
action_feature_dim = opt.action_feature_dim
|
||||||
|
vl_server_url = opt.vl_server_url
|
||||||
|
|
||||||
|
logger.info("Loading model ...")
|
||||||
|
smolvla_policy_a = CauchyKesai(opt.action_hbm_model)
|
||||||
|
smolvla_policy_a.s()
|
||||||
|
|
||||||
|
app.run(host='0.0.0.0', port=opt.port, threaded=False, debug=False)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
199
RoboTwin_Policy/GPU_LeRobot_SmolVLA_Cloud4VLA.py
Normal file
199
RoboTwin_Policy/GPU_LeRobot_SmolVLA_Cloud4VLA.py
Normal file
@ -0,0 +1,199 @@
|
|||||||
|
# Copyright (c) 2025, Cauchy WuChao, D-Robotics.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import os
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
|
||||||
|
os.environ["HTTPS_PROXY"] = "http://192.168.16.68:18000"
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
import json
|
||||||
|
import gzip
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
import tempfile
|
||||||
|
import tarfile
|
||||||
|
import shutil
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from tools import measure_time, show_data_summary
|
||||||
|
from binary_protocol import dict_to_binary, binary_to_dict, decode_images_jpeg
|
||||||
|
|
||||||
|
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
|
||||||
|
from flask import Flask, Response, request
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level = logging.INFO,
|
||||||
|
format = '[%(name)s] [%(asctime)s.%(msecs)03d] [%(levelname)s] %(message)s',
|
||||||
|
datefmt='%H:%M:%S')
|
||||||
|
logger = logging.getLogger("SmolVLA_Server")
|
||||||
|
|
||||||
|
app = Flask(__name__)
|
||||||
|
|
||||||
|
global smolvla_policy, device
|
||||||
|
smolvla_policy = None
|
||||||
|
device = None
|
||||||
|
|
||||||
|
@app.route('/infer', methods=['POST'])
|
||||||
|
def infer():
|
||||||
|
# global check
|
||||||
|
global smolvla_policy, device
|
||||||
|
if smolvla_policy is None:
|
||||||
|
return Response(dict_to_binary({"status": "error", "message": "Service not ready"}),
|
||||||
|
status=503, mimetype='application/octet-stream')
|
||||||
|
# try:
|
||||||
|
if True:
|
||||||
|
# binary_to_dict
|
||||||
|
begin_time = time.time()
|
||||||
|
begin_time_part = time.time()
|
||||||
|
raw_data = request.data
|
||||||
|
logger.info("\033[1;31m" + f"request.data time = {1000*(time.time() - begin_time_part):.2f} ms" + "\033[0m")
|
||||||
|
begin_time_part = time.time()
|
||||||
|
data = binary_to_dict(raw_data)
|
||||||
|
logger.info("\033[1;31m" + f"binary_to_dict time = {1000*(time.time() - begin_time_part):.2f} ms" + "\033[0m")
|
||||||
|
begin_time_part = time.time()
|
||||||
|
for k in ['observation.images.cam_high', 'observation.images.cam_left_wrist', 'observation.images.cam_right_wrist']:
|
||||||
|
data[k] = cv2.imdecode(data[k], cv2.IMREAD_COLOR).transpose(2, 0, 1)[np.newaxis].astype(np.float32) / 255.0 # (H, W, 3) -> (1, 3, H, W), uint8 -> float32 [0,1]
|
||||||
|
logger.info("\033[1;31m" + f"jpeg_decode time = {1000*(time.time() - begin_time_part):.2f} ms" + "\033[0m")
|
||||||
|
begin_time_part = time.time()
|
||||||
|
obs = {}
|
||||||
|
for k in data.keys():
|
||||||
|
if isinstance(data[k], np.ndarray):
|
||||||
|
obs[k] = torch.from_numpy(data[k]).to(device)
|
||||||
|
else:
|
||||||
|
obs[k] = data[k]
|
||||||
|
logger.info("\033[1;31m" + f"obs,np2tensor time = {1000*(time.time() - begin_time_part):.2f} ms" + "\033[0m")
|
||||||
|
# inference
|
||||||
|
begin_time = time.time()
|
||||||
|
with torch.no_grad():
|
||||||
|
action_chunk = smolvla_policy.predict_action_chunk(obs).detach().cpu().numpy()#[0,0,:]
|
||||||
|
logger.info("\033[1;31m" + f"{device} inference time = {1000*(time.time() - begin_time):.2f} ms" + "\033[0m")
|
||||||
|
#dict_to_binary
|
||||||
|
begin_time = time.time()
|
||||||
|
response_blob = dict_to_binary({"status": "ok",'action_chunk': action_chunk})
|
||||||
|
logger.info("\033[1;31m" + f"dict_to_binary time = {1000*(time.time() - begin_time):.2f} ms" + "\033[0m")
|
||||||
|
|
||||||
|
return Response(response_blob, status=200, mimetype='application/octet-stream')
|
||||||
|
|
||||||
|
# return encode_request_data({'action_chunk': action_chunk, 'message': 'success'}, compressed=compressed)
|
||||||
|
# except Exception as e:
|
||||||
|
# logger.error(e)
|
||||||
|
# return encode_request_data({'message': 'error'}, compressed=compressed)
|
||||||
|
|
||||||
|
@app.route('/update_policy', methods=['POST'])
|
||||||
|
def update_policy():
|
||||||
|
global smolvla_policy, device
|
||||||
|
# update device
|
||||||
|
device = torch.device(request.form.get('device', '').strip())
|
||||||
|
|
||||||
|
# update model
|
||||||
|
if 'file' not in request.files:
|
||||||
|
return {"error": "No file part"}, 400
|
||||||
|
file = request.files['file']
|
||||||
|
if file.filename == '':
|
||||||
|
return {"error": "No selected file"}, 400
|
||||||
|
if not file.filename.endswith('.tar'):
|
||||||
|
return {"error": "Only .tar (uncompressed archive) is allowed"}, 400
|
||||||
|
|
||||||
|
temp_dir = None
|
||||||
|
try:
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
logger.info(f"Extracting uploaded policy to temporary directory: {temp_dir}")
|
||||||
|
|
||||||
|
# 保存并解压 .tar 文件
|
||||||
|
tar_path = os.path.join(temp_dir, 'uploaded.tar')
|
||||||
|
file.save(tar_path)
|
||||||
|
|
||||||
|
with tarfile.open(tar_path, 'r') as tar:
|
||||||
|
# 安全检查:防止路径遍历(CVE-2007-4559 类问题)
|
||||||
|
def is_within_directory(directory, target):
|
||||||
|
abs_directory = os.path.abspath(directory)
|
||||||
|
abs_target = os.path.abspath(target)
|
||||||
|
prefix = os.path.commonprefix([abs_directory, abs_target])
|
||||||
|
return prefix == abs_directory
|
||||||
|
|
||||||
|
def safe_extract(tar, path=".", members=None, *, numeric_owner=False):
|
||||||
|
for member in tar.getmembers():
|
||||||
|
member_path = os.path.join(path, member.name)
|
||||||
|
if not is_within_directory(path, member_path):
|
||||||
|
raise Exception("Attempted Path Traversal in Tar File")
|
||||||
|
tar.extractall(path, members, numeric_owner=numeric_owner)
|
||||||
|
|
||||||
|
safe_extract(tar, path=temp_dir)
|
||||||
|
|
||||||
|
# 移除 tar 文件,只保留解压内容
|
||||||
|
os.remove(tar_path)
|
||||||
|
|
||||||
|
# 找到实际的模型目录(解压后可能多一层)
|
||||||
|
extracted_items = os.listdir(temp_dir)
|
||||||
|
if len(extracted_items) == 1 and os.path.isdir(os.path.join(temp_dir, extracted_items[0])):
|
||||||
|
model_dir = os.path.join(temp_dir, extracted_items[0])
|
||||||
|
else:
|
||||||
|
model_dir = temp_dir
|
||||||
|
|
||||||
|
# 验证必要文件是否存在
|
||||||
|
required_files = {'config.json', 'model.safetensors'}
|
||||||
|
if not required_files.issubset(set(os.listdir(model_dir))):
|
||||||
|
return {"error": f"Missing required files: {required_files}"}, 400
|
||||||
|
|
||||||
|
# 尝试加载新策略
|
||||||
|
logger.info(f"Loading new policy from: {model_dir}")
|
||||||
|
new_policy = SmolVLAPolicy.from_pretrained(model_dir).to(device).eval()
|
||||||
|
|
||||||
|
# 替换全局策略
|
||||||
|
smolvla_policy = new_policy
|
||||||
|
logger.info("Policy updated successfully!")
|
||||||
|
|
||||||
|
return {"message": "Policy updated successfully"}, 200
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to update policy: {e}", exc_info=True)
|
||||||
|
return {"error": str(e)}, 500
|
||||||
|
|
||||||
|
|
||||||
|
# @measure_time(logger)
|
||||||
|
# def decode_request_data(raw_data, compressed: bool = False, device: torch.device = torch.device('cpu')):
|
||||||
|
# raw_data = gzip.decompress(raw_data) if compressed else raw_data
|
||||||
|
# return deserialize_dict(json.loads(raw_data.decode('utf-8')), device=device)
|
||||||
|
|
||||||
|
# @measure_time(logger)
|
||||||
|
# def encode_request_data(data: dict, compressed: bool = False):
|
||||||
|
# resp_bytes = json.dumps(serialize_dict(data)).encode('utf-8')
|
||||||
|
# resp_bytes = gzip.compress(resp_bytes) if compressed else resp_bytes
|
||||||
|
# headers = {'Content-Type': 'application/json'}
|
||||||
|
# headers['compressed'] = 'gzip' if compressed else 'raw'
|
||||||
|
# return app.response_class(response=resp_bytes, status=200, headers=headers)
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--smolvla-path', type=str, default="/home/chao.wu/SmolVLA_RoboTwin2_BPU/train_result_rtw15fps_1instruction/adjust_bottle_clean_50_fps15_1instruction/checkpoints/040000/pretrained_model", help="")
|
||||||
|
parser.add_argument('--device', type=str, default="cuda:0", help="")
|
||||||
|
parser.add_argument('--port', type=int, default=60002, help="")
|
||||||
|
opt = parser.parse_args()
|
||||||
|
logger.info(opt)
|
||||||
|
|
||||||
|
# logger.info("Loading model ...")
|
||||||
|
global smolvla_policy, device
|
||||||
|
device = torch.device(opt.device)
|
||||||
|
smolvla_policy = SmolVLAPolicy.from_pretrained(opt.smolvla_path).to(device).eval()
|
||||||
|
|
||||||
|
app.run(host='127.0.0.1', port=opt.port, threaded=False, debug=False)
|
||||||
|
# app.run(host='127.0.0.1', port=opt.port, threaded=False, debug=False)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
230
RoboTwin_Policy/GPU_LeRobot_SmolVLA_Cloud4VL_BPU4A.py
Normal file
230
RoboTwin_Policy/GPU_LeRobot_SmolVLA_Cloud4VL_BPU4A.py
Normal file
@ -0,0 +1,230 @@
|
|||||||
|
# Copyright (c) 2025, Cauchy WuChao, D-Robotics.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import os
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
|
||||||
|
os.environ["HTTPS_PROXY"] = "http://192.168.16.68:18000"
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
import copy
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
import tempfile
|
||||||
|
import tarfile
|
||||||
|
from tools import measure_time, show_data_summary
|
||||||
|
from binary_protocol import dict_to_binary, binary_to_dict, decode_images_jpeg
|
||||||
|
|
||||||
|
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy, make_att_2d_masks
|
||||||
|
from flask import Flask, Response, request
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level = logging.INFO,
|
||||||
|
format = '[%(name)s] [%(asctime)s.%(msecs)03d] [%(levelname)s] %(message)s',
|
||||||
|
datefmt='%H:%M:%S')
|
||||||
|
logger = logging.getLogger("SmolVLA_Server")
|
||||||
|
|
||||||
|
app = Flask(__name__)
|
||||||
|
|
||||||
|
global smolvla_policy, device
|
||||||
|
smolvla_policy = None
|
||||||
|
device = None
|
||||||
|
|
||||||
|
@app.route('/infer_vl', methods=['POST'])
|
||||||
|
def infer_vl():
|
||||||
|
# global check
|
||||||
|
global smolvla_policy, device
|
||||||
|
policy = smolvla_policy
|
||||||
|
if policy is None:
|
||||||
|
return Response(dict_to_binary({"status": "error", "message": "Service not ready"}),
|
||||||
|
status=503, mimetype='application/octet-stream')
|
||||||
|
# try:
|
||||||
|
if True:
|
||||||
|
# binary_to_dict
|
||||||
|
begin_time = time.time()
|
||||||
|
begin_time_part = time.time()
|
||||||
|
raw_data = request.data
|
||||||
|
logger.info(f"request.data time = {1000*(time.time() - begin_time_part):.2f} ms")
|
||||||
|
begin_time_part = time.time()
|
||||||
|
data = binary_to_dict(raw_data)
|
||||||
|
logger.info(f"binary_to_dict time = {1000*(time.time() - begin_time_part):.2f} ms")
|
||||||
|
begin_time_part = time.time()
|
||||||
|
for k in ['observation.images.cam_high', 'observation.images.cam_left_wrist', 'observation.images.cam_right_wrist']:
|
||||||
|
data[k] = cv2.imdecode(data[k], cv2.IMREAD_COLOR).transpose(2, 0, 1)[np.newaxis].astype(np.float32) / 255.0 # (H, W, 3) -> (1, 3, H, W), uint8 -> float32 [0,1]
|
||||||
|
logger.info(f"jpeg_decode time = {1000*(time.time() - begin_time_part):.2f} ms")
|
||||||
|
# show_data_summary(data)
|
||||||
|
begin_time_part = time.time()
|
||||||
|
obs = {}
|
||||||
|
for k in data.keys():
|
||||||
|
if isinstance(data[k], np.ndarray):
|
||||||
|
obs[k] = torch.from_numpy(data[k]).to(device)
|
||||||
|
elif isinstance(data[k], torch.Tensor):
|
||||||
|
obs[k] = data[k].to(device)
|
||||||
|
else:
|
||||||
|
obs[k] = data[k]
|
||||||
|
logger.info(f"obs,np2tensor time = {1000*(time.time() - begin_time_part):.2f} ms")
|
||||||
|
# show_data_summary(obs)
|
||||||
|
# VL inference
|
||||||
|
with torch.no_grad():
|
||||||
|
begin_time = time.time()
|
||||||
|
# only state need normalize
|
||||||
|
batch = policy.normalize_inputs(obs)
|
||||||
|
# prepare inputs
|
||||||
|
images, img_masks = policy.prepare_images(batch)
|
||||||
|
state = policy.prepare_state(batch)
|
||||||
|
lang_tokens, lang_masks = policy.prepare_language(batch)
|
||||||
|
bsize = 1
|
||||||
|
# noise
|
||||||
|
actions_shape = (bsize, policy.model.config.chunk_size, policy.model.config.max_action_dim)
|
||||||
|
noise = torch.normal(
|
||||||
|
mean=0.0,
|
||||||
|
std=1.0,
|
||||||
|
size=actions_shape,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
# prepare tokens
|
||||||
|
prefix_embs, prefix_pad_masks, prefix_att_masks = policy.model.embed_prefix(
|
||||||
|
images, img_masks, lang_tokens, lang_masks, state=state
|
||||||
|
)
|
||||||
|
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
||||||
|
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||||
|
|
||||||
|
# get kv cache
|
||||||
|
_, past_key_values = policy.model.vlm_with_expert.forward(
|
||||||
|
attention_mask=prefix_att_2d_masks,
|
||||||
|
position_ids=prefix_position_ids,
|
||||||
|
past_key_values=None,
|
||||||
|
inputs_embeds=[prefix_embs, None],
|
||||||
|
use_cache=policy.model.config.use_cache,
|
||||||
|
fill_kv_cache=True,
|
||||||
|
)
|
||||||
|
logger.info("\033[1;31m" + f"GPU VL inference time = {1000*(time.time() - begin_time):.2f} ms" + "\033[0m")
|
||||||
|
|
||||||
|
begin_time = time.time()
|
||||||
|
# prepare diffusion output
|
||||||
|
diffusion = {'status': 'ok'}
|
||||||
|
diffusion['num_steps'] = policy.model.config.num_steps
|
||||||
|
diffusion['prefix_pad_masks'] = prefix_pad_masks.detach().cpu().numpy()
|
||||||
|
diffusion['x_t'] = noise.detach().cpu().numpy()
|
||||||
|
for i in range(len(past_key_values)):
|
||||||
|
diffusion[f'k_{i}'] = past_key_values[i]['key_states'].detach().to(torch.float32).cpu().numpy()
|
||||||
|
diffusion[f'v_{i}'] = past_key_values[i]['value_states'].detach().to(torch.float32).cpu().numpy()
|
||||||
|
|
||||||
|
# show_data_summary(diffusion)
|
||||||
|
# dict_to_binary
|
||||||
|
response_blob = dict_to_binary(diffusion)
|
||||||
|
logger.info(f"dict_to_binary time = {1000*(time.time() - begin_time):.2f} ms")
|
||||||
|
|
||||||
|
return Response(response_blob, status=200, mimetype='application/octet-stream')
|
||||||
|
|
||||||
|
# except Exception as e:
|
||||||
|
# logger.error(e)
|
||||||
|
# return Response(dict_to_binary({"status": "error", "message": str(e)}),
|
||||||
|
# status=500, mimetype='application/octet-stream')
|
||||||
|
|
||||||
|
@app.route('/update_policy', methods=['POST'])
|
||||||
|
def update_policy():
|
||||||
|
global smolvla_policy, device
|
||||||
|
# update device
|
||||||
|
device = torch.device(request.form.get('device', '').strip())
|
||||||
|
|
||||||
|
# update model
|
||||||
|
if 'file' not in request.files:
|
||||||
|
return {"error": "No file part"}, 400
|
||||||
|
file = request.files['file']
|
||||||
|
if file.filename == '':
|
||||||
|
return {"error": "No selected file"}, 400
|
||||||
|
if not file.filename.endswith('.tar'):
|
||||||
|
return {"error": "Only .tar (uncompressed archive) is allowed"}, 400
|
||||||
|
|
||||||
|
temp_dir = None
|
||||||
|
try:
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
logger.info(f"Extracting uploaded policy to temporary directory: {temp_dir}")
|
||||||
|
|
||||||
|
# 保存并解压 .tar 文件
|
||||||
|
tar_path = os.path.join(temp_dir, 'uploaded.tar')
|
||||||
|
file.save(tar_path)
|
||||||
|
|
||||||
|
with tarfile.open(tar_path, 'r') as tar:
|
||||||
|
# 安全检查:防止路径遍历(CVE-2007-4559 类问题)
|
||||||
|
def is_within_directory(directory, target):
|
||||||
|
abs_directory = os.path.abspath(directory)
|
||||||
|
abs_target = os.path.abspath(target)
|
||||||
|
prefix = os.path.commonprefix([abs_directory, abs_target])
|
||||||
|
return prefix == abs_directory
|
||||||
|
|
||||||
|
def safe_extract(tar, path=".", members=None, *, numeric_owner=False):
|
||||||
|
for member in tar.getmembers():
|
||||||
|
member_path = os.path.join(path, member.name)
|
||||||
|
if not is_within_directory(path, member_path):
|
||||||
|
raise Exception("Attempted Path Traversal in Tar File")
|
||||||
|
tar.extractall(path, members, numeric_owner=numeric_owner)
|
||||||
|
|
||||||
|
safe_extract(tar, path=temp_dir)
|
||||||
|
|
||||||
|
# 移除 tar 文件,只保留解压内容
|
||||||
|
os.remove(tar_path)
|
||||||
|
|
||||||
|
# 找到实际的模型目录(解压后可能多一层)
|
||||||
|
extracted_items = os.listdir(temp_dir)
|
||||||
|
if len(extracted_items) == 1 and os.path.isdir(os.path.join(temp_dir, extracted_items[0])):
|
||||||
|
model_dir = os.path.join(temp_dir, extracted_items[0])
|
||||||
|
else:
|
||||||
|
model_dir = temp_dir
|
||||||
|
|
||||||
|
# 验证必要文件是否存在
|
||||||
|
required_files = {'config.json', 'model.safetensors'}
|
||||||
|
if not required_files.issubset(set(os.listdir(model_dir))):
|
||||||
|
return {"error": f"Missing required files: {required_files}"}, 400
|
||||||
|
|
||||||
|
# 尝试加载新策略
|
||||||
|
logger.info(f"Loading new policy from: {model_dir}")
|
||||||
|
new_policy = SmolVLAPolicy.from_pretrained(model_dir).to(device).eval()
|
||||||
|
|
||||||
|
# 替换全局策略
|
||||||
|
smolvla_policy = new_policy
|
||||||
|
logger.info("Policy updated successfully!")
|
||||||
|
|
||||||
|
return {"message": "Policy updated successfully"}, 200
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to update policy: {e}", exc_info=True)
|
||||||
|
return {"error": str(e)}, 500
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--smolvla-path', type=str, default="/home/chao.wu/SmolVLA_RoboTwin2_BPU/train_result_rtw15fps_1instruction/adjust_bottle_randomized_500_fps15_1instruction/checkpoints/040000/pretrained_model", help="")
|
||||||
|
parser.add_argument('--device', type=str, default="cuda:0", help="")
|
||||||
|
parser.add_argument('--port', type=int, default=50002, help="")
|
||||||
|
opt = parser.parse_args()
|
||||||
|
logger.info(opt)
|
||||||
|
|
||||||
|
global smolvla_policy, device
|
||||||
|
device = torch.device(opt.device)
|
||||||
|
smolvla_policy = SmolVLAPolicy.from_pretrained(opt.smolvla_path).to(device).eval()
|
||||||
|
|
||||||
|
host = '127.0.0.1'
|
||||||
|
host = '0.0.0.0'
|
||||||
|
app.run(host=host, port=opt.port, threaded=False, debug=False)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
1
RoboTwin_Policy/__init__.py
Normal file
1
RoboTwin_Policy/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from .deploy_policy import *
|
||||||
139
RoboTwin_Policy/binary_protocol.py
Normal file
139
RoboTwin_Policy/binary_protocol.py
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
import pickle
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
import logging
|
||||||
|
from typing import Any, List
|
||||||
|
from tools import measure_time, show_data_summary
|
||||||
|
# 使用 Pickle Protocol 5 (Python 3.8+)
|
||||||
|
# 它支持 out-of-band data,对大型 numpy 数组非常高效
|
||||||
|
# 如果 Python 版本 < 3.8,默认 protocol 即可,依然比 JSON 快得多
|
||||||
|
HIGHEST_PROTOCOL = 4
|
||||||
|
FIX_IMPORTS = False
|
||||||
|
|
||||||
|
# 图像 key 列表,用于 JPEG 压缩/解码匹配
|
||||||
|
IMAGE_KEYS: List[str] = [
|
||||||
|
'observation.images.cam_high',
|
||||||
|
'observation.images.cam_left_wrist',
|
||||||
|
'observation.images.cam_right_wrist',
|
||||||
|
]
|
||||||
|
|
||||||
|
JPEG_QUALITY = 95
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level = logging.INFO,
|
||||||
|
format = '[%(name)s] [%(asctime)s.%(msecs)03d] [%(levelname)s] %(message)s',
|
||||||
|
datefmt='%H:%M:%S')
|
||||||
|
logger = logging.getLogger("SmolVLA_Server")
|
||||||
|
|
||||||
|
|
||||||
|
def jpeg_encode(img_chw: np.ndarray) -> bytes:
|
||||||
|
"""
|
||||||
|
将 (1, 3, H, W) 的图像数组编码为 JPEG bytes。
|
||||||
|
输入: shape (1, 3, H, W), dtype uint8 [0,255] 或 float32 [0,1]
|
||||||
|
输出: JPEG bytes
|
||||||
|
"""
|
||||||
|
# (1, 3, H, W) -> (H, W, 3)
|
||||||
|
img = img_chw[0].transpose(1, 2, 0)
|
||||||
|
if img.dtype != np.uint8:
|
||||||
|
img = (img * 255.0).clip(0, 255).astype(np.uint8)
|
||||||
|
success, buf = cv2.imencode('.jpg', img, [cv2.IMWRITE_JPEG_QUALITY, JPEG_QUALITY])
|
||||||
|
if not success:
|
||||||
|
raise RuntimeError("cv2.imencode JPEG failed")
|
||||||
|
return buf.tobytes()
|
||||||
|
|
||||||
|
|
||||||
|
def jpeg_decode(jpeg_bytes) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
将 JPEG bytes 或 np.ndarray(uint8) 解码回 (1, 3, H, W) float32 [0,1] 的图像数组。
|
||||||
|
输出: shape (1, 3, H, W), dtype float32, 值域 [0, 1]
|
||||||
|
"""
|
||||||
|
if isinstance(jpeg_bytes, np.ndarray):
|
||||||
|
buf = jpeg_bytes
|
||||||
|
else:
|
||||||
|
buf = np.frombuffer(jpeg_bytes, dtype=np.uint8)
|
||||||
|
img = cv2.imdecode(buf, cv2.IMREAD_COLOR) # (H, W, 3) uint8 BGR
|
||||||
|
# (H, W, 3) -> (1, 3, H, W), uint8 -> float32 [0,1]
|
||||||
|
return img.transpose(2, 0, 1)[np.newaxis].astype(np.float32) / 255.0
|
||||||
|
|
||||||
|
|
||||||
|
def encode_images_jpeg(data: dict) -> dict:
|
||||||
|
"""
|
||||||
|
将 data 中匹配 IMAGE_KEYS 的图像数组替换为 JPEG bytes(用于发送前压缩)。
|
||||||
|
非图像字段保持不变。
|
||||||
|
"""
|
||||||
|
out = {}
|
||||||
|
for k, v in data.items():
|
||||||
|
if k in IMAGE_KEYS and isinstance(v, np.ndarray):
|
||||||
|
out[k] = jpeg_encode(v)
|
||||||
|
else:
|
||||||
|
out[k] = v
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def decode_images_jpeg(data: dict) -> dict:
|
||||||
|
"""
|
||||||
|
将 data 中匹配 IMAGE_KEYS 的 JPEG bytes/np.ndarray(uint8) 解码回图像数组(用于接收后解压)。
|
||||||
|
非图像字段保持不变。
|
||||||
|
"""
|
||||||
|
out = {}
|
||||||
|
for k, v in data.items():
|
||||||
|
if k in IMAGE_KEYS and isinstance(v, (bytes, np.ndarray)):
|
||||||
|
out[k] = jpeg_decode(v)
|
||||||
|
else:
|
||||||
|
out[k] = v
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
# @measure_time(logger)
|
||||||
|
def dict_to_binary(data: Any) -> bytes:
|
||||||
|
"""
|
||||||
|
将包含 numpy 数组、字符串、字典的混合对象序列化为二进制流。
|
||||||
|
不进行任何压缩,追求极致速度。
|
||||||
|
"""
|
||||||
|
# fix_imports=False 确保二进制兼容性更好,buffer_callback 可用于极端优化,但默认已足够快
|
||||||
|
return pickle.dumps(data, protocol=HIGHEST_PROTOCOL, fix_imports=FIX_IMPORTS)
|
||||||
|
|
||||||
|
|
||||||
|
# @measure_time(logger)
|
||||||
|
def binary_to_dict(data: bytes) -> Any:
|
||||||
|
"""
|
||||||
|
将二进制流还原为原始字典对象。
|
||||||
|
"""
|
||||||
|
return pickle.loads(data, fix_imports=FIX_IMPORTS)
|
||||||
|
|
||||||
|
# 简单的自测
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import time
|
||||||
|
|
||||||
|
# 测试 JPEG 编解码
|
||||||
|
print("=== JPEG encode/decode test ===")
|
||||||
|
img_orig = np.random.randint(0, 255, (1, 3, 480, 640), dtype=np.uint8).astype(np.float32) / 255.0
|
||||||
|
jpeg_bytes = jpeg_encode(img_orig)
|
||||||
|
img_restored = jpeg_decode(jpeg_bytes)
|
||||||
|
print(f"Original shape: {img_orig.shape}, JPEG size: {len(jpeg_bytes)/1024:.2f} KB")
|
||||||
|
print(f"Restored shape: {img_restored.shape}, max abs diff: {np.abs(img_orig - img_restored).max():.4f}")
|
||||||
|
|
||||||
|
test_data = {
|
||||||
|
"instruction": "pick up the cup",
|
||||||
|
"state": np.array([0.1, 0.2, 0.3], dtype=np.float32),
|
||||||
|
"observation.images.cam_high": np.random.randint(0, 255, (1, 3, 480, 640), dtype=np.uint8).astype(np.float32) / 255.0,
|
||||||
|
"observation.images.cam_left_wrist": np.random.randint(0, 255, (1, 3, 480, 640), dtype=np.uint8).astype(np.float32) / 255.0,
|
||||||
|
"observation.images.cam_right_wrist": np.random.randint(0, 255, (1, 3, 480, 640), dtype=np.uint8).astype(np.float32) / 255.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
print("\n=== Full pipeline test (encode_images_jpeg -> dict_to_binary -> binary_to_dict -> decode_images_jpeg) ===")
|
||||||
|
start = time.time()
|
||||||
|
encoded = encode_images_jpeg(test_data)
|
||||||
|
blob = dict_to_binary(encoded)
|
||||||
|
t1 = time.time() - start
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
restored_encoded = binary_to_dict(blob)
|
||||||
|
restored = decode_images_jpeg(restored_encoded)
|
||||||
|
t2 = time.time() - start
|
||||||
|
|
||||||
|
print(f"Serialize+JPEG time: {t1*1000:.2f} ms, Size: {len(blob)/1024:.2f} KB")
|
||||||
|
print(f"Deserialize+JPEG time: {t2*1000:.2f} ms")
|
||||||
|
for k in test_data.keys():
|
||||||
|
if isinstance(test_data[k], np.ndarray):
|
||||||
|
print(f"{k}: {test_data[k].shape} -> {restored[k].shape}")
|
||||||
111
RoboTwin_Policy/deploy_policy.py
Normal file
111
RoboTwin_Policy/deploy_policy.py
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
# import packages and module here
|
||||||
|
import sys, os
|
||||||
|
sys.path.append("/home/chao.wu/SmolVLA_RoboTwin2_BPU/LeRobot_SmolVLA_Server_Fast_JPEG")
|
||||||
|
|
||||||
|
from request_tools import upload_policy, send_inference_request
|
||||||
|
from tools import show_data_summary
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
# current_file_path = os.path.abspath(__file__)
|
||||||
|
# parent_directory = os.path.dirname(current_file_path)
|
||||||
|
|
||||||
|
class SmolVLA_Client:
|
||||||
|
def __init__(self, server_url, server_device, chunk_slice=32, instruction=None):
|
||||||
|
self.url = server_url.rstrip('/')
|
||||||
|
# print(f"{model_path = }")
|
||||||
|
# print(f"{self.url}/update_policy")
|
||||||
|
# print(f"{server_device = }")
|
||||||
|
self.chunk_slice = chunk_slice
|
||||||
|
self.instruction = instruction
|
||||||
|
# upload_policy(model_path, f"{self.url}/update_policy", server_device)
|
||||||
|
|
||||||
|
def get_action(self, obs):
|
||||||
|
'''
|
||||||
|
task: Pick the bottle with ridges on bottom upright from the table <class 'str'>
|
||||||
|
observation.images.cam_high: torch.Size([1, 3, 480, 640]) torch.float32 <class 'torch.Tensor'> 0.0353~1.0000
|
||||||
|
observation.images.cam_left_wrist: torch.Size([1, 3, 480, 640]) torch.float32 <class 'torch.Tensor'> 0.0549~0.9922
|
||||||
|
observation.images.cam_right_wrist: torch.Size([1, 3, 480, 640]) torch.float32 <class 'torch.Tensor'> 0.1608~1.0000
|
||||||
|
observation.state: torch.Size([1, 14]) torch.float32 <class 'torch.Tensor'> -1.5298~2.3349
|
||||||
|
'''
|
||||||
|
if self.instruction is not None:
|
||||||
|
obs["instruction"] = self.instruction # str
|
||||||
|
obs["task"] = obs['instruction']
|
||||||
|
show_data_summary(obs)
|
||||||
|
|
||||||
|
return send_inference_request(obs, url=f"{self.url}/infer")['action_chunk'][0,:self.chunk_slice,:]
|
||||||
|
# if result['message'] == 'success':
|
||||||
|
# return result['action_chunk'][0,:self.chunk_slice,:]
|
||||||
|
# print(f"[{cnt}] [Cauchy] {result = }, Try again.")
|
||||||
|
# cnt += 1
|
||||||
|
|
||||||
|
def preprocess(self, img: np.ndarray) -> bytes:
|
||||||
|
# Resize using cv2 nearest neighbor, then JPEG encode
|
||||||
|
img_resized = cv2.resize(img, (640, 480), interpolation=cv2.INTER_NEAREST)
|
||||||
|
_, buf = cv2.imencode('.jpg', img_resized, [cv2.IMWRITE_JPEG_QUALITY, 95])
|
||||||
|
return buf
|
||||||
|
def get_model(usr_args):
|
||||||
|
print(f"[Cauchy] {usr_args = }")
|
||||||
|
instruction = usr_args["task_name"].replace("_", " ")
|
||||||
|
model = SmolVLA_Client(usr_args["server_url"], usr_args["server_device"], instruction=instruction)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def eval(TASK_ENV, model, observation):
|
||||||
|
# print(f"{observation.keys() =}")
|
||||||
|
|
||||||
|
obs = {}
|
||||||
|
obs["instruction"] = TASK_ENV.get_instruction() # str
|
||||||
|
obs["task"] = obs['instruction']
|
||||||
|
obs['observation.images.cam_high'] = model.preprocess(observation["observation"]["head_camera"]["rgb"])
|
||||||
|
obs['observation.images.cam_left_wrist'] = model.preprocess(observation["observation"]["left_camera"]["rgb"])
|
||||||
|
obs['observation.images.cam_right_wrist'] = model.preprocess(observation["observation"]["right_camera"]["rgb"])
|
||||||
|
obs['observation.state'] = torch.from_numpy(observation["joint_action"]["vector"]).unsqueeze(0).float().numpy()
|
||||||
|
actions = model.get_action(obs)
|
||||||
|
|
||||||
|
# print(f"{actions.shape = }")
|
||||||
|
for action in actions: # Execute each step of the action
|
||||||
|
# for action in tqdm(actions, desc="action", ncols=100):
|
||||||
|
# print(f"{action.shape = }")
|
||||||
|
TASK_ENV.take_action(action)
|
||||||
|
# TASK_ENV.get_obs()
|
||||||
|
TASK_ENV.get_obs()
|
||||||
|
return
|
||||||
|
# observation["observation"]["head_camera"]["rgb"]: np.array, np.unit8, (0~255), (240, 320, 3)
|
||||||
|
# observation["observation"]["right_camera"]["rgb"]: np.array, np.unit8, (0~255), (240, 320, 3)
|
||||||
|
# observation["observation"]["left_camera"]["rgb"]: np.array, np.unit8, (0~255), (240, 320, 3)
|
||||||
|
# observation["joint_action"]["vector"]: np.array, np.float64, (0.0~1.0), (14, )
|
||||||
|
obs = encode_obs(observation) # Post-Process Observation
|
||||||
|
instruction = TASK_ENV.get_instruction()
|
||||||
|
input_rgb_arr, input_state = [
|
||||||
|
obs["observation"]["head_camera"]["rgb"],
|
||||||
|
obs["observation"]["right_camera"]["rgb"],
|
||||||
|
obs["observation"]["left_camera"]["rgb"],
|
||||||
|
], obs["agent_pos"] # TODO
|
||||||
|
|
||||||
|
if (model.observation_window
|
||||||
|
is None): # Force an update of the observation at the first frame to avoid an empty observation window
|
||||||
|
model.set_language_instruction(instruction)
|
||||||
|
model.update_observation_window(input_rgb_arr, input_state)
|
||||||
|
|
||||||
|
actions = model.get_action()[:model.rdt_step, :] # Get Action according to observation chunk
|
||||||
|
|
||||||
|
for action in actions: # Execute each step of the action
|
||||||
|
TASK_ENV.take_action(action)
|
||||||
|
# observation = TASK_ENV.get_obs()
|
||||||
|
# obs = encode_obs(observation)
|
||||||
|
# input_rgb_arr, input_state = [
|
||||||
|
# obs["observation"]["head_camera"]["rgb"],
|
||||||
|
# obs["observation"]["right_camera"]["rgb"],
|
||||||
|
# obs["observation"]["left_camera"]["rgb"],
|
||||||
|
# ], obs["agent_pos"] # TODO
|
||||||
|
# model.update_observation_window(input_rgb_arr, input_state) # Update Observation
|
||||||
|
|
||||||
|
|
||||||
|
def reset_model(model):
|
||||||
|
return
|
||||||
|
model.reset_obsrvationwindows()
|
||||||
8
RoboTwin_Policy/deploy_policy.yml
Normal file
8
RoboTwin_Policy/deploy_policy.yml
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
# Basic experiment configuration
|
||||||
|
policy_name: null
|
||||||
|
task_name: null
|
||||||
|
task_config: null
|
||||||
|
ckpt_setting: null
|
||||||
|
seed: null
|
||||||
|
instruction_type: seen
|
||||||
|
policy_conda_env: null
|
||||||
13
RoboTwin_Policy/eval_smolvla_client.sh
Normal file
13
RoboTwin_Policy/eval_smolvla_client.sh
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
|
||||||
|
cd ../.. # Optional, based on RoboTwin环境
|
||||||
|
export CUDA_VISIBLE_DEVICES=2
|
||||||
|
|
||||||
|
python3 script/eval_policy.py --config policy/SmolVLA_Client_BPUA/deploy_policy.yml \
|
||||||
|
--overrides \
|
||||||
|
--policy_name RoboTwin_Policy \
|
||||||
|
--task_name adjust_bottle \
|
||||||
|
--task_config "demo_randomized" \
|
||||||
|
--seed 0 \
|
||||||
|
--server_url "http://127.0.0.1:60002" \
|
||||||
|
--server_device cuda:0 \
|
||||||
|
2> >(grep -v "svulkan2.*OIDN" >&2)
|
||||||
121
RoboTwin_Policy/request_tools.py
Normal file
121
RoboTwin_Policy/request_tools.py
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
|
||||||
|
import json
|
||||||
|
import gzip
|
||||||
|
import torch
|
||||||
|
import requests
|
||||||
|
import tarfile
|
||||||
|
import pickle
|
||||||
|
import numpy as np
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
HIGHEST_PROTOCOL = 4
|
||||||
|
FIX_IMPORTS = False
|
||||||
|
|
||||||
|
import time
|
||||||
|
import requests
|
||||||
|
import gzip
|
||||||
|
import json
|
||||||
|
import torch
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
|
||||||
|
from tools import measure_time, show_data_summary
|
||||||
|
from binary_protocol import dict_to_binary, binary_to_dict
|
||||||
|
|
||||||
|
def send_inference_request(
|
||||||
|
data_dict: Dict[str, Any],
|
||||||
|
url: str = 'http://127.0.0.1:50000/infer',
|
||||||
|
timeout: int = 10,
|
||||||
|
max_retries: int = 30,
|
||||||
|
retry_delay: float = 2
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
发送推理请求 (二进制 Pickle 协议),支持超时、自动重试。
|
||||||
|
兼容 Python 3.10/3.12 及 Numpy 2.x。
|
||||||
|
|
||||||
|
:param data_dict: 输入数据字典 (可包含 numpy 数组)
|
||||||
|
:param url: 推理服务地址
|
||||||
|
:param timeout: 单次 HTTP 请求超时时间(秒)
|
||||||
|
:param max_retries: 最大重试次数
|
||||||
|
:param retry_delay: 重试前等待时间(秒)
|
||||||
|
:return: 解码后的响应字典 (包含 numpy 数组)
|
||||||
|
:raises RuntimeError: 超过重试次数或收到非 200 响应
|
||||||
|
"""
|
||||||
|
# 1. 尝试序列化输入数据
|
||||||
|
try:
|
||||||
|
req_body = dict_to_binary(data_dict)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to serialize request data: {e}")
|
||||||
|
|
||||||
|
# 2. 设置二进制流 headers
|
||||||
|
headers = {
|
||||||
|
'Content-Type': 'application/octet-stream'
|
||||||
|
}
|
||||||
|
|
||||||
|
# 3. 发送二进制流
|
||||||
|
last_exception: Optional[Exception] = None
|
||||||
|
for attempt in range(max_retries + 1):
|
||||||
|
try:
|
||||||
|
resp = requests.post(url, data=req_body, headers=headers, timeout=timeout)
|
||||||
|
if resp.status_code == 200:
|
||||||
|
try:
|
||||||
|
result_dict = binary_to_dict(resp.content)
|
||||||
|
return result_dict
|
||||||
|
except Exception as deserialize_err:
|
||||||
|
raise RuntimeError(f"Failed to deserialize response: {deserialize_err}")
|
||||||
|
else:
|
||||||
|
error_msg = f"HTTP {resp.status_code}"
|
||||||
|
try:
|
||||||
|
err_data = binary_to_dict(resp.content)
|
||||||
|
if isinstance(err_data, dict) and "message" in err_data:
|
||||||
|
error_msg += f": {err_data['message']}"
|
||||||
|
else:
|
||||||
|
error_msg += f": {resp.text[:200]}"
|
||||||
|
except:
|
||||||
|
error_msg += f": {resp.text[:200]}"
|
||||||
|
print(f"[Attempt {attempt + 1}/{max_retries + 1}] Server error: {error_msg}")
|
||||||
|
last_exception = RuntimeError(error_msg)
|
||||||
|
|
||||||
|
except requests.exceptions.Timeout as e:
|
||||||
|
last_exception = e
|
||||||
|
print(f"[Attempt {attempt + 1}/{max_retries + 1}] Request timeout: {e}")
|
||||||
|
|
||||||
|
except requests.exceptions.ConnectionError as e:
|
||||||
|
last_exception = e
|
||||||
|
print(f"[Attempt {attempt + 1}/{max_retries + 1}] Connection error: {e}")
|
||||||
|
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
last_exception = e
|
||||||
|
print(f"[Attempt {attempt + 1}/{max_retries + 1}] Network error: {e}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
last_exception = e
|
||||||
|
print(f"[Attempt {attempt + 1}/{max_retries + 1}] Unexpected error: {e}")
|
||||||
|
|
||||||
|
# 如果不是最后一次尝试,则等待后重试
|
||||||
|
if attempt < max_retries:
|
||||||
|
time.sleep(retry_delay)
|
||||||
|
|
||||||
|
raise RuntimeError(f"Failed after {max_retries + 1} attempts. Last error: {last_exception}")
|
||||||
|
|
||||||
|
def upload_policy(policy_dir: str, server_url: str = "http://localhost:50001/update_policy", device: str = 'cpu'):
|
||||||
|
"""
|
||||||
|
将 policy_dir 打包为 .tar 并上传到服务器
|
||||||
|
"""
|
||||||
|
# 创建临时 .tar 文件(内存中或磁盘)
|
||||||
|
import tempfile
|
||||||
|
with tempfile.NamedTemporaryFile(suffix='.tar') as tmp_tar:
|
||||||
|
# print("tar ... ")
|
||||||
|
with tarfile.open(tmp_tar.name, 'w') as tar:
|
||||||
|
# arcname='' 表示根目录内容直接放入 tar,避免多层嵌套
|
||||||
|
tar.add(policy_dir, arcname='')
|
||||||
|
|
||||||
|
# print("upload ... ")
|
||||||
|
# 上传
|
||||||
|
with open(tmp_tar.name, 'rb') as f:
|
||||||
|
files = {'file': ('policy.tar', f, 'application/x-tar')}
|
||||||
|
resp = requests.post(server_url, files=files, data={'device': device})
|
||||||
|
print(f"{resp = }")
|
||||||
|
# print(resp.json())
|
||||||
|
return resp.status_code == 200
|
||||||
43
RoboTwin_Policy/tools.py
Normal file
43
RoboTwin_Policy/tools.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from functools import wraps
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# def measure_time(func):
|
||||||
|
# @wraps(func)
|
||||||
|
# def wrapper(*args, **kwargs):
|
||||||
|
# begin_time = time.time()
|
||||||
|
# result = func(* args, **kwargs)
|
||||||
|
# elapsed_ms = 1000 * (time.time() - begin_time)
|
||||||
|
# logger.info("\033[1;31m" + f"{func.__name__}: {elapsed_ms:.2f} ms" + "\033[0m")
|
||||||
|
# return result
|
||||||
|
# return wrapper
|
||||||
|
|
||||||
|
# from functools import wraps
|
||||||
|
# import time
|
||||||
|
|
||||||
|
def measure_time(logger):
|
||||||
|
"""返回一个装饰器,使用指定的 logger 记录耗时"""
|
||||||
|
def decorator(func):
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
begin_time = time.time()
|
||||||
|
result = func(*args, **kwargs)
|
||||||
|
elapsed_ms = 1000 * (time.time() - begin_time)
|
||||||
|
logger.info("\033[1;31m" + f"{func.__name__}: {elapsed_ms:.2f} ms" + "\033[0m")
|
||||||
|
return result
|
||||||
|
return wrapper
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
def show_data_summary(data):
|
||||||
|
"""打印数据集的维度和数据类型"""
|
||||||
|
for k, v in data.items():
|
||||||
|
if isinstance(v, np.ndarray):
|
||||||
|
print(f"{k}: {v.shape} {v.dtype} {type(v)} {v.min():.4f}~{v.max():.4f}")
|
||||||
|
elif isinstance(v, torch.Tensor):
|
||||||
|
print(f"{k}: {v.shape} {v.dtype} {type(v)} {v.min():.4f}~{v.max():.4f}, {v.device}")
|
||||||
|
else:
|
||||||
|
print(f"{k}: {v} {type(v)}")
|
||||||
|
|
||||||
168
dataset_tools/convert_robotwin2raw_to_robotwin_pi0_aloha_hdf5.py
Normal file
168
dataset_tools/convert_robotwin2raw_to_robotwin_pi0_aloha_hdf5.py
Normal file
@ -0,0 +1,168 @@
|
|||||||
|
'''
|
||||||
|
source: https://github.com/RoboTwin-Platform/RoboTwin/blob/main/policy/pi0/scripts/process_data.py
|
||||||
|
function:
|
||||||
|
- 来源: RoboTwin2 collect data生成的数据
|
||||||
|
- 输出: HDF5格式的数据, 是一种中间的表达
|
||||||
|
description:
|
||||||
|
- 对图像的处理: 解码 + 重采样图像(缩放) + 编码
|
||||||
|
- 构建状态(observation)序列
|
||||||
|
- 指令json只保存instructions/episode*.json中seen的部分, 不保存unseen的部分
|
||||||
|
- 保存新格式 HDF5 + 指令 JSON
|
||||||
|
- 新增tqdm进度条
|
||||||
|
- 新增结束后显示源数据集和处理后数据集大小的打印
|
||||||
|
'''
|
||||||
|
import os
|
||||||
|
import h5py
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
import argparse
|
||||||
|
import yaml, json
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Process some episodes.")
|
||||||
|
parser.add_argument("--dataset", type=str, default="/home/chao.wu/SmolVLA_RoboTwin2_BPU/RoboTwin2_dataset_mini/stack_blocks_three/aloha-agilex_randomized_500", help="robotwin2 raw dataset path",)
|
||||||
|
'''example:
|
||||||
|
./dataset/
|
||||||
|
|-- _traj_data
|
||||||
|
|-- data
|
||||||
|
|-- instructions
|
||||||
|
|-- scene_info.json
|
||||||
|
|-- seed.txt
|
||||||
|
`-- video
|
||||||
|
'''
|
||||||
|
extra_cauchy = "procrssed_"
|
||||||
|
parser.add_argument("--dist", type=str, default="/home/chao.wu/SmolVLA_RoboTwin2_BPU/RoboTwin2_dataset_mini/stack_blocks_three/procrssed_aloha-agilex_randomized_500", help="target dataset path",)
|
||||||
|
opt = parser.parse_args()
|
||||||
|
print(f'read data from path: {opt.dataset}.')
|
||||||
|
expert_data_num = len(os.listdir(os.path.join(opt.dataset, 'data')))
|
||||||
|
data_transform(opt.dataset, expert_data_num, opt.dist)
|
||||||
|
os.system(f"du -sh {opt.dataset} {opt.dist}")
|
||||||
|
|
||||||
|
def load_hdf5(dataset_path):
|
||||||
|
if not os.path.isfile(dataset_path):
|
||||||
|
print(f"Dataset does not exist at \n{dataset_path}\n")
|
||||||
|
exit()
|
||||||
|
|
||||||
|
with h5py.File(dataset_path, "r") as root:
|
||||||
|
left_gripper, left_arm = (
|
||||||
|
root["/joint_action/left_gripper"][()],
|
||||||
|
root["/joint_action/left_arm"][()],
|
||||||
|
)
|
||||||
|
right_gripper, right_arm = (
|
||||||
|
root["/joint_action/right_gripper"][()],
|
||||||
|
root["/joint_action/right_arm"][()],
|
||||||
|
)
|
||||||
|
image_dict = dict()
|
||||||
|
for cam_name in root[f"/observation/"].keys():
|
||||||
|
image_dict[cam_name] = root[f"/observation/{cam_name}/rgb"][()]
|
||||||
|
|
||||||
|
return left_gripper, left_arm, right_gripper, right_arm, image_dict
|
||||||
|
|
||||||
|
|
||||||
|
def images_encoding(imgs):
|
||||||
|
encode_data = []
|
||||||
|
padded_data = []
|
||||||
|
max_len = 0
|
||||||
|
for i in range(len(imgs)):
|
||||||
|
success, encoded_image = cv2.imencode(".jpg", imgs[i])
|
||||||
|
jpeg_data = encoded_image.tobytes()
|
||||||
|
encode_data.append(jpeg_data)
|
||||||
|
max_len = max(max_len, len(jpeg_data))
|
||||||
|
# padding
|
||||||
|
for i in range(len(imgs)):
|
||||||
|
padded_data.append(encode_data[i].ljust(max_len, b"\0"))
|
||||||
|
return encode_data, max_len
|
||||||
|
|
||||||
|
|
||||||
|
def get_task_config(task_name):
|
||||||
|
with open(f"./task_config/{task_name}.yml", "r", encoding="utf-8") as f:
|
||||||
|
args = yaml.load(f.read(), Loader=yaml.FullLoader)
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def data_transform(path, episode_num, save_path):
|
||||||
|
begin = 0
|
||||||
|
os.makedirs(save_path, exist_ok=True)
|
||||||
|
# for i in range(episode_num):
|
||||||
|
for i in tqdm(range(episode_num), desc="Process", ncols=100):
|
||||||
|
desc_type = "seen"
|
||||||
|
instruction_data_path = os.path.join(path, "instructions", f"episode{i}.json")
|
||||||
|
with open(instruction_data_path, "r") as f_instr:
|
||||||
|
instruction_dict = json.load(f_instr)
|
||||||
|
instructions = instruction_dict[desc_type]
|
||||||
|
save_instructions_json = {"instructions": instructions}
|
||||||
|
os.makedirs(os.path.join(save_path, f"episode_{i}"), exist_ok=True)
|
||||||
|
with open(os.path.join(os.path.join(save_path, f"episode_{i}"), "instructions.json"), "w") as f:
|
||||||
|
json.dump(save_instructions_json, f, indent=2)
|
||||||
|
left_gripper_all, left_arm_all, right_gripper_all, right_arm_all, image_dict = (load_hdf5(
|
||||||
|
os.path.join(path, "data", f"episode{i}.hdf5")))
|
||||||
|
qpos = []
|
||||||
|
actions = []
|
||||||
|
cam_high = []
|
||||||
|
cam_right_wrist = []
|
||||||
|
cam_left_wrist = []
|
||||||
|
left_arm_dim = []
|
||||||
|
right_arm_dim = []
|
||||||
|
|
||||||
|
last_state = None
|
||||||
|
for j in range(0, left_gripper_all.shape[0]):
|
||||||
|
left_gripper, left_arm, right_gripper, right_arm = (
|
||||||
|
left_gripper_all[j],
|
||||||
|
left_arm_all[j],
|
||||||
|
right_gripper_all[j],
|
||||||
|
right_arm_all[j],
|
||||||
|
)
|
||||||
|
|
||||||
|
state = np.array(left_arm.tolist() + [left_gripper] + right_arm.tolist() + [right_gripper]) # joints angle
|
||||||
|
|
||||||
|
state = state.astype(np.float32)
|
||||||
|
|
||||||
|
if j != left_gripper_all.shape[0] - 1:
|
||||||
|
qpos.append(state)
|
||||||
|
|
||||||
|
camera_high_bits = image_dict["head_camera"][j]
|
||||||
|
camera_high = cv2.imdecode(np.frombuffer(camera_high_bits, np.uint8), cv2.IMREAD_COLOR)
|
||||||
|
camera_high_resized = cv2.resize(camera_high, (640, 480))
|
||||||
|
cam_high.append(camera_high_resized)
|
||||||
|
|
||||||
|
camera_right_wrist_bits = image_dict["right_camera"][j]
|
||||||
|
camera_right_wrist = cv2.imdecode(np.frombuffer(camera_right_wrist_bits, np.uint8), cv2.IMREAD_COLOR)
|
||||||
|
camera_right_wrist_resized = cv2.resize(camera_right_wrist, (640, 480))
|
||||||
|
cam_right_wrist.append(camera_right_wrist_resized)
|
||||||
|
|
||||||
|
camera_left_wrist_bits = image_dict["left_camera"][j]
|
||||||
|
camera_left_wrist = cv2.imdecode(np.frombuffer(camera_left_wrist_bits, np.uint8), cv2.IMREAD_COLOR)
|
||||||
|
camera_left_wrist_resized = cv2.resize(camera_left_wrist, (640, 480))
|
||||||
|
cam_left_wrist.append(camera_left_wrist_resized)
|
||||||
|
|
||||||
|
if j != 0:
|
||||||
|
action = state
|
||||||
|
actions.append(action)
|
||||||
|
left_arm_dim.append(left_arm.shape[0])
|
||||||
|
right_arm_dim.append(right_arm.shape[0])
|
||||||
|
|
||||||
|
hdf5path = os.path.join(save_path, f"episode_{i}/episode_{i}.hdf5")
|
||||||
|
|
||||||
|
with h5py.File(hdf5path, "w") as f:
|
||||||
|
f.create_dataset("action", data=np.array(actions))
|
||||||
|
obs = f.create_group("observations")
|
||||||
|
obs.create_dataset("qpos", data=np.array(qpos))
|
||||||
|
obs.create_dataset("left_arm_dim", data=np.array(left_arm_dim))
|
||||||
|
obs.create_dataset("right_arm_dim", data=np.array(right_arm_dim))
|
||||||
|
image = obs.create_group("images")
|
||||||
|
cam_high_enc, len_high = images_encoding(cam_high)
|
||||||
|
cam_right_wrist_enc, len_right = images_encoding(cam_right_wrist)
|
||||||
|
cam_left_wrist_enc, len_left = images_encoding(cam_left_wrist)
|
||||||
|
image.create_dataset("cam_high", data=cam_high_enc, dtype=f"S{len_high}")
|
||||||
|
image.create_dataset("cam_right_wrist", data=cam_right_wrist_enc, dtype=f"S{len_right}")
|
||||||
|
image.create_dataset("cam_left_wrist", data=cam_left_wrist_enc, dtype=f"S{len_left}")
|
||||||
|
|
||||||
|
begin += 1
|
||||||
|
# print(f"proccess {i} success!")
|
||||||
|
|
||||||
|
return begin
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@ -0,0 +1,332 @@
|
|||||||
|
"""
|
||||||
|
Script to convert Aloha hdf5 data to the LeRobot dataset v2.0 format.
|
||||||
|
|
||||||
|
Example usage: uv run examples/aloha_real/convert_aloha_data_to_lerobot.py --raw-dir /path/to/raw/data --repo-id <org>/<dataset-name>
|
||||||
|
|
||||||
|
source: https://github.com/RoboTwin-Platform/RoboTwin/blob/main/policy/pi0/examples/aloha_real/convert_aloha_data_to_lerobot_robotwin.py
|
||||||
|
function:
|
||||||
|
description:
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
os.environ["XDG_CACHE_HOME"] = "/home/chao.wu/SmolVLA_RoboTwin2_BPU"
|
||||||
|
os.environ["SVT_LOG_LEVEL"] = "0"
|
||||||
|
import dataclasses
|
||||||
|
from pathlib import Path
|
||||||
|
import shutil
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
import h5py
|
||||||
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset # LeRobot 0.3.3
|
||||||
|
# from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import tqdm
|
||||||
|
import json
|
||||||
|
import fnmatch
|
||||||
|
import argparse
|
||||||
|
from datasets import disable_progress_bars
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Process some episodes.")
|
||||||
|
parser.add_argument("--dataset", type=str, default="/home/chao.wu/SmolVLA_RoboTwin2_BPU/RoboTwin2_dataset_mini/stack_blocks_three/procrssed_aloha-agilex_clean_50", help="robotwin2 raw dataset path",)
|
||||||
|
'''
|
||||||
|
./dataset/
|
||||||
|
|-- episode_0
|
||||||
|
| |-- episode_0.hdf5
|
||||||
|
| `-- instructions.json
|
||||||
|
|-- ...
|
||||||
|
`-- episode_9
|
||||||
|
'''
|
||||||
|
parser.add_argument("--repo-id", type=str, default="stack_blocks_three_aloha-agilex_clean_50", help="target dataset path",)
|
||||||
|
parser.add_argument("--mode", type=str, default="video", help="image / video",)
|
||||||
|
parser.add_argument("--instruction", type=str, default="", help="string like: stack_blocks_three.",)
|
||||||
|
opt = parser.parse_args()
|
||||||
|
disable_progress_bars() # 关闭datasets库内置的进度条
|
||||||
|
opt.instruction = opt.instruction.replace("_", " ")
|
||||||
|
opt.instruction = None if opt.instruction=="" else opt.instruction
|
||||||
|
port_aloha(raw_dir=opt.dataset, repo_id=opt.repo_id, mode=opt.mode, instruction=opt.instruction)
|
||||||
|
|
||||||
|
@dataclasses.dataclass(frozen=True)
|
||||||
|
class DatasetConfig:
|
||||||
|
use_videos: bool = True
|
||||||
|
tolerance_s: float = 0.0001
|
||||||
|
image_writer_processes: int = 10
|
||||||
|
image_writer_threads: int = 10
|
||||||
|
video_backend: str | None = None
|
||||||
|
fps: int = 15
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_DATASET_CONFIG = DatasetConfig()
|
||||||
|
|
||||||
|
|
||||||
|
def create_empty_dataset(
|
||||||
|
repo_id: str,
|
||||||
|
robot_type: str,
|
||||||
|
mode: Literal["video", "image"] = "video",
|
||||||
|
*,
|
||||||
|
has_velocity: bool = False,
|
||||||
|
has_effort: bool = False,
|
||||||
|
dataset_config: DatasetConfig = DEFAULT_DATASET_CONFIG,
|
||||||
|
) -> LeRobotDataset:
|
||||||
|
motors = [
|
||||||
|
"left_waist",
|
||||||
|
"left_shoulder",
|
||||||
|
"left_elbow",
|
||||||
|
"left_forearm_roll",
|
||||||
|
"left_wrist_angle",
|
||||||
|
"left_wrist_rotate",
|
||||||
|
"left_gripper",
|
||||||
|
"right_waist",
|
||||||
|
"right_shoulder",
|
||||||
|
"right_elbow",
|
||||||
|
"right_forearm_roll",
|
||||||
|
"right_wrist_angle",
|
||||||
|
"right_wrist_rotate",
|
||||||
|
"right_gripper",
|
||||||
|
]
|
||||||
|
|
||||||
|
cameras = [
|
||||||
|
"cam_high",
|
||||||
|
"cam_left_wrist",
|
||||||
|
"cam_right_wrist",
|
||||||
|
]
|
||||||
|
|
||||||
|
features = {
|
||||||
|
"observation.state": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (len(motors), ),
|
||||||
|
"names": [
|
||||||
|
motors,
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"action": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (len(motors), ),
|
||||||
|
"names": [
|
||||||
|
motors,
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if has_velocity:
|
||||||
|
features["observation.velocity"] = {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (len(motors), ),
|
||||||
|
"names": [
|
||||||
|
motors,
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
if has_effort:
|
||||||
|
features["observation.effort"] = {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (len(motors), ),
|
||||||
|
"names": [
|
||||||
|
motors,
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
for cam in cameras:
|
||||||
|
features[f"observation.images.{cam}"] = {
|
||||||
|
"dtype": mode,
|
||||||
|
"shape": (3, 480, 640),
|
||||||
|
"names": [
|
||||||
|
"channels",
|
||||||
|
"height",
|
||||||
|
"width",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
return LeRobotDataset.create(
|
||||||
|
repo_id=repo_id,
|
||||||
|
fps=dataset_config.fps,
|
||||||
|
robot_type=robot_type,
|
||||||
|
features=features,
|
||||||
|
# use_videos=dataset_config.use_videos,
|
||||||
|
# tolerance_s=dataset_config.tolerance_s,
|
||||||
|
image_writer_processes=dataset_config.image_writer_processes,
|
||||||
|
image_writer_threads=dataset_config.image_writer_threads,
|
||||||
|
# video_backend=dataset_config.video_backend,
|
||||||
|
)
|
||||||
|
if Path(HF_LEROBOT_HOME / repo_id).exists():
|
||||||
|
shutil.rmtree(HF_LEROBOT_HOME / repo_id)
|
||||||
|
|
||||||
|
return LeRobotDataset.create(
|
||||||
|
repo_id=repo_id,
|
||||||
|
fps=50,
|
||||||
|
robot_type=robot_type,
|
||||||
|
features=features,
|
||||||
|
use_videos=dataset_config.use_videos,
|
||||||
|
tolerance_s=dataset_config.tolerance_s,
|
||||||
|
image_writer_processes=dataset_config.image_writer_processes,
|
||||||
|
image_writer_threads=dataset_config.image_writer_threads,
|
||||||
|
video_backend=dataset_config.video_backend,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_cameras(hdf5_files: list[Path]) -> list[str]:
|
||||||
|
with h5py.File(hdf5_files[0], "r") as ep:
|
||||||
|
# ignore depth channel, not currently handled
|
||||||
|
return [key for key in ep["/observations/images"].keys() if "depth" not in key] # noqa: SIM118
|
||||||
|
|
||||||
|
|
||||||
|
def has_velocity(hdf5_files: list[Path]) -> bool:
|
||||||
|
with h5py.File(hdf5_files[0], "r") as ep:
|
||||||
|
return "/observations/qvel" in ep
|
||||||
|
|
||||||
|
|
||||||
|
def has_effort(hdf5_files: list[Path]) -> bool:
|
||||||
|
with h5py.File(hdf5_files[0], "r") as ep:
|
||||||
|
return "/observations/effort" in ep
|
||||||
|
|
||||||
|
|
||||||
|
def load_raw_images_per_camera(ep: h5py.File, cameras: list[str]) -> dict[str, np.ndarray]:
|
||||||
|
imgs_per_cam = {}
|
||||||
|
for camera in cameras:
|
||||||
|
uncompressed = ep[f"/observations/images/{camera}"].ndim == 4
|
||||||
|
|
||||||
|
if uncompressed:
|
||||||
|
# load all images in RAM
|
||||||
|
imgs_array = ep[f"/observations/images/{camera}"][:]
|
||||||
|
else:
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
# load one compressed image after the other in RAM and uncompress
|
||||||
|
imgs_array = []
|
||||||
|
for data in ep[f"/observations/images/{camera}"]:
|
||||||
|
data = np.frombuffer(data, np.uint8)
|
||||||
|
# img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) # 解码为彩色图像
|
||||||
|
imgs_array.append(cv2.imdecode(data, cv2.IMREAD_COLOR))
|
||||||
|
imgs_array = np.array(imgs_array)
|
||||||
|
|
||||||
|
imgs_per_cam[camera] = imgs_array
|
||||||
|
return imgs_per_cam
|
||||||
|
|
||||||
|
|
||||||
|
def load_raw_episode_data(
|
||||||
|
ep_path: Path,
|
||||||
|
) -> tuple[
|
||||||
|
dict[str, np.ndarray],
|
||||||
|
torch.Tensor,
|
||||||
|
torch.Tensor,
|
||||||
|
torch.Tensor | None,
|
||||||
|
torch.Tensor | None,
|
||||||
|
]:
|
||||||
|
with h5py.File(ep_path, "r") as ep:
|
||||||
|
state = torch.from_numpy(ep["/observations/qpos"][:])
|
||||||
|
action = torch.from_numpy(ep["/action"][:])
|
||||||
|
|
||||||
|
velocity = None
|
||||||
|
if "/observations/qvel" in ep:
|
||||||
|
velocity = torch.from_numpy(ep["/observations/qvel"][:])
|
||||||
|
|
||||||
|
effort = None
|
||||||
|
if "/observations/effort" in ep:
|
||||||
|
effort = torch.from_numpy(ep["/observations/effort"][:])
|
||||||
|
|
||||||
|
imgs_per_cam = load_raw_images_per_camera(
|
||||||
|
ep,
|
||||||
|
[
|
||||||
|
"cam_high",
|
||||||
|
"cam_left_wrist",
|
||||||
|
"cam_right_wrist",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
return imgs_per_cam, state, action, velocity, effort
|
||||||
|
|
||||||
|
|
||||||
|
def populate_dataset(
|
||||||
|
dataset: LeRobotDataset,
|
||||||
|
hdf5_files: list[Path],
|
||||||
|
task: str,
|
||||||
|
episodes: list[int] | None = None,
|
||||||
|
instruction: str | None = None,
|
||||||
|
) -> LeRobotDataset:
|
||||||
|
if episodes is None:
|
||||||
|
episodes = range(len(hdf5_files))
|
||||||
|
|
||||||
|
for ep_idx in tqdm.tqdm(episodes, desc="Process", ncols=100):
|
||||||
|
ep_path = hdf5_files[ep_idx]
|
||||||
|
|
||||||
|
imgs_per_cam, state, action, velocity, effort = load_raw_episode_data(ep_path)
|
||||||
|
num_frames = state.shape[0]
|
||||||
|
# add prompt
|
||||||
|
if instruction is None:
|
||||||
|
dir_path = os.path.dirname(ep_path)
|
||||||
|
json_Path = f"{dir_path}/instructions.json"
|
||||||
|
with open(json_Path, 'r') as f_instr:
|
||||||
|
instruction_dict = json.load(f_instr)
|
||||||
|
instructions = instruction_dict['instructions']
|
||||||
|
instruction = np.random.choice(instructions)
|
||||||
|
|
||||||
|
for i in range(num_frames):
|
||||||
|
frame = {
|
||||||
|
"observation.state": state[i],
|
||||||
|
"action": action[i],
|
||||||
|
# "task": instruction,
|
||||||
|
}
|
||||||
|
|
||||||
|
for camera, img_array in imgs_per_cam.items():
|
||||||
|
frame[f"observation.images.{camera}"] = img_array[i]
|
||||||
|
|
||||||
|
if velocity is not None:
|
||||||
|
frame["observation.velocity"] = velocity[i]
|
||||||
|
if effort is not None:
|
||||||
|
frame["observation.effort"] = effort[i]
|
||||||
|
dataset.add_frame(frame, task=instruction)
|
||||||
|
dataset.save_episode()
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
def port_aloha(
|
||||||
|
raw_dir: Path,
|
||||||
|
repo_id: str,
|
||||||
|
raw_repo_id: str | None = None,
|
||||||
|
task: str = "DEBUG",
|
||||||
|
*,
|
||||||
|
episodes: list[int] | None = None,
|
||||||
|
push_to_hub: bool = False,
|
||||||
|
is_mobile: bool = False,
|
||||||
|
mode: Literal["video", "image"] = "video",
|
||||||
|
dataset_config: DatasetConfig = DEFAULT_DATASET_CONFIG,
|
||||||
|
instruction: str | None = None,
|
||||||
|
):
|
||||||
|
# if (HF_LEROBOT_HOME / repo_id).exists():
|
||||||
|
# shutil.rmtree(HF_LEROBOT_HOME / repo_id)
|
||||||
|
|
||||||
|
if not Path(raw_dir).exists():
|
||||||
|
if raw_repo_id is None:
|
||||||
|
raise ValueError("raw_repo_id must be provided if raw_dir does not exist")
|
||||||
|
# download_raw(raw_dir, repo_id=raw_repo_id)
|
||||||
|
hdf5_files = []
|
||||||
|
for root, _, files in os.walk(raw_dir):
|
||||||
|
for filename in fnmatch.filter(files, '*.hdf5'):
|
||||||
|
file_path = os.path.join(root, filename)
|
||||||
|
hdf5_files.append(file_path)
|
||||||
|
|
||||||
|
dataset = create_empty_dataset(
|
||||||
|
repo_id,
|
||||||
|
robot_type="mobile_aloha" if is_mobile else "aloha",
|
||||||
|
mode=mode,
|
||||||
|
has_effort=has_effort(hdf5_files),
|
||||||
|
has_velocity=has_velocity(hdf5_files),
|
||||||
|
dataset_config=dataset_config,
|
||||||
|
)
|
||||||
|
dataset = populate_dataset(
|
||||||
|
dataset,
|
||||||
|
hdf5_files,
|
||||||
|
task=task,
|
||||||
|
episodes=episodes,
|
||||||
|
instruction=instruction,
|
||||||
|
)
|
||||||
|
# dataset.consolidate()
|
||||||
|
|
||||||
|
if push_to_hub:
|
||||||
|
dataset.push_to_hub()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
742
export_tools/export_smolvla.py
Normal file
742
export_tools/export_smolvla.py
Normal file
@ -0,0 +1,742 @@
|
|||||||
|
"""
|
||||||
|
SmolVLA Model Export Tool
|
||||||
|
|
||||||
|
This module exports SmolVLA models for BPU deployment, including:
|
||||||
|
- Vision encoder with connector
|
||||||
|
- VLM expert model (KV cache generation)
|
||||||
|
- Action expert model (denoising)
|
||||||
|
- Calibration data preparation
|
||||||
|
- Configuration file generation
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import copy
|
||||||
|
import random
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Tuple, Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from lerobot.policies.smolvla.modeling_smolvla import (
|
||||||
|
SmolVLAPolicy,
|
||||||
|
make_att_2d_masks,
|
||||||
|
)
|
||||||
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
|
||||||
|
|
||||||
|
# Constants
|
||||||
|
DEFAULT_REPO_ID = "adjust_bottle_clean_50_fps15_1instruction"
|
||||||
|
DEFAULT_DEVICE = "cuda:0"
|
||||||
|
DEFAULT_NUM_CALIBRATION_SAMPLES = 10
|
||||||
|
|
||||||
|
# Model component names
|
||||||
|
ALL_OUTPUT_NAME = "board_outputs_all"
|
||||||
|
TEST_DATA_NAME = "e2e_test_datas"
|
||||||
|
STATE_NORM_NAME = "state_normalize_unnormalize"
|
||||||
|
STATE_PROJ_NAME = "state_proj"
|
||||||
|
VISION_ENCODER_NAME = "vlm_vision_encoder_with_connecter"
|
||||||
|
VLM_EXPERT_MODEL_NAME = "vlm_expert"
|
||||||
|
ACTION_EXPERT_MODEL_NAME = "action_expert"
|
||||||
|
VLM_EXPERT_EMBEDDING = "language_embedding_matrix"
|
||||||
|
|
||||||
|
|
||||||
|
def parse_arguments() -> argparse.Namespace:
|
||||||
|
"""Parse command line arguments."""
|
||||||
|
repo_id = "adjust_bottle_clean_50_fps15_1instruction"
|
||||||
|
parser = argparse.ArgumentParser(description="Export SmolVLA models for BPU deployment")
|
||||||
|
parser.add_argument("--repo-id", type=str, default=repo_id, help="Repository ID for the model and dataset")
|
||||||
|
parser.add_argument("--smolvla-model-path", type=str, default=f"/home/chao.wu/SmolVLA_RoboTwin2_BPU/train_result_rtw15fps_1instruction/{repo_id}/checkpoints/040000/pretrained_model", help="Path to pretrained SmolVLA model")
|
||||||
|
parser.add_argument("--lerobot-dataset-path", type=str, default=f"/home/chao.wu/SmolVLA_RoboTwin2_BPU/huggingface/lerobot/rtw15fps_1instruction/{repo_id}", help="Path to LeRobot dataset")
|
||||||
|
parser.add_argument("--export-path", type=str, default="export_dir", help="Output directory for exported models")
|
||||||
|
parser.add_argument("--jobs", type=int, default=32, help="Number of parallel jobs for compilation")
|
||||||
|
parser.add_argument("--march", type=str, default="nash-m", help="Target architecture for BPU")
|
||||||
|
parser.add_argument("--debug", type=str, default="False", choices=["True", "False"], help="Enable debug mode")
|
||||||
|
parser.add_argument("--device", type=str, default=DEFAULT_DEVICE, help="Device to use (e.g., 'cuda:0', 'cpu')")
|
||||||
|
parser.add_argument("--num-calibration-samples", type=int, default=DEFAULT_NUM_CALIBRATION_SAMPLES, help="Number of calibration samples to use")
|
||||||
|
return parser.parse_args([])
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main entry point."""
|
||||||
|
# Set environment variables
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "4"
|
||||||
|
os.environ["HTTPS_PROXY"] = "http://192.168.16.68:18000"
|
||||||
|
|
||||||
|
args = parse_arguments()
|
||||||
|
|
||||||
|
# Create exporter and run
|
||||||
|
exporter = SmolVLAExporter(
|
||||||
|
model_path=args.smolvla_model_path,
|
||||||
|
dataset_path=args.lerobot_dataset_path,
|
||||||
|
export_path=args.export_path,
|
||||||
|
device=args.device,
|
||||||
|
num_calibration_samples=args.num_calibration_samples,
|
||||||
|
march=args.march,
|
||||||
|
jobs=args.jobs,
|
||||||
|
debug=args.debug,
|
||||||
|
)
|
||||||
|
|
||||||
|
exporter.export_all()
|
||||||
|
|
||||||
|
|
||||||
|
class BPUKVCache(nn.Module):
|
||||||
|
"""Wrapper for KV cache generation model."""
|
||||||
|
|
||||||
|
def __init__(self, policy: SmolVLAPolicy):
|
||||||
|
super().__init__()
|
||||||
|
self.policy = policy
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
prefix_att_2d_masks: torch.Tensor,
|
||||||
|
prefix_position_ids: torch.Tensor,
|
||||||
|
prefix_embs: torch.Tensor,
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
"""Generate KV cache from prefix embeddings."""
|
||||||
|
_, past_key_values = self.policy.model.vlm_with_expert.forward(
|
||||||
|
attention_mask=prefix_att_2d_masks,
|
||||||
|
position_ids=prefix_position_ids,
|
||||||
|
past_key_values=None,
|
||||||
|
inputs_embeds=[prefix_embs, None],
|
||||||
|
use_cache=self.policy.model.config.use_cache,
|
||||||
|
fill_kv_cache=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for i in range(len(past_key_values)):
|
||||||
|
results.append(past_key_values[i]['key_states'])
|
||||||
|
results.append(past_key_values[i]['value_states'])
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
class BPUDenoise(nn.Module):
|
||||||
|
"""Wrapper for denoising model."""
|
||||||
|
|
||||||
|
def __init__(self, policy: SmolVLAPolicy):
|
||||||
|
super().__init__()
|
||||||
|
self.policy = policy
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
prefix_pad_masks: torch.Tensor,
|
||||||
|
x_t: torch.Tensor,
|
||||||
|
expanded_time: torch.Tensor,
|
||||||
|
k_0: torch.Tensor,
|
||||||
|
v_0: torch.Tensor,
|
||||||
|
k_1: torch.Tensor,
|
||||||
|
v_1: torch.Tensor,
|
||||||
|
k_2: torch.Tensor,
|
||||||
|
v_2: torch.Tensor,
|
||||||
|
k_3: torch.Tensor,
|
||||||
|
v_3: torch.Tensor,
|
||||||
|
k_4: torch.Tensor,
|
||||||
|
v_4: torch.Tensor,
|
||||||
|
k_5: torch.Tensor,
|
||||||
|
v_5: torch.Tensor,
|
||||||
|
k_6: torch.Tensor,
|
||||||
|
v_6: torch.Tensor,
|
||||||
|
k_7: torch.Tensor,
|
||||||
|
v_7: torch.Tensor,
|
||||||
|
k_8: torch.Tensor,
|
||||||
|
v_8: torch.Tensor,
|
||||||
|
k_9: torch.Tensor,
|
||||||
|
v_9: torch.Tensor,
|
||||||
|
k_10: torch.Tensor,
|
||||||
|
v_10: torch.Tensor,
|
||||||
|
k_11: torch.Tensor,
|
||||||
|
v_11: torch.Tensor,
|
||||||
|
k_12: torch.Tensor,
|
||||||
|
v_12: torch.Tensor,
|
||||||
|
k_13: torch.Tensor,
|
||||||
|
v_13: torch.Tensor,
|
||||||
|
k_14: torch.Tensor,
|
||||||
|
v_14: torch.Tensor,
|
||||||
|
k_15: torch.Tensor,
|
||||||
|
v_15: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Perform one denoising step."""
|
||||||
|
key_states = [
|
||||||
|
k_0,
|
||||||
|
k_1,
|
||||||
|
k_2,
|
||||||
|
k_3,
|
||||||
|
k_4,
|
||||||
|
k_5,
|
||||||
|
k_6,
|
||||||
|
k_7,
|
||||||
|
k_8,
|
||||||
|
k_9,
|
||||||
|
k_10,
|
||||||
|
k_11,
|
||||||
|
k_12,
|
||||||
|
k_13,
|
||||||
|
k_14,
|
||||||
|
k_15,
|
||||||
|
]
|
||||||
|
value_states = [
|
||||||
|
v_0,
|
||||||
|
v_1,
|
||||||
|
v_2,
|
||||||
|
v_3,
|
||||||
|
v_4,
|
||||||
|
v_5,
|
||||||
|
v_6,
|
||||||
|
v_7,
|
||||||
|
v_8,
|
||||||
|
v_9,
|
||||||
|
v_10,
|
||||||
|
v_11,
|
||||||
|
v_12,
|
||||||
|
v_13,
|
||||||
|
v_14,
|
||||||
|
v_15,
|
||||||
|
]
|
||||||
|
|
||||||
|
past_key_values = {
|
||||||
|
i: {"key_states": key_states[i], "value_states": value_states[i]}
|
||||||
|
for i in range(len(self.policy.model.vlm_with_expert.lm_expert.layers))
|
||||||
|
}
|
||||||
|
|
||||||
|
v_t = self.policy.model.denoise_step(
|
||||||
|
prefix_pad_masks,
|
||||||
|
past_key_values,
|
||||||
|
x_t,
|
||||||
|
expanded_time,
|
||||||
|
)
|
||||||
|
return v_t
|
||||||
|
|
||||||
|
|
||||||
|
class DirectoryManager:
|
||||||
|
"""Manages export directory structure."""
|
||||||
|
|
||||||
|
def __init__(self, export_path: str, num_layers: int):
|
||||||
|
self.export_path = Path(export_path)
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self._create_directories()
|
||||||
|
|
||||||
|
def _create_directories(self):
|
||||||
|
"""Create all necessary directories."""
|
||||||
|
# Main directories
|
||||||
|
(self.export_path / ALL_OUTPUT_NAME).mkdir(parents=True, exist_ok=True)
|
||||||
|
(self.export_path / TEST_DATA_NAME).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Create workspace directories for each model
|
||||||
|
self._create_model_workspace(STATE_PROJ_NAME, ["state"], ["state_output"])
|
||||||
|
kv_output_names = []
|
||||||
|
for i in range(self.num_layers):
|
||||||
|
kv_output_names.append(f"k_{i}")
|
||||||
|
kv_output_names.append(f"v_{i}")
|
||||||
|
self._create_model_workspace(
|
||||||
|
VLM_EXPERT_MODEL_NAME,
|
||||||
|
["prefix_att_2d_masks", "prefix_position_ids", "prefix_embs"],
|
||||||
|
kv_output_names,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Action expert inputs
|
||||||
|
action_inputs = ["prefix_pad_masks", "x_t", "expanded_time"]
|
||||||
|
action_inputs.extend(kv_output_names)
|
||||||
|
self._create_model_workspace(
|
||||||
|
ACTION_EXPERT_MODEL_NAME, action_inputs, ["x_t_output"]
|
||||||
|
)
|
||||||
|
|
||||||
|
self._create_model_workspace(
|
||||||
|
VISION_ENCODER_NAME, ["pixel_values"], ["hidden_state"]
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_model_workspace(
|
||||||
|
self, model_name: str, input_names: List[str], output_names: List[str]
|
||||||
|
):
|
||||||
|
"""Create workspace structure for a model."""
|
||||||
|
ws_path = self.export_path / f"{model_name}_ws"
|
||||||
|
ws_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Calibration directories
|
||||||
|
cal_path = ws_path / "cal"
|
||||||
|
cal_path.mkdir(exist_ok=True)
|
||||||
|
for name in input_names:
|
||||||
|
(cal_path / name).mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
# Output directories
|
||||||
|
output_path = ws_path / "output"
|
||||||
|
output_path.mkdir(exist_ok=True)
|
||||||
|
for name in output_names:
|
||||||
|
(output_path / name).mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigGenerator:
|
||||||
|
"""Generates YAML configuration and bash scripts."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate_yaml_config(
|
||||||
|
model_name: str,
|
||||||
|
input_names: List[str],
|
||||||
|
march: str,
|
||||||
|
jobs: int,
|
||||||
|
debug: str,
|
||||||
|
extra_node_config: Dict[str, Dict] = None,
|
||||||
|
) -> str:
|
||||||
|
"""Generate YAML configuration for model compilation."""
|
||||||
|
input_name_str = ";".join(input_names) + ";"
|
||||||
|
layout_str = "NCHW;" * len(input_names)
|
||||||
|
type_str = "featuremap;" * len(input_names)
|
||||||
|
norm_type_str = "no_preprocess;" * len(input_names)
|
||||||
|
cal_data_dir_str = ";".join([f"./cal/{name}" for name in input_names]) + ";"
|
||||||
|
|
||||||
|
node_config_str = ""
|
||||||
|
if extra_node_config:
|
||||||
|
node_config_items = ",\n".join(
|
||||||
|
f' "{node}": {{"qtype": "float32"}}'
|
||||||
|
for node in extra_node_config
|
||||||
|
)
|
||||||
|
node_config_str = f"""
|
||||||
|
"node_config": {{
|
||||||
|
{node_config_items}
|
||||||
|
}}"""
|
||||||
|
|
||||||
|
yaml_content = f"""model_parameters:
|
||||||
|
onnx_model: {model_name}.onnx
|
||||||
|
march: {march}
|
||||||
|
layer_out_dump: False
|
||||||
|
working_dir: bpu_output
|
||||||
|
output_model_file_prefix: {model_name}_featuremaps
|
||||||
|
enable_vpu: True
|
||||||
|
input_parameters:
|
||||||
|
input_name: {input_name_str}
|
||||||
|
input_layout_rt: {layout_str}
|
||||||
|
input_layout_train: {layout_str}
|
||||||
|
input_type_rt: {type_str}
|
||||||
|
input_type_train: {type_str}
|
||||||
|
norm_type: {norm_type_str}
|
||||||
|
calibration_parameters:
|
||||||
|
cal_data_dir: '{cal_data_dir_str}'
|
||||||
|
quant_config: {{
|
||||||
|
"model_config": {{
|
||||||
|
"all_node_type": "int16",
|
||||||
|
"model_output_type": "float32",
|
||||||
|
"activation": {{
|
||||||
|
"calibration_type": ["max"],
|
||||||
|
"num_bin": [1024, 2048, 4096],
|
||||||
|
"max_num_bin": 16384,
|
||||||
|
"max_percentile": 1.0,
|
||||||
|
"per_channel": true,
|
||||||
|
"asymmetric": [true]
|
||||||
|
}},
|
||||||
|
"weight": {{
|
||||||
|
"bias_correction": {{
|
||||||
|
"metric": "mae"
|
||||||
|
}}
|
||||||
|
}},
|
||||||
|
"modelwise_search": {{
|
||||||
|
"metric": "mae"
|
||||||
|
}}
|
||||||
|
}},
|
||||||
|
"model_config": {{
|
||||||
|
"all_node_type": "int16",
|
||||||
|
"model_output_type": "int16",
|
||||||
|
}},
|
||||||
|
"op_config": {{
|
||||||
|
"ReduceMean": {{"qtype": "int16"}},
|
||||||
|
"Sub": {{"qtype": "int16"}},
|
||||||
|
"Softmax": {{"qtype": "int16"}}
|
||||||
|
}},{node_config_str}
|
||||||
|
}}
|
||||||
|
compiler_parameters:
|
||||||
|
extra_params: {{'input_no_padding': True, 'output_no_padding': True}}
|
||||||
|
jobs: {jobs}
|
||||||
|
compile_mode: 'latency'
|
||||||
|
debug: {debug}
|
||||||
|
advice: 1
|
||||||
|
optimize_level: 'O2'
|
||||||
|
"""
|
||||||
|
return yaml_content
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate_bash_script(model_name: str) -> str:
|
||||||
|
"""Generate bash build script."""
|
||||||
|
return f"""hb_compile --config config.yaml
|
||||||
|
chmod 777 ./*
|
||||||
|
chmod 777 ./*/*
|
||||||
|
chmod 777 ./*/*/*
|
||||||
|
cp bpu_output/{model_name}_featuremaps.hbm ../{ALL_OUTPUT_NAME}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class SmolVLAExporter:
|
||||||
|
"""Main exporter class for SmolVLA models."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_path: str,
|
||||||
|
dataset_path: str,
|
||||||
|
export_path: str,
|
||||||
|
device: str = DEFAULT_DEVICE,
|
||||||
|
num_calibration_samples: int = DEFAULT_NUM_CALIBRATION_SAMPLES,
|
||||||
|
march: str = "nash-m",
|
||||||
|
jobs: int = 32,
|
||||||
|
debug: str = "False",
|
||||||
|
):
|
||||||
|
self.model_path = model_path
|
||||||
|
self.dataset_path = dataset_path
|
||||||
|
self.export_path = Path(export_path)
|
||||||
|
self.device = torch.device(device)
|
||||||
|
self.num_calibration_samples = num_calibration_samples
|
||||||
|
self.march = march
|
||||||
|
self.jobs = jobs
|
||||||
|
self.debug = debug
|
||||||
|
|
||||||
|
# Load policy and dataset
|
||||||
|
self.policy = self._load_policy()
|
||||||
|
self.dataset = self._load_dataset()
|
||||||
|
self.data_indices = self._select_calibration_samples()
|
||||||
|
|
||||||
|
# Setup directories
|
||||||
|
num_layers = len(self.policy.model.vlm_with_expert.lm_expert.layers)
|
||||||
|
self.dir_manager = DirectoryManager(export_path, num_layers)
|
||||||
|
|
||||||
|
# Define input/output names
|
||||||
|
self.input_names_kv = [
|
||||||
|
"prefix_att_2d_masks",
|
||||||
|
"prefix_position_ids",
|
||||||
|
"prefix_embs",
|
||||||
|
]
|
||||||
|
self.output_names_kv = []
|
||||||
|
for i in range(num_layers):
|
||||||
|
self.output_names_kv.append(f"k_{i}")
|
||||||
|
self.output_names_kv.append(f"v_{i}")
|
||||||
|
self.input_names_denoise = ["prefix_pad_masks", "x_t", "expanded_time"]
|
||||||
|
self.input_names_denoise.extend(self.output_names_kv)
|
||||||
|
self.output_names_denoise = ["x_t_output"]
|
||||||
|
|
||||||
|
def _load_policy(self) -> SmolVLAPolicy:
|
||||||
|
"""Load pretrained policy model."""
|
||||||
|
policy = SmolVLAPolicy.from_pretrained(self.model_path)
|
||||||
|
return policy.to(self.device).float().eval()
|
||||||
|
|
||||||
|
def _load_dataset(self) -> LeRobotDataset:
|
||||||
|
"""Load LeRobot dataset."""
|
||||||
|
dataset = LeRobotDataset(repo_id="Foo/Bar", root=self.dataset_path)
|
||||||
|
print(f"Dataset loaded: {len(dataset)} samples")
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
def _select_calibration_samples(self) -> List[int]:
|
||||||
|
"""Randomly select calibration samples."""
|
||||||
|
return random.sample(range(len(self.dataset)), self.num_calibration_samples)
|
||||||
|
|
||||||
|
def export_language_embedding(self):
|
||||||
|
"""Export language embedding matrix."""
|
||||||
|
embedding_matrix = (
|
||||||
|
self.policy.model.vlm_with_expert.vlm.model.text_model.embed_tokens.weight.detach()
|
||||||
|
.cpu()
|
||||||
|
.float()
|
||||||
|
.numpy()
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"Language embedding shape: {embedding_matrix.shape}, dtype: {embedding_matrix.dtype}"
|
||||||
|
)
|
||||||
|
|
||||||
|
output_path = self.export_path / ALL_OUTPUT_NAME / f"{VLM_EXPERT_EMBEDDING}.npy"
|
||||||
|
np.save(output_path, embedding_matrix)
|
||||||
|
|
||||||
|
def export_normalization_params(self):
|
||||||
|
"""Export state normalization/unnormalization parameters."""
|
||||||
|
params = {
|
||||||
|
"normalize_inputs.mean": self.policy.normalize_inputs.buffer_observation_state.mean.data.detach().cpu(),
|
||||||
|
"normalize_inputs.std": self.policy.normalize_inputs.buffer_observation_state.std.data.detach().cpu(),
|
||||||
|
"unnormalize_outputs.mean": self.policy.unnormalize_outputs.buffer_action.mean.data.detach().cpu(),
|
||||||
|
"unnormalize_outputs.std": self.policy.unnormalize_outputs.buffer_action.std.data.detach().cpu(),
|
||||||
|
}
|
||||||
|
|
||||||
|
output_path = self.export_path / ALL_OUTPUT_NAME / f"{STATE_NORM_NAME}.pt"
|
||||||
|
torch.save(params, output_path)
|
||||||
|
|
||||||
|
def export_denoise_model(self):
|
||||||
|
"""Export denoising model to ONNX."""
|
||||||
|
# Prepare sample data
|
||||||
|
sample_data = self.dataset[0]
|
||||||
|
obs = self._prepare_observation(sample_data)
|
||||||
|
batch = self.policy.normalize_inputs(copy.deepcopy(obs))
|
||||||
|
|
||||||
|
# Prepare inputs
|
||||||
|
images, img_masks = self.policy.prepare_images(batch)
|
||||||
|
state = self.policy.prepare_state(batch)
|
||||||
|
lang_tokens, lang_masks = self.policy.prepare_language(batch)
|
||||||
|
|
||||||
|
# Generate prefix embeddings
|
||||||
|
prefix_embs, prefix_pad_masks, prefix_att_masks = (
|
||||||
|
self.policy.model.embed_prefix(
|
||||||
|
images, img_masks, lang_tokens, lang_masks, state=state
|
||||||
|
)
|
||||||
|
)
|
||||||
|
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
||||||
|
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||||
|
|
||||||
|
# Generate KV cache
|
||||||
|
m_kv = BPUKVCache(self.policy)
|
||||||
|
with torch.no_grad():
|
||||||
|
kv_caches = m_kv(prefix_att_2d_masks, prefix_position_ids, prefix_embs)
|
||||||
|
|
||||||
|
# Prepare denoising inputs
|
||||||
|
noise = self._generate_noise(batch_size=1)
|
||||||
|
m_denoise = BPUDenoise(self.policy)
|
||||||
|
m_denoise.eval()
|
||||||
|
|
||||||
|
dt = -1.0 / self.policy.model.config.num_steps
|
||||||
|
dt = torch.tensor(dt, dtype=torch.float32, device=self.device)
|
||||||
|
x_t = noise
|
||||||
|
time = torch.tensor(1.0, dtype=torch.float32, device=self.device)
|
||||||
|
while time >= -dt / 2:
|
||||||
|
expanded_time = time.expand(1)
|
||||||
|
input_tensors = (prefix_pad_masks, x_t, expanded_time, *kv_caches)
|
||||||
|
with torch.no_grad():
|
||||||
|
v_t = m_denoise(*input_tensors)
|
||||||
|
x_t += dt * v_t
|
||||||
|
time += dt
|
||||||
|
|
||||||
|
onnx_path = (
|
||||||
|
self.export_path
|
||||||
|
/ f"{ACTION_EXPERT_MODEL_NAME}_ws"
|
||||||
|
/ f"{ACTION_EXPERT_MODEL_NAME}.onnx"
|
||||||
|
)
|
||||||
|
torch.onnx.export(
|
||||||
|
m_denoise,
|
||||||
|
input_tensors,
|
||||||
|
onnx_path,
|
||||||
|
export_params=True,
|
||||||
|
opset_version=19,
|
||||||
|
do_constant_folding=True,
|
||||||
|
input_names=self.input_names_denoise,
|
||||||
|
output_names=self.output_names_denoise,
|
||||||
|
dynamic_axes=None,
|
||||||
|
dynamo=False,
|
||||||
|
)
|
||||||
|
print(f"Denoising model exported to {onnx_path}")
|
||||||
|
|
||||||
|
def prepare_calibration_data(self):
|
||||||
|
"""Prepare calibration data for all models."""
|
||||||
|
m_kv = BPUKVCache(self.policy)
|
||||||
|
m_denoise = BPUDenoise(self.policy)
|
||||||
|
|
||||||
|
kv_cnt = 0
|
||||||
|
denoise_cnt = 0
|
||||||
|
|
||||||
|
for idx in tqdm(self.data_indices, desc="Preparing calibration data"):
|
||||||
|
sample_data = self.dataset[idx]
|
||||||
|
obs = self._prepare_observation(sample_data)
|
||||||
|
batch = self.policy.normalize_inputs(copy.deepcopy(obs))
|
||||||
|
|
||||||
|
# Prepare inputs
|
||||||
|
images, img_masks = self.policy.prepare_images(batch)
|
||||||
|
state = self.policy.prepare_state(batch)
|
||||||
|
lang_tokens, lang_masks = self.policy.prepare_language(batch)
|
||||||
|
|
||||||
|
# Generate prefix embeddings
|
||||||
|
prefix_embs, prefix_pad_masks, prefix_att_masks = (
|
||||||
|
self.policy.model.embed_prefix(
|
||||||
|
images, img_masks, lang_tokens, lang_masks, state=state
|
||||||
|
)
|
||||||
|
)
|
||||||
|
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
||||||
|
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||||
|
|
||||||
|
# Save KV cache inputs and outputs
|
||||||
|
self._save_calibration_tensors(
|
||||||
|
[prefix_att_2d_masks, prefix_position_ids, prefix_embs],
|
||||||
|
self.input_names_kv,
|
||||||
|
VLM_EXPERT_MODEL_NAME,
|
||||||
|
"cal",
|
||||||
|
kv_cnt,
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
kv_caches = m_kv(prefix_att_2d_masks, prefix_position_ids, prefix_embs)
|
||||||
|
|
||||||
|
self._save_calibration_tensors(
|
||||||
|
kv_caches, self.output_names_kv, VLM_EXPERT_MODEL_NAME, "output", kv_cnt
|
||||||
|
)
|
||||||
|
kv_cnt += 1
|
||||||
|
|
||||||
|
# Diffusion loop for denoising calibration
|
||||||
|
noise = self._generate_noise(batch_size=1)
|
||||||
|
dt = -1.0 / self.policy.model.config.num_steps
|
||||||
|
dt = torch.tensor(dt, dtype=torch.float32, device=self.device)
|
||||||
|
x_t = noise
|
||||||
|
time = torch.tensor(1.0, dtype=torch.float32, device=self.device)
|
||||||
|
|
||||||
|
if kv_cnt % 5 == 0: # 保留所有trace作为校准数据
|
||||||
|
while time >= -dt / 2:
|
||||||
|
expanded_time = time.expand(1)
|
||||||
|
input_tensors = (prefix_pad_masks, x_t, expanded_time, *kv_caches)
|
||||||
|
with torch.no_grad():
|
||||||
|
v_t = m_denoise(*input_tensors)
|
||||||
|
self._save_calibration_tensors(
|
||||||
|
input_tensors, self.input_names_denoise,
|
||||||
|
ACTION_EXPERT_MODEL_NAME, "cal", denoise_cnt,
|
||||||
|
)
|
||||||
|
self._save_calibration_tensors(
|
||||||
|
[v_t], self.output_names_denoise,
|
||||||
|
ACTION_EXPERT_MODEL_NAME, "output", denoise_cnt,
|
||||||
|
)
|
||||||
|
x_t += dt * v_t
|
||||||
|
time += dt
|
||||||
|
denoise_cnt += 1
|
||||||
|
else: # 保留一半作为校准数据
|
||||||
|
cnt_ls = 0
|
||||||
|
while time >= -dt / 2:
|
||||||
|
expanded_time = time.expand(1)
|
||||||
|
input_tensors = (prefix_pad_masks, x_t, expanded_time, *kv_caches)
|
||||||
|
with torch.no_grad():
|
||||||
|
v_t = m_denoise(*input_tensors)
|
||||||
|
cnt_ls += 1
|
||||||
|
if cnt_ls % 3 == 0:
|
||||||
|
self._save_calibration_tensors(
|
||||||
|
input_tensors, self.input_names_denoise,
|
||||||
|
ACTION_EXPERT_MODEL_NAME, "cal", denoise_cnt,
|
||||||
|
)
|
||||||
|
self._save_calibration_tensors(
|
||||||
|
[v_t], self.output_names_denoise,
|
||||||
|
ACTION_EXPERT_MODEL_NAME, "output", denoise_cnt,
|
||||||
|
)
|
||||||
|
denoise_cnt += 1
|
||||||
|
x_t += dt * v_t
|
||||||
|
time += dt
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Calibration data prepared: {kv_cnt} KV samples, {denoise_cnt} denoise samples"
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate_config_files(self):
|
||||||
|
"""Generate YAML configs and bash scripts for compilation."""
|
||||||
|
config_gen = ConfigGenerator()
|
||||||
|
|
||||||
|
# VLM expert config
|
||||||
|
yaml_content = config_gen.generate_yaml_config(
|
||||||
|
VLM_EXPERT_MODEL_NAME,
|
||||||
|
self.input_names_kv,
|
||||||
|
self.march,
|
||||||
|
self.jobs,
|
||||||
|
self.debug,
|
||||||
|
)
|
||||||
|
self._write_config_files(VLM_EXPERT_MODEL_NAME, yaml_content)
|
||||||
|
|
||||||
|
# Action expert config with extra node configuration
|
||||||
|
extra_node_config = [
|
||||||
|
"/Unsqueeze",
|
||||||
|
"/Mul",
|
||||||
|
"/Cos",
|
||||||
|
"/Sin",
|
||||||
|
"/Concat",
|
||||||
|
"/Cast",
|
||||||
|
"/Unsqueeze_1",
|
||||||
|
]
|
||||||
|
yaml_content = config_gen.generate_yaml_config(
|
||||||
|
ACTION_EXPERT_MODEL_NAME,
|
||||||
|
self.input_names_denoise,
|
||||||
|
self.march,
|
||||||
|
self.jobs,
|
||||||
|
self.debug,
|
||||||
|
extra_node_config,
|
||||||
|
)
|
||||||
|
self._write_config_files(ACTION_EXPERT_MODEL_NAME, yaml_content)
|
||||||
|
|
||||||
|
def _prepare_observation(self, data: Dict) -> Dict[str, Any]:
|
||||||
|
"""Prepare observation dictionary from dataset sample."""
|
||||||
|
obs = {
|
||||||
|
"instruction": data["task"],
|
||||||
|
"task": data["task"],
|
||||||
|
"observation.images.cam_high": data["observation.images.cam_high"]
|
||||||
|
.unsqueeze(0)
|
||||||
|
.to(self.device),
|
||||||
|
"observation.images.cam_left_wrist": data[
|
||||||
|
"observation.images.cam_left_wrist"
|
||||||
|
]
|
||||||
|
.unsqueeze(0)
|
||||||
|
.to(self.device),
|
||||||
|
"observation.images.cam_right_wrist": data[
|
||||||
|
"observation.images.cam_right_wrist"
|
||||||
|
]
|
||||||
|
.unsqueeze(0)
|
||||||
|
.to(self.device),
|
||||||
|
"observation.state": data["action"].unsqueeze(0).to(self.device),
|
||||||
|
}
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def _generate_noise(self, batch_size: int) -> torch.Tensor:
|
||||||
|
"""Generate noise tensor for diffusion."""
|
||||||
|
actions_shape = (
|
||||||
|
batch_size,
|
||||||
|
self.policy.model.config.chunk_size,
|
||||||
|
self.policy.model.config.max_action_dim,
|
||||||
|
)
|
||||||
|
return torch.normal(
|
||||||
|
mean=0.0,
|
||||||
|
std=1.0,
|
||||||
|
size=actions_shape,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _save_calibration_tensors(
|
||||||
|
self,
|
||||||
|
tensors: List[torch.Tensor],
|
||||||
|
names: List[str],
|
||||||
|
model_name: str,
|
||||||
|
subdir: str,
|
||||||
|
index: int,
|
||||||
|
):
|
||||||
|
"""Save calibration tensors to disk."""
|
||||||
|
for tensor, name in zip(tensors, names):
|
||||||
|
output_path = (
|
||||||
|
self.export_path / f"{model_name}_ws" / subdir / name / f"{index}.npy"
|
||||||
|
)
|
||||||
|
np.save(output_path, tensor.detach().cpu().numpy())
|
||||||
|
|
||||||
|
def _write_config_files(self, model_name: str, yaml_content: str):
|
||||||
|
"""Write YAML config and bash script for a model."""
|
||||||
|
ws_path = self.export_path / f"{model_name}_ws"
|
||||||
|
|
||||||
|
# Write YAML config
|
||||||
|
yaml_path = ws_path / "config.yaml"
|
||||||
|
with open(yaml_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(yaml_content)
|
||||||
|
|
||||||
|
# Write bash script
|
||||||
|
bash_content = ConfigGenerator.generate_bash_script(model_name)
|
||||||
|
bash_path = ws_path / "build.bash"
|
||||||
|
with open(bash_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(bash_content)
|
||||||
|
|
||||||
|
print(f"Config files written for {model_name}")
|
||||||
|
|
||||||
|
def export_all(self):
|
||||||
|
"""Run complete export pipeline."""
|
||||||
|
print("Starting SmolVLA export...")
|
||||||
|
print(f"Export path: {self.export_path}")
|
||||||
|
|
||||||
|
print("\n[1/5] Exporting language embedding...")
|
||||||
|
self.export_language_embedding()
|
||||||
|
|
||||||
|
print("\n[2/5] Exporting normalization parameters...")
|
||||||
|
self.export_normalization_params()
|
||||||
|
|
||||||
|
print("\n[3/5] Exporting denoising model...")
|
||||||
|
self.export_denoise_model()
|
||||||
|
|
||||||
|
print("\n[4/5] Preparing calibration data...")
|
||||||
|
self.prepare_calibration_data()
|
||||||
|
|
||||||
|
print("\n[5/5] Generating configuration files...")
|
||||||
|
self.generate_config_files()
|
||||||
|
|
||||||
|
print("\n✓ Export completed successfully!")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Loading…
x
Reference in New Issue
Block a user