#!/usr/bin/env python3 """ RDT 推理服务器测试客户端 使用模拟数据测试 get_actions 接口 """ import numpy as np import logging import argparse import time from cloud_helper import Client # 配置日志 logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) def create_mock_observation( state_dim=6, img_history_size=2, img_height=480, img_width=640, num_cameras=3 ): """创建模拟的观测数据 Args: state_dim: 状态向量维度(关节数量) img_history_size: 图像历史长度 img_height: 图像高度 img_width: 图像宽度 num_cameras: 相机数量 Returns: observation: 包含状态和图像的观测字典 """ observation = {} # 1. 创建模拟的机器人状态(关节角度等) # 范围在 [-180, 180] 度之间 state = np.random.uniform(-180, 180, size=(state_dim,)).astype(np.float32) observation["state"] = state # 2. 创建模拟的相机图像 # 注意:msgpack_numpy 会自动处理 numpy 数组的序列化 camera_names = ["cam_high", "cam_left_wrist", "cam_right_wrist"] for i, cam_name in enumerate(camera_names[:num_cameras]): # 创建彩色渐变图像作为模拟数据 images = [] for t in range(img_history_size): # 为每个时间步创建不同颜色的图像 img = np.zeros((img_height, img_width, 3), dtype=np.uint8) # 创建彩色渐变效果 color_shift = (t * 50 + i * 100) % 255 img[:, :, 0] = np.linspace(color_shift, 255, img_width, dtype=np.uint8) # R img[:, :, 1] = np.linspace(0, 255 - color_shift, img_height, dtype=np.uint8)[:, None] # G img[:, :, 2] = 128 # B images.append(img) # 堆叠为 (IMG_HISTORY_SIZE, H, W, 3) 格式 observation[f"images.{cam_name}"] = np.stack(images, axis=0) return observation def create_test_batch( observation, instruction="pick up the object and place it in the box", use_instruction_index=False ): """创建完整的测试批次数据 Args: observation: 观测数据字典 instruction: 指令字符串或索引 use_instruction_index: 是否使用指令索引而非字符串 Returns: batch: 完整的请求数据 """ batch = { "observation": observation, "instruction": 0 if use_instruction_index else instruction } return batch def test_single_request(client, args): """测试单次请求""" logger.info("=" * 60) logger.info("开始单次请求测试") logger.info("=" * 60) # 创建模拟数据 observation = create_mock_observation( state_dim=args.state_dim, img_history_size=args.img_history_size, img_height=args.img_height, img_width=args.img_width, num_cameras=args.num_cameras ) logger.info(f"模拟观测数据:") logger.info(f" - state shape: {observation['state'].shape}") for key in observation.keys(): if key.startswith("images."): logger.info(f" - {key} shape: {observation[key].shape}") # 创建请求批次 batch = create_test_batch( observation, instruction=args.instruction, use_instruction_index=args.use_index ) # 发送请求 logger.info(f"发送指令: {batch['instruction']}") start_time = time.time() try: action = client.call_endpoint("get_actions", batch) elapsed_time = time.time() - start_time logger.info(f"✓ 请求成功! 耗时: {elapsed_time*1000:.2f} ms") logger.info(f" - action shape: {action.shape}") logger.info(f" - action dtype: {action.dtype}") logger.info(f" - action range: [{action.min():.3f}, {action.max():.3f}]") logger.info(f" - action preview (前3个时间步的前3个维度):") preview_steps = min(3, action.shape[0]) preview_dims = min(3, action.shape[1]) for t in range(preview_steps): logger.info(f" t={t}: {action[t, :preview_dims]}") return True except Exception as e: logger.error(f"✗ 请求失败: {e}") return False def test_multiple_requests(client, args): """测试多次连续请求(性能测试)""" logger.info("=" * 60) logger.info(f"开始连续请求测试 (共 {args.num_requests} 次)") logger.info("=" * 60) # 预先创建观测数据 observation = create_mock_observation( state_dim=args.state_dim, img_history_size=args.img_history_size, img_height=args.img_height, img_width=args.img_width, num_cameras=args.num_cameras ) batch = create_test_batch( observation, instruction=args.instruction, use_instruction_index=args.use_index ) success_count = 0 total_time = 0 latencies = [] for i in range(args.num_requests): try: start_time = time.time() action = client.call_endpoint("get_actions", batch) elapsed_time = time.time() - start_time success_count += 1 total_time += elapsed_time latencies.append(elapsed_time) if (i + 1) % 10 == 0: logger.info(f"已完成 {i + 1}/{args.num_requests} 次请求") except Exception as e: logger.error(f"第 {i+1} 次请求失败: {e}") # 统计结果 logger.info("=" * 60) logger.info("性能统计:") logger.info(f" - 总请求数: {args.num_requests}") logger.info(f" - 成功数: {success_count}") logger.info(f" - 失败数: {args.num_requests - success_count}") logger.info(f" - 成功率: {success_count/args.num_requests*100:.1f}%") if latencies: latencies = np.array(latencies) logger.info(f" - 平均延迟: {np.mean(latencies)*1000:.2f} ms") logger.info(f" - 中位数延迟: {np.median(latencies)*1000:.2f} ms") logger.info(f" - 最小延迟: {np.min(latencies)*1000:.2f} ms") logger.info(f" - 最大延迟: {np.max(latencies)*1000:.2f} ms") logger.info(f" - 吞吐量: {success_count/total_time:.2f} requests/s") def test_different_instructions(client, args): """测试不同的指令""" logger.info("=" * 60) logger.info("测试不同指令") logger.info("=" * 60) instructions = [ "pick up the red cube", "place the object on the table", "move to the left", "grasp the bottle", "open the drawer" ] observation = create_mock_observation( state_dim=args.state_dim, img_history_size=args.img_history_size, img_height=args.img_height, img_width=args.img_width, num_cameras=args.num_cameras ) for i, instruction in enumerate(instructions): logger.info(f"\n测试指令 {i+1}/{len(instructions)}: '{instruction}'") batch = create_test_batch(observation, instruction=instruction) try: start_time = time.time() action = client.call_endpoint("get_actions", batch) elapsed_time = time.time() - start_time logger.info(f" ✓ 成功 | 耗时: {elapsed_time*1000:.2f} ms | action shape: {action.shape}") except Exception as e: logger.error(f" ✗ 失败: {e}") def main(): parser = argparse.ArgumentParser(description="RDT 推理服务器测试客户端") # 连接参数 parser.add_argument("--host", type=str, default="localhost", help="服务器地址") parser.add_argument("--port", type=int, default=8001, help="服务器端口") # 测试模式 parser.add_argument("--mode", type=str, default="single", choices=["single", "multiple", "instructions"], help="测试模式: single(单次), multiple(多次), instructions(不同指令)") parser.add_argument("--num-requests", type=int, default=50, help="多次测试的请求数量") # 数据参数 parser.add_argument("--state-dim", type=int, default=6, help="状态向量维度") parser.add_argument("--img-history-size", type=int, default=2, help="图像历史长度") parser.add_argument("--img-height", type=int, default=480, help="图像高度") parser.add_argument("--img-width", type=int, default=640, help="图像宽度") parser.add_argument("--num-cameras", type=int, default=3, help="相机数量 (与服务器配置一致)") # 指令参数 parser.add_argument("--instruction", type=str, default="pick up the object and place it in the box", help="测试指令") parser.add_argument("--use-index", action="store_true", help="使用指令索引而非字符串") args = parser.parse_args() # 连接服务器 logger.info(f"正在连接到 {args.host}:{args.port} ...") try: client = Client(host=args.host, port=args.port) logger.info("✓ 连接成功!") except Exception as e: logger.error(f"✗ 连接失败: {e}") return # 根据模式运行测试 try: if args.mode == "single": test_single_request(client, args) elif args.mode == "multiple": test_multiple_requests(client, args) elif args.mode == "instructions": test_different_instructions(client, args) except KeyboardInterrupt: logger.info("\n测试被用户中断") except Exception as e: logger.error(f"测试过程中发生错误: {e}") import traceback traceback.print_exc() if __name__ == "__main__": main()