update modules and pipeline

This commit is contained in:
hofee
2024-08-21 17:57:52 +08:00
parent 913d4e521d
commit 837e1c870a
8 changed files with 129 additions and 57 deletions

View File

@@ -8,5 +8,5 @@ class ViewFinder(nn.Module):
super(ViewFinder, self).__init__()
@abstractmethod
def next_best_view(self, scene_pts_feat, target_pts_feat):
def next_best_view(self, seq_feat):
pass

View File

@@ -19,18 +19,18 @@ def zero_module(module):
@stereotype.module("gf_view_finder")
class GradientFieldViewFinder(ViewFinder):
def __init__(self, pose_mode='rot_matrix', regression_head='Rx_Ry', per_point_feature=False,
sample_mode="ode", sampling_steps=None, sde_mode="ve"):
def __init__(self, config):
super(GradientFieldViewFinder, self).__init__()
self.regression_head = regression_head
self.per_point_feature = per_point_feature
self.regression_head = config["regression_head"]
self.per_point_feature = config["per_point_feature"]
self.act = nn.ReLU(True)
self.sample_mode = sample_mode
self.pose_mode = pose_mode
pose_dim = PoseUtil.get_pose_dim(pose_mode)
self.prior_fn, self.marginal_prob_fn, self.sde_fn, self.sampling_eps, self.T = flib.init_sde(sde_mode)
self.sampling_steps = sampling_steps
self.sample_mode = config["sample_mode"]
self.pose_mode = config["pose_mode"]
pose_dim = PoseUtil.get_pose_dim(self.pose_mode)
self.prior_fn, self.marginal_prob_fn, self.sde_fn, self.sampling_eps, self.T = flib.init_sde(config["sde_mode"])
self.sampling_steps = config["sampling_steps"]
''' encode pose '''
self.pose_encoder = nn.Sequential(
@@ -49,9 +49,9 @@ class GradientFieldViewFinder(ViewFinder):
''' fusion tail '''
if self.regression_head == 'Rx_Ry':
if pose_mode != 'rot_matrix':
if self.pose_mode != 'rot_matrix':
raise NotImplementedError
if not per_point_feature:
if not self.per_point_feature:
''' rotation_x_axis regress head '''
self.fusion_tail_rot_x = nn.Sequential(
nn.Linear(128 + 256 + 1024 + 1024, 256),