Source code for HALF.Delegates.DatasetIncreaseDelegate

from HALF.Utils.ConfigActiveLearner import ConfigActiveLearner
import pytorch_lightning as pl
from HALF.Utils.ALDatasetManager import ALDatasetManager
from HALF.Interfaces.IOracle import IOracle
from HALF.Utils.Registry import REGISTRY
from HALF.Interfaces.IDelegate import IDelegate
from distil.active_learning_strategies.strategy import Strategy as Distil_Strategy
from HALF.Interfaces.IUnifiedStrategy import IUnifiedStrategy

from typing import Union, List

[docs]class DatasetIncreaseDelegate(IDelegate): """ Class managing the selection of queries, handing them to the Oracle and the subsequent update of the labelled and unlabelled datasets """
[docs] def __init__(self, configAL: ConfigActiveLearner, model: pl.LightningModule, dataset_manager: ALDatasetManager, oracle: IOracle): """Class managing the selection of queries, ending them to the Oracle and the subsequent update of the labelled and unlabelled datasets Args: configAL (ConfigActiveLearner) : confif containing all parameters model (pl.LightningModule) : model to train in the ative learning loop dataset_manager (ALDatasetManager) : object manipulating the datasets oracle (IOracle) : object giving the labels at each iteration of the active learning loop """ Strategy = REGISTRY[configAL.strategy_name] if issubclass(Strategy, Distil_Strategy): self.strategy = Strategy(labeled_dataset=dataset_manager.dataset_labeled, unlabeled_dataset=dataset_manager.dataset_unlabeled, net=model, nclasses = model.num_classes, args=configAL.strategy_args) elif issubclass(Strategy, IUnifiedStrategy): self.strategy = Strategy(model=model, args = configAL.strategy_args, dataset_manager=dataset_manager) self.budget = configAL.AL_batch_size self.dataset_manager = dataset_manager self.oracle = oracle
def _select_idx_distil(self, strategy:Union[Distil_Strategy, IUnifiedStrategy], budget: int) -> List[int]: """Select the indices of the samples to label with respect to their indices in the unlabelled dataset Args: strategy (Union[Distil_Strategy, IUnifiedStrategy]): strategy used to select the samples budget (int): Number of samples to select Returns: List[int]: list of indices of samples to labels with respect to their order in the unlabelled dataset """ self.dataset_manager.set_test_mode() ulb_idx_selected = strategy.select(budget) self.dataset_manager.set_train_mode() return ulb_idx_selected
[docs] def run(self) -> List[int]: """Select the indices of the samples to label with respect to their indices in the unlabelled dataset and update the datasets to move the samples from the unlabelled dataset to the labelled dataset Returns: List[int]: list of indices of samples to labels with respect to their order in the unlabelled dataset """ if isinstance(self.strategy, Distil_Strategy): ulb_idx_selected = self._select_idx_distil(self.strategy, self.budget) else: ulb_idx_selected = self._select_idx_distil(self.strategy, self.budget) subds_tb_labeled = self.dataset_manager.subset_unlabeled(ulb_idx_selected, label=True) subds_labeled = self.oracle.query(subds_tb_labeled) self.dataset_manager.update(ulb_idx_selected=ulb_idx_selected, subds_labeled=subds_labeled) return ulb_idx_selected