import numpy as np

# These functions may be customized to specific needs, such as accommodating to
# the expectations of Neural Networks trained outside our environment

def get_model_information():
    """Model information for end user
    DO NOT MODIFY.

    Returns
    -------
        str Information message that will appear in the Deep-Learning
        models information port.
    """

    model_information = "<html>This network is a back scattered electrons denoiser<br>- The model is based on a <a href=\"https://arxiv.org/abs/1505.04597\">U-Net architecture</a><br>- It is working on grayscale image only (1 channel)</html>"
    return model_information


def preprocess_input(x):
    """Preprocessing of the input array x, prior to the actual prediction by the Deep Learning Prediction module.
    This preprocessing must be the same as that which was applied during the Training of the model.

    If the normalization is not set to "None" in get_normalization_configuration(), a normalization is performed
    prior to this function call.
    Uncomment this function to use a customized preprocessing.

    Parameters
    ----------
        x : numpy array
            array to process. Shape (number, depth, width, height, channels)

    Returns
    -------
        processed array with same shape (number, depth, width, height, channels)
    """
    if len(x.shape) != 5:
        raise ValueError("incorrect 'x' shape, must be a 5d array with shape=(number, depth, width, height, channels)")

    # Pre-processes the data by standardizing it to min and max.
    n, _, _, _, _ = x.shape
    
    if isinstance(x, np.ndarray):
        return ((x-x.min())*float(np.iinfo(np.uint8).max)/(x.max()-x.min())).astype(dtype=np.float32)
    else:
        raise TypeError("'x' must be a numpy array")


def postprocess_local_output(predicted_tile):
    """This function performs a local post-processing of the prediction output,
    operating on each tile generated when the input data is large.
    It is intended for pixel-wise operations, which do not require a global context.

    Parameters
    ----------
        y_pred : numpy array
            tile of the model. It is the result of the prediction, composed of the class probabilities stored in each channel.
            Its shape is
                - (number, width, height, depth=1, nb_channels) in 2D
                - (number=1, width, height, depth, nb_channels) in 3D

    Returns
    -------
        post-processed tile (note: It is currently limited to output a single channel.)
        Its shape is
            - (number, width, height, depth=1, 1) in 2D
            - (number=1, width, height, depth, 1) in 3D
    """
    # This function assumes that the model generates a probability score for each class.
    # It returns the index of the image (label) with the highest probability.
    # It is applied on the tiles of the input data, and thus avoid the allocation
    # of the N scalar arrays representing the probability of each class.
    if len(predicted_tile.shape) != 5:
        raise ValueError("incorrect 'predicted_tile' shape, must be a 5d array with shape=(number, tile_depth, tile_width, tile_height, input_channels)")

    # the denoising model was originally trained to generate RGB outputs. For grayscale images, all channels are the same and we only keep one.
    return predicted_tile[:, :, :, :, 0:1]
    

def postprocess_global_output(list_predicted_arrays, input_array):
    """This function performs a global post-processing of the prediction output,
    after the tiles have been re-assembled. It is intended to be used for global
    normalization operations.

    Parameters
    ----------
        y_pred : numpy array list
            model prediction output
            Each element of the list has the shape (width, height, depth)
        x : numpy array
            origin data input
            Has a shape (width, height, depth)

    Returns
    -------
        A list of the processed arrays, each with shape (width, height, depth)
    """
    # N.B. Former shape of pred/input_array were (depth, width, height, 1) 
    # since 2022.1 (archi v2), it is (width, height, depth)
    
    # Two steps process output array after model prediction
    # 1/ standardization to the same mean and std as input image
    # 2/ casting to the same pixel type as the input
    pred = list_predicted_arrays[0]

    if isinstance(pred, np.ndarray) and isinstance(input_array, np.ndarray):
        _, _, n = pred.shape        
        voxel_type = input_array.dtype.type
        
        # get mean and stddev of each slice of the input
        tmp_input = np.reshape(input_array,[np.product(input_array.shape[:2]), input_array.shape[2]])
        target_mean = np.mean(tmp_input,axis=0).reshape((1,1,n))
        target_std = np.std(tmp_input,axis=0).reshape((1,1,n))

        # get mean and stddev of each predicted slice
        tmp_pred = np.reshape(pred,[np.product(pred.shape[:2]), pred.shape[2]])
        pred_mean = np.mean(tmp_pred,axis=0).reshape((1,1,n))
        pred_std = np.std(tmp_pred,axis=0).reshape((1,1,n))

        # Mean std normalization slice by slice
        pred_norm = (pred-pred_mean)*target_std/np.maximum(pred_std,1e-7) + target_mean

        # Clip to allowed range
        pred_norm = np.clip(pred_norm, 0, np.iinfo(voxel_type).max)
        
        # Cast and return result within a list as expected
        return [pred_norm.astype(dtype=voxel_type)]
    else:
        return False


def postprocess_output_type():
    """This function allows specifying the output object type of the Deep Learning Prediction module.

    Returns
    -------
        Object type (i.e. HxUniformLabelField3, HxUniformScalarField3, etc.)
    """
    # This implementation makes the Prediction return a label field dataset,
    # rather than a grayscale image.
    return ['HxUniformScalarField3']


def get_normalization_configuration():
    """Defines the normalization to apply on the input before the prediction.

    Returns
    --------
    A list of normalization configurations: each configuration is defined using a
    CanalConfig(item_or_collection, algoType, per_channel, parameters=None) function.

        item_or_collection: str
            - "PerSample": the normalization is computed for each item of the collection independently
            - "Global": the normalization is computed globally for the whole collection

        algo_type: str
            - "Standardization": applies the normalization (x-mean)/std_dev to the input
            - "MinMax": applies the normalization (x-min)/(max-min) to the input
            - "None": no normalization is applied

        per_channel: bool
            The normalization is performed for each channel or over all channels.

        parameters: DO NOT MODIFY

    If this function is not in the custom file, or the custom file itself is missing, no normalization is performed.
    """
    return [CanalConfig('PerSample', 'None', 'True', parameters=None)]

def get_product_name():
    """Product name.
    DO NOT MODIFY.
    
    To get the product name using the Python console, type the command 
    >>>hx_application.name

    Returns
    -------
        The product name ("Avizo", "Amira" or "PerGeos").
    """
    return "Avizo"

def get_product_version():
    """Product version.
    DO NOT MODIFY.

    To get the product version using the Python console, type the command 
    >>>hx_application.version

    Returns
    -------
        The product version line number including patch number in format yyyynp.
    """
    return 202210
