Source code for HALF.Utils.GlobalIndexDataset
from torch.utils.data import Dataset
from HALF.Utils.sample import Sample, SampleList
[docs]class GlobalIndexDataset(Dataset):
[docs] def __init__(self, dataset):
self.dataset = dataset
self.data = []
for global_i, e in enumerate(self.dataset):
input, target = None, None
if any([isinstance(e, list), isinstance(e, tuple)]) and len(e)==2:
input, target = e
else:
input = e
i_data = Sample({
"input": input,
"target": target,
"global_i": global_i
})
self.data.append(i_data)
def __len__(self):
return len(self.dataset)
def __getitem__(self, index):
return self.data[index]