update modules and pipeline

This commit is contained in:
hofee 2024-08-21 17:57:52 +08:00
parent 913d4e521d
commit 837e1c870a
8 changed files with 129 additions and 57 deletions

29
core/pipeline.py Normal file
View File

@ -0,0 +1,29 @@
from torch import nn
import PytorchBoot.namespace as namespace
import PytorchBoot.stereotype as stereotype
from PytorchBoot.factory.component_factory import ComponentFactory
@stereotype.pipeline("nbv_reconstruction_pipeline")
class ViewFinderPipeline(nn.Module):
def __init__(self, config):
super(ViewFinderPipeline, self).__init__()
self.config = config
self.pts_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, config["pts_encoder"])
self.pose_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, config["pose_encoder"])
self.seq_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, config["seq_encoder"])
self.view_finder = ComponentFactory.create(namespace.Stereotype.MODULE, config["view_finder"])
def forward(self, data):
output = {}
pts_list = data['pts_list']
pose_list = data['pose_list']
pts_feat_list = []
pose_feat_list = []
for pts,pose in zip(pts_list,pose_list):
pts_feat_list.append(self.pts_encoder.encode_points(pts))
pose_feat_list.append(self.pose_encoder.encode_pose(pose))
seq_feat = self.seq_encoder.encode_sequence(pts_feat_list, pose_feat_list)
output['next_best_view'] = self.view_finder.next_best_view(seq_feat)
return output

View File

@ -0,0 +1,2 @@
from modules.module_lib.gaussian_fourier_projection import GaussianFourierProjection
from modules.module_lib.linear import Linear

View File

@ -0,0 +1,20 @@
from torch import nn
import PytorchBoot.stereotype as stereotype
@stereotype.module("pose_encoder")
class PoseEncoder(nn.Module):
def __init__(self, config):
super(PoseEncoder, self).__init__()
self.config = config
pose_dim = config["pose_dim"]
self.act = nn.ReLU(True)
self.pose_encoder = nn.Sequential(
nn.Linear(pose_dim, 256),
self.act,
nn.Linear(256, 256),
self.act,
)
def encode_pose(self, pose):
return self.pose_encoder(pose)

View File

@ -9,57 +9,21 @@ import torch.nn.functional as F
from modules.pts_encoder.abstract_pts_encoder import PointsEncoder from modules.pts_encoder.abstract_pts_encoder import PointsEncoder
import PytorchBoot.stereotype as stereotype import PytorchBoot.stereotype as stereotype
class STNkd(nn.Module):
def __init__(self, k=64):
super(STNkd, self).__init__()
self.conv1 = torch.nn.Conv1d(k, 64, 1)
self.conv2 = torch.nn.Conv1d(64, 128, 1)
self.conv3 = torch.nn.Conv1d(128, 1024, 1)
self.fc1 = nn.Linear(1024, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, k * k)
self.relu = nn.ReLU()
self.k = k
def forward(self, x):
batchsize = x.size()[0]
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = torch.max(x, 2, keepdim=True)[0]
x = x.view(-1, 1024)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
iden = (
Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32)))
.view(1, self.k * self.k)
.repeat(batchsize, 1)
)
if x.is_cuda:
iden = iden.to(x.get_device())
x = x + iden
x = x.view(-1, self.k, self.k)
return x
@stereotype.module("pointnet_encoder") @stereotype.module("pointnet_encoder")
class PointNetEncoder(PointsEncoder): class PointNetEncoder(PointsEncoder):
def __init__(self, global_feat=True, in_dim=3, out_dim=1024, feature_transform=False): def __init__(self, config:dict):
super(PointNetEncoder, self).__init__() super(PointNetEncoder, self).__init__()
self.out_dim = out_dim
self.feature_transform = feature_transform self.out_dim = config["out_dim"]
self.stn = STNkd(k=in_dim) self.in_dim = config["in_dim"]
self.conv1 = torch.nn.Conv1d(in_dim, 64, 1) self.feature_transform = config.get("feature_transform", False)
self.stn = STNkd(k=self.in_dim)
self.conv1 = torch.nn.Conv1d(self.in_dim , 64, 1)
self.conv2 = torch.nn.Conv1d(64, 128, 1) self.conv2 = torch.nn.Conv1d(64, 128, 1)
self.conv3 = torch.nn.Conv1d(128, 512, 1) self.conv3 = torch.nn.Conv1d(128, 512, 1)
self.conv4 = torch.nn.Conv1d(512, out_dim, 1) self.conv4 = torch.nn.Conv1d(512, self.out_dim , 1)
self.global_feat = global_feat self.global_feat = config["global_feat"]
if self.feature_transform: if self.feature_transform:
self.f_stn = STNkd(k=64) self.f_stn = STNkd(k=64)
@ -97,6 +61,41 @@ class PointNetEncoder(PointsEncoder):
pts_feature = self(pts) pts_feature = self(pts)
return pts_feature return pts_feature
class STNkd(nn.Module):
def __init__(self, k=64):
super(STNkd, self).__init__()
self.conv1 = torch.nn.Conv1d(k, 64, 1)
self.conv2 = torch.nn.Conv1d(64, 128, 1)
self.conv3 = torch.nn.Conv1d(128, 1024, 1)
self.fc1 = nn.Linear(1024, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, k * k)
self.relu = nn.ReLU()
self.k = k
def forward(self, x):
batchsize = x.size()[0]
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = torch.max(x, 2, keepdim=True)[0]
x = x.view(-1, 1024)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
iden = (
Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32)))
.view(1, self.k * self.k)
.repeat(batchsize, 1)
)
if x.is_cuda:
iden = iden.to(x.get_device())
x = x + iden
x = x.view(-1, self.k, self.k)
return x
if __name__ == "__main__": if __name__ == "__main__":
sim_data = Variable(torch.rand(32, 2500, 3)) sim_data = Variable(torch.rand(32, 2500, 3))

View File

@ -0,0 +1,12 @@
from abc import abstractmethod
from torch import nn
class SequenceEncoder(nn.Module):
def __init__(self):
super(SequenceEncoder, self).__init__()
@abstractmethod
def encode_sequence(self, pts_embedding_list, pose_embedding_list):
pass

View File

@ -0,0 +1,10 @@
from torch import nn
import PytorchBoot.stereotype as stereotype
@stereotype.module("transformer_seq_encoder")
class TransformerSequenceEncoder(nn.Module):
def __init__(self, config):
super(TransformerSequenceEncoder, self).__init__()
self.config = config
def encode_sequence(self, pts_embedding_list, pose_embedding_list):
pass

View File

@ -8,5 +8,5 @@ class ViewFinder(nn.Module):
super(ViewFinder, self).__init__() super(ViewFinder, self).__init__()
@abstractmethod @abstractmethod
def next_best_view(self, scene_pts_feat, target_pts_feat): def next_best_view(self, seq_feat):
pass pass

View File

@ -19,18 +19,18 @@ def zero_module(module):
@stereotype.module("gf_view_finder") @stereotype.module("gf_view_finder")
class GradientFieldViewFinder(ViewFinder): class GradientFieldViewFinder(ViewFinder):
def __init__(self, pose_mode='rot_matrix', regression_head='Rx_Ry', per_point_feature=False, def __init__(self, config):
sample_mode="ode", sampling_steps=None, sde_mode="ve"):
super(GradientFieldViewFinder, self).__init__() super(GradientFieldViewFinder, self).__init__()
self.regression_head = regression_head
self.per_point_feature = per_point_feature self.regression_head = config["regression_head"]
self.per_point_feature = config["per_point_feature"]
self.act = nn.ReLU(True) self.act = nn.ReLU(True)
self.sample_mode = sample_mode self.sample_mode = config["sample_mode"]
self.pose_mode = pose_mode self.pose_mode = config["pose_mode"]
pose_dim = PoseUtil.get_pose_dim(pose_mode) pose_dim = PoseUtil.get_pose_dim(self.pose_mode)
self.prior_fn, self.marginal_prob_fn, self.sde_fn, self.sampling_eps, self.T = flib.init_sde(sde_mode) self.prior_fn, self.marginal_prob_fn, self.sde_fn, self.sampling_eps, self.T = flib.init_sde(config["sde_mode"])
self.sampling_steps = sampling_steps self.sampling_steps = config["sampling_steps"]
''' encode pose ''' ''' encode pose '''
self.pose_encoder = nn.Sequential( self.pose_encoder = nn.Sequential(
@ -49,9 +49,9 @@ class GradientFieldViewFinder(ViewFinder):
''' fusion tail ''' ''' fusion tail '''
if self.regression_head == 'Rx_Ry': if self.regression_head == 'Rx_Ry':
if pose_mode != 'rot_matrix': if self.pose_mode != 'rot_matrix':
raise NotImplementedError raise NotImplementedError
if not per_point_feature: if not self.per_point_feature:
''' rotation_x_axis regress head ''' ''' rotation_x_axis regress head '''
self.fusion_tail_rot_x = nn.Sequential( self.fusion_tail_rot_x = nn.Sequential(
nn.Linear(128 + 256 + 1024 + 1024, 256), nn.Linear(128 + 256 + 1024 + 1024, 256),