from abc import ABC, abstractmethod
from collections.abc import Iterable
from typing import List, Tuple
from copy import copy
import einops
import numpy as np
from ._dataitems import GridDataItem, CloudDataItem, DataItem
from scipy.spatial import KDTree
        
[docs]
class Compose(BaseTransform):
    """
    Compose function for combining multiple transforms.
    Iterates over transformations and applies them to the data item.
    Parameters
    ----------
    transforms : List[BaseTransform]
        List of transforms to be applied to the data
    """
    def __init__(self, transforms: List[BaseTransform]):
        if not isinstance(transforms, Iterable):
            raise ValueError("Transforms should be an iterable")
        elif len(list(transforms)) == 0:
            raise ValueError("No transforms were given")
        else:
            for i in transforms:
                if i is None:
                    raise ValueError("Transform can not be None")
                elif not isinstance(i, BaseTransform):
                    raise ValueError(f"Transform should be an instance of BaseTransform, got {type(i)}")
        
        self.transforms = transforms
    def __call__(self, data_item: DataItem) -> DataItem:
        self._check_data(data_item)
        result = copy(data_item)
        for aug in self.transforms:
            result = aug(result)
        return result
    def __repr__(self):
        return self.__class__.__name__ + str(self.transforms) 
    
[docs]
class Grid2Cloud(BaseTransform):
    """
    Create a Cloud item from a grid data item
    Parameters
    ----------
    data_item : GridItem
    Returns
    -------
    CloudItem
        data_item with cloud shape
    """
    def __init__(self):
        super().__init__()
    def __call__(self, data_item: DataItem) -> CloudDataItem:
        self._check_data(data_item)
        cloud_x = einops.rearrange(data_item.x, '... c w h -> ... (w h) c')
        cloud_y = einops.rearrange(data_item.y, '... c w h -> ... (w h) c')
        cloud_pos = einops.rearrange(data_item.pos, 'w h d -> (w h) d')
        return CloudDataItem(
            x=cloud_x,
            y=cloud_y,
            pos=cloud_pos
        ) 
    
[docs]
class ToDict(BaseTransform):
    """
    Convert the data item to a dictionary.
    Parameters
    ----------
    data_item : DataItem
    Returns
    -------
    dict
        data_item as a dictionary
    """
    def __init__(self):
        super().__init__()
    def __call__(self, data_item: DataItem) -> dict:
        self._check_data(data_item)
        return {key: value for key, value in data_item.__dict__.items() if value is not None} 
[docs]
class KNNGraph(BaseTransform):
    """
    Create a KNN graph from the cloud data.
    Parameters
    ----------
    data_item : CloudItem
    Returns
    -------
    CloudItem
        data_item with knn_graph
    """
    def __init__(self, k: int, grid_limits: Tuple[float] = (1.0, 1.0)):
        super().__init__()
        self.k = k
        self.grid_limits = grid_limits
    def __call__(self, data_item: CloudDataItem) -> CloudDataItem:
        self._check_data(data_item)
        points = data_item.pos
        self.grid_limits = np.array(self.grid_limits, dtype=np.float32)
        points_padded = np.concatenate(
               (points,
                points + np.array([0, 1]) * self.grid_limits,
                points + np.array([1, 0]) * self.grid_limits,
                points + np.array([1, 1]) * self.grid_limits,
                points + np.array([0, -1]) * self.grid_limits,
                points + np.array([-1, 0]) * self.grid_limits,
                points + np.array([-1, -1]) * self.grid_limits,
                points + np.array([1, -1]) * self.grid_limits,
                points + np.array([-1, 1]) * self.grid_limits,
                ), axis=0)
        tree = KDTree(points_padded)
        _, neighbors = tree.query(points, k=self.k+1)
        neighbors = neighbors[:, 1:] # remove the first column, which is the point itself
        
        # calculate distances
        neighbor_points = points_padded[neighbors]
        points_unsqueezed = np.expand_dims(points, axis=1)
        distances = neighbor_points - points_unsqueezed
        
        neighbors = neighbors % points.shape[0]
        return CloudDataItem(
            x=data_item.x,
            y=data_item.y,
            pos=data_item.pos,
            neighbors=neighbors,
            distances=distances,
        )
    
[docs]
    def check_if_valid(self):
        return True 
 
    
[docs]
class EdgeListFromKNN(BaseTransform):
    """
    Create an edge list from the KNN graph.
    Parameters
    ----------
    data_item : CloudItem
    Returns
    -------
    CloudItem
        data_item with knn_graph
    """
    def __init__(self):
        super().__init__()
    def __call__(self, data_item: CloudDataItem) -> CloudDataItem:
        """
        Default transformation for a data item. Does not modify the data.
        Parameters
        ----------
        data_item : DataItem
        Returns
        -------
        DataItem
            transformed data_item
        """
        self._check_data(data_item)
        neighbors = data_item.neighbors
        num_points = neighbors.shape[0]
        k = neighbors.shape[-1]
        src = np.repeat(np.arange(num_points), k)
        dst = neighbors.flatten()
        edge_list = np.stack((src, dst), axis=0)
        
        return CloudDataItem(
            x=data_item.x,
            y=data_item.y,
            pos=data_item.pos,
            neighbors=data_item.neighbors,
            distances=data_item.distances,
            edgelist=edge_list,
        )
    
[docs]
    def check_if_valid(self):
        return True 
 
    
[docs]
class EdgeList(Compose):
    """
    Create an edge list graph (src, dst) to use with PyG.
    Parameters
    ----------
    data_item : CloudItem
    Returns
    -------
    CloudItem
        data_item with edge_list as knn_graph
    """
    def __init__(self, k: int):
        super().__init__(transforms=[KNNGraph(k=k), EdgeListFromKNN()]) 
[docs]
class TypeCaster(BaseTransform):
    """
    Cast the data item to the correct type. (In place!!!)
    """
    def __init__(self, dtype: np.dtype = np.float32):        
        super().__init__()
        self.dtype = dtype
    def __call__(self, data_item: DataItem) -> DataItem:
        self._check_data(data_item)
        data_item.x = data_item.x.astype(self.dtype)
        data_item.y = data_item.y.astype(self.dtype)
        if hasattr(data_item, 'pos') and data_item.pos is not None:
            data_item.pos = data_item.pos.astype(self.dtype)
            
        if hasattr(data_item, 'distances') and data_item.distances is not None:
            data_item.distances = data_item.distances.astype(self.dtype)
        return data_item 
    
[docs]
class GridDownsampleFactor(BaseTransform):
    """
        Downsample the grid by a factor.
        Parameters
        ----------
        factor : int
            Factor by which to downsample the grid.
    """
    def __init__(self, factor: int = 2):
        super().__init__()
        self.factor = factor
    def __call__(self, data_item: DataItem) -> DataItem:
        self._check_data(data_item)
        # Downsample the grid
        if data_item.x.ndim == 3:
            downsampled_x = data_item.x[:, ::self.factor, ::self.factor]
        else:
            downsampled_x = data_item.x[:, :, ::self.factor, ::self.factor]
        
        if hasattr(data_item, 'y') and data_item.y is not None:
            downsampled_y = data_item.y[:, :, ::self.factor, ::self.factor]
        else:
            downsampled_y = None
        if hasattr(data_item, 'pos') and data_item.pos is not None:
            downsampled_pos = data_item.pos[::self.factor, ::self.factor]
        data_item = DataItem(
            x=downsampled_x,
            y=downsampled_y,
            pos=downsampled_pos,
        )
        return data_item 
    
[docs]
class GridDownsampleFFT(BaseTransform):
    """
        Downsample the grid to a smaller size using FFT.
        Parameters
        ----------
        target_size : Tuple[int, int]
            Target size of the grid.
    """
    def __init__(self, target_size: Tuple[int, int] = (1.0, 1.0)):
        super().__init__()
        self.target_size = target_size
    def __call__(self, data_item: DataItem) -> DataItem:
        self._check_data(data_item)
        # Get the original grid size
        original_size = data_item.x.shape[-2:]
        # Downsample using FFT
        downsampled_x = np.fft.rfft2(data_item.x, s=self.target_size)
        downsampled_x = np.fft.irfft2(downsampled_x)
        
        if hasattr(data_item, 'y') and data_item.y is not None:
            downsampled_y = np.fft.rfft2(data_item.y, s=self.target_size)
            downsampled_y = np.fft.irfft2(downsampled_y)
        else:
            downsampled_y = None
        if hasattr(data_item, 'pos') and data_item.pos is not None:
            downsampled_pos = np.fft.rfft2(data_item.pos, s=self.target_size)
            downsampled_pos = np.fft.irfft2(downsampled_pos)
        data_item = DataItem(
            x=downsampled_x,
            y=downsampled_y,
            pos=downsampled_pos,
        )
        return data_item 
    
[docs]
class PointSampling(BaseTransform):
    """
        Point sampling transform for the dataset.
        Parameters
        ----------
        num_points : int
            Number of points to sample.
        k : int
            Number of nearest neighbors to use for the KNN graph.
    """
    def __init__(self, num_points: int = 900):
        super().__init__()
        self.num_points = num_points
        
        
    def __call__(self, data_item: CloudDataItem) -> CloudDataItem:
        self._check_data(data_item)
        
        total_points = data_item.pos.shape[0]
        indices = np.random.choice(total_points, self.num_points, replace=False)
        points = data_item.pos[indices]
        y = data_item.y[:, indices]
        if data_item.x.ndim == 3:
            x = data_item.x[:, indices]
        else:
            x = data_item.x[indices]
        
        return CloudDataItem(
            x=x,
            y=y,
            pos=points
        )