from _hx_core import *
import numpy as np
import re
from LM._PythonModuleHelpers2 import selected_button_menu, is_toggle_on, set_module_result, tcl_module_command, create_module
from LM._LabelsMatching import matching, precision, recall, f1, accuracy, _safe_divide

console_option = 'console'
discard_option = 'replace'
name_after_pred_option = 'name as prediction'
analysis_option = "analysis"
image_option = "diff image"
sheet_option = "table"
confusion_matrix_option = "confusion matrix"
gt_option = 'ground truth'
pred_option = 'prediction'
pre_relabel_sequentially_option = 'dense labels'
matching_criteria_dict = {"Intersection Over Union (Jaccard index)": "IoU",
                          "Intersection Over ground Truth (recall)": "IoT",
                          "Intersection Over Prediction (precision)": "IoP"}


globals_names_types = {'criterion': ('Criterion','string'),
                 'thresh': ('Threshold','float'),
                 'fp': ('False positives','int'),
                 'tp': ('True positives','int'),
                 'fn': ('False negatives','int'),
                 'precision': ('Precision','float'),
                 'recall': ('Recall','float'),
                 'accuracy': ('IoU','float'),
                 'f1': ('F1 score','float'),
                 'n_true': ('Used ground truth labels','int'),
                 'n_pred': ('Used prediction labels','int'),
                 'mean_true_score': ('Mean true score','float'),
                 'mean_matched_score': ('Mean matched score','float'),
                 'panoptic_quality': ('Panoptic quality','float')}

## TODO: in addition to the 'pixeldiff' image, the module could output a diff image by "instances", although the shape of the instances wouldn't necessarily be easy to show    


class LabelsMatchingImpl:

    def __init__(self):
        self.result_output = None
        self.result_by = None
        self.selected_labels = None
        self.ignored_labels = None
        self.summaryResult_sheet = None
        self.summaryResult_sheet_slot = 0
        self.gt_analysis_sheet = None
        self.gt_analysis_slot = 1
        self.pred_analysis_sheet = None
        self.pred_analysis_slot = 2
        self.confusion_sheet = None
        self.confusion_sheet_slot = 3
        self.output_image = None
        self.image_slot = 4
        self.labeledGT_slot = 5 # if the input GT is binary, the module could output its labeled version
        self.labeledPred_slot = 6 # if the input prediction is binary, the module could output its labeled version
        self.do_it = None
        self.advanced_options = None
        self.result_options = None
        self.options = None
        self.overlap_threshold = None
        self.matching_threshold = None
        self.matching_criterion = None
        self.prediction = None
        self.data = None
        self.mode = None

    def init(self, pyscro):
        pyscro.ports.startStop.visible = False
        self.data = pyscro.data
        self.data.label = "Image1 (Ground Truth)"
        self.data.valid_types = ['HxUniformLabelField3']

        self.prediction = HxConnection(pyscro, "prediction", "Image2 (Prediction)")
        self.prediction.valid_types = ['HxUniformLabelField3']
        
        #radio button, for selecting between 2D & 3D mode
        self.mode = HxPortRadioBox(pyscro, 'mode', 'Mode')
        self.mode.radio_boxes = [
            HxPortRadioBox.RadioBox(label="3D"),
            HxPortRadioBox.RadioBox(label="2D (XY)")
        ]
        self.mode.selected = 1 # 2D by default, at least the time to work on NPC project

        self.matching_criterion = HxPortMultiMenu(pyscro, 'matching_criterion', 'Matching Criterion')
        self.matching_criterion.tooltip = \
            "Metric for labels matching: Intersection Over Union, Intersection Over True, Intersection Over Prediction"
        self.matching_criterion.menus = [HxPortButtonMenu.Menu(options=list(matching_criteria_dict.keys()))]

        self.matching_threshold = HxPortFloatTextN(pyscro, 'matching_threshold', 'Matching Threshold')
        self.matching_threshold.texts = [HxPortFloatTextN.FloatText(label='', value=0.5, clamp_range=(0, 1))]

        self.overlap_threshold = HxPortFloatTextN(pyscro, 'overlap_threshold', 'Overlap Threshold')
        self.overlap_threshold.texts = [HxPortFloatTextN.FloatText(label='', value=0., clamp_range=(0, 1))]
        self.overlap_threshold.visible = False

        # TODO: support fore ignored/selected labels to be revised/completed
        self.ignored_labels = HxPortText(pyscro, 'ignored_labels', 'Ignore labels')
        self.ignored_labels.text = ""
        self.ignored_labels.visible = False

        self.selected_labels = HxPortText(pyscro, 'selected_labels', 'Selected labels')
        self.selected_labels.text = ""
        self.selected_labels.visible = False

        self.result_output = HxPortToggleList(pyscro, 'result_output', 'Results output')
        self.result_output.toggles = [
            HxPortToggleList.Toggle(label=sheet_option, checked=HxPortToggleList.Toggle.CHECKED),
            HxPortToggleList.Toggle(label=analysis_option, checked=HxPortToggleList.Toggle.CHECKED),
            HxPortToggleList.Toggle(label=confusion_matrix_option, checked=HxPortToggleList.Toggle.CHECKED),
            HxPortToggleList.Toggle(label=image_option, checked=HxPortToggleList.Toggle.CHECKED),
            HxPortToggleList.Toggle(label=console_option, checked=HxPortToggleList.Toggle.UNCHECKED),
        ]
        self.result_by = HxPortToggleList(pyscro, 'results_by', 'Analysis Results by')
        self.result_by.toggles = [
            HxPortToggleList.Toggle(label=gt_option, checked=HxPortToggleList.Toggle.CHECKED),
            HxPortToggleList.Toggle(label=pred_option, checked=HxPortToggleList.Toggle.UNCHECKED),
        ]

        self.result_options = HxPortToggleList(pyscro, 'result_options', 'Result options')
        self.result_options.toggles = [
            HxPortToggleList.Toggle(label=discard_option, checked=HxPortToggleList.Toggle.CHECKED),
            HxPortToggleList.Toggle(label=name_after_pred_option, checked=HxPortToggleList.Toggle.CHECKED),
        ]
        self.result_options.visible = False
        self.advanced_options = HxPortToggleList(pyscro, 'advanced_options', 'Advanced options')
        self.advanced_options.toggles = [
            HxPortToggleList.Toggle(label=pre_relabel_sequentially_option, checked=HxPortToggleList.Toggle.UNCHECKED),
        ]
        self.advanced_options.visible = False

        self.do_it = HxPortDoIt(pyscro, 'apply', 'Apply')

    def __del__(self, pyscro):
        self.data.visible = True

    def update(self, pyscro):
        pass

    def compute(self, pyscro):
        # Check if module is applied
        if not self.do_it.was_hit:
            return

        # Check if input data is connected
        if self.data.source() is None:
            return

        # Check if input data is connected
        if self.prediction.source() is None:
            return

        # If result was disconnected or not created yet, make sure forget ancillaries
        #if pyscro.results[0] is None:  
        if True:  
            self.summaryResult_sheet = None
            self.gt_analysis_sheet = None
            self.pred_analysis_sheet = None
            self.confusion_sheet = None
            self.output_image = None

        # Get input images
        input_gt_image = self.data.source()

        input_pred_image = self.prediction.source()

        if input_gt_image.get_array().shape != input_pred_image.get_array().shape:
            hx_message.error("Input images are not the same size. ")
            return
        # Unneeded here for nows, label arrays are promoted in the confusion matrix calculation
        # if gt_array.dtype != pred_array.dtype:
        #     print("Input images are not the same data type. "
        #           "Data types will be internally promoted to best common data type.")
        #     gt_array_best_type = np.result_type(np.min_scalar_type(gt_array.min()), gt_array.max())
        #     pred_array_best_type = np.result_type(np.min_scalar_type(gt_array.min()), gt_array.max())
        #     gt_array = gt_array.astype(np.promote_types(gt_array_best_type, pred_array_best_type))
        #     pred_array = pred_array.astype(np.promote_types(gt_array_best_type, pred_array_best_type))

        matching_threshold = self.matching_threshold.texts[0].value
        overlap_threshold = self.overlap_threshold.texts[0].value
        match_criterion = matching_criteria_dict[selected_button_menu(self.matching_criterion)]

        # various functions to fill spreadsheet tables...
        def labels_string_to_index_list(image, labels_string):
            label_index_list = []
            materials = image.parameters['Materials'] if 'Materials' in image.parameters else None
            for s in self.ignored_labels.text.split():
                if s.isdigit():
                    label_index_list.append(int(s))
                elif materials is not None:
                    match_indices = [i for i, n in enumerate(materials) if n.name == s]
                    if len(match_indices) > 0:
                        label_index_list.append(match_indices[0])  # take any first matching
            return label_index_list


        def get_scalar_metrics_RB(match_res):
            match_res_dict = match_res._asdict()
            scalar_metrics_dict = {k: v for (k, v) in match_res_dict.items() if not isinstance(v, tuple)}
            names = [f'{globals_names_types[k][0]}' for k in scalar_metrics_dict.keys()]
            types = [f'{globals_names_types[k][1]}' for k in scalar_metrics_dict.keys()]
            values = [f"{val:0.3f}" if isinstance(val, float) else str(val) for val in scalar_metrics_dict.values()]
            return names, types, values
        
        def get_spreadsheet_RB(spreadsheet, visible=True):
            # create or clear spreadsheet (assuming #0):
            if spreadsheet is None:
                spreadsheet = create_module('HxSpreadSheet', visible = False)
                #spreadsheet.icon_visible = visible
            else:
                tcl_module_command(spreadsheet, " clearTable 0")
                # Remove added tabs if any
                num_tables = len(spreadsheet.all_interfaces.HxSpreadSheetInterface.tables)
                for tab_index in range(num_tables - 1, 0, -1):
                    tcl_module_command(spreadsheet, f" removeTable {tab_index}")
            return spreadsheet

        def get_spreadsheet_col_index_by_name(spreadsheet, tab_id, col_name):
            col_index = -1
            for col in range(len(spreadsheet.all_interfaces.HxSpreadSheetInterface.tables[tab_id].columns)):
                if spreadsheet.all_interfaces.HxSpreadSheetInterface.tables[tab_id].columns[col].name == col_name:
                    col_index = col
                    break
            return col_index


        def set_table_num_rows(spreadsheet, table, rows_number):
            tcl_module_command(spreadsheet, f' setNumRows {table} {rows_number}')


        def add_column_RB(spreadsheet, tab_id, col_name, typename):
            s_interface = spreadsheet.all_interfaces.HxSpreadSheetInterface
            s_interface.tables[tab_id].columns.append(HxSpreadSheetInterface.Column(name=col_name, typename=typename))
            return

        def set_table_value(spreadsheet, table, column, row, val):
            spreadsheet.all_interfaces.HxSpreadSheetInterface.tables[table].items[row,column] = val

        def set_table_value_withTypeCheck(spreadsheet, table, column, row, val):
            coltype = spreadsheet.all_interfaces.HxSpreadSheetInterface.tables[table].columns[column].typename
            if coltype == 'float':
                spreadsheet.all_interfaces.HxSpreadSheetInterface.tables[table].items[row,column] = float(val)
            if coltype == 'int':
                spreadsheet.all_interfaces.HxSpreadSheetInterface.tables[table].items[row,column] = int(val)
            if coltype == 'string':
                spreadsheet.all_interfaces.HxSpreadSheetInterface.tables[table].items[row,column] = str(val)
        
        def add_table_RB(spreadsheet, table_name):
            # nonlocal last_column_index
            tcl_module_command(spreadsheet, f' addTable "{table_name}"')
            tcl_module_command(spreadsheet, f' fire')
            table_index = len(spreadsheet.all_interfaces.HxSpreadSheetInterface.tables) - 1
            return table_index        

        # set values in a whole column, or a subset of it
        def set_column_values(spreadsheet, tab_id, col_name, col_type, rows, values):
            if len(rows) != len(values):
                print("Warning: set_column_values: rows and values have different lengths")

            col_index = get_spreadsheet_col_index_by_name(spreadsheet, tab_id, col_name)
            if col_index == -1:
                print(f"Error: set_column_values: cannot find requested column {col_name}")
                return

            if col_type == 'float':
                for row, val in zip(rows, values):
                    set_table_value(spreadsheet, tab_id, col_index, row, float(val))
            if col_type == 'int':
                for row, val in zip(rows, values):
                    set_table_value(spreadsheet, tab_id, col_index, row, int(val))
            if col_type == 'string':
                for row, val in zip(rows, values):
                    set_table_value(spreadsheet, tab_id, col_index, row, str(val))

        def fill_individual_table_RB(spreadsheet, tab_id, match_res, max_row_label, col_axis, start_label=1):
            row_axis = 1 - col_axis
            set_table_num_rows(spreadsheet, tab_id, max_row_label + 1 - start_label)  # start at 1 to skip background
            (rows_name, cols_name) = ("GT", "P") if col_axis == 1 else ("P", "GT")
            matched_pairs_array = np.array(match_res.matched_pairs)

            best_match_col_name = 'BestMatch ' + cols_name
            add_column_RB(spreadsheet, tab_id, best_match_col_name, 'int')
            if matched_pairs_array.shape != (0,):
                matched_rows = matched_pairs_array[:, row_axis] - start_label  # spreadsheet numbers start at zero
                set_column_values(spreadsheet, tab_id, best_match_col_name, 'int', matched_rows, matched_pairs_array[:, col_axis])

            add_column_RB(spreadsheet, tab_id, f'{match_criterion} score', 'float')
            if matched_pairs_array.shape != (0,):
                set_column_values(spreadsheet, tab_id, f'{match_criterion} score', 'float', matched_rows, match_res.matched_scores)

            
            add_column_RB(spreadsheet, tab_id, 'Accept', 'string')
            if matched_pairs_array.shape != (0,):
                accept_col = matched_pairs_array[match_res.matched_tps, row_axis] - start_label
                set_column_values(spreadsheet, tab_id, 'Accept', 'string', accept_col, len(accept_col) * ['OK'])

            ignored = ignored_gt_labels if col_axis == 1 else ignored_pred_labels
            for row in ignored:
                col_id = get_spreadsheet_col_index_by_name(spreadsheet, tab_id, 'Accept')
                set_table_value(spreadsheet, tab_id, col_id, row-start_label, 'IGN')
            add_column_RB(spreadsheet, tab_id, 'Jaccard (IoU)', 'float')
            add_column_RB(spreadsheet, tab_id, 'Precision (IoP)', 'float')
            add_column_RB(spreadsheet, tab_id, 'Recall (IoT)', 'float')
            add_column_RB(spreadsheet, tab_id, 'F1 score', 'float')
            if matched_pairs_array.shape != (0,):
                set_column_values(spreadsheet, tab_id, 'Jaccard (IoU)', 'float', matched_rows, match_res.matched_scores_iou)
                set_column_values(spreadsheet, tab_id, 'Precision (IoP)', 'float', matched_rows, match_res.matched_scores_iop)
                set_column_values(spreadsheet, tab_id, 'Recall (IoT)', 'float', matched_rows, match_res.matched_scores_iot)
                set_column_values(spreadsheet, tab_id, 'F1 score', 'float', matched_rows, match_res.f1_pix)

        def fill_matrix_table_RB(spreadsheet, tab_id, matrix, max_row_label, col_axis, col_type, start_label=1):
            (rows_name, cols_name) = ("GT", "Pred") if col_axis == 1 else ("Pred", "GT")
            set_table_num_rows(spreadsheet, tab_id, max_row_label + 1 - start_label)
            if start_label == 0:
                # Add a column for labels id if different from table row numbering
                rows_range = range(0, max_row_label + 1)
                #set_column(spreadsheet, tab_id, rows_name + " label", 'int', rows_range, rows_range)
                add_column_RB(spreadsheet, tab_id, rows_name + " label", 'int')
                set_column_values(spreadsheet, tab_id, rows_name + " label", 'int', rows_range, rows_range)
            # Adding columns
            (map_rev_col, map_rev_row) = (map_rev_pred, map_rev_gt) if col_axis == 1 else (map_rev_gt, map_rev_pred)

            for label_pred in map_rev_col:
                add_column_RB(spreadsheet, tab_id, cols_name + "#" + str(label_pred), col_type)

            for row_label in map_rev_row[start_label:]:
                ss_row_number = row_label - start_label  # spreadsheet rows/columns start at 0
                for col_label in map_rev_col:
                    row = row_label if col_axis == 1 else col_label
                    col = col_label if col_axis == 1 else row_label
                    col_id = get_spreadsheet_col_index_by_name(spreadsheet, tab_id, cols_name + "#" + str(col_label))
                    set_table_value(spreadsheet, tab_id, col_id,
                                    ss_row_number, matrix[map_fwd_gt[row], map_fwd_pred[col]])
                    # ... scores has 1 less row/column, ignoring background

        def fill_global_metrics_table_RB(spreadsheet, tab_id, match_res, row_id=0):
            names, types, values = get_scalar_metrics_RB(match_res)           
            for col in range(len(names)):
                col_id = get_spreadsheet_col_index_by_name(spreadsheet, tab_id, names[col])
                set_table_value_withTypeCheck(spreadsheet, tab_id, col_id, row_id, values[col])
            
            # add the indeximage information if needed
            col_id = get_spreadsheet_col_index_by_name(spreadsheet, tab_id, 'indeximage')
            if col_id != -1:
                set_table_value_withTypeCheck(spreadsheet, tab_id, col_id, row_id, row_id+1) # indeximage starts at 1, not 0 !
            return

        def global_metrics_prepare_table(spreadsheet, tab_id, match_res, number_rows=1):
            names, types, values = get_scalar_metrics_RB(match_res)
            
            set_table_num_rows(spreadsheet, tab_id, number_rows)
            for col in range(len(names)):
                # print(f"adding column {col}, with name {names[col]}, and type: {types[col]}")
                add_column_RB(spreadsheet, tab_id, names[col], types[col])
            if number_rows > 1:
                # print(f"also adding column for index image")
                add_column_RB(spreadsheet, tab_id, 'indeximage', 'int') # add a column to store the slice index (usually called indeximage in Label Analysis)


        # Create analysis spreadsheet from label input
        def create_analysis_spreadsheet(input_label, inputIs2DStack):
            analyze_labels = hx_object_factory.create('HxAnalyzeLabels')
            analyze_labels.ports.data.connect(input_label)
            if (input_label.get_array().shape[2] > 1) and (inputIs2DStack == False):
                analyze_labels.ports.interpretation.selected = 0
                tcl_module_command(analyze_labels, " measures setState Volume3d  Volume3d") # sets most basic useful group
            else:
                analyze_labels.ports.interpretation.selected = 1
                analyze_labels.fire()
                tcl_module_command(analyze_labels, " measures setState Area  Area") # sets most basic useful group
            analyze_labels.execute()
            label_analysis = analyze_labels.results[0]
            return label_analysis

        # append a spreadsheet to an analysis spreadsheet
        # if 'row_offset' is 0, create new columns to the right of the existing ones, and insert the available items starting with first row
        # otherwise, the columns should already exist, and the data to be inserted shall start at the indicated row index.
        # in both cases, the 'sheet to append' is likely not covering all rows on the analysis_sheet...
        def append_spreadsheet_to_analysis(analysis_sheet, sheet_to_append, row_offset = 0):
            num_rows_in_sheet = len(sheet_to_append.all_interfaces.HxSpreadSheetInterface.tables[0].rows)

            for col in range(len(sheet_to_append.all_interfaces.HxSpreadSheetInterface.tables[0].columns)):
                col_name = sheet_to_append.all_interfaces.HxSpreadSheetInterface.tables[0].columns[col].name
                typename = sheet_to_append.all_interfaces.HxSpreadSheetInterface.tables[0].columns[col].typename
                col_array = sheet_to_append.all_interfaces.HxSpreadSheetInterface.tables[0].columns[col].asarray()
                if row_offset == 0:
                    add_column_RB(analysis_sheet, 0, col_name, typename)
                set_column_values(analysis_sheet, 0, col_name, typename, range(row_offset, row_offset + num_rows_in_sheet), col_array)
            return

        def fill_summary_sheet_from_slices(spreadsheet):
            t0 = spreadsheet.all_interfaces.HxSpreadSheetInterface.tables[0]
            t1 = spreadsheet.all_interfaces.HxSpreadSheetInterface.tables[1]
            col_Index_Criterion = get_spreadsheet_col_index_by_name(spreadsheet, 1, 'Criterion')
            col_Index_Threshold = get_spreadsheet_col_index_by_name(spreadsheet, 1, 'Threshold')
            col_Index_FP = get_spreadsheet_col_index_by_name(spreadsheet, 1, 'False positives')
            col_Index_TP = get_spreadsheet_col_index_by_name(spreadsheet, 1, 'True positives')
            col_Index_FN = get_spreadsheet_col_index_by_name(spreadsheet, 1, 'False negatives')
            col_Index_precision = get_spreadsheet_col_index_by_name(spreadsheet, 1, 'Precision')
            col_Index_recall = get_spreadsheet_col_index_by_name(spreadsheet, 1, 'Recall')
            col_Index_iou = get_spreadsheet_col_index_by_name(spreadsheet, 1, 'IoU')
            col_Index_f1 = get_spreadsheet_col_index_by_name(spreadsheet, 1, 'F1 score')
            col_Index_Ntrue = get_spreadsheet_col_index_by_name(spreadsheet, 1, 'Used ground truth labels')
            col_Index_Npred = get_spreadsheet_col_index_by_name(spreadsheet, 1, 'Used prediction labels')
            col_Index_MTS = get_spreadsheet_col_index_by_name(spreadsheet, 1, 'Mean true score')
            col_Index_MMS = get_spreadsheet_col_index_by_name(spreadsheet, 1, 'Mean matched score')
            col_Index_panoptic_quality = get_spreadsheet_col_index_by_name(spreadsheet, 1, 'Panoptic quality')

            # Criterion and Threshold are always the same, just copy from the first row
            t0.items[0,col_Index_Criterion] = t1.items[0,col_Index_Criterion] 
            t0.items[0,col_Index_Threshold] = t1.items[0,col_Index_Threshold] 
           
            # other metrics are recomputed from the sum of information from each slice
            sum_fp, sum_tp, sum_fn, sum_ntrue, sum_npred, sum_matched_score = 0, 0, 0, 0, 0, 0
            for sliceNum in range(0, len(t1.rows)):
                sum_fp = sum_fp + t1.items[sliceNum,col_Index_FP]
                sum_tp = sum_tp + t1.items[sliceNum,col_Index_TP]
                sum_fn = sum_fn + t1.items[sliceNum,col_Index_FN]
                sum_ntrue = sum_ntrue + t1.items[sliceNum,col_Index_Ntrue]
                sum_npred = sum_npred + t1.items[sliceNum,col_Index_Npred]
                sum_matched_score = sum_matched_score + ( t1.items[sliceNum,col_Index_MMS] * t1.items[sliceNum,col_Index_TP] )

            t0.items[0,col_Index_FP] = sum_fp
            t0.items[0,col_Index_TP] = sum_tp
            t0.items[0,col_Index_FN] = sum_fn
            t0.items[0,col_Index_Ntrue] = sum_ntrue
            t0.items[0,col_Index_Npred] = sum_npred
            t0.items[0,col_Index_MMS] = _safe_divide(sum_matched_score, sum_tp)
            t0.items[0,col_Index_MTS] = _safe_divide(sum_matched_score, sum_ntrue)
            t0.items[0,col_Index_panoptic_quality] = _safe_divide(sum_matched_score, sum_tp+sum_fp/2+sum_fn/2)
            t0.items[0,col_Index_precision] = precision(sum_tp,sum_fp,sum_fn)
            t0.items[0,col_Index_recall] = recall(sum_tp,sum_fp,sum_fn)
            t0.items[0,col_Index_iou] = accuracy(sum_tp,sum_fp,sum_fn)
            t0.items[0,col_Index_f1] = f1(sum_tp,sum_fp,sum_fn)
            return

        def binarize_image(image):
            thresholdModule = hx_object_factory.create('HxInteractiveThreshold')
            thresholdModule.ports.data.connect(image)
            thresholdModule.fire()
            thresholdModule.ports.intensityRange.range[0] = 1
            thresholdModule.execute()
            return thresholdModule.results[0]
        
        def labelize_image(image):
            labelingModule = hx_object_factory.create('label')
            labelingModule.ports.inputImage.connect(image)
            labelingModule.ports.interpretation.selected = self.mode.selected
            labelingModule.execute()
            return labelingModule.results[0]

        ## Main processing
        ignored_gt_labels = labels_string_to_index_list(input_gt_image, self.ignored_labels)

        # check whether input images are binary, or labeled images, and produce the other image
        if input_gt_image.range[1] == 1:
            input_gt_image_binary = input_gt_image
            input_gt_image_labeled = labelize_image(input_gt_image)
        else:
            input_gt_image_labeled = input_gt_image
            input_gt_image_binary = binarize_image(input_gt_image)

        if input_pred_image.range[1] == 1:
            input_pred_image_binary = input_pred_image
            input_pred_image_labeled = labelize_image(input_pred_image)
        else:
            input_pred_image_labeled = input_pred_image
            input_pred_image_binary = binarize_image(input_pred_image)

        # Prepare and set module results
        basename_obj = self.prediction if is_toggle_on(self.result_options, name_after_pred_option) else self.data
        basename = re.sub(".am$", "", basename_obj.source().name)
        discard = is_toggle_on(self.result_options, discard_option)

        # if an analysis output is requested, make sure the labeled image corresponding to that analysis is exposed
        if is_toggle_on(self.result_output, analysis_option):
            if is_toggle_on(self.result_by, gt_option) and input_gt_image_binary == input_gt_image:
                name = re.sub(".am$", "", input_gt_image.name)
                set_module_result(pyscro, input_gt_image_labeled, name + f".labels", discard, self.labeledGT_slot)
            else:
                set_module_result(pyscro, None, "", True, self.labeledGT_slot)
            if is_toggle_on(self.result_by, pred_option) and input_pred_image_binary == input_pred_image:
                name = re.sub(".am$", "", input_pred_image.name)
                set_module_result(pyscro, input_pred_image_labeled, name + f".labels", discard, self.labeledPred_slot)
            else:
                set_module_result(pyscro, None, "", True, self.labeledPred_slot)
        else:
            set_module_result(pyscro, None, "", True, self.labeledGT_slot)
            set_module_result(pyscro, None, "", True, self.labeledPred_slot)


        # check whether the inputs are a 'single image to be analyzed', or a stack of slices
        inputisStack = False
        if self.mode.selected == 1 and input_gt_image_labeled.get_array().shape[2] > 1:
            inputisStack = True

        #create the output spreadsheets as requested by module options
        self.summaryResult_sheet = get_spreadsheet_RB( self.summaryResult_sheet) if is_toggle_on(self.result_output, sheet_option) else None            
        self.confusion_sheet = get_spreadsheet_RB( self.confusion_sheet) if is_toggle_on(self.result_output, confusion_matrix_option) else None

        # analysis spreadsheets are ready for appending the useful results.
        number_rows_gt = 0
        number_rows_pred = 0
        self.gt_analysis_sheet = None
        self.pred_analysis_sheet = None            
        if is_toggle_on(self.result_output, analysis_option):
            if is_toggle_on(self.result_by, gt_option):
                self.gt_analysis_sheet = create_analysis_spreadsheet(input_gt_image_labeled, inputisStack)
                number_rows_gt = len(self.gt_analysis_sheet.all_interfaces.HxSpreadSheetInterface.tables[0].rows)

            if is_toggle_on(self.result_by, gt_option):
                self.pred_analysis_sheet = create_analysis_spreadsheet(input_pred_image_labeled, inputisStack)
                number_rows_pred = len(self.pred_analysis_sheet.all_interfaces.HxSpreadSheetInterface.tables[0].rows)



        # if the input is a single dataset, process it as such ; output the confusion matrix 
        if inputisStack == False:
            gt_array = input_gt_image_labeled.get_array()
            pred_array = input_pred_image_labeled.get_array()

            # selected_labels = [int(s) for s in self.selected_labels.text.split()]
            map_fwd_gt, map_fwd_pred, map_rev_gt, map_rev_pred, ignored_pred_labels,\
                confusion_matrix, scores, scores_with_bg, \
                matching_result = matching(gt_array, pred_array, thresh=matching_threshold,
                                        criterion=match_criterion.lower(),
                                        ignored_labels=ignored_gt_labels, overlap_threshold=overlap_threshold,
                                        report_matches=True, details=True,
                                        relabel=is_toggle_on(self.advanced_options, pre_relabel_sequentially_option))
            # ... note that scores shape has 1 fewer rows and columns than overlap

            if is_toggle_on(self.result_output, console_option):
                metrics_names, metrics_values = get_scalar_metrics_RB(matching_result)
                for name, value in zip(metrics_names, metrics_values):
                    print(f"{name}: {value}")
                print(f"Matching pairs: {matching_result.matched_pairs}")
                print(f"Scores: {matching_result.matched_scores}")
                print(f"Accepted: {matching_result.matched_tps}")
                pyscro.ports.showConsole.buttons[0].hit = True
                pyscro.fire()
                print(f"({match_criterion} threshold = {matching_threshold})")

            # Fill tables tabs for results
            # summary table
            if is_toggle_on(self.result_output, sheet_option):
                tcl_module_command(self.summaryResult_sheet, f' setTableName 0 "Matching measures"')
                global_metrics_prepare_table(self.summaryResult_sheet, 0, matching_result, 1)                
                fill_global_metrics_table_RB(self.summaryResult_sheet, 0, matching_result)
            # confusion matrix
            if is_toggle_on(self.result_output, confusion_matrix_option):
                tcl_module_command(self.confusion_sheet, f' setTableName 0 "Confusion matrix"')
                fill_matrix_table_RB(self.confusion_sheet, 0, confusion_matrix, pred_array.max(), 0, 'int', start_label=0)

            # analysis tables, for prediction
            if is_toggle_on(self.result_output, analysis_option) and is_toggle_on(self.result_by, pred_option):
                tmp_sheet = get_spreadsheet_RB(None, visible=False)
                fill_individual_table_RB(tmp_sheet, 0, matching_result, input_pred_image_labeled.range[1], 0, start_label=1)
                append_spreadsheet_to_analysis(self.pred_analysis_sheet, tmp_sheet, 0)
                hx_project.remove(tmp_sheet)
            # analysis tables, for ground truth
            if is_toggle_on(self.result_output, analysis_option) and is_toggle_on(self.result_by, gt_option):
                tmp_sheet2 = get_spreadsheet_RB(None, visible=False)
                fill_individual_table_RB(tmp_sheet2, 0, matching_result, input_gt_image_labeled.range[1], 1, start_label=1)
                append_spreadsheet_to_analysis(self.gt_analysis_sheet, tmp_sheet2, 0)
                hx_project.remove(tmp_sheet2)
      

        else:
            # we are processing a stack of slices, so process them one by one and fill the results accordingly
            numSlices = input_gt_image_labeled.get_array().shape[2]
            row_offset_pred = 0
            row_offset_gt = 0
            print(f"Processing stack of {numSlices} slices")

            for sliceIndex in range(numSlices):
                print(f"processing slice {sliceIndex}")
                gt_array = input_gt_image_labeled.get_array()[:,:,sliceIndex]
                pred_array = input_pred_image_labeled.get_array()[:,:,sliceIndex]

                num_objects_gt = gt_array.max()
                num_objects_pred = pred_array.max()
                print(f"number of objects in GT: {num_objects_gt}, in prediction: {num_objects_pred}")

                ##TODO: warning: if the number of objects is 0... I should bypass this call, and handle the outputs separately
                
                map_fwd_gt, map_fwd_pred, map_rev_gt, map_rev_pred, ignored_pred_labels,\
                    confusion_matrix, scores, scores_with_bg, \
                    matching_result = matching(gt_array, pred_array, thresh=matching_threshold,
                                            criterion=match_criterion.lower(),
                                            ignored_labels=ignored_gt_labels, overlap_threshold=overlap_threshold,
                                            report_matches=True, details=True,
                                            relabel=is_toggle_on(self.advanced_options, pre_relabel_sequentially_option))

                # Fill tables tabs for results
                # summary table: a tab indexed 1 should have the summary for each slice (one per row, ideally, including an 'indeximage' column)
                # and the main tab, to be filled at the end of the loop, would have global summary.
                if is_toggle_on(self.result_output, sheet_option):
                    tcl_module_command(self.summaryResult_sheet, f' setTableName 0 "Matching measures"')
                    
                    if sliceIndex == 0:
                        newtab_index = add_table_RB(self.summaryResult_sheet, f"per Slices") 
                        global_metrics_prepare_table(self.summaryResult_sheet, 0, matching_result, 1)
                        global_metrics_prepare_table(self.summaryResult_sheet, 1, matching_result, numSlices)                    
                    fill_global_metrics_table_RB(self.summaryResult_sheet, 1, matching_result, sliceIndex)
                    if sliceIndex == numSlices - 1:
                        fill_summary_sheet_from_slices(self.summaryResult_sheet)


                # confusion matrix, should be one tab per slice, with the confusion matrix for each slice |= OK!
                if is_toggle_on(self.result_output, confusion_matrix_option):
                    if sliceIndex > 0:
                        add_table_RB(self.confusion_sheet, f"Slice {sliceIndex}")
                    tcl_module_command(self.confusion_sheet, f' setTableName {sliceIndex} "Slice {sliceIndex}"')
                    fill_matrix_table_RB(self.confusion_sheet, sliceIndex, confusion_matrix, num_objects_pred, 0, 'int', start_label=0)

                # analysis tables: these will remain single spreadsheets, but indeximage column will allow separate analysis for each slice
                if self.pred_analysis_sheet is not None:
                    tmp_sheet = get_spreadsheet_RB(None, visible=False)
                    fill_individual_table_RB(tmp_sheet, 0, matching_result, num_objects_pred, 0, start_label=1)
                    append_spreadsheet_to_analysis(self.pred_analysis_sheet, tmp_sheet, row_offset_pred)
                    row_offset_pred = (int)(row_offset_pred + num_objects_pred)
                    hx_project.remove(tmp_sheet)
                # analysis tables, for ground truth
                if self.gt_analysis_sheet is not None:
                    tmp_sheet2 = get_spreadsheet_RB(None, visible=False)
                    fill_individual_table_RB(tmp_sheet2, 0, matching_result, num_objects_gt, 1, start_label=1)
                    append_spreadsheet_to_analysis(self.gt_analysis_sheet, tmp_sheet2, row_offset_gt)
                    row_offset_gt = (int)(row_offset_gt + num_objects_gt)
                    hx_project.remove(tmp_sheet2)

                

        if is_toggle_on(self.result_output, image_option):
            #binarize the ground truth & prediction inputs
            # with these values, and the default label colormap, this gives the following color code:
            # green are true positives, red are false positives, blue are false negatives
            # TODO: add 'materials' and explicit colors to the output image...
            # I could also output some images, where entire objects (from GT and/or from Prediction, so possibly 2 more outputs)
            # are color-coded, whether they have an accepted match; or not (and whether it is a false positive or a false negative)

            arithmModule = hx_object_factory.create('HxArithmetic')
            arithmModule.ports.inputA.connect(input_gt_image_binary)
            arithmModule.ports.inputB.connect(input_pred_image_binary)
            arithmModule.fire()
            arithmModule.ports.expr0.text = "4*(A==1)*(A==B)+3*(A==0)*(B==1)+2*(A==1)*(B==0)"
            arithmModule.execute()
            self.output_image = arithmModule.results[0]
            #expose the output. (to try, always make this call, and check whether this clear the previous output if it is no longer requested...)
            set_module_result(pyscro, self.output_image, basename + f".pixelDiff", discard, self.image_slot)
        else:
            set_module_result(pyscro, None, "", True, self.image_slot)

        #  Expose outputs
        # summary table
        if is_toggle_on(self.result_output, sheet_option):
            set_module_result(pyscro, self.summaryResult_sheet, basename + f".matchingGT", discard, self.summaryResult_sheet_slot)
        else:
            set_module_result(pyscro, None, "", True, self.summaryResult_sheet_slot)

        # confusion matrix
        if is_toggle_on(self.result_output, confusion_matrix_option):
            set_module_result(pyscro, self.confusion_sheet, basename + f".confusion", discard, self.confusion_sheet_slot)
        else:
            set_module_result(pyscro, None, "", True, self.confusion_sheet_slot)

        # analysis tables, for prediction
        if is_toggle_on(self.result_output, analysis_option) and is_toggle_on(self.result_by, pred_option):
            set_module_result(pyscro, self.pred_analysis_sheet, basename + f".analysisPred", discard, self.pred_analysis_slot)
        else:
            set_module_result(pyscro, None, "", True, self.pred_analysis_slot)
        # analysis tables, for ground truth
        if is_toggle_on(self.result_output, analysis_option) and is_toggle_on(self.result_by, gt_option):
            set_module_result(pyscro, self.gt_analysis_sheet, basename + f".analysisGT", discard, self.gt_analysis_slot)
        else:
            set_module_result(pyscro, None, "", True, self.gt_analysis_slot)

        
        return
