Source code for HALF.Utils.ActiveLearnerBuilder

from HALF.Utils.ActiveLearner import ActiveLearner
from HALF.Utils.ALDataModule import ALDataModule
from HALF.Utils.ALDatasetManager import ALDatasetManager
from HALF.Interfaces.IOracle import IOracle
from HALF.Interfaces.IHook import IHook
from HALF.Utils.Registry import REGISTRY
from HALF.Utils.Config import Config
from HALF import Hooks

import logging
from dataclasses import dataclass
from typing import Dict
import pytorch_lightning as pl

[docs]@dataclass class HookAttachData: attach_point: str class_name: str args: dict
[docs]class ActiveLearnerBuilder:
[docs] def __init__(self): """Builder class for ``ActiveLearner`` """ self.reset()
[docs] def add_hook(self, uid: str, hook_attach_data: HookAttachData): """Attach a hook to the ActiveLearner Args: uid (str): unique identifier for the hook hook_attach_data (HookAttachData): information about the hook Returns: None """ self.hook_dict[uid] = hook_attach_data
[docs] def remove_hook(self, uid: str): """Remove the hook attached to the given uid, if it does not exist nothing happens Args: uid (str): unique identifier of the hook you wish to remove Returns: None """ self.hook_dict.pop(uid, None) logging.warning(f"key {uid} was not linked to any hook, ignoring")
[docs] def reset(self): """Clean the Builder and initialize the necessary attributes """ self.model = None self.data_module = None self.oracle = None self.hook_dict: Dict[str, HookAttachData] = dict() self.active_learner = None
[docs] def hydrate_hook_dict(self): """Attach the hooks from the config to the ``ActiveLearner`` Returns: None """ config_hook: Dict = Config.CONFIG.get("active_learner", {}).get("hooks", {}) for uid, kwargs in config_hook.items(): self.hook_dict[uid] = HookAttachData(**kwargs)
[docs] def set_model(self, model: pl.LightningModule): """Set the model to be used in the active learning loop Args: model (pl.LightningModule) : model to be used """ self.model = model
[docs] def set_data_module(self, al_data_module: ALDataModule): """Set the data module containing the datasets Args: al_data_module (ALDataModule): data module to be used """ self.data_module = al_data_module
[docs] def set_oracle(self, oracle: IOracle): """Set the oracle Args: oracle (IOracle) : oracle to be used """ self.oracle = oracle
def _setting_registered_hooks(self): """Sets the differents hooks """ for hook_attach_data in self.hook_dict.values(): Hook = REGISTRY[hook_attach_data.class_name] hook = Hook(**hook_attach_data.args) self.active_learner.register_hook(stage=hook_attach_data.attach_point, hook=hook)
[docs] def add_defaults_hooks(self): """Add the default hooks for the ``ActiveLearner`` """ self.hook_dict["TestAccuracyLoggerHook"] = HookAttachData( attach_point="on_test_end", class_name="TestAccuracyLoggerHook", args={})
# self.hook_dict["WriteAccuracyLogsHook"] = HookAttachData( # attach_point="on_test_end", # class_name="WriteAccuracyLogsHook", # args={})
[docs] def build(self, same_instance : bool = False, use_configured_hooks: bool =True) -> ActiveLearner: """Finalise the building of the ActiveLearner instance Args: same_instance (bool): check if we want to obtain the current build of the ActiveLearner use_configured_hooks (bool): use the hooks given in the YAML configuration Returns: ActiveLearner : Active learner instance """ if same_instance and self.active_learner is not None: return self.active_learner active_learner = ActiveLearner(oracle=self.oracle, model=self.model, dataset_manager=ALDatasetManager(self.data_module)) self.active_learner = active_learner if use_configured_hooks: self.hydrate_hook_dict() self._setting_registered_hooks() return self.active_learner