first commit
This commit is contained in:
91
modules/mlp_view_finder.py
Normal file
91
modules/mlp_view_finder.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import PytorchBoot.stereotype as stereotype
|
||||
|
||||
from utils.pose import PoseUtil
|
||||
import modules.module_lib as mlib
|
||||
import modules.func_lib as flib
|
||||
|
||||
def zero_module(module):
|
||||
"""
|
||||
Zero out the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
@stereotype.module("mlp_view_finder")
|
||||
class MLPViewFinder(nn.Module):
|
||||
def __init__(self, config):
|
||||
|
||||
super(MLPViewFinder, self).__init__()
|
||||
|
||||
self.regression_head = 'Rx_Ry_and_T'
|
||||
self.per_point_feature = False
|
||||
self.act = nn.ReLU(True)
|
||||
self.main_feat_dim = config["main_feat_dim"]
|
||||
|
||||
''' rotation_x_axis regress head '''
|
||||
self.fusion_tail_rot_x = nn.Sequential(
|
||||
nn.Linear(self.main_feat_dim, 256),
|
||||
self.act,
|
||||
zero_module(nn.Linear(256, 3)),
|
||||
)
|
||||
self.fusion_tail_rot_y = nn.Sequential(
|
||||
nn.Linear(self.main_feat_dim, 256),
|
||||
self.act,
|
||||
zero_module(nn.Linear(256, 3)),
|
||||
)
|
||||
''' tranalation regress head '''
|
||||
self.fusion_tail_trans = nn.Sequential(
|
||||
nn.Linear(self.main_feat_dim, 256),
|
||||
self.act,
|
||||
zero_module(nn.Linear(256, 3)),
|
||||
)
|
||||
|
||||
|
||||
def forward(self, data):
|
||||
"""
|
||||
Args:
|
||||
data, dict {
|
||||
'main_feat': [bs, c]
|
||||
}
|
||||
"""
|
||||
|
||||
total_feat = data['main_feat']
|
||||
rot_x = self.fusion_tail_rot_x(total_feat)
|
||||
rot_y = self.fusion_tail_rot_y(total_feat)
|
||||
trans = self.fusion_tail_trans(total_feat)
|
||||
output = torch.cat([rot_x,rot_y,trans], dim=-1)
|
||||
return output
|
||||
|
||||
def next_best_view(self, main_feat):
|
||||
data = {
|
||||
'main_feat': main_feat,
|
||||
}
|
||||
res = self(data)
|
||||
return res.to(dtype=torch.float32), None
|
||||
|
||||
''' ----------- DEBUG -----------'''
|
||||
if __name__ == "__main__":
|
||||
config = {
|
||||
"regression_head": "Rx_Ry_and_T",
|
||||
"per_point_feature": False,
|
||||
"pose_mode": "rot_matrix",
|
||||
"sde_mode": "ve",
|
||||
"sampling_steps": 500,
|
||||
"sample_mode": "ode"
|
||||
}
|
||||
test_seq_feat = torch.rand(32, 2048).to("cuda:0")
|
||||
test_pose = torch.rand(32, 9).to("cuda:0")
|
||||
test_t = torch.rand(32, 1).to("cuda:0")
|
||||
view_finder = GradientFieldViewFinder(config).to("cuda:0")
|
||||
test_data = {
|
||||
'seq_feat': test_seq_feat,
|
||||
'sampled_pose': test_pose,
|
||||
't': test_t
|
||||
}
|
||||
score = view_finder(test_data)
|
||||
print(score.shape)
|
||||
res, inprocess = view_finder.next_best_view(test_seq_feat)
|
||||
print(res.shape, inprocess.shape)
|
Reference in New Issue
Block a user