SmolVLA_Tools/RoboTwin_Policy/GPU_LeRobot_SmolVLA_Cloud4VLA.py
2026-03-03 16:24:02 +08:00

199 lines
8.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Copyright (c) 2025, Cauchy WuChao, D-Robotics.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
os.environ["HTTPS_PROXY"] = "http://192.168.16.68:18000"
import argparse
import logging
import time
import torch
import json
import gzip
import torch
import numpy as np
import cv2
import tempfile
import tarfile
import shutil
from contextlib import contextmanager
from tools import measure_time, show_data_summary
from binary_protocol import dict_to_binary, binary_to_dict, decode_images_jpeg
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
from flask import Flask, Response, request
logging.basicConfig(
level = logging.INFO,
format = '[%(name)s] [%(asctime)s.%(msecs)03d] [%(levelname)s] %(message)s',
datefmt='%H:%M:%S')
logger = logging.getLogger("SmolVLA_Server")
app = Flask(__name__)
global smolvla_policy, device
smolvla_policy = None
device = None
@app.route('/infer', methods=['POST'])
def infer():
# global check
global smolvla_policy, device
if smolvla_policy is None:
return Response(dict_to_binary({"status": "error", "message": "Service not ready"}),
status=503, mimetype='application/octet-stream')
# try:
if True:
# binary_to_dict
begin_time = time.time()
begin_time_part = time.time()
raw_data = request.data
logger.info("\033[1;31m" + f"request.data time = {1000*(time.time() - begin_time_part):.2f} ms" + "\033[0m")
begin_time_part = time.time()
data = binary_to_dict(raw_data)
logger.info("\033[1;31m" + f"binary_to_dict time = {1000*(time.time() - begin_time_part):.2f} ms" + "\033[0m")
begin_time_part = time.time()
for k in ['observation.images.cam_high', 'observation.images.cam_left_wrist', 'observation.images.cam_right_wrist']:
data[k] = cv2.imdecode(data[k], cv2.IMREAD_COLOR).transpose(2, 0, 1)[np.newaxis].astype(np.float32) / 255.0 # (H, W, 3) -> (1, 3, H, W), uint8 -> float32 [0,1]
logger.info("\033[1;31m" + f"jpeg_decode time = {1000*(time.time() - begin_time_part):.2f} ms" + "\033[0m")
begin_time_part = time.time()
obs = {}
for k in data.keys():
if isinstance(data[k], np.ndarray):
obs[k] = torch.from_numpy(data[k]).to(device)
else:
obs[k] = data[k]
logger.info("\033[1;31m" + f"obs,np2tensor time = {1000*(time.time() - begin_time_part):.2f} ms" + "\033[0m")
# inference
begin_time = time.time()
with torch.no_grad():
action_chunk = smolvla_policy.predict_action_chunk(obs).detach().cpu().numpy()#[0,0,:]
logger.info("\033[1;31m" + f"{device} inference time = {1000*(time.time() - begin_time):.2f} ms" + "\033[0m")
#dict_to_binary
begin_time = time.time()
response_blob = dict_to_binary({"status": "ok",'action_chunk': action_chunk})
logger.info("\033[1;31m" + f"dict_to_binary time = {1000*(time.time() - begin_time):.2f} ms" + "\033[0m")
return Response(response_blob, status=200, mimetype='application/octet-stream')
# return encode_request_data({'action_chunk': action_chunk, 'message': 'success'}, compressed=compressed)
# except Exception as e:
# logger.error(e)
# return encode_request_data({'message': 'error'}, compressed=compressed)
@app.route('/update_policy', methods=['POST'])
def update_policy():
global smolvla_policy, device
# update device
device = torch.device(request.form.get('device', '').strip())
# update model
if 'file' not in request.files:
return {"error": "No file part"}, 400
file = request.files['file']
if file.filename == '':
return {"error": "No selected file"}, 400
if not file.filename.endswith('.tar'):
return {"error": "Only .tar (uncompressed archive) is allowed"}, 400
temp_dir = None
try:
with tempfile.TemporaryDirectory() as temp_dir:
logger.info(f"Extracting uploaded policy to temporary directory: {temp_dir}")
# 保存并解压 .tar 文件
tar_path = os.path.join(temp_dir, 'uploaded.tar')
file.save(tar_path)
with tarfile.open(tar_path, 'r') as tar:
# 安全检查防止路径遍历CVE-2007-4559 类问题)
def is_within_directory(directory, target):
abs_directory = os.path.abspath(directory)
abs_target = os.path.abspath(target)
prefix = os.path.commonprefix([abs_directory, abs_target])
return prefix == abs_directory
def safe_extract(tar, path=".", members=None, *, numeric_owner=False):
for member in tar.getmembers():
member_path = os.path.join(path, member.name)
if not is_within_directory(path, member_path):
raise Exception("Attempted Path Traversal in Tar File")
tar.extractall(path, members, numeric_owner=numeric_owner)
safe_extract(tar, path=temp_dir)
# 移除 tar 文件,只保留解压内容
os.remove(tar_path)
# 找到实际的模型目录(解压后可能多一层)
extracted_items = os.listdir(temp_dir)
if len(extracted_items) == 1 and os.path.isdir(os.path.join(temp_dir, extracted_items[0])):
model_dir = os.path.join(temp_dir, extracted_items[0])
else:
model_dir = temp_dir
# 验证必要文件是否存在
required_files = {'config.json', 'model.safetensors'}
if not required_files.issubset(set(os.listdir(model_dir))):
return {"error": f"Missing required files: {required_files}"}, 400
# 尝试加载新策略
logger.info(f"Loading new policy from: {model_dir}")
new_policy = SmolVLAPolicy.from_pretrained(model_dir).to(device).eval()
# 替换全局策略
smolvla_policy = new_policy
logger.info("Policy updated successfully!")
return {"message": "Policy updated successfully"}, 200
except Exception as e:
logger.error(f"Failed to update policy: {e}", exc_info=True)
return {"error": str(e)}, 500
# @measure_time(logger)
# def decode_request_data(raw_data, compressed: bool = False, device: torch.device = torch.device('cpu')):
# raw_data = gzip.decompress(raw_data) if compressed else raw_data
# return deserialize_dict(json.loads(raw_data.decode('utf-8')), device=device)
# @measure_time(logger)
# def encode_request_data(data: dict, compressed: bool = False):
# resp_bytes = json.dumps(serialize_dict(data)).encode('utf-8')
# resp_bytes = gzip.compress(resp_bytes) if compressed else resp_bytes
# headers = {'Content-Type': 'application/json'}
# headers['compressed'] = 'gzip' if compressed else 'raw'
# return app.response_class(response=resp_bytes, status=200, headers=headers)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--smolvla-path', type=str, default="/home/chao.wu/SmolVLA_RoboTwin2_BPU/train_result_rtw15fps_1instruction/adjust_bottle_clean_50_fps15_1instruction/checkpoints/040000/pretrained_model", help="")
parser.add_argument('--device', type=str, default="cuda:0", help="")
parser.add_argument('--port', type=int, default=60002, help="")
opt = parser.parse_args()
logger.info(opt)
# logger.info("Loading model ...")
global smolvla_policy, device
device = torch.device(opt.device)
smolvla_policy = SmolVLAPolicy.from_pretrained(opt.smolvla_path).to(device).eval()
app.run(host='127.0.0.1', port=opt.port, threaded=False, debug=False)
# app.run(host='127.0.0.1', port=opt.port, threaded=False, debug=False)
if __name__ == "__main__":
main()