169 lines
6.5 KiB
Python
169 lines
6.5 KiB
Python
# Copyright (c) 2025, Cauchy WuChao, D-Robotics.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
|
|
import os
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "5"
|
|
# os.environ["HTTPS_PROXY"] = "http://192.168.16.68:18000"
|
|
import argparse
|
|
import logging
|
|
import time
|
|
import torch
|
|
import numpy as np
|
|
|
|
import tempfile
|
|
import tarfile
|
|
from tools import measure_time, show_data_summary
|
|
from binary_protocol import dict_to_binary, binary_to_dict
|
|
from request_tools import send_inference_request
|
|
|
|
from pyCauchyKesai import CauchyKesai
|
|
from flask import Flask, Response, request
|
|
|
|
from torch import Tensor
|
|
|
|
# lerobot.constants
|
|
ACTION = "action"
|
|
OBS_STATE = "observation.state"
|
|
|
|
logging.basicConfig(
|
|
level = logging.INFO,
|
|
format = '[%(name)s] [%(asctime)s.%(msecs)03d] [%(levelname)s] %(message)s',
|
|
datefmt='%H:%M:%S')
|
|
logger = logging.getLogger("SmolVLA_Server")
|
|
|
|
app = Flask(__name__)
|
|
|
|
global smolvla_policy_a, normalize_params, device, action_feature_dim, vl_server_url
|
|
smolvla_policy_a = None
|
|
normalize_params = None
|
|
device = None
|
|
action_feature_dim = None
|
|
vl_server_url = None
|
|
|
|
|
|
@app.route('/infer', methods=['POST'])
|
|
def infer():
|
|
global smolvla_policy_a, normalize_params, device, action_feature_dim, vl_server_url
|
|
|
|
def unnormalize_outputs(batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
batch[ACTION] = batch[ACTION] * normalize_params["unnormalize_outputs.std"] + normalize_params["unnormalize_outputs.mean"]
|
|
return batch
|
|
|
|
# global check
|
|
if smolvla_policy_a is None:
|
|
return Response(dict_to_binary({"status": "error", "message": "Service not ready"}),
|
|
status=503, mimetype='application/octet-stream')
|
|
# try:
|
|
if True:
|
|
# binary_to_dict
|
|
begin_time_part = time.time()
|
|
raw_data = request.data
|
|
logger.info(f"request.data time = {1000*(time.time() - begin_time_part):.2f} ms")
|
|
begin_time_part = time.time()
|
|
data = binary_to_dict(raw_data)
|
|
# show_data_summary(data)
|
|
obs = {}
|
|
for k in data.keys():
|
|
if isinstance(data[k], np.ndarray):
|
|
obs[k] = data[k]
|
|
else:
|
|
obs[k] = data[k]
|
|
logger.info(f"binary_to_dict time = {1000*(time.time() - begin_time_part):.2f} ms")
|
|
# show_data_summary(obs)
|
|
|
|
begin_time = time.time()
|
|
# 送入云端推理得到kv caches
|
|
begin_time_part = time.time()
|
|
diffusion = send_inference_request(obs, url=f"{vl_server_url}", timeout=10, max_retries=3)
|
|
logger.info(f"Cloud4VL inference+post time = {1000*(time.time() - begin_time_part):.2f} ms")
|
|
|
|
# show_data_summary(diffusion)
|
|
# 解析云端推理结果并开始BPU推理A(10次flow matching)
|
|
begin_time_part = time.time()
|
|
num_steps = diffusion['num_steps']
|
|
prefix_pad_masks = diffusion['prefix_pad_masks'].astype(np.bool)
|
|
x_t = diffusion['x_t'].astype(np.float32)
|
|
past_key_values = []
|
|
for i in range(16):
|
|
past_key_values.append(diffusion[f'k_{i}'].astype(np.float32))
|
|
past_key_values.append(diffusion[f'v_{i}'].astype(np.float32))
|
|
|
|
# print(f"{prefix_pad_masks.dtype = }")
|
|
# print(f"{x_t.dtype = }")
|
|
# # # print(f"{time_fm.dtype = }")
|
|
# for cnt, _ in enumerate(past_key_values):
|
|
# print(f"{cnt}, {_.dtype = }")
|
|
|
|
|
|
|
|
dt = -1.0 / num_steps
|
|
time_fm = np.array([1.0]).astype(np.float32)
|
|
while time_fm >= -dt / 2:
|
|
v_t = smolvla_policy_a([prefix_pad_masks, x_t, time_fm, *past_key_values])[0]
|
|
# Euler step
|
|
x_t += dt * v_t
|
|
time_fm += dt
|
|
logger.info(f"BPU4A 10 flow matching time = {1000*(time.time() - begin_time_part):.2f} ms")
|
|
|
|
# Post Process
|
|
begin_time_part = time.time()
|
|
actions = torch.from_numpy(x_t)
|
|
actions = actions[:, :, :action_feature_dim]
|
|
actions = unnormalize_outputs({ACTION: actions})[ACTION]
|
|
action_chunk = actions.numpy()
|
|
logger.info(f"Post Process time = {1000*(time.time() - begin_time_part):.2f} ms")
|
|
|
|
logger.info("\033[1;31m" + f"Cloud4VL, BPU4A, e2e inference time = {1000*(time.time() - begin_time):.2f} ms" + "\033[0m")
|
|
|
|
# dict_to_binary
|
|
response_blob = dict_to_binary({"status": "ok", 'action_chunk': action_chunk})
|
|
|
|
return Response(response_blob, status=200, mimetype='application/octet-stream')
|
|
|
|
# except Exception as e:
|
|
# logger.error(e)
|
|
# return Response(dict_to_binary({"status": "error", "message": str(e)}),
|
|
# status=500, mimetype='application/octet-stream')
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
root_path = "/root/ssd/SmolVLA_BPU/LeRobot_SmolVLA_Quantization_Small/LeRobot_SmolVLA_Quantization/ls_export_adjust_bottle_randomized_500_fps15_1instruction/board_outputs_all"
|
|
parser.add_argument('--action-hbm-model', type=str, default=f"{root_path}/action_expert_featuremaps.hbm", help="")
|
|
parser.add_argument('--state-normalize', type=str, default=f"{root_path}/state_normalize_unnormalize.pt", help="")
|
|
parser.add_argument('--action_feature_dim', type=int, default=14, help="LeRobot So101: 6, 14")
|
|
parser.add_argument('--vl-server-url', type=str, default="http://10.112.20.37:50002/infer_vl", help="")
|
|
# parser.add_argument('--vl-server-url', type=str, default="http://120.48.157.2:50002/infer_vl", help="")
|
|
parser.add_argument('--device', type=str, default="cpu", help="")
|
|
parser.add_argument('--port', type=int, default=60002, help="")
|
|
opt = parser.parse_args()
|
|
logger.info(opt)
|
|
|
|
global smolvla_policy_a, normalize_params, device, action_feature_dim, vl_server_url
|
|
device = torch.device(opt.device)
|
|
normalize_params = torch.load(opt.state_normalize)
|
|
action_feature_dim = opt.action_feature_dim
|
|
vl_server_url = opt.vl_server_url
|
|
|
|
logger.info("Loading model ...")
|
|
smolvla_policy_a = CauchyKesai(opt.action_hbm_model)
|
|
smolvla_policy_a.s()
|
|
|
|
app.run(host='0.0.0.0', port=opt.port, threaded=False, debug=False)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|