Source code for HALF.Delegates.ModelTrainingDelegate
import pytorch_lightning as pl
from HALF.Utils.ALDataModule import ALDataModule
from HALF.Utils.ALDataModule import ALDataModule
from HALF.Interfaces.IDelegate import IDelegate
from torch.utils.data.dataloader import DataLoader
from HALF.Utils.ALDatasetManager import ALDatasetManager
[docs]class ModelTrainingDelegate(IDelegate):
[docs] def __init__(self, model: pl.LightningModule,
dataset_manager: ALDatasetManager,
config_dl: dict,
trainer: pl.Trainer):
"""Class managing the training of the model with the labelled set
Args:
model (pl.LightningModule): model to train
dataset_manager (ALDatasetManager): object where the labelled set is contained
config_dl (dict[str, Union[int,str,float]]): configuration of the model for training
trainer (pl.Trainer): wrapper for the model training and evaluation
"""
self.model = model
self.trainer = trainer
self.config_dl = config_dl
self.al_dataset_manager = dataset_manager
[docs] def run(self):
"""Run the training
"""
self.al_dataset_manager.set_train_mode()
train_dataset = self.al_dataset_manager.dataset_labeled
train_ld = DataLoader(train_dataset, **self.config_dl)
train_results = self.trainer.fit(self.model, train_dataloaders=train_ld)
return train_results