#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Mar 13 17:29:17 2019

@author: aurelien
"""

import numpy as np
import matplotlib.pyplot as plt
import os
import glob
import sys
import multipletau
import h5py
import matplotlib

from scipy.interpolate import griddata
from scipy.optimize import curve_fit
from pyfcs.fitting import Gc2D,Gt,Curve
from pyfcs.io import open_SIN
sys.path.append("../..")

from constants import FIG_UNIT, PLOT_FONT_SIZE

def comparison_file_extractor(file,open_stacks=True):
    """Opens the result of an experiment stored in h5 format
    Parameters:
        file: str, path to file
        open_stacks: bool, if False does not load the stacks (for less memory consumption)"""
    out={}
    h5f = h5py.File(file, 'r')
    for k in h5f.keys():
        if k!="filenames":
            if k=="stacks" and not open_stacks:
                continue
            out[k] = h5f[k].value
        else:
            fn = {}
            for kk in h5f["filenames/"]:
                nr = int(kk[4:])
                fn[nr] = h5f["filenames"][kk].value
                # print(kk,fn[nr])
            fn = sorted(fn.items())
            nrs = np.array([x[0] for x in fn])
            fn = [x[1] for x in fn]
            assert(np.all(nrs==np.arange(1,nrs.size+1)) )
            out["filenames"] = fn
    h5f.close()
    return out


font = {'size'   : PLOT_FONT_SIZE}

matplotlib.rc('font', **font)

font = {'size'   : PLOT_FONT_SIZE}

matplotlib.rc('font', **font)

plt.close('all')



TAU_T = 5*10**-3
CONFOCAL_LATERAL = 225/(2*np.sqrt(2*np.log(2))) #nm

X_SIZE = 3*CONFOCAL_LATERAL
Z_SIZE = 1100

def gaussian2D(x,sigma,amplitude):
    return np.exp(-x**2/(2*sigma**2) )*amplitude

def get_parameters(path, show_acfs = False, normalise = True, 
                   normalise_int = True):
    """Gets the parameters from one experiment.
    
    Parameters:
        path (str): path to experiment to analyse
        show_acfs (bool): if True, displays the FCS curves
        normalise_int (bool): if True, returns intensity normalised to highest 
            count value
    Returns:
        dict: containing the relevant information
        """
    files = glob.glob(path+"/*.npy")
    
    files.sort()
    all_depths = np.array([ int(w.split(os.sep)[-1].split(".")[0].split("_")[-1]) \
        for w in files ])
    
    def Gfit(tau,N,tauD,T):
        return Gc2D(tau,N,tauD,1)*Gt(tau,T,TAU_T)
    
    bounds = ((10**-1,10**-3,0),(10**4,100,0.9))
    
    du = sorted(list(np.unique(all_depths)))    
    all_intensities_d = dict(zip(du,[[] for w in du]))
    all_curves_d = dict(zip(du,[[] for w in du]))
    
    for j in range(len(files)):
        file = files[j]
        depth = all_depths[j]
        corr,intensity = np.load(file,allow_pickle = True)
        curve = Curve(corr=corr,power=2000,fitter=Gfit,bounds=bounds,first_n = 3)
        curve.fit()
        all_curves_d[depth].append(curve)
        all_intensities_d[depth].append(intensity)
        
    if show_acfs:
        nc = len(all_curves_d[all_depths[0]])
        fig,axes = plt.subplots(len(all_depths)//2+len(all_depths)%2,2)
        axes = axes.ravel()
        for j in range(all_depths.size):
            for k in range(nc):
                curve = all_curves_d[all_depths[j]][k]
                curve.plot(axes = axes[j],normalise = False)
            axes[j].set_title(str(all_depths[j])+" nm")
        fig.suptitle(files[0].split(os.sep)[-2])
        
    depths = np.array(du)
    all_intensities = [np.mean(all_intensities_d[d]) for d in depths]
    maxind = all_intensities.index(max(all_intensities))
    
    all_intensities = np.asarray(all_intensities)
    all_taus = [np.mean([c.tau for c in all_curves_d[d]]) for d in depths]
    Ns = [[c.N for c in all_curves_d[d]] for d in depths]
    Ns = np.array(Ns)
    Ns = np.mean(Ns, axis=1)    
    all_xy = np.array([np.mean( [np.sqrt(c.tau) for c in all_curves_d[d] ]) for d in depths])
    if normalise_int:
        all_intensities = all_intensities/all_intensities.max()
    out_dict = {"maxind":maxind,
                "depths":depths,
                "intensities":all_intensities,
                "taus":all_taus,
                "xy":all_xy,
                "Ns": Ns
                }
    return out_dict

def acf_plot(path,save=True,savename="SLB_ACF"):
    files = glob.glob(path+"/*.SIN")
    
    files.sort()
    
    depths = np.array([ int(w.split(os.sep)[-1].split(".")[-2].split("_")[-1]) \
        for w in files ])
    
    def Gfit(tau,N,tauD,T,d):
        return Gc2D(tau,N,tauD,1)*Gt(tau,T,TAU_T)+d
    
    
    bounds = ((10**-3,10**-3,0,-0.1),(10**3,100,0.99,0.1))
    
    
    all_intensities = []
    all_curves = []
    fig,ax = plt.subplots(1,figsize = (FIG_UNIT/2,FIG_UNIT))
    norm = 0
    print(files,depths,path)
    for file,depth in zip(files,depths):
        corr,intensity = open_SIN(file,channel=1)
        intensity = intensity[:,1]
        curve = Curve(corr=corr,power=2000,fitter=Gfit,bounds=bounds)
        curve.fit()
        if depth==0:
            norm = curve.N
        if depth in [0,100,400]:
            ax.semilogx(curve.x,curve.y*norm)
            ax.semilogx(curve.x,curve.fitter(curve.x,*curve.popt)*norm,linestyle="--",color="black")
        all_curves.append(curve)
        all_intensities.append(intensity)
    
        ax.set_xlabel(r"$\rm \tau \ (ms)$")
        # ax.set_ylabel(r"$\rm G(\tau)$")
    ax.get_yaxis().set_visible(False)
    fig.tight_layout()
    if save:
        fig.savefig(savename+".svg",transparent=True)
    
def export_file(file, target_dirs):
    """Exports as a series of npy files a h5 file"""
    
    ext = comparison_file_extractor(file)
    notes = ext["notes"]
    
    if "zSTED" in notes or "z-STED" in notes:
        target_dir = target_dirs[0]
        
    elif "2D" in notes:
        target_dir = target_dirs[1]
    elif "CH" in notes:
        target_dir = target_dirs[2]
    elif "confocal" in notes:
        target_dir = target_dirs[3]
    else:
        target_dir = target_dirs[4]
        
    correlations = list()
    
    name = file.split("/")[-1].split(".")[0]
    name = target_dir + "/" + name
    
    os.mkdir(name)
    zs = ext['zs']
    zs = zs-zs[zs.size//2]
    zs = (zs * 10**9).astype(np.int)
    stacks = ext['stacks']
    
    for j in range(stacks.shape[0]):
        corr = multipletau.autocorrelate(stacks[j],deltat=4*10**-3,normalize=True)
        correlations.append(corr)
        np.save(name+"/0_"+str(zs[j]),np.array([corr,stacks[j]]))
    notes = ext["notes"]
    with open(name+"/"+notes+".txt","w") as f:
        f.write(notes+" "+notes)

    path = name
    return path

def export_all(subfiles):
    dirs = ["zSTED","2D","CH","confocal","others"]
    [os.mkdir(w) for w in dirs if not os.path.isdir(w)]
    
    for file in subfiles:
        try: 
            export_file(file,dirs)
        except Exception as e:
            print(file,"raised the following exception:")
            print(e)

def plot_results_dict(od, normalise = False):
    """Plots the results of analysis of an experiment"""
    maxind = od["maxind"]
    depths = od["depths"]
    intensities = od["intensities"]
    xy = od["xy"]
    Ns = od["Ns"]
    taus = od["taus"]
    
    plt.figure(figsize = (8,3))
    plt.subplot(131)
    plt.plot(depths,intensities/intensities[maxind])
    plt.xlabel("Depths (nm)")
    plt.ylabel("I (norm.)")
    
    plt.subplot(132)
    if normalise:
        plt.plot(depths,Ns/Ns[maxind],marker="o", label = "N")
        plt.plot(depths,taus/taus[maxind],marker="o", label = "tau")
    else:
        plt.plot(depths,Ns,marker="o", label = "N")
        plt.plot(depths,taus,marker="o", label = "tau")
        
    plt.xlabel("Depths (nm)")
    plt.ylabel("N/N0, tau/tau0")
    plt.legend()
    
    plt.subplot(133)
    plt.plot(depths,Ns/Ns[maxind],marker="o", label = "N")
    plt.plot(depths,taus/taus[maxind],marker="o", label = "tau")
    plt.xlabel("Depths (nm)")
    plt.ylabel("N/N0, tau/tau0")
    plt.ylim(bottom = 0.8, top = 5)
    plt.tight_layout()

def merge_depths(d):
    """Merges depths values with close proximity"""
    differences = np.diff(d)
    new_indices = [0]
    center = (d[0]+d[-1])/2
    th_large = 100
    th_low = 15
    currind = 0
    for j in range(1,len(d)):
        dcenter = np.abs(d[j]-center)
        if dcenter>200:
            if np.abs(differences[j-1]>th_large):
                currind = j
        else:
            if np.abs(differences[j-1]>th_low):
                currind = j
        new_indices.append(currind)
    print(len(new_indices), len(d))
    assert(len(new_indices) == len(d))
    depths_new = d[list(np.unique(new_indices))]
    return new_indices, depths_new  

def make_merge_list(l1):
    corresp_list = [w for w in l1]
    for j in range(1, len(l1)):
            u = corresp_list[j]
            if u>corresp_list[j-1]+1:
                nu = corresp_list[j-1]+1
                corresp_list[j] = nu
                corresp_list = [w if w!=u else nu for w in corresp_list ]
    return corresp_list

def merge_experiments(dict_list, pool = True):
    """Merges the results of several experiments (results stored in a dict)
    and aligns the different results to the position of maximum intensity"""
    assert(len(dict_list)>0)
    all_depths = []
    
    for dic in dict_list:
        depths = dic["depths"]
        depths = list(depths - depths[dic["maxind"]])
        all_depths.extend(depths)
    depths_unique = sorted(list(set(all_depths)))
    depths_unique = np.asarray(depths_unique)
    
    if pool:
        merge_list,depths_merged = merge_depths(depths_unique)
        corresp_list = make_merge_list(merge_list)
        
    def merge_quantity(dict_list,quantity_name):
        if pool:
            new_quantity = np.zeros_like(depths_merged).astype(np.float)
            counter = np.zeros_like(depths_merged)
        else:
            new_quantity = np.zeros_like(depths_unique).astype(np.float)
            counter = np.zeros_like(depths_unique)
            
        for od in dict_list:
            depths = od["depths"]
            depths -= depths[od["maxind"]]
            assert(depths[od["maxind"]]==0)
            
            quantity = od[quantity_name]
            for j in range(len(depths)):
                if pool:
                    # Where the data lies in old data ordering position
                    new_index = np.where(depths[j]==depths_unique)[0][0]
                    # Where it goes in new data ordering position
                    new_index = corresp_list[new_index]
                else:
                    new_index = np.where(depths[j]==depths_unique)[0][0]
                    
                new_quantity[new_index] += quantity[j]
                counter[new_index]+=1
        assert(np.count_nonzero(counter)==counter.size)
        assert(np.sum(counter)==sum([len(w["depths"]) for w in dict_list]))
        new_quantity = np.array(new_quantity)
        counter = np.array(counter)
        new_quantity/=counter
        return new_quantity
    depth_to_save = depths_unique
    if pool:
        depth_to_save = depths_merged
    new_maxind = np.where(depth_to_save==0)[0][0]
    new_dict = {"maxind":new_maxind,
        "depths":depth_to_save,
        "intensities": merge_quantity(dict_list,"intensities"),
        "taus":merge_quantity(dict_list,"taus"),
        "xy":merge_quantity(dict_list,"xy"),
        "Ns": merge_quantity(dict_list,"Ns")
        }
    return new_dict


def plot_profile_fromdict(pardict, ax0, ax1, debug = False, xy0_confocal = 1.41):
    """PLots the results of an experiment stored in a dict"""
    Ns = pardict["Ns"]
    depths = pardict["depths"]
    intensities = pardict["intensities"]
    taus = np.array(pardict["taus"])
    all_xy = pardict["xy"]
    maxind = pardict["maxind"]
    assert(intensities[maxind]==intensities.max())
    # New modifications
    depths = -1* (depths - depths[maxind])
    mk = depths>=0
    intensities = intensities[mk]
    print(taus)
    taus = taus[mk]
    
    all_xy = all_xy[mk]
    depths = depths[mk]
    Ns = Ns[mk]
    
    intensities/=intensities[maxind]
    taus/=taus[maxind]
    Ns/=Ns[maxind]
    
    all_xy0_confocal = xy0_confocal
    all_xy/= all_xy0_confocal
    all_xy*=CONFOCAL_LATERAL
    
    # indm = 0
    Nexpected = np.array( [ Ns[maxind] * xy**2/all_xy[maxind]**2 for xy in all_xy] )
    # print( (Ns - Nexpected )/Nexpected)
    snb_inv = np.sqrt(np.abs((Ns - Nexpected ))/Nexpected)
    snb_inv[(Ns-Nexpected)<0]=0
    snb_inv[Ns>15]=np.inf
    print()
    print("Background/signal",snb_inv)
    
        
    x1 = np.linspace(-X_SIZE,X_SIZE,25)
    
    profiles = []
    for dx,ints in zip(all_xy,intensities):
        profiles.append(gaussian2D(x1,dx,np.sum(ints)))
        
    profiles = np.asarray(profiles)
    xx,yy=np.meshgrid(x1,depths)
    
    xx = xx.reshape(-1)
    yy = yy.reshape(-1)
    
    # Set signal to intensity correlating
    # Set background to (1+ncorrintensity)
    int_ncorr = intensities/(1+1/snb_inv)
    int_corr = intensities/(1+snb_inv)
    profiles_correlating = []
    profiles_constint = []
    
    for (dx,ints,ncints) in zip(all_xy,int_corr,int_ncorr):
        profiles_correlating.append(gaussian2D(x1,dx,np.sum(ints+ncints)))
        profiles_constint.append(gaussian2D(x1,dx,1))

    profiles_correlating = np.asarray(profiles_correlating)
    profiles_constint = np.asarray(profiles_constint)
    
    xi,yi = np.meshgrid(np.linspace(x1.min(),x1.max(),100),
                       np.linspace(depths.min(),depths.max(),100))
    
    zcorr_i = griddata((xx, yy), profiles_correlating.reshape(-1), (xi, yi), method='linear')
    
    zcorr_i = np.flipud(zcorr_i)
    # zcorr_i[~mask] = np.nan
    
    ax0.imshow(np.zeros((20,20)),extent=[-X_SIZE,X_SIZE,0,Z_SIZE],cmap='hot')
    cs = ax0.imshow(zcorr_i,extent=[-X_SIZE,X_SIZE,0,depths.max()],cmap='afmhot')
    # cs = ax0.imshow(znoncorr_i,extent=[-X_SIZE,X_SIZE,0,depths.max()],cmap='RdPu_r',vmin=0,vmax=13)
    
    # Calcluation of total SBR. Not currently displayed
    def mk_weights(zd):
        """Weights each point based on spatial sampling"""
        n = len(zd)
        weights = np.zeros(n)
        for j in range(n):
            if j==0:
                weights[j] = (zd[j+1]-zd[j])/2
            elif  j==n-1:
                weights[j] = (zd[j]-zd[j-1])/2
            else:
                weights[j] = (zd[j+1]-zd[j-1])/2
        return weights
    ws = mk_weights(depths)
    sbr = round(np.sum(int_corr*ws)/np.sum(int_ncorr*ws),1)
    # textstr = "SBR: "+str(round(np.sum(int_corr*ws)/np.sum(int_ncorr*ws),1))
    # place a text box in upper left in axes coords
    print("SBR total:",sbr)
    textstr = "SBR "+str(round(sbr,1))
    props = dict(boxstyle='round', facecolor='white', alpha=1)
    ax0.text(0.7, 0.95, textstr, transform=ax0.transAxes, fontsize=8,
             verticalalignment='top', bbox=props)

    ax0.set_xlabel("x [nm]")
    ax0.set_ylabel('z [nm]')
    ms = 2
    ax1.plot(intensities,depths,"-s",label=r"total", markersize=ms)
    ax1.plot(int_corr,depths,'-o',color="green",label="signal", markersize=ms)
    ax1.plot(int_ncorr,depths,'-v',color="magenta",label="background", markersize=ms)
    ax1.set_xlabel(r"$\rm I/I_{0}$")
    ax1.set_ylabel("z [nm]")
    ax1.set_ylim(top = Z_SIZE,bottom=0)
    ax1.set_xlim(left=0,right=1)
    
    # fit intensity vs depth x,sigma,amplitude
    plot_fit = False
    if plot_fit:
        def g1(x,sigma,amplitude,z0):
            return np.exp(-(x-z0)**2/(2*sigma**2) )*amplitude
        popt,_ = curve_fit(g1,depths,
                           intensities, bounds = 
                           ((20,0.2,depths.min()),(depths.max(),2,depths.max())))
        ax1.plot(g1(depths,*popt),depths,color="black",linestyle="--")
    
    return cs

fold_ch = 'CH/16_10_19_20h22m53/'
fold_2d = '2D/16_10_19_20h09m01/'
fold_conf = 'confocal/16_10_19_19h32m48/'
fold_z = 'z/16_10_19_20h13m32/'

od_conf = get_parameters(fold_conf,show_acfs=False)
xy0_conf = np.sqrt(od_conf["taus"][od_conf["maxind"]])

folders_selected = [fold_conf,fold_2d,fold_z,fold_ch]
names = ["confocal","2D","z","CH"]
figsize = (FIG_UNIT*1.5,FIG_UNIT*0.8)

  
show_onetype = False
if show_onetype:
    folder = "2D"
    folds = glob.glob(folder+"/*/")
    folds = list(filter(lambda x: "trash" not in x,folds))
    
    ods = [get_parameters(w, show_acfs=True) for w in folds]
    mdict1 = merge_experiments(ods, pool = False)

    for od in ods:
        fig, (ax0, ax1) = plt.subplots(1, 2, gridspec_kw={'width_ratios': [3, 1]},
           figsize=figsize,sharey = True)
        plot_profile_fromdict(od,ax0,ax1)
        plot_results_dict(od)

# Raw merge is ok for 2D
# ok-ish for zSTED
# ok-ish for CH
# ok-ish for confocal
merge = True
gridspec_kw = {'width_ratios': [1,1.2]}
if merge:
    pool = True
    for j,fold in enumerate(names):
        folds = glob.glob(fold+"/*/")
        folds = list(filter(lambda x: "trash" not in x,folds))
        ods = [get_parameters(w) for w in folds]
        mdict = merge_experiments(ods, pool = pool)
        
        fig, (ax0, ax1) = plt.subplots(1, 2, gridspec_kw = gridspec_kw,
               figsize=figsize,sharey = True)
        plot_profile_fromdict(mdict,ax0,ax1)
        ax0.set_title(fold)
        fig.tight_layout()
        if fold=="CH":
            ax1.legend()
        fig.savefig(fold+"_SLB.svg",transparent = True)


else:
    for j,fold in enumerate(folders_selected):
        od = get_parameters(fold,show_acfs=False)
        fig, (ax0, ax1) = plt.subplots(1, 2, gridspec_kw = gridspec_kw,
           figsize=figsize,sharey = True)
        plot_profile_fromdict(od,ax0,ax1)
        ax0.set_title(names[j])
        fig.tight_layout()
        fig.savefig(names[j]+"_SLB.svg",transparent = True)

ax1.legend()
fig.savefig(names[j]+"_legend.svg",transparent = True)
# intensity in confocal vs 2D STED
compare_absolute_intensities = False

if compare_absolute_intensities:
    plt.figure()
    folders = ["confocal", "2D"]
    colors=["gray","blue"]
    for i,folder in enumerate(folders):
        folds = glob.glob(folder+"/*/")
        # Remove bad curves, i.e that don't converge towards 0 at long lag times
        folds = list(filter(lambda x: "trash" not in x,folds))
        ods = [get_parameters(w, normalise_int=False) for w in folds]
        intensities = [w["intensities"] for w in ods]
        depths = [w["depths"] for w in ods]
        for j in range(len(intensities)):
            plt.plot(depths[j], intensities[j], color=colors[i])
        
FIT_CONFOCAL = True
if FIT_CONFOCAL:
    def g1(x,x0,sigma,amplitude):
        return np.exp(-(x-x0)**2/(2*sigma**2) )*amplitude
    def g2(x,x0,sigma,amplitude):
        return amplitude/(1+(x-x0)**2/sigma**2)
    
    folds = glob.glob("confocal"+"/*/")
    folds = list(filter(lambda x: "trash" not in x,folds))
    
    ods = [get_parameters(w) for w in folds]
    mdict = merge_experiments(ods, pool = pool)
    depths = mdict["depths"]
    intensities = mdict["intensities"]
    bounds = ((depths.min(),0,0),(depths.max(),depths.max(),intensities.max()*2))
    gg = g2
    popt,_ = curve_fit(gg,depths,intensities, bounds = bounds)
    
    plt.figure()
    plt.plot(depths,intensities)
    plt.plot(depths,gg(depths,*popt),color="black",linestyle="--")

    """
    fig, (ax0, ax1) = plt.subplots(1, 2, gridspec_kw = gridspec_kw,
           figsize=figsize,sharey = True)
    plot_profile_fromdict(mdict,ax0,ax1)
    ax0.set_title(fold)
    fig.tight_layout()"""