success
This commit is contained in:
71
runners/preprocessor.py
Executable file
71
runners/preprocessor.py
Executable file
@@ -0,0 +1,71 @@
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
import shutil
|
||||
|
||||
from configs.config import ConfigManager
|
||||
from runners.runner import Runner
|
||||
|
||||
|
||||
class Preprocessor(Runner, ABC):
|
||||
DATA = "data"
|
||||
|
||||
def __init__(self, config_path):
|
||||
super().__init__(config_path)
|
||||
|
||||
self.preprocess_config = ConfigManager.get("settings", "preprocess")
|
||||
|
||||
def load_experiment(self,backup_name=None):
|
||||
super().load_experiment(backup_name)
|
||||
exists_ok = self.experiments_config["keep_exists"]
|
||||
if not exists_ok:
|
||||
data_dir = os.path.join(str(self.experiment_path), Preprocessor.DATA)
|
||||
shutil.rmtree(data_dir, ignore_errors=True)
|
||||
os.makedirs(data_dir)
|
||||
self.create_dataset_list()
|
||||
|
||||
def create_experiment(self,backup_name=None):
|
||||
super().create_experiment(backup_name)
|
||||
data_dir = os.path.join(str(self.experiment_path), Preprocessor.DATA)
|
||||
os.makedirs(data_dir)
|
||||
self.create_dataset_list()
|
||||
|
||||
def create_dataset_list(self):
|
||||
dataset_list = self.preprocess_config["dataset_list"]
|
||||
exists_ok = self.experiments_config["keep_exists"]
|
||||
for dataset in dataset_list:
|
||||
source = dataset["source"]
|
||||
source_dir = os.path.join(str(self.experiment_path), Preprocessor.DATA, source)
|
||||
if not os.path.exists(source_dir):
|
||||
os.makedirs(source_dir,exist_ok=exists_ok)
|
||||
dataset_name = dataset["data_type"]
|
||||
dataset_dir = os.path.join(source_dir, dataset_name)
|
||||
if not os.path.exists(dataset_dir):
|
||||
os.makedirs(dataset_dir,exist_ok=exists_ok)
|
||||
|
||||
@abstractmethod
|
||||
def get_dataloader(self, dataset_config):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_model(self, model_config):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def prediction(self, model, dataloader):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def preprocess(self, predicted_data):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_processed_data(self, processed_data, data_config=None):
|
||||
pass
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config", type=str, default="../configs/local_gsnet_preprocess_config.yaml")
|
||||
args = parser.parse_args()
|
||||
preproc = Preprocessor(args.config)
|
Reference in New Issue
Block a user