d-robotics-rdt/client.py
2025-11-02 12:25:30 +08:00

301 lines
9.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
RDT 推理服务器测试客户端
使用模拟数据测试 get_actions 接口
"""
import numpy as np
import logging
import argparse
import time
from cloud_helper import Client
# 配置日志
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
def create_mock_observation(
state_dim=6,
img_history_size=2,
img_height=480,
img_width=640,
num_cameras=3
):
"""创建模拟的观测数据
Args:
state_dim: 状态向量维度(关节数量)
img_history_size: 图像历史长度
img_height: 图像高度
img_width: 图像宽度
num_cameras: 相机数量
Returns:
observation: 包含状态和图像的观测字典
"""
observation = {}
# 1. 创建模拟的机器人状态(关节角度等)
# 范围在 [-180, 180] 度之间
state = np.random.uniform(-180, 180, size=(state_dim,)).astype(np.float32)
observation["state"] = state
# 2. 创建模拟的相机图像
# 注意msgpack_numpy 会自动处理 numpy 数组的序列化
camera_names = ["cam_high", "cam_left_wrist", "cam_right_wrist"]
for i, cam_name in enumerate(camera_names[:num_cameras]):
# 创建彩色渐变图像作为模拟数据
images = []
for t in range(img_history_size):
# 为每个时间步创建不同颜色的图像
img = np.zeros((img_height, img_width, 3), dtype=np.uint8)
# 创建彩色渐变效果
color_shift = (t * 50 + i * 100) % 255
img[:, :, 0] = np.linspace(color_shift, 255, img_width, dtype=np.uint8) # R
img[:, :, 1] = np.linspace(0, 255 - color_shift, img_height, dtype=np.uint8)[:, None] # G
img[:, :, 2] = 128 # B
images.append(img)
# 堆叠为 (IMG_HISTORY_SIZE, H, W, 3) 格式
observation[f"images.{cam_name}"] = np.stack(images, axis=0)
return observation
def create_test_batch(
observation,
instruction="pick up the object and place it in the box",
use_instruction_index=False
):
"""创建完整的测试批次数据
Args:
observation: 观测数据字典
instruction: 指令字符串或索引
use_instruction_index: 是否使用指令索引而非字符串
Returns:
batch: 完整的请求数据
"""
batch = {
"observation": observation,
"instruction": 0 if use_instruction_index else instruction
}
return batch
def test_single_request(client, args):
"""测试单次请求"""
logger.info("=" * 60)
logger.info("开始单次请求测试")
logger.info("=" * 60)
# 创建模拟数据
observation = create_mock_observation(
state_dim=args.state_dim,
img_history_size=args.img_history_size,
img_height=args.img_height,
img_width=args.img_width,
num_cameras=args.num_cameras
)
logger.info(f"模拟观测数据:")
logger.info(f" - state shape: {observation['state'].shape}")
for key in observation.keys():
if key.startswith("images."):
logger.info(f" - {key} shape: {observation[key].shape}")
# 创建请求批次
batch = create_test_batch(
observation,
instruction=args.instruction,
use_instruction_index=args.use_index
)
# 发送请求
logger.info(f"发送指令: {batch['instruction']}")
start_time = time.time()
try:
action = client.call_endpoint("get_actions", batch)
elapsed_time = time.time() - start_time
logger.info(f"✓ 请求成功! 耗时: {elapsed_time*1000:.2f} ms")
logger.info(f" - action shape: {action.shape}")
logger.info(f" - action dtype: {action.dtype}")
logger.info(f" - action range: [{action.min():.3f}, {action.max():.3f}]")
logger.info(f" - action preview (前3个时间步的前3个维度):")
preview_steps = min(3, action.shape[0])
preview_dims = min(3, action.shape[1])
for t in range(preview_steps):
logger.info(f" t={t}: {action[t, :preview_dims]}")
return True
except Exception as e:
logger.error(f"✗ 请求失败: {e}")
return False
def test_multiple_requests(client, args):
"""测试多次连续请求(性能测试)"""
logger.info("=" * 60)
logger.info(f"开始连续请求测试 (共 {args.num_requests} 次)")
logger.info("=" * 60)
# 预先创建观测数据
observation = create_mock_observation(
state_dim=args.state_dim,
img_history_size=args.img_history_size,
img_height=args.img_height,
img_width=args.img_width,
num_cameras=args.num_cameras
)
batch = create_test_batch(
observation,
instruction=args.instruction,
use_instruction_index=args.use_index
)
success_count = 0
total_time = 0
latencies = []
for i in range(args.num_requests):
try:
start_time = time.time()
action = client.call_endpoint("get_actions", batch)
elapsed_time = time.time() - start_time
success_count += 1
total_time += elapsed_time
latencies.append(elapsed_time)
if (i + 1) % 10 == 0:
logger.info(f"已完成 {i + 1}/{args.num_requests} 次请求")
except Exception as e:
logger.error(f"{i+1} 次请求失败: {e}")
# 统计结果
logger.info("=" * 60)
logger.info("性能统计:")
logger.info(f" - 总请求数: {args.num_requests}")
logger.info(f" - 成功数: {success_count}")
logger.info(f" - 失败数: {args.num_requests - success_count}")
logger.info(f" - 成功率: {success_count/args.num_requests*100:.1f}%")
if latencies:
latencies = np.array(latencies)
logger.info(f" - 平均延迟: {np.mean(latencies)*1000:.2f} ms")
logger.info(f" - 中位数延迟: {np.median(latencies)*1000:.2f} ms")
logger.info(f" - 最小延迟: {np.min(latencies)*1000:.2f} ms")
logger.info(f" - 最大延迟: {np.max(latencies)*1000:.2f} ms")
logger.info(f" - 吞吐量: {success_count/total_time:.2f} requests/s")
def test_different_instructions(client, args):
"""测试不同的指令"""
logger.info("=" * 60)
logger.info("测试不同指令")
logger.info("=" * 60)
instructions = [
"pick up the red cube",
"place the object on the table",
"move to the left",
"grasp the bottle",
"open the drawer"
]
observation = create_mock_observation(
state_dim=args.state_dim,
img_history_size=args.img_history_size,
img_height=args.img_height,
img_width=args.img_width,
num_cameras=args.num_cameras
)
for i, instruction in enumerate(instructions):
logger.info(f"\n测试指令 {i+1}/{len(instructions)}: '{instruction}'")
batch = create_test_batch(observation, instruction=instruction)
try:
start_time = time.time()
action = client.call_endpoint("get_actions", batch)
elapsed_time = time.time() - start_time
logger.info(f" ✓ 成功 | 耗时: {elapsed_time*1000:.2f} ms | action shape: {action.shape}")
except Exception as e:
logger.error(f" ✗ 失败: {e}")
def main():
parser = argparse.ArgumentParser(description="RDT 推理服务器测试客户端")
# 连接参数
parser.add_argument("--host", type=str, default="localhost", help="服务器地址")
parser.add_argument("--port", type=int, default=8001, help="服务器端口")
# 测试模式
parser.add_argument("--mode", type=str, default="single",
choices=["single", "multiple", "instructions"],
help="测试模式: single(单次), multiple(多次), instructions(不同指令)")
parser.add_argument("--num-requests", type=int, default=50,
help="多次测试的请求数量")
# 数据参数
parser.add_argument("--state-dim", type=int, default=6, help="状态向量维度")
parser.add_argument("--img-history-size", type=int, default=2, help="图像历史长度")
parser.add_argument("--img-height", type=int, default=480, help="图像高度")
parser.add_argument("--img-width", type=int, default=640, help="图像宽度")
parser.add_argument("--num-cameras", type=int, default=3, help="相机数量 (与服务器配置一致)")
# 指令参数
parser.add_argument("--instruction", type=str,
default="pick up the object and place it in the box",
help="测试指令")
parser.add_argument("--use-index", action="store_true",
help="使用指令索引而非字符串")
args = parser.parse_args()
# 连接服务器
logger.info(f"正在连接到 {args.host}:{args.port} ...")
try:
client = Client(host=args.host, port=args.port)
logger.info("✓ 连接成功!")
except Exception as e:
logger.error(f"✗ 连接失败: {e}")
return
# 根据模式运行测试
try:
if args.mode == "single":
test_single_request(client, args)
elif args.mode == "multiple":
test_multiple_requests(client, args)
elif args.mode == "instructions":
test_different_instructions(client, args)
except KeyboardInterrupt:
logger.info("\n测试被用户中断")
except Exception as e:
logger.error(f"测试过程中发生错误: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()