train pointnet++

This commit is contained in:
2024-12-30 14:00:53 +00:00
parent 34548c64a3
commit 88d44f020e
3 changed files with 20 additions and 17 deletions

View File

@@ -75,11 +75,10 @@ class PointNet2Encoder(nn.Module):
def __init__(self, config:dict):
super().__init__()
input_channels = config.get("in_dim", 3) - 3
channel_in = config.get("in_dim", 3) - 3
params_name = config.get("params_name", "light")
self.SA_modules = nn.ModuleList()
channel_in = input_channels
selected_params = select_params(params_name)
for k in range(selected_params['NPOINTS'].__len__()):
mlps = selected_params['MLPS'][k].copy()