fix bug for training

This commit is contained in:
2024-09-12 15:11:09 +08:00
parent a79ca7749d
commit 4c69ed777b
15 changed files with 201 additions and 120 deletions

View File

@@ -33,19 +33,22 @@ class GradientFieldViewFinder(nn.Module):
pose_dim = PoseUtil.get_pose_dim(self.pose_mode)
self.prior_fn, self.marginal_prob_fn, self.sde_fn, self.sampling_eps, self.T = flib.init_sde(config["sde_mode"])
self.sampling_steps = config["sampling_steps"]
self.t_feat_dim = config["t_feat_dim"]
self.pose_feat_dim = config["pose_feat_dim"]
self.main_feat_dim = config["main_feat_dim"]
''' encode pose '''
self.pose_encoder = nn.Sequential(
nn.Linear(pose_dim, 256),
nn.Linear(pose_dim, self.pose_feat_dim ),
self.act,
nn.Linear(256, 256),
nn.Linear(self.pose_feat_dim , self.pose_feat_dim ),
self.act,
)
''' encode t '''
self.t_encoder = nn.Sequential(
mlib.GaussianFourierProjection(embed_dim=128),
nn.Linear(128, 128),
mlib.GaussianFourierProjection(embed_dim=self.t_feat_dim ),
nn.Linear(self.t_feat_dim , self.t_feat_dim ),
self.act,
)
@@ -56,18 +59,18 @@ class GradientFieldViewFinder(nn.Module):
if not self.per_point_feature:
''' rotation_x_axis regress head '''
self.fusion_tail_rot_x = nn.Sequential(
nn.Linear(128 + 256 + 2048, 256),
nn.Linear(self.t_feat_dim + self.pose_feat_dim + self.main_feat_dim, 256),
self.act,
zero_module(nn.Linear(256, 3)),
)
self.fusion_tail_rot_y = nn.Sequential(
nn.Linear(128 + 256 + 2048, 256),
nn.Linear(self.t_feat_dim + self.pose_feat_dim + self.main_feat_dim, 256),
self.act,
zero_module(nn.Linear(256, 3)),
)
''' tranalation regress head '''
self.fusion_tail_trans = nn.Sequential(
nn.Linear(128 + 256 + 2048, 256),
nn.Linear(self.t_feat_dim + self.pose_feat_dim + self.main_feat_dim, 256),
self.act,
zero_module(nn.Linear(256, 3)),
)