success
This commit is contained in:
0
modules/pts_encoder/__init__.py
Executable file
0
modules/pts_encoder/__init__.py
Executable file
12
modules/pts_encoder/abstract_pts_encoder.py
Executable file
12
modules/pts_encoder/abstract_pts_encoder.py
Executable file
@@ -0,0 +1,12 @@
|
||||
from abc import abstractmethod
|
||||
|
||||
from torch import nn
|
||||
|
||||
|
||||
class PointsEncoder(nn.Module):
|
||||
def __init__(self):
|
||||
super(PointsEncoder, self).__init__()
|
||||
|
||||
@abstractmethod
|
||||
def encode_points(self, pts):
|
||||
pass
|
117
modules/pts_encoder/pointnet2_encoder.py
Executable file
117
modules/pts_encoder/pointnet2_encoder.py
Executable file
@@ -0,0 +1,117 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import os
|
||||
import sys
|
||||
path = os.path.abspath(__file__)
|
||||
for i in range(3):
|
||||
path = os.path.dirname(path)
|
||||
PROJECT_ROOT = path
|
||||
sys.path.append(PROJECT_ROOT)
|
||||
from modules.module_lib.pointnet2_utils.pointnet2.pointnet2_modules import PointnetSAModuleMSG
|
||||
from modules.pts_encoder.abstract_pts_encoder import PointsEncoder
|
||||
|
||||
ClsMSG_CFG_Dense = {
|
||||
'NPOINTS': [512, 256, 128, None],
|
||||
'RADIUS': [[0.02, 0.04], [0.04, 0.08], [0.08, 0.16], [None, None]],
|
||||
'NSAMPLE': [[32, 64], [16, 32], [8, 16], [None, None]],
|
||||
'MLPS': [[[16, 16, 32], [32, 32, 64]],
|
||||
[[64, 64, 128], [64, 96, 128]],
|
||||
[[128, 196, 256], [128, 196, 256]],
|
||||
[[256, 256, 512], [256, 384, 512]]],
|
||||
'DP_RATIO': 0.5,
|
||||
}
|
||||
|
||||
ClsMSG_CFG_Light = {
|
||||
'NPOINTS': [512, 256, 128, None],
|
||||
'RADIUS': [[0.02, 0.04], [0.04, 0.08], [0.08, 0.16], [None, None]],
|
||||
'NSAMPLE': [[16, 32], [16, 32], [16, 32], [None, None]],
|
||||
'MLPS': [[[16, 16, 32], [32, 32, 64]],
|
||||
[[64, 64, 128], [64, 96, 128]],
|
||||
[[128, 196, 256], [128, 196, 256]],
|
||||
[[256, 256, 512], [256, 384, 512]]],
|
||||
'DP_RATIO': 0.5,
|
||||
}
|
||||
|
||||
ClsMSG_CFG_Lighter = {
|
||||
'NPOINTS': [512, 256, 128, 64, None],
|
||||
'RADIUS': [[0.01], [0.02], [0.04], [0.08], [None]],
|
||||
'NSAMPLE': [[64], [32], [16], [8], [None]],
|
||||
'MLPS': [[[32, 32, 64]],
|
||||
[[64, 64, 128]],
|
||||
[[128, 196, 256]],
|
||||
[[256, 256, 512]],
|
||||
[[512, 512, 1024]]],
|
||||
'DP_RATIO': 0.5,
|
||||
}
|
||||
|
||||
|
||||
def select_params(name):
|
||||
if name == 'light':
|
||||
return ClsMSG_CFG_Light
|
||||
elif name == 'lighter':
|
||||
return ClsMSG_CFG_Lighter
|
||||
elif name == 'dense':
|
||||
return ClsMSG_CFG_Dense
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def break_up_pc(pc):
|
||||
xyz = pc[..., 0:3].contiguous()
|
||||
features = (
|
||||
pc[..., 3:].transpose(1, 2).contiguous()
|
||||
if pc.size(-1) > 3 else None
|
||||
)
|
||||
|
||||
return xyz, features
|
||||
|
||||
|
||||
class PointNet2Encoder(PointsEncoder):
|
||||
def encode_points(self, pts):
|
||||
return self.forward(pts)
|
||||
|
||||
def __init__(self, input_channels=6, params_name="light"):
|
||||
super().__init__()
|
||||
|
||||
self.SA_modules = nn.ModuleList()
|
||||
channel_in = input_channels
|
||||
selected_params = select_params(params_name)
|
||||
for k in range(selected_params['NPOINTS'].__len__()):
|
||||
mlps = selected_params['MLPS'][k].copy()
|
||||
channel_out = 0
|
||||
for idx in range(mlps.__len__()):
|
||||
mlps[idx] = [channel_in] + mlps[idx]
|
||||
channel_out += mlps[idx][-1]
|
||||
|
||||
self.SA_modules.append(
|
||||
PointnetSAModuleMSG(
|
||||
npoint=selected_params['NPOINTS'][k],
|
||||
radii=selected_params['RADIUS'][k],
|
||||
nsamples=selected_params['NSAMPLE'][k],
|
||||
mlps=mlps,
|
||||
use_xyz=True,
|
||||
bn=True
|
||||
)
|
||||
)
|
||||
channel_in = channel_out
|
||||
|
||||
def forward(self, point_cloud: torch.cuda.FloatTensor):
|
||||
xyz, features = break_up_pc(point_cloud)
|
||||
|
||||
l_xyz, l_features = [xyz], [features]
|
||||
for i in range(len(self.SA_modules)):
|
||||
li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i])
|
||||
l_xyz.append(li_xyz)
|
||||
l_features.append(li_features)
|
||||
return l_features[-1].squeeze(-1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
seed = 100
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
net = PointNet2Encoder(0).cuda()
|
||||
pts = torch.randn(2, 1024, 3).cuda()
|
||||
print(torch.mean(pts, dim=1))
|
||||
pre = net.encode_points(pts)
|
||||
print(pre.shape)
|
117
modules/pts_encoder/pointnet3_encoder.py
Executable file
117
modules/pts_encoder/pointnet3_encoder.py
Executable file
@@ -0,0 +1,117 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from modules.module_lib.pointnet2_utils.pointnet2.pointnet2_modules import PointnetSAModuleMSG
|
||||
from modules.pts_encoder.abstract_pts_encoder import PointsEncoder
|
||||
|
||||
ClsMSG_CFG_Dense = {
|
||||
'NPOINTS': [512, 256, 128, None],
|
||||
'RADIUS': [[0.02, 0.04], [0.04, 0.08], [0.08, 0.16], [None, None]],
|
||||
'NSAMPLE': [[32, 64], [16, 32], [8, 16], [None, None]],
|
||||
'MLPS': [[[16, 16, 32], [32, 32, 64]],
|
||||
[[64, 64, 128], [64, 96, 128]],
|
||||
[[128, 196, 256], [128, 196, 256]],
|
||||
[[256, 256, 512], [256, 384, 512]]],
|
||||
'DP_RATIO': 0.5,
|
||||
}
|
||||
|
||||
ClsMSG_CFG_Light = {
|
||||
'NPOINTS': [512, 256, 128, None],
|
||||
'RADIUS': [[0.02, 0.04], [0.04, 0.08], [0.08, 0.16], [None, None]],
|
||||
'NSAMPLE': [[16, 32], [16, 32], [16, 32], [None, None]],
|
||||
'MLPS': [[[16, 16, 32], [32, 32, 64]],
|
||||
[[64, 64, 128], [64, 96, 128]],
|
||||
[[128, 196, 256], [128, 196, 256]],
|
||||
[[256, 256, 512], [256, 384, 512]]],
|
||||
'DP_RATIO': 0.5,
|
||||
}
|
||||
|
||||
ClsMSG_CFG_Lighter = {
|
||||
'NPOINTS': [512, 256, 128, 64, None],
|
||||
'RADIUS': [[0.01], [0.02], [0.04], [0.08], [None]],
|
||||
'NSAMPLE': [[64], [32], [16], [8], [None]],
|
||||
'MLPS': [[[32, 32, 64]],
|
||||
[[64, 64, 128]],
|
||||
[[128, 196, 256]],
|
||||
[[256, 256, 512]],
|
||||
[[512, 512, 1024]]],
|
||||
'DP_RATIO': 0.5,
|
||||
}
|
||||
|
||||
|
||||
def select_params(name):
|
||||
if name == 'light':
|
||||
return ClsMSG_CFG_Light
|
||||
elif name == 'lighter':
|
||||
return ClsMSG_CFG_Lighter
|
||||
elif name == 'dense':
|
||||
return ClsMSG_CFG_Dense
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def break_up_pc(pc):
|
||||
xyz = pc[..., 0:3].contiguous()
|
||||
features = (
|
||||
pc[..., 3:].transpose(1, 2).contiguous()
|
||||
if pc.size(-1) > 3 else None
|
||||
)
|
||||
|
||||
return xyz, features
|
||||
|
||||
|
||||
class PointNet3Encoder(PointsEncoder):
|
||||
def encode_points(self, pts, rgb_feat):
|
||||
return self.forward(pts,rgb_feat)
|
||||
|
||||
def __init__(self, input_channels=6, params_name="light",target_layer=2, rgb_feat_dim=384):
|
||||
super().__init__()
|
||||
self.SA_modules = nn.ModuleList()
|
||||
channel_in = input_channels
|
||||
self.target_layer = target_layer
|
||||
selected_params = select_params(params_name)
|
||||
for k in range(selected_params['NPOINTS'].__len__()):
|
||||
mlps = selected_params['MLPS'][k].copy()
|
||||
channel_out = 0
|
||||
if k==target_layer:
|
||||
channel_in += rgb_feat_dim
|
||||
for idx in range(mlps.__len__()):
|
||||
mlps[idx] = [channel_in] + mlps[idx]
|
||||
channel_out += mlps[idx][-1]
|
||||
|
||||
self.SA_modules.append(
|
||||
PointnetSAModuleMSG(
|
||||
npoint=selected_params['NPOINTS'][k],
|
||||
radii=selected_params['RADIUS'][k],
|
||||
nsamples=selected_params['NSAMPLE'][k],
|
||||
mlps=mlps,
|
||||
use_xyz=True,
|
||||
bn=True
|
||||
)
|
||||
)
|
||||
channel_in = channel_out
|
||||
|
||||
def forward(self, point_cloud: torch.cuda.FloatTensor, rgb_feat):
|
||||
xyz, features = break_up_pc(point_cloud)
|
||||
|
||||
l_xyz, l_features = [xyz], [features]
|
||||
for i in range(len(self.SA_modules)):
|
||||
if i==self.target_layer:
|
||||
rgb_feat = torch.mean(rgb_feat, dim=1)
|
||||
rgb_feat = rgb_feat.unsqueeze(-1).repeat(1,1,l_xyz[i].shape[1])
|
||||
l_features[-1] = torch.cat([l_features[-1], rgb_feat], dim=1)
|
||||
li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i])
|
||||
l_xyz.append(li_xyz)
|
||||
l_features.append(li_features)
|
||||
return l_features[-1].squeeze(-1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
seed = 100
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
net = PointNet3Encoder(0).cuda()
|
||||
pts = torch.randn(2, 1024, 3).cuda()
|
||||
rgb_feat = torch.randn(2, 384).cuda()
|
||||
print(torch.mean(pts, dim=1))
|
||||
pre = net.encode_points(pts,rgb_feat)
|
||||
print(pre.shape)
|
110
modules/pts_encoder/pointnet_encoder.py
Executable file
110
modules/pts_encoder/pointnet_encoder.py
Executable file
@@ -0,0 +1,110 @@
|
||||
from __future__ import print_function
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.parallel
|
||||
import torch.utils.data
|
||||
from torch.autograd import Variable
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
|
||||
from modules.pts_encoder.abstract_pts_encoder import PointsEncoder
|
||||
|
||||
|
||||
class STNkd(nn.Module):
|
||||
def __init__(self, k=64):
|
||||
super(STNkd, self).__init__()
|
||||
self.conv1 = torch.nn.Conv1d(k, 64, 1)
|
||||
self.conv2 = torch.nn.Conv1d(64, 128, 1)
|
||||
self.conv3 = torch.nn.Conv1d(128, 1024, 1)
|
||||
self.fc1 = nn.Linear(1024, 512)
|
||||
self.fc2 = nn.Linear(512, 256)
|
||||
self.fc3 = nn.Linear(256, k * k)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
self.k = k
|
||||
|
||||
def forward(self, x):
|
||||
batchsize = x.size()[0]
|
||||
x = F.relu(self.conv1(x))
|
||||
x = F.relu(self.conv2(x))
|
||||
x = F.relu(self.conv3(x))
|
||||
x = torch.max(x, 2, keepdim=True)[0]
|
||||
x = x.view(-1, 1024)
|
||||
|
||||
x = F.relu(self.fc1(x))
|
||||
x = F.relu(self.fc2(x))
|
||||
x = self.fc3(x)
|
||||
|
||||
iden = (
|
||||
Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32)))
|
||||
.view(1, self.k * self.k)
|
||||
.repeat(batchsize, 1)
|
||||
)
|
||||
if x.is_cuda:
|
||||
iden = iden.to(x.get_device())
|
||||
x = x + iden
|
||||
x = x.view(-1, self.k, self.k)
|
||||
return x
|
||||
|
||||
|
||||
# NOTE: removed BN
|
||||
class PointNetEncoder(PointsEncoder):
|
||||
|
||||
def __init__(self, global_feat=True, in_dim=3, out_dim=1024, feature_transform=False):
|
||||
super(PointNetEncoder, self).__init__()
|
||||
self.out_dim = out_dim
|
||||
self.feature_transform = feature_transform
|
||||
self.stn = STNkd(k=in_dim)
|
||||
self.conv1 = torch.nn.Conv1d(in_dim, 64, 1)
|
||||
self.conv2 = torch.nn.Conv1d(64, 128, 1)
|
||||
self.conv3 = torch.nn.Conv1d(128, 512, 1)
|
||||
self.conv4 = torch.nn.Conv1d(512, out_dim, 1)
|
||||
self.global_feat = global_feat
|
||||
if self.feature_transform:
|
||||
self.f_stn = STNkd(k=64)
|
||||
|
||||
def forward(self, x):
|
||||
n_pts = x.shape[2]
|
||||
trans = self.stn(x)
|
||||
x = x.transpose(2, 1)
|
||||
x = torch.bmm(x, trans)
|
||||
x = x.transpose(2, 1)
|
||||
x = F.relu(self.conv1(x))
|
||||
|
||||
if self.feature_transform:
|
||||
trans_feat = self.f_stn(x)
|
||||
x = x.transpose(2, 1)
|
||||
x = torch.bmm(x, trans_feat)
|
||||
x = x.transpose(2, 1)
|
||||
|
||||
point_feat = x
|
||||
x = F.relu(self.conv2(x))
|
||||
x = F.relu(self.conv3(x))
|
||||
x = self.conv4(x)
|
||||
x = torch.max(x, 2, keepdim=True)[0]
|
||||
x = x.view(-1, self.out_dim)
|
||||
if self.global_feat:
|
||||
return x
|
||||
else:
|
||||
x = x.view(-1, self.out_dim, 1).repeat(1, 1, n_pts)
|
||||
return torch.cat([x, point_feat], 1)
|
||||
|
||||
def encode_points(self, pts):
|
||||
pts = pts.transpose(2, 1)
|
||||
if not self.global_feat:
|
||||
pts_feature = self(pts).transpose(2, 1)
|
||||
else:
|
||||
pts_feature = self(pts)
|
||||
return pts_feature
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sim_data = Variable(torch.rand(32, 2500, 3))
|
||||
|
||||
pointnet_global = PointNetEncoder(global_feat=True)
|
||||
out = pointnet_global.encode_points(sim_data)
|
||||
print("global feat", out.size())
|
||||
|
||||
pointnet = PointNetEncoder(global_feat=False)
|
||||
out = pointnet.encode_points(sim_data)
|
||||
print("point feat", out.size())
|
56
modules/pts_encoder/pts_encoder_factory.py
Executable file
56
modules/pts_encoder/pts_encoder_factory.py
Executable file
@@ -0,0 +1,56 @@
|
||||
import sys
|
||||
import os
|
||||
path = os.path.abspath(__file__)
|
||||
for i in range(3):
|
||||
path = os.path.dirname(path)
|
||||
PROJECT_ROOT = path
|
||||
sys.path.append(PROJECT_ROOT)
|
||||
|
||||
from modules.pts_encoder.abstract_pts_encoder import PointsEncoder
|
||||
from modules.pts_encoder.pointnet_encoder import PointNetEncoder
|
||||
from modules.pts_encoder.pointnet2_encoder import PointNet2Encoder
|
||||
from modules.pts_encoder.pointnet3_encoder import PointNet3Encoder
|
||||
|
||||
class PointsEncoderFactory:
|
||||
@staticmethod
|
||||
def create(name, config) -> PointsEncoder:
|
||||
general_config = config["general"]
|
||||
pts_encoder_config = config["pts_encoder"][name]
|
||||
if name == "pointnet":
|
||||
return PointNetEncoder(
|
||||
in_dim=general_config["pts_channels"],
|
||||
out_dim=general_config["feature_dim"],
|
||||
global_feat=not general_config["per_point_feature"]
|
||||
)
|
||||
elif name == "pointnet++":
|
||||
return PointNet2Encoder(
|
||||
input_channels=general_config["pts_channels"] - 3,
|
||||
params_name=pts_encoder_config["params_name"]
|
||||
)
|
||||
elif name == "pointnet++rgb":
|
||||
return PointNet3Encoder(
|
||||
input_channels=general_config["pts_channels"] - 3,
|
||||
params_name=pts_encoder_config["params_name"],
|
||||
target_layer=pts_encoder_config["target_layer"],
|
||||
rgb_feat_dim=pts_encoder_config["rgb_feat_dim"]
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown encoder name: {name}")
|
||||
|
||||
|
||||
''' ------------ Debug ------------ '''
|
||||
if __name__ == "__main__":
|
||||
from configs.config import ConfigManager
|
||||
import torch
|
||||
|
||||
pts = torch.rand(32, 1200, 3) # BxNxC
|
||||
ConfigManager.load_config_with('configs/local_train_config.yaml')
|
||||
ConfigManager.print_config()
|
||||
pts_encoder = PointsEncoderFactory.create(name="pointnet++", config=ConfigManager.get("modules"))
|
||||
print(pts_encoder)
|
||||
pts = pts.to("cuda")
|
||||
pts_encoder = pts_encoder.to("cuda")
|
||||
|
||||
pts_feat = pts_encoder.encode_points(pts)
|
||||
|
||||
print(pts_feat.shape)
|
Reference in New Issue
Block a user