Source code for dynabench.model._grid.cnn

import torch.nn as nn
import torch

from typing import Tuple

[docs] class CNN(nn.Module): """ Simple 2D CNN model for grid data. Parameters ---------- input_size : int Number of input channels. output_size : int Number of output channels. hidden_layers : int Number of hidden layers. Default is 1. hidden_channels : int Number of channels in each hidden layer. Default is 64. padding : int | str | Tuple[int] Padding size. If 'same', padding is calculated to keep the input size the same as the output size. Default is 'same'. padding_mode : str What value to pad with. Can be 'zeros', 'reflect', 'replicate' or 'circular'. Default: 'zeros' kernel_size : int Size of the kernel. Default is 3. activation : str Activation function to use. Can be one of `torch.nn activation functions <https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity>`_. Default is 'relu'. """ def __init__(self, input_size: int, output_size: int, hidden_layers: int = 1, hidden_channels: int = 64, padding: int | str | Tuple[int] = 'same', padding_mode: str = 'circular', kernel_size: int = 3, activation: str = 'ReLU'): super().__init__() self.input_layer = nn.Conv2d(input_size, hidden_channels, kernel_size, padding=padding, padding_mode=padding_mode) self.hidden_layers = nn.ModuleList([nn.Conv2d(hidden_channels, hidden_channels, kernel_size, padding=padding, padding_mode=padding_mode) for _ in range(hidden_layers)]) self.output_layer = nn.Conv2d(hidden_channels, output_size, kernel_size, padding=padding, padding_mode=padding_mode) self.activation = getattr(nn, activation)()
[docs] def forward(self, x: torch.Tensor): """ Forward pass of the model. Should not be called directly, instead call the model instance. Parameters ---------- x : torch.Tensor Input tensor of shape (batch_size, input_size, height, width). Returns ------- torch.Tensor Output tensor of shape (batch_size, output_size, height, width). """ x = self.activation(self.input_layer(x)) for layer in self.hidden_layers: x = self.activation(layer(x)) return self.output_layer(x)