fix bug for training
This commit is contained in:
@@ -33,19 +33,22 @@ class GradientFieldViewFinder(nn.Module):
|
||||
pose_dim = PoseUtil.get_pose_dim(self.pose_mode)
|
||||
self.prior_fn, self.marginal_prob_fn, self.sde_fn, self.sampling_eps, self.T = flib.init_sde(config["sde_mode"])
|
||||
self.sampling_steps = config["sampling_steps"]
|
||||
self.t_feat_dim = config["t_feat_dim"]
|
||||
self.pose_feat_dim = config["pose_feat_dim"]
|
||||
self.main_feat_dim = config["main_feat_dim"]
|
||||
|
||||
''' encode pose '''
|
||||
self.pose_encoder = nn.Sequential(
|
||||
nn.Linear(pose_dim, 256),
|
||||
nn.Linear(pose_dim, self.pose_feat_dim ),
|
||||
self.act,
|
||||
nn.Linear(256, 256),
|
||||
nn.Linear(self.pose_feat_dim , self.pose_feat_dim ),
|
||||
self.act,
|
||||
)
|
||||
|
||||
''' encode t '''
|
||||
self.t_encoder = nn.Sequential(
|
||||
mlib.GaussianFourierProjection(embed_dim=128),
|
||||
nn.Linear(128, 128),
|
||||
mlib.GaussianFourierProjection(embed_dim=self.t_feat_dim ),
|
||||
nn.Linear(self.t_feat_dim , self.t_feat_dim ),
|
||||
self.act,
|
||||
)
|
||||
|
||||
@@ -56,18 +59,18 @@ class GradientFieldViewFinder(nn.Module):
|
||||
if not self.per_point_feature:
|
||||
''' rotation_x_axis regress head '''
|
||||
self.fusion_tail_rot_x = nn.Sequential(
|
||||
nn.Linear(128 + 256 + 2048, 256),
|
||||
nn.Linear(self.t_feat_dim + self.pose_feat_dim + self.main_feat_dim, 256),
|
||||
self.act,
|
||||
zero_module(nn.Linear(256, 3)),
|
||||
)
|
||||
self.fusion_tail_rot_y = nn.Sequential(
|
||||
nn.Linear(128 + 256 + 2048, 256),
|
||||
nn.Linear(self.t_feat_dim + self.pose_feat_dim + self.main_feat_dim, 256),
|
||||
self.act,
|
||||
zero_module(nn.Linear(256, 3)),
|
||||
)
|
||||
''' tranalation regress head '''
|
||||
self.fusion_tail_trans = nn.Sequential(
|
||||
nn.Linear(128 + 256 + 2048, 256),
|
||||
nn.Linear(self.t_feat_dim + self.pose_feat_dim + self.main_feat_dim, 256),
|
||||
self.act,
|
||||
zero_module(nn.Linear(256, 3)),
|
||||
)
|
||||
|
@@ -54,6 +54,7 @@ class PointNetEncoder(nn.Module):
|
||||
|
||||
def encode_points(self, pts):
|
||||
pts = pts.transpose(2, 1)
|
||||
|
||||
if not self.global_feat:
|
||||
pts_feature = self(pts).transpose(2, 1)
|
||||
else:
|
||||
@@ -98,11 +99,24 @@ class STNkd(nn.Module):
|
||||
|
||||
if __name__ == "__main__":
|
||||
sim_data = Variable(torch.rand(32, 2500, 3))
|
||||
|
||||
pointnet_global = PointNetEncoder(global_feat=True)
|
||||
config = {
|
||||
"in_dim": 3,
|
||||
"out_dim": 1024,
|
||||
"global_feat": True,
|
||||
"feature_transform": False
|
||||
}
|
||||
pointnet_global = PointNetEncoder(config)
|
||||
out = pointnet_global.encode_points(sim_data)
|
||||
|
||||
print("global feat", out.size())
|
||||
|
||||
pointnet = PointNetEncoder(global_feat=False)
|
||||
config = {
|
||||
"in_dim": 3,
|
||||
"out_dim": 1024,
|
||||
"global_feat": False,
|
||||
"feature_transform": False
|
||||
}
|
||||
|
||||
pointnet = PointNetEncoder(config)
|
||||
out = pointnet.encode_points(sim_data)
|
||||
print("point feat", out.size())
|
||||
|
@@ -38,7 +38,7 @@ class TransformerSequenceEncoder(nn.Module):
|
||||
|
||||
# Prepare mask for padding
|
||||
max_len = max(lengths)
|
||||
padding_mask = torch.tensor([([0] * length + [1] * (max_len - length)) for length in lengths], dtype=torch.bool)
|
||||
padding_mask = torch.tensor([([0] * length + [1] * (max_len - length)) for length in lengths], dtype=torch.bool).to(combined_tensor.device)
|
||||
# Transformer encoding
|
||||
transformer_output = self.transformer_encoder(combined_tensor, src_key_padding_mask=padding_mask)
|
||||
|
||||
|
Reference in New Issue
Block a user