Source code for HALF.Utils.GlobalIndexDataset

from torch.utils.data import Dataset
from HALF.Utils.sample import Sample, SampleList

[docs]class GlobalIndexDataset(Dataset):
[docs] def __init__(self, dataset): self.dataset = dataset self.data = [] for global_i, e in enumerate(self.dataset): input, target = None, None if any([isinstance(e, list), isinstance(e, tuple)]) and len(e)==2: input, target = e else: input = e i_data = Sample({ "input": input, "target": target, "global_i": global_i }) self.data.append(i_data)
def __len__(self): return len(self.dataset) def __getitem__(self, index): return self.data[index]