#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Feb 25 10:03:49 2019

@author: aurelien
"""

import numpy as np
import matplotlib.pyplot as plt
import matplotlib

import glob
from pyfcs.fitting import Gc,Gt,Curve,make_res_functions,Gc_z

from accessories.misc import file_extractor, aberration_names
import sys
sys.path.append("../..")

from constants import PLOT_DEFAULT,FIG_UNIT,PLOT_FONT_SIZE
# excitation 10 uW in BFP
def reorder(x,y):
    lists = sorted(zip(*[x, y]))
    new_x, new_y = list(zip(*lists))
    return new_x,new_y


plt.close('all')
ALPHA = 1
TAU_T = 11*10**-3
T_AMPL = 0.16
K_CONFOCAL = 4

font = {'size'   : PLOT_FONT_SIZE}

matplotlib.rc('font', **font)
def Gconf(tau,N,tauD,T):
    return Gc(tau,N,tauD,K_CONFOCAL,ALPHA) * Gt(tau,T,TAU_T)

#from aotools.ext.fcs import k_dependency

def sted_maker(tau_confocal,fk):
    def f(tau,N,tauD,T):
        r= (tauD/tau_confocal)**(ALPHA/2)
        return Gc(tau,N,tauD,fk(r),ALPHA) * Gt(tau,T,TAU_T)
    return f

def sted_maker_2D(tau_confocal,fk):
    tauZ = K_CONFOCAL**2*tau_confocal
    def f(tau,N,tauD,T):
        return Gc_z(tau,N,tauD,tauZ,ALPHA) * Gt(tau,T,TAU_T)
    return f

bounds = ((10**-3,10**-3,T_AMPL),(1000,1,T_AMPL+1e-3))

path = "../../data_analysis/FCS/2019_01_29_aberrations/"

files = glob.glob(path+"/*.h5")
files.sort()

"""for file in files:
    correlate_modal_h5(file,erase=True,first=50000)"""    
    
path_confocal = "../../data_analysis/FCS/2019_01_29_aberrations/confocal/29_01_19_14h46m26/"    
confocals = glob.glob(path_confocal+"*.npy")
confocals = [Curve(corr = np.load(w), fitter = Gconf, power=0,bounds=bounds) for w in confocals]
tau_confocal = np.mean([w.tau for w in confocals])
N_confocal = np.mean([w.N for w in confocals])


def plot_aberration_variation(file,sted_type,axN,axV,axcurves1,axcurves2,labl,
                              show_half = True):
    """Extracts, fits and plots data from one aberration file.
    Parameters:
        file: string, path ot h5 file containing data for 1 mask
        sted_type: string, name of STED type
        axn: mpl axis, where to plot number of molecules
        axV: mpl axis, where to plot focal volumes
        labl: string. label that appears in legend
        show_half: bool, if True shows only positive aberration values"""
    ext = file_extractor(file,open_stacks=False)
    f1,fk,fv = make_res_functions(sted_type,K_CONFOCAL)
    if sted_type=="2D":
        fk = lambda x: K_CONFOCAL / x
        fv = lambda x: x**2
        Gsted = sted_maker_2D(tau_confocal,fk)
    else:
        Gsted = sted_maker(tau_confocal,fk)
    # print("sted type",sted_type,fk(0.5),fk(0.2),fk(0.5)*0.5)
    sted_curves = [Curve(corr =w, fitter = Gsted , power=1000,bounds=bounds) 
        for w in ext["log_autocorrelations"]] 
    [w.fit() for w in sted_curves]
    volumes = [ fv(np.sqrt(w.tau/tau_confocal)) for w in sted_curves]
    Ns = np.array([w.N for w in sted_curves])

    xdata = ext["xdata"].squeeze()
    # Curves to show, based on aberration value
    c_to_show=[0,1]
    [w.plot(axes=axcurves1,label=labl,normalise=False) for j,w in zip(xdata,sted_curves) if j == c_to_show[0]]
    [w.plot(axes=axcurves2,label=labl,normalise = False) 
        for j,w in zip(xdata,sted_curves) if j == c_to_show[1]]
    
    x1,v1 = reorder(xdata,volumes)
    x2,n1 = reorder(xdata,Ns)
    
    assert(np.all(x1==x2))
    
    v1 = np.array(v1)
    n1 = np.array(n1)
    x1 = np.array(x1)
    if show_half:
        v1 = v1[x1>=0]
        n1 = n1[x1>=0]
        x1 = x1[x1>=0]
        x2 = x1
    axV.plot(x1,v1,marker="o",label=labl)
    axN.plot(x2,n1,marker="o",label=labl)

def get_aberration_variation(file,sted_type):
    """Extracts and fits and plots data from one aberration file.
    Parameters:
        file: string, path ot h5 file containing data for 1 mask
        sted_type: string, name of STED type
    Returns:
        x1: numpy array, aberration values in rad
        n1: numpy array, numbers of molecules
        v1: numpy array, observation volumes
        sted_curves: list, sted"""
    ext = file_extractor(file,open_stacks=False)
    f1,fk,fv = make_res_functions(sted_type,K_CONFOCAL)
    if sted_type=="2D":
        fk = lambda x: K_CONFOCAL / x
        fv = lambda x: x**2
        Gsted = sted_maker_2D(tau_confocal,fk)
    else:
        Gsted = sted_maker(tau_confocal,fk)
    xdata = ext["xdata"].squeeze()
    # store aberration value in power field of FCS curve
    sted_curves = [Curve(corr =w, fitter = Gsted , power=xd,bounds=bounds) 
        for w,xd in zip(ext["log_autocorrelations"],xdata)] 
    [w.fit() for w in sted_curves]
    volumes = [ fv(np.sqrt(w.tau/tau_confocal)) for w in sted_curves]
    Ns = np.array([w.N for w in sted_curves])

    x1,v1 = reorder(xdata,volumes)
    x2,n1 = reorder(xdata,Ns)
    x3,new_curves = reorder(xdata,sted_curves)
    
    assert(np.all(x1==x2))
    v1 = np.array(v1)
    n1 = np.array(n1)
    x1 = np.array(x1)
    return x1,n1,v1,new_curves

def plot_one_file(file,sted_type,show_half = False):
    """Obsolete script, can be used for visualisation of more data than currently"""
    ext = file_extractor(file,open_stacks=False)
    modes = int(ext["modes"].squeeze())

    fig,axes = plt.subplots(2,2)
    axcurves = axes[1,:]
    axV = axes[0,0]
    axN = axes[0,1]
    
    sted_type = "3D"
    f1,fk,fv = make_res_functions(sted_type,K_CONFOCAL)
    
    if sted_type=="2D":
        fk = lambda x: K_CONFOCAL / x
        fv = lambda x: x**2
    Gsted = sted_maker(tau_confocal,fk)
    sted_curves = [Curve(corr =w, fitter = Gsted , power=1000,bounds=bounds) 
        for w in ext["log_autocorrelations"]] 
    [w.fit() for w in sted_curves]
    volumes = [ fv(np.sqrt(w.tau/tau_confocal)) for w in sted_curves]
    Ns = np.array([w.N for w in sted_curves])

    xdata = ext["xdata"].squeeze()
    [w.plot(axes=axcurves,label=j,normalise=False,offset_residuals_only=True,offset = k*0.02) 
        for k,(j,w) in enumerate(zip(xdata,sted_curves))]

    
    x1,v1 = reorder(xdata,volumes)
    x2,n1 = reorder(xdata,Ns)
    
    assert(np.all(x1==x2))
    
    v1 = np.array(v1)
    n1 = np.array(n1)
    x1 = np.array(x1)
    if show_half:
        
        v1 = v1[x1>=0]
        n1 = n1[x1>=0]
        x1 = x1[x1>=0]
        x2 = x1
    axV.plot(x1,v1,marker="o")
    axN.plot(x2,n1,marker="o")
    axN.set_xlabel("Bias (rad)")
    axN.set_ylabel("N ")
    
    axV.set_xlabel("Bias (rad)")
    axV.set_ylabel("V/Vconfocal ")
    
    axplots1,axplots2  = axcurves
    axplots1.set_xlabel(r"$\rm \tau (ms)$")
    axplots1.set_ylabel(r"$\rm G(\tau)$")
    axplots1.set_title("ACFs")
    axplots1.legend()
    
    axplots2.set_xlabel(r"$\rm \tau (ms)$")
    axplots2.set_ylabel(r"$\rm G(\tau)$")
    axplots2.set_title("Residuals")
    axplots2.legend()
    fig.suptitle(aberration_names[modes])
    fig.tight_layout()
    

def plot_series(fseries,types,names):
    fig,axes = plt.subplots(2,2)

    axN = axes[0,0]
    axV = axes[1,0]
    axplots1 = axes[0,1]
    axplots2 = axes[1,1]
    
    f0 = fseries[0]
    tp = types[0]
    plot_aberration_variation(f0,tp,axN,axV,axplots1,axplots2,names[0])
    
    f1 = fseries[1]
    tp1 = types[1]
    plot_aberration_variation(f1,tp1,axN,axV,axplots1,axplots2,names[1])

    f2 = fseries[2] #CH0.75
    tp2 = types[2]
    plot_aberration_variation(f2,tp2,axN,axV,axplots1,axplots2,names[2])
    
    axN.set_xlabel("Bias (rad)")
    axN.set_ylabel("N ")
    
    axV.set_xlabel("Bias (rad)")
    axV.set_ylabel("V/Vconfocal ")
    
    axplots1.set_xlabel(r"$\rm \tau (ms)$")
    axplots1.set_ylabel(r"$\rm G(\tau)$")
    axplots1.set_title("FCS curves, no aberration")
    axplots1.legend()
    
    axplots2.set_xlabel(r"$\rm \tau (ms)$")
    axplots2.set_ylabel(r"$\rm G(\tau)$")
    axplots2.set_title("FCS curves, 1 rad")
    axplots2.legend()
    axV.legend()
    axV.axhline(1,color="gray",linestyle="--")
    
    axplots1.set_ylim(0,0.200)
    axplots2.set_ylim(0,0.200)
    fig.tight_layout()

def sensitivity(ns,vs,i0,i1):
    """"""
    r1 = max(vs[i1]/vs[i0],ns[i1]/ns[i0])
    r2 = max(vs[i1-1]/vs[i0],ns[i1-1]/ns[i0])
    print(r1,r2)
    if r2>1.5:
        print("high")
    elif r1>1.5:
        print("medium")
    else:
        print("low")
    """if r3<1.5:
        print("low")
    elif r3<2:
        print("medium")
    else:
        print("high")"""

def plot_series2(fseries,types,names,plot_curves=False,savename = None,
                 legend=False,maxab = 1.5, verbose = True):
    """Obsolete script"""
    fig,ax1 = plt.subplots(1,1,figsize = (FIG_UNIT,FIG_UNIT))
    if verbose:
        print("\n\n",savename)
    if plot_curves:
        fig2,axes = plt.subplots(1,2,figsize = (2*FIG_UNIT,FIG_UNIT))
        axplots1,axplots2 = axes
        
    for j in range(len(fseries)):
        ff = fseries[j]
        tp = types[j]
        name = names[j]
        xab,ns,vs,curves = get_aberration_variation(ff,tp)
        # for xb,curve in zip(xab,curves):
        #    print("Aberration vs in curve",xb,curve.power)
        # index of aberration 1 rad
        i1 = 6
        i0 = 4
        assert(np.isclose(xab[i0],0.0))
        assert(np.isclose(xab[i1],1.0))
        # print("ns,vs",ns[i1]/ns[i0],vs[i1]/vs[i0])
        if verbose:
            print(tp)
            sensitivity(ns,vs,i0,i1)
        msk = np.logical_and(xab>=0,xab<=maxab)
        
        new_curves = list()
        for j in range(len(xab)):
            if xab[j]>=0:
                new_curves.append(curves[j])
                
        xab = xab[msk]
        ns = ns[msk]
        vs = vs[msk]
        curves = new_curves
        
        # ax1.plot(xab,vs/vs[0],**PLOT_DEFAULT[name],linestyle = "-",label=name+", V")
        ax1.plot(xab,ns/ns[0],**PLOT_DEFAULT[name],linestyle = "-",label=name)
    
        # FCS curves
        if plot_curves:
            ind1 = np.where(xab==0)[0][0]
            ind2 = np.where(xab==1)[0][0]
            c1 = new_curves[ind1]
            c2 = new_curves[ind2]
            assert(c1!=c2)
            assert(c1.power==0)
            assert(c2.power==1)
            axplots1.semilogx(c1.x,c1.y,color=PLOT_DEFAULT[name]["color"],label=name)
            axplots2.semilogx(c2.x,c2.y,color=PLOT_DEFAULT[name]["color"],label=name)
    
    if plot_curves:
        axplots1.set_ylim(0,0.200)
        axplots2.set_ylim(0,0.200)
                
        axplots1.set_xlabel(r"$\rm \tau (ms)$")
        axplots1.set_ylabel(r"$\rm G(\tau)$")
        axplots1.set_title("ACFs not aberrated")
        
        axplots2.set_xlabel(r"$\rm \tau (ms)$")
        axplots2.set_ylabel(r"$\rm G(\tau)$")
        axplots2.set_title("ACFs 1 rad")
        axplots2.legend()
        fig2.tight_layout()
        
    ax1.set_xlabel("Bias (rad)")
    ax1.set_ylabel(r"$\rm N/N_{0}$")
    ax1.set_title(savename)
    ax1.set_yscale("log")
    ax1.set_ylim(bottom = 0.8, top = 15)
    ax1.grid(axis="y",which="both")
    if legend:
        handles, labels = ax1.get_legend_handles_labels()
        order = [0,2,4,1,3,5]
        #ax1.legend([handles[idx] for idx in order],[labels[idx] for idx in order],ncol=2)
        ax1.legend()
    fig.tight_layout()
    if savename is not None:
        fig.savefig(savename+".svg",dpi=600,transparent = True)
        if plot_curves:
            fig2.savefig(savename+"_curves.svg",dpi=600,transparent = True)
            
def plot_series3(fseries,types,names,plot_curves=False,savename = None,
                 legend=False,maxab = 1, verbose = True):
    """Shows N vs V f"""
    fig,ax1 = plt.subplots(1,1,figsize = (FIG_UNIT,FIG_UNIT))
    if verbose:
        print("\n\n",savename)
    if plot_curves:
        fig2,axes = plt.subplots(1,2,figsize = (2*FIG_UNIT,FIG_UNIT))
        axplots1,axplots2 = axes
    nmax = 0
    for j in range(len(fseries)):
        ff = fseries[j]
        tp = types[j]
        name = names[j]
        xab,ns,vs,curves = get_aberration_variation(ff,tp)
        # for xb,curve in zip(xab,curves):
        #    print("Aberration vs in curve",xb,curve.power)
        # index of aberration 1 rad
        i1 = 6
        i0 = 4
        assert(np.isclose(xab[i0],0.0))
        assert(np.isclose(xab[i1],1.0))
        # print("ns,vs",ns[i1]/ns[i0],vs[i1]/vs[i0])
        if verbose:
            print(tp)
            sensitivity(ns,vs,i0,i1)
        msk = np.logical_and(xab>=0,xab<=maxab)
        
        new_curves = list()
        for j in range(len(xab)):
            if xab[j]>=0:
                new_curves.append(curves[j])
                
        xab = xab[msk]
        ns = ns[msk]
        vs = vs[msk]
        curves = new_curves
        
        # ax1.plot(xab,vs/vs[0],**PLOT_DEFAULT[name],linestyle = "-",label=name+", V")
        ax1.plot(vs,ns/N_confocal,color = PLOT_DEFAULT[name]["color"],linestyle = "-",label=name)
        # facecolors = ["none","gray","black"]
        ss = [10,40,80]
        ampl = ["0","0.5","1"]
        print("File",ff,"vs",vs,"Ns",ns/N_confocal)
        for k in range(len(vs)):
            # color = "black",facecolors=facecolors[k],marker="o"
            color = PLOT_DEFAULT[name]["color"]
            if legend:
                color = "k"
            ax1.scatter(vs[k],ns[k]/N_confocal,label=ampl[k],color = color,
                        s = ss[k],marker="o")
        nmax = max(nmax,max(np.max(ns/N_confocal),np.max(vs)  ))
        # FCS curves
        if plot_curves:
            ind1 = np.where(xab==0)[0][0]
            ind2 = np.where(xab==1)[0][0]
            c1 = new_curves[ind1]
            c2 = new_curves[ind2]
            assert(c1!=c2)
            assert(c1.power==0)
            assert(c2.power==1)
            axplots1.semilogx(c1.x,c1.y,color=PLOT_DEFAULT[name]["color"],label=name)
            axplots2.semilogx(c2.x,c2.y,color=PLOT_DEFAULT[name]["color"],label=name)
    
    if plot_curves:
        axplots1.set_ylim(0,0.200)
        axplots2.set_ylim(0,0.200)
                
        axplots1.set_xlabel(r"$\rm \tau (ms)$")
        axplots1.set_ylabel(r"$\rm G(\tau)$")
        axplots1.set_title("ACFs not aberrated")
        
        axplots2.set_xlabel(r"$\rm \tau (ms)$")
        axplots2.set_ylabel(r"$\rm G(\tau)$")
        axplots2.set_title("ACFs 1 rad")
        axplots2.legend()
        fig2.tight_layout()
    nmax = 8
    ax1.plot(np.linspace(0.1,nmax),np.linspace(0.1,nmax),color="gray",
             linestyle="--")
    
    ax1.set_xlabel(r"$\rm V/V_{confocal}$")
    ax1.set_ylabel(r"$\rm N/N_{confocal}$")
    ax1.set_title(savename)
    ax1.set_yscale("log")
    ax1.set_xscale("log")
    # ax1.set_ylim(bottom = 0.8, top = 15)
    ax1.grid(axis="y",which="both")
    if legend:
        handles, labels = ax1.get_legend_handles_labels()
        order = [0,2,4,1,3,5]
        #ax1.legend([handles[idx] for idx in order],[labels[idx] for idx in order],ncol=2)
        ax1.legend()
    fig.tight_layout()
    if savename is not None:
        fig.savefig(savename+".svg",dpi=600,transparent = True)
        if plot_curves:
            fig2.savefig(savename+"_curves.svg",dpi=600,transparent = True)
files_spherical = [files[3],files[1],files[2]]
    
files_coma = [w for w in files if "coma" in w]
files_coma.sort()
files_coma.pop(0) #get rid of CH 0.9

files_tip = [w for w in files if "tip" in w]
files_tip.sort()
files_tip.pop(0) #get rid of CH 0.9

files_astigmatism = [w for w in files if "astigmatism" in w]
files_astigmatism.sort()
files_astigmatism.pop(0) #get rid of CH 0.9


types = [0.75,"2D","3D"]
names = ["CH","2D","z"]


plot_series3(files_spherical,types,names,plot_curves = True,savename="Spherical",legend=False)

plot_series3(files_tip,types,names,plot_curves = False,savename="Tilt")
plot_series3(files_coma,types,names,plot_curves = False,savename="Coma")
plot_series3(files_astigmatism,types,names,plot_curves = False,savename="Astigmatism")

plot_series3(files_astigmatism,types,names,plot_curves = False, legend = True, 
             savename="Astigmatism_legend")

OLD_STYLE = False
if OLD_STYLE:
    names = ["CH","2D","z"]
    types = [0.75,"2D","3D"]
    plot_series(files_spherical,types,names)
    plt.suptitle("spherical")
    

    types = [0.75,"2D","3D"]
    names = ["CH","2D","z"]
    plot_series(files_tip,types,names)
    plt.suptitle("Tip")
    
    types = [0.75,"2D","3D"]
    names = ["CH","2D","z"]
    plot_series(files_coma,types,names)
    plt.suptitle("Coma")
    

    types = [0.75,"2D","3D"]
    names = ["CH","2D","z"]
    plot_series(files_astigmatism,types,names)
    plt.suptitle("Astigmatism")
    
    ind = 1
    ff = files_tip[ind]
    tp = types[ind]
    plot_one_file(ff,tp)
    plt.suptitle(tp+" tip")
    
    
    ind = 2
    ff = files_astigmatism[ind]
    tp = types[ind]
    plot_one_file(ff,tp)
    plt.suptitle(tp+" Astigmatism")
    
    
    ind = 1
    ff = files[ind]
    tp = types[ind]
    plot_one_file(ff,tp)
    plt.suptitle(tp+" Spherical")
    
    ind = 2
    ff = files_coma[ind]
    tp = types[ind]
    plot_one_file(ff,tp)
    plt.suptitle(str(tp)+" coma")
# Analysis of Astigmatism: values above 1.5 rads not reliably fittable. in zSTED
# Corresponds to 8.6 fold increase in number of molecules
# CH and coma: worse is 1 rad, 2 rads is better
# Same with 2D: 2 is better than 1.5
# 3D is ok