#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Apr  6 16:59:43 2020

@author: aurelien
"""

import numpy as np
import matplotlib.pyplot as plt
import glob

from accessories.misc import comparison_file_extractor
from scipy.optimize import curve_fit
from pyfcs.fitting import Gc,Gt
from pyfcs.io import open_SIN


def fitter(tau,N,tauD,T,tau_T):
    return Gc(tau,N,tauD,K,1)*Gt(tau,T,tau_T)

def get_taut(corr, plot=True, return_all = False, first_index = 2):
    
    bounds = ((0.01,0.01,0,10**-3),
              (1000,1,1,2*10**-2))
    x,y = corr[first_index:,0],corr[first_index:,1]
    popt,_ = curve_fit(fitter,x,y,bounds = bounds)
    yh = fitter(x,*popt)
    if plot:
        plt.figure()
        plt.semilogx(x,y)
        plt.semilogx(x,yh,color="k",linestyle="--")
        plt.title("start at "+str(first_index)+" us")
    if return_all:
        return popt
    return popt[3]


PLOT_ALL = False
plt.close("all")
MSIZE = 3
m_quality = 30
first_index = 2
K = 4

files_2 = glob.glob("../../data_analysis/FCS/2019_02_04/60s/*/")
files_3 = glob.glob("../../data_analysis/FCS/2019_02_13/60s/*/")
files_2.extend(files_3)


#corr = multipletau.autocorrelate(confocal[0],deltat=10**-3,normalize = True)
#corr2 = multipletau.autocorrelate(confocal[1],deltat=10**-3,normalize = True)


all_tt3s = list()
all_amplitudes = list()

for j in range(len(files_2)):
    file = glob.glob(files_2[j]+"/*.npy")[0]
    corr3 = np.load(file)
    # msk = corr3[:,0]<10**-1
    # corr3 = corr3[msk]
    popt = get_taut(corr3, plot=False, return_all = True)
    tt3 = popt[3]
    triplet_amplitude = popt[2]
    all_tt3s.append(tt3)
    all_amplitudes.append(triplet_amplitude)
    print("Average transit time 3:",tt3)
    
plt.figure()
plt.scatter(all_amplitudes,all_tt3s)
plt.xlabel("Triplet amplitude")
plt.ylabel("Triplet correlation time")

all_tt3s = np.array(all_tt3s)*10**3

plt.figure(figsize = (5,2.5))
plt.subplot(1,2,1)
plt.hist(all_tt3s, bins = 5)
plt.axvline(np.mean(all_tt3s),color="red")
plt.xlabel(r"$ \rm Triplet\ correlation\ time\ (\mu s)$")
plt.ylabel("Occurence")

plt.subplot(1,2,2)
plt.hist(all_amplitudes,bins=5)
plt.axvline(np.mean(all_amplitudes),color="red")
plt.xlabel("Triplet amplitude")
plt.ylabel("Occurence")
plt.tight_layout()
plt.savefig("triplet_parameters.svg")

# -----Comparison between 2 values------
ind = -2
file = glob.glob(files_2[ind]+"/*.npy")[0]
corr = np.load(file)

bounds = ((0.01,0.01,0,10**-3),
              (1000,1,1,2*10**-2))

def f1(tau,N,tauD,T):
    return Gc(tau,N,tauD,K,1)*Gt(tau,T,all_tt3s[ind]*10**-3)
def f2(tau,N,tauD,T):
    return Gc(tau,N,tauD,K,1)*Gt(tau,T,5*10**-3)

x,y = corr[first_index:,0],corr[first_index:,1]


bounds1 = ((0.01,0.01,0),
              (1000,1,1))
popt1,_ = curve_fit(f1,x,y,bounds = bounds1)
popt2,_ = curve_fit(f2,x,y,bounds = bounds1)
yh1 = f1(x,*popt1)
yh2 = f2(x,*popt2)

plt.figure(figsize=(5,2.5))
plt.subplot(121)
plt.semilogx(x,y,color="black")
plt.semilogx(x,yh1,color="green",linestyle="--")
plt.semilogx(x,yh2,color="m",linestyle="-.")
plt.xlabel(r"$\rm \tau\ (ms)$")
plt.ylabel(r"$\rm G(\tau)$")
plt.gca().ticklabel_format(axis="y",style="sci")
plt.locator_params(axis='y', nbins=3)

plt.subplot(122)
plt.semilogx(x,y-yh2,color="m",linestyle="-.",label=r"$\tau_{T}=5\mu s$")
plt.semilogx(x,y-yh1,color="green",linestyle="--", label=r"$\tau_{T}=12\mu s$")
plt.locator_params(axis='y', nbins=3)
plt.xlabel(r"$\rm \tau\ (ms)$")
plt.ylabel(r"$\rm Residuals$")
plt.legend()
plt.tight_layout()
plt.savefig("curves_triplet.svg",transparent=True)

# Zoom
msk_zoom = x<3*10**-2
plt.figure(figsize=(1.5,1.5))

plt.semilogx(x[msk_zoom],y[msk_zoom],color="black")
plt.semilogx(x[msk_zoom],yh1[msk_zoom],color="green",linestyle="--")
plt.semilogx(x[msk_zoom],yh2[msk_zoom],color="m",linestyle="-.")
plt.tight_layout()
plt.savefig("zoom_curve.svg",transparent=False)
