"""
    Module containing different initial condition generators.
"""
import numpy as np
import dynabench.grid
from typing import List
from warnings import warn
import itertools
[docs]
class InitialCondition(object):
    """
        Base class for all initial conditions.
        Parameters
        ----------
        parameters : dict, default {}
            Dictionary of parameters for the initial condition.
    """
    
    def __init__(self, 
                 parameters: dict = {}, 
                 **kwargs):
        self.spatial_dim = 2
        self.parameters = parameters
    def __str__(self):
        return "Initial condition base class"
    
    @property
    def num_variables(self):
        """
            Get the number of variables.
        """
        return 1
[docs]
    def generate(self, grid: dynabench.grid.Grid, random_state: int = 42):
        """
            Generate the initial condition.
            Parameters
            ----------
            grid : dynabench.grid.Grid
                The grid on which the initial condition is to be generated.
            Returns
            -------
            np.ndarray
                The initial condition.
        """
        raise NotImplementedError("The generate method must be implemented in the subclass.") 
    
    def __call__(self, grid: dynabench.grid.Grid, *args, **kwargs):
        return self.generate(grid, *args, **kwargs) 
    
[docs]
class Composite(InitialCondition):
    """
        Composite initial condition generator consisting of multiple initial conditions for the same grid.
        Convenience class to generate multiple initial conditions for different variables.
        Parameters
        ----------
        components : list
            List of single initial conditions.
    """
    
    def __init__(self, *components: list):
        self.components = components
    @property
    def num_variables(self):
        """
            Get the number of variables.
        """
        return len(self.components)
        
[docs]
    def generate(self, grid: dynabench.grid.Grid, random_state: int = 42):
        np.random.seed(random_state)
        seeds = np.random.randint(0, 1e6, len(self.components))
        return [component(grid, random_state=seed) for component, seed in zip(self.components, seeds)] 
 
    
[docs]
class Constant(InitialCondition):
    """
        Initial condition with a constant value.
        Parameters
        ----------
        value : float, default 0.0
            The value of the constant.
    """
    
    def __init__(self, value: float = 0.0, **kwargs):
        super(Constant, self).__init__(**kwargs)
        self.value = value
        
    def __str__(self):
        return f"I(x, y) = {self.value}"
    
[docs]
    def generate(self, grid: dynabench.grid.Grid, random_state: int = 42):
        return self.value+np.zeros(grid.shape) 
 
    
    
[docs]
class SumOfGaussians(InitialCondition):
    """
        Initial condition generator for the sum of gaussians.
        Parameters
        ----------
        grid_size : tuple, default (64, 64)
            The size of the grid.
        components : int, default 1
            The number of gaussian components.
        zero_level : float, default 0.0
            The zero level of the initial condition.
    """
    
    def __init__(self, 
                 components: int = 1, 
                 zero_level: float = 0.0, 
                 **kwargs):
        super(SumOfGaussians, self).__init__(**kwargs)
        self.components = components
        self.zero_level = zero_level
        warn(f'{self.__class__.__name__} is deprecated and will be removed in a future version. Use WrappedGaussians instead.', DeprecationWarning, stacklevel=2)
        
    def __str__(self):
        return "I(x, y) = sum_i A_i exp(-40(x-x_i)^2 + (y-y_i)^2)"
    
[docs]
    def generate(self, grid: dynabench.grid.Grid, random_state: int = 42):
        np.random.seed(random_state)
        x, y = np.meshgrid(grid.x, grid.y)
        mx = [np.random.choice(grid.shape[0]) for i in range(self.components)]
        my = [np.random.choice(grid.shape[1]) for i in range(self.components)]
        squared_distance_to_center = (x-0.5)**2 + (y-0.5)**2
        gaussian = np.exp(-40*squared_distance_to_center)
        u = self.zero_level+np.zeros_like(x)
        for i in range(self.components):
            component = np.roll(gaussian, (mx[i],my[i]), axis=(0,1))
            u = u + np.random.uniform(-1, 1) * component
        return u 
 
    
[docs]
class WrappedGaussians(InitialCondition):
    """
        Initial condition generator for the sum of wrapped gaussians.
        Parameters
        ----------
        components : int, default 1
            The number of gaussian components.
        zero_level : float, default 0.0
            The zero level of the initial condition.
        periodic_levels : int or list, default 10
            The number of periodic levels to calculate the wrapped distribution. :math:`p_w(\\theta)=\\sum_{k=-\\infty}^\\infty {p(\\theta+2\\pi k)}`
    """
    
    def __init__(self, 
                 components: int = 1, 
                 zero_level: float = 0.0, 
                 periodic_levels: int = 10,
                 **kwargs):
        super(WrappedGaussians, self).__init__(**kwargs)
        self.components = components
        self.zero_level = zero_level
        self.periodic_levels = periodic_levels
        
    def __str__(self):
        return "I(x, y) = sum_i A_i exp(-40(x-x_i)^2 + (y-y_i)^2)"
    
    def _wrapped_gaussian_2d(self, x, y, mu, sigma, limits_x = (0, 1), limits_y = (0, 1)):
        def gaussian_2d(x, y, mu, sigma):
            return np.exp(-((x - mu[0])**2 + (y - mu[1])**2)/(2*sigma**2))
        
        n = self.periodic_levels
        dLx = limits_x[1] - limits_x[0]
        dLy = limits_y[1] - limits_y[0]
        components = np.array([gaussian_2d(x, y, mu+[dLx*k_x, dLy*k_y], sigma) for k_x, k_y in itertools.product(range(-n, n+1), repeat=2)])
        return components.sum(axis=0)
    
[docs]
    def generate(self, grid: dynabench.grid.Grid, random_state: int = 42):
        np.random.seed(random_state)
        x, y = np.meshgrid(grid.x, grid.y)
        limits_x = grid.grid_limits[0]
        limits_y = grid.grid_limits[1]
        dLx = limits_x[1] - limits_x[0]
        dLy = limits_y[1] - limits_y[0]
        std_scale = min(dLx, dLy)
        m = grid.get_random_point_within_domain(self.components)
        u = self.zero_level+np.zeros_like(x)
        for i in range(self.components):
            std = np.random.uniform(0.01, 0.1) * std_scale
            component = self._wrapped_gaussian_2d(x, y, m[i], std, limits_x, limits_y)
            u = u + np.random.uniform(-1, 1) * component
        return u