success
This commit is contained in:
121
baselines/grasping/GSNet/utils/loss_utils.py
Executable file
121
baselines/grasping/GSNet/utils/loss_utils.py
Executable file
@@ -0,0 +1,121 @@
|
||||
""" Tools for loss computation.
|
||||
Author: chenxi-wang
|
||||
"""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
GRASP_MAX_WIDTH = 0.1
|
||||
GRASPNESS_THRESHOLD = 0.1
|
||||
NUM_VIEW = 300
|
||||
NUM_ANGLE = 12
|
||||
NUM_DEPTH = 4
|
||||
M_POINT = 1024
|
||||
|
||||
|
||||
def transform_point_cloud(cloud, transform, format='4x4'):
|
||||
""" Transform points to new coordinates with transformation matrix.
|
||||
|
||||
Input:
|
||||
cloud: [torch.FloatTensor, (N,3)]
|
||||
points in original coordinates
|
||||
transform: [torch.FloatTensor, (3,3)/(3,4)/(4,4)]
|
||||
transformation matrix, could be rotation only or rotation+translation
|
||||
format: [string, '3x3'/'3x4'/'4x4']
|
||||
the shape of transformation matrix
|
||||
'3x3' --> rotation matrix
|
||||
'3x4'/'4x4' --> rotation matrix + translation matrix
|
||||
|
||||
Output:
|
||||
cloud_transformed: [torch.FloatTensor, (N,3)]
|
||||
points in new coordinates
|
||||
"""
|
||||
if not (format == '3x3' or format == '4x4' or format == '3x4'):
|
||||
raise ValueError('Unknown transformation format, only support \'3x3\' or \'4x4\' or \'3x4\'.')
|
||||
if format == '3x3':
|
||||
cloud_transformed = torch.matmul(transform, cloud.T).T
|
||||
elif format == '4x4' or format == '3x4':
|
||||
ones = cloud.new_ones(cloud.size(0), device=cloud.device).unsqueeze(-1)
|
||||
cloud_ = torch.cat([cloud, ones], dim=1)
|
||||
cloud_transformed = torch.matmul(transform, cloud_.T).T
|
||||
cloud_transformed = cloud_transformed[:, :3]
|
||||
return cloud_transformed
|
||||
|
||||
|
||||
def generate_grasp_views(N=300, phi=(np.sqrt(5) - 1) / 2, center=np.zeros(3), r=1):
|
||||
""" View sampling on a unit sphere using Fibonacci lattices.
|
||||
Ref: https://arxiv.org/abs/0912.4540
|
||||
|
||||
Input:
|
||||
N: [int]
|
||||
number of sampled views
|
||||
phi: [float]
|
||||
constant for view coordinate calculation, different phi's bring different distributions, default: (sqrt(5)-1)/2
|
||||
center: [np.ndarray, (3,), np.float32]
|
||||
sphere center
|
||||
r: [float]
|
||||
sphere radius
|
||||
|
||||
Output:
|
||||
views: [torch.FloatTensor, (N,3)]
|
||||
sampled view coordinates
|
||||
"""
|
||||
views = []
|
||||
for i in range(N):
|
||||
zi = (2 * i + 1) / N - 1
|
||||
xi = np.sqrt(1 - zi ** 2) * np.cos(2 * i * np.pi * phi)
|
||||
yi = np.sqrt(1 - zi ** 2) * np.sin(2 * i * np.pi * phi)
|
||||
views.append([xi, yi, zi])
|
||||
views = r * np.array(views) + center
|
||||
return torch.from_numpy(views.astype(np.float32))
|
||||
|
||||
|
||||
def batch_viewpoint_params_to_matrix(batch_towards, batch_angle):
|
||||
""" Transform approach vectors and in-plane rotation angles to rotation matrices.
|
||||
|
||||
Input:
|
||||
batch_towards: [torch.FloatTensor, (N,3)]
|
||||
approach vectors in batch
|
||||
batch_angle: [torch.floatTensor, (N,)]
|
||||
in-plane rotation angles in batch
|
||||
|
||||
Output:
|
||||
batch_matrix: [torch.floatTensor, (N,3,3)]
|
||||
rotation matrices in batch
|
||||
"""
|
||||
axis_x = batch_towards
|
||||
ones = torch.ones(axis_x.shape[0], dtype=axis_x.dtype, device=axis_x.device)
|
||||
zeros = torch.zeros(axis_x.shape[0], dtype=axis_x.dtype, device=axis_x.device)
|
||||
axis_y = torch.stack([-axis_x[:, 1], axis_x[:, 0], zeros], dim=-1)
|
||||
mask_y = (torch.norm(axis_y, dim=-1) == 0)
|
||||
axis_y[mask_y, 1] = 1
|
||||
axis_x = axis_x / torch.norm(axis_x, dim=-1, keepdim=True)
|
||||
axis_y = axis_y / torch.norm(axis_y, dim=-1, keepdim=True)
|
||||
axis_z = torch.cross(axis_x, axis_y)
|
||||
sin = torch.sin(batch_angle)
|
||||
cos = torch.cos(batch_angle)
|
||||
R1 = torch.stack([ones, zeros, zeros, zeros, cos, -sin, zeros, sin, cos], dim=-1)
|
||||
R1 = R1.reshape([-1, 3, 3])
|
||||
R2 = torch.stack([axis_x, axis_y, axis_z], dim=-1)
|
||||
batch_matrix = torch.matmul(R2, R1)
|
||||
return batch_matrix
|
||||
|
||||
|
||||
def huber_loss(error, delta=1.0):
|
||||
"""
|
||||
Args:
|
||||
error: Torch tensor (d1,d2,...,dk)
|
||||
Returns:
|
||||
loss: Torch tensor (d1,d2,...,dk)
|
||||
|
||||
x = error = pred - gt or dist(pred,gt)
|
||||
0.5 * |x|^2 if |x|<=d
|
||||
0.5 * d^2 + d * (|x|-d) if |x|>d
|
||||
Author: Charles R. Qi
|
||||
Ref: https://github.com/charlesq34/frustum-pointnets/blob/master/models/model_util.py
|
||||
"""
|
||||
abs_error = torch.abs(error)
|
||||
quadratic = torch.clamp(abs_error, max=delta)
|
||||
linear = (abs_error - quadratic)
|
||||
loss = 0.5 * quadratic ** 2 + delta * linear
|
||||
return loss
|
Reference in New Issue
Block a user