199 lines
8.3 KiB
Python
199 lines
8.3 KiB
Python
# Copyright (c) 2025, Cauchy WuChao, D-Robotics.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
|
||
|
||
import os
|
||
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
|
||
os.environ["HTTPS_PROXY"] = "http://192.168.16.68:18000"
|
||
import argparse
|
||
import logging
|
||
import time
|
||
import torch
|
||
import json
|
||
import gzip
|
||
|
||
import torch
|
||
import numpy as np
|
||
import cv2
|
||
|
||
import tempfile
|
||
import tarfile
|
||
import shutil
|
||
from contextlib import contextmanager
|
||
from tools import measure_time, show_data_summary
|
||
from binary_protocol import dict_to_binary, binary_to_dict, decode_images_jpeg
|
||
|
||
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
|
||
from flask import Flask, Response, request
|
||
|
||
logging.basicConfig(
|
||
level = logging.INFO,
|
||
format = '[%(name)s] [%(asctime)s.%(msecs)03d] [%(levelname)s] %(message)s',
|
||
datefmt='%H:%M:%S')
|
||
logger = logging.getLogger("SmolVLA_Server")
|
||
|
||
app = Flask(__name__)
|
||
|
||
global smolvla_policy, device
|
||
smolvla_policy = None
|
||
device = None
|
||
|
||
@app.route('/infer', methods=['POST'])
|
||
def infer():
|
||
# global check
|
||
global smolvla_policy, device
|
||
if smolvla_policy is None:
|
||
return Response(dict_to_binary({"status": "error", "message": "Service not ready"}),
|
||
status=503, mimetype='application/octet-stream')
|
||
# try:
|
||
if True:
|
||
# binary_to_dict
|
||
begin_time = time.time()
|
||
begin_time_part = time.time()
|
||
raw_data = request.data
|
||
logger.info("\033[1;31m" + f"request.data time = {1000*(time.time() - begin_time_part):.2f} ms" + "\033[0m")
|
||
begin_time_part = time.time()
|
||
data = binary_to_dict(raw_data)
|
||
logger.info("\033[1;31m" + f"binary_to_dict time = {1000*(time.time() - begin_time_part):.2f} ms" + "\033[0m")
|
||
begin_time_part = time.time()
|
||
for k in ['observation.images.cam_high', 'observation.images.cam_left_wrist', 'observation.images.cam_right_wrist']:
|
||
data[k] = cv2.imdecode(data[k], cv2.IMREAD_COLOR).transpose(2, 0, 1)[np.newaxis].astype(np.float32) / 255.0 # (H, W, 3) -> (1, 3, H, W), uint8 -> float32 [0,1]
|
||
logger.info("\033[1;31m" + f"jpeg_decode time = {1000*(time.time() - begin_time_part):.2f} ms" + "\033[0m")
|
||
begin_time_part = time.time()
|
||
obs = {}
|
||
for k in data.keys():
|
||
if isinstance(data[k], np.ndarray):
|
||
obs[k] = torch.from_numpy(data[k]).to(device)
|
||
else:
|
||
obs[k] = data[k]
|
||
logger.info("\033[1;31m" + f"obs,np2tensor time = {1000*(time.time() - begin_time_part):.2f} ms" + "\033[0m")
|
||
# inference
|
||
begin_time = time.time()
|
||
with torch.no_grad():
|
||
action_chunk = smolvla_policy.predict_action_chunk(obs).detach().cpu().numpy()#[0,0,:]
|
||
logger.info("\033[1;31m" + f"{device} inference time = {1000*(time.time() - begin_time):.2f} ms" + "\033[0m")
|
||
#dict_to_binary
|
||
begin_time = time.time()
|
||
response_blob = dict_to_binary({"status": "ok",'action_chunk': action_chunk})
|
||
logger.info("\033[1;31m" + f"dict_to_binary time = {1000*(time.time() - begin_time):.2f} ms" + "\033[0m")
|
||
|
||
return Response(response_blob, status=200, mimetype='application/octet-stream')
|
||
|
||
# return encode_request_data({'action_chunk': action_chunk, 'message': 'success'}, compressed=compressed)
|
||
# except Exception as e:
|
||
# logger.error(e)
|
||
# return encode_request_data({'message': 'error'}, compressed=compressed)
|
||
|
||
@app.route('/update_policy', methods=['POST'])
|
||
def update_policy():
|
||
global smolvla_policy, device
|
||
# update device
|
||
device = torch.device(request.form.get('device', '').strip())
|
||
|
||
# update model
|
||
if 'file' not in request.files:
|
||
return {"error": "No file part"}, 400
|
||
file = request.files['file']
|
||
if file.filename == '':
|
||
return {"error": "No selected file"}, 400
|
||
if not file.filename.endswith('.tar'):
|
||
return {"error": "Only .tar (uncompressed archive) is allowed"}, 400
|
||
|
||
temp_dir = None
|
||
try:
|
||
with tempfile.TemporaryDirectory() as temp_dir:
|
||
logger.info(f"Extracting uploaded policy to temporary directory: {temp_dir}")
|
||
|
||
# 保存并解压 .tar 文件
|
||
tar_path = os.path.join(temp_dir, 'uploaded.tar')
|
||
file.save(tar_path)
|
||
|
||
with tarfile.open(tar_path, 'r') as tar:
|
||
# 安全检查:防止路径遍历(CVE-2007-4559 类问题)
|
||
def is_within_directory(directory, target):
|
||
abs_directory = os.path.abspath(directory)
|
||
abs_target = os.path.abspath(target)
|
||
prefix = os.path.commonprefix([abs_directory, abs_target])
|
||
return prefix == abs_directory
|
||
|
||
def safe_extract(tar, path=".", members=None, *, numeric_owner=False):
|
||
for member in tar.getmembers():
|
||
member_path = os.path.join(path, member.name)
|
||
if not is_within_directory(path, member_path):
|
||
raise Exception("Attempted Path Traversal in Tar File")
|
||
tar.extractall(path, members, numeric_owner=numeric_owner)
|
||
|
||
safe_extract(tar, path=temp_dir)
|
||
|
||
# 移除 tar 文件,只保留解压内容
|
||
os.remove(tar_path)
|
||
|
||
# 找到实际的模型目录(解压后可能多一层)
|
||
extracted_items = os.listdir(temp_dir)
|
||
if len(extracted_items) == 1 and os.path.isdir(os.path.join(temp_dir, extracted_items[0])):
|
||
model_dir = os.path.join(temp_dir, extracted_items[0])
|
||
else:
|
||
model_dir = temp_dir
|
||
|
||
# 验证必要文件是否存在
|
||
required_files = {'config.json', 'model.safetensors'}
|
||
if not required_files.issubset(set(os.listdir(model_dir))):
|
||
return {"error": f"Missing required files: {required_files}"}, 400
|
||
|
||
# 尝试加载新策略
|
||
logger.info(f"Loading new policy from: {model_dir}")
|
||
new_policy = SmolVLAPolicy.from_pretrained(model_dir).to(device).eval()
|
||
|
||
# 替换全局策略
|
||
smolvla_policy = new_policy
|
||
logger.info("Policy updated successfully!")
|
||
|
||
return {"message": "Policy updated successfully"}, 200
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to update policy: {e}", exc_info=True)
|
||
return {"error": str(e)}, 500
|
||
|
||
|
||
# @measure_time(logger)
|
||
# def decode_request_data(raw_data, compressed: bool = False, device: torch.device = torch.device('cpu')):
|
||
# raw_data = gzip.decompress(raw_data) if compressed else raw_data
|
||
# return deserialize_dict(json.loads(raw_data.decode('utf-8')), device=device)
|
||
|
||
# @measure_time(logger)
|
||
# def encode_request_data(data: dict, compressed: bool = False):
|
||
# resp_bytes = json.dumps(serialize_dict(data)).encode('utf-8')
|
||
# resp_bytes = gzip.compress(resp_bytes) if compressed else resp_bytes
|
||
# headers = {'Content-Type': 'application/json'}
|
||
# headers['compressed'] = 'gzip' if compressed else 'raw'
|
||
# return app.response_class(response=resp_bytes, status=200, headers=headers)
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser()
|
||
parser.add_argument('--smolvla-path', type=str, default="/home/chao.wu/SmolVLA_RoboTwin2_BPU/train_result_rtw15fps_1instruction/adjust_bottle_clean_50_fps15_1instruction/checkpoints/040000/pretrained_model", help="")
|
||
parser.add_argument('--device', type=str, default="cuda:0", help="")
|
||
parser.add_argument('--port', type=int, default=60002, help="")
|
||
opt = parser.parse_args()
|
||
logger.info(opt)
|
||
|
||
# logger.info("Loading model ...")
|
||
global smolvla_policy, device
|
||
device = torch.device(opt.device)
|
||
smolvla_policy = SmolVLAPolicy.from_pretrained(opt.smolvla_path).to(device).eval()
|
||
|
||
app.run(host='127.0.0.1', port=opt.port, threaded=False, debug=False)
|
||
# app.run(host='127.0.0.1', port=opt.port, threaded=False, debug=False)
|
||
|
||
if __name__ == "__main__":
|
||
main() |