231 lines
9.4 KiB
Python
231 lines
9.4 KiB
Python
# Copyright (c) 2025, Cauchy WuChao, D-Robotics.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
|
||
|
||
import os
|
||
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
|
||
os.environ["HTTPS_PROXY"] = "http://192.168.16.68:18000"
|
||
import argparse
|
||
import logging
|
||
import time
|
||
import torch
|
||
import copy
|
||
|
||
import torch
|
||
import numpy as np
|
||
import cv2
|
||
|
||
import tempfile
|
||
import tarfile
|
||
from tools import measure_time, show_data_summary
|
||
from binary_protocol import dict_to_binary, binary_to_dict, decode_images_jpeg
|
||
|
||
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy, make_att_2d_masks
|
||
from flask import Flask, Response, request
|
||
|
||
logging.basicConfig(
|
||
level = logging.INFO,
|
||
format = '[%(name)s] [%(asctime)s.%(msecs)03d] [%(levelname)s] %(message)s',
|
||
datefmt='%H:%M:%S')
|
||
logger = logging.getLogger("SmolVLA_Server")
|
||
|
||
app = Flask(__name__)
|
||
|
||
global smolvla_policy, device
|
||
smolvla_policy = None
|
||
device = None
|
||
|
||
@app.route('/infer_vl', methods=['POST'])
|
||
def infer_vl():
|
||
# global check
|
||
global smolvla_policy, device
|
||
policy = smolvla_policy
|
||
if policy is None:
|
||
return Response(dict_to_binary({"status": "error", "message": "Service not ready"}),
|
||
status=503, mimetype='application/octet-stream')
|
||
# try:
|
||
if True:
|
||
# binary_to_dict
|
||
begin_time = time.time()
|
||
begin_time_part = time.time()
|
||
raw_data = request.data
|
||
logger.info(f"request.data time = {1000*(time.time() - begin_time_part):.2f} ms")
|
||
begin_time_part = time.time()
|
||
data = binary_to_dict(raw_data)
|
||
logger.info(f"binary_to_dict time = {1000*(time.time() - begin_time_part):.2f} ms")
|
||
begin_time_part = time.time()
|
||
for k in ['observation.images.cam_high', 'observation.images.cam_left_wrist', 'observation.images.cam_right_wrist']:
|
||
data[k] = cv2.imdecode(data[k], cv2.IMREAD_COLOR).transpose(2, 0, 1)[np.newaxis].astype(np.float32) / 255.0 # (H, W, 3) -> (1, 3, H, W), uint8 -> float32 [0,1]
|
||
logger.info(f"jpeg_decode time = {1000*(time.time() - begin_time_part):.2f} ms")
|
||
# show_data_summary(data)
|
||
begin_time_part = time.time()
|
||
obs = {}
|
||
for k in data.keys():
|
||
if isinstance(data[k], np.ndarray):
|
||
obs[k] = torch.from_numpy(data[k]).to(device)
|
||
elif isinstance(data[k], torch.Tensor):
|
||
obs[k] = data[k].to(device)
|
||
else:
|
||
obs[k] = data[k]
|
||
logger.info(f"obs,np2tensor time = {1000*(time.time() - begin_time_part):.2f} ms")
|
||
# show_data_summary(obs)
|
||
# VL inference
|
||
with torch.no_grad():
|
||
begin_time = time.time()
|
||
# only state need normalize
|
||
batch = policy.normalize_inputs(obs)
|
||
# prepare inputs
|
||
images, img_masks = policy.prepare_images(batch)
|
||
state = policy.prepare_state(batch)
|
||
lang_tokens, lang_masks = policy.prepare_language(batch)
|
||
bsize = 1
|
||
# noise
|
||
actions_shape = (bsize, policy.model.config.chunk_size, policy.model.config.max_action_dim)
|
||
noise = torch.normal(
|
||
mean=0.0,
|
||
std=1.0,
|
||
size=actions_shape,
|
||
dtype=torch.float32,
|
||
device=device,
|
||
)
|
||
# prepare tokens
|
||
prefix_embs, prefix_pad_masks, prefix_att_masks = policy.model.embed_prefix(
|
||
images, img_masks, lang_tokens, lang_masks, state=state
|
||
)
|
||
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
||
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||
|
||
# get kv cache
|
||
_, past_key_values = policy.model.vlm_with_expert.forward(
|
||
attention_mask=prefix_att_2d_masks,
|
||
position_ids=prefix_position_ids,
|
||
past_key_values=None,
|
||
inputs_embeds=[prefix_embs, None],
|
||
use_cache=policy.model.config.use_cache,
|
||
fill_kv_cache=True,
|
||
)
|
||
logger.info("\033[1;31m" + f"GPU VL inference time = {1000*(time.time() - begin_time):.2f} ms" + "\033[0m")
|
||
|
||
begin_time = time.time()
|
||
# prepare diffusion output
|
||
diffusion = {'status': 'ok'}
|
||
diffusion['num_steps'] = policy.model.config.num_steps
|
||
diffusion['prefix_pad_masks'] = prefix_pad_masks.detach().cpu().numpy()
|
||
diffusion['x_t'] = noise.detach().cpu().numpy()
|
||
for i in range(len(past_key_values)):
|
||
diffusion[f'k_{i}'] = past_key_values[i]['key_states'].detach().to(torch.float32).cpu().numpy()
|
||
diffusion[f'v_{i}'] = past_key_values[i]['value_states'].detach().to(torch.float32).cpu().numpy()
|
||
|
||
# show_data_summary(diffusion)
|
||
# dict_to_binary
|
||
response_blob = dict_to_binary(diffusion)
|
||
logger.info(f"dict_to_binary time = {1000*(time.time() - begin_time):.2f} ms")
|
||
|
||
return Response(response_blob, status=200, mimetype='application/octet-stream')
|
||
|
||
# except Exception as e:
|
||
# logger.error(e)
|
||
# return Response(dict_to_binary({"status": "error", "message": str(e)}),
|
||
# status=500, mimetype='application/octet-stream')
|
||
|
||
@app.route('/update_policy', methods=['POST'])
|
||
def update_policy():
|
||
global smolvla_policy, device
|
||
# update device
|
||
device = torch.device(request.form.get('device', '').strip())
|
||
|
||
# update model
|
||
if 'file' not in request.files:
|
||
return {"error": "No file part"}, 400
|
||
file = request.files['file']
|
||
if file.filename == '':
|
||
return {"error": "No selected file"}, 400
|
||
if not file.filename.endswith('.tar'):
|
||
return {"error": "Only .tar (uncompressed archive) is allowed"}, 400
|
||
|
||
temp_dir = None
|
||
try:
|
||
with tempfile.TemporaryDirectory() as temp_dir:
|
||
logger.info(f"Extracting uploaded policy to temporary directory: {temp_dir}")
|
||
|
||
# 保存并解压 .tar 文件
|
||
tar_path = os.path.join(temp_dir, 'uploaded.tar')
|
||
file.save(tar_path)
|
||
|
||
with tarfile.open(tar_path, 'r') as tar:
|
||
# 安全检查:防止路径遍历(CVE-2007-4559 类问题)
|
||
def is_within_directory(directory, target):
|
||
abs_directory = os.path.abspath(directory)
|
||
abs_target = os.path.abspath(target)
|
||
prefix = os.path.commonprefix([abs_directory, abs_target])
|
||
return prefix == abs_directory
|
||
|
||
def safe_extract(tar, path=".", members=None, *, numeric_owner=False):
|
||
for member in tar.getmembers():
|
||
member_path = os.path.join(path, member.name)
|
||
if not is_within_directory(path, member_path):
|
||
raise Exception("Attempted Path Traversal in Tar File")
|
||
tar.extractall(path, members, numeric_owner=numeric_owner)
|
||
|
||
safe_extract(tar, path=temp_dir)
|
||
|
||
# 移除 tar 文件,只保留解压内容
|
||
os.remove(tar_path)
|
||
|
||
# 找到实际的模型目录(解压后可能多一层)
|
||
extracted_items = os.listdir(temp_dir)
|
||
if len(extracted_items) == 1 and os.path.isdir(os.path.join(temp_dir, extracted_items[0])):
|
||
model_dir = os.path.join(temp_dir, extracted_items[0])
|
||
else:
|
||
model_dir = temp_dir
|
||
|
||
# 验证必要文件是否存在
|
||
required_files = {'config.json', 'model.safetensors'}
|
||
if not required_files.issubset(set(os.listdir(model_dir))):
|
||
return {"error": f"Missing required files: {required_files}"}, 400
|
||
|
||
# 尝试加载新策略
|
||
logger.info(f"Loading new policy from: {model_dir}")
|
||
new_policy = SmolVLAPolicy.from_pretrained(model_dir).to(device).eval()
|
||
|
||
# 替换全局策略
|
||
smolvla_policy = new_policy
|
||
logger.info("Policy updated successfully!")
|
||
|
||
return {"message": "Policy updated successfully"}, 200
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to update policy: {e}", exc_info=True)
|
||
return {"error": str(e)}, 500
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser()
|
||
parser.add_argument('--smolvla-path', type=str, default="/home/chao.wu/SmolVLA_RoboTwin2_BPU/train_result_rtw15fps_1instruction/adjust_bottle_randomized_500_fps15_1instruction/checkpoints/040000/pretrained_model", help="")
|
||
parser.add_argument('--device', type=str, default="cuda:0", help="")
|
||
parser.add_argument('--port', type=int, default=50002, help="")
|
||
opt = parser.parse_args()
|
||
logger.info(opt)
|
||
|
||
global smolvla_policy, device
|
||
device = torch.device(opt.device)
|
||
smolvla_policy = SmolVLAPolicy.from_pretrained(opt.smolvla_path).to(device).eval()
|
||
|
||
host = '127.0.0.1'
|
||
host = '0.0.0.0'
|
||
app.run(host=host, port=opt.port, threaded=False, debug=False)
|
||
|
||
if __name__ == "__main__":
|
||
main()
|