success
This commit is contained in:
60
runners/runner.py
Executable file
60
runners/runner.py
Executable file
@@ -0,0 +1,60 @@
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
from abc import abstractmethod, ABC
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from configs.config import ConfigManager
|
||||
|
||||
class Runner(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, config_path):
|
||||
ConfigManager.load_config_with(config_path)
|
||||
ConfigManager.print_config()
|
||||
seed = ConfigManager.get("settings", "general", "seed")
|
||||
self.device = ConfigManager.get("settings", "general", "device")
|
||||
self.cuda_visible_devices = ConfigManager.get("settings","general","cuda_visible_devices")
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = self.cuda_visible_devices
|
||||
self.experiments_config = ConfigManager.get("settings", "experiment")
|
||||
self.experiment_path = os.path.join(self.experiments_config["root_dir"], self.experiments_config["name"])
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
lt = time.localtime()
|
||||
self.file_name = f"{lt.tm_year}_{lt.tm_mon}_{lt.tm_mday}_{lt.tm_hour}h{lt.tm_min}m{lt.tm_sec}s"
|
||||
|
||||
@abstractmethod
|
||||
def run(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_experiment(self, backup_name=None):
|
||||
if not os.path.exists(self.experiment_path):
|
||||
print(f"experiments environment {self.experiments_config['name']} does not exists.")
|
||||
self.create_experiment(backup_name)
|
||||
else:
|
||||
print(f"experiments environment {self.experiments_config['name']}")
|
||||
backup_config_dir = os.path.join(str(self.experiment_path), "configs")
|
||||
if not os.path.exists(backup_config_dir):
|
||||
os.makedirs(backup_config_dir)
|
||||
ConfigManager.backup_config_to(backup_config_dir, self.file_name, backup_name)
|
||||
|
||||
@abstractmethod
|
||||
def create_experiment(self, backup_name=None):
|
||||
print("creating experiment: " + self.experiments_config["name"])
|
||||
os.makedirs(self.experiment_path)
|
||||
backup_config_dir = os.path.join(str(self.experiment_path), "configs")
|
||||
os.makedirs(backup_config_dir)
|
||||
ConfigManager.backup_config_to(backup_config_dir, self.file_name, backup_name)
|
||||
log_dir = os.path.join(str(self.experiment_path), "log")
|
||||
os.makedirs(log_dir)
|
||||
cache_dir = os.path.join(str(self.experiment_path), "cache")
|
||||
os.makedirs(cache_dir)
|
||||
|
||||
def print_info(self):
|
||||
table_size = 80
|
||||
print("+" + "-" * table_size + "+")
|
||||
print(f"| Experiment <{self.experiments_config['name']}>")
|
||||
print("+" + "-" * table_size + "+")
|
Reference in New Issue
Block a user