Source code for HALF.Interfaces.AbstractModel
from abc import ABC, abstractmethod
from torch import nn as nn
import torch as T
import pytorch_lightning as pl
[docs]class AbstractModel(ABC, pl.LightningModule):
    """
    Abstract class for the model used in the framework by the Active Learner
    """
[docs]    @abstractmethod
    def forward(self, x: T.torch, last: bool = False):
        """Forward pass through model
        Args:
            x (T.torch): input
            last (bool, optional): If True, return an (output, embedding) tuple. Defaults to False.
        """
        pass
[docs]    @abstractmethod
    def get_embedding_dim(self) -> int:
        """Get dimensionality (number of features) of the embedding
        Returns
            int : The dimensionality
        """
        pass
[docs]    def reset(self):
        """Resets model parameters
        """
        for layer in self.children():
            if hasattr(layer, 'reset_parameters'):
                layer.reset_parameters()
    # TODO need attribute num_classes