update gf_view_finder
This commit is contained in:
@@ -88,8 +88,8 @@ def cond_ode_sampler(
|
||||
x = mean_x
|
||||
|
||||
num_steps = xs.shape[0]
|
||||
xs = xs.reshape(batch_size * num_steps, -1)
|
||||
xs = PoseUtil.normalize_rotation(xs, pose_mode)
|
||||
xs = xs.reshape(batch_size*num_steps, -1)
|
||||
xs[:, :-3] = PoseUtil.normalize_rotation(xs[:, :-3], pose_mode)
|
||||
xs = xs.reshape(num_steps, batch_size, -1)
|
||||
x = PoseUtil.normalize_rotation(x, pose_mode)
|
||||
x[:, :-3] = PoseUtil.normalize_rotation(x[:, :-3], pose_mode)
|
||||
return xs.permute(1, 0, 2), x
|
||||
|
@@ -2,6 +2,9 @@ import torch
|
||||
import torch.nn as nn
|
||||
import PytorchBoot.stereotype as stereotype
|
||||
|
||||
import sys
|
||||
sys.path.append(r"C:\Document\Local Project\nbv_rec\nbv_reconstruction")
|
||||
|
||||
from utils.pose import PoseUtil
|
||||
import modules.module_lib as mlib
|
||||
import modules.func_lib as flib
|
||||
@@ -47,7 +50,7 @@ class GradientFieldViewFinder(nn.Module):
|
||||
)
|
||||
|
||||
''' fusion tail '''
|
||||
if self.regression_head == 'Rx_Ry':
|
||||
if self.regression_head == 'Rx_Ry_and_T':
|
||||
if self.pose_mode != 'rot_matrix':
|
||||
raise NotImplementedError
|
||||
if not self.per_point_feature:
|
||||
@@ -62,6 +65,12 @@ class GradientFieldViewFinder(nn.Module):
|
||||
self.act,
|
||||
zero_module(nn.Linear(256, 3)),
|
||||
)
|
||||
''' tranalation regress head '''
|
||||
self.fusion_tail_trans = nn.Sequential(
|
||||
nn.Linear(128 + 256 + 2048, 256),
|
||||
self.act,
|
||||
zero_module(nn.Linear(256, 3)),
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
@@ -89,10 +98,11 @@ class GradientFieldViewFinder(nn.Module):
|
||||
total_feat = torch.cat([seq_feat, t_feat, pose_feat], dim=-1)
|
||||
_, std = self.marginal_prob_fn(total_feat, t)
|
||||
|
||||
if self.regression_head == 'Rx_Ry':
|
||||
if self.regression_head == 'Rx_Ry_and_T':
|
||||
rot_x = self.fusion_tail_rot_x(total_feat)
|
||||
rot_y = self.fusion_tail_rot_y(total_feat)
|
||||
out_score = torch.cat([rot_x, rot_y], dim=-1) / (std + 1e-7) # normalisation
|
||||
trans = self.fusion_tail_trans(total_feat)
|
||||
out_score = torch.cat([rot_x, rot_y, trans], dim=-1) / (std+1e-7) # normalisation
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -134,18 +144,24 @@ class GradientFieldViewFinder(nn.Module):
|
||||
|
||||
''' ----------- DEBUG -----------'''
|
||||
if __name__ == "__main__":
|
||||
test_scene_feat = torch.rand(32, 1024).to("cuda:0")
|
||||
test_target_feat = torch.rand(32, 1024).to("cuda:0")
|
||||
test_pose = torch.rand(32, 6).to("cuda:0")
|
||||
config = {
|
||||
"regression_head": "Rx_Ry_and_T",
|
||||
"per_point_feature": False,
|
||||
"pose_mode": "rot_matrix",
|
||||
"sde_mode": "ve",
|
||||
"sampling_steps": 500,
|
||||
"sample_mode": "ode"
|
||||
}
|
||||
test_seq_feat = torch.rand(32, 2048).to("cuda:0")
|
||||
test_pose = torch.rand(32, 9).to("cuda:0")
|
||||
test_t = torch.rand(32, 1).to("cuda:0")
|
||||
view_finder = GradientFieldViewFinder().to("cuda:0")
|
||||
view_finder = GradientFieldViewFinder(config).to("cuda:0")
|
||||
test_data = {
|
||||
'target_feat': test_target_feat,
|
||||
'scene_feat': test_scene_feat,
|
||||
'seq_feat': test_seq_feat,
|
||||
'sampled_pose': test_pose,
|
||||
't': test_t
|
||||
}
|
||||
score = view_finder(test_data)
|
||||
|
||||
result = view_finder.next_best_view(test_scene_feat, test_target_feat)
|
||||
print(result)
|
||||
print(score.shape)
|
||||
res, inprocess = view_finder.next_best_view(test_seq_feat)
|
||||
print(res.shape, inprocess.shape)
|
||||
|
Reference in New Issue
Block a user