121 lines
4.2 KiB
Python
121 lines
4.2 KiB
Python
|
||
import json
|
||
import gzip
|
||
import torch
|
||
import requests
|
||
import tarfile
|
||
import pickle
|
||
import numpy as np
|
||
import logging
|
||
from typing import Any
|
||
|
||
|
||
HIGHEST_PROTOCOL = 4
|
||
FIX_IMPORTS = False
|
||
|
||
import time
|
||
import requests
|
||
import gzip
|
||
import json
|
||
import torch
|
||
from typing import Optional, Dict, Any
|
||
|
||
from tools import measure_time, show_data_summary
|
||
from binary_protocol import dict_to_binary, binary_to_dict
|
||
|
||
def send_inference_request(
|
||
data_dict: Dict[str, Any],
|
||
url: str = 'http://127.0.0.1:50000/infer',
|
||
timeout: int = 10,
|
||
max_retries: int = 30,
|
||
retry_delay: float = 2
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
发送推理请求 (二进制 Pickle 协议),支持超时、自动重试。
|
||
兼容 Python 3.10/3.12 及 Numpy 2.x。
|
||
|
||
:param data_dict: 输入数据字典 (可包含 numpy 数组)
|
||
:param url: 推理服务地址
|
||
:param timeout: 单次 HTTP 请求超时时间(秒)
|
||
:param max_retries: 最大重试次数
|
||
:param retry_delay: 重试前等待时间(秒)
|
||
:return: 解码后的响应字典 (包含 numpy 数组)
|
||
:raises RuntimeError: 超过重试次数或收到非 200 响应
|
||
"""
|
||
# 1. 尝试序列化输入数据
|
||
try:
|
||
req_body = dict_to_binary(data_dict)
|
||
except Exception as e:
|
||
raise RuntimeError(f"Failed to serialize request data: {e}")
|
||
|
||
# 2. 设置二进制流 headers
|
||
headers = {
|
||
'Content-Type': 'application/octet-stream'
|
||
}
|
||
|
||
# 3. 发送二进制流
|
||
last_exception: Optional[Exception] = None
|
||
for attempt in range(max_retries + 1):
|
||
try:
|
||
resp = requests.post(url, data=req_body, headers=headers, timeout=timeout)
|
||
if resp.status_code == 200:
|
||
try:
|
||
result_dict = binary_to_dict(resp.content)
|
||
return result_dict
|
||
except Exception as deserialize_err:
|
||
raise RuntimeError(f"Failed to deserialize response: {deserialize_err}")
|
||
else:
|
||
error_msg = f"HTTP {resp.status_code}"
|
||
try:
|
||
err_data = binary_to_dict(resp.content)
|
||
if isinstance(err_data, dict) and "message" in err_data:
|
||
error_msg += f": {err_data['message']}"
|
||
else:
|
||
error_msg += f": {resp.text[:200]}"
|
||
except:
|
||
error_msg += f": {resp.text[:200]}"
|
||
print(f"[Attempt {attempt + 1}/{max_retries + 1}] Server error: {error_msg}")
|
||
last_exception = RuntimeError(error_msg)
|
||
|
||
except requests.exceptions.Timeout as e:
|
||
last_exception = e
|
||
print(f"[Attempt {attempt + 1}/{max_retries + 1}] Request timeout: {e}")
|
||
|
||
except requests.exceptions.ConnectionError as e:
|
||
last_exception = e
|
||
print(f"[Attempt {attempt + 1}/{max_retries + 1}] Connection error: {e}")
|
||
|
||
except requests.exceptions.RequestException as e:
|
||
last_exception = e
|
||
print(f"[Attempt {attempt + 1}/{max_retries + 1}] Network error: {e}")
|
||
|
||
except Exception as e:
|
||
last_exception = e
|
||
print(f"[Attempt {attempt + 1}/{max_retries + 1}] Unexpected error: {e}")
|
||
|
||
# 如果不是最后一次尝试,则等待后重试
|
||
if attempt < max_retries:
|
||
time.sleep(retry_delay)
|
||
|
||
raise RuntimeError(f"Failed after {max_retries + 1} attempts. Last error: {last_exception}")
|
||
|
||
def upload_policy(policy_dir: str, server_url: str = "http://localhost:50001/update_policy", device: str = 'cpu'):
|
||
"""
|
||
将 policy_dir 打包为 .tar 并上传到服务器
|
||
"""
|
||
# 创建临时 .tar 文件(内存中或磁盘)
|
||
import tempfile
|
||
with tempfile.NamedTemporaryFile(suffix='.tar') as tmp_tar:
|
||
# print("tar ... ")
|
||
with tarfile.open(tmp_tar.name, 'w') as tar:
|
||
# arcname='' 表示根目录内容直接放入 tar,避免多层嵌套
|
||
tar.add(policy_dir, arcname='')
|
||
|
||
# print("upload ... ")
|
||
# 上传
|
||
with open(tmp_tar.name, 'rb') as f:
|
||
files = {'file': ('policy.tar', f, 'application/x-tar')}
|
||
resp = requests.post(server_url, files=files, data={'device': device})
|
||
print(f"{resp = }")
|
||
# print(resp.json())
|
||
return resp.status_code == 200 |