import json import os import threading import time import torch from fastapi import FastAPI import uvicorn from cloud_helper import Server from lerobot.policies.factory import get_policy_class os.environ["CUDA_VISIBLE_DEVICES"] = "0" os.environ["HF_HUB_OFFLINE"] = "1" os.environ["HF_DATASETS_OFFLINE"] = "1" # os.environ["HTTPS_PROXY"] = "http://192.168.16.68:18000" with open("/workspace/embolab/params/build_task.json") as f: task_configs = json.load(f) model = task_configs["online_infer"]["model"] assert model in ["smolvla", "act"], f"Unsupported model: {model}" checkpoint = task_configs["online_infer"]["checkpoint_path"] if not checkpoint.endswith("/pretrained_model"): checkpoint += "/pretrained_model" print(f"Adjusted checkpoint path to: {checkpoint}") server_port = task_configs["online_infer"].get("port", 8080) class LerobotInferenceServer: def __init__( self, checkpoint: str, policy_type: str = "smolvla", host: str = "localhost", port: int = 5555, device="cuda", http_port: int = 80, ): self.server = Server(host, port) self.policy_type = policy_type policy_class = get_policy_class(self.policy_type) self.policy = policy_class.from_pretrained(checkpoint) self.device = device self.policy.to(self.device) print(f"Loaded {self.policy_type.upper()} policy from {checkpoint}") self.last_activity = time.time() self.fastapi = FastAPI() def get_actions(self, batch): # batch = { # "observation": { # "state": ..., # "images.front": ..., HWC uint8 # "images.wrist": ..., # }, # "instruction": ..., # } obs = {} for k, v in batch["observation"].items(): if k.startswith("images.") and v is not None: img = v.astype("float32") / 255.0 img = img.transpose(2, 0, 1) # HWC -> CHW img = torch.from_numpy(img).unsqueeze(0).to(self.device) obs[f"observation.{k}"] = img elif k == "state": tensor = torch.from_numpy(v.astype("float32")).unsqueeze(0).to(self.device) obs[f"observation.{k}"] = tensor obs["task"] = batch["instruction"] action_chunk = self.policy.predict_action_chunk(obs) self.last_activity = time.time() return action_chunk.cpu().numpy() # (B, chunk_size, action_dim) def get_idle_time(self): return time.time() - self.last_activity def run(self): self.server.register_endpoint("get_actions", self.get_actions) @self.fastapi.get("/health") def health_check(): return {"status": 0} @self.fastapi.get("/idle_time") def idle_time(): return {"status": 0, "idle_time": self.get_idle_time()} def start_fastapi(app, port: int = 80): """在独立线程中启动 FastAPI""" print(f"Starting FastAPI HTTP server on port {port}...") uvicorn.run(app, host="0.0.0.0", port=port, log_level="info") threading.Thread(target=start_fastapi, args=(self.fastapi,), daemon=True).start() print(f"Lerobot {self.policy_type.upper()} Server is running...") self.server.loop_forever() if __name__ == "__main__": server = LerobotInferenceServer( checkpoint=checkpoint, policy_type=model, host="0.0.0.0", port=server_port ) server.run()