update modules and pipeline
This commit is contained in:
@@ -8,5 +8,5 @@ class ViewFinder(nn.Module):
|
||||
super(ViewFinder, self).__init__()
|
||||
|
||||
@abstractmethod
|
||||
def next_best_view(self, scene_pts_feat, target_pts_feat):
|
||||
def next_best_view(self, seq_feat):
|
||||
pass
|
||||
|
@@ -19,18 +19,18 @@ def zero_module(module):
|
||||
|
||||
@stereotype.module("gf_view_finder")
|
||||
class GradientFieldViewFinder(ViewFinder):
|
||||
def __init__(self, pose_mode='rot_matrix', regression_head='Rx_Ry', per_point_feature=False,
|
||||
sample_mode="ode", sampling_steps=None, sde_mode="ve"):
|
||||
def __init__(self, config):
|
||||
|
||||
super(GradientFieldViewFinder, self).__init__()
|
||||
self.regression_head = regression_head
|
||||
self.per_point_feature = per_point_feature
|
||||
|
||||
self.regression_head = config["regression_head"]
|
||||
self.per_point_feature = config["per_point_feature"]
|
||||
self.act = nn.ReLU(True)
|
||||
self.sample_mode = sample_mode
|
||||
self.pose_mode = pose_mode
|
||||
pose_dim = PoseUtil.get_pose_dim(pose_mode)
|
||||
self.prior_fn, self.marginal_prob_fn, self.sde_fn, self.sampling_eps, self.T = flib.init_sde(sde_mode)
|
||||
self.sampling_steps = sampling_steps
|
||||
self.sample_mode = config["sample_mode"]
|
||||
self.pose_mode = config["pose_mode"]
|
||||
pose_dim = PoseUtil.get_pose_dim(self.pose_mode)
|
||||
self.prior_fn, self.marginal_prob_fn, self.sde_fn, self.sampling_eps, self.T = flib.init_sde(config["sde_mode"])
|
||||
self.sampling_steps = config["sampling_steps"]
|
||||
|
||||
''' encode pose '''
|
||||
self.pose_encoder = nn.Sequential(
|
||||
@@ -49,9 +49,9 @@ class GradientFieldViewFinder(ViewFinder):
|
||||
|
||||
''' fusion tail '''
|
||||
if self.regression_head == 'Rx_Ry':
|
||||
if pose_mode != 'rot_matrix':
|
||||
if self.pose_mode != 'rot_matrix':
|
||||
raise NotImplementedError
|
||||
if not per_point_feature:
|
||||
if not self.per_point_feature:
|
||||
''' rotation_x_axis regress head '''
|
||||
self.fusion_tail_rot_x = nn.Sequential(
|
||||
nn.Linear(128 + 256 + 1024 + 1024, 256),
|
||||
|
Reference in New Issue
Block a user