Source code for avalanche.models.simple_sequence_classifier

import torch
from avalanche.models.dynamic_modules import (
    MultiTaskModule,
    MultiHeadClassifier,
)


[docs]class SimpleSequenceClassifier(torch.nn.Module):
[docs] def __init__( self, input_size, hidden_size, n_classes, rnn_layers=1, batch_first=True ): super().__init__() self.batch_first = batch_first self.rnn = torch.nn.LSTM( input_size, hidden_size, num_layers=rnn_layers, batch_first=batch_first, ) self.classifier = torch.nn.Linear(hidden_size, n_classes)
def forward(self, x): out, _ = self.rnn(x) out = out[:, -1] if self.batch_first else out[-1] out = self.classifier(out) return out
[docs]class MTSimpleSequenceClassifier(MultiTaskModule):
[docs] def __init__(self, input_size, hidden_size, rnn_layers=1, batch_first=True): super().__init__() self.batch_first = batch_first self.rnn = torch.nn.LSTM( input_size, hidden_size, num_layers=rnn_layers, batch_first=batch_first, ) self.classifier = MultiHeadClassifier(hidden_size)
def forward(self, x, task_labels): out, _ = self.rnn(x) out = out[:, -1] if self.batch_first else out[-1] out = self.classifier(out, task_labels) return out
__all__ = ["SimpleSequenceClassifier", "MTSimpleSequenceClassifier"]