From 26a2af0c16020d9f8ecc0d5b31abd5917c1955ee Mon Sep 17 00:00:00 2001 From: 0nhc Date: Fri, 11 Oct 2024 23:40:34 -0500 Subject: [PATCH] successfully inferenced --- src/active_grasp/active_perception.py | 5 +++-- src/active_grasp/active_perception/modules/pipeline.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/active_grasp/active_perception.py b/src/active_grasp/active_perception.py index cc11be1..0a9f35a 100644 --- a/src/active_grasp/active_perception.py +++ b/src/active_grasp/active_perception.py @@ -23,6 +23,7 @@ class InferenceEngine(): def __init__(self, config_path): ''' Config Manager ''' ConfigManager.load_config_with(config_path) + # ConfigManager.print_config() ''' Pytorch Seed ''' seed = ConfigManager.get("settings", "general", "seed") @@ -31,9 +32,9 @@ class InferenceEngine(): ''' Pipeline ''' # self.pipeline_config = {'pts_encoder': 'pointnet', 'view_finder': 'gradient_field'} - self.pipeline_config = ConfigManager.get("settings", "pipeline") + # self.pipeline_config = ConfigManager.get("settings", "pipeline") self.device = ConfigManager.get("settings", "general", "device") - self.pipeline = Pipeline(self.pipeline_config) + self.pipeline = Pipeline(config_path) self.parallel = ConfigManager.get("settings","general","parallel") if self.parallel and self.device == "cuda": self.pipeline = torch.nn.DataParallel(self.pipeline) diff --git a/src/active_grasp/active_perception/modules/pipeline.py b/src/active_grasp/active_perception/modules/pipeline.py index 20c9858..0255e02 100755 --- a/src/active_grasp/active_perception/modules/pipeline.py +++ b/src/active_grasp/active_perception/modules/pipeline.py @@ -18,9 +18,10 @@ class Pipeline(nn.Module): TRAIN_MODE: str = "train" TEST_MODE: str = "test" - def __init__(self, pipeline_config): + def __init__(self, config_path): super(Pipeline, self).__init__() - + ConfigManager.load_config_with(config_path) + pipeline_config = ConfigManager.get("settings", "pipeline") self.modules_config = ConfigManager.get("modules") self.device = ConfigManager.get("settings", "general", "device") self.rgb_feat_cache = ConfigManager.get("datasets", "general", "rgb_feat_cache")