243 lines
7.4 KiB
Python
243 lines
7.4 KiB
Python
import tensorflow as tf
|
|
import tensorflow_graphics.geometry.transformation.euler as tf_euler
|
|
import tensorflow_graphics.geometry.transformation.quaternion as tf_quat
|
|
import tensorflow_graphics.geometry.transformation.rotation_matrix_3d as tf_rotmat
|
|
|
|
|
|
def dataset_to_path(dataset_name: str, dir_name: str) -> str:
|
|
"""
|
|
Return the path to the dataset.
|
|
"""
|
|
if (dataset_name == "robo_net" or dataset_name == "cmu_playing_with_food" or dataset_name == "droid"):
|
|
version = "1.0.0"
|
|
elif (dataset_name == "language_table" or dataset_name == "fmb" or dataset_name == "dobbe"):
|
|
version = "0.0.1"
|
|
elif dataset_name == "nyu_door_opening_surprising_effectiveness":
|
|
version = ""
|
|
elif dataset_name == "cmu_play_fusion":
|
|
version = ""
|
|
elif dataset_name == "berkeley_gnm_recon":
|
|
version = ""
|
|
else:
|
|
version = "0.1.0"
|
|
return f"{dir_name}/{dataset_name}/{version}"
|
|
|
|
|
|
def clean_task_instruction(task_instruction: tf.Tensor, replacements: dict) -> tf.Tensor:
|
|
"""
|
|
Clean up the natural language task instruction.
|
|
"""
|
|
|
|
# Create a function that applies all replacements
|
|
def apply_replacements(tensor):
|
|
for old, new in replacements.items():
|
|
tensor = tf.strings.regex_replace(tensor, old, new)
|
|
return tensor
|
|
|
|
# Apply the replacements and strip leading and trailing spaces
|
|
cleaned_task_instruction = apply_replacements(task_instruction)
|
|
cleaned_task_instruction = tf.strings.strip(cleaned_task_instruction)
|
|
return cleaned_task_instruction
|
|
|
|
|
|
def quaternion_to_euler(quaternion: tf.Tensor) -> tf.Tensor:
|
|
"""
|
|
Convert a quaternion (x, y, z, w) to Euler angles (roll, pitch, yaw).
|
|
The (roll, pitch, yaw) corresponds to `Rotation.as_euler("xyz")` convention.
|
|
"""
|
|
# Normalize the quaternion
|
|
quaternion = tf.nn.l2_normalize(quaternion, axis=-1)
|
|
return tf_euler.from_quaternion(quaternion)
|
|
|
|
|
|
def euler_to_quaternion(euler: tf.Tensor) -> tf.Tensor:
|
|
"""
|
|
Convert Euler angles (roll, pitch, yaw) to a quaternion (x, y, z, w).
|
|
The (roll, pitch, yaw) corresponds to `Rotation.as_euler("xyz")` convention.
|
|
"""
|
|
quaternion = tf_quat.from_euler(euler)
|
|
return tf.nn.l2_normalize(quaternion, axis=-1)
|
|
|
|
|
|
def rotation_matrix_to_euler(matrix: tf.Tensor) -> tf.Tensor:
|
|
"""
|
|
Convert a 3x3 rotation matrix to Euler angles (roll, pitch, yaw).
|
|
The (roll, pitch, yaw) corresponds to `Rotation.as_euler("xyz")` convention.
|
|
"""
|
|
return tf_euler.from_rotation_matrix(matrix)
|
|
|
|
|
|
def rotation_matrix_to_quaternion(matrix: tf.Tensor) -> tf.Tensor:
|
|
"""
|
|
Convert a 3x3 rotation matrix to a quaternion (x, y, z, w).
|
|
"""
|
|
quaternion = tf_quat.from_rotation_matrix(matrix)
|
|
return tf.nn.l2_normalize(quaternion, axis=-1)
|
|
|
|
|
|
def euler_to_rotation_matrix(euler: tf.Tensor) -> tf.Tensor:
|
|
"""
|
|
Convert Euler angles (roll, pitch, yaw) to a 3x3 rotation matrix.
|
|
The (roll, pitch, yaw) corresponds to `Rotation.as_euler("xyz")` convention.
|
|
"""
|
|
return tf_rotmat.from_euler(euler)
|
|
|
|
|
|
def quaternion_to_rotation_matrix(quaternion: tf.Tensor) -> tf.Tensor:
|
|
"""
|
|
Convert a quaternion (x, y, z, w) to a 3x3 rotation matrix.
|
|
"""
|
|
# Normalize the quaternion
|
|
quaternion = tf.nn.l2_normalize(quaternion, axis=-1)
|
|
return tf_rotmat.from_quaternion(quaternion)
|
|
|
|
|
|
def quaternion_to_rotation_matrix_wo_static_check(quaternion: tf.Tensor) -> tf.Tensor:
|
|
"""
|
|
Convert a quaternion (x, y, z, w) to a 3x3 rotation matrix.
|
|
This function is used to make tensorflow happy.
|
|
"""
|
|
# Normalize the quaternion
|
|
quaternion = tf.nn.l2_normalize(quaternion, axis=-1)
|
|
|
|
x = quaternion[..., 0]
|
|
y = quaternion[..., 1]
|
|
z = quaternion[..., 2]
|
|
w = quaternion[..., 3]
|
|
|
|
tx = 2.0 * x
|
|
ty = 2.0 * y
|
|
tz = 2.0 * z
|
|
twx = tx * w
|
|
twy = ty * w
|
|
twz = tz * w
|
|
txx = tx * x
|
|
txy = ty * x
|
|
txz = tz * x
|
|
tyy = ty * y
|
|
tyz = tz * y
|
|
tzz = tz * z
|
|
matrix = tf.stack(
|
|
(
|
|
1.0 - (tyy + tzz),
|
|
txy - twz,
|
|
txz + twy,
|
|
txy + twz,
|
|
1.0 - (txx + tzz),
|
|
tyz - twx,
|
|
txz - twy,
|
|
tyz + twx,
|
|
1.0 - (txx + tyy),
|
|
),
|
|
axis=-1,
|
|
) # pyformat: disable
|
|
output_shape = tf.concat((tf.shape(input=quaternion)[:-1], (3, 3)), axis=-1)
|
|
return tf.reshape(matrix, shape=output_shape)
|
|
|
|
|
|
"""
|
|
Below is a continuous 6D rotation representation adapted from
|
|
On the Continuity of Rotation Representations in Neural Networks
|
|
https://arxiv.org/pdf/1812.07035.pdf
|
|
https://github.com/papagina/RotationContinuity/blob/master/sanity_test/code/tools.py
|
|
"""
|
|
|
|
|
|
def rotation_matrix_to_ortho6d(matrix: tf.Tensor) -> tf.Tensor:
|
|
"""
|
|
The orhto6d represents the first two column vectors a1 and a2 of the
|
|
rotation matrix: [ | , |, | ]
|
|
[ a1, a2, a3]
|
|
[ | , |, | ]
|
|
Input: (A1, ..., An, 3, 3)
|
|
Output: (A1, ..., An, 6)
|
|
"""
|
|
ortho6d = matrix[..., :, :2]
|
|
# Transpose the last two dimension
|
|
perm = list(range(len(ortho6d.shape)))
|
|
perm[-2], perm[-1] = perm[-1], perm[-2]
|
|
ortho6d = tf.transpose(ortho6d, perm)
|
|
# Flatten the last two dimension
|
|
ortho6d = tf.reshape(ortho6d, ortho6d.shape[:-2] + [6])
|
|
return ortho6d
|
|
|
|
|
|
def rotation_matrix_to_ortho6d_1d(matrix: tf.Tensor) -> tf.Tensor:
|
|
"""
|
|
The orhto6d represents the first two column vectors a1 and a2 of the
|
|
rotation matrix: [ | , |, | ]
|
|
[ a1, a2, a3]
|
|
[ | , |, | ]
|
|
Input: (3, 3)
|
|
Output: (6,)
|
|
This function is used to make tensorflow happy.
|
|
"""
|
|
ortho6d = matrix[:, :2]
|
|
# Transpose the last two dimension
|
|
ortho6d = tf.transpose(ortho6d)
|
|
# Flatten the last two dimension
|
|
ortho6d = tf.reshape(ortho6d, [6])
|
|
return ortho6d
|
|
|
|
|
|
def normalize_vector(v):
|
|
"""
|
|
v: (..., N)
|
|
"""
|
|
v_mag = tf.sqrt(tf.reduce_sum(tf.square(v), axis=-1, keepdims=True))
|
|
v_mag = tf.maximum(v_mag, 1e-8)
|
|
v_normalized = v / v_mag
|
|
|
|
return v_normalized
|
|
|
|
|
|
def cross_product(u, v):
|
|
"""
|
|
u: (..., 3)
|
|
v: (..., 3)
|
|
u x v: (..., 3)
|
|
"""
|
|
i = u[..., 1] * v[..., 2] - u[..., 2] * v[..., 1]
|
|
j = u[..., 2] * v[..., 0] - u[..., 0] * v[..., 2]
|
|
k = u[..., 0] * v[..., 1] - u[..., 1] * v[..., 0]
|
|
out = tf.stack([i, j, k], axis=-1)
|
|
return out
|
|
|
|
|
|
def ortho6d_to_rotation_matrix(ortho6d: tf.Tensor) -> tf.Tensor:
|
|
"""
|
|
The orhto6d represents the first two column vectors a1 and a2 of the
|
|
rotation matrix: [ | , |, | ]
|
|
[ a1, a2, a3]
|
|
[ | , |, | ]
|
|
Input: (A1, ..., An, 6)
|
|
Output: (A1, ..., An, 3, 3)
|
|
"""
|
|
x_raw = ortho6d[..., 0:3]
|
|
y_raw = ortho6d[..., 3:6]
|
|
|
|
x = normalize_vector(x_raw)
|
|
z = cross_product(x, y_raw)
|
|
z = normalize_vector(z)
|
|
y = cross_product(z, x)
|
|
|
|
# Stack x, y, z to form the matrix
|
|
matrix = tf.stack([x, y, z], axis=-1)
|
|
return matrix
|
|
|
|
|
|
def capitalize_and_period(instr: str) -> str:
|
|
"""
|
|
Capitalize the first letter of a string and add a period to the end if it's not there.
|
|
"""
|
|
if len(instr) > 0:
|
|
# if the first letter is not capital, make it so
|
|
if not instr[0].isupper():
|
|
# if the first letter is not capital, make it so
|
|
instr = instr[0].upper() + instr[1:]
|
|
# add period to the end if it's not there
|
|
if instr[-1] != ".":
|
|
# add period to the end if it's not there
|
|
instr = instr + "."
|
|
return instr
|