#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Oct 10 16:46:31 2019

@author: aurelien
"""

import numpy as np
import matplotlib.pyplot as plt

from shutil import rmtree

from pyfcs.fitting import (Gc, Gt, Curve, fit_dictionaries, fitters_dict, 
                           FcsExperiment, FcsFitter)
import os

plt.close("all")

def test_curve(plot = True):
    """Generates a curve and fits it using the Curve class. Tests fitting transit time, number of molecules and triplet fraction"""
    # These never change
    x0 = np.logspace(-3,3)
    tauT0 = 5*10**-3
    
    # Random values generator
    randgen = lambda minv,maxv: minv + (maxv-minv) * np.random.rand(1)[0]
    
    # Parameters not fitted
    K0 = randgen (0.8,8)
    alpha0 = randgen(0.2,2)
    
    # Fitted parameters
    bounds = ((0.1,10**-2,0),(1000,100,0.8))
    N0 = randgen(bounds[0][0] * 1.2, bounds[1][0] * 0.8)
    tauD0 = randgen(bounds[0][1] * 1.2, bounds[1][1] * 0.8)
    T0 = randgen(bounds[0][2] * 1.2, bounds[1][2] * 0.8)
    y0 = Gc(x0,N0,tauD0,K0,alpha0)*Gt(x0,T0,tauT0)
    
    corr = np.array([[x,y] for x,y in zip(x0,y0)])
    
    def fitter(tau,N,tauD,T):
        return Gc(tau,N,tauD,K0,alpha0)*Gt(tau,T,tauT0)
    
    curve = Curve(corr = corr, power = 1, fitter = fitter, bounds = bounds, 
                  first_n = 0)
    curve.fit()
    if plot:
        curve.plot(normalise = False)
    popt = curve.popt
    
    # Define 10**-3 max relative error
    errtol = lambda x,y: np.abs(x-y)/max(x,y)<10**-3
    errval = lambda x,y: np.abs(x-y)/max(x,y)
    try:
        assert(errtol(popt[0],N0))
        assert(errtol(popt[1],tauD0))
        print("Test succeeded")
    except:
        print("Test failed")
        print("Comparison between values:")
        print("N ",popt[0],N0)
        print("tau ",popt[1],tauD0)
    return [errval(popt[0],N0),errval(popt[1],tauD0)]


def multi_test_curve(n = 50):
    errvals = []
    for j in range(n):
        errvals.append(test_curve(plot=False))
    errvals = np.asarray(errvals)
    print(errvals.shape)
    print(errvals)
    plt.figure()
    plt.plot(errvals[:,0],label = "N")
    plt.plot(errvals[:,1],label = "tau")
    plt.xlabel("test nr")
    plt.ylabel("Relative error")

    plt.legend()
    
# test_curve()
# multi_test_curve()
    
# setUp() and tearDown() methods 
def generate_3DSTED_experiment():
    exptype = "3D"
    x0 = np.logspace(-3,3)
    tauT0 = 5*10**-3
    T0 = 0.2
    K0 = 4
    
    # Random values generator
    randgen = lambda minv,maxv: minv + (maxv-minv) * np.random.rand(1)[0]
    
    # Parameters not fitted
    alpha0 = randgen(0.4,1.5)
    
    # Confocal parameters
    bounds = ((0.1,10**-2,0),(1000,100,0.8))
    N0 = randgen(bounds[0][0] * 1.2, bounds[1][0] * 0.8)
    tauD0 = randgen(bounds[0][1] * 1.2, bounds[1][1] * 0.8)
    
    
    fd = fit_dictionaries["cat2"]
    pars,fn = fd[exptype]
    f1 = lambda x: fitters_dict[fn](x,*pars)
    fk = lambda x: f1(x) * fd["Kconf"]/x
    
    nsted = 5
    
    stedresmin = 0.7
    stedresmax = 1
    stedres = stedresmin + (stedresmax-stedresmin) * np.random.rand(nsted)
    stedres = np.array(sorted(stedres))[::-1]
    
    sted_powers = np.arange(nsted+1)+1
    # confocal
    y0 = Gc(x0,N0,tauD0,K0,alpha0)*Gt(x0,T0,tauT0)
    
    corr0 = np.array([[x,y] for x,y in zip(x0,y0)])
    sted_correlations = list()
    
    out_dict = dict(zip(sted_powers,[[w] for w in stedres]))
    out_dict[0] = [1,"confocal.npy",K0]
    
    folder = "test3D/"
    if not os.path.isdir(folder):
        os.mkdir(folder)
    Ks_dict = {}
    Ks_dict[0] = K0
    
    taus_dict = {}
    taus_dict[0] = tauD0
    vols_dict = {0:1}
    
    for j,sr in enumerate(stedres):
        tauDs = tauD0 * sr**(2/alpha0)
        Ks = fk(sr)
        Ks_dict[sted_powers[j]] = Ks
        taus_dict[sted_powers[j]] = tauDs
        vols_dict[sted_powers[j]] = sr**3*Ks/K0
        
        Ns = randgen(bounds[0][0] * 1.2, bounds[1][0] * 0.8)
        ys = Gc(x0,Ns,tauDs,Ks,alpha0)*Gt(x0,T0,tauT0)
        
        corrS = np.array([[x,y] for x,y in zip(x0,ys)])
        sted_correlations.append(corrS)
        
        name = "sted_"+str(j)+".npy"
        np.save(folder+name,corrS)
        out_dict[sted_powers[j]].append(name)
        out_dict[sted_powers[j]].append(Ks)
    np.save(folder+"confocal.npy",corr0)
    
    pars0 = {"alpha":alpha0,
             "K":K0,
             "tauD":tauD0,
             "tauT":tauT0,
             "Kdict":Ks_dict,
             "tausdict":taus_dict,
             "volsdict":vols_dict}
    return out_dict, folder, pars0

def tearDown_3DSTED():
     rmtree("test3D")
     

def test_sted_exp(plot=True):
    """Generates a STED experiment, analyse the data and compares"""
    
    pars_dict, folder, pars0 = generate_3DSTED_experiment()
    alpha0 = pars0["alpha"]
    K0 = pars0["K"]
    tauT0 = pars0["tauT"]
    
    def confocal_fitter(tau,N,tauD,T):
        return Gc(tau,N,tauD,K0,alpha0)*Gt(tau,T,tauT0)
        
    bounds = ((0.1,10**-2,0),(1000,100,0.8))
    
    fdict = {}
    for k in pars_dict.keys():
        fdict[k] = [folder+pars_dict[k][1]]
        
    lateral_sizes_theory = [pars_dict[k][0] for k in sorted(fdict.keys())]
    taus_theory = [pars0["tausdict"][k] for k in sorted(pars0["tausdict"].keys())]
    vols_theory = [pars0["volsdict"][k] for k in sorted(pars0["volsdict"].keys())]
    
    fitter = FcsFitter("3D",confocal_fitter,bounds,bounds)
    experiment = FcsExperiment(fcsFitter= fitter, files_dictionary= fdict,
                               alpha = alpha0,loadfun = np.load, tau_t = tauT0, 
                               tau_confocal = None,sted_maker = None)
    experiment.process()
    
    # Verify transit times
    latsizes = experiment.xy_fun(experiment.transit_times()[1])
    pws,taus,_ = experiment.transit_times()
    
    error_rel = np.abs(latsizes-lateral_sizes_theory)/np.max(lateral_sizes_theory)
    maxerr = np.max(error_rel)
    assert(maxerr<10**-2)
    assert(np.max(error_rel)<10**-2)
    
    # Verify volumes
    
    
    pws,vols,_ = experiment.volumes()
    error_rel = np.abs(vols-vols_theory)/np.max(vols_theory)
    print(error_rel)
    if plot:
        plt.figure()
        plt.subplot(131)
        plt.plot(lateral_sizes_theory)
        plt.plot(latsizes)
        plt.legend(["Theory","Measured"])
        plt.ylabel("lateral size")
        
        plt.subplot(132)
        plt.plot(taus_theory)
        plt.plot(taus)
        plt.ylabel("taus")
        plt.legend(["Theory","Measured"])
        
        plt.subplot(133)
        plt.plot(vols_theory)
        plt.plot(vols)
        plt.ylabel("volumes")
        plt.legend(["Theory","Measured"])
        experiment.plot()
        experiment.plot_acfs()
    
def test_3Dsted_exp():
    """repeats multiple times a single end to end test of 3D STED experiment"""
    nexp = 50
    for j in range(nexp):
        test_sted_exp(plot=j==nexp-1)
    tearDown_3DSTED()
test_3Dsted_exp()