Source code for HALF.Interfaces.AbstractModel
from abc import ABC, abstractmethod
from torch import nn as nn
import torch as T
import pytorch_lightning as pl
[docs]class AbstractModel(ABC, pl.LightningModule):
"""
Abstract class for the model used in the framework by the Active Learner
"""
[docs] @abstractmethod
def forward(self, x: T.torch, last: bool = False):
"""Forward pass through model
Args:
x (T.torch): input
last (bool, optional): If True, return an (output, embedding) tuple. Defaults to False.
"""
pass
[docs] @abstractmethod
def get_embedding_dim(self) -> int:
"""Get dimensionality (number of features) of the embedding
Returns
int : The dimensionality
"""
pass
[docs] def reset(self):
"""Resets model parameters
"""
for layer in self.children():
if hasattr(layer, 'reset_parameters'):
layer.reset_parameters()
# TODO need attribute num_classes