Source code for HALF.Utils.ALDataModule

from torch.utils.data.dataset import Dataset, Subset

[docs]class ALDataModule:
[docs] def __init__(self, dataset_labeled: Dataset, dataset_unlabeled: Dataset, dataset_test: Dataset=None, train_transform=None, test_transform=None): """Class containing the different datasets needed for an active learning loop, i.e. unlabelled, labelled and test set Args: dataset_labeled (Dataset): labeled dataset of images from which to load the data into the model dataset_unlabeled (Dataset): unlabeled dataset of images from which to load the data into the model dataset_test (Dataset): unlabeled dataset of images from which to load the data into the model train_transform (list): list of transformation objects to apply on the dataset for training test_transform (list): list of transformation objects to apply on the dataset for the evaluation Returns: None """ self.dataset_labeled = dataset_labeled self.dataset_unlabeled = dataset_unlabeled self.dataset_test = dataset_test self.train_transform = train_transform self.test_transform = test_transform
@staticmethod def _set_transform(dataset: Dataset, transform): """ Args: dataset (Dataset): the dataset to apply the transformation to transform (list): list of transformation methods Returns: None """ if isinstance(dataset, Subset): subset = dataset setattr(subset.dataset, "transform", transform) else: setattr(dataset, "transform", transform)
[docs] def set_train_mode(self,): """Put the datasets in training mode with the corresponding transformations """ ALDataModule._set_transform(self.dataset_labeled, self.train_transform) ALDataModule._set_transform(self.dataset_unlabeled, self.train_transform) if self.dataset_test is not None: ALDataModule._set_transform(self.dataset_test, self.train_transform)
[docs] def set_test_mode(self,): """Put the datasets in testing mode with the corresponding transformations """ ALDataModule._set_transform(self.dataset_labeled, self.test_transform) ALDataModule._set_transform(self.dataset_unlabeled, self.test_transform) if self.dataset_test is not None: ALDataModule._set_transform(self.dataset_test, self.test_transform)