delete all abstract class
This commit is contained in:
parent
eceedd5c15
commit
2503fca572
@ -3,7 +3,7 @@ import PytorchBoot.stereotype as stereotype
|
||||
|
||||
@stereotype.loss_function("gf_loss")
|
||||
class GFLoss:
|
||||
def __init__(self, config):
|
||||
def __init__(self, _):
|
||||
pass
|
||||
|
||||
def compute(self, output, _):
|
||||
|
@ -3,7 +3,6 @@ import torch.nn as nn
|
||||
import PytorchBoot.stereotype as stereotype
|
||||
|
||||
from utils.pose import PoseUtil
|
||||
from modules.view_finder.abstract_view_finder import ViewFinder
|
||||
import modules.module_lib as mlib
|
||||
import modules.func_lib as flib
|
||||
|
||||
@ -18,7 +17,7 @@ def zero_module(module):
|
||||
|
||||
|
||||
@stereotype.module("gf_view_finder")
|
||||
class GradientFieldViewFinder(ViewFinder):
|
||||
class GradientFieldViewFinder(nn.Module):
|
||||
def __init__(self, config):
|
||||
|
||||
super(GradientFieldViewFinder, self).__init__()
|
@ -7,10 +7,9 @@ from torch.autograd import Variable
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
|
||||
from modules.pts_encoder.abstract_pts_encoder import PointsEncoder
|
||||
import PytorchBoot.stereotype as stereotype
|
||||
@stereotype.module("pointnet_encoder")
|
||||
class PointNetEncoder(PointsEncoder):
|
||||
class PointNetEncoder(nn.Module):
|
||||
|
||||
def __init__(self, config:dict):
|
||||
super(PointNetEncoder, self).__init__()
|
@ -1,12 +0,0 @@
|
||||
from abc import abstractmethod
|
||||
|
||||
from torch import nn
|
||||
|
||||
|
||||
class PointsEncoder(nn.Module):
|
||||
def __init__(self):
|
||||
super(PointsEncoder, self).__init__()
|
||||
|
||||
@abstractmethod
|
||||
def encode_points(self, pts):
|
||||
pass
|
@ -1,12 +0,0 @@
|
||||
from abc import abstractmethod
|
||||
|
||||
from torch import nn
|
||||
|
||||
|
||||
class SequenceEncoder(nn.Module):
|
||||
def __init__(self):
|
||||
super(SequenceEncoder, self).__init__()
|
||||
|
||||
@abstractmethod
|
||||
def encode_sequence(self, pts_embedding_list, pose_embedding_list):
|
||||
pass
|
@ -2,11 +2,9 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
import PytorchBoot.stereotype as stereotype
|
||||
import sys; sys.path.append(r"C:\Document\Local Project\nbv_rec\nbv_reconstruction")
|
||||
from modules.seq_encoder.abstract_seq_encoder import SequenceEncoder
|
||||
|
||||
@stereotype.module("transformer_seq_encoder")
|
||||
class TransformerSequenceEncoder(SequenceEncoder):
|
||||
class TransformerSequenceEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(TransformerSequenceEncoder, self).__init__()
|
||||
self.config = config
|
@ -1,12 +0,0 @@
|
||||
from abc import abstractmethod
|
||||
|
||||
from torch import nn
|
||||
|
||||
|
||||
class ViewFinder(nn.Module):
|
||||
def __init__(self):
|
||||
super(ViewFinder, self).__init__()
|
||||
|
||||
@abstractmethod
|
||||
def next_best_view(self, seq_feat):
|
||||
pass
|
Loading…
x
Reference in New Issue
Block a user