diff --git a/core/pipeline.py b/core/pipeline.py new file mode 100644 index 0000000..b7256f5 --- /dev/null +++ b/core/pipeline.py @@ -0,0 +1,29 @@ + +from torch import nn +import PytorchBoot.namespace as namespace +import PytorchBoot.stereotype as stereotype +from PytorchBoot.factory.component_factory import ComponentFactory + +@stereotype.pipeline("nbv_reconstruction_pipeline") +class ViewFinderPipeline(nn.Module): + def __init__(self, config): + super(ViewFinderPipeline, self).__init__() + self.config = config + self.pts_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, config["pts_encoder"]) + self.pose_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, config["pose_encoder"]) + self.seq_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, config["seq_encoder"]) + self.view_finder = ComponentFactory.create(namespace.Stereotype.MODULE, config["view_finder"]) + + def forward(self, data): + output = {} + pts_list = data['pts_list'] + pose_list = data['pose_list'] + pts_feat_list = [] + pose_feat_list = [] + for pts,pose in zip(pts_list,pose_list): + pts_feat_list.append(self.pts_encoder.encode_points(pts)) + pose_feat_list.append(self.pose_encoder.encode_pose(pose)) + seq_feat = self.seq_encoder.encode_sequence(pts_feat_list, pose_feat_list) + output['next_best_view'] = self.view_finder.next_best_view(seq_feat) + + return output \ No newline at end of file diff --git a/modules/module_lib/__init__.py b/modules/module_lib/__init__.py new file mode 100644 index 0000000..0d9f7bf --- /dev/null +++ b/modules/module_lib/__init__.py @@ -0,0 +1,2 @@ +from modules.module_lib.gaussian_fourier_projection import GaussianFourierProjection +from modules.module_lib.linear import Linear \ No newline at end of file diff --git a/modules/pose_encoder/pose_encoder.py b/modules/pose_encoder/pose_encoder.py new file mode 100644 index 0000000..aeeeb85 --- /dev/null +++ b/modules/pose_encoder/pose_encoder.py @@ -0,0 +1,20 @@ +from torch import nn +import PytorchBoot.stereotype as stereotype + +@stereotype.module("pose_encoder") +class PoseEncoder(nn.Module): + def __init__(self, config): + super(PoseEncoder, self).__init__() + self.config = config + pose_dim = config["pose_dim"] + self.act = nn.ReLU(True) + + self.pose_encoder = nn.Sequential( + nn.Linear(pose_dim, 256), + self.act, + nn.Linear(256, 256), + self.act, + ) + + def encode_pose(self, pose): + return self.pose_encoder(pose) diff --git a/modules/pts_encoder/pointnet_encoder.py b/modules/pts_encoder/pointnet_encoder.py index 8e30261..e4e5110 100644 --- a/modules/pts_encoder/pointnet_encoder.py +++ b/modules/pts_encoder/pointnet_encoder.py @@ -9,57 +9,21 @@ import torch.nn.functional as F from modules.pts_encoder.abstract_pts_encoder import PointsEncoder import PytorchBoot.stereotype as stereotype - -class STNkd(nn.Module): - def __init__(self, k=64): - super(STNkd, self).__init__() - self.conv1 = torch.nn.Conv1d(k, 64, 1) - self.conv2 = torch.nn.Conv1d(64, 128, 1) - self.conv3 = torch.nn.Conv1d(128, 1024, 1) - self.fc1 = nn.Linear(1024, 512) - self.fc2 = nn.Linear(512, 256) - self.fc3 = nn.Linear(256, k * k) - self.relu = nn.ReLU() - - self.k = k - - def forward(self, x): - batchsize = x.size()[0] - x = F.relu(self.conv1(x)) - x = F.relu(self.conv2(x)) - x = F.relu(self.conv3(x)) - x = torch.max(x, 2, keepdim=True)[0] - x = x.view(-1, 1024) - - x = F.relu(self.fc1(x)) - x = F.relu(self.fc2(x)) - x = self.fc3(x) - - iden = ( - Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))) - .view(1, self.k * self.k) - .repeat(batchsize, 1) - ) - if x.is_cuda: - iden = iden.to(x.get_device()) - x = x + iden - x = x.view(-1, self.k, self.k) - return x - - @stereotype.module("pointnet_encoder") class PointNetEncoder(PointsEncoder): - def __init__(self, global_feat=True, in_dim=3, out_dim=1024, feature_transform=False): + def __init__(self, config:dict): super(PointNetEncoder, self).__init__() - self.out_dim = out_dim - self.feature_transform = feature_transform - self.stn = STNkd(k=in_dim) - self.conv1 = torch.nn.Conv1d(in_dim, 64, 1) + + self.out_dim = config["out_dim"] + self.in_dim = config["in_dim"] + self.feature_transform = config.get("feature_transform", False) + self.stn = STNkd(k=self.in_dim) + self.conv1 = torch.nn.Conv1d(self.in_dim , 64, 1) self.conv2 = torch.nn.Conv1d(64, 128, 1) self.conv3 = torch.nn.Conv1d(128, 512, 1) - self.conv4 = torch.nn.Conv1d(512, out_dim, 1) - self.global_feat = global_feat + self.conv4 = torch.nn.Conv1d(512, self.out_dim , 1) + self.global_feat = config["global_feat"] if self.feature_transform: self.f_stn = STNkd(k=64) @@ -97,6 +61,41 @@ class PointNetEncoder(PointsEncoder): pts_feature = self(pts) return pts_feature +class STNkd(nn.Module): + def __init__(self, k=64): + super(STNkd, self).__init__() + self.conv1 = torch.nn.Conv1d(k, 64, 1) + self.conv2 = torch.nn.Conv1d(64, 128, 1) + self.conv3 = torch.nn.Conv1d(128, 1024, 1) + self.fc1 = nn.Linear(1024, 512) + self.fc2 = nn.Linear(512, 256) + self.fc3 = nn.Linear(256, k * k) + self.relu = nn.ReLU() + + self.k = k + + def forward(self, x): + batchsize = x.size()[0] + x = F.relu(self.conv1(x)) + x = F.relu(self.conv2(x)) + x = F.relu(self.conv3(x)) + x = torch.max(x, 2, keepdim=True)[0] + x = x.view(-1, 1024) + + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + + iden = ( + Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))) + .view(1, self.k * self.k) + .repeat(batchsize, 1) + ) + if x.is_cuda: + iden = iden.to(x.get_device()) + x = x + iden + x = x.view(-1, self.k, self.k) + return x if __name__ == "__main__": sim_data = Variable(torch.rand(32, 2500, 3)) diff --git a/modules/seq_encoder/abstract_seq_encoder.py b/modules/seq_encoder/abstract_seq_encoder.py new file mode 100644 index 0000000..5c4a8ba --- /dev/null +++ b/modules/seq_encoder/abstract_seq_encoder.py @@ -0,0 +1,12 @@ +from abc import abstractmethod + +from torch import nn + + +class SequenceEncoder(nn.Module): + def __init__(self): + super(SequenceEncoder, self).__init__() + + @abstractmethod + def encode_sequence(self, pts_embedding_list, pose_embedding_list): + pass diff --git a/modules/seq_encoder/transformer_seq_encoder.py b/modules/seq_encoder/transformer_seq_encoder.py new file mode 100644 index 0000000..42af3b4 --- /dev/null +++ b/modules/seq_encoder/transformer_seq_encoder.py @@ -0,0 +1,10 @@ +from torch import nn +import PytorchBoot.stereotype as stereotype + +@stereotype.module("transformer_seq_encoder") +class TransformerSequenceEncoder(nn.Module): + def __init__(self, config): + super(TransformerSequenceEncoder, self).__init__() + self.config = config + def encode_sequence(self, pts_embedding_list, pose_embedding_list): + pass diff --git a/modules/view_finder/abstract_view_finder.py b/modules/view_finder/abstract_view_finder.py index b688c16..1516aad 100644 --- a/modules/view_finder/abstract_view_finder.py +++ b/modules/view_finder/abstract_view_finder.py @@ -8,5 +8,5 @@ class ViewFinder(nn.Module): super(ViewFinder, self).__init__() @abstractmethod - def next_best_view(self, scene_pts_feat, target_pts_feat): + def next_best_view(self, seq_feat): pass diff --git a/modules/view_finder/gf_view_finder.py b/modules/view_finder/gf_view_finder.py index 030760d..47b2b6c 100644 --- a/modules/view_finder/gf_view_finder.py +++ b/modules/view_finder/gf_view_finder.py @@ -19,18 +19,18 @@ def zero_module(module): @stereotype.module("gf_view_finder") class GradientFieldViewFinder(ViewFinder): - def __init__(self, pose_mode='rot_matrix', regression_head='Rx_Ry', per_point_feature=False, - sample_mode="ode", sampling_steps=None, sde_mode="ve"): + def __init__(self, config): super(GradientFieldViewFinder, self).__init__() - self.regression_head = regression_head - self.per_point_feature = per_point_feature + + self.regression_head = config["regression_head"] + self.per_point_feature = config["per_point_feature"] self.act = nn.ReLU(True) - self.sample_mode = sample_mode - self.pose_mode = pose_mode - pose_dim = PoseUtil.get_pose_dim(pose_mode) - self.prior_fn, self.marginal_prob_fn, self.sde_fn, self.sampling_eps, self.T = flib.init_sde(sde_mode) - self.sampling_steps = sampling_steps + self.sample_mode = config["sample_mode"] + self.pose_mode = config["pose_mode"] + pose_dim = PoseUtil.get_pose_dim(self.pose_mode) + self.prior_fn, self.marginal_prob_fn, self.sde_fn, self.sampling_eps, self.T = flib.init_sde(config["sde_mode"]) + self.sampling_steps = config["sampling_steps"] ''' encode pose ''' self.pose_encoder = nn.Sequential( @@ -49,9 +49,9 @@ class GradientFieldViewFinder(ViewFinder): ''' fusion tail ''' if self.regression_head == 'Rx_Ry': - if pose_mode != 'rot_matrix': + if self.pose_mode != 'rot_matrix': raise NotImplementedError - if not per_point_feature: + if not self.per_point_feature: ''' rotation_x_axis regress head ''' self.fusion_tail_rot_x = nn.Sequential( nn.Linear(128 + 256 + 1024 + 1024, 256),