fix nan
This commit is contained in:
parent
5e8684d149
commit
1123e69bff
@ -108,7 +108,10 @@ class NBVReconstructionPipeline(nn.Module):
|
|||||||
partial_point_feat_seq = []
|
partial_point_feat_seq = []
|
||||||
for j in range(seq_len):
|
for j in range(seq_len):
|
||||||
partial_per_point_feat = per_point_feat[scanned_pts_mask[j]]
|
partial_per_point_feat = per_point_feat[scanned_pts_mask[j]]
|
||||||
partial_point_feat = torch.mean(partial_per_point_feat, dim=0) # Tensor(Dp)
|
if partial_per_point_feat.shape[0] == 0:
|
||||||
|
partial_point_feat = torch.zeros(per_point_feat.shape[1], device=device)
|
||||||
|
else:
|
||||||
|
partial_point_feat = torch.mean(partial_per_point_feat, dim=0) # Tensor(Dp)
|
||||||
partial_point_feat_seq.append(partial_point_feat)
|
partial_point_feat_seq.append(partial_point_feat)
|
||||||
partial_point_feat_seq = torch.stack(partial_point_feat_seq, dim=0) # Tensor(S x Dp)
|
partial_point_feat_seq = torch.stack(partial_point_feat_seq, dim=0) # Tensor(S x Dp)
|
||||||
|
|
||||||
@ -122,6 +125,13 @@ class NBVReconstructionPipeline(nn.Module):
|
|||||||
main_feat = torch.cat([seq_feat, global_scanned_feat], dim=-1) # Tensor(B x (Ds+Dg))
|
main_feat = torch.cat([seq_feat, global_scanned_feat], dim=-1) # Tensor(B x (Ds+Dg))
|
||||||
|
|
||||||
if torch.isnan(main_feat).any():
|
if torch.isnan(main_feat).any():
|
||||||
|
for i in range(len(main_feat)):
|
||||||
|
if torch.isnan(main_feat[i]).any():
|
||||||
|
scanned_pts_mask = scanned_pts_mask_batch[i]
|
||||||
|
Log.info(f"scanned_pts_mask shape: {scanned_pts_mask.shape}")
|
||||||
|
Log.info(f"scanned_pts_mask sum: {scanned_pts_mask.sum()}")
|
||||||
|
import ipdb
|
||||||
|
ipdb.set_trace()
|
||||||
Log.error("nan in main_feat", True)
|
Log.error("nan in main_feat", True)
|
||||||
|
|
||||||
return main_feat
|
return main_feat
|
||||||
|
Loading…
x
Reference in New Issue
Block a user