SmolVLA_Tools/RoboTwin_Policy/GPU_LeRobot_SmolVLA_Cloud4VL_BPU4A.py
2026-03-03 16:24:02 +08:00

231 lines
9.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Copyright (c) 2025, Cauchy WuChao, D-Robotics.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
os.environ["HTTPS_PROXY"] = "http://192.168.16.68:18000"
import argparse
import logging
import time
import torch
import copy
import torch
import numpy as np
import cv2
import tempfile
import tarfile
from tools import measure_time, show_data_summary
from binary_protocol import dict_to_binary, binary_to_dict, decode_images_jpeg
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy, make_att_2d_masks
from flask import Flask, Response, request
logging.basicConfig(
level = logging.INFO,
format = '[%(name)s] [%(asctime)s.%(msecs)03d] [%(levelname)s] %(message)s',
datefmt='%H:%M:%S')
logger = logging.getLogger("SmolVLA_Server")
app = Flask(__name__)
global smolvla_policy, device
smolvla_policy = None
device = None
@app.route('/infer_vl', methods=['POST'])
def infer_vl():
# global check
global smolvla_policy, device
policy = smolvla_policy
if policy is None:
return Response(dict_to_binary({"status": "error", "message": "Service not ready"}),
status=503, mimetype='application/octet-stream')
# try:
if True:
# binary_to_dict
begin_time = time.time()
begin_time_part = time.time()
raw_data = request.data
logger.info(f"request.data time = {1000*(time.time() - begin_time_part):.2f} ms")
begin_time_part = time.time()
data = binary_to_dict(raw_data)
logger.info(f"binary_to_dict time = {1000*(time.time() - begin_time_part):.2f} ms")
begin_time_part = time.time()
for k in ['observation.images.cam_high', 'observation.images.cam_left_wrist', 'observation.images.cam_right_wrist']:
data[k] = cv2.imdecode(data[k], cv2.IMREAD_COLOR).transpose(2, 0, 1)[np.newaxis].astype(np.float32) / 255.0 # (H, W, 3) -> (1, 3, H, W), uint8 -> float32 [0,1]
logger.info(f"jpeg_decode time = {1000*(time.time() - begin_time_part):.2f} ms")
# show_data_summary(data)
begin_time_part = time.time()
obs = {}
for k in data.keys():
if isinstance(data[k], np.ndarray):
obs[k] = torch.from_numpy(data[k]).to(device)
elif isinstance(data[k], torch.Tensor):
obs[k] = data[k].to(device)
else:
obs[k] = data[k]
logger.info(f"obs,np2tensor time = {1000*(time.time() - begin_time_part):.2f} ms")
# show_data_summary(obs)
# VL inference
with torch.no_grad():
begin_time = time.time()
# only state need normalize
batch = policy.normalize_inputs(obs)
# prepare inputs
images, img_masks = policy.prepare_images(batch)
state = policy.prepare_state(batch)
lang_tokens, lang_masks = policy.prepare_language(batch)
bsize = 1
# noise
actions_shape = (bsize, policy.model.config.chunk_size, policy.model.config.max_action_dim)
noise = torch.normal(
mean=0.0,
std=1.0,
size=actions_shape,
dtype=torch.float32,
device=device,
)
# prepare tokens
prefix_embs, prefix_pad_masks, prefix_att_masks = policy.model.embed_prefix(
images, img_masks, lang_tokens, lang_masks, state=state
)
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
# get kv cache
_, past_key_values = policy.model.vlm_with_expert.forward(
attention_mask=prefix_att_2d_masks,
position_ids=prefix_position_ids,
past_key_values=None,
inputs_embeds=[prefix_embs, None],
use_cache=policy.model.config.use_cache,
fill_kv_cache=True,
)
logger.info("\033[1;31m" + f"GPU VL inference time = {1000*(time.time() - begin_time):.2f} ms" + "\033[0m")
begin_time = time.time()
# prepare diffusion output
diffusion = {'status': 'ok'}
diffusion['num_steps'] = policy.model.config.num_steps
diffusion['prefix_pad_masks'] = prefix_pad_masks.detach().cpu().numpy()
diffusion['x_t'] = noise.detach().cpu().numpy()
for i in range(len(past_key_values)):
diffusion[f'k_{i}'] = past_key_values[i]['key_states'].detach().to(torch.float32).cpu().numpy()
diffusion[f'v_{i}'] = past_key_values[i]['value_states'].detach().to(torch.float32).cpu().numpy()
# show_data_summary(diffusion)
# dict_to_binary
response_blob = dict_to_binary(diffusion)
logger.info(f"dict_to_binary time = {1000*(time.time() - begin_time):.2f} ms")
return Response(response_blob, status=200, mimetype='application/octet-stream')
# except Exception as e:
# logger.error(e)
# return Response(dict_to_binary({"status": "error", "message": str(e)}),
# status=500, mimetype='application/octet-stream')
@app.route('/update_policy', methods=['POST'])
def update_policy():
global smolvla_policy, device
# update device
device = torch.device(request.form.get('device', '').strip())
# update model
if 'file' not in request.files:
return {"error": "No file part"}, 400
file = request.files['file']
if file.filename == '':
return {"error": "No selected file"}, 400
if not file.filename.endswith('.tar'):
return {"error": "Only .tar (uncompressed archive) is allowed"}, 400
temp_dir = None
try:
with tempfile.TemporaryDirectory() as temp_dir:
logger.info(f"Extracting uploaded policy to temporary directory: {temp_dir}")
# 保存并解压 .tar 文件
tar_path = os.path.join(temp_dir, 'uploaded.tar')
file.save(tar_path)
with tarfile.open(tar_path, 'r') as tar:
# 安全检查防止路径遍历CVE-2007-4559 类问题)
def is_within_directory(directory, target):
abs_directory = os.path.abspath(directory)
abs_target = os.path.abspath(target)
prefix = os.path.commonprefix([abs_directory, abs_target])
return prefix == abs_directory
def safe_extract(tar, path=".", members=None, *, numeric_owner=False):
for member in tar.getmembers():
member_path = os.path.join(path, member.name)
if not is_within_directory(path, member_path):
raise Exception("Attempted Path Traversal in Tar File")
tar.extractall(path, members, numeric_owner=numeric_owner)
safe_extract(tar, path=temp_dir)
# 移除 tar 文件,只保留解压内容
os.remove(tar_path)
# 找到实际的模型目录(解压后可能多一层)
extracted_items = os.listdir(temp_dir)
if len(extracted_items) == 1 and os.path.isdir(os.path.join(temp_dir, extracted_items[0])):
model_dir = os.path.join(temp_dir, extracted_items[0])
else:
model_dir = temp_dir
# 验证必要文件是否存在
required_files = {'config.json', 'model.safetensors'}
if not required_files.issubset(set(os.listdir(model_dir))):
return {"error": f"Missing required files: {required_files}"}, 400
# 尝试加载新策略
logger.info(f"Loading new policy from: {model_dir}")
new_policy = SmolVLAPolicy.from_pretrained(model_dir).to(device).eval()
# 替换全局策略
smolvla_policy = new_policy
logger.info("Policy updated successfully!")
return {"message": "Policy updated successfully"}, 200
except Exception as e:
logger.error(f"Failed to update policy: {e}", exc_info=True)
return {"error": str(e)}, 500
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--smolvla-path', type=str, default="/home/chao.wu/SmolVLA_RoboTwin2_BPU/train_result_rtw15fps_1instruction/adjust_bottle_randomized_500_fps15_1instruction/checkpoints/040000/pretrained_model", help="")
parser.add_argument('--device', type=str, default="cuda:0", help="")
parser.add_argument('--port', type=int, default=50002, help="")
opt = parser.parse_args()
logger.info(opt)
global smolvla_policy, device
device = torch.device(opt.device)
smolvla_policy = SmolVLAPolicy.from_pretrained(opt.smolvla_path).to(device).eval()
host = '127.0.0.1'
host = '0.0.0.0'
app.run(host=host, port=opt.port, threaded=False, debug=False)
if __name__ == "__main__":
main()