import torch
from torch import nn
[docs]class LeNet5(nn.Module):
[docs] def __init__(self, n_classes, input_channels):
"""LeNet5 architecture.
:param n_classes:
:param input_channels:
"""
super(LeNet5, self).__init__()
self.feature_extractor = nn.Sequential(
nn.Conv2d(
in_channels=input_channels,
out_channels=6,
kernel_size=5,
stride=1,
),
nn.Tanh(),
nn.AvgPool2d(kernel_size=2),
nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1),
nn.Tanh(),
nn.AvgPool2d(kernel_size=2),
nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5, stride=1),
nn.Tanh(),
)
self.ff = nn.Sequential(nn.Linear(in_features=120, out_features=84), nn.Tanh())
self.classifier = nn.Sequential(
nn.Linear(in_features=84, out_features=n_classes),
)
def forward(self, x):
x = self.feature_extractor(x)
x = torch.flatten(x, 1)
x = self.ff(x)
logits = self.classifier(x)
return logits