update transformer_seq_encoder's config
This commit is contained in:
@@ -12,21 +12,24 @@ class NBVReconstructionGlobalPointsPipeline(nn.Module):
|
||||
super(NBVReconstructionGlobalPointsPipeline, self).__init__()
|
||||
self.config = config
|
||||
self.module_config = config["modules"]
|
||||
|
||||
self.pts_encoder = ComponentFactory.create(
|
||||
namespace.Stereotype.MODULE, self.module_config["pts_encoder"]
|
||||
)
|
||||
self.pose_encoder = ComponentFactory.create(
|
||||
namespace.Stereotype.MODULE, self.module_config["pose_encoder"]
|
||||
)
|
||||
self.pose_n_num_seq_encoder = ComponentFactory.create(
|
||||
namespace.Stereotype.MODULE, self.module_config["pose_n_num_seq_encoder"]
|
||||
self.pts_num_encoder = ComponentFactory.create(
|
||||
namespace.Stereotype.MODULE, self.module_config["pts_num_encoder"]
|
||||
)
|
||||
|
||||
self.transformer_seq_encoder = ComponentFactory.create(
|
||||
namespace.Stereotype.MODULE, self.module_config["transformer_seq_encoder"]
|
||||
)
|
||||
self.view_finder = ComponentFactory.create(
|
||||
namespace.Stereotype.MODULE, self.module_config["view_finder"]
|
||||
)
|
||||
self.pts_num_encoder = ComponentFactory.create(
|
||||
namespace.Stereotype.MODULE, self.module_config["pts_num_encoder"]
|
||||
)
|
||||
|
||||
|
||||
self.eps = float(self.config["eps"])
|
||||
self.enable_global_scanned_feat = self.config["global_scanned_feat"]
|
||||
@@ -128,7 +131,7 @@ class NBVReconstructionGlobalPointsPipeline(nn.Module):
|
||||
seq_embedding = torch.cat([pose_feat_seq, pts_num_feat_seq, partial_feat_seq], dim=-1) # Tensor(S x (Dp+Dn+Dl))
|
||||
embedding_list_batch.append(seq_embedding) # List(B): Tensor(S x (Dp+Dn+Dl))
|
||||
|
||||
seq_feat = self.pose_n_num_seq_encoder.encode_sequence(embedding_list_batch) # Tensor(B x Ds)
|
||||
seq_feat = self.transformer_seq_encoder.encode_sequence(embedding_list_batch) # Tensor(B x Ds)
|
||||
|
||||
main_feat = torch.cat([seq_feat, global_scanned_feat], dim=-1) # Tensor(B x (Ds+Dg))
|
||||
|
||||
|
Reference in New Issue
Block a user