#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Jan 15 16:53:19 2019

@author: aurelien
"""
import numpy as np
import matplotlib.pyplot as plt

from resolutions_comparisons import resolutions_vs_powers,extract_resolutions
import glob

from scipy.optimize import curve_fit
import sys
sys.path.append("../..")
sys.path.append("../../misc python/")
from constants import FIG_UNIT, PLOT_DEFAULT, PLOT_FONT_SIZE
from methods import exp_fitter,linear_fitter,make_res_functions
import matplotlib
from scipy.stats import linregress
import pickle

font = {'size'   : PLOT_FONT_SIZE}

matplotlib.rc('font', **font)


confocal_file = "/home/aurelien/Documents/phd/results/2018_10_29_beads_and_resolution/confocal.json"
confocal_file = "/home/aurelien/Documents/phd/manuscripts/02_CH_STED/data_analysis/Beads/confocal_2901.json"
confocal_file = "/home/aurelien/Documents/phd/manuscripts/02_CH_STED/data_analysis/Beads/raw_data/29_01_2019/confocal_2901.json"
x1,z1,x2,z2 = extract_resolutions(confocal_file)
print("Resolutions confocal:",np.mean(x1),np.mean(z1),np.mean(x2),np.mean(z2))
print("Aspect ratios:",np.mean(z1)/np.mean(x1),np.mean(z2)/np.mean(x2))
CONFOCAL_X = np.mean(x1)
CONFOCAL_Z = np.mean(z1)

folders = glob.glob("../../data_analysis/Beads/raw_data/28_11_2018/*/")
folders.sort()
folders.insert(1, "../../data_analysis/Beads/raw_data/29_01_2019/z/")
all_res_fits = []

fig,ax = plt.subplots(figsize= (FIG_UNIT*2,FIG_UNIT*2))


fitter_type=["lin","exp","lin","lin","lin","lin","lin","exp"]


for fd,ft in zip(folders,fitter_type):
    print(fd,":",ft)

def analyse_folder(folder,fitter_type,plot=False,plot_results=True,ax=None,
                   plot_kwargs = {},name = None):
    """Parameters:
        folder: string, folder containing data to analyse
        fitter_type: string, 'lin' or 'exp'
        plot: bool, if True plots the beads fits
        plot_results: bool, if True plots the fitting results
        ax: matplotlib axis, handle to the axis where results shall be plotted
        color: string, name of color for plotting results
    Returns:
        x: 2D numpy array, shape (npowers,n points)
        y: 2D numpy array, shape (npowers,n points)
        out: optimal fitting parameters
        fitter: fitting function"""
    if name is None:
        name = folder.split("/")[-2]
    pw,resolutions = resolutions_vs_powers(folder,confocal_file,plot=plot)
    x,y=resolutions[:,0],resolutions[:,1]
    x/=CONFOCAL_X
    y/=CONFOCAL_Z
    fitter = None
    print("x shape",x.shape)
    print("z shape",y.shape)
    x1= x.reshape(-1)
    y1=y.reshape(-1)
    
    if fitter_type=='lin':
        bounds= (0,np.inf)
        fitter = linear_fitter
        
    elif fitter_type=='exp':
        bounds = ((-10,-10,-10),(10,10,10))
        fitter=exp_fitter
    try:
        out,_ = curve_fit(fitter,x1,y1,bounds=bounds)
    except:
        print("fitter fail")
        out = (0,0,0)
    print('residuals:',np.sum((fitter(x1,*out)-y1)**2))
    if plot_results:
        if ax is None:
            fig = plt.figure()
            ax = fig.add_subplot(111)

        ax.errorbar(np.mean(x,axis=1),np.mean(y,axis=1),
                xerr=np.std(x,axis=1),yerr=np.std(y,axis=1),label = name,
                linestyle="",capsize=5, **plot_kwargs, markersize = 8)
        
        xx = np.linspace(x.min()*0.9,1,20)
        yy = fitter(xx,*out)
        c1 = plot_kwargs["color"]
        ls='--' #linestyle
        ax.plot(xx,yy,color=c1,linestyle=ls, markersize = 3)
        ax.legend()
    return x,y,out,fitter

PLOT_DEFAULT["CH07"] = {'marker': 'P', 'color': '#52a8e4ff'}
PLOT_DEFAULT["CH09"] = {'marker': 'D', 'color': '#165f91ff'}

PLOT_DEFAULT["3D_10"] = {'marker': 'P', 'color': '#c25b00ff'}
PLOT_DEFAULT["3D_30"] = {'marker': 'D', 'color': '#f7a45aff'}

kwargs_keys = ["2D", "z", "CH07", "CH", "CH09", "3D_10", "3D", "3D_30"]

names = ["2D","z",r"$ \rm CH, \rho = 0.7$",r"$ \rm CH, \rho = 0.8$",
         r"$ \rm CH, \rho = 0.9$", "3D, 88/12","3D, 60/40","3D, 25/75"]

for jn,(folder,ft) in enumerate(zip(folders,fitter_type)):
    name = folder.split("/")[-2]
    print('analysing folder',name)
    x,y,out,fitter = analyse_folder(folder,ft,ax=ax,
                                    plot_kwargs = PLOT_DEFAULT[kwargs_keys[jn]],
                                    name=names[jn])

    all_res_fits.append((name,out,ft))

ax.set_xlabel(r"$ \rm \omega_{xy}/\omega_{xy,confocal}$")
ax.set_ylabel(r"$ \rm \omega_{z}/\omega_{z,confocal}$")
ax.set_xlim(right=1.1)
fig.tight_layout()
ax.legend()
fig.savefig('fig_beads_fitting.svg',transparent = True)

radii = [0.7,0.8,0.9]
ch_indices = [2,3,4]
slpes = [all_res_fits[j][1] for j in ch_indices]
slpes = np.array(slpes).reshape(-1)
plt.figure()
plt.plot(radii,slpes,"o")
out_linregress = linregress(radii,slpes)
slope,intercept = out_linregress[0], out_linregress[1]
plt.plot(np.linspace(0,1,10),np.linspace(0,1,10)*slope+intercept)
plt.legend(["Data","fit"])
plt.xlabel("CH radius")
plt.ylabel("slope of linear fitter")

wanted_radii = np.round(np.linspace(0.55,1,10),2)

#dict 2 contains all the results
dict2 = {}
for w in wanted_radii:
    if w*slope+intercept >0:
        dict2[w] = ([w*slope+intercept],'lin')
    else:
        dict2[w] = ([0],'lin')
        
for nm,par,fitn in all_res_fits:
    if 'CH' not in nm:
        dict2[nm] = (par,fitn)

#dict2[1] = ([0],'lin')
dict2["Kconf"] = CONFOCAL_Z/CONFOCAL_X
with open("fit_dictionay_extrapolated.json","wb") as fd:
    pickle.dump(dict2,fd)

    
slopes = np.array([dict2[r][0] for r in wanted_radii]).reshape(-1)
plt.figure()
plt.plot(wanted_radii,slopes,"-o")
plt.xlabel("CH radius")
plt.ylabel("slope")

#-----------Test--------------
from pyfcs.fitting import fitters_dict
fits = "/home/aurelien/Documents/phd/Python/pyfcs/pyfcs/data/fit_dictionay_extrapolated.json"
with open(fits,"rb") as f:
    d1 = pickle.load(f)
    
def test_folder(folder,fittype=None):
    name = folder.split("/")[-2]
    fig,axes = plt.subplots(1,3)
    fig.suptitle(name)
    
    if 'CH' in name:
        name = name[2:]
        name = round(float(name) *10**-(len(name)-1),2)
        print("CH radius:",name)
    par,ftype = dict2[name]
    ax1 = axes[0]
    ax2 = axes[1]
    ax3 = axes[2]
    
    x,y,out,fitter = analyse_folder(folder,ftype,ax=ax1)
    
    x,y = x.reshape(-1), y.reshape(-1)

    fz,fk,fv = make_res_functions(dict2,name)

    x1 = np.linspace(0.1,1,10)
    dv = fv(x1)
    
    if fittype is not None:
        k = fittype
        print(k)
        fitter_name = d1[k][1]
        fitpars = d1[k][0]
        current_fit = fitters_dict[fitter_name](x1, *fitpars)
        ax1.plot(x1,current_fit,"-o",color="black")
        
    ax1.plot(x1,fz(x1),'-x',label='verification')
    ax1.set_xlabel("Lateral resolution")
    ax1.set_ylabel("Axial resolution")
    ax1.legend()
    
    ax2.plot(x,CONFOCAL_Z/CONFOCAL_X*y/x,"o")
    ax2.plot(x1,fk(x1))
    ax2.set_title("Aspect ratio")
    ax2.set_xlabel("Lateral resolution")
    ax2.set_ylabel("K")

    ax3.plot(x,x**2*y,'o')
    ax3.plot(x1,dv,'-x')
    ax3.set_title("Volume")
    ax3.set_xlabel("Lateral resolution")
    ax3.set_ylabel("V/V0")
1/0
plt.close('all')
folder = folders[2]
test_folder(folder)

folders_keys = ["2D","3D",0.7,0.8,0.9,"double_angle10","double_angle20",
                "double_angle30"]
for j,fold in enumerate(folders):
    test_folder(fold, fittype=folders_keys[j])




#Ok test is good
    
#Aspect ratio for CH 07 is bof bof but no apparent good solution
#CH08: slopes don't match. probably bc of fit
#double 10 ok
#double 20 ok