301 lines
9.8 KiB
Python
301 lines
9.8 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
RDT 推理服务器测试客户端
|
||
使用模拟数据测试 get_actions 接口
|
||
"""
|
||
|
||
import numpy as np
|
||
import logging
|
||
import argparse
|
||
import time
|
||
from cloud_helper import Client
|
||
|
||
# 配置日志
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format="%(asctime)s - %(levelname)s - %(message)s"
|
||
)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def create_mock_observation(
|
||
state_dim=6,
|
||
img_history_size=2,
|
||
img_height=480,
|
||
img_width=640,
|
||
num_cameras=3
|
||
):
|
||
"""创建模拟的观测数据
|
||
|
||
Args:
|
||
state_dim: 状态向量维度(关节数量)
|
||
img_history_size: 图像历史长度
|
||
img_height: 图像高度
|
||
img_width: 图像宽度
|
||
num_cameras: 相机数量
|
||
|
||
Returns:
|
||
observation: 包含状态和图像的观测字典
|
||
"""
|
||
observation = {}
|
||
|
||
# 1. 创建模拟的机器人状态(关节角度等)
|
||
# 范围在 [-180, 180] 度之间
|
||
state = np.random.uniform(-180, 180, size=(state_dim,)).astype(np.float32)
|
||
observation["state"] = state
|
||
|
||
# 2. 创建模拟的相机图像
|
||
# 注意:msgpack_numpy 会自动处理 numpy 数组的序列化
|
||
camera_names = ["cam_high", "cam_left_wrist", "cam_right_wrist"]
|
||
|
||
for i, cam_name in enumerate(camera_names[:num_cameras]):
|
||
# 创建彩色渐变图像作为模拟数据
|
||
images = []
|
||
for t in range(img_history_size):
|
||
# 为每个时间步创建不同颜色的图像
|
||
img = np.zeros((img_height, img_width, 3), dtype=np.uint8)
|
||
|
||
# 创建彩色渐变效果
|
||
color_shift = (t * 50 + i * 100) % 255
|
||
img[:, :, 0] = np.linspace(color_shift, 255, img_width, dtype=np.uint8) # R
|
||
img[:, :, 1] = np.linspace(0, 255 - color_shift, img_height, dtype=np.uint8)[:, None] # G
|
||
img[:, :, 2] = 128 # B
|
||
|
||
images.append(img)
|
||
|
||
# 堆叠为 (IMG_HISTORY_SIZE, H, W, 3) 格式
|
||
observation[f"images.{cam_name}"] = np.stack(images, axis=0)
|
||
|
||
return observation
|
||
|
||
|
||
def create_test_batch(
|
||
observation,
|
||
instruction="pick up the object and place it in the box",
|
||
use_instruction_index=False
|
||
):
|
||
"""创建完整的测试批次数据
|
||
|
||
Args:
|
||
observation: 观测数据字典
|
||
instruction: 指令字符串或索引
|
||
use_instruction_index: 是否使用指令索引而非字符串
|
||
|
||
Returns:
|
||
batch: 完整的请求数据
|
||
"""
|
||
batch = {
|
||
"observation": observation,
|
||
"instruction": 0 if use_instruction_index else instruction
|
||
}
|
||
return batch
|
||
|
||
|
||
def test_single_request(client, args):
|
||
"""测试单次请求"""
|
||
logger.info("=" * 60)
|
||
logger.info("开始单次请求测试")
|
||
logger.info("=" * 60)
|
||
|
||
# 创建模拟数据
|
||
observation = create_mock_observation(
|
||
state_dim=args.state_dim,
|
||
img_history_size=args.img_history_size,
|
||
img_height=args.img_height,
|
||
img_width=args.img_width,
|
||
num_cameras=args.num_cameras
|
||
)
|
||
|
||
logger.info(f"模拟观测数据:")
|
||
logger.info(f" - state shape: {observation['state'].shape}")
|
||
for key in observation.keys():
|
||
if key.startswith("images."):
|
||
logger.info(f" - {key} shape: {observation[key].shape}")
|
||
|
||
# 创建请求批次
|
||
batch = create_test_batch(
|
||
observation,
|
||
instruction=args.instruction,
|
||
use_instruction_index=args.use_index
|
||
)
|
||
|
||
# 发送请求
|
||
logger.info(f"发送指令: {batch['instruction']}")
|
||
start_time = time.time()
|
||
|
||
try:
|
||
action = client.call_endpoint("get_actions", batch)
|
||
elapsed_time = time.time() - start_time
|
||
|
||
logger.info(f"✓ 请求成功! 耗时: {elapsed_time*1000:.2f} ms")
|
||
logger.info(f" - action shape: {action.shape}")
|
||
logger.info(f" - action dtype: {action.dtype}")
|
||
logger.info(f" - action range: [{action.min():.3f}, {action.max():.3f}]")
|
||
logger.info(f" - action preview (前3个时间步的前3个维度):")
|
||
preview_steps = min(3, action.shape[0])
|
||
preview_dims = min(3, action.shape[1])
|
||
for t in range(preview_steps):
|
||
logger.info(f" t={t}: {action[t, :preview_dims]}")
|
||
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"✗ 请求失败: {e}")
|
||
return False
|
||
|
||
|
||
def test_multiple_requests(client, args):
|
||
"""测试多次连续请求(性能测试)"""
|
||
logger.info("=" * 60)
|
||
logger.info(f"开始连续请求测试 (共 {args.num_requests} 次)")
|
||
logger.info("=" * 60)
|
||
|
||
# 预先创建观测数据
|
||
observation = create_mock_observation(
|
||
state_dim=args.state_dim,
|
||
img_history_size=args.img_history_size,
|
||
img_height=args.img_height,
|
||
img_width=args.img_width,
|
||
num_cameras=args.num_cameras
|
||
)
|
||
|
||
batch = create_test_batch(
|
||
observation,
|
||
instruction=args.instruction,
|
||
use_instruction_index=args.use_index
|
||
)
|
||
|
||
success_count = 0
|
||
total_time = 0
|
||
latencies = []
|
||
|
||
for i in range(args.num_requests):
|
||
try:
|
||
start_time = time.time()
|
||
action = client.call_endpoint("get_actions", batch)
|
||
elapsed_time = time.time() - start_time
|
||
|
||
success_count += 1
|
||
total_time += elapsed_time
|
||
latencies.append(elapsed_time)
|
||
|
||
if (i + 1) % 10 == 0:
|
||
logger.info(f"已完成 {i + 1}/{args.num_requests} 次请求")
|
||
|
||
except Exception as e:
|
||
logger.error(f"第 {i+1} 次请求失败: {e}")
|
||
|
||
# 统计结果
|
||
logger.info("=" * 60)
|
||
logger.info("性能统计:")
|
||
logger.info(f" - 总请求数: {args.num_requests}")
|
||
logger.info(f" - 成功数: {success_count}")
|
||
logger.info(f" - 失败数: {args.num_requests - success_count}")
|
||
logger.info(f" - 成功率: {success_count/args.num_requests*100:.1f}%")
|
||
|
||
if latencies:
|
||
latencies = np.array(latencies)
|
||
logger.info(f" - 平均延迟: {np.mean(latencies)*1000:.2f} ms")
|
||
logger.info(f" - 中位数延迟: {np.median(latencies)*1000:.2f} ms")
|
||
logger.info(f" - 最小延迟: {np.min(latencies)*1000:.2f} ms")
|
||
logger.info(f" - 最大延迟: {np.max(latencies)*1000:.2f} ms")
|
||
logger.info(f" - 吞吐量: {success_count/total_time:.2f} requests/s")
|
||
|
||
|
||
def test_different_instructions(client, args):
|
||
"""测试不同的指令"""
|
||
logger.info("=" * 60)
|
||
logger.info("测试不同指令")
|
||
logger.info("=" * 60)
|
||
|
||
instructions = [
|
||
"pick up the red cube",
|
||
"place the object on the table",
|
||
"move to the left",
|
||
"grasp the bottle",
|
||
"open the drawer"
|
||
]
|
||
|
||
observation = create_mock_observation(
|
||
state_dim=args.state_dim,
|
||
img_history_size=args.img_history_size,
|
||
img_height=args.img_height,
|
||
img_width=args.img_width,
|
||
num_cameras=args.num_cameras
|
||
)
|
||
|
||
for i, instruction in enumerate(instructions):
|
||
logger.info(f"\n测试指令 {i+1}/{len(instructions)}: '{instruction}'")
|
||
batch = create_test_batch(observation, instruction=instruction)
|
||
|
||
try:
|
||
start_time = time.time()
|
||
action = client.call_endpoint("get_actions", batch)
|
||
elapsed_time = time.time() - start_time
|
||
|
||
logger.info(f" ✓ 成功 | 耗时: {elapsed_time*1000:.2f} ms | action shape: {action.shape}")
|
||
|
||
except Exception as e:
|
||
logger.error(f" ✗ 失败: {e}")
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description="RDT 推理服务器测试客户端")
|
||
|
||
# 连接参数
|
||
parser.add_argument("--host", type=str, default="localhost", help="服务器地址")
|
||
parser.add_argument("--port", type=int, default=8001, help="服务器端口")
|
||
|
||
# 测试模式
|
||
parser.add_argument("--mode", type=str, default="single",
|
||
choices=["single", "multiple", "instructions"],
|
||
help="测试模式: single(单次), multiple(多次), instructions(不同指令)")
|
||
parser.add_argument("--num-requests", type=int, default=50,
|
||
help="多次测试的请求数量")
|
||
|
||
# 数据参数
|
||
parser.add_argument("--state-dim", type=int, default=6, help="状态向量维度")
|
||
parser.add_argument("--img-history-size", type=int, default=2, help="图像历史长度")
|
||
parser.add_argument("--img-height", type=int, default=480, help="图像高度")
|
||
parser.add_argument("--img-width", type=int, default=640, help="图像宽度")
|
||
parser.add_argument("--num-cameras", type=int, default=3, help="相机数量 (与服务器配置一致)")
|
||
|
||
# 指令参数
|
||
parser.add_argument("--instruction", type=str,
|
||
default="pick up the object and place it in the box",
|
||
help="测试指令")
|
||
parser.add_argument("--use-index", action="store_true",
|
||
help="使用指令索引而非字符串")
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 连接服务器
|
||
logger.info(f"正在连接到 {args.host}:{args.port} ...")
|
||
try:
|
||
client = Client(host=args.host, port=args.port)
|
||
logger.info("✓ 连接成功!")
|
||
except Exception as e:
|
||
logger.error(f"✗ 连接失败: {e}")
|
||
return
|
||
|
||
# 根据模式运行测试
|
||
try:
|
||
if args.mode == "single":
|
||
test_single_request(client, args)
|
||
elif args.mode == "multiple":
|
||
test_multiple_requests(client, args)
|
||
elif args.mode == "instructions":
|
||
test_different_instructions(client, args)
|
||
|
||
except KeyboardInterrupt:
|
||
logger.info("\n测试被用户中断")
|
||
except Exception as e:
|
||
logger.error(f"测试过程中发生错误: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|
||
|