Source code for dynabench.model._grid.neuralpde

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchdiffeq import odeint_adjoint, odeint
from .cnn import CNN

from typing import List, Optional

[docs] class NeuralPDE(nn.Module): """ Neural PDE model for grid data. The model combines a CNN with a differentiable ODE solver to learn the dynamics of the data using the method of lines. The CNN is used to approximate the spatial derivatives of the data, while the ODE solver is used to approximate the temporal evolution of the data. The model has been taken from `NeuralPDE: Modelling Dynamical Systems from Data <https://arxiv.org/abs/2111.07671>`_ by Dulny et al. Parameters ---------- input_dim : int Number of input channels. hidden_channels : int Number of channels in each hidden layer of the CNN. Default is 64. hidden_layers : int Number of hidden layers in the CNN. Default is 1. solver : dict Dictionary of solver parameters. Default is {"method": "dopri5"}. use_adjoint : bool Whether to use the adjoint method for backpropagation. Default is True. """ def __init__(self, input_dim: int, hidden_channels: int = 64, hidden_layers: int = 1, solver: dict = {"method": "dopri5"}, use_adjoint: bool = True, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.input_dim = input_dim self.cnn = CNN( input_size=input_dim, output_size=input_dim, hidden_layers=hidden_layers, hidden_channels=hidden_channels ) self.solver = solver self.use_adjoint = use_adjoint def _ode(self, t, x): return self.cnn(x)
[docs] def forward(self, x: torch.Tensor, t_eval: List[float]=[0.0, 1.0]): """ Forward pass of the model. Should not be called directly, instead call the model instance. Parameters ---------- x : torch.Tensor Input tensor of shape (batch_size, input_size, height, width). t_eval : List[float], default [0.0, 1.0] List of times to evaluate the ODE solver at. Default is [0.0, 1.0]. Returns ------- torch.Tensor Output tensor of shape (batch_size, rollout, output_size, height, width). """ t_eval = torch.tensor(t_eval, dtype=x.dtype, device=x.device) if self.use_adjoint: pred = odeint_adjoint(self._ode, x, t_eval, **self.solver, adjoint_params=self.cnn.parameters())[1:] else: pred = odeint(self._ode, x, t_eval, **self.solver)[1:] pred = torch.swapaxes(pred, 0, 1) return pred