RoboTwin_image/envs/utils/farthest_point_sampler.py
2025-07-02 03:13:07 +00:00

37 lines
938 B
Python

import os
import zarr
import pickle
import tqdm
import numpy as np
import torch
import pytorch3d.ops as torch3d_ops
import torchvision
from termcolor import cprint
import re
import time
import numpy as np
import torch
import pytorch3d.ops as torch3d_ops
import torchvision
import socket
import pickle
def fps(points, num_points=1024, use_cuda=True):
K = [num_points]
if use_cuda:
points = torch.from_numpy(points).cuda()
sampled_points, indices = torch3d_ops.sample_farthest_points(points=points.unsqueeze(0), K=K)
sampled_points = sampled_points.squeeze(0)
sampled_points = sampled_points.cpu().numpy()
else:
points = torch.from_numpy(points)
sampled_points, indices = torch3d_ops.sample_farthest_points(points=points.unsqueeze(0), K=K)
sampled_points = sampled_points.squeeze(0)
sampled_points = sampled_points.numpy()
return sampled_points, indices