train pointnet++
This commit is contained in:
parent
34548c64a3
commit
88d44f020e
@ -7,7 +7,7 @@ runner:
|
||||
parallel: False
|
||||
|
||||
experiment:
|
||||
name: train_ab_global_only_pointnet++
|
||||
name: train_ab_global_only_with_accept_probability
|
||||
root_dir: "experiments"
|
||||
use_checkpoint: False
|
||||
epoch: -1 # -1 stands for last epoch
|
||||
@ -80,7 +80,7 @@ dataset:
|
||||
pipeline:
|
||||
nbv_reconstruction_pipeline:
|
||||
modules:
|
||||
pts_encoder: pointnet++_encoder
|
||||
pts_encoder: pointnet_encoder
|
||||
seq_encoder: transformer_seq_encoder
|
||||
pose_encoder: pose_encoder
|
||||
view_finder: gf_view_finder
|
||||
|
@ -4,6 +4,7 @@ import PytorchBoot.namespace as namespace
|
||||
import PytorchBoot.stereotype as stereotype
|
||||
from PytorchBoot.config import ConfigManager
|
||||
from PytorchBoot.utils.log_util import Log
|
||||
|
||||
import torch
|
||||
import os
|
||||
import sys
|
||||
@ -50,7 +51,7 @@ class NBVReconstructionDataset(BaseDataset):
|
||||
scene_name_list.append(scene_name)
|
||||
return scene_name_list
|
||||
|
||||
def get_datalist(self):
|
||||
def get_datalist(self, bias=False):
|
||||
datalist = []
|
||||
for scene_name in self.scene_name_list:
|
||||
seq_num = DataLoadUtil.get_label_num(self.root_dir, scene_name)
|
||||
@ -79,16 +80,18 @@ class NBVReconstructionDataset(BaseDataset):
|
||||
for data_pair in label_data["data_pairs"]:
|
||||
scanned_views = data_pair[0]
|
||||
next_best_view = data_pair[1]
|
||||
datalist.append(
|
||||
{
|
||||
"scanned_views": scanned_views,
|
||||
"next_best_view": next_best_view,
|
||||
"seq_max_coverage_rate": max_coverage_rate,
|
||||
"scene_name": scene_name,
|
||||
"label_idx": seq_idx,
|
||||
"scene_max_coverage_rate": scene_max_coverage_rate,
|
||||
}
|
||||
)
|
||||
accept_probability = scanned_views[-1][1]
|
||||
if accept_probability > np.random.rand():
|
||||
datalist.append(
|
||||
{
|
||||
"scanned_views": scanned_views,
|
||||
"next_best_view": next_best_view,
|
||||
"seq_max_coverage_rate": max_coverage_rate,
|
||||
"scene_name": scene_name,
|
||||
"label_idx": seq_idx,
|
||||
"scene_max_coverage_rate": scene_max_coverage_rate,
|
||||
}
|
||||
)
|
||||
return datalist
|
||||
|
||||
def preprocess_cache(self):
|
||||
@ -227,9 +230,10 @@ if __name__ == "__main__":
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
config = {
|
||||
"root_dir": "/data/hofee/data/packed_preprocessed_data",
|
||||
"root_dir": "/data/hofee/data/new_full_data",
|
||||
"model_dir": "../data/scaled_object_meshes",
|
||||
"source": "nbv_reconstruction_dataset",
|
||||
"split_file": "/data/hofee/data/OmniObject3d_train.txt",
|
||||
"split_file": "/data/hofee/data/new_full_data_list/OmniObject3d_train.txt",
|
||||
"load_from_preprocess": True,
|
||||
"ratio": 0.5,
|
||||
"batch_size": 2,
|
||||
|
@ -75,11 +75,10 @@ class PointNet2Encoder(nn.Module):
|
||||
def __init__(self, config:dict):
|
||||
super().__init__()
|
||||
|
||||
input_channels = config.get("in_dim", 3) - 3
|
||||
channel_in = config.get("in_dim", 3) - 3
|
||||
params_name = config.get("params_name", "light")
|
||||
|
||||
self.SA_modules = nn.ModuleList()
|
||||
channel_in = input_channels
|
||||
selected_params = select_params(params_name)
|
||||
for k in range(selected_params['NPOINTS'].__len__()):
|
||||
mlps = selected_params['MLPS'][k].copy()
|
||||
|
Loading…
x
Reference in New Issue
Block a user