113 lines
3.5 KiB
Python
113 lines
3.5 KiB
Python
import json
|
|
import os
|
|
import threading
|
|
import time
|
|
|
|
import torch
|
|
|
|
from fastapi import FastAPI
|
|
import uvicorn
|
|
|
|
from cloud_helper import Server
|
|
from lerobot.policies.factory import get_policy_class
|
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
|
os.environ["HF_HUB_OFFLINE"] = "1"
|
|
os.environ["HF_DATASETS_OFFLINE"] = "1"
|
|
# os.environ["HTTPS_PROXY"] = "http://192.168.16.68:18000"
|
|
|
|
with open("/workspace/embolab/params/build_task.json") as f:
|
|
task_configs = json.load(f)
|
|
|
|
model = task_configs["online_infer"]["model"]
|
|
assert model in ["smolvla", "act"], f"Unsupported model: {model}"
|
|
|
|
checkpoint = task_configs["online_infer"]["checkpoint_path"]
|
|
if not checkpoint.endswith("/pretrained_model"):
|
|
checkpoint += "/pretrained_model"
|
|
print(f"Adjusted checkpoint path to: {checkpoint}")
|
|
|
|
server_port = task_configs["online_infer"].get("port", 8080)
|
|
|
|
|
|
class LerobotInferenceServer:
|
|
def __init__(
|
|
self,
|
|
checkpoint: str,
|
|
policy_type: str = "smolvla",
|
|
host: str = "localhost",
|
|
port: int = 5555,
|
|
device="cuda",
|
|
http_port: int = 80,
|
|
):
|
|
self.server = Server(host, port)
|
|
self.policy_type = policy_type
|
|
policy_class = get_policy_class(self.policy_type)
|
|
self.policy = policy_class.from_pretrained(checkpoint)
|
|
self.device = device
|
|
self.policy.to(self.device)
|
|
print(f"Loaded {self.policy_type.upper()} policy from {checkpoint}")
|
|
|
|
self.last_activity = time.time()
|
|
self.fastapi = FastAPI()
|
|
|
|
def get_actions(self, batch):
|
|
# batch = {
|
|
# "observation": {
|
|
# "state": ...,
|
|
# "images.front": ..., HWC uint8
|
|
# "images.wrist": ...,
|
|
# },
|
|
# "instruction": ...,
|
|
# }
|
|
|
|
obs = {}
|
|
|
|
for k, v in batch["observation"].items():
|
|
if k.startswith("images.") and v is not None:
|
|
img = v.astype("float32") / 255.0
|
|
img = img.transpose(2, 0, 1) # HWC -> CHW
|
|
img = torch.from_numpy(img).unsqueeze(0).to(self.device)
|
|
obs[f"observation.{k}"] = img
|
|
elif k == "state":
|
|
tensor = torch.from_numpy(v.astype("float32")).unsqueeze(0).to(self.device)
|
|
obs[f"observation.{k}"] = tensor
|
|
obs["task"] = batch["instruction"]
|
|
|
|
action_chunk = self.policy.predict_action_chunk(obs)
|
|
|
|
self.last_activity = time.time()
|
|
|
|
return action_chunk.cpu().numpy() # (B, chunk_size, action_dim)
|
|
|
|
def get_idle_time(self):
|
|
return time.time() - self.last_activity
|
|
|
|
def run(self):
|
|
self.server.register_endpoint("get_actions", self.get_actions)
|
|
|
|
@self.fastapi.get("/health")
|
|
def health_check():
|
|
return {"status": 0}
|
|
|
|
@self.fastapi.get("/idle_time")
|
|
def idle_time():
|
|
return {"status": 0, "idle_time": self.get_idle_time()}
|
|
|
|
def start_fastapi(app, port: int = 80):
|
|
"""在独立线程中启动 FastAPI"""
|
|
print(f"Starting FastAPI HTTP server on port {port}...")
|
|
uvicorn.run(app, host="0.0.0.0", port=port, log_level="info")
|
|
|
|
threading.Thread(target=start_fastapi, args=(self.fastapi,), daemon=True).start()
|
|
|
|
print(f"Lerobot {self.policy_type.upper()} Server is running...")
|
|
self.server.loop_forever()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
server = LerobotInferenceServer(
|
|
checkpoint=checkpoint, policy_type=model, host="0.0.0.0", port=server_port
|
|
)
|
|
server.run()
|