# -*- coding: utf-8 -*-

"""
Helpers for Avizo Object classification modules
"""
import _hx_core
from _hx_core import hx_object_factory, hx_message, HxConnection, PyScriptObject

import _PythonModuleHelpers as pmh
import numpy as np
import scipy.ndimage as sn


class ML_ObjectClassifier:
    def __init__(self):
        self.featureGroupNeedsIntensityImage = True
        self.ObjectDimensionality = 2
        self.model = []
        self.featureGroup = []
        self.featureScaler = []
        
        self.num_classes = 0
        self.class_names = []
        self.class_colors = []



def labelAnalysisFor2DObjectLocation(labelImage, type = "binary", classID = None):
    """
        Analyzes a 2D label image to extract object location information and returns 
        the processed label image along with a spreadsheet of analysis results.
        Parameters:
        -----------
        labelImage : object
            The input label image to be analyzed. This should be compatible with 
            the `HxAnalyzeLabels` module.
        type : str, optional
            Specifies the type of input image. Must be one of the following:
            - "binary": A binary image with values in the range [0, 1].
            - "label": A labeled image where each object has a unique label.
            - "class": A labeled image where a specific class ID is analyzed.
            Default is "binary".
        classID : int, optional
            The class ID to analyze when `type` is set to "class". This value must 
            be within the range of the label image. Required only when `type` is "class".
        Returns:
        --------
        outputLabelImage : object
            The processed label image after analysis. For "binary" and "class" types, 
            this is the output of the analysis module. For "label" type, this is the 
            original input label image.
        spreadSheet : object
            A spreadsheet containing the analysis results, including measures such as 
            area, barycenter coordinates, and bounding box dimensions.
        BBox_Ox_ColIndex : int
            The column index for the X-coordinate of the bounding box origin, in physical coordinates.
        BBox_Oy_ColIndex : int
            The column index for the Y-coordinate of the bounding box origin, in physical coordinates.
        BBox_Dx_ColIndex : int
            The column index for the X-dimension of the bounding box, in physical coordinates.
        BBox_Dy_ColIndex : int
            The column index for the Y-dimension of the bounding box, in physical coordinates.
        indexImage_ColIndex : int
            The column index for the image index in the spreadsheet, starting at 1. Equals -1 is only a single slice is analyzed.
        index_ColIndex : int
            The column index for the object index in the spreadsheet. If a single slice is analyzed, this will also represent the label value of the object in the outputLabeLImage.
        id_ColIndex : int
            The column index for the object ID in the spreadsheet. Represents the label value of the object in the outputLabeLImage. Equals -1 is only a single slice is analyzed.
        Raises:
        -------
        ValueError
            If `type` is not one of "binary", "label", or "class".
            If `type` is "binary" and the label image has a range greater than [0, 1].
            If `type` is "class" and `classID` is not specified or is out of range.
        Notes:
        ------
        - The function uses the `HxAnalyzeLabels` module to perform the analysis.
        - For "class" type, the label image is binarized for the specified class ID 
          before analysis.
        - The analysis includes basic 2D location measures such as area, barycenter 
          coordinates, and bounding box dimensions.
    """
    
    if type not in ["binary", "label", "class"]:
        raise ValueError("type must be 'binary' or 'label' or 'class'")

    if type == "binary":
        if labelImage.range[1] > 1:
            raise ValueError("binary image must have a range of 0-1")
        labelImage_forLA = labelImage
    elif type == "label":
        labelImage_forLA = labelImage
    elif type == "class":
        if classID is None:
            raise ValueError("classID must be specified for class type")
        if classID < labelImage.range[0] or classID > labelImage.range[1]:
            #the output image will be empty, and the number of rows in spreadsheet will be 0 ; but this should not cause much problems.
            print(f"Warning: classID ({classID}) should be in the range of the label image [{labelImage.range[0]} - {labelImage.range[1]}]")
        #binarize the label image for the specified class ID
        labelImage_forLA = binarizeLabelImageForClassID(labelImage, classID)

    labelAnalysisModule = hx_object_factory.create('HxAnalyzeLabels')
    labelAnalysisModule.ports.data.connect(labelImage_forLA)
    labelAnalysisModule.ports.interpretation.selected = 1
    labelAnalysisModule._tcl_interp(""" measures setState {"basic2D_Location" Area BaryCenterX BaryCenterY BoundingBoxOx BoundingBoxOy BoundingBoxDx BoundingBoxDy}""")
    labelAnalysisModule.execute()
    spreadSheet = labelAnalysisModule.results[0]

    if type == "binary" or type == "class":
        outputLabelImage = labelAnalysisModule.results[1]
    elif type == "label":
        outputLabelImage = labelImage

    # get the indices of the bounding box columns
    BBox_Ox_ColIndex = pmh.get_spreadsheet_col_index_by_name(spreadSheet, "BoundingBoxOx")
    BBox_Oy_ColIndex = pmh.get_spreadsheet_col_index_by_name(spreadSheet, "BoundingBoxOy")
    BBox_Dx_ColIndex = pmh.get_spreadsheet_col_index_by_name(spreadSheet, "BoundingBoxDx")
    BBox_Dy_ColIndex = pmh.get_spreadsheet_col_index_by_name(spreadSheet, "BoundingBoxDy")

    # get the indices of the columns 'indeximage', 'index', and 'id'
    indexImage_ColIndex = pmh.get_spreadsheet_col_index_by_name(spreadSheet, "indeximage")
    index_ColIndex = pmh.get_spreadsheet_col_index_by_name(spreadSheet, "index")
    id_ColIndex = pmh.get_spreadsheet_col_index_by_name(spreadSheet, "id")

    return outputLabelImage, spreadSheet, BBox_Ox_ColIndex, BBox_Oy_ColIndex, BBox_Dx_ColIndex, BBox_Dy_ColIndex, indexImage_ColIndex, index_ColIndex, id_ColIndex


# Note: this could be made more generic, to also propose selection by class name (looking into the material bundle)
def binarizeLabelImageForClassID(labelImage, classID):
    """
    Binarize a label image for a specific class ID.
    
    Parameters:
        labelImage (HxImage): The input label image.
        classID (int): The class ID to binarize for.
        
    Returns:
        HxImage: The binarized image.
    """
    # Create a new arithmetic module
    arithmeticModule = hx_object_factory.create('HxArithmetic')
    
    # Connect the input label image to the arithmetic module
    arithmeticModule.ports.inputA.connect(labelImage)
    arithmeticModule.fire()
    # Set the expression to binarize for the specified class ID
    arithmeticModule.ports.expr0.text = f"A=={classID}"
    
    # Execute the module
    arithmeticModule.execute()
    
    # Return the binarized image
    return arithmeticModule.results[0]


def labelAnalysisFor3DObjectLocation(labelImage, type = "binary", classID = None):
    """
        Analyzes a 3D label image to extract object location information and returns 
        the processed label image along with a spreadsheet of analysis results.
        Parameters:
        -----------
        labelImage : object
            The input label image to be analyzed. This should be compatible with 
            the `HxAnalyzeLabels` module.
        type : str, optional
            Specifies the type of input image. Must be one of the following:
            - "binary": A binary image with values in the range [0, 1].
            - "label": A labeled image where each object has a unique label.
            - "class": A labeled image where a specific class ID is analyzed.
            Default is "binary".
        classID : int, optional
            The class ID to analyze when `type` is set to "class". This value must 
            be within the range of the label image. Required only when `type` is "class".
        Returns:
        --------
        outputLabelImage : object
            The processed label image after analysis. For "binary" and "class" types, 
            this is the output of the analysis module. For "label" type, this is the 
            original input label image.
        spreadSheet : object
            A spreadsheet containing the analysis results, including measures such as 
            area, barycenter coordinates, and bounding box dimensions.
        BBox_Ox_ColIndex : int
            The column index for the X-coordinate of the bounding box origin, in physical coordinates.
        BBox_Oy_ColIndex : int
            The column index for the Y-coordinate of the bounding box origin, in physical coordinates.
        BBox_Oz_ColIndex : int
            The column index for the Z-coordinate of the bounding box origin, in physical coordinates.
        BBox_Dx_ColIndex : int
            The column index for the X-dimension of the bounding box, in physical coordinates.
        BBox_Dy_ColIndex : int
            The column index for the Y-dimension of the bounding box, in physical coordinates.
        BBox_Dz_ColIndex : int
            The column index for the Z-dimension of the bounding box, in physical coordinates.
        Raises:
        -------
        ValueError
            If `type` is not one of "binary", "label", or "class".
            If `type` is "binary" and the label image has a range greater than [0, 1].
            If `type` is "class" and `classID` is not specified or is out of range.
        Notes:
        ------
        - The function uses the `HxAnalyzeLabels` module to perform the analysis.
        - For "class" type, the label image is binarized for the specified class ID 
          before analysis.
        - The analysis includes basic 2D location measures such as area, barycenter 
          coordinates, and bounding box dimensions.
    """
    
    if type not in ["binary", "label", "class"]:
        raise ValueError("type must be 'binary' or 'label' or 'class'")

    if type == "binary":
        if labelImage.range[1] > 1:
            raise ValueError("binary image must have a range of 0-1")
        labelImage_forLA = labelImage
    elif type == "label":
        labelImage_forLA = labelImage
    elif type == "class":
        if classID is None:
            raise ValueError("classID must be specified for class type")
        if classID < labelImage.range[0] or classID > labelImage.range[1]:
            #the output image will be empty, and the number of rows in spreadsheet will be 0 ; but this should not cause much problems.
            print(f"Warning: classID ({classID}) should be in the range of the label image [{labelImage.range[0]} - {labelImage.range[1]}]")
        #binarize the label image for the specified class ID
        labelImage_forLA = binarizeLabelImageForClassID(labelImage, classID)

    labelAnalysisModule = hx_object_factory.create('HxAnalyzeLabels')
    labelAnalysisModule.ports.data.connect(labelImage_forLA)
    labelAnalysisModule.ports.interpretation.selected = 0
    labelAnalysisModule._tcl_interp(""" measures setState {"basic3D_Location" Area BaryCenterX BaryCenterY BaryCenterZ BoundingBoxOx BoundingBoxOy BoundingBoxOz BoundingBoxDx BoundingBoxDy BoundingBoxDz}""")
    labelAnalysisModule.execute()
    spreadSheet = labelAnalysisModule.results[0]

    if type == "binary" or type == "class":
        outputLabelImage = labelAnalysisModule.results[1]
    elif type == "label":
        outputLabelImage = labelImage

    # get the indices of the bounding box columns
    BBox_Ox_ColIndex = pmh.get_spreadsheet_col_index_by_name(spreadSheet, "BoundingBoxOx")
    BBox_Oy_ColIndex = pmh.get_spreadsheet_col_index_by_name(spreadSheet, "BoundingBoxOy")
    BBox_Oz_ColIndex = pmh.get_spreadsheet_col_index_by_name(spreadSheet, "BoundingBoxOz")
    BBox_Dx_ColIndex = pmh.get_spreadsheet_col_index_by_name(spreadSheet, "BoundingBoxDx")
    BBox_Dy_ColIndex = pmh.get_spreadsheet_col_index_by_name(spreadSheet, "BoundingBoxDy")
    BBox_Dz_ColIndex = pmh.get_spreadsheet_col_index_by_name(spreadSheet, "BoundingBoxDz")

    # get the indices of the column 'index'
    index_ColIndex = pmh.get_spreadsheet_col_index_by_name(spreadSheet, "index")

    return outputLabelImage, spreadSheet, BBox_Ox_ColIndex, BBox_Oy_ColIndex, BBox_Oz_ColIndex, BBox_Dx_ColIndex, BBox_Dy_ColIndex, BBox_Dz_ColIndex, index_ColIndex




def getInstanceIDandSlice(spreadsheet, label_index, indexImage_ColIndex = None, id_ColIndex = None, index_ColIndex = None, table_index=0):
    """
    Get the slice number and ID of a label from a spreadsheet, obtained from 2D Label Analysis
    Optionnally, provide the column indices for indexImage, id and index. If not provided, the function will try to find them in the spreadsheet.
    """
    si = spreadsheet.all_interfaces.HxSpreadSheetInterface

    if label_index > len(si.tables[0].rows):
        raise ValueError("label_index is out of range") 
    if indexImage_ColIndex is None or id_ColIndex is None or index_ColIndex is None:
        indexImage_ColIndex = pmh.get_spreadsheet_col_index_by_name(spreadsheet, "indeximage")
        index_ColIndex = pmh.get_spreadsheet_col_index_by_name(spreadsheet, "index")
        id_ColIndex = pmh.get_spreadsheet_col_index_by_name(spreadsheet, "id")     

    # identify the slice number and pixel value associated
    if indexImage_ColIndex == -1:
        num_slice = 0
    else:
        num_slice = int(si.tables[table_index].items[label_index, indexImage_ColIndex]) - 1
    if id_ColIndex == -1:
        label_id = int(si.tables[table_index].items[label_index, index_ColIndex])
    else:
        label_id = int(si.tables[table_index].items[label_index, id_ColIndex])

    return num_slice, label_id




def getPaddedPatchAroundInstance(labelImage, grayImage, spreadsheet, label_index, 
                                 minPatchSize, marginAroundInstance, paddingValue=0, getFadedGrayPatch = False,
                                 num_slice=None, label_id=None, 
                                 BBox_Ox_ColIndex=None, BBox_Oy_ColIndex=None, BBox_Dx_ColIndex=None, BBox_Dy_ColIndex=None,
                                 ):
    """
    Extracts a padded patch around a labeled instance in an image.
    Parameters:
        labelImage (HxUniformScalarField3): Labeled image containing the instance.
        grayImage (HxUniformScalarField3): Grayscale image corresponding to the labeled image.
        spreadsheet (HxSpreadSheetInterface): Spreadsheet containing bounding box information, derived from the labelImage, which needs to contain bounding box related columns.
        label_index (int): Index of the label in the spreadsheet. If possible, please also provide the 'num_slice', and 'label_id'
        minPatchSize (int): Minimum size of the extracted patch. The patch will be square, and may be larger than this if the object is larger
        marginAroundInstance (int): Number of pixels of margin to add around the instance (on all sides).
        paddingValue (int, optional): Value used for padding. Default is 0.
        getFadedGrayPatch (bool, optional): If True, returns a faded grayscale patch, with a blur to the paddingValue (gaussian profile controled by the 'margin'). Default is False.
    Returns:
        tuple of numpy arrays of shape of shape [x,x,1]: 
            - patchBBox: Tuple of tuples defining the crop box in physical coordinates (minx, miny, minz), (maxx, maxy, maxz),,
            - cropped_gray: Cropped grayscale patch, 
            - cropped_mask: Binary mask of the instance in the patch.
            - fadedGrayImage (optional): Faded grayscale patch (if `getFadedGrayPatch` is True).    
    """
    
    if num_slice is None or label_id is None:
        num_slice, label_id = getInstanceIDandSlice(spreadsheet, label_index)

    if BBox_Ox_ColIndex is None or BBox_Oy_ColIndex is None or BBox_Dx_ColIndex is None or BBox_Dy_ColIndex is None:
        BBox_Ox_ColIndex = pmh.get_spreadsheet_col_index_by_name(spreadsheet, "BoundingBoxOx")
        BBox_Oy_ColIndex = pmh.get_spreadsheet_col_index_by_name(spreadsheet, "BoundingBoxOy")
        BBox_Dx_ColIndex = pmh.get_spreadsheet_col_index_by_name(spreadsheet, "BoundingBoxDx")
        BBox_Dy_ColIndex = pmh.get_spreadsheet_col_index_by_name(spreadsheet, "BoundingBoxDy")


    label_array = labelImage.get_array()
    gray_array = grayImage.get_array()

    voxelSize = pmh.getVoxelSizeFromImage(labelImage)
    inputBBox = labelImage.bounding_box
    si = spreadsheet.all_interfaces.HxSpreadSheetInterface
    # get the bounding box of the object - it is given in physical coordinates    
    BBox_Ox_physical = si.tables[0].items[label_index, BBox_Ox_ColIndex]
    BBox_Oy_physical = si.tables[0].items[label_index, BBox_Oy_ColIndex]
    BBox_Dx_physical = si.tables[0].items[label_index, BBox_Dx_ColIndex]
    BBox_Dy_physical = si.tables[0].items[label_index, BBox_Dy_ColIndex]

    # convert to voxel coordinates
    BBox_Ox_index = int((BBox_Ox_physical - inputBBox[0][0])/voxelSize[0])
    BBox_Oy_index = int((BBox_Oy_physical - inputBBox[0][1])/voxelSize[1])
    BBox_Dx = int((BBox_Dx_physical)/voxelSize[0])
    BBox_Dy = int((BBox_Dy_physical)/voxelSize[1])                

    # get the center of the object's bounding box
    centerX = BBox_Ox_index + BBox_Dx // 2
    centerY = BBox_Oy_index + BBox_Dy // 2
    # box size needed to contain the object
    neededboxsize_x = BBox_Dx + 2*marginAroundInstance
    neededboxsize_y = BBox_Dy + 2*marginAroundInstance

    # Determine the size of the square box ; it should be at least minPatchSize, and at most the size of the object + marginAroundInstance
    box_size = max(max(neededboxsize_x, neededboxsize_y), minPatchSize)
    patchBBox = (((centerX - box_size // 2)*voxelSize[0], (centerY - box_size // 2)*voxelSize[1], inputBBox[0][2]+num_slice*voxelSize[2]), 
                    ((centerX + box_size // 2)*voxelSize[0], (centerY + box_size // 2)*voxelSize[1], inputBBox[0][2]+(num_slice+1)*voxelSize[2]))

    # Calculate the crop box coordinates
    cropbox_minx = max(0, centerX - box_size // 2)
    cropbox_miny = max(0, centerY - box_size // 2)
    cropbox_maxx = min(label_array.shape[0], centerX + box_size // 2)
    cropbox_maxy = min(label_array.shape[1], centerY + box_size // 2)

    # Check if padding is needed, if the object is too close to the image border
    pad_minx = max(0, -(centerX - box_size // 2))
    pad_miny = max(0, -(centerY - box_size // 2))
    pad_maxx = max(0, (centerX + box_size // 2) - label_array.shape[0])
    pad_maxy = max(0, (centerY + box_size // 2) - label_array.shape[1])

    # Crop the gray and label arrays, convert the 'label array' to a binary mask stored as float to enable fading later on (if needed).
    # Note: the array is made 3D, to be compatible with the HxUniformScalarField3 (although this could be done only when using set_array)
    #       further, I stick to RGB data type for the classification, but this could be changed to scalar later on.
    #       same for the mask data, and for the faded patch below
    cropped_gray = gray_array[cropbox_minx:cropbox_maxx, cropbox_miny:cropbox_maxy, num_slice:num_slice+1]

    #pxCropBox = ((cropbox_minx, cropbox_miny, num_slice), (cropbox_maxx, cropbox_maxy, num_slice+1))

    # Pad the arrays if necessary
    if pad_minx > 0 or pad_miny > 0 or pad_maxx > 0 or pad_maxy > 0:
        cropped_gray = np.pad(cropped_gray, ((pad_minx, pad_maxx), (pad_miny, pad_maxy), (0,0)), constant_values=paddingValue)

    # encapsulate the cropped images in a new HxUniformScalarField3
    tmpGrayPatch = hx_object_factory.create("HxUniformScalarField3")                    
    tmpGrayPatch.set_array(cropped_gray)
    tmpGrayPatch.bounding_box = patchBBox

    #if requested, fade out the gray image patch outside the object mask
    if getFadedGrayPatch:
        # extract the mask of the object in the cropbox
        cropped_mask = (label_array[cropbox_minx:cropbox_maxx, cropbox_miny:cropbox_maxy, num_slice:num_slice+1]==label_id).astype(np.float) 
        if pad_minx > 0 or pad_miny > 0 or pad_maxx > 0 or pad_maxy > 0:
            cropped_mask = np.pad(cropped_mask, ((pad_minx, pad_maxx), (pad_miny, pad_maxy), (0,0)), constant_values=0)
        # tmpMaskPatch = hx_object_factory.create("HxUniformScalarField3")
        # tmpMaskPatch.set_array(cropped_mask)
        # tmpMaskPatch.bounding_box = patchBBox
        # # blur the mask boundaries
        # gaussFilterModule = hx_object_factory.create('gaussfilter')
        # gaussFilterModule.ports.inputImage.connect(tmpMaskPatch)
        # gaussFilterModule.ports.interpretation.selected = 1
        # gaussFilterModule.fire()
        # gaussFilterModule.ports.standardDeviation.texts[0].value = marginAroundInstance / 2.0
        # gaussFilterModule.ports.standardDeviation.texts[1].value = marginAroundInstance / 2.0
        # gaussFilterModule.execute()
        # gaussProfileArray = gaussFilterModule.results[0].get_array()
        # create a gaussian profile to fade the gray image patch
        gaussProfileArray = sn.gaussian_filter(cropped_mask, sigma=marginAroundInstance / 2.0, mode='constant', cval=0.0)

        fadedGrayArray = cropped_gray * gaussProfileArray + paddingValue * (1.0 - gaussProfileArray)
        fadedGrayImage = hx_object_factory.create('HxUniformScalarField3')
        fadedGrayImage.set_array(fadedGrayArray.astype(gray_array.dtype))
        fadedGrayImage.bounding_box = patchBBox

        return patchBBox, cropped_gray, cropped_mask, fadedGrayImage.get_array()
    else:
        return patchBBox, cropped_gray, cropped_mask
