SmolVLA_Tools/RoboTwin_Policy/request_tools.py
2026-03-03 16:24:02 +08:00

121 lines
4.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import gzip
import torch
import requests
import tarfile
import pickle
import numpy as np
import logging
from typing import Any
HIGHEST_PROTOCOL = 4
FIX_IMPORTS = False
import time
import requests
import gzip
import json
import torch
from typing import Optional, Dict, Any
from tools import measure_time, show_data_summary
from binary_protocol import dict_to_binary, binary_to_dict
def send_inference_request(
data_dict: Dict[str, Any],
url: str = 'http://127.0.0.1:50000/infer',
timeout: int = 10,
max_retries: int = 30,
retry_delay: float = 2
) -> Dict[str, Any]:
"""
发送推理请求 (二进制 Pickle 协议),支持超时、自动重试。
兼容 Python 3.10/3.12 及 Numpy 2.x。
:param data_dict: 输入数据字典 (可包含 numpy 数组)
:param url: 推理服务地址
:param timeout: 单次 HTTP 请求超时时间(秒)
:param max_retries: 最大重试次数
:param retry_delay: 重试前等待时间(秒)
:return: 解码后的响应字典 (包含 numpy 数组)
:raises RuntimeError: 超过重试次数或收到非 200 响应
"""
# 1. 尝试序列化输入数据
try:
req_body = dict_to_binary(data_dict)
except Exception as e:
raise RuntimeError(f"Failed to serialize request data: {e}")
# 2. 设置二进制流 headers
headers = {
'Content-Type': 'application/octet-stream'
}
# 3. 发送二进制流
last_exception: Optional[Exception] = None
for attempt in range(max_retries + 1):
try:
resp = requests.post(url, data=req_body, headers=headers, timeout=timeout)
if resp.status_code == 200:
try:
result_dict = binary_to_dict(resp.content)
return result_dict
except Exception as deserialize_err:
raise RuntimeError(f"Failed to deserialize response: {deserialize_err}")
else:
error_msg = f"HTTP {resp.status_code}"
try:
err_data = binary_to_dict(resp.content)
if isinstance(err_data, dict) and "message" in err_data:
error_msg += f": {err_data['message']}"
else:
error_msg += f": {resp.text[:200]}"
except:
error_msg += f": {resp.text[:200]}"
print(f"[Attempt {attempt + 1}/{max_retries + 1}] Server error: {error_msg}")
last_exception = RuntimeError(error_msg)
except requests.exceptions.Timeout as e:
last_exception = e
print(f"[Attempt {attempt + 1}/{max_retries + 1}] Request timeout: {e}")
except requests.exceptions.ConnectionError as e:
last_exception = e
print(f"[Attempt {attempt + 1}/{max_retries + 1}] Connection error: {e}")
except requests.exceptions.RequestException as e:
last_exception = e
print(f"[Attempt {attempt + 1}/{max_retries + 1}] Network error: {e}")
except Exception as e:
last_exception = e
print(f"[Attempt {attempt + 1}/{max_retries + 1}] Unexpected error: {e}")
# 如果不是最后一次尝试,则等待后重试
if attempt < max_retries:
time.sleep(retry_delay)
raise RuntimeError(f"Failed after {max_retries + 1} attempts. Last error: {last_exception}")
def upload_policy(policy_dir: str, server_url: str = "http://localhost:50001/update_policy", device: str = 'cpu'):
"""
将 policy_dir 打包为 .tar 并上传到服务器
"""
# 创建临时 .tar 文件(内存中或磁盘)
import tempfile
with tempfile.NamedTemporaryFile(suffix='.tar') as tmp_tar:
# print("tar ... ")
with tarfile.open(tmp_tar.name, 'w') as tar:
# arcname='' 表示根目录内容直接放入 tar避免多层嵌套
tar.add(policy_dir, arcname='')
# print("upload ... ")
# 上传
with open(tmp_tar.name, 'rb') as f:
files = {'file': ('policy.tar', f, 'application/x-tar')}
resp = requests.post(server_url, files=files, data={'device': device})
print(f"{resp = }")
# print(resp.json())
return resp.status_code == 200