# Copyright (c) 2025, Cauchy WuChao, D-Robotics. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os os.environ["CUDA_VISIBLE_DEVICES"] = "2" os.environ["HTTPS_PROXY"] = "http://192.168.16.68:18000" import argparse import logging import time import torch import copy import torch import numpy as np import cv2 import tempfile import tarfile from tools import measure_time, show_data_summary from binary_protocol import dict_to_binary, binary_to_dict, decode_images_jpeg from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy, make_att_2d_masks from flask import Flask, Response, request logging.basicConfig( level = logging.INFO, format = '[%(name)s] [%(asctime)s.%(msecs)03d] [%(levelname)s] %(message)s', datefmt='%H:%M:%S') logger = logging.getLogger("SmolVLA_Server") app = Flask(__name__) global smolvla_policy, device smolvla_policy = None device = None @app.route('/infer_vl', methods=['POST']) def infer_vl(): # global check global smolvla_policy, device policy = smolvla_policy if policy is None: return Response(dict_to_binary({"status": "error", "message": "Service not ready"}), status=503, mimetype='application/octet-stream') # try: if True: # binary_to_dict begin_time = time.time() begin_time_part = time.time() raw_data = request.data logger.info(f"request.data time = {1000*(time.time() - begin_time_part):.2f} ms") begin_time_part = time.time() data = binary_to_dict(raw_data) logger.info(f"binary_to_dict time = {1000*(time.time() - begin_time_part):.2f} ms") begin_time_part = time.time() for k in ['observation.images.cam_high', 'observation.images.cam_left_wrist', 'observation.images.cam_right_wrist']: data[k] = cv2.imdecode(data[k], cv2.IMREAD_COLOR).transpose(2, 0, 1)[np.newaxis].astype(np.float32) / 255.0 # (H, W, 3) -> (1, 3, H, W), uint8 -> float32 [0,1] logger.info(f"jpeg_decode time = {1000*(time.time() - begin_time_part):.2f} ms") # show_data_summary(data) begin_time_part = time.time() obs = {} for k in data.keys(): if isinstance(data[k], np.ndarray): obs[k] = torch.from_numpy(data[k]).to(device) elif isinstance(data[k], torch.Tensor): obs[k] = data[k].to(device) else: obs[k] = data[k] logger.info(f"obs,np2tensor time = {1000*(time.time() - begin_time_part):.2f} ms") # show_data_summary(obs) # VL inference with torch.no_grad(): begin_time = time.time() # only state need normalize batch = policy.normalize_inputs(obs) # prepare inputs images, img_masks = policy.prepare_images(batch) state = policy.prepare_state(batch) lang_tokens, lang_masks = policy.prepare_language(batch) bsize = 1 # noise actions_shape = (bsize, policy.model.config.chunk_size, policy.model.config.max_action_dim) noise = torch.normal( mean=0.0, std=1.0, size=actions_shape, dtype=torch.float32, device=device, ) # prepare tokens prefix_embs, prefix_pad_masks, prefix_att_masks = policy.model.embed_prefix( images, img_masks, lang_tokens, lang_masks, state=state ) prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 # get kv cache _, past_key_values = policy.model.vlm_with_expert.forward( attention_mask=prefix_att_2d_masks, position_ids=prefix_position_ids, past_key_values=None, inputs_embeds=[prefix_embs, None], use_cache=policy.model.config.use_cache, fill_kv_cache=True, ) logger.info("\033[1;31m" + f"GPU VL inference time = {1000*(time.time() - begin_time):.2f} ms" + "\033[0m") begin_time = time.time() # prepare diffusion output diffusion = {'status': 'ok'} diffusion['num_steps'] = policy.model.config.num_steps diffusion['prefix_pad_masks'] = prefix_pad_masks.detach().cpu().numpy() diffusion['x_t'] = noise.detach().cpu().numpy() for i in range(len(past_key_values)): diffusion[f'k_{i}'] = past_key_values[i]['key_states'].detach().to(torch.float32).cpu().numpy() diffusion[f'v_{i}'] = past_key_values[i]['value_states'].detach().to(torch.float32).cpu().numpy() # show_data_summary(diffusion) # dict_to_binary response_blob = dict_to_binary(diffusion) logger.info(f"dict_to_binary time = {1000*(time.time() - begin_time):.2f} ms") return Response(response_blob, status=200, mimetype='application/octet-stream') # except Exception as e: # logger.error(e) # return Response(dict_to_binary({"status": "error", "message": str(e)}), # status=500, mimetype='application/octet-stream') @app.route('/update_policy', methods=['POST']) def update_policy(): global smolvla_policy, device # update device device = torch.device(request.form.get('device', '').strip()) # update model if 'file' not in request.files: return {"error": "No file part"}, 400 file = request.files['file'] if file.filename == '': return {"error": "No selected file"}, 400 if not file.filename.endswith('.tar'): return {"error": "Only .tar (uncompressed archive) is allowed"}, 400 temp_dir = None try: with tempfile.TemporaryDirectory() as temp_dir: logger.info(f"Extracting uploaded policy to temporary directory: {temp_dir}") # 保存并解压 .tar 文件 tar_path = os.path.join(temp_dir, 'uploaded.tar') file.save(tar_path) with tarfile.open(tar_path, 'r') as tar: # 安全检查:防止路径遍历(CVE-2007-4559 类问题) def is_within_directory(directory, target): abs_directory = os.path.abspath(directory) abs_target = os.path.abspath(target) prefix = os.path.commonprefix([abs_directory, abs_target]) return prefix == abs_directory def safe_extract(tar, path=".", members=None, *, numeric_owner=False): for member in tar.getmembers(): member_path = os.path.join(path, member.name) if not is_within_directory(path, member_path): raise Exception("Attempted Path Traversal in Tar File") tar.extractall(path, members, numeric_owner=numeric_owner) safe_extract(tar, path=temp_dir) # 移除 tar 文件,只保留解压内容 os.remove(tar_path) # 找到实际的模型目录(解压后可能多一层) extracted_items = os.listdir(temp_dir) if len(extracted_items) == 1 and os.path.isdir(os.path.join(temp_dir, extracted_items[0])): model_dir = os.path.join(temp_dir, extracted_items[0]) else: model_dir = temp_dir # 验证必要文件是否存在 required_files = {'config.json', 'model.safetensors'} if not required_files.issubset(set(os.listdir(model_dir))): return {"error": f"Missing required files: {required_files}"}, 400 # 尝试加载新策略 logger.info(f"Loading new policy from: {model_dir}") new_policy = SmolVLAPolicy.from_pretrained(model_dir).to(device).eval() # 替换全局策略 smolvla_policy = new_policy logger.info("Policy updated successfully!") return {"message": "Policy updated successfully"}, 200 except Exception as e: logger.error(f"Failed to update policy: {e}", exc_info=True) return {"error": str(e)}, 500 def main(): parser = argparse.ArgumentParser() parser.add_argument('--smolvla-path', type=str, default="/home/chao.wu/SmolVLA_RoboTwin2_BPU/train_result_rtw15fps_1instruction/adjust_bottle_randomized_500_fps15_1instruction/checkpoints/040000/pretrained_model", help="") parser.add_argument('--device', type=str, default="cuda:0", help="") parser.add_argument('--port', type=int, default=50002, help="") opt = parser.parse_args() logger.info(opt) global smolvla_policy, device device = torch.device(opt.device) smolvla_policy = SmolVLAPolicy.from_pretrained(opt.smolvla_path).to(device).eval() host = '127.0.0.1' host = '0.0.0.0' app.run(host=host, port=opt.port, threaded=False, debug=False) if __name__ == "__main__": main()