dynabench.model.grind

Classes

FourierInterpolator([num_ks, spatial_dim])

Fourier Interpolation Layer.

GrIND(prediction_net[, num_ks, ...])

GrIND model for predicting the evolution of PDEs by first interpolating onto a high-resolution grid, solving the PDE and interpolating back to the original space.

class dynabench.model.grind.FourierInterpolator(num_ks=5, spatial_dim: int = 2, *args, **kwargs)[source]

Bases: Module

Fourier Interpolation Layer. Interpolates a function using Fourier coefficients. Given a set of points and values of a function, it computes the Fourier coefficients and then evaluates the function at a different set of points.

Parameters:
  • num_ks (int, default 5) – The number of Fourier modes to use for the interpolation.

  • spatial_dim (int, default 2) – The spatial dimension of the PDE.

forward(points_source, values_source, points_target)[source]

approximates the function at the given points using the fourier coefficients

generate_fourier_basis(points)[source]
generate_fourier_ks(points)[source]
solve_for_fourier_coefficients(points, values)[source]
class dynabench.model.grind.GrIND(prediction_net: Module, num_ks: int = 21, grid_resolution: int = 64, spatial_dim: int = 2, *args, **kwargs)[source]

Bases: Module

GrIND model for predicting the evolution of PDEs by first interpolating onto a high-resolution grid, solving the PDE and interpolating back to the original space.

Parameters:
  • prediction_net (nn.Module) – The neural network that predicts the evolution of the PDE in the high resolution space.

  • num_ks (int, default 21) – The number of Fourier modes to use for the interpolation.

  • grid_resolution (int, default 64) – The resolution of the high-grid to interpolate onto.

  • spatial_dim (int, default 2) – The spatial dimension of the PDE.

forward(x, p, t_eval=[0.0, 1.0])[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

generate_interpolation_points(grid_resolution)[source]