Source code for HALF.Hooks.SetupTensorboardHook
from pytorch_lightning.loggers import TensorBoardLogger
import logging
from HALF.Interfaces.IHook import IHook
import torch as T
from HALF.Utils.utils import now_str
import os
from HALF.Utils.ActiveLearner import ActiveLearner
from HALF.Utils.Registry import REGISTRY
[docs]@REGISTRY.register()
class SetupTensorboardHook(IHook):
[docs] def __init__(self, experiment_name: str = "experiment_name",
dir_path: str=None, with_time=False):
self.experiment_name = experiment_name
self.dir_path = dir_path
self.with_time = with_time
[docs] def get_experiment_name(self):
file_name = (f"{self.experiment_name}_{now_str()}"
if self.with_time else f"{self.experiment_name}")
return file_name
[docs] def apply(self, al: ActiveLearner):
logger = TensorBoardLogger(self.dir_path, name=self.get_experiment_name())
# logger = TensorBoardLogger("tb_logs", name="my_model")
al.dynamic_config["trainer"]["logger"] = logger
logging.info(f"Adding tensorboard to logger configuration")