Source code for dynabench.model.utils

import torch

from typing import List, Optional

import einops

[docs] class RolloutWrapper(torch.nn.Module): """ Wrapper class for iterative model evaluation. This class is designed to perform iterative evaluation of models by calling the model multiple times at different time points. It can be used for both point-based and grid-based models. Parameters ---------- model : torch.nn.Module The model to be wrapped and iteratively evaluated. batch_first : bool, default True If True, the first dimension of the input tensor is considered as the batch dimension. If False, the first dimension is the rollout dimension. feature_dim: int, default -1 The id of the feature dimension. lookback_dim: int, default 1 The id of the lookback dimension. structure : str, default 'grid' The structure of the input data. Can be either 'grid' or 'cloud'. is_lookback_squeezed : bool, default False If True, the lookback dimension is squeezed. If True, the lookback dimension parameter is ignored. """ def __init__(self, model, structure: str = 'grid', batch_first: bool = True, lookback_dim: int = 1, is_lookback_squeezed: bool = False): super().__init__() if structure not in ['grid', 'cloud']: raise ValueError("Structure must be either 'grid' or 'cloud'") self.structure = structure self.model = model self.batch_first = batch_first self.feature_dim = 2 if structure == 'grid' else -1 self.lookback_dim = lookback_dim self.is_lookback_squeezed = is_lookback_squeezed self.alphabet = 'abcdefghijklmnopqrstuvwxyz'
[docs] def forward(self, x: torch.Tensor, # features p: Optional[torch.Tensor] = None, # point coordinates t_eval: List[float] = [1]): rollout = [] for t in t_eval: x_stacked_lookback = self._stack_lookback(x) # Merge lookback with the feature dimension x_single = self._single_step(x_stacked_lookback, p) # Call the model once x = self._wrap_input_with_lookback(x, x_single) # Wrap the input with the new prediction rollout.append(x_single) rollout_dim = 1 if self.batch_first else 0 return torch.stack(rollout, dim=rollout_dim)
def _stack_lookback(self, x): if self.structure == "grid": expr = 'batch lookback feature ... -> batch (lookback feature) ...' elif self.structure == "cloud": # Generate einops expression for cloud structure expr = 'batch lookback points feature -> batch points (lookback feature)' else: raise ValueError("Structure must be either 'grid' or 'cloud'") if not self.is_lookback_squeezed: return einops.rearrange(x, expr) else: return x def _single_step(self, x, p): if p is not None: x_single = self.model(x, p) else: x_single = self.model(x) return x_single def _wrap_input_with_lookback(self, x_previous, x_pred_single): if not self.is_lookback_squeezed: x_single_unstacked_loockback = einops.rearrange(x_pred_single, "batch ... -> batch () ...") # add dummy dim for lookback in pred x_next = torch.cat([x_previous[:, 1:], x_single_unstacked_loockback], dim=self.lookback_dim) else: x_next = x_pred_single return x_next
[docs] class CloudRolloutWrapper(RolloutWrapper): """ Alias for `dynabench.model.utils.RolloutWrapper` with structure="cloud" """ def __init__(self, model, batch_first: bool = True, lookback_dim: int = 1, is_lookback_squeezed: bool = False): super().__init__(model=model, structure="cloud", batch_first=batch_first, lookback_dim=lookback_dim, is_lookback_squeezed=is_lookback_squeezed)
[docs] class GridRolloutWrapper(RolloutWrapper): """ Alias for `dynabench.model.utils.RolloutWrapper` with structure="grid" """ def __init__(self, model, batch_first: bool = True, lookback_dim: int = 1, is_lookback_squeezed: bool = False): super().__init__(model=model, structure="grid", batch_first=batch_first, lookback_dim=lookback_dim, is_lookback_squeezed=is_lookback_squeezed)