success
This commit is contained in:
63
datasets/dataset.py
Executable file
63
datasets/dataset.py
Executable file
@@ -0,0 +1,63 @@
|
||||
from typing import Sized
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
import pickle
|
||||
from abc import ABC, abstractmethod
|
||||
from torch.utils.data import Dataset
|
||||
from torch.utils.data import DataLoader, Subset
|
||||
|
||||
from configs.config import ConfigManager
|
||||
|
||||
class AdvancedDataset(ABC, Dataset, Sized):
|
||||
def __init__(self, config):
|
||||
super(AdvancedDataset, self).__init__()
|
||||
self.config = config
|
||||
self.use_cache = ConfigManager.get("settings", "experiment", "use_cache")
|
||||
exp_root = ConfigManager.get("settings", "experiment", "root_dir")
|
||||
exp_name = ConfigManager.get("settings", "experiment", "name")
|
||||
self.cache_path = os.path.join(exp_root,exp_name,"cache",self.config["name"])
|
||||
if self.use_cache and not os.path.exists(self.cache_path):
|
||||
os.makedirs(self.cache_path)
|
||||
|
||||
@staticmethod
|
||||
def process_batch(batch, device):
|
||||
for key in batch.keys():
|
||||
if isinstance(batch[key], list):
|
||||
continue
|
||||
batch[key] = batch[key].to(device)
|
||||
return batch
|
||||
|
||||
@abstractmethod
|
||||
def getitem(self, index) -> dict:
|
||||
raise NotImplementedError
|
||||
|
||||
def __getitem__(self, index) -> dict:
|
||||
cache_data_path = os.path.join(self.cache_path, f"{index}.pkl")
|
||||
if self.use_cache and os.path.exists(cache_data_path):
|
||||
with open(cache_data_path, "rb") as f:
|
||||
item = pickle.load(f)
|
||||
else:
|
||||
item = self.getitem(index)
|
||||
if self.use_cache:
|
||||
with open(cache_data_path, "wb") as f:
|
||||
pickle.dump(item, f)
|
||||
return item
|
||||
|
||||
def get_loader(self, device, shuffle=False):
|
||||
ratio = self.config["ratio"]
|
||||
if ratio > 1 or ratio <= 0:
|
||||
raise ValueError(
|
||||
f"dataset ratio should be between (0,1], found {ratio} in {self.config['name']}"
|
||||
)
|
||||
subset_size = int(len(self) * ratio)
|
||||
indices = np.random.permutation(len(self))[:subset_size]
|
||||
subset = Subset(self, indices)
|
||||
return DataLoader(
|
||||
|
||||
subset,
|
||||
batch_size=self.config["batch_size"],
|
||||
num_workers=self.config["num_workers"],
|
||||
shuffle=shuffle,
|
||||
#generator=torch.Generator(device=device),
|
||||
)
|
Reference in New Issue
Block a user