train pointnet++

This commit is contained in:
2024-12-30 14:00:53 +00:00
parent 34548c64a3
commit 88d44f020e
3 changed files with 20 additions and 17 deletions

View File

@@ -4,6 +4,7 @@ import PytorchBoot.namespace as namespace
import PytorchBoot.stereotype as stereotype
from PytorchBoot.config import ConfigManager
from PytorchBoot.utils.log_util import Log
import torch
import os
import sys
@@ -50,7 +51,7 @@ class NBVReconstructionDataset(BaseDataset):
scene_name_list.append(scene_name)
return scene_name_list
def get_datalist(self):
def get_datalist(self, bias=False):
datalist = []
for scene_name in self.scene_name_list:
seq_num = DataLoadUtil.get_label_num(self.root_dir, scene_name)
@@ -79,16 +80,18 @@ class NBVReconstructionDataset(BaseDataset):
for data_pair in label_data["data_pairs"]:
scanned_views = data_pair[0]
next_best_view = data_pair[1]
datalist.append(
{
"scanned_views": scanned_views,
"next_best_view": next_best_view,
"seq_max_coverage_rate": max_coverage_rate,
"scene_name": scene_name,
"label_idx": seq_idx,
"scene_max_coverage_rate": scene_max_coverage_rate,
}
)
accept_probability = scanned_views[-1][1]
if accept_probability > np.random.rand():
datalist.append(
{
"scanned_views": scanned_views,
"next_best_view": next_best_view,
"seq_max_coverage_rate": max_coverage_rate,
"scene_name": scene_name,
"label_idx": seq_idx,
"scene_max_coverage_rate": scene_max_coverage_rate,
}
)
return datalist
def preprocess_cache(self):
@@ -227,9 +230,10 @@ if __name__ == "__main__":
torch.manual_seed(seed)
np.random.seed(seed)
config = {
"root_dir": "/data/hofee/data/packed_preprocessed_data",
"root_dir": "/data/hofee/data/new_full_data",
"model_dir": "../data/scaled_object_meshes",
"source": "nbv_reconstruction_dataset",
"split_file": "/data/hofee/data/OmniObject3d_train.txt",
"split_file": "/data/hofee/data/new_full_data_list/OmniObject3d_train.txt",
"load_from_preprocess": True,
"ratio": 0.5,
"batch_size": 2,