Source code for HALF.Delegates.DatasetTestingDelegate
import pytorch_lightning as pl
from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT
from HALF.Utils.ALDatasetManager import ALDatasetManager
from HALF.Interfaces.IDelegate import IDelegate
from torch.utils.data.dataloader import DataLoader
[docs]class DatasetTestingDelegate(IDelegate):
[docs] def __init__(self, model: pl.LightningModule,
dataset_manager: ALDatasetManager,
config_dl: dict,
trainer: pl.Trainer):
"""Class managing the evaluation of the performance of the model on the test set
Args:
model (pl.LightningModule): model to evaluate
dataset_manager (ALDatasetManager): object where the test set is contained
config_dl (dict[str, Union[int,str,float]]): configuration of the model for evaluation
trainer (pl.Trainer): wrapper for the model training and evaluation
"""
self.model = model
self.dataset_manager = dataset_manager
self.trainer = trainer
self.config_dl = config_dl
[docs] def run(self) -> _EVALUATE_OUTPUT:
"""Evaluate the performance of a model
Returns:
_EVALUATE_OUTPUT: results of the evaluation of the model
"""
self.dataset_manager.set_test_mode()
if self.dataset_manager.dataset_test is not None:
test_loader = DataLoader(self.dataset_manager.dataset_test, **self.config_dl)
test_results = self.trainer.test(self.model, dataloaders=test_loader)
return test_results