Source code for HALF.Hooks.TestAccuracyLoggerHook

from HALF.Interfaces.IHook import IHook, ActiveLearner
from dataclasses import dataclass
from typing import List
from HALF.Utils.Registry import REGISTRY
from HALF.Commons.data import DataLogAccuracy


[docs]@REGISTRY.register() class TestAccuracyLoggerHook(IHook):
[docs] def __init__(self,): """Updates test accuracy logs """ IHook.__init__(self)
[docs] def apply(self, al: ActiveLearner): """Log test accuracy Args: al (ActiveLearner) : The ActiveLearner whose test accuracy should be logged """ print("TestAccuracyLoggerHook") if "test_logs" not in al.ds_hook: al.ds_hook["test_logs"] : List[DataLogAccuracy]= list() current_test_accuracy = al.current_test_results[0]["test_acc"] size_lb_ds = len(al.dataset_manager.dataset_labeled) al.ds_hook["test_logs"].append(DataLogAccuracy(lb_dataset_size=size_lb_ds, test_acc=current_test_accuracy))