Source code for HALF.Hooks.SaveModelHook
import logging
from HALF.Interfaces.IHook import IHook
import torch as T
from HALF.Utils.utils import now_str
import os
from HALF.Utils.ActiveLearner import ActiveLearner
from HALF.Utils.Registry import REGISTRY
[docs]@REGISTRY.register()
class SaveModelHook(IHook):
[docs] def __init__(self, saving_dir: str, name: str=None, with_time:bool=False):
"""Saves model parameters to disk
Args:
saving_dir (str): output directory
name (str, optional): output name for the model. Defaults to None.
with_time (bool, optional): if True, prepend the current time to the output filename. Defaults to False.
"""
self.saving_dir = saving_dir
self.with_time = with_time
self.name = name if name is not None else "model"
[docs] def get_saving_path(self) -> str:
""" Builds the save path
"""
file_name = (f"{now_str()}_{self.name}.ckpt"
if self.with_time else f"{self.name}")
path_save = os.path.join(self.saving_dir, file_name)
return path_save
[docs] def apply(self, al: ActiveLearner):
"""Save model parameters to disk
Args:
al (ActiveLearner) : The ActiveLearner whose model parameters should be saved to disk
"""
T.save(al.model.state_dict(), self.get_saving_path())
logging.info(f"Saving model to {self.get_saving_path()}")