from HALF.Utils.ConfigActiveLearner import ConfigActiveLearner
import pytorch_lightning as pl
from HALF.Utils.ALDatasetManager import ALDatasetManager
from HALF.Interfaces.IOracle import IOracle
from HALF.Utils.Registry import REGISTRY
from HALF.Interfaces.IDelegate import IDelegate
from distil.active_learning_strategies.strategy import Strategy as Distil_Strategy
from HALF.Interfaces.IUnifiedStrategy import IUnifiedStrategy
from typing import Union, List
[docs]class DatasetIncreaseDelegate(IDelegate):
"""
Class managing the selection of queries, handing them to the Oracle and the subsequent update of the
labelled and unlabelled datasets
"""
[docs] def __init__(self, configAL: ConfigActiveLearner,
model: pl.LightningModule,
dataset_manager: ALDatasetManager,
oracle: IOracle):
"""Class managing the selection of queries, ending them to the Oracle and the subsequent update of the labelled and unlabelled datasets
Args:
configAL (ConfigActiveLearner) : confif containing all parameters
model (pl.LightningModule) : model to train in the ative learning loop
dataset_manager (ALDatasetManager) : object manipulating the datasets
oracle (IOracle) : object giving the labels at each iteration of the active learning loop
"""
Strategy = REGISTRY[configAL.strategy_name]
if issubclass(Strategy, Distil_Strategy):
self.strategy = Strategy(labeled_dataset=dataset_manager.dataset_labeled,
unlabeled_dataset=dataset_manager.dataset_unlabeled,
net=model,
nclasses = model.num_classes,
args=configAL.strategy_args)
elif issubclass(Strategy, IUnifiedStrategy):
self.strategy = Strategy(model=model, args = configAL.strategy_args, dataset_manager=dataset_manager)
self.budget = configAL.AL_batch_size
self.dataset_manager = dataset_manager
self.oracle = oracle
def _select_idx_distil(self, strategy:Union[Distil_Strategy, IUnifiedStrategy], budget: int) -> List[int]:
"""Select the indices of the samples to label with respect to their indices in the unlabelled dataset
Args:
strategy (Union[Distil_Strategy, IUnifiedStrategy]): strategy used to select the samples
budget (int): Number of samples to select
Returns:
List[int]: list of indices of samples to labels with respect to their order in the unlabelled dataset
"""
self.dataset_manager.set_test_mode()
ulb_idx_selected = strategy.select(budget)
self.dataset_manager.set_train_mode()
return ulb_idx_selected
[docs] def run(self) -> List[int]:
"""Select the indices of the samples to label with respect to their indices in the unlabelled dataset and update the datasets to move the samples
from the unlabelled dataset to the labelled dataset
Returns:
List[int]: list of indices of samples to labels with respect to their order in the unlabelled dataset
"""
if isinstance(self.strategy, Distil_Strategy):
ulb_idx_selected = self._select_idx_distil(self.strategy, self.budget)
else:
ulb_idx_selected = self._select_idx_distil(self.strategy, self.budget)
subds_tb_labeled = self.dataset_manager.subset_unlabeled(ulb_idx_selected, label=True)
subds_labeled = self.oracle.query(subds_tb_labeled)
self.dataset_manager.update(ulb_idx_selected=ulb_idx_selected,
subds_labeled=subds_labeled)
return ulb_idx_selected