import json import os import signal import threading import time import torch from cloud_helper import Server from lerobot.policies.factory import get_policy_class os.environ["CUDA_VISIBLE_DEVICES"] = "0" os.environ["HF_HUB_OFFLINE"] = "1" os.environ["HF_DATASETS_OFFLINE"] = "1" # os.environ["HTTPS_PROXY"] = "http://192.168.16.68:18000" with open("/workspace/embolab/params/build_task.json") as f: task_configs = json.load(f) model = task_configs["online_infer"]["model"] assert model in ["smolvla", "act"], f"Unsupported model: {model}" checkpoint = task_configs["online_infer"]["checkpoint_path"] if not checkpoint.endswith("/pretrained_model"): checkpoint += "/pretrained_model" print(f"Adjusted checkpoint path to: {checkpoint}") server_port = task_configs["online_infer"].get("port", 8080) class LerobotInferenceServer: def __init__( self, checkpoint: str, policy_type: str = "smolvla", host: str = "localhost", port: int = 5555, device="cuda", timeout: int = 3600, ): self.server = Server(host, port) self.policy_type = policy_type policy_class = get_policy_class(self.policy_type) self.policy = policy_class.from_pretrained(checkpoint) self.device = device self.policy.to(self.device) print(f"Loaded {self.policy_type.upper()} policy from {checkpoint}") self.timeout = timeout self.last_activity = time.time() self.stop_event = threading.Event() self.monitor_thread = threading.Thread(target=self.watchout, daemon=True) self.monitor_thread.start() def watchout(self): while not self.stop_event.is_set(): time.sleep(6) # Check every 6 seconds elapsed = time.time() - self.last_activity if elapsed > self.timeout: print(f"No activity for {elapsed:.0f} seconds. Shutting down due to timeout.") # Force exit since loop_forever might block os.kill(os.getpid(), signal.SIGINT) def get_actions(self, batch): # batch = { # "observation": { # "state": ..., # "images.front": ..., HWC uint8 # "images.wrist": ..., # }, # "instruction": ..., # } obs = {} for k, v in batch["observation"].items(): if k.startswith("images.") and v is not None: img = v.astype("float32") / 255.0 img = img.transpose(2, 0, 1) # HWC -> CHW img = torch.from_numpy(img).unsqueeze(0).to(self.device) obs[f"observation.{k}"] = img elif k == "state": tensor = torch.from_numpy(v.astype("float32")).unsqueeze(0).to(self.device) obs[f"observation.{k}"] = tensor obs["task"] = batch["instruction"] action_chunk = self.policy.predict_action_chunk(obs) self.last_activity = time.time() return action_chunk.cpu().numpy() # (B, chunk_size, action_dim) def run(self): self.server.register_endpoint("get_actions", self.get_actions) print(f"Lerobot {self.policy_type.upper()} Server is running...") self.server.loop_forever() if __name__ == "__main__": server = LerobotInferenceServer( checkpoint=checkpoint, policy_type=model, host="0.0.0.0", port=server_port, timeout=3600 ) server.run()