import torch
import torch.nn as nn
import torchvision.models as models
from .utils import FeatureExtractorBackbone

[docs]class SLDAResNetModel(nn.Module): """ This is a model wrapper to reproduce experiments from the original paper of Deep Streaming Linear Discriminant Analysis by using a pretrained ResNet model. """
[docs] def __init__( self, arch="resnet18", output_layer_name="layer4.1", imagenet_pretrained=True, device="cpu", ): """Init. :param arch: backbone architecture. Default is resnet-18, but others can be used by modifying layer for feature extraction in ``self.feature_extraction_wrapper``. :param imagenet_pretrained: True if initializing backbone with imagenet pre-trained weights else False :param output_layer_name: name of the layer from feature extractor :param device: cpu, gpu or other device """ super(SLDAResNetModel, self).__init__() feat_extractor = ( models.__dict__[arch](pretrained=imagenet_pretrained) .to(device) .eval() ) self.feature_extraction_wrapper = FeatureExtractorBackbone( feat_extractor, output_layer_name ).eval()
@staticmethod def pool_feat(features): feat_size = features.shape[-1] num_channels = features.shape[1] features2 = features.permute(0, 2, 3, 1) # 1 x feat_size x feat_size x # num_channels features3 = torch.reshape( features2, (features.shape[0], feat_size * feat_size, num_channels) ) feat = features3.mean(1) # mb x num_channels return feat def forward(self, x): """ :param x: raw x data """ feat = self.feature_extraction_wrapper(x) feat = SLDAResNetModel.pool_feat(feat) return feat