from torch.utils.data.dataset import Dataset, Subset
[docs]class ALDataModule:
[docs] def __init__(self,
dataset_labeled: Dataset,
dataset_unlabeled: Dataset,
dataset_test: Dataset=None,
train_transform=None,
test_transform=None):
"""Class containing the different datasets needed for an active learning loop, i.e. unlabelled, labelled and test set
Args:
dataset_labeled (Dataset): labeled dataset of images from which to load the data into the model
dataset_unlabeled (Dataset): unlabeled dataset of images from which to load the data into the model
dataset_test (Dataset): unlabeled dataset of images from which to load the data into the model
train_transform (list): list of transformation objects to apply on the dataset for training
test_transform (list): list of transformation objects to apply on the dataset for the evaluation
Returns:
None
"""
self.dataset_labeled = dataset_labeled
self.dataset_unlabeled = dataset_unlabeled
self.dataset_test = dataset_test
self.train_transform = train_transform
self.test_transform = test_transform
@staticmethod
def _set_transform(dataset: Dataset, transform):
"""
Args:
dataset (Dataset): the dataset to apply the transformation to
transform (list): list of transformation methods
Returns:
None
"""
if isinstance(dataset, Subset):
subset = dataset
setattr(subset.dataset, "transform", transform)
else:
setattr(dataset, "transform", transform)
[docs] def set_train_mode(self,):
"""Put the datasets in training mode with the corresponding transformations
"""
ALDataModule._set_transform(self.dataset_labeled, self.train_transform)
ALDataModule._set_transform(self.dataset_unlabeled, self.train_transform)
if self.dataset_test is not None:
ALDataModule._set_transform(self.dataset_test, self.train_transform)
[docs] def set_test_mode(self,):
"""Put the datasets in testing mode with the corresponding transformations
"""
ALDataModule._set_transform(self.dataset_labeled, self.test_transform)
ALDataModule._set_transform(self.dataset_unlabeled, self.test_transform)
if self.dataset_test is not None:
ALDataModule._set_transform(self.dataset_test, self.test_transform)