# -*- coding: utf-8 -*-

"""
Helpers for Avizo modules and interfaces
"""

import sys
import os
import re
import struct
import numpy as np
import _hx_core 
from _hx_core import *
from _hx_core import _tcl_interp, hx_project, hx_object_factory

def set_module_result(module, result_data, name, discard_previous=False, result_slot=0):
    # Replace previous result if discard option is on
    if len(module.results) > result_slot and module.results[result_slot] is not None and discard_previous:
        previous_result = module.results[result_slot]
        icon_position = previous_result.icon_position
        previous_name = previous_result.name
        previous_result.name = "tmp"  # avoid name conflict
        previous_result.fire()
        # Move previous connections
        while len(previous_result.downstream_connections) > 0:
            connection = previous_result.downstream_connections[0]
            connection.connect(result_data)
        _hx_core.hx_project.remove(previous_result)
        module.results[result_slot] = result_data
        if result_data is not None:
            result_data.name = name
            result_data.icon_position = icon_position # keep previous position
            result_data.touch(_hx_core.HxData.NEW_DATA)
            result_data.fire()
    elif result_data is not None:  # prepare adding new result unless None
        for slot in range(len(module.results), result_slot + 1):  # prepare missing slots as needed
            module.results.insert(slot, result_data)  # dummy insert requires a data object
            module.results[slot] = None  # set unused slots to None
        module.results[result_slot] = result_data
        result_data.name = name
    # refresh the module results ... seems necessary to ensure saved projects reload properly
    resultList = []
    for res in module.results:
        resultList.append(res)
    module.results = []
    module.fire()
    module.results = resultList    

def composeName(basename, suffix):
    """compose a name from a base name and a suffix, replacing the part of the basename after the last . by the suffix
    """
    if basename is None:
        return suffix
    if suffix is None:
        return basename    
    if (basename.find('.') == -1):
        return basename + "." + suffix
    else:
        return basename.rpartition(".")[0] + "." + suffix

def placeResult(computeModule, resultModule, targetPort = 0):
    """(re)place the resultModule as output of the computeModule, at the specified position.
    """
    
    #number of results slots currently available
    nbResultsSlots = len(computeModule.results)
    previousDataset = None

    if targetPort < nbResultsSlots:
        previousDataset = computeModule.results[targetPort]
        #set result
        computeModule.results[targetPort] = resultModule
    else:
        #insert result in a new slot...
        computeModule.results.insert(targetPort, resultModule)
    
    if previousDataset is not None:
        previousName = previousDataset.name
        previousDataset.name = "tmp"
        previousPosition = previousDataset.icon_position
        
        #transfer connections from the previous result, to the new one
        while len(previousDataset.downstream_connections)>0:
            c = previousDataset.downstream_connections[0]
            c.connect( resultModule )
        
        hx_project.remove(previousDataset)
                
        resultModule.icon_position = previousPosition
        resultModule.name = previousName        
    
    resultModule.touch(2)
    resultModule.fire()


def numpyArrayFromAvizoSpreadsheet( spreadsheet, n_table = 0, nbColsDiscarded = 0 ):
    """generate a numpy array containing the data from the specified table of the input spreadsheet
       the last columns may be discarded, e.g. the last of Label Analysis contains string data (material names), which would generated an error otherwise.
       TODO: a nicer implementation would be to provide the names of columns to be discarded ; but for now it seems only the last columns are not features.
    """
    table = spreadsheet.all_interfaces.HxSpreadSheetInterface.tables[n_table]
    nbObjects = table.columns[0].asarray().shape[0]        
    nbColumns = len(table.columns) - nbColsDiscarded 
    array = np.empty([nbObjects, nbColumns])
    for j in range(nbColumns):
        array[:, j] = table.columns[j].asarray()
        
    return array


def getVoxelSizeFromImage(image):
    """obtain a numpy array with the voxel size of the image (assuming an HxUniformScalarField)
    """

    voxelSize = image._tcl_interp(" getVoxelSize")
    voxelSizeX = float(voxelSize.split()[0])
    voxelSizeY = float(voxelSize.split()[1])
    voxelSizeZ = float(voxelSize.split()[2])

    return np.array([voxelSizeX, voxelSizeY, voxelSizeZ])



def ArrayFromHxLandmark(hxLandmarks):
    """obtain a numpy array with the list of coordinates, from an HxLandmarks data object
    """

    nbPts = int(hxLandmarks._tcl_interp(" getNumPoints"))
    array = np.zeros((nbPts, 3))
    for n in range(nbPts):
        array_text = hxLandmarks._tcl_interp(" getPoint " + str(n))
        array[n,:] = np.array(array_text.split())
    
    return array
    
    
def HxLandmarkFromArray(arrayLdks):
    """populate an HxLandmark object from a numpy array with the list of 3D coordinates
    """
    if len(arrayLdks.shape) != 2:
        hx_message.error("error in HxLandmarkFromArray: expecting a 2D array, containing a list of 3D coordinates")
        return None
    if arrayLdks.shape[1] != 3:
        hx_message.error("error in HxLandmarkFromArray: expecting a 2D array, containing a list of 3D coordinates")
        return None
    
    hxLdks = hx_object_factory.create('HxLandmarkSet')
    
    for n in range(arrayLdks.shape[0]):
        hxLdks._tcl_interp(" appendLandmark " + str(arrayLdks[n,0]) + " " + str(arrayLdks[n,1]) + " " + str(arrayLdks[n,2]))

    return hxLdks
    


def HxClusterFromArray(coordsArray, dataArray = None, listDataNames = None):
    """populate an HxCluster data object from a numpy array with the list of 3D coordinates 
       create as many columns as there are in the dataArray, and populate them with the data, and provide them names if possible
       the number of rows in coordsArray and dataArray must be the same, and the number of columns in dataArray must be the same as the number of data names
       TODO: current implementation expects only float data values, but the API of HxCluster also allows for Label data ; this could be relevant to investigate and leverage for classification applications
    """
    numPoints = 0
    numDataItems = 0
    if len(coordsArray.shape) != 2:
        hx_message.error("error in HxClusterFromArray: expecting a 2D array, containing a list of 3D coordinates")
        return None
    if coordsArray.shape[1] != 3:
        hx_message.error("error in HxClusterFromArray: expecting a 2D array, containing a list of 3D coordinates")
        return None
    numPoints = coordsArray.shape[0]

    if dataArray is not None:
        if len(dataArray.shape) != 2:
            hx_message.error("error in HxClusterFromArray: expecting a 2D array, containing a list of data")
            return None
        if dataArray.shape[0] != coordsArray.shape[0]:
            hx_message.error("error in HxClusterFromArray: expecting the same number of rows in coordsArray and dataArray")
            return None
        numDataItems = dataArray.shape[1]


    if listDataNames is not None:
        if numDataItems != len(listDataNames):
            hx_message.error("error in HxClusterFromArray: expecting the same number of columns in dataArray and listDataNames")
            return None
    
    # now, create the HxCluster data object and populate it with available information
    hxCluster = hx_object_factory.create('HxCluster')

    # coordinates
    for n in range(numPoints):
        hxCluster._tcl_interp(" addPoint " + str(coordsArray[n,0]) + " " + str(coordsArray[n,1]) + " " + str(coordsArray[n,2]))

    # data items
    hxCluster._tcl_interp(" setNumDataColumns " + str(numDataItems))
    for m in range(numDataItems):
        for n in range(numPoints):
            hxCluster._tcl_interp(" setDataValue  " + str(m) + " " + str(n) + " " + str(dataArray[n,m]))
        if listDataNames is not None:
            hxCluster._tcl_interp(" setDataColumnName " + str(m) + " " + listDataNames[m])

    return hxCluster



def HxSpreadsheetFromArray(dataArray, listDataNames = None):
    """populate an HxSpreadsheet data object from a numpy array - assuming all is/can be mapped to 'float' data types
       create as many columns as there are in the dataArray, and populate them with the data, and provide them names if possible
    """
    numPoints = 0
    numDataItems = 0

    if dataArray is not None:
        if len(dataArray.shape) > 2:
            hx_message.error("error in HxSpreadsheetFromArray: expecting a 1 or 2D array")
            return None
        if len(dataArray.shape) == 2:
            numRows = dataArray.shape[0]
            numDataItems = dataArray.shape[1]
        else:
            numDataItems = dataArray.shape[0]
            numRows = 1

    if listDataNames is not None:
        if numDataItems != len(listDataNames):
            hx_message.error("error in HxSpreadsheetFromArray: expecting the same number of columns in dataArray and listDataNames")
            return None


    # now, create the spreadsheet data object and populate it with available information
    spreadsheet = hx_object_factory.create('HxSpreadSheet')
    for m in range(numDataItems):
        if listDataNames is None:
            spreadsheet._tcl_interp(" addColumn data_" + str(m) + " float")
        else:
            spreadsheet._tcl_interp(" addColumn " + listDataNames[m] + " float")
    spreadsheet._tcl_interp(" setNumRows " + str(numRows))


    # Fill the values
    if len(dataArray.shape) == 2:
        for n in range(numRows):
            for m in range(numDataItems):
                spreadsheet._tcl_interp(' setValue ' + str(m) + ' ' + str(n) + ' ' + str(dataArray[n,m]))
    else: # if it is a linear array...
        for m in range(numDataItems):
            spreadsheet._tcl_interp(' setValue ' + str(m) + ' 0 ' + str(dataArray[m]))

    return spreadsheet


def append_spreadsheets_by_rows(s1, s2):
    """append a HxSpreadsheet s2 to the first spreadsheet, assuming they have the same columnds
    """
    if s1 is None or s2 is None:
        hx_message.error("Cannot append spreadsheets: one of them is None")
        return None

    if len(s1.all_interfaces.HxSpreadSheetInterface.tables) != 1 or len(s2.all_interfaces.HxSpreadSheetInterface.tables) != 1:
        hx_message.error("Cannot append spreadsheets: both should have exactly one table")
        return None

    table1 = s1.all_interfaces.HxSpreadSheetInterface.tables[0]
    table2 = s2.all_interfaces.HxSpreadSheetInterface.tables[0]

    if len(table1.columns) != len(table2.columns):
        hx_message.error("Cannot append spreadsheets: they have different number of columns")
        return None

    # Check if the columns have the same names
    for col in range(len(table1.columns)):
        if table1.columns[col].name != table2.columns[col].name:
            hx_message.error(f"Cannot append spreadsheets: column {col} names do not match: '{table1.columns[col].name}' vs '{table2.columns[col].name}'")
            return None
        
    # Add rows to table 1
    numRowsTab1 = len(table1.rows)
    numRowsTab2 = len(table2.rows)
    s1._tcl_interp(f" setNumRows 0 {numRowsTab1 + numRowsTab2}")
    table1 = s1.all_interfaces.HxSpreadSheetInterface.tables[0]
    # Append rows from table2 to table1
    for row in range(len(table2.rows)):
        for col in range(len(table2.columns)):
            # Copy the value from table2 to table1
            table1.items[numRowsTab1+row, col] = table2.items[row, col]
        #table1.rows.append(table2.rows[row])

    return s1




def get_spreadsheet_col_index_by_name(spreadsheet, col_name, tab_id = 0):
    """Get the index of a column in the spreadsheet by its name. returns -1 if not found
    """
    col_index = -1
    for col in range(len(spreadsheet.all_interfaces.HxSpreadSheetInterface.tables[tab_id].columns)):
        if spreadsheet.all_interfaces.HxSpreadSheetInterface.tables[tab_id].columns[col].name == col_name:
            col_index = col
            break
    return col_index


def spreadsheet_add_table(spreadsheet, table_name):
    """Add a table to the spreadsheet and return its index.
    """
    spreadsheet._tcl_interp(f' addTable "{table_name}"')
    spreadsheet.fire()
    return len(spreadsheet.all_interfaces.HxSpreadSheetInterface.tables) - 1

def spreadsheet_add_columns(spreadsheet, column_name, column_type, table_id = 0):
    """Add a column to a table of the spreadsheet and return its index.
    """
    spreadsheet.all_interfaces.HxSpreadSheetInterface.tables[table_id].columns.append(HxSpreadSheetInterface.Column(name=column_name, typename=column_type))
    spreadsheet.fire()
    return len(spreadsheet.all_interfaces.HxSpreadSheetInterface.tables[table_id].columns) - 1


def spreadsheet_copy_column(destinationSpreadsheet, sourceSpreadsheet, sourceColumnName, destinationColumnName, destinationTableID = 0, sourceTableId = 0):
    """Copy a column from a table of the source spreadsheet in a new column in the destinationSpreadsheet and return its index.
    """
    # check the spreadsheets have the same number of rows
    ssi = sourceSpreadsheet.all_interfaces.HxSpreadSheetInterface
    dsi = destinationSpreadsheet.all_interfaces.HxSpreadSheetInterface
    numRowsSource = len(ssi.tables[sourceTableId].rows)
    numRowsDestination = len(dsi.tables[destinationTableID].rows)
    if numRowsDestination != numRowsSource:
        hx_message.error(f"Cannot copy column: source spreadsheet has {numRowsSource} rows, but destination spreadsheet has {numRowsDestination} rows.")
        return -1
    # check the source column exists
    sourceColumnIndex = get_spreadsheet_col_index_by_name(sourceSpreadsheet, sourceColumnName, sourceTableId)
    if sourceColumnIndex == -1:
        hx_message.error(f"Cannot copy column: source column '{sourceColumnName}' not found in table {sourceTableId} of the source spreadsheet.")
        return -1
    #create the destination column - check if it already exists, and return error if it does
    destinationColumnIndex = get_spreadsheet_col_index_by_name(destinationSpreadsheet, destinationColumnName, destinationTableID)
    if destinationColumnIndex != -1:
        hx_message.error(f"Cannot copy column: destination column '{destinationColumnName}' already exists in table {destinationTableID} of the destination spreadsheet.")
        return -1
    # create the destination columns
    columnType = ssi.tables[sourceTableId].columns[sourceColumnIndex].typename
    dsi.tables[destinationTableID].columns.append(HxSpreadSheetInterface.Column(name=destinationColumnName, typename=columnType))
    destinationSpreadsheet.fire()
    destinationColumnIndex = len(dsi.tables[destinationTableID].columns) - 1
    # copy the values from the source column to the destination column
    for row in range(numRowsSource):
        dsi.tables[destinationTableID].items[row, destinationColumnIndex] = ssi.tables[sourceTableId].items[row, sourceColumnIndex]
    return destinationColumnIndex


# some python wrappers for the TCL commands available in Avizo around parameter bundles...
def data_has_parameter_bundle(data, bundle_name, list_sub_bundles = []):
    """Check if the data has a parameter bundle with the given name, possibly inside specified sub-bundle.
    """
    if data is None:
        return False
    sub_bundles = ""
    for b in list_sub_bundles:
        sub_bundles += f" {b}"
    # TODO: I should probably add a check on the names of sub bundles...
    return (data._tcl_interp(f' parameters {sub_bundles} hasBundle {bundle_name}') == '1')

def data_list_parameters_bundles(data, list_sub_bundles = []):
    """returns a list of parameter_bundles in the specified sub_bundle        
    """
    if data is None:
        return []
    sub_bundles = ""
    for b in list_sub_bundles:
        sub_bundles += f" {b}"
    list_all_parameters = data._tcl_interp(f' parameters {sub_bundles} list').split(' ')
    output_list = []
    for p in list_all_parameters:
        if data._tcl_interp(f' parameters {sub_bundles} {p} isBundle ') == '1':
            output_list.append(p)
    return output_list
    
def get_materials_from_parameter_bundle(data):
    """returns a list of materials in the specified sub_bundle        
    """
    if data is None:
        return []
    if data._tcl_interp(f' parameters hasBundle Materials') == '0':
        return []
    return data._tcl_interp(f' parameters Materials list').split(' ')


def get_material_color(labelField, MaterialName):
    """Returns the color of the materials as a numpy array of 3 float values in [0,1] range.
       If the material does not exist, returns None.
    """
    if labelField is None:
        return None
    if labelField._tcl_interp(f' parameters hasBundle Materials') == '0':
        hx_message.warning(f"no Material bundle is defined in labelField {labelField.name}.")
        return None

    if labelField._tcl_interp(f' parameters Materials hasBundle {MaterialName}') == '0':        
        hx_message.warning(f"Material '{MaterialName}' does not exist in labelField {labelField.name}.")
        return None

    return labelField._tcl_interp(f' parameters Materials {MaterialName} Color getValue').split(' ')



def add_material_to_parameter_bundle(labelField, MaterialName, MaterialColor = np.random.rand(3)):
    """Add a material to the Materials parameter bundle of the labelField, with the given name and color.
       The color is expected to be a numpy array of 3 float values in [0,1] range. (corresponding 1/255th of RGB values)
    """
    if labelField is None:
        return
    if labelField._tcl_interp(f' parameters hasBundle Materials') == '0':
        print(f"Adding material '{MaterialName}' to the label field, but no Materials bundle exists. Creating a new one.")
        labelField._tcl_interp(f' parameters newBundle Materials')

    if labelField._tcl_interp(f' parameters Materials hasBundle {MaterialName}') == '0':        
        labelField._tcl_interp(f' parameters Materials newBundle {MaterialName}')
        labelField._tcl_interp(f' parameters Materials {MaterialName} setValue Color {MaterialColor[0]} {MaterialColor[1]} {MaterialColor[2]}')

        #also add classical info from param bundle...
        labelField._tcl_interp(f' parameters Materials {MaterialName} setValue Display2D true')
        labelField._tcl_interp(f' parameters Materials {MaterialName} setValue Lock false')

    else:
        print(f"Request to add material '{MaterialName}' to the label field, but this Material already exists: just change its color.")
        labelField._tcl_interp(f' parameters Materials {MaterialName} setValue Color {MaterialColor[0]} {MaterialColor[1]} {MaterialColor[2]}')



def data_list_parameters(data, list_sub_bundles = []):
    """returns a list of parameter (associated with a value, not sub-bundles) in the specified sub_bundle        
    """
    if data is None:
        return []
    sub_bundles = ""
    for b in list_sub_bundles:
        sub_bundles += f" {b}"
    list_all_parameters = data._tcl_interp(f' parameters {sub_bundles} list').split(' ')
    output_list = []
    for p in list_all_parameters:
        if data._tcl_interp(f' parameters {sub_bundles} {p} isBundle ') == 0:
            output_list.append(p)
    return output_list


def data_get_parameter_value(data, parameter_name, list_sub_bundles = []):
    """Get the value of a parameter in the specified sub-bundle - as a string
    """
    if data is None:
        return None
    sub_bundles = ""
    for b in list_sub_bundles:
        sub_bundles += f" {b}"
    return data._tcl_interp(f' parameters {sub_bundles} {parameter_name} getValue ')
        

def data_new_parameter_bundle(data, bundle_name, list_sub_bundles = []):
    """Create a new parameter bundle with the given name, possibly inside specified sub-bundle.
    """
    # to consider... shall we create the bundles if they were not existing? to align with 'set_parameter_value'?
    sub_bundles = ""
    for b in list_sub_bundles:
        sub_bundles += f" {b}"

    data._tcl_interp(f" parameters {sub_bundles} newBundle {bundle_name}")

def data_set_parameter_value(data, parameter_name, parameter_value, list_sub_bundles = []):
    """Set the value of a parameter in the specified sub-bundle.
    The parameter may, or may not exist previously. It will be created if is was not there before.
    The value will be interpreted as a string, as this is the only format TCL understands
    """
    sub_bundles = ""
    for b in list_sub_bundles:
        sub_bundles += f" {b}"

    data._tcl_interp(f" parameters {sub_bundles} setValue {parameter_name} {str(parameter_value)}")



def addContourToLineSet(lineSet, contour, referenceImage, referenceSlice, label = 1, closed = True):
    """Add a contour to a line set, with the specified label and slice.
       The contour is assumed to be a list of 2D points, pixel coordinates in an XY plane, withing a reference image
       The referenceImage is used to get the voxel size and the bounding box.
    """
    
    voxelSize = getVoxelSizeFromImage(referenceImage)
    bbox = referenceImage.bounding_box
    
    # create a new line in the line set
    currentNumberPoints = (int)(lineSet._tcl_interp(" getNumPoints"))
    lineSet._tcl_interp(f" setNumPoints {currentNumberPoints + contour.shape[0]}")

    z = referenceSlice * voxelSize[2] + bbox[0][2]
    # add the points to the line set
    for n in range(contour.shape[0]):
        x = contour[n, 0] * voxelSize[0] + bbox[0][0]
        y = contour[n, 1] * voxelSize[1] + bbox[0][1]
        
        lineSet._tcl_interp(f" setPoint {currentNumberPoints+n} {x} {y} {z}")

    newIndices = [currentNumberPoints + n for n in range(contour.shape[0])]
    indices_str = " ".join(str(idx) for idx in newIndices)
    if closed: # if the contour is closed, we need to add the first point again to close the loop
        indices_str += f" {currentNumberPoints}"
    # add the line to the line set
    lineSet._tcl_interp(f" addLine {indices_str}")
    
def data_get_source_path(dataset):
    """look for the LoadCmd parameter bundle, and return the file path it contains
    returns None if no LoadCmd is found
    """
    if dataset is None:
        return None
    
    if dataset._tcl_interp(" parameters hasParameter LoadCmd") == '0':
        print(f"data_get_source_path: no LoadCmd parameter bundle found in {dataset.name}")
        return None
    loadCmd = dataset._tcl_interp(" parameters LoadCmd getValue")
    # the loadCmd is a string with 'load <file_path>', I want to get only the file path
    filePath = loadCmd.split(' ', 1)[-1]  # split only on the first space to get the file path
    # it could be that this 'path' starts with '-unit xx ', in which case, I also want to discard this piece
    if filePath.startswith('-unit'):
        filePath = filePath.split(' ', 2)[-1]
    return filePath
    
