train pointnet++
This commit is contained in:
@@ -75,11 +75,10 @@ class PointNet2Encoder(nn.Module):
|
||||
def __init__(self, config:dict):
|
||||
super().__init__()
|
||||
|
||||
input_channels = config.get("in_dim", 3) - 3
|
||||
channel_in = config.get("in_dim", 3) - 3
|
||||
params_name = config.get("params_name", "light")
|
||||
|
||||
self.SA_modules = nn.ModuleList()
|
||||
channel_in = input_channels
|
||||
selected_params = select_params(params_name)
|
||||
for k in range(selected_params['NPOINTS'].__len__()):
|
||||
mlps = selected_params['MLPS'][k].copy()
|
||||
|
Reference in New Issue
Block a user