#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Feb 11 11:53:47 2019

@author: aurelien
"""

import numpy as np
import sys
sys.path.append("..")
from methods import FcsFitter, FcsExperiment
from methods import fit_dictionary
from constants import K_CONFOCAL

import aotools
from aotools.ext.fcs import G_corr,Gt
import matplotlib.pyplot as plt
ALPHA = 0.75
TAU_T = 5*10**-3
folder = "/home/aurelien/Documents/phd/fcs_paper/data/2018_11_16_cell2_all_5s_LA5/"


def G_STED_maker(tau_confocal,fk,tauT=5*10**-3):
    """method used to generate a STED function from a measured tau confocal"""
    def G(tau,N,tauD,alpha,T,K,tauT):
        
        """S = 2wz/(wx+wy). See for example Karlstrup 2005, PRL."""
        factor = (1+(tau/tauD)**alpha )*np.sqrt((1+(1/K**2)*(tau/tauD)**alpha ) )
        triplet = 1+T/(1-T)*np.exp(-tau/tauT)
        return triplet/(N*factor)
    
    #alpha_conf = pars[2]
    def f(tau,N,tauD,T,offset,alpha = ALPHA):
        #K = f_quad(np.sqrt(tauD/tau_confocal))
        return G(tau,N,tauD,alpha,T,fk((tauD/tau_confocal)**(alpha/2) ),tauT)+offset
    return f

def G(tau,N,tauD,T,alpha=ALPHA,K=K_CONFOCAL,tauT=TAU_T):
    
    """S = 2wz/(wx+wy). See for example Karlstrup 2005, PRL."""
    factor = (1+(tau/tauD)**alpha )*np.sqrt((1+(1/K**2)*(tau/tauD)**alpha ) )
    triplet = 1+T/(1-T)*np.exp(-tau/tauT)
    return triplet/(N*factor)

bounds_confocal = ((10**-3,10**-3,0),(1000,1,1-1e-3))
bounds_sted = ((10**-3,10**-3,0,-0.1),(1000,1,1-1e-3,0.1))
fitter = FcsFitter('3D',G,G_STED_maker,bounds_confocal,
                 bounds_sted)
experiment = FcsExperiment(folder,sted_type=None,fcsFitter = fitter)
experiment.process(sted_keyword="correction",alpha=ALPHA)
p1,mz1,sz1 = experiment.axial_resolutions()

experiment2 = FcsExperiment(folder,sted_type=None,fcsFitter = fitter)
experiment2.process(sted_keyword="reference",alpha=ALPHA)
p2,mz2,sz2 = experiment2.axial_resolutions()

plt.figure()
plt.errorbar(p1,mz1,yerr= sz1,label="AO on")
plt.errorbar(p2,mz2,yerr= sz2,label="AO off")
plt.xlabel("STED power (mW)")
plt.ylabel("Axial resolution (normalised)")
plt.legend()

experiment.plot()
experiment2.plot()
