fix bug for training
This commit is contained in:
@@ -54,6 +54,7 @@ class PointNetEncoder(nn.Module):
|
||||
|
||||
def encode_points(self, pts):
|
||||
pts = pts.transpose(2, 1)
|
||||
|
||||
if not self.global_feat:
|
||||
pts_feature = self(pts).transpose(2, 1)
|
||||
else:
|
||||
@@ -98,11 +99,24 @@ class STNkd(nn.Module):
|
||||
|
||||
if __name__ == "__main__":
|
||||
sim_data = Variable(torch.rand(32, 2500, 3))
|
||||
|
||||
pointnet_global = PointNetEncoder(global_feat=True)
|
||||
config = {
|
||||
"in_dim": 3,
|
||||
"out_dim": 1024,
|
||||
"global_feat": True,
|
||||
"feature_transform": False
|
||||
}
|
||||
pointnet_global = PointNetEncoder(config)
|
||||
out = pointnet_global.encode_points(sim_data)
|
||||
|
||||
print("global feat", out.size())
|
||||
|
||||
pointnet = PointNetEncoder(global_feat=False)
|
||||
config = {
|
||||
"in_dim": 3,
|
||||
"out_dim": 1024,
|
||||
"global_feat": False,
|
||||
"feature_transform": False
|
||||
}
|
||||
|
||||
pointnet = PointNetEncoder(config)
|
||||
out = pointnet.encode_points(sim_data)
|
||||
print("point feat", out.size())
|
||||
|
Reference in New Issue
Block a user