SmolVLA_Tools/RoboTwin_Policy/binary_protocol.py
2026-03-03 16:24:02 +08:00

139 lines
4.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pickle
import numpy as np
import cv2
import logging
from typing import Any, List
from tools import measure_time, show_data_summary
# 使用 Pickle Protocol 5 (Python 3.8+)
# 它支持 out-of-band data对大型 numpy 数组非常高效
# 如果 Python 版本 < 3.8,默认 protocol 即可,依然比 JSON 快得多
HIGHEST_PROTOCOL = 4
FIX_IMPORTS = False
# 图像 key 列表,用于 JPEG 压缩/解码匹配
IMAGE_KEYS: List[str] = [
'observation.images.cam_high',
'observation.images.cam_left_wrist',
'observation.images.cam_right_wrist',
]
JPEG_QUALITY = 95
logging.basicConfig(
level = logging.INFO,
format = '[%(name)s] [%(asctime)s.%(msecs)03d] [%(levelname)s] %(message)s',
datefmt='%H:%M:%S')
logger = logging.getLogger("SmolVLA_Server")
def jpeg_encode(img_chw: np.ndarray) -> bytes:
"""
将 (1, 3, H, W) 的图像数组编码为 JPEG bytes。
输入: shape (1, 3, H, W), dtype uint8 [0,255] 或 float32 [0,1]
输出: JPEG bytes
"""
# (1, 3, H, W) -> (H, W, 3)
img = img_chw[0].transpose(1, 2, 0)
if img.dtype != np.uint8:
img = (img * 255.0).clip(0, 255).astype(np.uint8)
success, buf = cv2.imencode('.jpg', img, [cv2.IMWRITE_JPEG_QUALITY, JPEG_QUALITY])
if not success:
raise RuntimeError("cv2.imencode JPEG failed")
return buf.tobytes()
def jpeg_decode(jpeg_bytes) -> np.ndarray:
"""
将 JPEG bytes 或 np.ndarray(uint8) 解码回 (1, 3, H, W) float32 [0,1] 的图像数组。
输出: shape (1, 3, H, W), dtype float32, 值域 [0, 1]
"""
if isinstance(jpeg_bytes, np.ndarray):
buf = jpeg_bytes
else:
buf = np.frombuffer(jpeg_bytes, dtype=np.uint8)
img = cv2.imdecode(buf, cv2.IMREAD_COLOR) # (H, W, 3) uint8 BGR
# (H, W, 3) -> (1, 3, H, W), uint8 -> float32 [0,1]
return img.transpose(2, 0, 1)[np.newaxis].astype(np.float32) / 255.0
def encode_images_jpeg(data: dict) -> dict:
"""
将 data 中匹配 IMAGE_KEYS 的图像数组替换为 JPEG bytes用于发送前压缩
非图像字段保持不变。
"""
out = {}
for k, v in data.items():
if k in IMAGE_KEYS and isinstance(v, np.ndarray):
out[k] = jpeg_encode(v)
else:
out[k] = v
return out
def decode_images_jpeg(data: dict) -> dict:
"""
将 data 中匹配 IMAGE_KEYS 的 JPEG bytes/np.ndarray(uint8) 解码回图像数组(用于接收后解压)。
非图像字段保持不变。
"""
out = {}
for k, v in data.items():
if k in IMAGE_KEYS and isinstance(v, (bytes, np.ndarray)):
out[k] = jpeg_decode(v)
else:
out[k] = v
return out
# @measure_time(logger)
def dict_to_binary(data: Any) -> bytes:
"""
将包含 numpy 数组、字符串、字典的混合对象序列化为二进制流。
不进行任何压缩,追求极致速度。
"""
# fix_imports=False 确保二进制兼容性更好buffer_callback 可用于极端优化,但默认已足够快
return pickle.dumps(data, protocol=HIGHEST_PROTOCOL, fix_imports=FIX_IMPORTS)
# @measure_time(logger)
def binary_to_dict(data: bytes) -> Any:
"""
将二进制流还原为原始字典对象。
"""
return pickle.loads(data, fix_imports=FIX_IMPORTS)
# 简单的自测
if __name__ == "__main__":
import time
# 测试 JPEG 编解码
print("=== JPEG encode/decode test ===")
img_orig = np.random.randint(0, 255, (1, 3, 480, 640), dtype=np.uint8).astype(np.float32) / 255.0
jpeg_bytes = jpeg_encode(img_orig)
img_restored = jpeg_decode(jpeg_bytes)
print(f"Original shape: {img_orig.shape}, JPEG size: {len(jpeg_bytes)/1024:.2f} KB")
print(f"Restored shape: {img_restored.shape}, max abs diff: {np.abs(img_orig - img_restored).max():.4f}")
test_data = {
"instruction": "pick up the cup",
"state": np.array([0.1, 0.2, 0.3], dtype=np.float32),
"observation.images.cam_high": np.random.randint(0, 255, (1, 3, 480, 640), dtype=np.uint8).astype(np.float32) / 255.0,
"observation.images.cam_left_wrist": np.random.randint(0, 255, (1, 3, 480, 640), dtype=np.uint8).astype(np.float32) / 255.0,
"observation.images.cam_right_wrist": np.random.randint(0, 255, (1, 3, 480, 640), dtype=np.uint8).astype(np.float32) / 255.0,
}
print("\n=== Full pipeline test (encode_images_jpeg -> dict_to_binary -> binary_to_dict -> decode_images_jpeg) ===")
start = time.time()
encoded = encode_images_jpeg(test_data)
blob = dict_to_binary(encoded)
t1 = time.time() - start
start = time.time()
restored_encoded = binary_to_dict(blob)
restored = decode_images_jpeg(restored_encoded)
t2 = time.time() - start
print(f"Serialize+JPEG time: {t1*1000:.2f} ms, Size: {len(blob)/1024:.2f} KB")
print(f"Deserialize+JPEG time: {t2*1000:.2f} ms")
for k in test_data.keys():
if isinstance(test_data[k], np.ndarray):
print(f"{k}: {test_data[k].shape} -> {restored[k].shape}")