SmolVLA_Tools/RoboTwin_Policy/BPU_LeRobot_SmolVLA_Cloud4VLA_BPU4A.py
2026-03-03 16:24:02 +08:00

169 lines
6.5 KiB
Python

# Copyright (c) 2025, Cauchy WuChao, D-Robotics.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "5"
# os.environ["HTTPS_PROXY"] = "http://192.168.16.68:18000"
import argparse
import logging
import time
import torch
import numpy as np
import tempfile
import tarfile
from tools import measure_time, show_data_summary
from binary_protocol import dict_to_binary, binary_to_dict
from request_tools import send_inference_request
from pyCauchyKesai import CauchyKesai
from flask import Flask, Response, request
from torch import Tensor
# lerobot.constants
ACTION = "action"
OBS_STATE = "observation.state"
logging.basicConfig(
level = logging.INFO,
format = '[%(name)s] [%(asctime)s.%(msecs)03d] [%(levelname)s] %(message)s',
datefmt='%H:%M:%S')
logger = logging.getLogger("SmolVLA_Server")
app = Flask(__name__)
global smolvla_policy_a, normalize_params, device, action_feature_dim, vl_server_url
smolvla_policy_a = None
normalize_params = None
device = None
action_feature_dim = None
vl_server_url = None
@app.route('/infer', methods=['POST'])
def infer():
global smolvla_policy_a, normalize_params, device, action_feature_dim, vl_server_url
def unnormalize_outputs(batch: dict[str, Tensor]) -> dict[str, Tensor]:
batch[ACTION] = batch[ACTION] * normalize_params["unnormalize_outputs.std"] + normalize_params["unnormalize_outputs.mean"]
return batch
# global check
if smolvla_policy_a is None:
return Response(dict_to_binary({"status": "error", "message": "Service not ready"}),
status=503, mimetype='application/octet-stream')
# try:
if True:
# binary_to_dict
begin_time_part = time.time()
raw_data = request.data
logger.info(f"request.data time = {1000*(time.time() - begin_time_part):.2f} ms")
begin_time_part = time.time()
data = binary_to_dict(raw_data)
# show_data_summary(data)
obs = {}
for k in data.keys():
if isinstance(data[k], np.ndarray):
obs[k] = data[k]
else:
obs[k] = data[k]
logger.info(f"binary_to_dict time = {1000*(time.time() - begin_time_part):.2f} ms")
# show_data_summary(obs)
begin_time = time.time()
# 送入云端推理得到kv caches
begin_time_part = time.time()
diffusion = send_inference_request(obs, url=f"{vl_server_url}", timeout=10, max_retries=3)
logger.info(f"Cloud4VL inference+post time = {1000*(time.time() - begin_time_part):.2f} ms")
# show_data_summary(diffusion)
# 解析云端推理结果并开始BPU推理A(10次flow matching)
begin_time_part = time.time()
num_steps = diffusion['num_steps']
prefix_pad_masks = diffusion['prefix_pad_masks'].astype(np.bool)
x_t = diffusion['x_t'].astype(np.float32)
past_key_values = []
for i in range(16):
past_key_values.append(diffusion[f'k_{i}'].astype(np.float32))
past_key_values.append(diffusion[f'v_{i}'].astype(np.float32))
# print(f"{prefix_pad_masks.dtype = }")
# print(f"{x_t.dtype = }")
# # # print(f"{time_fm.dtype = }")
# for cnt, _ in enumerate(past_key_values):
# print(f"{cnt}, {_.dtype = }")
dt = -1.0 / num_steps
time_fm = np.array([1.0]).astype(np.float32)
while time_fm >= -dt / 2:
v_t = smolvla_policy_a([prefix_pad_masks, x_t, time_fm, *past_key_values])[0]
# Euler step
x_t += dt * v_t
time_fm += dt
logger.info(f"BPU4A 10 flow matching time = {1000*(time.time() - begin_time_part):.2f} ms")
# Post Process
begin_time_part = time.time()
actions = torch.from_numpy(x_t)
actions = actions[:, :, :action_feature_dim]
actions = unnormalize_outputs({ACTION: actions})[ACTION]
action_chunk = actions.numpy()
logger.info(f"Post Process time = {1000*(time.time() - begin_time_part):.2f} ms")
logger.info("\033[1;31m" + f"Cloud4VL, BPU4A, e2e inference time = {1000*(time.time() - begin_time):.2f} ms" + "\033[0m")
# dict_to_binary
response_blob = dict_to_binary({"status": "ok", 'action_chunk': action_chunk})
return Response(response_blob, status=200, mimetype='application/octet-stream')
# except Exception as e:
# logger.error(e)
# return Response(dict_to_binary({"status": "error", "message": str(e)}),
# status=500, mimetype='application/octet-stream')
def main():
parser = argparse.ArgumentParser()
root_path = "/root/ssd/SmolVLA_BPU/LeRobot_SmolVLA_Quantization_Small/LeRobot_SmolVLA_Quantization/ls_export_adjust_bottle_randomized_500_fps15_1instruction/board_outputs_all"
parser.add_argument('--action-hbm-model', type=str, default=f"{root_path}/action_expert_featuremaps.hbm", help="")
parser.add_argument('--state-normalize', type=str, default=f"{root_path}/state_normalize_unnormalize.pt", help="")
parser.add_argument('--action_feature_dim', type=int, default=14, help="LeRobot So101: 6, 14")
parser.add_argument('--vl-server-url', type=str, default="http://10.112.20.37:50002/infer_vl", help="")
# parser.add_argument('--vl-server-url', type=str, default="http://120.48.157.2:50002/infer_vl", help="")
parser.add_argument('--device', type=str, default="cpu", help="")
parser.add_argument('--port', type=int, default=60002, help="")
opt = parser.parse_args()
logger.info(opt)
global smolvla_policy_a, normalize_params, device, action_feature_dim, vl_server_url
device = torch.device(opt.device)
normalize_params = torch.load(opt.state_normalize)
action_feature_dim = opt.action_feature_dim
vl_server_url = opt.vl_server_url
logger.info("Loading model ...")
smolvla_policy_a = CauchyKesai(opt.action_hbm_model)
smolvla_policy_a.s()
app.run(host='0.0.0.0', port=opt.port, threaded=False, debug=False)
if __name__ == "__main__":
main()