#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Apr 29 18:41:46 2019

@author: aurelien
"""

import numpy as np
import matplotlib.pyplot as plt
import matplotlib

from pyfcs.fitting import FcsExperiment, FcsFitter,Gc,Gt,Gc_z

from accessories.calibration import sted_power
from accessories.io import generate_names_list_npy

import sys
sys.path.append("../..")
from constants import FIG_UNIT, PLOT_DEFAULT, PLOT_FONT_SIZE
plt.close("all")

font = {'size'   : PLOT_FONT_SIZE}

matplotlib.rc('font', **font)
# Excitation 9 uW in pellicle

K_CONFOCAL = 4
ALPHA = 1
TAU_T = 11*10**-3
T_AMPL = 0.16
#plt.close("all")

#Definition of functions
def Gconf(tau,N,tauD):
    return Gc(tau,N,tauD,K_CONFOCAL,ALPHA) * Gt(tau,T_AMPL,TAU_T)

#from aotools.ext.fcs import k_dependency

def sted_maker2D(tau_confocal,fk):
    def f(tau,N,tauD):
        tau_z = K_CONFOCAL**2 * tau_confocal
        return Gc_z(tau,N,tauD,tau_z,ALPHA)* Gt(tau,T_AMPL,TAU_T)
    return f
def default_sted_maker(tau_confocal,fk):
    def f(tau,N,tauD):
        r= (tauD/tau_confocal)**(ALPHA/2)
        return Gc(tau,N, tauD, fk(r), ALPHA) * Gt(tau, T_AMPL, TAU_T)
    return f

def get_residuals(exp,x0 = 0,x1 = 5*10**-2):
    curves = exp.fcs_curves
    ks = curves.keys()
    ks = list(sorted(ks))
    residuals = {}
    for k in ks:
        cs = curves[k]
        tmp_res = list()
        for curve in cs:
            x = curve.x
            y = curve.y
            yhat = curve.fitter(x,*curve.popt)
            N = curve.N
            mask = np.logical_and(x>x0,x<x1)
            diff = np.sqrt(np.mean( ( (y-yhat)[mask]*N )**2))
            tmp_res.append(diff)
        residuals[k] = tmp_res[:]
        tmp_res = list()
    res_mean = list()
    res_std = list()
    for k in ks:
        residuals_k = residuals[k]
        res_mean.append(np.mean(residuals_k))
        res_std.append(np.std(residuals_k))
    pws_res  = np.round(sted_power(np.array(ks)))
    return pws_res,np.array(res_mean),np.array(res_std)

def plot_acfs_subset(self,title=None, indices = []):
    """Plots a subset of ACFs for a given experiment"""
    n1 = len(indices)
    fig,axes = plt.subplots(2,n1+1, figsize = (8,4))
    axes = axes.T
    
    
    for j in range(len(indices)):
        ind = indices[j]
        pw = self.sted_powers[ind]
        curves = self.fcs_curves[pw]
        
        offset = 0
        offset_iter=0.6
        for curve in curves:
            curve.plot(axes=axes[j],normalise=True,offset = offset)
            offset+=offset_iter
        spw = "0"
        if pw>0:
            spw = str(int(sted_power(pw)))
        axes[j,0].set_title(spw+" mW")
        axes[j,0].set_xlabel(r"$\rm \tau\ (ms)$")
        axes[j,0].set_ylabel(r"$\rm G(\tau$)")
        
        axes[j,1].set_xlabel(r"$\rm \tau\ (ms)$")
        axes[j,1].set_ylabel(r"$\rm G(\tau$)")
        xvals = np.array([c.x for c in curves])
        x0 = xvals[0]
        assert(np.all([xv == x0 for xv in xvals]))
        yvals = np.median(np.array([c.y for c in curves]),axis=0)
        yvals_norm = np.median( np.array([c.y * c.N for c in curves]),
                                 axis= 0 )
        axes[n1,0].semilogx(x0,yvals_norm,label = str(pw))
        axes[n1,1].semilogx(x0,yvals,label = str(pw))
        
    axes[n1,0].set_title("Median of normalised ACFs")
    axes[n1,0].legend()
    axes[n1,0].set_xlabel(r"$\rm \tau\ (ms)$")
    axes[n1,0].set_ylabel(r"$\rm G(\tau$)")
    axes[n1,1].set_title("Median of ACFs")
    axes[n1,1].legend()
    axes[n1,1].set_xlabel(r"$\rm \tau\ (ms)$")
    axes[n1,1].set_ylabel(r"$\rm G(\tau$)")
    fig.suptitle(title)
    fig.tight_layout()

bounds = ((10**-3,10**-3),(1000,1))

sted_types = ["2D","double_angle30","3D",0.85]

files = [
         '../../data_analysis/FCS/2019_02_04/10s/04_02_19_16h46m34/',
         "../../data_analysis/FCS/2019_02_04/10s/04_02_19_18h17m40",
         "../../data_analysis/FCS/2019_02_04/10s/04_02_19_16h28m01",
         '../../data_analysis/FCS/2019_02_04/10s/04_02_19_17h24m59']

"""different modes of focal confinement by STED"""

all_experiments = list()

for folder,st in zip(files,sted_types):
    if st=="2D":
        sted_maker = sted_maker2D
    else:
        sted_maker = default_sted_maker
    fitter = FcsFitter(st,Gconf,bounds,bounds,
                      k_confocal = K_CONFOCAL)
    oo = generate_names_list_npy(folder,output_dict = True,sted_keyword="")
        
    experiment = FcsExperiment(fcsFitter = fitter,files_dictionary = oo,
                           alpha = ALPHA,loadfun = np.load,tau_t = TAU_T,
                           sted_maker =sted_maker)
    experiment.process()
    all_experiments.append(experiment)
    pws = np.array(list(oo.keys()))
    p0 = np.where(pws==1500)[0][0]
    p1 = np.where(pws==3000)[0][0]
    plot_acfs_subset(experiment, indices = [0,p0,p1])
    # for supplement S5
    plt.suptitle(st)
    plt.savefig("fcs_curves_"+str(st)+".svg")
fig2,ax2 = plt.subplots(1,1,figsize = (FIG_UNIT,FIG_UNIT))
fig3,axes3 = plt.subplots(2,1, gridspec_kw={'height_ratios': [1, 2]},
                                            figsize = (FIG_UNIT,FIG_UNIT), 
                                            sharex = True)

fig4,ax4 = plt.subplots(1,1,figsize = (FIG_UNIT,FIG_UNIT))
fig5,ax5 = plt.subplots(1,1,figsize = (FIG_UNIT,FIG_UNIT))
fig6,ax6 = plt.subplots(1,1,figsize = (FIG_UNIT,FIG_UNIT))

# for SBRs
fig7,ax7 = plt.subplots(1,1,figsize = (FIG_UNIT,FIG_UNIT))
ax7.set_xlabel("STED laser power (mW)")
ax7.set_ylabel("SBR")

curve_power = 2000

offset = 0.2
# Order: FCS curves, vols, Ns, n vs V
names = ["2D","3D","z","CH"]


for j in range(len(files)):
    experiment = all_experiments[j]
    p1,n1,std_n1  = experiment.number_of_molecules()
    
    pres,mres,sres = get_residuals(experiment)
    
    ax2.errorbar(pres,mres,yerr=sres,capsize=5,label=names[j],**PLOT_DEFAULT[names[j]])
    
    if sted_types[j]=="2D":
        vols = experiment.transit_times(raw=True)
        for kv in vols.keys():
            vols[kv] = np.array(vols[kv])/experiment.tau_confocal
        p2 = sorted(list(vols.keys()))
        v1 = []
        std_v1 = []
        for p in p2:
            v1.append(np.mean(vols[p]))
            std_v1.append(np.std(vols[p]))
        v1 = np.array(v1)
        std_v1 = np.array(std_v1)
    else:
        p2,v1,std_v1  = experiment.volumes()
    assert( p1==p2 )
    assert(p1[0] == 0)
    curves = experiment.fcs_curves
    curve = curves[curve_power][0]
    
    #n0 = np.mean([w.N for w in curves[0]]) #Confocal nr of molecules
    n0 = 1    
    fit = curve.fitter(curve.x,*curve.popt)*n0

    xmax = 100
    xmask = curve.x<xmax
    # Same on different plot
    axes3[1].semilogx(curve.x[xmask],curve.y[xmask]*n0,color=PLOT_DEFAULT[names[j]]["color"],
        label=names[j])
    axes3[1].semilogx(curve.x[xmask],fit[xmask],color='black',linestyle='--')
    
    residual = curve.y*n0 - fit
    axes3[0].semilogx(curve.x[xmask],residual[xmask] + offset*j,color=PLOT_DEFAULT[names[j]]["color"],
        label=names[j])
    
    p1 = sted_power(np.asarray(p1))
    p1[0] = 0
    
    ax4.errorbar(p1,v1,yerr=std_v1,capsize=5,label=names[j],**PLOT_DEFAULT[names[j]])
    ax5.errorbar(p1,n1,yerr=std_n1,capsize=5,label=names[j],**PLOT_DEFAULT[names[j]])
    
    msk = n1<2
    ax6.errorbar(v1[msk],n1[msk],xerr=std_v1[msk],yerr=std_n1[msk],capsize = 5,
        label = names[j], **PLOT_DEFAULT[names[j]])
    
# ax2.set_title("Residuals")
ax2.set_ylabel("nRMSD")
ax2.set_xlabel("STED laser power (mW)")
ax2.set_ylim(bottom=10**-2,top=1)
ax2.legend()
ax2.set_yscale("log")
ax2.grid(True,which="both")

ax4.set_xlabel("STED power (mW)")
ax4.set_ylabel(r"$\rm V/V_{confocal}$")
ax4.set_yscale("log")
ax4.grid(True,which="both")

ax5.set_xlabel("STED power (mW)")
ax5.set_ylabel(r"$\rm N/N_{confocal}$")
ax5.set_ylim(bottom = 0)

ax6.plot(np.linspace(0,1,10),np.linspace(0,1,10),linestyle="--",color="gray")

ax6.set_xlabel(r"$\rm V/V_{confocal}$")
ax6.set_ylabel(r"$\rm N/N_{confocal}$")
ax6.legend()
#axes[3].set_ylim(top = 2.5)
ax6.set_xlim(right = 1.2)



axes3[1].set_xlabel(r"$\rm \tau(ms)$")
axes3[1].set_ylabel(r"$\rm G(\tau)$")
axes3[1].legend()

axes3[0].set_ylabel("Residuals")

ax7.legend()

fig2.tight_layout()
fig3.tight_layout()
fig4.tight_layout()
fig5.tight_layout()
fig6.tight_layout()

fig2.savefig("residuals.svg",transparent = True)
fig3.savefig("curves.svg",transparent = True)
fig4.savefig("vs.svg",transparent = True)
fig5.savefig("ns.svg",transparent = True)
fig6.savefig("v_vs_n.svg",transparent = True)
#fig.tight_layout()
#fig.savefig("solution_result.svg",transparent = True)
