successfully inferenced

This commit is contained in:
0nhc 2024-10-11 23:40:34 -05:00
parent 16bfc22fe7
commit 26a2af0c16
2 changed files with 6 additions and 4 deletions

View File

@ -23,6 +23,7 @@ class InferenceEngine():
def __init__(self, config_path):
''' Config Manager '''
ConfigManager.load_config_with(config_path)
# ConfigManager.print_config()
''' Pytorch Seed '''
seed = ConfigManager.get("settings", "general", "seed")
@ -31,9 +32,9 @@ class InferenceEngine():
''' Pipeline '''
# self.pipeline_config = {'pts_encoder': 'pointnet', 'view_finder': 'gradient_field'}
self.pipeline_config = ConfigManager.get("settings", "pipeline")
# self.pipeline_config = ConfigManager.get("settings", "pipeline")
self.device = ConfigManager.get("settings", "general", "device")
self.pipeline = Pipeline(self.pipeline_config)
self.pipeline = Pipeline(config_path)
self.parallel = ConfigManager.get("settings","general","parallel")
if self.parallel and self.device == "cuda":
self.pipeline = torch.nn.DataParallel(self.pipeline)

View File

@ -18,9 +18,10 @@ class Pipeline(nn.Module):
TRAIN_MODE: str = "train"
TEST_MODE: str = "test"
def __init__(self, pipeline_config):
def __init__(self, config_path):
super(Pipeline, self).__init__()
ConfigManager.load_config_with(config_path)
pipeline_config = ConfigManager.get("settings", "pipeline")
self.modules_config = ConfigManager.get("modules")
self.device = ConfigManager.get("settings", "general", "device")
self.rgb_feat_cache = ConfigManager.get("datasets", "general", "rgb_feat_cache")