fix bug for training

This commit is contained in:
2024-09-12 15:11:09 +08:00
parent a79ca7749d
commit 4c69ed777b
15 changed files with 201 additions and 120 deletions

View File

@@ -33,19 +33,22 @@ class GradientFieldViewFinder(nn.Module):
pose_dim = PoseUtil.get_pose_dim(self.pose_mode)
self.prior_fn, self.marginal_prob_fn, self.sde_fn, self.sampling_eps, self.T = flib.init_sde(config["sde_mode"])
self.sampling_steps = config["sampling_steps"]
self.t_feat_dim = config["t_feat_dim"]
self.pose_feat_dim = config["pose_feat_dim"]
self.main_feat_dim = config["main_feat_dim"]
''' encode pose '''
self.pose_encoder = nn.Sequential(
nn.Linear(pose_dim, 256),
nn.Linear(pose_dim, self.pose_feat_dim ),
self.act,
nn.Linear(256, 256),
nn.Linear(self.pose_feat_dim , self.pose_feat_dim ),
self.act,
)
''' encode t '''
self.t_encoder = nn.Sequential(
mlib.GaussianFourierProjection(embed_dim=128),
nn.Linear(128, 128),
mlib.GaussianFourierProjection(embed_dim=self.t_feat_dim ),
nn.Linear(self.t_feat_dim , self.t_feat_dim ),
self.act,
)
@@ -56,18 +59,18 @@ class GradientFieldViewFinder(nn.Module):
if not self.per_point_feature:
''' rotation_x_axis regress head '''
self.fusion_tail_rot_x = nn.Sequential(
nn.Linear(128 + 256 + 2048, 256),
nn.Linear(self.t_feat_dim + self.pose_feat_dim + self.main_feat_dim, 256),
self.act,
zero_module(nn.Linear(256, 3)),
)
self.fusion_tail_rot_y = nn.Sequential(
nn.Linear(128 + 256 + 2048, 256),
nn.Linear(self.t_feat_dim + self.pose_feat_dim + self.main_feat_dim, 256),
self.act,
zero_module(nn.Linear(256, 3)),
)
''' tranalation regress head '''
self.fusion_tail_trans = nn.Sequential(
nn.Linear(128 + 256 + 2048, 256),
nn.Linear(self.t_feat_dim + self.pose_feat_dim + self.main_feat_dim, 256),
self.act,
zero_module(nn.Linear(256, 3)),
)

View File

@@ -54,6 +54,7 @@ class PointNetEncoder(nn.Module):
def encode_points(self, pts):
pts = pts.transpose(2, 1)
if not self.global_feat:
pts_feature = self(pts).transpose(2, 1)
else:
@@ -98,11 +99,24 @@ class STNkd(nn.Module):
if __name__ == "__main__":
sim_data = Variable(torch.rand(32, 2500, 3))
pointnet_global = PointNetEncoder(global_feat=True)
config = {
"in_dim": 3,
"out_dim": 1024,
"global_feat": True,
"feature_transform": False
}
pointnet_global = PointNetEncoder(config)
out = pointnet_global.encode_points(sim_data)
print("global feat", out.size())
pointnet = PointNetEncoder(global_feat=False)
config = {
"in_dim": 3,
"out_dim": 1024,
"global_feat": False,
"feature_transform": False
}
pointnet = PointNetEncoder(config)
out = pointnet.encode_points(sim_data)
print("point feat", out.size())

View File

@@ -38,7 +38,7 @@ class TransformerSequenceEncoder(nn.Module):
# Prepare mask for padding
max_len = max(lengths)
padding_mask = torch.tensor([([0] * length + [1] * (max_len - length)) for length in lengths], dtype=torch.bool)
padding_mask = torch.tensor([([0] * length + [1] * (max_len - length)) for length in lengths], dtype=torch.bool).to(combined_tensor.device)
# Transformer encoding
transformer_output = self.transformer_encoder(combined_tensor, src_key_padding_mask=padding_mask)