# -*- coding: utf-8 -*-
"""
Created on Tue Mar  3 10:26:44 2020

@author: univ4208
"""

import matplotlib.pyplot as plt

import glob

from accessories.io import generate_names_list_dict

from pyfcs.fitting import FcsExperiment, FcsFitter,Gc,Gt, export_experiments

import numpy as np
import sys
sys.path.append("../..")
from constants import FIG_UNIT,PLOT_DEFAULT,PLOT_FONT_SIZE
import matplotlib
from pyfcs.io import open_SIN

#Definition of functions
def Gconf(tau,N,tauD,T):
    return Gc(tau,N,tauD,K_CONFOCAL,ALPHA) * Gt(tau,T,TAU_T)

def get_residuals(exp,x0 = 0,x1 = 5*10**-2):
    curves = exp.fcs_curves
    ks = curves.keys()
    residuals = {}
    for k in ks:
        cs = curves[k]
        tmp_res = list()
        for curve in cs:
            x = curve.x
            y = curve.y
            yhat = curve.fitter(x,*curve.popt)
            N = curve.N
            mask = np.logical_and(x>x0,x<x1)        
            diff = np.sqrt(np.mean( ( (y-yhat)[mask]*N )**2 ))
            tmp_res.append(diff)
        residuals[k] = tmp_res[:]
        tmp_res = list()
    return residuals
import os

def sep_by_rep(input_dict):
    """Separates a dictionary of STED-FCS experiments and splits them one by one.
    Tested on CH dataset"""
    getrep = lambda a:int(a.split(os.sep)[-1].split("rep")[-1].split(".")[0])
    indices = [getrep(w) for w in input_dict[0]] # finds indices in confocal
    
    keys = input_dict.keys()
    output_dicts = list()
    
    for ind in indices:
        nd = {}
        for key in keys:
            element = list(filter(lambda x: "rep"+str(ind)+"." in x, input_dict[key]))
            nd[key] = element
        output_dicts.append(nd)
    return output_dicts

def merge_exp_results(exp_list,fname):
    assert(len(exp_list)>0)
    
    def ca(aa): # Checks all columns of an array are the same
        for j in range(aa.shape[1]):
            assert(np.all(aa[:,j]==aa[0,j]))
            
    all_powers = list()
    all_values = list()
    for exp in exp_list:
        method_to_call = getattr(exp, fname)
        powers, value, value_std = method_to_call()
        all_powers.append(powers)
        all_values.append(value)
    all_powers = np.asarray(all_powers)
    ca(all_powers)
    return all_powers[0], np.asarray(all_values)

def merge_dicts(dicts):
    new_dict = {}
    for dic in dicts:
        for k in dic.keys():
            if k in new_dict:
                new_dict[k].extend(dic[k])
            else:
                new_dict[k] = dic[k]
    return new_dict

def mean_std_fromdicts(dic):
    pows = np.array(sorted(dic.keys()))
    means = list()
    stds = list()
    
    for p in pows:
        vals = dic[p]
        means.append(np.mean(vals))
        stds.append(np.std(vals))
    return pows, np.array(means), np.array(stds)

def plot_experiment_series(exp_list):
    n = len(exp_list)
    assert(n>0)
    pws = list(exp_list[0].fcs_curves.keys())
    pws.sort()
    fig,axes = plt.subplots(2,len(pws))
    
    for jp in range(len(pws)):
        for j in range(n):
            exp_list[j]
            pws[jp]
            # print(exp_list[j].fcs_curves[pws[jp]])
            curves = exp_list[j].fcs_curves[pws[jp]]
            if len(curves)>0:
                curve = curves[0]
                curve.plot(axes = axes[:,jp], offset = 0.2*j)
        axes[0,jp].set_title(str(pws[jp])+" mW")
# Layout and stuff
font = {'size'   : PLOT_FONT_SIZE}

matplotlib.rc('font', **font)

plt.close("all")

folders = glob.glob("../../data_analysis/FCS/2018_10_19/*/")

fz = [f for f in folders if "3D STED" in f][0]
fch = [f for f in folders if "CH" in f][0]
f3d = [f for f in folders if "3D_50 2D_50" in f][0]

print(folders)

K_CONFOCAL = 4
ALPHA = 0.8
TAU_T = 5*10**-3

bounds = ((10**-3,10**-3,0),(1000,1,1-1e-3))

out_ch = generate_names_list_dict(fch)

fitter_ch = FcsFitter(0.85,Gconf,bounds,bounds,
                  k_confocal = K_CONFOCAL)

fitter_z = FcsFitter("3D",Gconf,bounds,bounds,
                  k_confocal = K_CONFOCAL)

fitter_3d = FcsFitter("double_angle20",Gconf,bounds,bounds,
                  k_confocal = K_CONFOCAL)

new_folders = [f3d, fz, fch]
fitters = [fitter_3d, fitter_z, fitter_ch]
names = ["3D", "z", "CH"]

def get_mean_std(arr):
    u, v= arr.shape
    means, stds = list(), list()
    for j in range(v):
        subarr = arr[:,j]
        subarr = subarr[~np.isnan(subarr)]
        means.append(np.mean(subarr))
        stds.append(np.std(subarr))
    return np.asarray(means), np.asarray(stds)

# !!! Needs checking
all_experiments = list()

plt.figure(figsize = (FIG_UNIT*2.3, FIG_UNIT))
for j in range(len(new_folders)):
    name = names[j]
    out = generate_names_list_dict(new_folders[j])
    output_dicts = sep_by_rep(out)
    loadfun = lambda x: open_SIN(x)[0]
    fitter = fitters[j]
    experiments = [FcsExperiment(fcsFitter = fitter,files_dictionary = fd,
                               alpha = ALPHA,loadfun = loadfun,tau_t = TAU_T) for fd in output_dicts] 
    
    all_experiments.append(experiments)
    export_experiments(experiments,name+"_cellbycell")
    
    pw, vols = merge_exp_results(experiments, "volumes")
    vols_mean, vols_std = get_mean_std(vols)
    
    pwn, ns = merge_exp_results(experiments, "number_of_molecules")
    ns_mean, ns_std = get_mean_std(ns)
    vols_mean, vols_std = get_mean_std(vols)
    
    resids = merge_dicts([get_residuals(w) for w in experiments])
    pwr, mr, sr = mean_std_fromdicts(resids)
    
    plt.subplot(131)
    plt.errorbar(pw, vols_mean, yerr=vols_std, **PLOT_DEFAULT[name], capsize=5)
    plt.xlabel("STED laser power (mW)")
    plt.ylabel(r"$\rm V/V_{confocal}$")
    plt.ylim(bottom=0)
    
    plt.subplot(132)
    plt.errorbar(pw, ns_mean, yerr=ns_std, **PLOT_DEFAULT[name], capsize=5)
    plt.xlabel("STED laser power (mW)")
    plt.ylabel(r"$\rm V/V_{confocal}$")
    plt.ylabel(r"$\rm N/N_{confocal}$")
    plt.ylim(bottom=0)

    plt.subplot(133)
    plt.errorbar(pwr, mr, yerr=sr, **PLOT_DEFAULT[name], capsize=5, label = names[j])
    plt.xlabel("STED laser power (mW)")
    plt.ylabel(r"$\rm nRMSD$")
    plt.ylim(bottom=0)

plt.legend()
plt.tight_layout()
plt.savefig("cell_normalised_results.svg")

# fig0, axes0 = plt.subplots(1,2)
fig1, ax1 = plt.subplots(figsize = (FIG_UNIT*2.3/3, FIG_UNIT))

sted_types = ["3D","z","CH"]

for j in range(len(new_folders)):
    foldname = new_folders[j].split("/")[-2]
    #6 is ok
    experiment = all_experiments[j][1]
    curve_confocal = experiment.fcs_curves[0][0]
    curve_sted = experiment.fcs_curves[32][0]
    
    x,y = curve_sted.x.copy(), curve_sted.y.copy()
    yh = curve_sted.fitter(x,*curve_sted.popt)
    Nc = curve_confocal.N
    y*=Nc
    yh*=Nc
    color = PLOT_DEFAULT[sted_types[j]]["color"]
    ax1.semilogx(x,y,label=sted_types[j], color=color)
    ax1.semilogx(x,yh,color="k",linestyle="--")
    
ax1.legend()
ax1.set_xlabel(r"$\rm \tau \ (ms)$")
ax1.legend()
ax1.set_ylabel(r"$\rm G(\tau) *N_{confocal}$")
fig1.tight_layout()

fig1.savefig("cells_FCScurves.svg")