import pickle import numpy as np import cv2 import logging from typing import Any, List from tools import measure_time, show_data_summary # 使用 Pickle Protocol 5 (Python 3.8+) # 它支持 out-of-band data,对大型 numpy 数组非常高效 # 如果 Python 版本 < 3.8,默认 protocol 即可,依然比 JSON 快得多 HIGHEST_PROTOCOL = 4 FIX_IMPORTS = False # 图像 key 列表,用于 JPEG 压缩/解码匹配 IMAGE_KEYS: List[str] = [ 'observation.images.cam_high', 'observation.images.cam_left_wrist', 'observation.images.cam_right_wrist', ] JPEG_QUALITY = 95 logging.basicConfig( level = logging.INFO, format = '[%(name)s] [%(asctime)s.%(msecs)03d] [%(levelname)s] %(message)s', datefmt='%H:%M:%S') logger = logging.getLogger("SmolVLA_Server") def jpeg_encode(img_chw: np.ndarray) -> bytes: """ 将 (1, 3, H, W) 的图像数组编码为 JPEG bytes。 输入: shape (1, 3, H, W), dtype uint8 [0,255] 或 float32 [0,1] 输出: JPEG bytes """ # (1, 3, H, W) -> (H, W, 3) img = img_chw[0].transpose(1, 2, 0) if img.dtype != np.uint8: img = (img * 255.0).clip(0, 255).astype(np.uint8) success, buf = cv2.imencode('.jpg', img, [cv2.IMWRITE_JPEG_QUALITY, JPEG_QUALITY]) if not success: raise RuntimeError("cv2.imencode JPEG failed") return buf.tobytes() def jpeg_decode(jpeg_bytes) -> np.ndarray: """ 将 JPEG bytes 或 np.ndarray(uint8) 解码回 (1, 3, H, W) float32 [0,1] 的图像数组。 输出: shape (1, 3, H, W), dtype float32, 值域 [0, 1] """ if isinstance(jpeg_bytes, np.ndarray): buf = jpeg_bytes else: buf = np.frombuffer(jpeg_bytes, dtype=np.uint8) img = cv2.imdecode(buf, cv2.IMREAD_COLOR) # (H, W, 3) uint8 BGR # (H, W, 3) -> (1, 3, H, W), uint8 -> float32 [0,1] return img.transpose(2, 0, 1)[np.newaxis].astype(np.float32) / 255.0 def encode_images_jpeg(data: dict) -> dict: """ 将 data 中匹配 IMAGE_KEYS 的图像数组替换为 JPEG bytes(用于发送前压缩)。 非图像字段保持不变。 """ out = {} for k, v in data.items(): if k in IMAGE_KEYS and isinstance(v, np.ndarray): out[k] = jpeg_encode(v) else: out[k] = v return out def decode_images_jpeg(data: dict) -> dict: """ 将 data 中匹配 IMAGE_KEYS 的 JPEG bytes/np.ndarray(uint8) 解码回图像数组(用于接收后解压)。 非图像字段保持不变。 """ out = {} for k, v in data.items(): if k in IMAGE_KEYS and isinstance(v, (bytes, np.ndarray)): out[k] = jpeg_decode(v) else: out[k] = v return out # @measure_time(logger) def dict_to_binary(data: Any) -> bytes: """ 将包含 numpy 数组、字符串、字典的混合对象序列化为二进制流。 不进行任何压缩,追求极致速度。 """ # fix_imports=False 确保二进制兼容性更好,buffer_callback 可用于极端优化,但默认已足够快 return pickle.dumps(data, protocol=HIGHEST_PROTOCOL, fix_imports=FIX_IMPORTS) # @measure_time(logger) def binary_to_dict(data: bytes) -> Any: """ 将二进制流还原为原始字典对象。 """ return pickle.loads(data, fix_imports=FIX_IMPORTS) # 简单的自测 if __name__ == "__main__": import time # 测试 JPEG 编解码 print("=== JPEG encode/decode test ===") img_orig = np.random.randint(0, 255, (1, 3, 480, 640), dtype=np.uint8).astype(np.float32) / 255.0 jpeg_bytes = jpeg_encode(img_orig) img_restored = jpeg_decode(jpeg_bytes) print(f"Original shape: {img_orig.shape}, JPEG size: {len(jpeg_bytes)/1024:.2f} KB") print(f"Restored shape: {img_restored.shape}, max abs diff: {np.abs(img_orig - img_restored).max():.4f}") test_data = { "instruction": "pick up the cup", "state": np.array([0.1, 0.2, 0.3], dtype=np.float32), "observation.images.cam_high": np.random.randint(0, 255, (1, 3, 480, 640), dtype=np.uint8).astype(np.float32) / 255.0, "observation.images.cam_left_wrist": np.random.randint(0, 255, (1, 3, 480, 640), dtype=np.uint8).astype(np.float32) / 255.0, "observation.images.cam_right_wrist": np.random.randint(0, 255, (1, 3, 480, 640), dtype=np.uint8).astype(np.float32) / 255.0, } print("\n=== Full pipeline test (encode_images_jpeg -> dict_to_binary -> binary_to_dict -> decode_images_jpeg) ===") start = time.time() encoded = encode_images_jpeg(test_data) blob = dict_to_binary(encoded) t1 = time.time() - start start = time.time() restored_encoded = binary_to_dict(blob) restored = decode_images_jpeg(restored_encoded) t2 = time.time() - start print(f"Serialize+JPEG time: {t1*1000:.2f} ms, Size: {len(blob)/1024:.2f} KB") print(f"Deserialize+JPEG time: {t2*1000:.2f} ms") for k in test_data.keys(): if isinstance(test_data[k], np.ndarray): print(f"{k}: {test_data[k].shape} -> {restored[k].shape}")