Source code for HALF.Hooks.SaveModelHook

import logging
from HALF.Interfaces.IHook import IHook
import torch as T
from HALF.Utils.utils import now_str
import os
from HALF.Utils.ActiveLearner import ActiveLearner
from HALF.Utils.Registry import REGISTRY

[docs]@REGISTRY.register() class SaveModelHook(IHook):
[docs] def __init__(self, saving_dir: str, name: str=None, with_time:bool=False): """Saves model parameters to disk Args: saving_dir (str): output directory name (str, optional): output name for the model. Defaults to None. with_time (bool, optional): if True, prepend the current time to the output filename. Defaults to False. """ self.saving_dir = saving_dir self.with_time = with_time self.name = name if name is not None else "model"
[docs] def get_saving_path(self) -> str: """ Builds the save path """ file_name = (f"{now_str()}_{self.name}.ckpt" if self.with_time else f"{self.name}") path_save = os.path.join(self.saving_dir, file_name) return path_save
[docs] def apply(self, al: ActiveLearner): """Save model parameters to disk Args: al (ActiveLearner) : The ActiveLearner whose model parameters should be saved to disk """ T.save(al.model.state_dict(), self.get_saving_path()) logging.info(f"Saving model to {self.get_saving_path()}")