diff --git a/core/loss.py b/core/loss.py index 1de8b8d..18656e2 100644 --- a/core/loss.py +++ b/core/loss.py @@ -3,7 +3,7 @@ import PytorchBoot.stereotype as stereotype @stereotype.loss_function("gf_loss") class GFLoss: - def __init__(self, config): + def __init__(self, _): pass def compute(self, output, _): diff --git a/modules/view_finder/gf_view_finder.py b/modules/gf_view_finder.py similarity index 94% rename from modules/view_finder/gf_view_finder.py rename to modules/gf_view_finder.py index ea7b037..21ada32 100644 --- a/modules/view_finder/gf_view_finder.py +++ b/modules/gf_view_finder.py @@ -3,7 +3,6 @@ import torch.nn as nn import PytorchBoot.stereotype as stereotype from utils.pose import PoseUtil -from modules.view_finder.abstract_view_finder import ViewFinder import modules.module_lib as mlib import modules.func_lib as flib @@ -18,7 +17,7 @@ def zero_module(module): @stereotype.module("gf_view_finder") -class GradientFieldViewFinder(ViewFinder): +class GradientFieldViewFinder(nn.Module): def __init__(self, config): super(GradientFieldViewFinder, self).__init__() diff --git a/modules/pts_encoder/pointnet_encoder.py b/modules/pointnet_encoder.py similarity index 93% rename from modules/pts_encoder/pointnet_encoder.py rename to modules/pointnet_encoder.py index e4e5110..b669a4c 100644 --- a/modules/pts_encoder/pointnet_encoder.py +++ b/modules/pointnet_encoder.py @@ -7,10 +7,9 @@ from torch.autograd import Variable import numpy as np import torch.nn.functional as F -from modules.pts_encoder.abstract_pts_encoder import PointsEncoder import PytorchBoot.stereotype as stereotype @stereotype.module("pointnet_encoder") -class PointNetEncoder(PointsEncoder): +class PointNetEncoder(nn.Module): def __init__(self, config:dict): super(PointNetEncoder, self).__init__() diff --git a/modules/pose_encoder/pose_encoder.py b/modules/pose_encoder.py similarity index 100% rename from modules/pose_encoder/pose_encoder.py rename to modules/pose_encoder.py diff --git a/modules/pts_encoder/abstract_pts_encoder.py b/modules/pts_encoder/abstract_pts_encoder.py deleted file mode 100644 index a7e33ab..0000000 --- a/modules/pts_encoder/abstract_pts_encoder.py +++ /dev/null @@ -1,12 +0,0 @@ -from abc import abstractmethod - -from torch import nn - - -class PointsEncoder(nn.Module): - def __init__(self): - super(PointsEncoder, self).__init__() - - @abstractmethod - def encode_points(self, pts): - pass diff --git a/modules/seq_encoder/abstract_seq_encoder.py b/modules/seq_encoder/abstract_seq_encoder.py deleted file mode 100644 index 5c4a8ba..0000000 --- a/modules/seq_encoder/abstract_seq_encoder.py +++ /dev/null @@ -1,12 +0,0 @@ -from abc import abstractmethod - -from torch import nn - - -class SequenceEncoder(nn.Module): - def __init__(self): - super(SequenceEncoder, self).__init__() - - @abstractmethod - def encode_sequence(self, pts_embedding_list, pose_embedding_list): - pass diff --git a/modules/seq_encoder/transformer_seq_encoder.py b/modules/transformer_seq_encoder.py similarity index 93% rename from modules/seq_encoder/transformer_seq_encoder.py rename to modules/transformer_seq_encoder.py index 47f0736..ba50f7d 100644 --- a/modules/seq_encoder/transformer_seq_encoder.py +++ b/modules/transformer_seq_encoder.py @@ -2,11 +2,9 @@ import torch from torch import nn import PytorchBoot.stereotype as stereotype -import sys; sys.path.append(r"C:\Document\Local Project\nbv_rec\nbv_reconstruction") -from modules.seq_encoder.abstract_seq_encoder import SequenceEncoder @stereotype.module("transformer_seq_encoder") -class TransformerSequenceEncoder(SequenceEncoder): +class TransformerSequenceEncoder(nn.Module): def __init__(self, config): super(TransformerSequenceEncoder, self).__init__() self.config = config diff --git a/modules/view_finder/abstract_view_finder.py b/modules/view_finder/abstract_view_finder.py deleted file mode 100644 index 1516aad..0000000 --- a/modules/view_finder/abstract_view_finder.py +++ /dev/null @@ -1,12 +0,0 @@ -from abc import abstractmethod - -from torch import nn - - -class ViewFinder(nn.Module): - def __init__(self): - super(ViewFinder, self).__init__() - - @abstractmethod - def next_best_view(self, seq_feat): - pass