Source code for HALF.Hooks.TestAccuracyLoggerHook
from HALF.Interfaces.IHook import IHook, ActiveLearner
from dataclasses import dataclass
from typing import List
from HALF.Utils.Registry import REGISTRY
from HALF.Commons.data import DataLogAccuracy
[docs]@REGISTRY.register()
class TestAccuracyLoggerHook(IHook):
[docs] def apply(self, al: ActiveLearner):
"""Log test accuracy
Args:
al (ActiveLearner) : The ActiveLearner whose test accuracy should be logged
"""
print("TestAccuracyLoggerHook")
if "test_logs" not in al.ds_hook:
al.ds_hook["test_logs"] : List[DataLogAccuracy]= list()
current_test_accuracy = al.current_test_results[0]["test_acc"]
size_lb_ds = len(al.dataset_manager.dataset_labeled)
al.ds_hook["test_logs"].append(DataLogAccuracy(lb_dataset_size=size_lb_ds,
test_acc=current_test_accuracy))