Lerobot/docker/infer.py
2025-12-11 14:11:41 +08:00

105 lines
3.4 KiB
Python

import json
import os
import signal
import threading
import time
import torch
from cloud_helper import Server
from lerobot.policies.factory import get_policy_class
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["HF_HUB_OFFLINE"] = "1"
os.environ["HF_DATASETS_OFFLINE"] = "1"
# os.environ["HTTPS_PROXY"] = "http://192.168.16.68:18000"
with open("/workspace/embolab/params/build_task.json") as f:
task_configs = json.load(f)
model = task_configs["online_infer"]["model"]
assert model in ["smolvla", "act"], f"Unsupported model: {model}"
checkpoint = task_configs["online_infer"]["checkpoint_path"]
if not checkpoint.endswith("/pretrained_model"):
checkpoint += "/pretrained_model"
print(f"Adjusted checkpoint path to: {checkpoint}")
server_port = task_configs["online_infer"].get("port", 8080)
class LerobotInferenceServer:
def __init__(
self,
checkpoint: str,
policy_type: str = "smolvla",
host: str = "localhost",
port: int = 5555,
device="cuda",
timeout: int = 3600,
):
self.server = Server(host, port)
self.policy_type = policy_type
policy_class = get_policy_class(self.policy_type)
self.policy = policy_class.from_pretrained(checkpoint)
self.device = device
self.policy.to(self.device)
print(f"Loaded {self.policy_type.upper()} policy from {checkpoint}")
self.timeout = timeout
self.last_activity = time.time()
self.stop_event = threading.Event()
self.monitor_thread = threading.Thread(target=self.watchout, daemon=True)
self.monitor_thread.start()
def watchout(self):
while not self.stop_event.is_set():
time.sleep(6) # Check every 6 seconds
elapsed = time.time() - self.last_activity
if elapsed > self.timeout:
print(f"No activity for {elapsed:.0f} seconds. Shutting down due to timeout.")
# Force exit since loop_forever might block
os.kill(os.getpid(), signal.SIGINT)
def get_actions(self, batch):
# batch = {
# "observation": {
# "state": ...,
# "images.front": ..., HWC uint8
# "images.wrist": ...,
# },
# "instruction": ...,
# }
obs = {}
for k, v in batch["observation"].items():
if k.startswith("images.") and v is not None:
img = v.astype("float32") / 255.0
img = img.transpose(2, 0, 1) # HWC -> CHW
img = torch.from_numpy(img).unsqueeze(0).to(self.device)
obs[f"observation.{k}"] = img
elif k == "state":
tensor = torch.from_numpy(v.astype("float32")).unsqueeze(0).to(self.device)
obs[f"observation.{k}"] = tensor
obs["task"] = batch["instruction"]
action_chunk = self.policy.predict_action_chunk(obs)
self.last_activity = time.time()
return action_chunk.cpu().numpy() # (B, chunk_size, action_dim)
def run(self):
self.server.register_endpoint("get_actions", self.get_actions)
print(f"Lerobot {self.policy_type.upper()} Server is running...")
self.server.loop_forever()
if __name__ == "__main__":
server = LerobotInferenceServer(
checkpoint=checkpoint, policy_type=model, host="0.0.0.0", port=server_port, timeout=3600
)
server.run()