from HALF.Utils.ActiveLearner import ActiveLearner
from HALF.Utils.ALDataModule import ALDataModule
from HALF.Utils.ALDatasetManager import ALDatasetManager
from HALF.Interfaces.IOracle import IOracle
from HALF.Interfaces.IHook import IHook
from HALF.Utils.Registry import REGISTRY
from HALF.Utils.Config import Config
from HALF import Hooks
import logging
from dataclasses import dataclass
from typing import Dict
import pytorch_lightning as pl
[docs]@dataclass
class HookAttachData:
attach_point: str
class_name: str
args: dict
[docs]class ActiveLearnerBuilder:
[docs] def __init__(self):
"""Builder class for ``ActiveLearner``
"""
self.reset()
[docs] def add_hook(self, uid: str, hook_attach_data: HookAttachData):
"""Attach a hook to the ActiveLearner
Args:
uid (str): unique identifier for the hook
hook_attach_data (HookAttachData): information about the hook
Returns:
None
"""
self.hook_dict[uid] = hook_attach_data
[docs] def remove_hook(self, uid: str):
"""Remove the hook attached to the given uid, if it does not exist nothing happens
Args:
uid (str): unique identifier of the hook you wish to remove
Returns:
None
"""
self.hook_dict.pop(uid, None)
logging.warning(f"key {uid} was not linked to any hook, ignoring")
[docs] def reset(self):
"""Clean the Builder and initialize the necessary attributes
"""
self.model = None
self.data_module = None
self.oracle = None
self.hook_dict: Dict[str, HookAttachData] = dict()
self.active_learner = None
[docs] def hydrate_hook_dict(self):
"""Attach the hooks from the config to the ``ActiveLearner``
Returns:
None
"""
config_hook: Dict = Config.CONFIG.get("active_learner", {}).get("hooks", {})
for uid, kwargs in config_hook.items():
self.hook_dict[uid] = HookAttachData(**kwargs)
[docs] def set_model(self, model: pl.LightningModule):
"""Set the model to be used in the active learning loop
Args:
model (pl.LightningModule) : model to be used
"""
self.model = model
[docs] def set_data_module(self, al_data_module: ALDataModule):
"""Set the data module containing the datasets
Args:
al_data_module (ALDataModule): data module to be used
"""
self.data_module = al_data_module
[docs] def set_oracle(self, oracle: IOracle):
"""Set the oracle
Args:
oracle (IOracle) : oracle to be used
"""
self.oracle = oracle
def _setting_registered_hooks(self):
"""Sets the differents hooks
"""
for hook_attach_data in self.hook_dict.values():
Hook = REGISTRY[hook_attach_data.class_name]
hook = Hook(**hook_attach_data.args)
self.active_learner.register_hook(stage=hook_attach_data.attach_point,
hook=hook)
[docs] def add_defaults_hooks(self):
"""Add the default hooks for the ``ActiveLearner``
"""
self.hook_dict["TestAccuracyLoggerHook"] = HookAttachData(
attach_point="on_test_end",
class_name="TestAccuracyLoggerHook",
args={})
# self.hook_dict["WriteAccuracyLogsHook"] = HookAttachData(
# attach_point="on_test_end",
# class_name="WriteAccuracyLogsHook",
# args={})
[docs] def build(self, same_instance : bool = False, use_configured_hooks: bool =True) -> ActiveLearner:
"""Finalise the building of the ActiveLearner instance
Args:
same_instance (bool): check if we want to obtain the current build of the ActiveLearner
use_configured_hooks (bool): use the hooks given in the YAML configuration
Returns:
ActiveLearner : Active learner instance
"""
if same_instance and self.active_learner is not None:
return self.active_learner
active_learner = ActiveLearner(oracle=self.oracle,
model=self.model,
dataset_manager=ALDatasetManager(self.data_module))
self.active_learner = active_learner
if use_configured_hooks:
self.hydrate_hook_dict()
self._setting_registered_hooks()
return self.active_learner