This commit is contained in:
2024-10-09 16:13:22 +00:00
commit 0ea3f048dc
437 changed files with 44406 additions and 0 deletions

View File

@@ -0,0 +1,224 @@
import MinkowskiEngine as ME
from MinkowskiEngine.modules.resnet_block import BasicBlock, Bottleneck
from models.resnet import ResNetBase
class MinkUNetBase(ResNetBase):
BLOCK = None
PLANES = None
DILATIONS = (1, 1, 1, 1, 1, 1, 1, 1)
LAYERS = (2, 2, 2, 2, 2, 2, 2, 2)
PLANES = (32, 64, 128, 256, 256, 128, 96, 96)
INIT_DIM = 32
OUT_TENSOR_STRIDE = 1
# To use the model, must call initialize_coords before forward pass.
# Once data is processed, call clear to reset the model before calling
# initialize_coords
def __init__(self, in_channels, out_channels, D=3):
ResNetBase.__init__(self, in_channels, out_channels, D)
def network_initialization(self, in_channels, out_channels, D):
# Output of the first conv concated to conv6
self.inplanes = self.INIT_DIM
self.conv0p1s1 = ME.MinkowskiConvolution(
in_channels, self.inplanes, kernel_size=5, dimension=D)
self.bn0 = ME.MinkowskiBatchNorm(self.inplanes)
self.conv1p1s2 = ME.MinkowskiConvolution(
self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D)
self.bn1 = ME.MinkowskiBatchNorm(self.inplanes)
self.block1 = self._make_layer(self.BLOCK, self.PLANES[0],
self.LAYERS[0])
self.conv2p2s2 = ME.MinkowskiConvolution(
self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D)
self.bn2 = ME.MinkowskiBatchNorm(self.inplanes)
self.block2 = self._make_layer(self.BLOCK, self.PLANES[1],
self.LAYERS[1])
self.conv3p4s2 = ME.MinkowskiConvolution(
self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D)
self.bn3 = ME.MinkowskiBatchNorm(self.inplanes)
self.block3 = self._make_layer(self.BLOCK, self.PLANES[2],
self.LAYERS[2])
self.conv4p8s2 = ME.MinkowskiConvolution(
self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D)
self.bn4 = ME.MinkowskiBatchNorm(self.inplanes)
self.block4 = self._make_layer(self.BLOCK, self.PLANES[3],
self.LAYERS[3])
self.convtr4p16s2 = ME.MinkowskiConvolutionTranspose(
self.inplanes, self.PLANES[4], kernel_size=2, stride=2, dimension=D)
self.bntr4 = ME.MinkowskiBatchNorm(self.PLANES[4])
self.inplanes = self.PLANES[4] + self.PLANES[2] * self.BLOCK.expansion
self.block5 = self._make_layer(self.BLOCK, self.PLANES[4],
self.LAYERS[4])
self.convtr5p8s2 = ME.MinkowskiConvolutionTranspose(
self.inplanes, self.PLANES[5], kernel_size=2, stride=2, dimension=D)
self.bntr5 = ME.MinkowskiBatchNorm(self.PLANES[5])
self.inplanes = self.PLANES[5] + self.PLANES[1] * self.BLOCK.expansion
self.block6 = self._make_layer(self.BLOCK, self.PLANES[5],
self.LAYERS[5])
self.convtr6p4s2 = ME.MinkowskiConvolutionTranspose(
self.inplanes, self.PLANES[6], kernel_size=2, stride=2, dimension=D)
self.bntr6 = ME.MinkowskiBatchNorm(self.PLANES[6])
self.inplanes = self.PLANES[6] + self.PLANES[0] * self.BLOCK.expansion
self.block7 = self._make_layer(self.BLOCK, self.PLANES[6],
self.LAYERS[6])
self.convtr7p2s2 = ME.MinkowskiConvolutionTranspose(
self.inplanes, self.PLANES[7], kernel_size=2, stride=2, dimension=D)
self.bntr7 = ME.MinkowskiBatchNorm(self.PLANES[7])
self.inplanes = self.PLANES[7] + self.INIT_DIM
self.block8 = self._make_layer(self.BLOCK, self.PLANES[7],
self.LAYERS[7])
self.final = ME.MinkowskiConvolution(
self.PLANES[7] * self.BLOCK.expansion,
out_channels,
kernel_size=1,
bias=True,
dimension=D)
self.relu = ME.MinkowskiReLU(inplace=True)
def forward(self, x):
out = self.conv0p1s1(x)
out = self.bn0(out)
out_p1 = self.relu(out)
out = self.conv1p1s2(out_p1)
out = self.bn1(out)
out = self.relu(out)
out_b1p2 = self.block1(out)
out = self.conv2p2s2(out_b1p2)
out = self.bn2(out)
out = self.relu(out)
out_b2p4 = self.block2(out)
out = self.conv3p4s2(out_b2p4)
out = self.bn3(out)
out = self.relu(out)
out_b3p8 = self.block3(out)
# tensor_stride=16
out = self.conv4p8s2(out_b3p8)
out = self.bn4(out)
out = self.relu(out)
out = self.block4(out)
# tensor_stride=8
out = self.convtr4p16s2(out)
out = self.bntr4(out)
out = self.relu(out)
out = ME.cat(out, out_b3p8)
out = self.block5(out)
# tensor_stride=4
out = self.convtr5p8s2(out)
out = self.bntr5(out)
out = self.relu(out)
out = ME.cat(out, out_b2p4)
out = self.block6(out)
# tensor_stride=2
out = self.convtr6p4s2(out)
out = self.bntr6(out)
out = self.relu(out)
out = ME.cat(out, out_b1p2)
out = self.block7(out)
# tensor_stride=1
out = self.convtr7p2s2(out)
out = self.bntr7(out)
out = self.relu(out)
out = ME.cat(out, out_p1)
out = self.block8(out)
return self.final(out)
class MinkUNet14(MinkUNetBase):
BLOCK = BasicBlock
LAYERS = (1, 1, 1, 1, 1, 1, 1, 1)
class MinkUNet18(MinkUNetBase):
BLOCK = BasicBlock
LAYERS = (2, 2, 2, 2, 2, 2, 2, 2)
class MinkUNet34(MinkUNetBase):
BLOCK = BasicBlock
LAYERS = (2, 3, 4, 6, 2, 2, 2, 2)
class MinkUNet50(MinkUNetBase):
BLOCK = Bottleneck
LAYERS = (2, 3, 4, 6, 2, 2, 2, 2)
class MinkUNet101(MinkUNetBase):
BLOCK = Bottleneck
LAYERS = (2, 3, 4, 23, 2, 2, 2, 2)
class MinkUNet14A(MinkUNet14):
PLANES = (32, 64, 128, 256, 128, 128, 96, 96)
class MinkUNet14B(MinkUNet14):
PLANES = (32, 64, 128, 256, 128, 128, 128, 128)
class MinkUNet14C(MinkUNet14):
PLANES = (32, 64, 128, 256, 192, 192, 128, 128)
class MinkUNet14Dori(MinkUNet14):
PLANES = (32, 64, 128, 256, 384, 384, 384, 384)
class MinkUNet14E(MinkUNet14):
PLANES = (32, 64, 128, 256, 384, 384, 384, 384)
class MinkUNet14D(MinkUNet14):
PLANES = (32, 64, 128, 256, 192, 192, 192, 192)
class MinkUNet18A(MinkUNet18):
PLANES = (32, 64, 128, 256, 128, 128, 96, 96)
class MinkUNet18B(MinkUNet18):
PLANES = (32, 64, 128, 256, 128, 128, 128, 128)
class MinkUNet18D(MinkUNet18):
PLANES = (32, 64, 128, 256, 384, 384, 384, 384)
class MinkUNet34A(MinkUNet34):
PLANES = (32, 64, 128, 256, 256, 128, 64, 64)
class MinkUNet34B(MinkUNet34):
PLANES = (32, 64, 128, 256, 256, 128, 64, 32)
class MinkUNet34C(MinkUNet34):
PLANES = (32, 64, 128, 256, 256, 128, 96, 96)

View File

@@ -0,0 +1,126 @@
""" GraspNet baseline model definition.
Author: chenxi-wang
"""
import os
import sys
import numpy as np
import torch
import torch.nn as nn
import MinkowskiEngine as ME
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_DIR = os.path.dirname(BASE_DIR)
sys.path.append(ROOT_DIR)
from models.backbone_resunet14 import MinkUNet14D
from models.modules import ApproachNet, GraspableNet, CloudCrop, SWADNet
from loss_utils import GRASP_MAX_WIDTH, NUM_VIEW, NUM_ANGLE, NUM_DEPTH, GRASPNESS_THRESHOLD, M_POINT
from label_generation import process_grasp_labels, match_grasp_view_and_label, batch_viewpoint_params_to_matrix
from pointnet2.pointnet2_utils import furthest_point_sample, gather_operation
class GraspNet(nn.Module):
def __init__(self, cylinder_radius=0.05, seed_feat_dim=512, is_training=True):
super().__init__()
self.is_training = is_training
self.seed_feature_dim = seed_feat_dim
self.num_depth = NUM_DEPTH
self.num_angle = NUM_ANGLE
self.M_points = M_POINT
self.num_view = NUM_VIEW
self.backbone = MinkUNet14D(in_channels=3, out_channels=self.seed_feature_dim, D=3)
self.graspable = GraspableNet(seed_feature_dim=self.seed_feature_dim)
self.rotation = ApproachNet(self.num_view, seed_feature_dim=self.seed_feature_dim, is_training=self.is_training)
self.crop = CloudCrop(nsample=16, cylinder_radius=cylinder_radius, seed_feature_dim=self.seed_feature_dim)
self.swad = SWADNet(num_angle=self.num_angle, num_depth=self.num_depth)
def forward(self, end_points):
seed_xyz = end_points['point_clouds'] # use all sampled point cloud, B*Ns*3
B, point_num, _ = seed_xyz.shape # batch _size
# point-wise features
coordinates_batch = end_points['coors']
features_batch = end_points['feats']
mink_input = ME.SparseTensor(features_batch, coordinates=coordinates_batch)
seed_features = self.backbone(mink_input).F
seed_features = seed_features[end_points['quantize2original']].view(B, point_num, -1).transpose(1, 2)
end_points = self.graspable(seed_features, end_points)
seed_features_flipped = seed_features.transpose(1, 2) # B*Ns*feat_dim
objectness_score = end_points['objectness_score']
graspness_score = end_points['graspness_score'].squeeze(1)
objectness_pred = torch.argmax(objectness_score, 1)
objectness_mask = (objectness_pred == 1)
graspness_mask = graspness_score > GRASPNESS_THRESHOLD
graspable_mask = objectness_mask & graspness_mask
seed_features_graspable = []
seed_xyz_graspable = []
graspable_num_batch = 0.
for i in range(B):
cur_mask = graspable_mask[i]
graspable_num_batch += cur_mask.sum()
if graspable_num_batch == 0:
return None
cur_feat = seed_features_flipped[i][cur_mask] # Ns*feat_dim
cur_seed_xyz = seed_xyz[i][cur_mask] # Ns*3
cur_seed_xyz = cur_seed_xyz.unsqueeze(0) # 1*Ns*3
fps_idxs = furthest_point_sample(cur_seed_xyz, self.M_points)
cur_seed_xyz_flipped = cur_seed_xyz.transpose(1, 2).contiguous() # 1*3*Ns
cur_seed_xyz = gather_operation(cur_seed_xyz_flipped, fps_idxs).transpose(1, 2).squeeze(0).contiguous() # Ns*3
cur_feat_flipped = cur_feat.unsqueeze(0).transpose(1, 2).contiguous() # 1*feat_dim*Ns
cur_feat = gather_operation(cur_feat_flipped, fps_idxs).squeeze(0).contiguous() # feat_dim*Ns
seed_features_graspable.append(cur_feat)
seed_xyz_graspable.append(cur_seed_xyz)
seed_xyz_graspable = torch.stack(seed_xyz_graspable, 0) # B*Ns*3
seed_features_graspable = torch.stack(seed_features_graspable) # B*feat_dim*Ns
end_points['xyz_graspable'] = seed_xyz_graspable
end_points['graspable_count_stage1'] = graspable_num_batch / B
end_points, res_feat = self.rotation(seed_features_graspable, end_points)
seed_features_graspable = seed_features_graspable + res_feat
if self.is_training:
end_points = process_grasp_labels(end_points)
grasp_top_views_rot, end_points = match_grasp_view_and_label(end_points)
else:
grasp_top_views_rot = end_points['grasp_top_view_rot']
group_features = self.crop(seed_xyz_graspable.contiguous(), seed_features_graspable.contiguous(), grasp_top_views_rot)
end_points = self.swad(group_features, end_points)
return end_points
def pred_decode(end_points):
batch_size = len(end_points['point_clouds'])
grasp_preds = []
for i in range(batch_size):
grasp_center = end_points['xyz_graspable'][i].float()
grasp_score = end_points['grasp_score_pred'][i].float()
grasp_score = grasp_score.view(M_POINT, NUM_ANGLE*NUM_DEPTH)
grasp_score, grasp_score_inds = torch.max(grasp_score, -1) # [M_POINT]
grasp_score = grasp_score.view(-1, 1)
grasp_angle = (grasp_score_inds // NUM_DEPTH) * np.pi / 12
grasp_depth = (grasp_score_inds % NUM_DEPTH + 1) * 0.01
grasp_depth = grasp_depth.view(-1, 1)
grasp_width = 1.2 * end_points['grasp_width_pred'][i] / 10.
grasp_width = grasp_width.view(M_POINT, NUM_ANGLE*NUM_DEPTH)
grasp_width = torch.gather(grasp_width, 1, grasp_score_inds.view(-1, 1))
grasp_width = torch.clamp(grasp_width, min=0., max=GRASP_MAX_WIDTH)
approaching = -end_points['grasp_top_view_xyz'][i].float()
grasp_rot = batch_viewpoint_params_to_matrix(approaching, grasp_angle)
grasp_rot = grasp_rot.view(M_POINT, 9)
# merge preds
grasp_height = 0.02 * torch.ones_like(grasp_score)
obj_ids = -1 * torch.ones_like(grasp_score)
grasp_preds.append(
torch.cat([grasp_score, grasp_width, grasp_height, grasp_depth, grasp_rot, grasp_center, obj_ids], axis=-1))
return grasp_preds

View File

@@ -0,0 +1,80 @@
import torch.nn as nn
import torch
def get_loss(end_points):
objectness_loss, end_points = compute_objectness_loss(end_points)
graspness_loss, end_points = compute_graspness_loss(end_points)
view_loss, end_points = compute_view_graspness_loss(end_points)
score_loss, end_points = compute_score_loss(end_points)
width_loss, end_points = compute_width_loss(end_points)
loss = objectness_loss + 10 * graspness_loss + 100 * view_loss + 15 * score_loss + 10 * width_loss
end_points['loss/overall_loss'] = loss
return loss, end_points
def compute_objectness_loss(end_points):
criterion = nn.CrossEntropyLoss(reduction='mean')
objectness_score = end_points['objectness_score']
objectness_label = end_points['objectness_label']
loss = criterion(objectness_score, objectness_label)
end_points['loss/stage1_objectness_loss'] = loss
objectness_pred = torch.argmax(objectness_score, 1)
end_points['stage1_objectness_acc'] = (objectness_pred == objectness_label.long()).float().mean()
end_points['stage1_objectness_prec'] = (objectness_pred == objectness_label.long())[
objectness_pred == 1].float().mean()
end_points['stage1_objectness_recall'] = (objectness_pred == objectness_label.long())[
objectness_label == 1].float().mean()
return loss, end_points
def compute_graspness_loss(end_points):
criterion = nn.SmoothL1Loss(reduction='none')
graspness_score = end_points['graspness_score'].squeeze(1)
graspness_label = end_points['graspness_label'].squeeze(-1)
loss_mask = end_points['objectness_label'].bool()
loss = criterion(graspness_score, graspness_label)
loss = loss[loss_mask]
loss = loss.mean()
graspness_score_c = graspness_score.detach().clone()[loss_mask]
graspness_label_c = graspness_label.detach().clone()[loss_mask]
graspness_score_c = torch.clamp(graspness_score_c, 0., 0.99)
graspness_label_c = torch.clamp(graspness_label_c, 0., 0.99)
rank_error = (torch.abs(torch.trunc(graspness_score_c * 20) - torch.trunc(graspness_label_c * 20)) / 20.).mean()
end_points['stage1_graspness_acc_rank_error'] = rank_error
end_points['loss/stage1_graspness_loss'] = loss
return loss, end_points
def compute_view_graspness_loss(end_points):
criterion = nn.SmoothL1Loss(reduction='mean')
view_score = end_points['view_score']
view_label = end_points['batch_grasp_view_graspness']
loss = criterion(view_score, view_label)
end_points['loss/stage2_view_loss'] = loss
return loss, end_points
def compute_score_loss(end_points):
criterion = nn.SmoothL1Loss(reduction='mean')
grasp_score_pred = end_points['grasp_score_pred']
grasp_score_label = end_points['batch_grasp_score']
loss = criterion(grasp_score_pred, grasp_score_label)
end_points['loss/stage3_score_loss'] = loss
return loss, end_points
def compute_width_loss(end_points):
criterion = nn.SmoothL1Loss(reduction='none')
grasp_width_pred = end_points['grasp_width_pred']
grasp_width_label = end_points['batch_grasp_width'] * 10
loss = criterion(grasp_width_pred, grasp_width_label)
grasp_score_label = end_points['batch_grasp_score']
loss_mask = grasp_score_label > 0
loss = loss[loss_mask].mean()
end_points['loss/stage3_width_loss'] = loss
return loss, end_points

View File

@@ -0,0 +1,116 @@
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_DIR = os.path.dirname(BASE_DIR)
sys.path.append(ROOT_DIR)
import pointnet2.pytorch_utils as pt_utils
from pointnet2.pointnet2_utils import CylinderQueryAndGroup
from loss_utils import generate_grasp_views, batch_viewpoint_params_to_matrix
class GraspableNet(nn.Module):
def __init__(self, seed_feature_dim):
super().__init__()
self.in_dim = seed_feature_dim
self.conv_graspable = nn.Conv1d(self.in_dim, 3, 1)
def forward(self, seed_features, end_points):
graspable_score = self.conv_graspable(seed_features) # (B, 3, num_seed)
end_points['objectness_score'] = graspable_score[:, :2]
end_points['graspness_score'] = graspable_score[:, 2]
return end_points
class ApproachNet(nn.Module):
def __init__(self, num_view, seed_feature_dim, is_training=True):
super().__init__()
self.num_view = num_view
self.in_dim = seed_feature_dim
self.is_training = is_training
self.conv1 = nn.Conv1d(self.in_dim, self.in_dim, 1)
self.conv2 = nn.Conv1d(self.in_dim, self.num_view, 1)
def forward(self, seed_features, end_points):
B, _, num_seed = seed_features.size()
res_features = F.relu(self.conv1(seed_features), inplace=True)
features = self.conv2(res_features)
view_score = features.transpose(1, 2).contiguous() # (B, num_seed, num_view)
end_points['view_score'] = view_score
if self.is_training:
# normalize view graspness score to 0~1
view_score_ = view_score.clone().detach()
view_score_max, _ = torch.max(view_score_, dim=2)
view_score_min, _ = torch.min(view_score_, dim=2)
view_score_max = view_score_max.unsqueeze(-1).expand(-1, -1, self.num_view)
view_score_min = view_score_min.unsqueeze(-1).expand(-1, -1, self.num_view)
view_score_ = (view_score_ - view_score_min) / (view_score_max - view_score_min + 1e-8)
top_view_inds = []
for i in range(B):
top_view_inds_batch = torch.multinomial(view_score_[i], 1, replacement=False)
top_view_inds.append(top_view_inds_batch)
top_view_inds = torch.stack(top_view_inds, dim=0).squeeze(-1) # B, num_seed
else:
_, top_view_inds = torch.max(view_score, dim=2) # (B, num_seed)
top_view_inds_ = top_view_inds.view(B, num_seed, 1, 1).expand(-1, -1, -1, 3).contiguous()
template_views = generate_grasp_views(self.num_view).to(features.device) # (num_view, 3)
template_views = template_views.view(1, 1, self.num_view, 3).expand(B, num_seed, -1, -1).contiguous()
vp_xyz = torch.gather(template_views, 2, top_view_inds_).squeeze(2) # (B, num_seed, 3)
vp_xyz_ = vp_xyz.view(-1, 3)
batch_angle = torch.zeros(vp_xyz_.size(0), dtype=vp_xyz.dtype, device=vp_xyz.device)
vp_rot = batch_viewpoint_params_to_matrix(-vp_xyz_, batch_angle).view(B, num_seed, 3, 3)
end_points['grasp_top_view_xyz'] = vp_xyz
end_points['grasp_top_view_rot'] = vp_rot
end_points['grasp_top_view_inds'] = top_view_inds
return end_points, res_features
class CloudCrop(nn.Module):
def __init__(self, nsample, seed_feature_dim, cylinder_radius=0.05, hmin=-0.02, hmax=0.04):
super().__init__()
self.nsample = nsample
self.in_dim = seed_feature_dim
self.cylinder_radius = cylinder_radius
mlps = [3 + self.in_dim, 256, 256] # use xyz, so plus 3
self.grouper = CylinderQueryAndGroup(radius=cylinder_radius, hmin=hmin, hmax=hmax, nsample=nsample,
use_xyz=True, normalize_xyz=True)
self.mlps = pt_utils.SharedMLP(mlps, bn=True)
def forward(self, seed_xyz_graspable, seed_features_graspable, vp_rot):
grouped_feature = self.grouper(seed_xyz_graspable, seed_xyz_graspable, vp_rot,
seed_features_graspable) # B*3 + feat_dim*M*K
new_features = self.mlps(grouped_feature) # (batch_size, mlps[-1], M, K)
new_features = F.max_pool2d(new_features, kernel_size=[1, new_features.size(3)]) # (batch_size, mlps[-1], M, 1)
new_features = new_features.squeeze(-1) # (batch_size, mlps[-1], M)
return new_features
class SWADNet(nn.Module):
def __init__(self, num_angle, num_depth):
super().__init__()
self.num_angle = num_angle
self.num_depth = num_depth
self.conv1 = nn.Conv1d(256, 256, 1) # input feat dim need to be consistent with CloudCrop module
self.conv_swad = nn.Conv1d(256, 2*num_angle*num_depth, 1)
def forward(self, vp_features, end_points):
B, _, num_seed = vp_features.size()
vp_features = F.relu(self.conv1(vp_features), inplace=True)
vp_features = self.conv_swad(vp_features)
vp_features = vp_features.view(B, 2, self.num_angle, self.num_depth, num_seed)
vp_features = vp_features.permute(0, 1, 4, 2, 3)
# split prediction
end_points['grasp_score_pred'] = vp_features[:, 0] # B * num_seed * num angle * num_depth
end_points['grasp_width_pred'] = vp_features[:, 1]
return end_points

View File

@@ -0,0 +1,196 @@
import torch.nn as nn
try:
import open3d as o3d
except ImportError:
raise ImportError("Please install open3d with `pip install open3d`.")
import MinkowskiEngine as ME
from MinkowskiEngine.modules.resnet_block import BasicBlock, Bottleneck
class ResNetBase(nn.Module):
BLOCK = None
LAYERS = ()
INIT_DIM = 64
PLANES = (64, 128, 256, 512)
def __init__(self, in_channels, out_channels, D=3):
nn.Module.__init__(self)
self.D = D
assert self.BLOCK is not None
self.network_initialization(in_channels, out_channels, D)
self.weight_initialization()
def network_initialization(self, in_channels, out_channels, D):
self.inplanes = self.INIT_DIM
self.conv1 = nn.Sequential(
ME.MinkowskiConvolution(
in_channels, self.inplanes, kernel_size=3, stride=2, dimension=D
),
ME.MinkowskiInstanceNorm(self.inplanes),
ME.MinkowskiReLU(inplace=True),
ME.MinkowskiMaxPooling(kernel_size=2, stride=2, dimension=D),
)
self.layer1 = self._make_layer(
self.BLOCK, self.PLANES[0], self.LAYERS[0], stride=2
)
self.layer2 = self._make_layer(
self.BLOCK, self.PLANES[1], self.LAYERS[1], stride=2
)
self.layer3 = self._make_layer(
self.BLOCK, self.PLANES[2], self.LAYERS[2], stride=2
)
self.layer4 = self._make_layer(
self.BLOCK, self.PLANES[3], self.LAYERS[3], stride=2
)
self.conv5 = nn.Sequential(
ME.MinkowskiDropout(),
ME.MinkowskiConvolution(
self.inplanes, self.inplanes, kernel_size=3, stride=3, dimension=D
),
ME.MinkowskiInstanceNorm(self.inplanes),
ME.MinkowskiGELU(),
)
self.glob_pool = ME.MinkowskiGlobalMaxPooling()
self.final = ME.MinkowskiLinear(self.inplanes, out_channels, bias=True)
def weight_initialization(self):
for m in self.modules():
if isinstance(m, ME.MinkowskiConvolution):
ME.utils.kaiming_normal_(m.kernel, mode="fan_out", nonlinearity="relu")
if isinstance(m, ME.MinkowskiBatchNorm):
nn.init.constant_(m.bn.weight, 1)
nn.init.constant_(m.bn.bias, 0)
def _make_layer(self, block, planes, blocks, stride=1, dilation=1, bn_momentum=0.1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
ME.MinkowskiConvolution(
self.inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
dimension=self.D,
),
ME.MinkowskiBatchNorm(planes * block.expansion),
)
layers = []
layers.append(
block(
self.inplanes,
planes,
stride=stride,
dilation=dilation,
downsample=downsample,
dimension=self.D,
)
)
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(
block(
self.inplanes, planes, stride=1, dilation=dilation, dimension=self.D
)
)
return nn.Sequential(*layers)
def forward(self, x: ME.SparseTensor):
x = self.conv1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.conv5(x)
x = self.glob_pool(x)
return self.final(x)
class ResNet14(ResNetBase):
BLOCK = BasicBlock
LAYERS = (1, 1, 1, 1)
class ResNet18(ResNetBase):
BLOCK = BasicBlock
LAYERS = (2, 2, 2, 2)
class ResNet34(ResNetBase):
BLOCK = BasicBlock
LAYERS = (3, 4, 6, 3)
class ResNet50(ResNetBase):
BLOCK = Bottleneck
LAYERS = (3, 4, 6, 3)
class ResNet101(ResNetBase):
BLOCK = Bottleneck
LAYERS = (3, 4, 23, 3)
class ResFieldNetBase(ResNetBase):
def network_initialization(self, in_channels, out_channels, D):
field_ch = 32
field_ch2 = 64
self.field_network = nn.Sequential(
ME.MinkowskiSinusoidal(in_channels, field_ch),
ME.MinkowskiBatchNorm(field_ch),
ME.MinkowskiReLU(inplace=True),
ME.MinkowskiLinear(field_ch, field_ch),
ME.MinkowskiBatchNorm(field_ch),
ME.MinkowskiReLU(inplace=True),
ME.MinkowskiToSparseTensor(),
)
self.field_network2 = nn.Sequential(
ME.MinkowskiSinusoidal(field_ch + in_channels, field_ch2),
ME.MinkowskiBatchNorm(field_ch2),
ME.MinkowskiReLU(inplace=True),
ME.MinkowskiLinear(field_ch2, field_ch2),
ME.MinkowskiBatchNorm(field_ch2),
ME.MinkowskiReLU(inplace=True),
ME.MinkowskiToSparseTensor(),
)
ResNetBase.network_initialization(self, field_ch2, out_channels, D)
def forward(self, x: ME.TensorField):
otensor = self.field_network(x)
otensor2 = self.field_network2(otensor.cat_slice(x))
return ResNetBase.forward(self, otensor2)
class ResFieldNet14(ResFieldNetBase):
BLOCK = BasicBlock
LAYERS = (1, 1, 1, 1)
class ResFieldNet18(ResFieldNetBase):
BLOCK = BasicBlock
LAYERS = (2, 2, 2, 2)
class ResFieldNet34(ResFieldNetBase):
BLOCK = BasicBlock
LAYERS = (3, 4, 6, 3)
class ResFieldNet50(ResFieldNetBase):
BLOCK = Bottleneck
LAYERS = (3, 4, 6, 3)
class ResFieldNet101(ResFieldNetBase):
BLOCK = Bottleneck
LAYERS = (3, 4, 23, 3)