Source code for HALF.Utils.ActiveLearner

from HALF.Interfaces.AbstractModel import AbstractModel
from HALF.Interfaces import IOracle
from HALF.Utils.ALDatasetManager import ALDatasetManager
import pytorch_lightning as pl
from HALF.Delegates.DatasetIncreaseDelegate import DatasetIncreaseDelegate
from HALF.Delegates.DatasetTestingDelegate import DatasetTestingDelegate
from HALF.Delegates.ModelTrainingDelegate import ModelTrainingDelegate
from HALF.Utils.ConfigActiveLearner import ConfigActiveLearner
from HALF.Commons.const import ModelStrategy
from HALF.Interfaces.IHook import IHook
from torch.utils.data import Dataset
from typing import List, DefaultDict
from collections import defaultdict
from HALF.Utils.Config import Config
import torch
import logging
from HALF import Strategies
CONFIG = Config.CONFIG


[docs]class ActiveLearner: """Class following a Mediator pattern, handling the communication between all the components of an active learning loop : datasets, model and oracle Attributes: oracle (IOracle) : The oracle, in charge of labelling points dataset_manager (ALDatasetManager) : Manages train, test, (un)labeled datasets model (AbstractModel): The predictive model, by default None dict_hooks (DefaultDict[str, List[IHook]]) : Dictionary of hooks to be applied for each hook identifier ds_hook (Dict) : Contains data manipulated by hooks i_round (int) : Active learning round index best_test_accuracy (float) : Best test accuracy observed during current active learning process nb_oracle_labeled (int) : Number of datapoints labeled by the oracle configAL (ConfigActiveLearner) : Configuration for the active learning process trainer (pytorch_lightning.Trainer) : The model trainer, initialized with configAL nb_label_increased (int) : Number of new labels obtained in the latest iteration current_test_results : Contains the trainer's latest test results """
[docs] def __init__(self, oracle: IOracle, dataset_manager: ALDatasetManager, model: AbstractModel=None): """ Args: oracle (IOracle): The oracle, in charge of labelling points dataset_manager (ALDatasetManager): Manages train, test, (un)labeled datasets model (AbstractModel, optional): The predictive model. Defaults to None. """ self.model = model self.dataset_manager = dataset_manager self.oracle = oracle self.dict_hooks: DefaultDict[str, List[IHook]] = defaultdict(list) self.ds_hook = dict() self.dynamic_config = defaultdict(dict) self.i_round = None self.best_test_accuracy = None self.nb_oracle_labeled = None self.configAL = None self.trainer = None self.nb_label_increased = None self.current_test_results = None
@property def dataset_test(self) -> Dataset: return self.dataset_manager.dataset_test
[docs] def register_hook(self, stage: str, hook: IHook): self.dict_hooks[stage].append(hook)
[docs] def set_list_hook(self, stage: str, list_hooks: List[IHook]): self.dict_hooks[stage] = list_hooks
def _apply_hooks(self, stage: str): for hook in self.dict_hooks[stage] : hook.apply(self)
[docs] def apply_model_strategy(self, strategy:ModelStrategy=ModelStrategy.UPDATE): """Applies strategy to the model Args: strategy (ModelStrategy) : The strategy to apply, by default ModelStrategy.UPDATE """ if strategy == ModelStrategy.RESET: self.model.reset() elif strategy == ModelStrategy.UPDATE: pass else: pass
[docs] def test(self): """Collects test results for the current model """ test_retval = DatasetTestingDelegate( model=self.model, dataset_manager=self.dataset_manager, trainer=self.trainer, config_dl=CONFIG.test.dataloader ).run() self.current_test_results = test_retval
def _build_trainer(self): """Builds pytorch_lightning trainer Returns: pytorch_lightning.Trainer: The model trainer """ return pl.Trainer(**CONFIG.trainer, **self.dynamic_config.get("trainer", {}))
[docs] def train(self): """Fits model to current data """ retval = ModelTrainingDelegate( model=self.model, trainer=self.trainer, config_dl=CONFIG.train.dataloader, dataset_manager=self.dataset_manager ).run() self.current_train_results = retval
[docs] def increase_label(self, configAL: ConfigActiveLearner): """Samples new points to be labeled and updates training data accordingly Args: configAL (ConfigActiveLearner) : configuration """ ulb_idx_selected = DatasetIncreaseDelegate( configAL=configAL, dataset_manager=self.dataset_manager, model=self.model, oracle=self.oracle).run() self.nb_label_increased = len(ulb_idx_selected) logging.info(f"Adding {self.nb_label_increased} new labels")
[docs] def stop_condition(self) -> bool: """Checks whether stopping conditions (budget or targeted accuracy) have been reached. Returns: bool : True if stopping conditions were attained or no configuration is set """ if self.configAL is not None: is_over_budget = self.nb_oracle_labeled + self.configAL.AL_batch_size > self.configAL.budget is_enough_acc = self.best_test_accuracy >= self.configAL.targeted_accuracy if is_over_budget: logging.info("Budget has been reached, stopping AL process") return False elif is_enough_acc: logging.info(f"Targeted accuracy [{self.configAL.targeted_accuracy}] has been reached : {self.best_test_accuracy}, stopping AL process") return False return True
[docs] def run(self, configAL: ConfigActiveLearner): """Runs active learning process until stop condition is attained Args: configAL (ConfigActiveLearner) : configuration """ self.configAL = configAL self.on_AL_loop_begin() while (self.stop_condition()): self.on_AL_iteration_begin() self.on_increase_dataset_begin() self.increase_label(configAL=configAL) self.on_increase_dataset_end() self.on_train_begin() self.train() self.on_train_end() self.on_test_begin() self.test() self.on_test_end() self.on_AL_iteration_end() self.on_AL_loop_end()
[docs] def on_increase_dataset_begin(self): self._apply_hooks(stage="on_increase_dataset_begin")
[docs] def on_increase_dataset_end(self): self._apply_hooks(stage="on_increase_dataset_end")
[docs] def on_train_begin(self): self.apply_model_strategy(CONFIG.active_learner.model.reset_strategy) self._apply_hooks(stage="on_train_begin")
[docs] def on_train_end(self): torch.cuda.empty_cache() self._apply_hooks(stage="on_train_end")
[docs] def on_test_begin(self): self._apply_hooks(stage="on_test_begin")
[docs] def on_test_end(self): self._apply_hooks(stage="on_test_end")
[docs] def on_AL_loop_begin(self): """Set up AL process""" self.i_round = 1 self.nb_oracle_labeled = 0 self.best_test_accuracy = 0 self.trainer = self._build_trainer() self._apply_hooks(stage="on_AL_loop_begin") self.trainer = self._build_trainer() self.on_train_begin() self.train() self.on_train_end() self.on_test_begin() self.test() self.on_test_end() self.trainer = None
[docs] def on_AL_loop_end(self): """Terminate AL process """ self.i_round = None self.configAL = None self.ds_hook = dict() self._apply_hooks(stage="on_AL_loop_end")
[docs] def on_AL_iteration_begin(self): """Set up new AL iteration """ print("="*30) print(f"Round {self.i_round}") print("="*30) self._apply_hooks(stage="on_AL_iteration_begin") self.trainer = self._build_trainer()
[docs] def on_AL_iteration_end(self): """Terminate current AL iteration """ self.i_round += 1 self.nb_oracle_labeled += self.nb_label_increased self.nb_label_increased = None self._apply_hooks(stage="on_AL_iteration_end")