Source code for HALF.Utils.ALDatasetManager
from torch.utils.data.dataset import Dataset, ConcatDataset, Subset
from HALF.Utils.ALDataModule import ALDataModule
from distil.utils.utils import LabeledToUnlabeledDataset
import numpy as np
[docs]class ALDatasetManager:
[docs] def __init__(self, al_data_module: ALDataModule):
"""
Class handling operations and manipulations on the datasets
Args:
al_data_module (ALDataModule): contains different datasets for active learning
Returns:
None
"""
self.al_data_module = al_data_module
self.seed_dataset_labeled = self.al_data_module.dataset_labeled
self.list_late_lb_dataset = []
self.full_dataset_labeled = ConcatDataset([self.al_data_module.dataset_labeled,
self.al_data_module.dataset_unlabeled])
self._gen_global_indexes(self.al_data_module.dataset_labeled,
self.al_data_module.dataset_unlabeled)
[docs] def set_train_mode(self):
"""Trains the dataset with corresponding transformations
"""
self.al_data_module.set_train_mode()
[docs] def set_test_mode(self):
"""Tests the dataset with corresponding transformations
"""
self.al_data_module.set_test_mode()
@property
def dataset_test(self):
"""Gets the test dataset
Args:
None
Returns:
Dataset: Test set
"""
return self.al_data_module.dataset_test
@property
def dataset_labeled(self):
"""Get the labeled dataset
Args:
None
Returns:
Dataset: labeled dataset
"""
return ConcatDataset([self.seed_dataset_labeled, *self.list_late_lb_dataset])
# TODO delete, for testing only
@property
def dataset_unlabeled_lb(self):
return Subset(self.full_dataset_labeled, self.gi_ds_ulb)
@property
def dataset_unlabeled(self):
return LabeledToUnlabeledDataset(Subset(self.full_dataset_labeled, self.gi_ds_ulb))
def _gen_global_indexes(self, dataset_labeled:Dataset, dataset_unlabeled:Dataset):
len_ds_lb = len(dataset_labeled)
len_ds_ulb = len(dataset_unlabeled)
self.start_idx_ulb = len_ds_lb
self.gi_ds_lb = np.arange(len_ds_lb)
self.gi_ds_ulb = np.arange(len_ds_lb, len_ds_ulb+len_ds_lb)
[docs] def update(self, ulb_idx_selected:np.array, subds_labeled:Dataset):
"""Updates the labelled and unlabelled datasets and keeping track of their indices
Args:
ulb_idx_selected (np.array): array of indices selected with the active learning strategy with respect to the unlabeled dataset
subds_labeled (Dataset): subset of the unlabeled dataset corresponding to the indices selected by the strategy
Returns:
None
"""
gi_ulb_selected = self.gi_ds_ulb[ulb_idx_selected]
self.gi_ds_lb = np.concatenate((self.gi_ds_lb, gi_ulb_selected))
self.gi_ds_ulb = np.delete(self.gi_ds_ulb, ulb_idx_selected)
self.list_late_lb_dataset.append(subds_labeled)
[docs] def subset_unlabeled(self, ulb_idx_selected:np.array, label: bool=False) -> Subset:
"""Get a subset of the unlabelled dataset
Args:
ulb_idx_selected (np.array): array of indices you want in the Subset
label (bool): True if you want to get the labels in the Data
Returns:
Subset : subset of the unlabeled dataset with the given indices
"""
if not label:
return Subset(self.dataset_unlabeled, ulb_idx_selected)
else:
return Subset(self.dataset_unlabeled_lb, ulb_idx_selected)