Source code for HALF.Hooks.LoadModelHook

import logging
from HALF.Interfaces.IHook import IHook
from HALF.Utils.utils import now_str
import os
from HALF.Utils.ActiveLearner import ActiveLearner
from HALF.Utils.Registry import REGISTRY
from HALF.Utils.Config import Config
import torch

[docs]@REGISTRY.register() class LoadModelHook(IHook):
[docs] def __init__(self, name: str=None, path: str=None): """Hook to load model parameters from disk Note: if path is provided name will be ignored Args: name (str, optional): name of the model in the config model_path dir. Defaults to None. path (str, optional): path to the model. Defaults to None. """ self.name = name self.path = path self._validate()
def _validate(self): if all(e is None for e in [self.name, self.path]): raise RuntimeError("name or path should be set") elif all(e is not None for e in [self.name, self.path]): raise RuntimeError("name and path cannot be set simultaneously")
[docs] def get_loading_path(self) -> str: """ Returns path from which the model will be loaded """ if self.path: return self.path else: return os.path.join(Config.CONFIG.dir.model_path, self.name)
[docs] def apply(self, al: ActiveLearner): """Load model parameters from disk and assign them to `al`'s model Args: al (ActiveLearner): The ActiveLearner whose model weights should be loaded from disk """ model_path = self.get_loading_path() if os.path.exists(model_path): loaded_data = torch.load(model_path) if "state_dict" in loaded_data: al.model.load_state_dict(loaded_data["state_dict"]) else: al.model.load_state_dict(loaded_data) logging.info(f"Loading model {model_path}") else: logging.warning(f"Path {model_path} doesn't exist, ignoring model loading")