Source code for dynabench.dataset._equation

import h5py
import numpy as np
import pathlib

from ._base import BaseListMovingWindowIterator, BaseListSimulationIterator
from .transforms import BaseTransform
from typing import List, Optional


[docs] class EquationMovingWindowIterator(BaseListMovingWindowIterator): """ Iterator for arbitrary equations generated using the dynabench solver. Each sample returned by the __getitem__ method is a tuple of (data_input, data_target, points), where data_input is the input data of shape (L, F, H, W), data_target is the target data of shape (R, F, H, W), and points are the points in the grid of shape (H, W, 2). In this context L corresponds to the lookback parameter and R corresponds to the rollout parameter. H and W are the height and width of the grid, respectively. F is the number of variables in the equation system. Parameters ---------- eq_dir : str Path to the directory where the generated simulations are stored. lookback : int Number of time steps to look back. This corresponds to the L parameter. rollout : int Number of time steps to predict. This corresponds to the R parameter. squeeze_lookback_dim: bool Whether to squeeze the lookback dimension. Defaults to False. If lookback > 1 has no effect. is_batched: bool Whether the data is batched. Defaults to False. If True, the data is expected to be of shape (B, L, F, H, W), where B is the batch size. dtype: np.dtype Data type of the input data. Defaults to np.float32. selected_simulations: List[str] List of selected simulation names to load. If None, all simulations in the directory are loaded. """ def __init__( self, eq_dir: str, lookback: int, rollout: int, selected_simulations: Optional[List[str]] = None, squeeze_lookback_dim: bool = True, is_batched: bool = True, transforms: Optional[BaseTransform] = None, dtype: np.dtype=np.float32, ) -> None: eq_dir = pathlib.Path(eq_dir) # read the directory and get the list of files if selected_simulations is not None: data_paths = [path for path in eq_dir.iterdir() if path.name in selected_simulations] else: data_paths = [path for path in eq_dir.iterdir() if path.name.endswith(".h5")] super().__init__( data_paths = data_paths, lookback = lookback, rollout = rollout, squeeze_lookback_dim = squeeze_lookback_dim, is_batched = False, transforms = transforms, dtype = dtype )
[docs] class EquationSimulationIterator(BaseListSimulationIterator): """ Iterator for full equations generated using the dynabench solver. Each sample returned by the __getitem__ method is a tuple of (data_input, points), where data_input is the input data of shape (L, F, H, W), data_target is the target data of shape (R, F, H, W), and points are the points in the grid of shape (H, W, 2). In this context L corresponds to the lookback parameter and R corresponds to the rollout parameter. H and W are the height and width of the grid, respectively. F is the number of variables in the equation system. Parameters ---------- eq_dir : str Path to the directory where the generated simulations are stored. is_batched: bool Whether the data is batched. Defaults to False. If True, the data is expected to be of shape (B, L, F, H, W), where B is the batch size. dtype: np.dtype Data type of the input data. Defaults to np.float32. selected_simulations: List[str] List of selected simulation names to load. If None, all simulations in the directory are loaded. """ def __init__( self, eq_dir: str, selected_simulations: Optional[List[str]] = None, is_batched: bool = True, transforms: Optional[BaseTransform] = None, dtype: np.dtype=np.float32, ) -> None: eq_dir = pathlib.Path(eq_dir) # read the directory and get the list of files if selected_simulations is not None: data_paths = [path for path in eq_dir.iterdir() if path.name in selected_simulations] else: data_paths = [path for path in eq_dir.iterdir() if path.name.endswith(".h5")] super().__init__( data_paths = data_paths, is_batched = False, transforms = transforms, dtype = dtype )