Source code for dynabench.model.utils

import torch

from typing import List

[docs] class PointIterativeWrapper(torch.nn.Module): """ Wrapper class for iterative point-based model evaluation. Parameters ---------- model : torch.nn.Module The model to be wrapped and iteratively evaluated. batch_first : bool, default True If True, the first dimension of the input tensor is considered as the batch dimension. Attributes ---------- model : torch.nn.Module The wrapped model. batch_first : bool Indicates if the first dimension of the input tensor is the batch dimension. Methods ------- forward(x: torch.Tensor, p: torch.Tensor, t_eval: List[float] = [1]) -> torch.Tensor Perform iterative evaluation of the model at specified time points. """ def __init__(self, model, batch_first: bool = True): super().__init__() self.model = model self.batch_first = batch_first
[docs] def forward(self, x: torch.Tensor, # features p: torch.Tensor, # point coordinates t_eval: List[float] = [1]): rollout = [] for t in t_eval: x = self.model(x, p) rollout.append(x) dim = 1 if self.batch_first else 0 return torch.stack(rollout, dim=dim)
[docs] class GridIterativeWrapper(torch.nn.Module): """ Wrapper class for iterative grid-based model evaluation. Parameters ---------- model : torch.nn.Module The model to be wrapped and iteratively evaluated. batch_first : bool, default True If True, the first dimension of the input tensor is considered as the batch dimension. Attributes ---------- model : torch.nn.Module The wrapped model. batch_first : bool Indicates if the first dimension of the input tensor is the batch dimension. Methods ------- forward(x: torch.Tensor, t_eval: List[float] = [1]) -> torch.Tensor Perform iterative evaluation of the model at specified time points. """ def __init__(self, model, batch_first: bool = True): super().__init__() self.model = model self.batch_first = batch_first
[docs] def forward(self, x: torch.Tensor, # features t_eval: List[float] = [1]): rollout = [] for t in t_eval: x = self.model(x) rollout.append(x) dim = 1 if self.batch_first else 0 return torch.stack(rollout, dim=dim)