fix bug for training
This commit is contained in:
@@ -33,19 +33,22 @@ class GradientFieldViewFinder(nn.Module):
|
||||
pose_dim = PoseUtil.get_pose_dim(self.pose_mode)
|
||||
self.prior_fn, self.marginal_prob_fn, self.sde_fn, self.sampling_eps, self.T = flib.init_sde(config["sde_mode"])
|
||||
self.sampling_steps = config["sampling_steps"]
|
||||
self.t_feat_dim = config["t_feat_dim"]
|
||||
self.pose_feat_dim = config["pose_feat_dim"]
|
||||
self.main_feat_dim = config["main_feat_dim"]
|
||||
|
||||
''' encode pose '''
|
||||
self.pose_encoder = nn.Sequential(
|
||||
nn.Linear(pose_dim, 256),
|
||||
nn.Linear(pose_dim, self.pose_feat_dim ),
|
||||
self.act,
|
||||
nn.Linear(256, 256),
|
||||
nn.Linear(self.pose_feat_dim , self.pose_feat_dim ),
|
||||
self.act,
|
||||
)
|
||||
|
||||
''' encode t '''
|
||||
self.t_encoder = nn.Sequential(
|
||||
mlib.GaussianFourierProjection(embed_dim=128),
|
||||
nn.Linear(128, 128),
|
||||
mlib.GaussianFourierProjection(embed_dim=self.t_feat_dim ),
|
||||
nn.Linear(self.t_feat_dim , self.t_feat_dim ),
|
||||
self.act,
|
||||
)
|
||||
|
||||
@@ -56,18 +59,18 @@ class GradientFieldViewFinder(nn.Module):
|
||||
if not self.per_point_feature:
|
||||
''' rotation_x_axis regress head '''
|
||||
self.fusion_tail_rot_x = nn.Sequential(
|
||||
nn.Linear(128 + 256 + 2048, 256),
|
||||
nn.Linear(self.t_feat_dim + self.pose_feat_dim + self.main_feat_dim, 256),
|
||||
self.act,
|
||||
zero_module(nn.Linear(256, 3)),
|
||||
)
|
||||
self.fusion_tail_rot_y = nn.Sequential(
|
||||
nn.Linear(128 + 256 + 2048, 256),
|
||||
nn.Linear(self.t_feat_dim + self.pose_feat_dim + self.main_feat_dim, 256),
|
||||
self.act,
|
||||
zero_module(nn.Linear(256, 3)),
|
||||
)
|
||||
''' tranalation regress head '''
|
||||
self.fusion_tail_trans = nn.Sequential(
|
||||
nn.Linear(128 + 256 + 2048, 256),
|
||||
nn.Linear(self.t_feat_dim + self.pose_feat_dim + self.main_feat_dim, 256),
|
||||
self.act,
|
||||
zero_module(nn.Linear(256, 3)),
|
||||
)
|
||||
|
Reference in New Issue
Block a user