#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Apr 23 08:34:21 2019

@author: aurelien
"""

import numpy as np
import matplotlib.pyplot as plt
import glob

from accessories.io import generate_names_list_npy
from accessories.calibration import sted_power
from pyfcs.fitting import FcsExperiment,FcsFitter, Gc,Gt, Gc_z
import sys
sys.path.append("../..")
from constants import FIG_UNIT, PLOT_FONT_SIZE
import matplotlib
font = {'size'   : PLOT_FONT_SIZE}

matplotlib.rc('font', **font)
plt.close("all")

# Excitation 20 uW

TAU_T = 11*10**-3
T_AMPL = 0.25
ALPHA = 1
K_CONFOCAL = 4
    
CUSTOM_2D_FITTER = False

def read_notes(folder):
    with open(folder+"notes0.txt","r") as f:
        notes = f.read()
    return notes

def Gconf(tau,N,tauD):
    return Gc(tau,N,tauD,K_CONFOCAL,ALPHA) * Gt(tau,T_AMPL,TAU_T)

def default_sted_maker(tau_confocal,fk):
    def f(tau,N,tauD):
        r= (tauD/tau_confocal)**(ALPHA/2)
        return Gc(tau,N, tauD, fk(r), ALPHA) * Gt(tau, T_AMPL, TAU_T)
    return f

def sted_maker2D(tau_confocal,fk):
    def f(tau,N,tauD):
        tau_z = K_CONFOCAL**2 * tau_confocal
        return Gc_z(tau,N,tauD,tau_z,ALPHA)* Gt(tau,T_AMPL,TAU_T)
    return f

def fit_exp(nexp,plot=False):
    folder = folders[nexp]

    
    files_dict = generate_names_list_npy(folder,output_dict=True,sted_keyword="")
    rad = radii[nexp]
    
    if CUSTOM_2D_FITTER:
        if rad<0.95:
            sted_maker = default_sted_maker
        else:
            sted_maker = sted_maker2D
            print("Radius {}, sdsds".format(rad))
    else:
        sted_maker = default_sted_maker
        
    bounds_confocal = ((10**-3,10**-3),(10**5,100))
    bounds_sted=bounds_confocal
    fitter  =FcsFitter(rad,Gconf,bounds_confocal,bounds_sted,k_confocal = K_CONFOCAL)
    experiment = FcsExperiment(fcsFitter=fitter,files_dictionary=files_dict,alpha=ALPHA,
                               loadfun=np.load,tau_t = TAU_T, sted_maker=sted_maker)
    if plot:
        experiment.process()
        experiment.plot_acfs()
        experiment.plot()
    return experiment


def get_residuals(exp,x0 = 0,x1 = 5*10**-2):
    curves = exp.fcs_curves
    ks = curves.keys()
    ks = list(sorted(ks))
    residuals = {}
    for k in ks:
        cs = curves[k]
        tmp_res = list()
        for curve in cs:
            x = curve.x
            y = curve.y
            yhat = curve.fitter(x,*curve.popt)
            N = curve.N
            mask = np.logical_and(x>x0,x<x1)
            diff = np.sqrt(np.mean( ( (y-yhat)[mask]*N )**2))
            tmp_res.append(diff)
        residuals[k] = tmp_res[:]
        tmp_res = list()
    res_mean = list()
    res_std = list()
    for k in ks:
        residuals_k = residuals[k]
        res_mean.append(np.mean(residuals_k))
        res_std.append(np.std(residuals_k))
    pws_res  = np.array(ks)
    return pws_res,np.array(res_mean),np.array(res_std)

path = "../../data_analysis/FCS/2018_11_28/10s/"
folders = glob.glob(path+'*/')
folders.sort()
f1 = folders.pop(0)
folders.append(f1)

notes = [read_notes(w) for w in folders]
print()
radii = np.ones(len(folders))-np.round(np.arange(len(folders))[::-1]*0.05,2)
nexp = 0

all_experiments = []
# plt.figure()
for j in range(len(radii)):
    exp = fit_exp(j,plot=False)
    all_experiments.append(exp)
    # exp.plot(title=str(radii[j]))
    p1,m1,s1 = exp.volumes()

target_pw1 = 1000
target_pw2 = 2000
target_pws = [1000,1500]

def get_v_n(experiments,pw,radii):
    all_vs, all_ns = list(), list()
    for j, exp in enumerate(experiments):
        rad = radii[j]
        if CUSTOM_2D_FITTER:
            if rad >= 0.95:
                print("Volume rad>0.95")
                taus = exp.transit_times(raw=True)
                tau_confocal = np.mean(taus[0])
                v, v_s = np.mean(taus[pw]/tau_confocal), np.std(taus[pw]/tau_confocal)
                print("tau confocal",tau_confocal,"STED",np.mean(taus[pw]))
            else:
                pws,mv,sv = exp.volumes()
                pws = np.array(pws)
                print(pws,pw)
                v,v_s = mv[pws==pw], sv[pws==pw]
        else:
            pws,mv,sv = exp.volumes()
            pws = np.array(pws)
            print(pws,pw)
            v,v_s = mv[pws==pw], sv[pws==pw]
        assert(np.count_nonzero(pws==pw))
        pws,mn,sn = exp.number_of_molecules()
        pws = np.array(pws)
        assert(np.count_nonzero(pws==pw))
        
        n,n_s = mn[pws==pw], sn[pws==pw]
        all_vs.append([v,v_s])
        all_ns.append([n,n_s])
        
    all_vs = np.array(all_vs)
    all_ns = np.array(all_ns)
    return all_vs,all_ns

def get_res_allexps(experiments,pw):
    """Exrtacts residuals at a target power for a list of experiments"""
    all_res =  list()
    for exp in experiments:
        pws,m_res,s_res = get_residuals(exp)
        pws = np.array(pws)
        print(pws,pw)
        assert(np.count_nonzero(pws==pw))
        r, r_s = m_res[pws==pw], s_res[pws==pw]
        all_res.append([r,r_s])
        
    all_res = np.array(all_res)
    return all_res

def plot_acfs(experiments,pw,indices):
    fig,ax = plt.subplots(1,1,figsize= (FIG_UNIT,FIG_UNIT))
    names = [r"$\rm \rho\ =\ 0.75$",r"$\rm \rho\ =\ 1$"]
    colors = ["C0", "C5"]
    for j in range(len(indices)):
        exp = experiments[indices[j]]
        curve = exp.fcs_curves[pw][0]
        curve.plot(axes = ax, label=names[j], color = colors[j])
    # Confocal
    exp = experiments[0]
    curve = exp.fcs_curves[0][0]
    curve.plot(axes = ax, label = "Confocal", color="C2")
    ax.set_xlabel(r"$\rm \tau$")
    ax.set_ylabel(r"$\rm G(\tau)$")
    ax.legend()
    fig.tight_layout()
    fig.savefig("ACFS.svg")
plot_acfs(all_experiments,1500,[1,6])

fig2,ax2 = plt.subplots(1,1,figsize= (FIG_UNIT,FIG_UNIT))

fig1,axes1 = plt.subplots(1,1,figsize= (FIG_UNIT,FIG_UNIT))
fig3,axes3 = plt.subplots(1,1,figsize= (FIG_UNIT,FIG_UNIT))
# residuals
fig4,axes4 = plt.subplots(1,1,figsize= (FIG_UNIT,FIG_UNIT))

axes = [axes1, axes3, axes4]

colors = ["#b20de5","#35c535"]

for j,target_pw in enumerate(target_pws):
    vv,nn = get_v_n(all_experiments,target_pw, radii)
    rr = get_res_allexps(all_experiments,target_pw)
    
    # Residuals
    rrm = [w[0] for w in rr]
    rrs = [w[1] for w in rr]
    # Volumes
    vvm = [w[0] for w in vv]
    vvs = [w[1] for w in vv]
    # Number of molecules
    nnm = [w[0] for w in nn]
    nns = [w[1] for w in nn]
    pw_mw = int(sted_power(target_pw))
    
    axes[0].errorbar(radii,nnm,yerr=nns,label=str(pw_mw)+" mW",
        color=colors[j],marker="o",capsize=5)
    axes[1].errorbar(radii,vvm,yerr=vvs,label=str(pw_mw)+" mW",color=colors[j],
        marker="o",capsize=5)
    axes[2].errorbar(radii, rrm, yerr=rrs, label=str(pw_mw)+" mW",
        color=colors[j],marker="o",capsize=5)

    ax2.errorbar(vvm,nnm,xerr=vvs,yerr=nns,label = str(pw_mw)+" mW",
                 color = colors[j],marker = "o",
                 capsize = 5)

axes[0].set_xlabel(r"$\rm \rho$")
axes[1].set_xlabel(r"$\rm \rho$")
axes[2].set_xlabel(r"$\rm \rho$")

axes[0].set_ylabel(r"$\rm N/N_{confocal}$")
axes[1].set_ylabel(r"$\rm V/V_{confocal}$")
axes[2].set_ylabel(r"$\rm nRMSD$")
axes[0].set_ylim(bottom = 0)
axes[1].set_ylim(bottom = 0)
axes[2].set_ylim(bottom = 0)
ax2.set_ylabel(r"$\rm N/N_{confocal}$")
ax2.set_xlabel(r"$\rm V/V_{confocal}$")

xr = np.linspace(0,1,10)
ax2.plot(xr,xr,"--",color="gray")
ax2.set_xlim(left=0)
ax2.set_ylim(bottom=0)
fig2.tight_layout()

ax2.legend()
fig1.tight_layout()
fig1.savefig("ch_radius_n.svg",transparent=True)
fig2.tight_layout()
fig2.savefig("ch_snr.svg",transparent=True)
fig3.tight_layout()
fig3.savefig("ch_radius_v.svg",transparent=True)
fig4.tight_layout()
fig4.savefig("ch_radius_residuals.svg", transparent=True)
