37 lines
938 B
Python
37 lines
938 B
Python
import os
|
|
import zarr
|
|
import pickle
|
|
import tqdm
|
|
import numpy as np
|
|
import torch
|
|
import pytorch3d.ops as torch3d_ops
|
|
import torchvision
|
|
from termcolor import cprint
|
|
import re
|
|
import time
|
|
|
|
|
|
import numpy as np
|
|
import torch
|
|
import pytorch3d.ops as torch3d_ops
|
|
import torchvision
|
|
import socket
|
|
import pickle
|
|
|
|
|
|
def fps(points, num_points=1024, use_cuda=True):
|
|
|
|
K = [num_points]
|
|
if use_cuda:
|
|
points = torch.from_numpy(points).cuda()
|
|
sampled_points, indices = torch3d_ops.sample_farthest_points(points=points.unsqueeze(0), K=K)
|
|
sampled_points = sampled_points.squeeze(0)
|
|
sampled_points = sampled_points.cpu().numpy()
|
|
else:
|
|
points = torch.from_numpy(points)
|
|
sampled_points, indices = torch3d_ops.sample_farthest_points(points=points.unsqueeze(0), K=K)
|
|
sampled_points = sampled_points.squeeze(0)
|
|
sampled_points = sampled_points.numpy()
|
|
|
|
return sampled_points, indices
|