Source code for HALF.Hooks.LoadModelHook
import logging
from HALF.Interfaces.IHook import IHook
from HALF.Utils.utils import now_str
import os
from HALF.Utils.ActiveLearner import ActiveLearner
from HALF.Utils.Registry import REGISTRY
from HALF.Utils.Config import Config
import torch
[docs]@REGISTRY.register()
class LoadModelHook(IHook):
[docs] def __init__(self, name: str=None, path: str=None):
"""Hook to load model parameters from disk
Note: if path is provided name will be ignored
Args:
name (str, optional): name of the model in the config model_path dir. Defaults to None.
path (str, optional): path to the model. Defaults to None.
"""
self.name = name
self.path = path
self._validate()
def _validate(self):
if all(e is None for e in [self.name, self.path]):
raise RuntimeError("name or path should be set")
elif all(e is not None for e in [self.name, self.path]):
raise RuntimeError("name and path cannot be set simultaneously")
[docs] def get_loading_path(self) -> str:
""" Returns path from which the model will be loaded
"""
if self.path:
return self.path
else:
return os.path.join(Config.CONFIG.dir.model_path, self.name)
[docs] def apply(self, al: ActiveLearner):
"""Load model parameters from disk and assign them to `al`'s model
Args:
al (ActiveLearner): The ActiveLearner whose model weights should be loaded from disk
"""
model_path = self.get_loading_path()
if os.path.exists(model_path):
loaded_data = torch.load(model_path)
if "state_dict" in loaded_data:
al.model.load_state_dict(loaded_data["state_dict"])
else:
al.model.load_state_dict(loaded_data)
logging.info(f"Loading model {model_path}")
else:
logging.warning(f"Path {model_path} doesn't exist, ignoring model loading")