# -*- coding: utf-8 -*-

"""
Convenience tools for sorting and parsing Thermo Fisher DIB files in Amira/Avizo.
For more information see: https://en.wikipedia.org/wiki/BMP_file_format#DIB_header_(bitmap_information_header).
"""

import sys
import os
import re
import struct
import numpy as np
import _hx_core 
from _hx_core import _tcl_interp, hx_project, hx_object_factory

def parseDibZStackFolder(dibDir):
    """
    Parse the folder, and regroup all datasets into a data structure by plate, wells, fields, channels
    time steps, and z slices.
    also, lists the number of channels and time steps
    recover the relevant meta-data 
    """
    plateDict = {}
    
    fileList = sorted(os.listdir(dibDir))
    
    for file in fileList:

        #currently, we look only at DIB files
        #perhaps this could be slightly extended, to support exports in BMP ; or later C01 files.
        if not file.endswith(".DIB"):
            continue 

        if re.search(r'_[A-Z]{1,2}[0-99]{2}f(.*)[DIB]',file) is None:
            continue # Go to next file if the current file is not a DIB with a well descriptor

        #Look for the name of the well, and register it in the structure if not yet in.
        #search for characters: _ ; one or 2 capital letters ; 2 digits ; followed by 'f'
        well = re.search(r'_[A-Z]{1,2}[0-99]{2}f',file).group() 
        well = well[1:-1]

        if well not in plateDict.keys():
            plateDict[well] = {}
            
        #Look for the name of the field, and register it in the structure if not yet in.
        #TODO: the number of fields may grows above 100, in which case 3 digits will show up... this is not handled at this stage.
        field = re.search(r'[f][0-9999]{1,4}',file).group() 
        if field not in plateDict[well]:
            plateDict[well][field] = {}
            
        #Look for the name of the channel, and register it in the structure if not yet in.
        #also, build the general list of channels (all wells and fields should have the same channels acquired)
        channel = re.search(r'[d][0-9]{1}',file).group() # Also list channels; assumed to be same for each field
        # if channel not in channel_list:
            # channel_list.append(channel)

        if channel not in plateDict[well][field]:
            plateDict[well][field][channel] = {}
            
        # #TODO: insert time
        # in case of a Kinetic acquisition, the filename pattern is: <UPD>i3t<TimePoint><WellName>f<FieldIndex>d<ChannelIndex>.DIB or CO1
        # hasTime = false
        # if hasTime is True:
            # print("TODO... time...")
            # time = 't0'
            # if time not in plateDict[well][field][channel]:
                # plateDict[well][field][channel][time] = {}

        # else:
            # time = 't0'
            # if time not in plateDict[well][field][channel]:
                # plateDict[well][field][channel][time] = {}
        

        
        # insert slice ; or mip file
        if re.search(r'[z][0-9]{1,3}',file) is None:
            #no 'z' -> this is likely a mip file
            plateDict[well][field][channel]['mip'] = {}
            plateDict[well][field][channel]['mip']['filename'] = os.path.join(dibDir, file)
            #reading the file header is not necessary at this stage.
            #plateDict[well][field][channel]['mip']['metadata'] = readDibHeader(os.path.join(dibDir, file))

        else:
            slice = re.search(r'[z][0-9]{1,3}',file).group()
            if slice not in plateDict[well][field][channel]:
                plateDict[well][field][channel][slice] = {}
            
            plateDict[well][field][channel][slice]['filename'] = os.path.join(dibDir, file)
            #reading the file header is not necessary at this stage.
            #plateDict[well][field][channel][slice]['metadata'] = readDibHeader(os.path.join(dibDir, file))
        
        
    return plateDict


def padForSorting(stringToPad):
    """
    Pads the Z descriptor in Thermo Fisher DIB files with zeros to ensure proper slice order when loading as a 3D dataset.
    """
    return stringToPad[stringToPad.rfind('z')+1:stringToPad.rfind('.')].zfill(3)

def readDibHeader(dibPath):
    """
    Reads the parameters from the 52 byte DIB header and custom Thermo Fisher footer into Amira/Avizo object metadata.
    """
    mdf = open(dibPath, "rb") # "My DIB file"
    mdf_metadata = {}
    mdf.seek(0)
    mdf_metadata['biSize'] = int(struct.unpack('I',mdf.read(4))[0])
    mdf_metadata['biWidth'] = int(struct.unpack('i',mdf.read(4))[0])
    mdf_metadata['biHeight'] = int(struct.unpack('i',mdf.read(4))[0])
    mdf_metadata['biPlanes'] = int(struct.unpack('H',mdf.read(2))[0])
    mdf_metadata['biBitCount'] = int(struct.unpack('H',mdf.read(2))[0])
    mdf_metadata['biCompression'] = int(struct.unpack('I',mdf.read(4))[0])
    mdf_metadata['biSizeImage'] = int(struct.unpack('I',mdf.read(4))[0])
    mdf_metadata['biXPelsPerMeter'] = int(struct.unpack('i',mdf.read(4))[0])
    mdf_metadata['biYPelsPerMeter'] = int(struct.unpack('i',mdf.read(4))[0])
    mdf_metadata['biClrUsed'] = int(struct.unpack('I',mdf.read(4))[0])
    mdf_metadata['biClrImportant'] = int(struct.unpack('I',mdf.read(4))[0])
    mdf_metadata['palette1'] = int(struct.unpack('I',mdf.read(4))[0])
    mdf_metadata['palette2'] = int(struct.unpack('I',mdf.read(4))[0])
    mdf_metadata['palette3'] = int(struct.unpack('I',mdf.read(4))[0])
    # Skip image data for this first pass on header only
    mdf.seek(52+(mdf_metadata['biWidth']*mdf_metadata['biHeight']*2)) 
    # continue reading EXTRA HEADER. Old DIB files may not have such extra header
    # TODO: avoid errors if extra header is not present.
    mdf_metadata['magicNumber'] = int(struct.unpack('I',mdf.read(4))[0])
    mdf_metadata['formatVersion'] = int(struct.unpack('i',mdf.read(4))[0])
    mdf_metadata['headerSize'] = int(struct.unpack('I',mdf.read(4))[0])
    mdf_metadata['biZPelsPerMeter'] = int(struct.unpack('i',mdf.read(4))[0])
    mdf_metadata['centerX'] = int(struct.unpack('d',mdf.read(8))[0])
    mdf_metadata['centerY'] = int(struct.unpack('d',mdf.read(8))[0])
    mdf_metadata['centerZ'] = int(struct.unpack('d',mdf.read(8))[0])
    mdf_metadata['acquisitionTime'] = int(struct.unpack('q',mdf.read(8))[0])
    mdf_metadata['effectiveExposureTime'] = int(struct.unpack('d',mdf.read(8))[0])
    mdf_metadata['maxReferenceValue'] = int(struct.unpack('d',mdf.read(8))[0])
    mdf.close()
    return mdf_metadata

def readDibImage(DibFile,xsize,ysize):
    """
    Reads only image data from DIB into a numpy array.
    """
    mdf = open(DibFile, "rb") # "My DIB file"
    mdf.seek(52)
    num_pixels = xsize*ysize
    mdf_image = np.ndarray(shape=(xsize,ysize),dtype=np.int16)
    try:
        #initial proposal is the first, commented line. Proposed replacer is much faster, but may be less robust 
        #mdf_image = np.asarray(struct.unpack(num_pixels*'H',mdf.read(num_pixels*2))).reshape(xsize,ysize)
        mdf_image = np.fromfile(mdf, dtype = np.dtype('H'), count = num_pixels).reshape(xsize,ysize)
    except struct.error:
        mdf_image = np.zeros((xsize,ysize)) # Made b/c 1 slice corrupt in test data and could not be read
        print('struct.error encountered in readDibImage for {}'.format(DibFile))
    mdf.close()
    return mdf_image

def readDibFileAsAmiraImage(DibFile):
    """
    Reads a DIB image file as an Amira Image
    """
    #
    dib_metadata = readDibHeader(DibFile)

    # Construct bounding box from DIB data
    x_vox_size = 1000000.0/dib_metadata['biXPelsPerMeter']
    y_vox_size = 1000000.0/dib_metadata['biYPelsPerMeter']
    z_vox_size = (1000000.0/dib_metadata['biZPelsPerMeter'])

    x_length = (dib_metadata['biWidth']-1)*x_vox_size
    y_length = (dib_metadata['biHeight']-1)*y_vox_size
    z_length = (1000000.0/dib_metadata['biZPelsPerMeter'])
    
    x_min = dib_metadata['centerX']-x_length/2.0
    y_min = dib_metadata['centerY']-y_length/2.0
    z_min = dib_metadata['centerZ']-z_length/2.0

    x_max = x_min + x_length
    y_max = y_min + y_length
    z_max = z_min + z_length
    
    channelBBox = ((x_min,y_min,z_min),(x_max,y_max,z_max))
    
    # Construct data object to hold results and apply bounding box
    amiraImage = hx_object_factory.create('HxUniformScalarField3')
    #channelObject = _hx_core.HxObjectFactory().create('HxUniformScalarField3')
    amiraImage.name = 'img'
    amiraImage.bounding_box = channelBBox
    amiraImage.fire()
    # Add metadata fields as parameters of the channel
    dictToParameters(dib_metadata,amiraImage)

    # Read images into array and set array to data
    dibImage = np.ndarray(shape=(dib_metadata['biWidth'],dib_metadata['biHeight'],1),dtype=np.int16)
    dibImage[:,:,0] = readDibImage(DibFile,dib_metadata['biWidth'],dib_metadata['biHeight'])
        
    amiraImage.set_array(np.flip(np.swapaxes(dibImage,0,1),1))
    
    amiraImage.fire()
    hx_project.add(amiraImage)
    return amiraImage
    
        
    
def loadDibChannel(channelDict, channelName = 'channel'):
    """
    Loads a 3D stack representing a single well field from a list of DIB files based on the filename.
    """

    channelSliceFiles = []    
    #get list of slices in the selected channel, in the adequate order 
    #and read the metadata, from the dict or from the files on disk if necessary
    # recover the zmin and zmax from the meta-data.
    z_min = sys.float_info.max
    z_max = - sys.float_info.max
    listSlicesNames = sorted(channelDict.keys())
    if 'mip' in listSlicesNames:
        listSlicesNames.remove('mip')

    for slice in listSlicesNames:
        channelSliceFiles.append( channelDict[slice]['filename'] )
        if channelDict[slice].get('metadata') is None:
            channelDict[slice]['metadata'] = readDibHeader(channelDict[slice]['filename'])
        z_max = max(z_max, channelDict[slice]['metadata']['centerZ'])
        z_min = min(z_min, channelDict[slice]['metadata']['centerZ'])
        
    #TODO: ideally, files should be sorted by 'centerZ' value to ensure proper ordering.
    channelSliceFiles = sorted(channelSliceFiles, key=padForSorting)

    #
    dib_metadata = channelDict[listSlicesNames[0]]['metadata']


    # Construct bounding box from DIB data
    x_vox_size = 1000000.0/dib_metadata['biXPelsPerMeter']
    y_vox_size = 1000000.0/dib_metadata['biYPelsPerMeter']
    #z_vox_size = (1000000.0/dib_metadata['biZPelsPerMeter'])
    x_length = (dib_metadata['biWidth']-1)*x_vox_size
    y_length = (dib_metadata['biHeight']-1)*y_vox_size
    #z_length = (len(channelSliceFiles)-1)*z_vox_size

    x_min = dib_metadata['centerX']-x_length/2.0
    y_min = dib_metadata['centerY']-y_length/2.0
    #z_min = dib_metadata['centerZ']

    x_max = x_min + x_length
    y_max = y_min + y_length
    #z_max = z_min + z_length

    if ( len(channelSliceFiles) == 1):
        z_vox_size = (1000000.0/dib_metadata['biZPelsPerMeter'])
        z_length = (1000000.0/dib_metadata['biZPelsPerMeter'])
        z_min = dib_metadata['centerZ']-z_length/2.0
        z_max = dib_metadata['centerZ']+z_length/2.0
    else:
        z_length = z_max - z_min
        z_vox_size = z_length / (len(channelSliceFiles)-1)
    
    channelBBox = ((x_min,y_min,z_min),(x_max,y_max,z_max))
    
    # Construct data object to hold results and apply bounding box
    channelObject = hx_object_factory.create('HxUniformScalarField3')
    #channelObject = _hx_core.HxObjectFactory().create('HxUniformScalarField3')
    channelObject.name = channelName
    channelObject.bounding_box = channelBBox
    channelObject.fire()
    # Add metadata fields as parameters of the channel
    dictToParameters(dib_metadata,channelObject)

    # Read images into array and set array to data
    dibImage = np.ndarray(shape=(dib_metadata['biWidth'],dib_metadata['biHeight'],len(channelSliceFiles)),dtype=np.int16)
    for s,dib in enumerate(channelSliceFiles):
        dibImage[:,:,-1-s] = readDibImage(dib,dib_metadata['biWidth'],dib_metadata['biHeight'])
    #Trevor initially used this 'swapaxes' command ; but this needs to be re-investigated...
    #channelObject.set_array(np.swapaxes(dibImage,0,1))
    channelObject.set_array(np.flip(np.swapaxes(dibImage,0,1),1))

    #debug: 

    
    channelObject.fire()
    return channelObject

    
def loadDibField(fieldDict, fieldName = 'field', exposeInPool = True):    
    """
    Loads all channels of a field, and builds a MultiChannelField object, assuming all channels share the same lattice
    """

    # Construct the MultiChannelField data object to hold results and assign the channels
    multiChannelObject = hx_object_factory.create('HxMultiChannelField3')
    #multiChannelObject = _hx_core.HxObjectFactory().create('HxMultiChannelField3')
    multiChannelObject.name = fieldName
    if exposeInPool == True:
        hx_project.add(multiChannelObject)
        #_hx_core.HxProject().add(multiChannelObject)
    
    # get list of channels for the selected field and assign them to the MultiChannelField 
    channelList = sorted(list(fieldDict.keys()))
    i=0
    for c in channelList:
        i=i+1
        
        channelName =  '{}_{}'.format(multiChannelObject.name, c)
        channelObject = loadDibChannel(fieldDict[c], channelName)
        if exposeInPool == True:
            hx_project.add(channelObject)
            #_hx_core.HxProject().add(channelObject)

        if (i==1): multiChannelObject.ports.channel1.connect(channelObject);
        elif (i==2): multiChannelObject.ports.channel2.connect(channelObject);
        elif (i==3): multiChannelObject.ports.channel3.connect(channelObject);
        elif (i==4): multiChannelObject.ports.channel4.connect(channelObject);
        elif (i==5): multiChannelObject.ports.channel5.connect(channelObject);
        elif (i==6): multiChannelObject.ports.channel6.connect(channelObject);
        elif (i==7): multiChannelObject.ports.channel7.connect(channelObject);
        elif (i==8): multiChannelObject.ports.channel8.connect(channelObject);
        elif (i==9): multiChannelObject.ports.channel9.connect(channelObject);
        
        multiChannelObject.fire()
    
    
    return multiChannelObject


# Function to write a Python dictionary to HxField parameters
def dictToParameters(parameterDict,hxfield_object):
    """
    Writes a two column dictionary to a data object's parameters as the "DIBHeader" bundle.
    """
    if _tcl_interp('{} parameters hasBundle DIBHeader'.format(hxfield_object.name)) == '0':
        _tcl_interp('{} parameters newBundle DIBHeader'.format(hxfield_object.name))
    for p in parameterDict.keys():
        _tcl_interp('{} parameters DIBHeader setValue {} {}'.format(hxfield_object.name,p,parameterDict[p]))

