success
This commit is contained in:
126
baselines/grasping/GSNet/models/graspnet.py
Executable file
126
baselines/grasping/GSNet/models/graspnet.py
Executable file
@@ -0,0 +1,126 @@
|
||||
""" GraspNet baseline model definition.
|
||||
Author: chenxi-wang
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import MinkowskiEngine as ME
|
||||
|
||||
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
ROOT_DIR = os.path.dirname(BASE_DIR)
|
||||
sys.path.append(ROOT_DIR)
|
||||
|
||||
from models.backbone_resunet14 import MinkUNet14D
|
||||
from models.modules import ApproachNet, GraspableNet, CloudCrop, SWADNet
|
||||
from loss_utils import GRASP_MAX_WIDTH, NUM_VIEW, NUM_ANGLE, NUM_DEPTH, GRASPNESS_THRESHOLD, M_POINT
|
||||
from label_generation import process_grasp_labels, match_grasp_view_and_label, batch_viewpoint_params_to_matrix
|
||||
from pointnet2.pointnet2_utils import furthest_point_sample, gather_operation
|
||||
|
||||
|
||||
class GraspNet(nn.Module):
|
||||
def __init__(self, cylinder_radius=0.05, seed_feat_dim=512, is_training=True):
|
||||
super().__init__()
|
||||
self.is_training = is_training
|
||||
self.seed_feature_dim = seed_feat_dim
|
||||
self.num_depth = NUM_DEPTH
|
||||
self.num_angle = NUM_ANGLE
|
||||
self.M_points = M_POINT
|
||||
self.num_view = NUM_VIEW
|
||||
|
||||
self.backbone = MinkUNet14D(in_channels=3, out_channels=self.seed_feature_dim, D=3)
|
||||
self.graspable = GraspableNet(seed_feature_dim=self.seed_feature_dim)
|
||||
self.rotation = ApproachNet(self.num_view, seed_feature_dim=self.seed_feature_dim, is_training=self.is_training)
|
||||
self.crop = CloudCrop(nsample=16, cylinder_radius=cylinder_radius, seed_feature_dim=self.seed_feature_dim)
|
||||
self.swad = SWADNet(num_angle=self.num_angle, num_depth=self.num_depth)
|
||||
|
||||
def forward(self, end_points):
|
||||
seed_xyz = end_points['point_clouds'] # use all sampled point cloud, B*Ns*3
|
||||
B, point_num, _ = seed_xyz.shape # batch _size
|
||||
# point-wise features
|
||||
coordinates_batch = end_points['coors']
|
||||
features_batch = end_points['feats']
|
||||
mink_input = ME.SparseTensor(features_batch, coordinates=coordinates_batch)
|
||||
seed_features = self.backbone(mink_input).F
|
||||
seed_features = seed_features[end_points['quantize2original']].view(B, point_num, -1).transpose(1, 2)
|
||||
|
||||
end_points = self.graspable(seed_features, end_points)
|
||||
seed_features_flipped = seed_features.transpose(1, 2) # B*Ns*feat_dim
|
||||
objectness_score = end_points['objectness_score']
|
||||
graspness_score = end_points['graspness_score'].squeeze(1)
|
||||
objectness_pred = torch.argmax(objectness_score, 1)
|
||||
objectness_mask = (objectness_pred == 1)
|
||||
graspness_mask = graspness_score > GRASPNESS_THRESHOLD
|
||||
graspable_mask = objectness_mask & graspness_mask
|
||||
|
||||
seed_features_graspable = []
|
||||
seed_xyz_graspable = []
|
||||
graspable_num_batch = 0.
|
||||
for i in range(B):
|
||||
cur_mask = graspable_mask[i]
|
||||
graspable_num_batch += cur_mask.sum()
|
||||
if graspable_num_batch == 0:
|
||||
return None
|
||||
cur_feat = seed_features_flipped[i][cur_mask] # Ns*feat_dim
|
||||
cur_seed_xyz = seed_xyz[i][cur_mask] # Ns*3
|
||||
|
||||
cur_seed_xyz = cur_seed_xyz.unsqueeze(0) # 1*Ns*3
|
||||
fps_idxs = furthest_point_sample(cur_seed_xyz, self.M_points)
|
||||
cur_seed_xyz_flipped = cur_seed_xyz.transpose(1, 2).contiguous() # 1*3*Ns
|
||||
cur_seed_xyz = gather_operation(cur_seed_xyz_flipped, fps_idxs).transpose(1, 2).squeeze(0).contiguous() # Ns*3
|
||||
cur_feat_flipped = cur_feat.unsqueeze(0).transpose(1, 2).contiguous() # 1*feat_dim*Ns
|
||||
cur_feat = gather_operation(cur_feat_flipped, fps_idxs).squeeze(0).contiguous() # feat_dim*Ns
|
||||
|
||||
seed_features_graspable.append(cur_feat)
|
||||
seed_xyz_graspable.append(cur_seed_xyz)
|
||||
seed_xyz_graspable = torch.stack(seed_xyz_graspable, 0) # B*Ns*3
|
||||
seed_features_graspable = torch.stack(seed_features_graspable) # B*feat_dim*Ns
|
||||
|
||||
end_points['xyz_graspable'] = seed_xyz_graspable
|
||||
end_points['graspable_count_stage1'] = graspable_num_batch / B
|
||||
|
||||
end_points, res_feat = self.rotation(seed_features_graspable, end_points)
|
||||
seed_features_graspable = seed_features_graspable + res_feat
|
||||
|
||||
if self.is_training:
|
||||
end_points = process_grasp_labels(end_points)
|
||||
grasp_top_views_rot, end_points = match_grasp_view_and_label(end_points)
|
||||
else:
|
||||
grasp_top_views_rot = end_points['grasp_top_view_rot']
|
||||
|
||||
group_features = self.crop(seed_xyz_graspable.contiguous(), seed_features_graspable.contiguous(), grasp_top_views_rot)
|
||||
end_points = self.swad(group_features, end_points)
|
||||
|
||||
return end_points
|
||||
|
||||
|
||||
def pred_decode(end_points):
|
||||
batch_size = len(end_points['point_clouds'])
|
||||
grasp_preds = []
|
||||
for i in range(batch_size):
|
||||
grasp_center = end_points['xyz_graspable'][i].float()
|
||||
|
||||
grasp_score = end_points['grasp_score_pred'][i].float()
|
||||
grasp_score = grasp_score.view(M_POINT, NUM_ANGLE*NUM_DEPTH)
|
||||
grasp_score, grasp_score_inds = torch.max(grasp_score, -1) # [M_POINT]
|
||||
grasp_score = grasp_score.view(-1, 1)
|
||||
grasp_angle = (grasp_score_inds // NUM_DEPTH) * np.pi / 12
|
||||
grasp_depth = (grasp_score_inds % NUM_DEPTH + 1) * 0.01
|
||||
grasp_depth = grasp_depth.view(-1, 1)
|
||||
grasp_width = 1.2 * end_points['grasp_width_pred'][i] / 10.
|
||||
grasp_width = grasp_width.view(M_POINT, NUM_ANGLE*NUM_DEPTH)
|
||||
grasp_width = torch.gather(grasp_width, 1, grasp_score_inds.view(-1, 1))
|
||||
grasp_width = torch.clamp(grasp_width, min=0., max=GRASP_MAX_WIDTH)
|
||||
|
||||
approaching = -end_points['grasp_top_view_xyz'][i].float()
|
||||
grasp_rot = batch_viewpoint_params_to_matrix(approaching, grasp_angle)
|
||||
grasp_rot = grasp_rot.view(M_POINT, 9)
|
||||
|
||||
# merge preds
|
||||
grasp_height = 0.02 * torch.ones_like(grasp_score)
|
||||
obj_ids = -1 * torch.ones_like(grasp_score)
|
||||
grasp_preds.append(
|
||||
torch.cat([grasp_score, grasp_width, grasp_height, grasp_depth, grasp_rot, grasp_center, obj_ids], axis=-1))
|
||||
return grasp_preds
|
Reference in New Issue
Block a user