Basic Framework

This commit is contained in:
hofee
2024-08-18 00:37:17 +08:00
commit 73dcd592df
14 changed files with 733 additions and 0 deletions

35
datasets/dataset.py Normal file
View File

@@ -0,0 +1,35 @@
from abc import ABC
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader, Subset
class BaseDataset(ABC, Dataset):
def __init__(self, config):
super(BaseDataset, self).__init__()
self.config = config
@staticmethod
def process_batch(batch, device):
for key in batch.keys():
if isinstance(batch[key], list):
continue
batch[key] = batch[key].to(device)
return batch
def get_loader(self, shuffle=False):
ratio = self.config["ratio"]
if ratio > 1 or ratio <= 0:
raise ValueError(
f"dataset ratio should be between (0,1], found {ratio} in {self.config['name']}"
)
subset_size = int(len(self) * ratio)
indices = np.random.permutation(len(self))[:subset_size]
subset = Subset(self, indices)
return DataLoader(
subset,
batch_size=self.config["batch_size"],
num_workers=self.config["num_workers"],
shuffle=shuffle,
)

View File

@@ -0,0 +1,30 @@
import sys
import os
path = os.path.abspath(__file__)
for i in range(2):
path = os.path.dirname(path)
PROJECT_ROOT = path
sys.path.append(PROJECT_ROOT)
from datasets.dataset import BaseDataset
class DatasetFactory:
@staticmethod
def create(config) -> BaseDataset:
pass
''' ------------ Debug ------------ '''
if __name__ == "__main__":
from configs.config import ConfigManager
ConfigManager.load_config_with('/home/data/hofee/project/ActivePerception/ActivePerception/configs/server_train_config.yaml')
ConfigManager.print_config()
dataset = DatasetFactory.create(ConfigManager.get("settings", "test", "dataset_list")[1])
print(len(dataset))
data_test = dataset.__getitem__(107000)
print(data_test['src_path'])
import pickle
# with open("data_sample_new.pkl", "wb") as f:
# pickle.dump(data_test, f)