Source code for HALF.Utils.ALDatasetManager

from torch.utils.data.dataset import Dataset, ConcatDataset, Subset
from HALF.Utils.ALDataModule import ALDataModule
from distil.utils.utils import LabeledToUnlabeledDataset 
import numpy as np

[docs]class ALDatasetManager:
[docs] def __init__(self, al_data_module: ALDataModule): """ Class handling operations and manipulations on the datasets Args: al_data_module (ALDataModule): contains different datasets for active learning Returns: None """ self.al_data_module = al_data_module self.seed_dataset_labeled = self.al_data_module.dataset_labeled self.list_late_lb_dataset = [] self.full_dataset_labeled = ConcatDataset([self.al_data_module.dataset_labeled, self.al_data_module.dataset_unlabeled]) self._gen_global_indexes(self.al_data_module.dataset_labeled, self.al_data_module.dataset_unlabeled)
[docs] def set_train_mode(self): """Trains the dataset with corresponding transformations """ self.al_data_module.set_train_mode()
[docs] def set_test_mode(self): """Tests the dataset with corresponding transformations """ self.al_data_module.set_test_mode()
@property def dataset_test(self): """Gets the test dataset Args: None Returns: Dataset: Test set """ return self.al_data_module.dataset_test @property def dataset_labeled(self): """Get the labeled dataset Args: None Returns: Dataset: labeled dataset """ return ConcatDataset([self.seed_dataset_labeled, *self.list_late_lb_dataset]) # TODO delete, for testing only @property def dataset_unlabeled_lb(self): return Subset(self.full_dataset_labeled, self.gi_ds_ulb) @property def dataset_unlabeled(self): return LabeledToUnlabeledDataset(Subset(self.full_dataset_labeled, self.gi_ds_ulb)) def _gen_global_indexes(self, dataset_labeled:Dataset, dataset_unlabeled:Dataset): len_ds_lb = len(dataset_labeled) len_ds_ulb = len(dataset_unlabeled) self.start_idx_ulb = len_ds_lb self.gi_ds_lb = np.arange(len_ds_lb) self.gi_ds_ulb = np.arange(len_ds_lb, len_ds_ulb+len_ds_lb)
[docs] def update(self, ulb_idx_selected:np.array, subds_labeled:Dataset): """Updates the labelled and unlabelled datasets and keeping track of their indices Args: ulb_idx_selected (np.array): array of indices selected with the active learning strategy with respect to the unlabeled dataset subds_labeled (Dataset): subset of the unlabeled dataset corresponding to the indices selected by the strategy Returns: None """ gi_ulb_selected = self.gi_ds_ulb[ulb_idx_selected] self.gi_ds_lb = np.concatenate((self.gi_ds_lb, gi_ulb_selected)) self.gi_ds_ulb = np.delete(self.gi_ds_ulb, ulb_idx_selected) self.list_late_lb_dataset.append(subds_labeled)
[docs] def subset_unlabeled(self, ulb_idx_selected:np.array, label: bool=False) -> Subset: """Get a subset of the unlabelled dataset Args: ulb_idx_selected (np.array): array of indices you want in the Subset label (bool): True if you want to get the labels in the Data Returns: Subset : subset of the unlabeled dataset with the given indices """ if not label: return Subset(self.dataset_unlabeled, ulb_idx_selected) else: return Subset(self.dataset_unlabeled_lb, ulb_idx_selected)