import json import gzip import torch import requests import tarfile import pickle import numpy as np import logging from typing import Any HIGHEST_PROTOCOL = 4 FIX_IMPORTS = False import time import requests import gzip import json import torch from typing import Optional, Dict, Any from tools import measure_time, show_data_summary from binary_protocol import dict_to_binary, binary_to_dict def send_inference_request( data_dict: Dict[str, Any], url: str = 'http://127.0.0.1:50000/infer', timeout: int = 10, max_retries: int = 30, retry_delay: float = 2 ) -> Dict[str, Any]: """ 发送推理请求 (二进制 Pickle 协议),支持超时、自动重试。 兼容 Python 3.10/3.12 及 Numpy 2.x。 :param data_dict: 输入数据字典 (可包含 numpy 数组) :param url: 推理服务地址 :param timeout: 单次 HTTP 请求超时时间(秒) :param max_retries: 最大重试次数 :param retry_delay: 重试前等待时间(秒) :return: 解码后的响应字典 (包含 numpy 数组) :raises RuntimeError: 超过重试次数或收到非 200 响应 """ # 1. 尝试序列化输入数据 try: req_body = dict_to_binary(data_dict) except Exception as e: raise RuntimeError(f"Failed to serialize request data: {e}") # 2. 设置二进制流 headers headers = { 'Content-Type': 'application/octet-stream' } # 3. 发送二进制流 last_exception: Optional[Exception] = None for attempt in range(max_retries + 1): try: resp = requests.post(url, data=req_body, headers=headers, timeout=timeout) if resp.status_code == 200: try: result_dict = binary_to_dict(resp.content) return result_dict except Exception as deserialize_err: raise RuntimeError(f"Failed to deserialize response: {deserialize_err}") else: error_msg = f"HTTP {resp.status_code}" try: err_data = binary_to_dict(resp.content) if isinstance(err_data, dict) and "message" in err_data: error_msg += f": {err_data['message']}" else: error_msg += f": {resp.text[:200]}" except: error_msg += f": {resp.text[:200]}" print(f"[Attempt {attempt + 1}/{max_retries + 1}] Server error: {error_msg}") last_exception = RuntimeError(error_msg) except requests.exceptions.Timeout as e: last_exception = e print(f"[Attempt {attempt + 1}/{max_retries + 1}] Request timeout: {e}") except requests.exceptions.ConnectionError as e: last_exception = e print(f"[Attempt {attempt + 1}/{max_retries + 1}] Connection error: {e}") except requests.exceptions.RequestException as e: last_exception = e print(f"[Attempt {attempt + 1}/{max_retries + 1}] Network error: {e}") except Exception as e: last_exception = e print(f"[Attempt {attempt + 1}/{max_retries + 1}] Unexpected error: {e}") # 如果不是最后一次尝试,则等待后重试 if attempt < max_retries: time.sleep(retry_delay) raise RuntimeError(f"Failed after {max_retries + 1} attempts. Last error: {last_exception}") def upload_policy(policy_dir: str, server_url: str = "http://localhost:50001/update_policy", device: str = 'cpu'): """ 将 policy_dir 打包为 .tar 并上传到服务器 """ # 创建临时 .tar 文件(内存中或磁盘) import tempfile with tempfile.NamedTemporaryFile(suffix='.tar') as tmp_tar: # print("tar ... ") with tarfile.open(tmp_tar.name, 'w') as tar: # arcname='' 表示根目录内容直接放入 tar,避免多层嵌套 tar.add(policy_dir, arcname='') # print("upload ... ") # 上传 with open(tmp_tar.name, 'rb') as f: files = {'file': ('policy.tar', f, 'application/x-tar')} resp = requests.post(server_url, files=files, data={'device': device}) print(f"{resp = }") # print(resp.json()) return resp.status_code == 200