"""
File: qumphy/models/dumbnet.py
Project: 22HLT01 QUMPHY
Contact: oskar.pfeffer@ptb.de
Gitlab: https://gitlab.com/qumphy
Description: TinyVGG-style convolutional neural network.
"""
import torch.nn as nn
[docs]
class DumbNet(nn.Module):
"""Model architecture that replicates the TinyVGG model from CNN explainer website.
Parameters
----------
input_shape : int
Number of input channels.
hidden_units : int
Number of hidden channels used in the convolutional layers.
output_shape : int
Number of output classes or output values.
"""
def __init__(
self,
input_shape: int,
hidden_units: int,
output_shape: int,
):
super().__init__()
self.conv_block_1 = nn.Sequential(
nn.Conv2d(
in_channels=input_shape,
out_channels=hidden_units,
kernel_size=3,
stride=1,
padding=1,
), # Values we can set ourselves in our NNs are called hyperparameters
nn.ReLU(),
nn.Conv2d(
in_channels=hidden_units,
out_channels=hidden_units,
kernel_size=3,
stride=1,
padding=1,
),
nn.ReLU(),
nn.MaxPool2d(
kernel_size=4
), # tries to take the max value of whatever the input is and outputs it
)
self.conv_block_2 = nn.Sequential(
nn.Conv2d(
in_channels=hidden_units,
out_channels=hidden_units,
kernel_size=3,
stride=1,
padding=1,
),
nn.ReLU(),
nn.Conv2d(
in_channels=hidden_units,
out_channels=hidden_units,
kernel_size=3,
stride=1,
padding=1,
),
nn.ReLU(),
nn.MaxPool2d(kernel_size=4),
)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(
in_features=hidden_units * 14 * 14, # there's a trick to calculate this
out_features=output_shape,
),
)
[docs]
def forward(self, x):
"""Run a forward pass through the model.
Parameters
----------
x : torch.Tensor
Input tensor of shape (batch_size, input_shape, height, width).
Returns
-------
torch.Tensor
Model output tensor of shape (batch_size, output_shape).
"""
x = self.conv_block_1(x)
x = self.conv_block_2(x)
x = self.classifier(x)
return x