global_only: debug

This commit is contained in:
2024-10-29 16:21:30 +00:00
parent 2487039445
commit e23697eb87
3 changed files with 18 additions and 23 deletions

View File

@@ -20,8 +20,8 @@ class NBVReconstructionPipeline(nn.Module):
self.pose_encoder = ComponentFactory.create(
namespace.Stereotype.MODULE, self.module_config["pose_encoder"]
)
self.transformer_seq_encoder = ComponentFactory.create(
namespace.Stereotype.MODULE, self.module_config["transformer_seq_encoder"]
self.seq_encoder = ComponentFactory.create(
namespace.Stereotype.MODULE, self.module_config["seq_encoder"]
)
self.view_finder = ComponentFactory.create(
namespace.Stereotype.MODULE, self.module_config["view_finder"]
@@ -107,7 +107,7 @@ class NBVReconstructionPipeline(nn.Module):
seq_embedding = pose_feat_seq
embedding_list_batch.append(seq_embedding) # List(B): Tensor(S x (Dp))
seq_feat = self.transformer_seq_encoder.encode_sequence(embedding_list_batch) # Tensor(B x Ds)
seq_feat = self.seq_encoder.encode_sequence(embedding_list_batch) # Tensor(B x Ds)
main_feat = torch.cat([seq_feat, global_scanned_feat], dim=-1) # Tensor(B x (Ds+Dg))
if torch.isnan(main_feat).any():