import numpy as np

# These functions may be customized to specific needs, such as accommodating to
# the expectations of Neural Networks trained outside our environment
def get_model_information():
    """Model information for end user

    Returns
    -------
        str Information message that will appear in the Deep-Learning
        models information port.
    """

    model_information = "<html>This network is a <a href=\"https://arxiv.org/abs/1505.04597\"> \
            U-Net model variant.</a></html>"
    return model_information


# This function pre-processes the data by standardizing it to zero mean and unit variance.
def preprocess_input(x):
    """Standardization of the input array x, prior to the actual prediction by the Deep Learning Prediction module.
    This standardization must be the same as that which was applied during the Training of the model.


    Parameters
    ----------
        x: Array to pre-process.
            Numpy array of shape (depth, width, height, channels)

    Returns
    -------
        processed array.
            Expected shape: (depth, width, height, channels)
    """
    if len(x.shape) != 4:
        raise ValueError("incorrect 'x' shape, must be a 4d array with shape=(depth, width, height, channels)")

    n, _, _, _ = x.shape

    if isinstance(x, np.ndarray):
        x_tmp = np.reshape(x, [x.shape[0], np.product(x.shape[1:])])
        x_mean = np.mean(x_tmp, dtype=np.float32, axis=1).reshape((n, 1, 1, 1))
        x_std = np.std(x_tmp, dtype=np.float32, axis=1).reshape((n, 1, 1, 1))
        return (x-x_mean)/np.maximum(x_std, 1e-7)
    else:
        raise TypeError("'x' must be a numpy array")


# This function assumes that the model generates a probability score for each class.
# It returns the index of the image (label) with the highest probability.
# It is applied on the tiles of the input data, and thus avoid the allocation
# of the N scalar arrays representing the probability of each class.
def postprocess_local_output(predicted_tile):
    """This function performs a local post-processing of the prediction output,
    operating on each tile generated when the input data is large.
    It is intended for pixel-wise operations, which do not require a global context.

    Parameters
    ----------
        predicted_tile: tile of the model prediction output
            Numpy array of shape (tile_depth, tile_width, tile_height, input_channels)

    Returns
    -------
        post-processed tile
            Expected shape: (tile_depth, tile_width, tile_height, output_channels)
    """
    if len(predicted_tile.shape) != 4:
        raise ValueError("incorrect 'y_pred' shape, must be a 4d array with shape=(depth, width, height, channels)")

    _, _, _, channels = predicted_tile.shape

    if channels == 1:
        local_output = predicted_tile > 0.5
    else:
        local_output = np.expand_dims(np.argmax(predicted_tile,axis=3),axis=3)

    return local_output


# This function casts the predicted_array elements (float32 array) to an 8-bits unsigned integer
# when handling labels. The resulting image is therefore of this type.
def postprocess_global_output(list_predicted_arrays, input_array):
    """This function performs a global post-processing of the prediction output,
    after the tiles have been re-assembled. It is intended to be used for global
    normalization operations.


    Parameters
    ----------
        list_predicted_arrays: list of model prediction output(s), each item of the list is a channel.
            List of numpy arrays, each with shape (depth, width, height, 1)

        input_array: input data array.
            Numpy array of shape (depth, input_width, input_height, channels)

    Returns
    -------
        A list of the processed arrays, each will be exposed as a 3D Data object in the application.
            Expected shape is a list of arrays, each with shape (depth, output_width, output_height, 1)
    """
    return [list_predicted_arrays[i].astype(np.uint8) for i in np.arange(len(list_predicted_arrays))]


# This implementation makes the Prediction return a label field dataset,
# rather than a grayscale image.
def postprocess_output_type():
    """This function allows specifying the output object type of the Deep Learning Prediction module.

    Returns
    -------
        List of strings describing the object type for the corresponding data exposed in the application.
            The string should be either HxUniformLabelField3, or HxUniformScalarField3.
    """
    return ['HxUniformLabelField3']
