Source code for HALF.Delegates.DatasetTestingDelegate

import pytorch_lightning as pl
from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT
from HALF.Utils.ALDatasetManager import ALDatasetManager
from HALF.Interfaces.IDelegate import IDelegate
from torch.utils.data.dataloader import DataLoader

[docs]class DatasetTestingDelegate(IDelegate):
[docs] def __init__(self, model: pl.LightningModule, dataset_manager: ALDatasetManager, config_dl: dict, trainer: pl.Trainer): """Class managing the evaluation of the performance of the model on the test set Args: model (pl.LightningModule): model to evaluate dataset_manager (ALDatasetManager): object where the test set is contained config_dl (dict[str, Union[int,str,float]]): configuration of the model for evaluation trainer (pl.Trainer): wrapper for the model training and evaluation """ self.model = model self.dataset_manager = dataset_manager self.trainer = trainer self.config_dl = config_dl
[docs] def run(self) -> _EVALUATE_OUTPUT: """Evaluate the performance of a model Returns: _EVALUATE_OUTPUT: results of the evaluation of the model """ self.dataset_manager.set_test_mode() if self.dataset_manager.dataset_test is not None: test_loader = DataLoader(self.dataset_manager.dataset_test, **self.config_dl) test_results = self.trainer.test(self.model, dataloaders=test_loader) return test_results