update modules and pipeline
This commit is contained in:
parent
913d4e521d
commit
837e1c870a
29
core/pipeline.py
Normal file
29
core/pipeline.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
|
||||||
|
from torch import nn
|
||||||
|
import PytorchBoot.namespace as namespace
|
||||||
|
import PytorchBoot.stereotype as stereotype
|
||||||
|
from PytorchBoot.factory.component_factory import ComponentFactory
|
||||||
|
|
||||||
|
@stereotype.pipeline("nbv_reconstruction_pipeline")
|
||||||
|
class ViewFinderPipeline(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super(ViewFinderPipeline, self).__init__()
|
||||||
|
self.config = config
|
||||||
|
self.pts_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, config["pts_encoder"])
|
||||||
|
self.pose_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, config["pose_encoder"])
|
||||||
|
self.seq_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, config["seq_encoder"])
|
||||||
|
self.view_finder = ComponentFactory.create(namespace.Stereotype.MODULE, config["view_finder"])
|
||||||
|
|
||||||
|
def forward(self, data):
|
||||||
|
output = {}
|
||||||
|
pts_list = data['pts_list']
|
||||||
|
pose_list = data['pose_list']
|
||||||
|
pts_feat_list = []
|
||||||
|
pose_feat_list = []
|
||||||
|
for pts,pose in zip(pts_list,pose_list):
|
||||||
|
pts_feat_list.append(self.pts_encoder.encode_points(pts))
|
||||||
|
pose_feat_list.append(self.pose_encoder.encode_pose(pose))
|
||||||
|
seq_feat = self.seq_encoder.encode_sequence(pts_feat_list, pose_feat_list)
|
||||||
|
output['next_best_view'] = self.view_finder.next_best_view(seq_feat)
|
||||||
|
|
||||||
|
return output
|
2
modules/module_lib/__init__.py
Normal file
2
modules/module_lib/__init__.py
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
from modules.module_lib.gaussian_fourier_projection import GaussianFourierProjection
|
||||||
|
from modules.module_lib.linear import Linear
|
20
modules/pose_encoder/pose_encoder.py
Normal file
20
modules/pose_encoder/pose_encoder.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
from torch import nn
|
||||||
|
import PytorchBoot.stereotype as stereotype
|
||||||
|
|
||||||
|
@stereotype.module("pose_encoder")
|
||||||
|
class PoseEncoder(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super(PoseEncoder, self).__init__()
|
||||||
|
self.config = config
|
||||||
|
pose_dim = config["pose_dim"]
|
||||||
|
self.act = nn.ReLU(True)
|
||||||
|
|
||||||
|
self.pose_encoder = nn.Sequential(
|
||||||
|
nn.Linear(pose_dim, 256),
|
||||||
|
self.act,
|
||||||
|
nn.Linear(256, 256),
|
||||||
|
self.act,
|
||||||
|
)
|
||||||
|
|
||||||
|
def encode_pose(self, pose):
|
||||||
|
return self.pose_encoder(pose)
|
@ -9,57 +9,21 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
from modules.pts_encoder.abstract_pts_encoder import PointsEncoder
|
from modules.pts_encoder.abstract_pts_encoder import PointsEncoder
|
||||||
import PytorchBoot.stereotype as stereotype
|
import PytorchBoot.stereotype as stereotype
|
||||||
|
|
||||||
class STNkd(nn.Module):
|
|
||||||
def __init__(self, k=64):
|
|
||||||
super(STNkd, self).__init__()
|
|
||||||
self.conv1 = torch.nn.Conv1d(k, 64, 1)
|
|
||||||
self.conv2 = torch.nn.Conv1d(64, 128, 1)
|
|
||||||
self.conv3 = torch.nn.Conv1d(128, 1024, 1)
|
|
||||||
self.fc1 = nn.Linear(1024, 512)
|
|
||||||
self.fc2 = nn.Linear(512, 256)
|
|
||||||
self.fc3 = nn.Linear(256, k * k)
|
|
||||||
self.relu = nn.ReLU()
|
|
||||||
|
|
||||||
self.k = k
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
batchsize = x.size()[0]
|
|
||||||
x = F.relu(self.conv1(x))
|
|
||||||
x = F.relu(self.conv2(x))
|
|
||||||
x = F.relu(self.conv3(x))
|
|
||||||
x = torch.max(x, 2, keepdim=True)[0]
|
|
||||||
x = x.view(-1, 1024)
|
|
||||||
|
|
||||||
x = F.relu(self.fc1(x))
|
|
||||||
x = F.relu(self.fc2(x))
|
|
||||||
x = self.fc3(x)
|
|
||||||
|
|
||||||
iden = (
|
|
||||||
Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32)))
|
|
||||||
.view(1, self.k * self.k)
|
|
||||||
.repeat(batchsize, 1)
|
|
||||||
)
|
|
||||||
if x.is_cuda:
|
|
||||||
iden = iden.to(x.get_device())
|
|
||||||
x = x + iden
|
|
||||||
x = x.view(-1, self.k, self.k)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
@stereotype.module("pointnet_encoder")
|
@stereotype.module("pointnet_encoder")
|
||||||
class PointNetEncoder(PointsEncoder):
|
class PointNetEncoder(PointsEncoder):
|
||||||
|
|
||||||
def __init__(self, global_feat=True, in_dim=3, out_dim=1024, feature_transform=False):
|
def __init__(self, config:dict):
|
||||||
super(PointNetEncoder, self).__init__()
|
super(PointNetEncoder, self).__init__()
|
||||||
self.out_dim = out_dim
|
|
||||||
self.feature_transform = feature_transform
|
self.out_dim = config["out_dim"]
|
||||||
self.stn = STNkd(k=in_dim)
|
self.in_dim = config["in_dim"]
|
||||||
self.conv1 = torch.nn.Conv1d(in_dim, 64, 1)
|
self.feature_transform = config.get("feature_transform", False)
|
||||||
|
self.stn = STNkd(k=self.in_dim)
|
||||||
|
self.conv1 = torch.nn.Conv1d(self.in_dim , 64, 1)
|
||||||
self.conv2 = torch.nn.Conv1d(64, 128, 1)
|
self.conv2 = torch.nn.Conv1d(64, 128, 1)
|
||||||
self.conv3 = torch.nn.Conv1d(128, 512, 1)
|
self.conv3 = torch.nn.Conv1d(128, 512, 1)
|
||||||
self.conv4 = torch.nn.Conv1d(512, out_dim, 1)
|
self.conv4 = torch.nn.Conv1d(512, self.out_dim , 1)
|
||||||
self.global_feat = global_feat
|
self.global_feat = config["global_feat"]
|
||||||
if self.feature_transform:
|
if self.feature_transform:
|
||||||
self.f_stn = STNkd(k=64)
|
self.f_stn = STNkd(k=64)
|
||||||
|
|
||||||
@ -97,6 +61,41 @@ class PointNetEncoder(PointsEncoder):
|
|||||||
pts_feature = self(pts)
|
pts_feature = self(pts)
|
||||||
return pts_feature
|
return pts_feature
|
||||||
|
|
||||||
|
class STNkd(nn.Module):
|
||||||
|
def __init__(self, k=64):
|
||||||
|
super(STNkd, self).__init__()
|
||||||
|
self.conv1 = torch.nn.Conv1d(k, 64, 1)
|
||||||
|
self.conv2 = torch.nn.Conv1d(64, 128, 1)
|
||||||
|
self.conv3 = torch.nn.Conv1d(128, 1024, 1)
|
||||||
|
self.fc1 = nn.Linear(1024, 512)
|
||||||
|
self.fc2 = nn.Linear(512, 256)
|
||||||
|
self.fc3 = nn.Linear(256, k * k)
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
|
||||||
|
self.k = k
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
batchsize = x.size()[0]
|
||||||
|
x = F.relu(self.conv1(x))
|
||||||
|
x = F.relu(self.conv2(x))
|
||||||
|
x = F.relu(self.conv3(x))
|
||||||
|
x = torch.max(x, 2, keepdim=True)[0]
|
||||||
|
x = x.view(-1, 1024)
|
||||||
|
|
||||||
|
x = F.relu(self.fc1(x))
|
||||||
|
x = F.relu(self.fc2(x))
|
||||||
|
x = self.fc3(x)
|
||||||
|
|
||||||
|
iden = (
|
||||||
|
Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32)))
|
||||||
|
.view(1, self.k * self.k)
|
||||||
|
.repeat(batchsize, 1)
|
||||||
|
)
|
||||||
|
if x.is_cuda:
|
||||||
|
iden = iden.to(x.get_device())
|
||||||
|
x = x + iden
|
||||||
|
x = x.view(-1, self.k, self.k)
|
||||||
|
return x
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
sim_data = Variable(torch.rand(32, 2500, 3))
|
sim_data = Variable(torch.rand(32, 2500, 3))
|
||||||
|
12
modules/seq_encoder/abstract_seq_encoder.py
Normal file
12
modules/seq_encoder/abstract_seq_encoder.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
from abc import abstractmethod
|
||||||
|
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
class SequenceEncoder(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(SequenceEncoder, self).__init__()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def encode_sequence(self, pts_embedding_list, pose_embedding_list):
|
||||||
|
pass
|
10
modules/seq_encoder/transformer_seq_encoder.py
Normal file
10
modules/seq_encoder/transformer_seq_encoder.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
from torch import nn
|
||||||
|
import PytorchBoot.stereotype as stereotype
|
||||||
|
|
||||||
|
@stereotype.module("transformer_seq_encoder")
|
||||||
|
class TransformerSequenceEncoder(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super(TransformerSequenceEncoder, self).__init__()
|
||||||
|
self.config = config
|
||||||
|
def encode_sequence(self, pts_embedding_list, pose_embedding_list):
|
||||||
|
pass
|
@ -8,5 +8,5 @@ class ViewFinder(nn.Module):
|
|||||||
super(ViewFinder, self).__init__()
|
super(ViewFinder, self).__init__()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def next_best_view(self, scene_pts_feat, target_pts_feat):
|
def next_best_view(self, seq_feat):
|
||||||
pass
|
pass
|
||||||
|
@ -19,18 +19,18 @@ def zero_module(module):
|
|||||||
|
|
||||||
@stereotype.module("gf_view_finder")
|
@stereotype.module("gf_view_finder")
|
||||||
class GradientFieldViewFinder(ViewFinder):
|
class GradientFieldViewFinder(ViewFinder):
|
||||||
def __init__(self, pose_mode='rot_matrix', regression_head='Rx_Ry', per_point_feature=False,
|
def __init__(self, config):
|
||||||
sample_mode="ode", sampling_steps=None, sde_mode="ve"):
|
|
||||||
|
|
||||||
super(GradientFieldViewFinder, self).__init__()
|
super(GradientFieldViewFinder, self).__init__()
|
||||||
self.regression_head = regression_head
|
|
||||||
self.per_point_feature = per_point_feature
|
self.regression_head = config["regression_head"]
|
||||||
|
self.per_point_feature = config["per_point_feature"]
|
||||||
self.act = nn.ReLU(True)
|
self.act = nn.ReLU(True)
|
||||||
self.sample_mode = sample_mode
|
self.sample_mode = config["sample_mode"]
|
||||||
self.pose_mode = pose_mode
|
self.pose_mode = config["pose_mode"]
|
||||||
pose_dim = PoseUtil.get_pose_dim(pose_mode)
|
pose_dim = PoseUtil.get_pose_dim(self.pose_mode)
|
||||||
self.prior_fn, self.marginal_prob_fn, self.sde_fn, self.sampling_eps, self.T = flib.init_sde(sde_mode)
|
self.prior_fn, self.marginal_prob_fn, self.sde_fn, self.sampling_eps, self.T = flib.init_sde(config["sde_mode"])
|
||||||
self.sampling_steps = sampling_steps
|
self.sampling_steps = config["sampling_steps"]
|
||||||
|
|
||||||
''' encode pose '''
|
''' encode pose '''
|
||||||
self.pose_encoder = nn.Sequential(
|
self.pose_encoder = nn.Sequential(
|
||||||
@ -49,9 +49,9 @@ class GradientFieldViewFinder(ViewFinder):
|
|||||||
|
|
||||||
''' fusion tail '''
|
''' fusion tail '''
|
||||||
if self.regression_head == 'Rx_Ry':
|
if self.regression_head == 'Rx_Ry':
|
||||||
if pose_mode != 'rot_matrix':
|
if self.pose_mode != 'rot_matrix':
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
if not per_point_feature:
|
if not self.per_point_feature:
|
||||||
''' rotation_x_axis regress head '''
|
''' rotation_x_axis regress head '''
|
||||||
self.fusion_tail_rot_x = nn.Sequential(
|
self.fusion_tail_rot_x = nn.Sequential(
|
||||||
nn.Linear(128 + 256 + 1024 + 1024, 256),
|
nn.Linear(128 + 256 + 1024 + 1024, 256),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user