first commit
This commit is contained in:
57
runners/data_spliter.py
Normal file
57
runners/data_spliter.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import os
|
||||
import random
|
||||
from PytorchBoot.runners.runner import Runner
|
||||
from PytorchBoot.config import ConfigManager
|
||||
from PytorchBoot.utils import Log
|
||||
import PytorchBoot.stereotype as stereotype
|
||||
from PytorchBoot.status import status_manager
|
||||
|
||||
@stereotype.runner("data_spliter")
|
||||
class DataSpliter(Runner):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.load_experiment("data_split")
|
||||
self.root_dir = ConfigManager.get("runner", "split", "root_dir")
|
||||
self.type = ConfigManager.get("runner", "split", "type")
|
||||
self.datasets = ConfigManager.get("runner", "split", "datasets")
|
||||
self.datapath_list = self.load_all_datapath()
|
||||
|
||||
def run(self):
|
||||
self.split_dataset()
|
||||
|
||||
def split_dataset(self):
|
||||
|
||||
random.shuffle(self.datapath_list)
|
||||
start_idx = 0
|
||||
for dataset_idx in range(len(self.datasets)):
|
||||
dataset = list(self.datasets.keys())[dataset_idx]
|
||||
ratio = self.datasets[dataset]["ratio"]
|
||||
path = self.datasets[dataset]["path"]
|
||||
split_size = int(len(self.datapath_list) * ratio)
|
||||
split_files = self.datapath_list[start_idx:start_idx + split_size]
|
||||
start_idx += split_size
|
||||
self.save_split_files(path, split_files)
|
||||
status_manager.set_progress("split", "data_splitor", "split dataset", dataset_idx, len(self.datasets))
|
||||
Log.success(f"save {dataset} split files to {path}")
|
||||
status_manager.set_progress("split", "data_splitor", "split dataset", len(self.datasets), len(self.datasets))
|
||||
def save_split_files(self, path, split_files):
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
with open(path, "w") as f:
|
||||
f.write("\n".join(split_files))
|
||||
|
||||
|
||||
def load_all_datapath(self):
|
||||
return os.listdir(self.root_dir)
|
||||
|
||||
def create_experiment(self, backup_name=None):
|
||||
super().create_experiment(backup_name)
|
||||
|
||||
def load_experiment(self, backup_name=None):
|
||||
super().load_experiment(backup_name)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user