from tensorflow.keras import backend as K
import numpy as np

def get_model_information():
    """Model informations for end user

    Returns
    -------
        str
    """

    model_information = "<html>This network is a back scattered electrons denoiser<br>- The model is based on a <a href=\"https://arxiv.org/abs/1505.04597\">U-Net architecture</a><br>- It is working on grayscale image only (1 channel)</html>"
    return model_information

def preprocess_input(x):
    """Two steps process input array before model prediction
    1-min max normalization to 0-255
    2-float casting

    Parameters
    ----------
        x : numpy array
            array to cast

    Returns
    -------
        cast array
    """

    if isinstance(x, np.ndarray):
        x = (x-x.min())*float(np.iinfo(np.uint8).max)/(x.max()-x.min())
        x = x.astype(dtype=np.float32)
        return x
    else:
        return False

def postprocess_output(x, target_x):
    """Two steps process output array after model prediction
    1-mean std normalization one the target_x mean and std
    2-target_x type casting

    Parameters
    ----------
        x : numpy array
            array to cast
        target_x : numpy array
            the target array distribution and type to match

    Returns
    -------
        cast array
    """

    if isinstance(x, np.ndarray) and isinstance(target_x, np.ndarray):
        n, _, _, _ = x.shape

        target_type = target_x.dtype.type
        target_x_tmp = np.reshape(target_x,[target_x.shape[0],np.product(target_x.shape[1:])])
        target_mean = np.mean(target_x_tmp,axis=1).reshape((n,1,1,1))
        target_std = np.std(target_x_tmp,axis=1).reshape((n,1,1,1))

        x_tmp = x.astype(dtype=np.float32)
        x_tmp = np.reshape(x_tmp,[x_tmp.shape[0],np.product(x_tmp.shape[1:])])
        x_mean = np.mean(x_tmp,axis=1).reshape((n,1,1,1))
        x_std = np.std(x_tmp,axis=1).reshape((n,1,1,1))

        # Mean std normalization
        x_norm = (x-x_mean)*target_std/np.maximum(x_std,1e-7) + target_mean
        # Clip
        x_norm = np.clip(x_norm, 0, np.iinfo(target_type).max)
        # Cast
        return x_norm.astype(dtype=target_type)
    else:
        return False
