from HALF.Interfaces.AbstractModel import AbstractModel
from HALF.Interfaces import IOracle
from HALF.Utils.ALDatasetManager import ALDatasetManager
import pytorch_lightning as pl
from HALF.Delegates.DatasetIncreaseDelegate import DatasetIncreaseDelegate
from HALF.Delegates.DatasetTestingDelegate import DatasetTestingDelegate
from HALF.Delegates.ModelTrainingDelegate import ModelTrainingDelegate
from HALF.Utils.ConfigActiveLearner import ConfigActiveLearner
from HALF.Commons.const import ModelStrategy
from HALF.Interfaces.IHook import IHook
from torch.utils.data import Dataset
from typing import List, DefaultDict
from collections import defaultdict
from HALF.Utils.Config import Config
import torch
import logging
from HALF import Strategies
CONFIG = Config.CONFIG
[docs]class ActiveLearner:
"""Class following a Mediator pattern, handling the communication between all the components
of an active learning loop : datasets, model and oracle
Attributes:
oracle (IOracle) : The oracle, in charge of labelling points
dataset_manager (ALDatasetManager) : Manages train, test, (un)labeled datasets
model (AbstractModel): The predictive model, by default None
dict_hooks (DefaultDict[str, List[IHook]]) : Dictionary of hooks to be applied for each hook identifier
ds_hook (Dict) : Contains data manipulated by hooks
i_round (int) : Active learning round index
best_test_accuracy (float) : Best test accuracy observed during current active learning process
nb_oracle_labeled (int) : Number of datapoints labeled by the oracle
configAL (ConfigActiveLearner) : Configuration for the active learning process
trainer (pytorch_lightning.Trainer) : The model trainer, initialized with configAL
nb_label_increased (int) : Number of new labels obtained in the latest iteration
current_test_results : Contains the trainer's latest test results
"""
[docs] def __init__(self,
oracle: IOracle,
dataset_manager: ALDatasetManager,
model: AbstractModel=None):
"""
Args:
oracle (IOracle): The oracle, in charge of labelling points
dataset_manager (ALDatasetManager): Manages train, test, (un)labeled datasets
model (AbstractModel, optional): The predictive model. Defaults to None.
"""
self.model = model
self.dataset_manager = dataset_manager
self.oracle = oracle
self.dict_hooks: DefaultDict[str, List[IHook]] = defaultdict(list)
self.ds_hook = dict()
self.dynamic_config = defaultdict(dict)
self.i_round = None
self.best_test_accuracy = None
self.nb_oracle_labeled = None
self.configAL = None
self.trainer = None
self.nb_label_increased = None
self.current_test_results = None
@property
def dataset_test(self) -> Dataset:
return self.dataset_manager.dataset_test
[docs] def register_hook(self, stage: str, hook: IHook):
self.dict_hooks[stage].append(hook)
[docs] def set_list_hook(self, stage: str, list_hooks: List[IHook]):
self.dict_hooks[stage] = list_hooks
def _apply_hooks(self, stage: str):
for hook in self.dict_hooks[stage] :
hook.apply(self)
[docs] def apply_model_strategy(self, strategy:ModelStrategy=ModelStrategy.UPDATE):
"""Applies strategy to the model
Args:
strategy (ModelStrategy) : The strategy to apply, by default ModelStrategy.UPDATE
"""
if strategy == ModelStrategy.RESET:
self.model.reset()
elif strategy == ModelStrategy.UPDATE:
pass
else:
pass
[docs] def test(self):
"""Collects test results for the current model
"""
test_retval = DatasetTestingDelegate(
model=self.model,
dataset_manager=self.dataset_manager,
trainer=self.trainer,
config_dl=CONFIG.test.dataloader
).run()
self.current_test_results = test_retval
def _build_trainer(self):
"""Builds pytorch_lightning trainer
Returns:
pytorch_lightning.Trainer:
The model trainer
"""
return pl.Trainer(**CONFIG.trainer, **self.dynamic_config.get("trainer", {}))
[docs] def train(self):
"""Fits model to current data
"""
retval = ModelTrainingDelegate(
model=self.model,
trainer=self.trainer,
config_dl=CONFIG.train.dataloader,
dataset_manager=self.dataset_manager
).run()
self.current_train_results = retval
[docs] def increase_label(self, configAL: ConfigActiveLearner):
"""Samples new points to be labeled and updates training data accordingly
Args:
configAL (ConfigActiveLearner) : configuration
"""
ulb_idx_selected = DatasetIncreaseDelegate(
configAL=configAL,
dataset_manager=self.dataset_manager,
model=self.model,
oracle=self.oracle).run()
self.nb_label_increased = len(ulb_idx_selected)
logging.info(f"Adding {self.nb_label_increased} new labels")
[docs] def stop_condition(self) -> bool:
"""Checks whether stopping conditions (budget or targeted accuracy) have been reached.
Returns:
bool : True if stopping conditions were attained or no configuration is set
"""
if self.configAL is not None:
is_over_budget = self.nb_oracle_labeled + self.configAL.AL_batch_size > self.configAL.budget
is_enough_acc = self.best_test_accuracy >= self.configAL.targeted_accuracy
if is_over_budget:
logging.info("Budget has been reached, stopping AL process")
return False
elif is_enough_acc:
logging.info(f"Targeted accuracy [{self.configAL.targeted_accuracy}] has been reached : {self.best_test_accuracy}, stopping AL process")
return False
return True
[docs] def run(self, configAL: ConfigActiveLearner):
"""Runs active learning process until stop condition is attained
Args:
configAL (ConfigActiveLearner) : configuration
"""
self.configAL = configAL
self.on_AL_loop_begin()
while (self.stop_condition()):
self.on_AL_iteration_begin()
self.on_increase_dataset_begin()
self.increase_label(configAL=configAL)
self.on_increase_dataset_end()
self.on_train_begin()
self.train()
self.on_train_end()
self.on_test_begin()
self.test()
self.on_test_end()
self.on_AL_iteration_end()
self.on_AL_loop_end()
[docs] def on_increase_dataset_begin(self):
self._apply_hooks(stage="on_increase_dataset_begin")
[docs] def on_increase_dataset_end(self):
self._apply_hooks(stage="on_increase_dataset_end")
[docs] def on_train_begin(self):
self.apply_model_strategy(CONFIG.active_learner.model.reset_strategy)
self._apply_hooks(stage="on_train_begin")
[docs] def on_train_end(self):
torch.cuda.empty_cache()
self._apply_hooks(stage="on_train_end")
[docs] def on_test_begin(self):
self._apply_hooks(stage="on_test_begin")
[docs] def on_test_end(self):
self._apply_hooks(stage="on_test_end")
[docs] def on_AL_loop_begin(self):
"""Set up AL process"""
self.i_round = 1
self.nb_oracle_labeled = 0
self.best_test_accuracy = 0
self.trainer = self._build_trainer()
self._apply_hooks(stage="on_AL_loop_begin")
self.trainer = self._build_trainer()
self.on_train_begin()
self.train()
self.on_train_end()
self.on_test_begin()
self.test()
self.on_test_end()
self.trainer = None
[docs] def on_AL_loop_end(self):
"""Terminate AL process
"""
self.i_round = None
self.configAL = None
self.ds_hook = dict()
self._apply_hooks(stage="on_AL_loop_end")
[docs] def on_AL_iteration_begin(self):
"""Set up new AL iteration
"""
print("="*30)
print(f"Round {self.i_round}")
print("="*30)
self._apply_hooks(stage="on_AL_iteration_begin")
self.trainer = self._build_trainer()
[docs] def on_AL_iteration_end(self):
"""Terminate current AL iteration
"""
self.i_round += 1
self.nb_oracle_labeled += self.nb_label_increased
self.nb_label_increased = None
self._apply_hooks(stage="on_AL_iteration_end")