139 lines
4.9 KiB
Python
139 lines
4.9 KiB
Python
import pickle
|
||
import numpy as np
|
||
import cv2
|
||
import logging
|
||
from typing import Any, List
|
||
from tools import measure_time, show_data_summary
|
||
# 使用 Pickle Protocol 5 (Python 3.8+)
|
||
# 它支持 out-of-band data,对大型 numpy 数组非常高效
|
||
# 如果 Python 版本 < 3.8,默认 protocol 即可,依然比 JSON 快得多
|
||
HIGHEST_PROTOCOL = 4
|
||
FIX_IMPORTS = False
|
||
|
||
# 图像 key 列表,用于 JPEG 压缩/解码匹配
|
||
IMAGE_KEYS: List[str] = [
|
||
'observation.images.cam_high',
|
||
'observation.images.cam_left_wrist',
|
||
'observation.images.cam_right_wrist',
|
||
]
|
||
|
||
JPEG_QUALITY = 95
|
||
|
||
logging.basicConfig(
|
||
level = logging.INFO,
|
||
format = '[%(name)s] [%(asctime)s.%(msecs)03d] [%(levelname)s] %(message)s',
|
||
datefmt='%H:%M:%S')
|
||
logger = logging.getLogger("SmolVLA_Server")
|
||
|
||
|
||
def jpeg_encode(img_chw: np.ndarray) -> bytes:
|
||
"""
|
||
将 (1, 3, H, W) 的图像数组编码为 JPEG bytes。
|
||
输入: shape (1, 3, H, W), dtype uint8 [0,255] 或 float32 [0,1]
|
||
输出: JPEG bytes
|
||
"""
|
||
# (1, 3, H, W) -> (H, W, 3)
|
||
img = img_chw[0].transpose(1, 2, 0)
|
||
if img.dtype != np.uint8:
|
||
img = (img * 255.0).clip(0, 255).astype(np.uint8)
|
||
success, buf = cv2.imencode('.jpg', img, [cv2.IMWRITE_JPEG_QUALITY, JPEG_QUALITY])
|
||
if not success:
|
||
raise RuntimeError("cv2.imencode JPEG failed")
|
||
return buf.tobytes()
|
||
|
||
|
||
def jpeg_decode(jpeg_bytes) -> np.ndarray:
|
||
"""
|
||
将 JPEG bytes 或 np.ndarray(uint8) 解码回 (1, 3, H, W) float32 [0,1] 的图像数组。
|
||
输出: shape (1, 3, H, W), dtype float32, 值域 [0, 1]
|
||
"""
|
||
if isinstance(jpeg_bytes, np.ndarray):
|
||
buf = jpeg_bytes
|
||
else:
|
||
buf = np.frombuffer(jpeg_bytes, dtype=np.uint8)
|
||
img = cv2.imdecode(buf, cv2.IMREAD_COLOR) # (H, W, 3) uint8 BGR
|
||
# (H, W, 3) -> (1, 3, H, W), uint8 -> float32 [0,1]
|
||
return img.transpose(2, 0, 1)[np.newaxis].astype(np.float32) / 255.0
|
||
|
||
|
||
def encode_images_jpeg(data: dict) -> dict:
|
||
"""
|
||
将 data 中匹配 IMAGE_KEYS 的图像数组替换为 JPEG bytes(用于发送前压缩)。
|
||
非图像字段保持不变。
|
||
"""
|
||
out = {}
|
||
for k, v in data.items():
|
||
if k in IMAGE_KEYS and isinstance(v, np.ndarray):
|
||
out[k] = jpeg_encode(v)
|
||
else:
|
||
out[k] = v
|
||
return out
|
||
|
||
|
||
def decode_images_jpeg(data: dict) -> dict:
|
||
"""
|
||
将 data 中匹配 IMAGE_KEYS 的 JPEG bytes/np.ndarray(uint8) 解码回图像数组(用于接收后解压)。
|
||
非图像字段保持不变。
|
||
"""
|
||
out = {}
|
||
for k, v in data.items():
|
||
if k in IMAGE_KEYS and isinstance(v, (bytes, np.ndarray)):
|
||
out[k] = jpeg_decode(v)
|
||
else:
|
||
out[k] = v
|
||
return out
|
||
|
||
|
||
# @measure_time(logger)
|
||
def dict_to_binary(data: Any) -> bytes:
|
||
"""
|
||
将包含 numpy 数组、字符串、字典的混合对象序列化为二进制流。
|
||
不进行任何压缩,追求极致速度。
|
||
"""
|
||
# fix_imports=False 确保二进制兼容性更好,buffer_callback 可用于极端优化,但默认已足够快
|
||
return pickle.dumps(data, protocol=HIGHEST_PROTOCOL, fix_imports=FIX_IMPORTS)
|
||
|
||
|
||
# @measure_time(logger)
|
||
def binary_to_dict(data: bytes) -> Any:
|
||
"""
|
||
将二进制流还原为原始字典对象。
|
||
"""
|
||
return pickle.loads(data, fix_imports=FIX_IMPORTS)
|
||
|
||
# 简单的自测
|
||
if __name__ == "__main__":
|
||
import time
|
||
|
||
# 测试 JPEG 编解码
|
||
print("=== JPEG encode/decode test ===")
|
||
img_orig = np.random.randint(0, 255, (1, 3, 480, 640), dtype=np.uint8).astype(np.float32) / 255.0
|
||
jpeg_bytes = jpeg_encode(img_orig)
|
||
img_restored = jpeg_decode(jpeg_bytes)
|
||
print(f"Original shape: {img_orig.shape}, JPEG size: {len(jpeg_bytes)/1024:.2f} KB")
|
||
print(f"Restored shape: {img_restored.shape}, max abs diff: {np.abs(img_orig - img_restored).max():.4f}")
|
||
|
||
test_data = {
|
||
"instruction": "pick up the cup",
|
||
"state": np.array([0.1, 0.2, 0.3], dtype=np.float32),
|
||
"observation.images.cam_high": np.random.randint(0, 255, (1, 3, 480, 640), dtype=np.uint8).astype(np.float32) / 255.0,
|
||
"observation.images.cam_left_wrist": np.random.randint(0, 255, (1, 3, 480, 640), dtype=np.uint8).astype(np.float32) / 255.0,
|
||
"observation.images.cam_right_wrist": np.random.randint(0, 255, (1, 3, 480, 640), dtype=np.uint8).astype(np.float32) / 255.0,
|
||
}
|
||
|
||
print("\n=== Full pipeline test (encode_images_jpeg -> dict_to_binary -> binary_to_dict -> decode_images_jpeg) ===")
|
||
start = time.time()
|
||
encoded = encode_images_jpeg(test_data)
|
||
blob = dict_to_binary(encoded)
|
||
t1 = time.time() - start
|
||
|
||
start = time.time()
|
||
restored_encoded = binary_to_dict(blob)
|
||
restored = decode_images_jpeg(restored_encoded)
|
||
t2 = time.time() - start
|
||
|
||
print(f"Serialize+JPEG time: {t1*1000:.2f} ms, Size: {len(blob)/1024:.2f} KB")
|
||
print(f"Deserialize+JPEG time: {t2*1000:.2f} ms")
|
||
for k in test_data.keys():
|
||
if isinstance(test_data[k], np.ndarray):
|
||
print(f"{k}: {test_data[k].shape} -> {restored[k].shape}") |