Source code for HALF.Interfaces.AbstractModel

from abc import ABC, abstractmethod
from torch import nn as nn
import torch as T
import pytorch_lightning as pl


[docs]class AbstractModel(ABC, pl.LightningModule): """ Abstract class for the model used in the framework by the Active Learner """
[docs] @abstractmethod def forward(self, x: T.torch, last: bool = False): """Forward pass through model Args: x (T.torch): input last (bool, optional): If True, return an (output, embedding) tuple. Defaults to False. """ pass
[docs] @abstractmethod def get_embedding_dim(self) -> int: """Get dimensionality (number of features) of the embedding Returns int : The dimensionality """ pass
[docs] def reset(self): """Resets model parameters """ for layer in self.children(): if hasattr(layer, 'reset_parameters'): layer.reset_parameters()
# TODO need attribute num_classes