Source code for HALF.Delegates.ModelTrainingDelegate

import pytorch_lightning as pl
from HALF.Utils.ALDataModule import ALDataModule
from HALF.Utils.ALDataModule import ALDataModule
from HALF.Interfaces.IDelegate import IDelegate
from torch.utils.data.dataloader import DataLoader
from HALF.Utils.ALDatasetManager import ALDatasetManager

[docs]class ModelTrainingDelegate(IDelegate):
[docs] def __init__(self, model: pl.LightningModule, dataset_manager: ALDatasetManager, config_dl: dict, trainer: pl.Trainer): """Class managing the training of the model with the labelled set Args: model (pl.LightningModule): model to train dataset_manager (ALDatasetManager): object where the labelled set is contained config_dl (dict[str, Union[int,str,float]]): configuration of the model for training trainer (pl.Trainer): wrapper for the model training and evaluation """ self.model = model self.trainer = trainer self.config_dl = config_dl self.al_dataset_manager = dataset_manager
[docs] def run(self): """Run the training """ self.al_dataset_manager.set_train_mode() train_dataset = self.al_dataset_manager.dataset_labeled train_ld = DataLoader(train_dataset, **self.config_dl) train_results = self.trainer.fit(self.model, train_dataloaders=train_ld) return train_results