import os import zarr import pickle import tqdm import numpy as np import torch import pytorch3d.ops as torch3d_ops import torchvision from termcolor import cprint import re import time import numpy as np import torch import pytorch3d.ops as torch3d_ops import torchvision import socket import pickle def fps(points, num_points=1024, use_cuda=True): K = [num_points] if use_cuda: points = torch.from_numpy(points).cuda() sampled_points, indices = torch3d_ops.sample_farthest_points(points=points.unsqueeze(0), K=K) sampled_points = sampled_points.squeeze(0) sampled_points = sampled_points.cpu().numpy() else: points = torch.from_numpy(points) sampled_points, indices = torch3d_ops.sample_farthest_points(points=points.unsqueeze(0), K=K) sampled_points = sampled_points.squeeze(0) sampled_points = sampled_points.numpy() return sampled_points, indices