success
This commit is contained in:
129
baselines/grasping/GSNet/utils/collision_detector.py
Executable file
129
baselines/grasping/GSNet/utils/collision_detector.py
Executable file
@@ -0,0 +1,129 @@
|
||||
""" Collision detection to remove collided grasp pose predictions.
|
||||
Author: chenxi-wang
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
import open3d as o3d
|
||||
|
||||
class ModelFreeCollisionDetector():
|
||||
""" Collision detection in scenes without object labels. Current finger width and length are fixed.
|
||||
|
||||
Input:
|
||||
scene_points: [numpy.ndarray, (N,3), numpy.float32]
|
||||
the scene points to detect
|
||||
voxel_size: [float]
|
||||
used for downsample
|
||||
|
||||
Example usage:
|
||||
mfcdetector = ModelFreeCollisionDetector(scene_points, voxel_size=0.005)
|
||||
collision_mask = mfcdetector.detect(grasp_group, approach_dist=0.03)
|
||||
collision_mask, iou_list = mfcdetector.detect(grasp_group, approach_dist=0.03, collision_thresh=0.05, return_ious=True)
|
||||
collision_mask, empty_mask = mfcdetector.detect(grasp_group, approach_dist=0.03, collision_thresh=0.05,
|
||||
return_empty_grasp=True, empty_thresh=0.01)
|
||||
collision_mask, empty_mask, iou_list = mfcdetector.detect(grasp_group, approach_dist=0.03, collision_thresh=0.05,
|
||||
return_empty_grasp=True, empty_thresh=0.01, return_ious=True)
|
||||
"""
|
||||
def __init__(self, scene_points, voxel_size=0.005):
|
||||
self.finger_width = 0.01
|
||||
self.finger_length = 0.06
|
||||
self.voxel_size = voxel_size
|
||||
scene_cloud = o3d.geometry.PointCloud()
|
||||
scene_cloud.points = o3d.utility.Vector3dVector(scene_points)
|
||||
scene_cloud = scene_cloud.voxel_down_sample(voxel_size)
|
||||
self.scene_points = np.array(scene_cloud.points)
|
||||
|
||||
def detect(self, grasp_group, approach_dist=0.03, collision_thresh=0.05, return_empty_grasp=False, empty_thresh=0.01, return_ious=False):
|
||||
""" Detect collision of grasps.
|
||||
|
||||
Input:
|
||||
grasp_group: [GraspGroup, M grasps]
|
||||
the grasps to check
|
||||
approach_dist: [float]
|
||||
the distance for a gripper to move along approaching direction before grasping
|
||||
this shifting space requires no point either
|
||||
collision_thresh: [float]
|
||||
if global collision iou is greater than this threshold,
|
||||
a collision is detected
|
||||
return_empty_grasp: [bool]
|
||||
if True, return a mask to imply whether there are objects in a grasp
|
||||
empty_thresh: [float]
|
||||
if inner space iou is smaller than this threshold,
|
||||
a collision is detected
|
||||
only set when [return_empty_grasp] is True
|
||||
return_ious: [bool]
|
||||
if True, return global collision iou and part collision ious
|
||||
|
||||
Output:
|
||||
collision_mask: [numpy.ndarray, (M,), numpy.bool]
|
||||
True implies collision
|
||||
[optional] empty_mask: [numpy.ndarray, (M,), numpy.bool]
|
||||
True implies empty grasp
|
||||
only returned when [return_empty_grasp] is True
|
||||
[optional] iou_list: list of [numpy.ndarray, (M,), numpy.float32]
|
||||
global and part collision ious, containing
|
||||
[global_iou, left_iou, right_iou, bottom_iou, shifting_iou]
|
||||
only returned when [return_ious] is True
|
||||
"""
|
||||
approach_dist = max(approach_dist, self.finger_width)
|
||||
T = grasp_group.translations
|
||||
R = grasp_group.rotation_matrices
|
||||
heights = grasp_group.heights[:,np.newaxis]
|
||||
depths = grasp_group.depths[:,np.newaxis]
|
||||
widths = grasp_group.widths[:,np.newaxis]
|
||||
targets = self.scene_points[np.newaxis,:,:] - T[:,np.newaxis,:]
|
||||
targets = np.matmul(targets, R)
|
||||
|
||||
## collision detection
|
||||
# height mask
|
||||
mask1 = ((targets[:,:,2] > -heights/2) & (targets[:,:,2] < heights/2))
|
||||
# left finger mask
|
||||
mask2 = ((targets[:,:,0] > depths - self.finger_length) & (targets[:,:,0] < depths))
|
||||
mask3 = (targets[:,:,1] > -(widths/2 + self.finger_width))
|
||||
mask4 = (targets[:,:,1] < -widths/2)
|
||||
# right finger mask
|
||||
mask5 = (targets[:,:,1] < (widths/2 + self.finger_width))
|
||||
mask6 = (targets[:,:,1] > widths/2)
|
||||
# bottom mask
|
||||
mask7 = ((targets[:,:,0] <= depths - self.finger_length)\
|
||||
& (targets[:,:,0] > depths - self.finger_length - self.finger_width))
|
||||
# shifting mask
|
||||
mask8 = ((targets[:,:,0] <= depths - self.finger_length - self.finger_width)\
|
||||
& (targets[:,:,0] > depths - self.finger_length - self.finger_width - approach_dist))
|
||||
|
||||
# get collision mask of each point
|
||||
left_mask = (mask1 & mask2 & mask3 & mask4)
|
||||
right_mask = (mask1 & mask2 & mask5 & mask6)
|
||||
bottom_mask = (mask1 & mask3 & mask5 & mask7)
|
||||
shifting_mask = (mask1 & mask3 & mask5 & mask8)
|
||||
global_mask = (left_mask | right_mask | bottom_mask | shifting_mask)
|
||||
|
||||
# calculate equivalant volume of each part
|
||||
left_right_volume = (heights * self.finger_length * self.finger_width / (self.voxel_size**3)).reshape(-1)
|
||||
bottom_volume = (heights * (widths+2*self.finger_width) * self.finger_width / (self.voxel_size**3)).reshape(-1)
|
||||
shifting_volume = (heights * (widths+2*self.finger_width) * approach_dist / (self.voxel_size**3)).reshape(-1)
|
||||
volume = left_right_volume*2 + bottom_volume + shifting_volume
|
||||
|
||||
# get collision iou of each part
|
||||
global_iou = global_mask.sum(axis=1) / (volume+1e-6)
|
||||
|
||||
# get collison mask
|
||||
collision_mask = (global_iou > collision_thresh)
|
||||
|
||||
if not (return_empty_grasp or return_ious):
|
||||
return collision_mask
|
||||
|
||||
ret_value = [collision_mask,]
|
||||
if return_empty_grasp:
|
||||
inner_mask = (mask1 & mask2 & (~mask4) & (~mask6))
|
||||
inner_volume = (heights * self.finger_length * widths / (self.voxel_size**3)).reshape(-1)
|
||||
empty_mask = (inner_mask.sum(axis=-1)/inner_volume < empty_thresh)
|
||||
ret_value.append(empty_mask)
|
||||
if return_ious:
|
||||
left_iou = left_mask.sum(axis=1) / (left_right_volume+1e-6)
|
||||
right_iou = right_mask.sum(axis=1) / (left_right_volume+1e-6)
|
||||
bottom_iou = bottom_mask.sum(axis=1) / (bottom_volume+1e-6)
|
||||
shifting_iou = shifting_mask.sum(axis=1) / (shifting_volume+1e-6)
|
||||
ret_value.append([global_iou, left_iou, right_iou, bottom_iou, shifting_iou])
|
||||
return ret_value
|
156
baselines/grasping/GSNet/utils/data_utils.py
Executable file
156
baselines/grasping/GSNet/utils/data_utils.py
Executable file
@@ -0,0 +1,156 @@
|
||||
""" Tools for data processing.
|
||||
Author: chenxi-wang
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class CameraInfo():
|
||||
""" Camera intrisics for point cloud creation. """
|
||||
|
||||
def __init__(self, width, height, fx, fy, cx, cy, scale):
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.fx = fx
|
||||
self.fy = fy
|
||||
self.cx = cx
|
||||
self.cy = cy
|
||||
self.scale = scale
|
||||
|
||||
|
||||
def create_point_cloud_from_depth_image(depth, camera, organized=True):
|
||||
""" Generate point cloud using depth image only.
|
||||
|
||||
Input:
|
||||
depth: [numpy.ndarray, (H,W), numpy.float32]
|
||||
depth image
|
||||
camera: [CameraInfo]
|
||||
camera intrinsics
|
||||
organized: bool
|
||||
whether to keep the cloud in image shape (H,W,3)
|
||||
|
||||
Output:
|
||||
cloud: [numpy.ndarray, (H,W,3)/(H*W,3), numpy.float32]
|
||||
generated cloud, (H,W,3) for organized=True, (H*W,3) for organized=False
|
||||
"""
|
||||
assert (depth.shape[0] == camera.height and depth.shape[1] == camera.width)
|
||||
xmap = np.arange(camera.width)
|
||||
ymap = np.arange(camera.height)
|
||||
xmap, ymap = np.meshgrid(xmap, ymap)
|
||||
points_z = depth / camera.scale
|
||||
points_x = (xmap - camera.cx) * points_z / camera.fx
|
||||
points_y = (ymap - camera.cy) * points_z / camera.fy
|
||||
cloud = np.stack([points_x, points_y, points_z], axis=-1)
|
||||
if not organized:
|
||||
cloud = cloud.reshape([-1, 3])
|
||||
return cloud
|
||||
|
||||
|
||||
def transform_point_cloud(cloud, transform, format='4x4'):
|
||||
""" Transform points to new coordinates with transformation matrix.
|
||||
|
||||
Input:
|
||||
cloud: [np.ndarray, (N,3), np.float32]
|
||||
points in original coordinates
|
||||
transform: [np.ndarray, (3,3)/(3,4)/(4,4), np.float32]
|
||||
transformation matrix, could be rotation only or rotation+translation
|
||||
format: [string, '3x3'/'3x4'/'4x4']
|
||||
the shape of transformation matrix
|
||||
'3x3' --> rotation matrix
|
||||
'3x4'/'4x4' --> rotation matrix + translation matrix
|
||||
|
||||
Output:
|
||||
cloud_transformed: [np.ndarray, (N,3), np.float32]
|
||||
points in new coordinates
|
||||
"""
|
||||
if not (format == '3x3' or format == '4x4' or format == '3x4'):
|
||||
raise ValueError('Unknown transformation format, only support \'3x3\' or \'4x4\' or \'3x4\'.')
|
||||
if format == '3x3':
|
||||
cloud_transformed = np.dot(transform, cloud.T).T
|
||||
elif format == '4x4' or format == '3x4':
|
||||
ones = np.ones(cloud.shape[0])[:, np.newaxis]
|
||||
cloud_ = np.concatenate([cloud, ones], axis=1)
|
||||
cloud_transformed = np.dot(transform, cloud_.T).T
|
||||
cloud_transformed = cloud_transformed[:, :3]
|
||||
return cloud_transformed
|
||||
|
||||
|
||||
def compute_point_dists(A, B):
|
||||
""" Compute pair-wise point distances in two matrices.
|
||||
|
||||
Input:
|
||||
A: [np.ndarray, (N,3), np.float32]
|
||||
point cloud A
|
||||
B: [np.ndarray, (M,3), np.float32]
|
||||
point cloud B
|
||||
|
||||
Output:
|
||||
dists: [np.ndarray, (N,M), np.float32]
|
||||
distance matrix
|
||||
"""
|
||||
A = A[:, np.newaxis, :]
|
||||
B = B[np.newaxis, :, :]
|
||||
dists = np.linalg.norm(A - B, axis=-1)
|
||||
return dists
|
||||
|
||||
|
||||
def remove_invisible_grasp_points(cloud, grasp_points, pose, th=0.01):
|
||||
""" Remove invisible part of object model according to scene point cloud.
|
||||
|
||||
Input:
|
||||
cloud: [np.ndarray, (N,3), np.float32]
|
||||
scene point cloud
|
||||
grasp_points: [np.ndarray, (M,3), np.float32]
|
||||
grasp point label in object coordinates
|
||||
pose: [np.ndarray, (4,4), np.float32]
|
||||
transformation matrix from object coordinates to world coordinates
|
||||
th: [float]
|
||||
if the minimum distance between a grasp point and the scene points is greater than outlier, the point will be removed
|
||||
|
||||
Output:
|
||||
visible_mask: [np.ndarray, (M,), np.bool]
|
||||
mask to show the visible part of grasp points
|
||||
"""
|
||||
grasp_points_trans = transform_point_cloud(grasp_points, pose)
|
||||
dists = compute_point_dists(grasp_points_trans, cloud)
|
||||
min_dists = dists.min(axis=1)
|
||||
visible_mask = (min_dists < th)
|
||||
return visible_mask
|
||||
|
||||
|
||||
def get_workspace_mask(cloud, seg, trans=None, organized=True, outlier=0):
|
||||
""" Keep points in workspace as input.
|
||||
|
||||
Input:
|
||||
cloud: [np.ndarray, (H,W,3), np.float32]
|
||||
scene point cloud
|
||||
seg: [np.ndarray, (H,W,), np.uint8]
|
||||
segmantation label of scene points
|
||||
trans: [np.ndarray, (4,4), np.float32]
|
||||
transformation matrix for scene points, default: None.
|
||||
organized: [bool]
|
||||
whether to keep the cloud in image shape (H,W,3)
|
||||
outlier: [float]
|
||||
if the distance between a point and workspace is greater than outlier, the point will be removed
|
||||
|
||||
Output:
|
||||
workspace_mask: [np.ndarray, (H,W)/(H*W,), np.bool]
|
||||
mask to indicate whether scene points are in workspace
|
||||
"""
|
||||
if organized:
|
||||
h, w, _ = cloud.shape
|
||||
cloud = cloud.reshape([h * w, 3])
|
||||
seg = seg.reshape(h * w)
|
||||
if trans is not None:
|
||||
cloud = transform_point_cloud(cloud, trans)
|
||||
foreground = cloud[seg > 0]
|
||||
xmin, ymin, zmin = foreground.min(axis=0)
|
||||
xmax, ymax, zmax = foreground.max(axis=0)
|
||||
mask_x = ((cloud[:, 0] > xmin - outlier) & (cloud[:, 0] < xmax + outlier))
|
||||
mask_y = ((cloud[:, 1] > ymin - outlier) & (cloud[:, 1] < ymax + outlier))
|
||||
mask_z = ((cloud[:, 2] > zmin - outlier) & (cloud[:, 2] < zmax + outlier))
|
||||
workspace_mask = (mask_x & mask_y & mask_z)
|
||||
if organized:
|
||||
workspace_mask = workspace_mask.reshape([h, w])
|
||||
|
||||
return workspace_mask
|
143
baselines/grasping/GSNet/utils/label_generation.py
Executable file
143
baselines/grasping/GSNet/utils/label_generation.py
Executable file
@@ -0,0 +1,143 @@
|
||||
""" Dynamically generate grasp labels during training.
|
||||
Author: chenxi-wang
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
|
||||
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
ROOT_DIR = os.path.dirname(BASE_DIR)
|
||||
sys.path.append(ROOT_DIR)
|
||||
# sys.path.append(os.path.join(ROOT_DIR, 'knn'))
|
||||
|
||||
from knn.knn_modules import knn
|
||||
from loss_utils import GRASP_MAX_WIDTH, batch_viewpoint_params_to_matrix, \
|
||||
transform_point_cloud, generate_grasp_views
|
||||
|
||||
|
||||
def process_grasp_labels(end_points):
|
||||
""" Process labels according to scene points and object poses. """
|
||||
seed_xyzs = end_points['xyz_graspable'] # (B, M_point, 3)
|
||||
batch_size, num_samples, _ = seed_xyzs.size()
|
||||
|
||||
batch_grasp_points = []
|
||||
batch_grasp_views_rot = []
|
||||
batch_grasp_scores = []
|
||||
batch_grasp_widths = []
|
||||
for i in range(batch_size):
|
||||
seed_xyz = seed_xyzs[i] # (Ns, 3)
|
||||
poses = end_points['object_poses_list'][i] # [(3, 4),]
|
||||
|
||||
# get merged grasp points for label computation
|
||||
grasp_points_merged = []
|
||||
grasp_views_rot_merged = []
|
||||
grasp_scores_merged = []
|
||||
grasp_widths_merged = []
|
||||
for obj_idx, pose in enumerate(poses):
|
||||
grasp_points = end_points['grasp_points_list'][i][obj_idx] # (Np, 3)
|
||||
grasp_scores = end_points['grasp_scores_list'][i][obj_idx] # (Np, V, A, D)
|
||||
grasp_widths = end_points['grasp_widths_list'][i][obj_idx] # (Np, V, A, D)
|
||||
_, V, A, D = grasp_scores.size()
|
||||
num_grasp_points = grasp_points.size(0)
|
||||
# generate and transform template grasp views
|
||||
grasp_views = generate_grasp_views(V).to(pose.device) # (V, 3)
|
||||
grasp_points_trans = transform_point_cloud(grasp_points, pose, '3x4')
|
||||
grasp_views_trans = transform_point_cloud(grasp_views, pose[:3, :3], '3x3')
|
||||
# generate and transform template grasp view rotation
|
||||
angles = torch.zeros(grasp_views.size(0), dtype=grasp_views.dtype, device=grasp_views.device)
|
||||
grasp_views_rot = batch_viewpoint_params_to_matrix(-grasp_views, angles) # (V, 3, 3)
|
||||
grasp_views_rot_trans = torch.matmul(pose[:3, :3], grasp_views_rot) # (V, 3, 3)
|
||||
|
||||
# assign views
|
||||
grasp_views_ = grasp_views.transpose(0, 1).contiguous().unsqueeze(0)
|
||||
grasp_views_trans_ = grasp_views_trans.transpose(0, 1).contiguous().unsqueeze(0)
|
||||
view_inds = knn(grasp_views_trans_, grasp_views_, k=1).squeeze() - 1
|
||||
grasp_views_rot_trans = torch.index_select(grasp_views_rot_trans, 0, view_inds) # (V, 3, 3)
|
||||
grasp_views_rot_trans = grasp_views_rot_trans.unsqueeze(0).expand(num_grasp_points, -1, -1,
|
||||
-1) # (Np, V, 3, 3)
|
||||
grasp_scores = torch.index_select(grasp_scores, 1, view_inds) # (Np, V, A, D)
|
||||
grasp_widths = torch.index_select(grasp_widths, 1, view_inds) # (Np, V, A, D)
|
||||
# add to list
|
||||
grasp_points_merged.append(grasp_points_trans)
|
||||
grasp_views_rot_merged.append(grasp_views_rot_trans)
|
||||
grasp_scores_merged.append(grasp_scores)
|
||||
grasp_widths_merged.append(grasp_widths)
|
||||
|
||||
grasp_points_merged = torch.cat(grasp_points_merged, dim=0) # (Np', 3)
|
||||
grasp_views_rot_merged = torch.cat(grasp_views_rot_merged, dim=0) # (Np', V, 3, 3)
|
||||
grasp_scores_merged = torch.cat(grasp_scores_merged, dim=0) # (Np', V, A, D)
|
||||
grasp_widths_merged = torch.cat(grasp_widths_merged, dim=0) # (Np', V, A, D)
|
||||
|
||||
# compute nearest neighbors
|
||||
seed_xyz_ = seed_xyz.transpose(0, 1).contiguous().unsqueeze(0) # (1, 3, Ns)
|
||||
grasp_points_merged_ = grasp_points_merged.transpose(0, 1).contiguous().unsqueeze(0) # (1, 3, Np')
|
||||
nn_inds = knn(grasp_points_merged_, seed_xyz_, k=1).squeeze() - 1 # (Ns)
|
||||
|
||||
# assign anchor points to real points
|
||||
grasp_points_merged = torch.index_select(grasp_points_merged, 0, nn_inds) # (Ns, 3)
|
||||
grasp_views_rot_merged = torch.index_select(grasp_views_rot_merged, 0, nn_inds) # (Ns, V, 3, 3)
|
||||
grasp_scores_merged = torch.index_select(grasp_scores_merged, 0, nn_inds) # (Ns, V, A, D)
|
||||
grasp_widths_merged = torch.index_select(grasp_widths_merged, 0, nn_inds) # (Ns, V, A, D)
|
||||
|
||||
# add to batch
|
||||
batch_grasp_points.append(grasp_points_merged)
|
||||
batch_grasp_views_rot.append(grasp_views_rot_merged)
|
||||
batch_grasp_scores.append(grasp_scores_merged)
|
||||
batch_grasp_widths.append(grasp_widths_merged)
|
||||
|
||||
batch_grasp_points = torch.stack(batch_grasp_points, 0) # (B, Ns, 3)
|
||||
batch_grasp_views_rot = torch.stack(batch_grasp_views_rot, 0) # (B, Ns, V, 3, 3)
|
||||
batch_grasp_scores = torch.stack(batch_grasp_scores, 0) # (B, Ns, V, A, D)
|
||||
batch_grasp_widths = torch.stack(batch_grasp_widths, 0) # (B, Ns, V, A, D)
|
||||
|
||||
# compute view graspness
|
||||
view_u_threshold = 0.6
|
||||
view_grasp_num = 48
|
||||
batch_grasp_view_valid_mask = (batch_grasp_scores <= view_u_threshold) & (batch_grasp_scores > 0) # (B, Ns, V, A, D)
|
||||
batch_grasp_view_valid = batch_grasp_view_valid_mask.float()
|
||||
batch_grasp_view_graspness = torch.sum(torch.sum(batch_grasp_view_valid, dim=-1), dim=-1) / view_grasp_num # (B, Ns, V)
|
||||
view_graspness_min, _ = torch.min(batch_grasp_view_graspness, dim=-1) # (B, Ns)
|
||||
view_graspness_max, _ = torch.max(batch_grasp_view_graspness, dim=-1)
|
||||
view_graspness_max = view_graspness_max.unsqueeze(-1).expand(-1, -1, 300) # (B, Ns, V)
|
||||
view_graspness_min = view_graspness_min.unsqueeze(-1).expand(-1, -1, 300) # same shape as batch_grasp_view_graspness
|
||||
batch_grasp_view_graspness = (batch_grasp_view_graspness - view_graspness_min) / (view_graspness_max - view_graspness_min + 1e-5)
|
||||
|
||||
# process scores
|
||||
label_mask = (batch_grasp_scores > 0) & (batch_grasp_widths <= GRASP_MAX_WIDTH) # (B, Ns, V, A, D)
|
||||
batch_grasp_scores[~label_mask] = 0
|
||||
|
||||
end_points['batch_grasp_point'] = batch_grasp_points
|
||||
end_points['batch_grasp_view_rot'] = batch_grasp_views_rot
|
||||
end_points['batch_grasp_score'] = batch_grasp_scores
|
||||
end_points['batch_grasp_width'] = batch_grasp_widths
|
||||
end_points['batch_grasp_view_graspness'] = batch_grasp_view_graspness
|
||||
|
||||
return end_points
|
||||
|
||||
|
||||
def match_grasp_view_and_label(end_points):
|
||||
""" Slice grasp labels according to predicted views. """
|
||||
top_view_inds = end_points['grasp_top_view_inds'] # (B, Ns)
|
||||
template_views_rot = end_points['batch_grasp_view_rot'] # (B, Ns, V, 3, 3)
|
||||
grasp_scores = end_points['batch_grasp_score'] # (B, Ns, V, A, D)
|
||||
grasp_widths = end_points['batch_grasp_width'] # (B, Ns, V, A, D, 3)
|
||||
|
||||
B, Ns, V, A, D = grasp_scores.size()
|
||||
top_view_inds_ = top_view_inds.view(B, Ns, 1, 1, 1).expand(-1, -1, -1, 3, 3)
|
||||
top_template_views_rot = torch.gather(template_views_rot, 2, top_view_inds_).squeeze(2)
|
||||
top_view_inds_ = top_view_inds.view(B, Ns, 1, 1, 1).expand(-1, -1, -1, A, D)
|
||||
top_view_grasp_scores = torch.gather(grasp_scores, 2, top_view_inds_).squeeze(2)
|
||||
top_view_grasp_widths = torch.gather(grasp_widths, 2, top_view_inds_).squeeze(2)
|
||||
|
||||
u_max = top_view_grasp_scores.max()
|
||||
po_mask = top_view_grasp_scores > 0
|
||||
po_mask_num = torch.sum(po_mask)
|
||||
if po_mask_num > 0:
|
||||
u_min = top_view_grasp_scores[po_mask].min()
|
||||
top_view_grasp_scores[po_mask] = torch.log(u_max / top_view_grasp_scores[po_mask]) / (torch.log(u_max / u_min) + 1e-6)
|
||||
|
||||
end_points['batch_grasp_score'] = top_view_grasp_scores # (B, Ns, A, D)
|
||||
end_points['batch_grasp_width'] = top_view_grasp_widths # (B, Ns, A, D)
|
||||
|
||||
return top_template_views_rot, end_points
|
121
baselines/grasping/GSNet/utils/loss_utils.py
Executable file
121
baselines/grasping/GSNet/utils/loss_utils.py
Executable file
@@ -0,0 +1,121 @@
|
||||
""" Tools for loss computation.
|
||||
Author: chenxi-wang
|
||||
"""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
GRASP_MAX_WIDTH = 0.1
|
||||
GRASPNESS_THRESHOLD = 0.1
|
||||
NUM_VIEW = 300
|
||||
NUM_ANGLE = 12
|
||||
NUM_DEPTH = 4
|
||||
M_POINT = 1024
|
||||
|
||||
|
||||
def transform_point_cloud(cloud, transform, format='4x4'):
|
||||
""" Transform points to new coordinates with transformation matrix.
|
||||
|
||||
Input:
|
||||
cloud: [torch.FloatTensor, (N,3)]
|
||||
points in original coordinates
|
||||
transform: [torch.FloatTensor, (3,3)/(3,4)/(4,4)]
|
||||
transformation matrix, could be rotation only or rotation+translation
|
||||
format: [string, '3x3'/'3x4'/'4x4']
|
||||
the shape of transformation matrix
|
||||
'3x3' --> rotation matrix
|
||||
'3x4'/'4x4' --> rotation matrix + translation matrix
|
||||
|
||||
Output:
|
||||
cloud_transformed: [torch.FloatTensor, (N,3)]
|
||||
points in new coordinates
|
||||
"""
|
||||
if not (format == '3x3' or format == '4x4' or format == '3x4'):
|
||||
raise ValueError('Unknown transformation format, only support \'3x3\' or \'4x4\' or \'3x4\'.')
|
||||
if format == '3x3':
|
||||
cloud_transformed = torch.matmul(transform, cloud.T).T
|
||||
elif format == '4x4' or format == '3x4':
|
||||
ones = cloud.new_ones(cloud.size(0), device=cloud.device).unsqueeze(-1)
|
||||
cloud_ = torch.cat([cloud, ones], dim=1)
|
||||
cloud_transformed = torch.matmul(transform, cloud_.T).T
|
||||
cloud_transformed = cloud_transformed[:, :3]
|
||||
return cloud_transformed
|
||||
|
||||
|
||||
def generate_grasp_views(N=300, phi=(np.sqrt(5) - 1) / 2, center=np.zeros(3), r=1):
|
||||
""" View sampling on a unit sphere using Fibonacci lattices.
|
||||
Ref: https://arxiv.org/abs/0912.4540
|
||||
|
||||
Input:
|
||||
N: [int]
|
||||
number of sampled views
|
||||
phi: [float]
|
||||
constant for view coordinate calculation, different phi's bring different distributions, default: (sqrt(5)-1)/2
|
||||
center: [np.ndarray, (3,), np.float32]
|
||||
sphere center
|
||||
r: [float]
|
||||
sphere radius
|
||||
|
||||
Output:
|
||||
views: [torch.FloatTensor, (N,3)]
|
||||
sampled view coordinates
|
||||
"""
|
||||
views = []
|
||||
for i in range(N):
|
||||
zi = (2 * i + 1) / N - 1
|
||||
xi = np.sqrt(1 - zi ** 2) * np.cos(2 * i * np.pi * phi)
|
||||
yi = np.sqrt(1 - zi ** 2) * np.sin(2 * i * np.pi * phi)
|
||||
views.append([xi, yi, zi])
|
||||
views = r * np.array(views) + center
|
||||
return torch.from_numpy(views.astype(np.float32))
|
||||
|
||||
|
||||
def batch_viewpoint_params_to_matrix(batch_towards, batch_angle):
|
||||
""" Transform approach vectors and in-plane rotation angles to rotation matrices.
|
||||
|
||||
Input:
|
||||
batch_towards: [torch.FloatTensor, (N,3)]
|
||||
approach vectors in batch
|
||||
batch_angle: [torch.floatTensor, (N,)]
|
||||
in-plane rotation angles in batch
|
||||
|
||||
Output:
|
||||
batch_matrix: [torch.floatTensor, (N,3,3)]
|
||||
rotation matrices in batch
|
||||
"""
|
||||
axis_x = batch_towards
|
||||
ones = torch.ones(axis_x.shape[0], dtype=axis_x.dtype, device=axis_x.device)
|
||||
zeros = torch.zeros(axis_x.shape[0], dtype=axis_x.dtype, device=axis_x.device)
|
||||
axis_y = torch.stack([-axis_x[:, 1], axis_x[:, 0], zeros], dim=-1)
|
||||
mask_y = (torch.norm(axis_y, dim=-1) == 0)
|
||||
axis_y[mask_y, 1] = 1
|
||||
axis_x = axis_x / torch.norm(axis_x, dim=-1, keepdim=True)
|
||||
axis_y = axis_y / torch.norm(axis_y, dim=-1, keepdim=True)
|
||||
axis_z = torch.cross(axis_x, axis_y)
|
||||
sin = torch.sin(batch_angle)
|
||||
cos = torch.cos(batch_angle)
|
||||
R1 = torch.stack([ones, zeros, zeros, zeros, cos, -sin, zeros, sin, cos], dim=-1)
|
||||
R1 = R1.reshape([-1, 3, 3])
|
||||
R2 = torch.stack([axis_x, axis_y, axis_z], dim=-1)
|
||||
batch_matrix = torch.matmul(R2, R1)
|
||||
return batch_matrix
|
||||
|
||||
|
||||
def huber_loss(error, delta=1.0):
|
||||
"""
|
||||
Args:
|
||||
error: Torch tensor (d1,d2,...,dk)
|
||||
Returns:
|
||||
loss: Torch tensor (d1,d2,...,dk)
|
||||
|
||||
x = error = pred - gt or dist(pred,gt)
|
||||
0.5 * |x|^2 if |x|<=d
|
||||
0.5 * d^2 + d * (|x|-d) if |x|>d
|
||||
Author: Charles R. Qi
|
||||
Ref: https://github.com/charlesq34/frustum-pointnets/blob/master/models/model_util.py
|
||||
"""
|
||||
abs_error = torch.abs(error)
|
||||
quadratic = torch.clamp(abs_error, max=delta)
|
||||
linear = (abs_error - quadratic)
|
||||
loss = 0.5 * quadratic ** 2 + delta * linear
|
||||
return loss
|
Reference in New Issue
Block a user