#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Apr  3 11:32:34 2020

@author: aurelien

Generates plenty of plots only for you, brave scientist going through the raw
data of the supplementary material. Why are you here? Is it worth it? Do
you remember coronavirus times? Not so great was it?
"""
import numpy as np
import matplotlib.pyplot as plt

import glob
import sys
sys.path.append("../..")
from constants import FIG_UNIT, PLOT_DEFAULT

import pickle
import os
from resolutions_comparisons import (extract_resolutions, extract_profiles, 
                                     func_dict, fit_slice, 
                                     gaussian2D, lorentzian2D,gaussx_lorentzz,
                                     lorentzx_gaussz, gres, register_slices)

from accessories.plotting import plot_scalebar,get_extent

# -------------- Plotting -------------------
plt.close("all")
PLOT_DEFAULT["CH07"] = {'marker': 'P', 'color': '#52a8e4ff'}
PLOT_DEFAULT["CH09"] = {'marker': 'D', 'color': '#165f91ff'}

PLOT_DEFAULT["3D_10"] = {'marker': 'P', 'color': '#c25b00ff'}
PLOT_DEFAULT["3D_30"] = {'marker': 'D', 'color': '#f7a45aff'}
# -------------- End Plotting -------------------         

def test_one_stack_comparison():
    sp = stacks_paths[7]
    files_1 = sorted(glob.glob(sp+"/*.json"))
    file = files_1[2]
    slices = extract_profiles(file,plot=False)
    sl1 = slices[1]
    
    popt_l, xvals= fit_slice(sl1,model="lorentzian")
    popt_g, xvals= fit_slice(sl1,model="gaussian")
    popt_gl, xvals= fit_slice(sl1,model="gauss_lor")
    popt_lg, xvals= fit_slice(sl1,model="lor_gauss")
    
    lorfit = lorentzian2D(xvals,*popt_l)
    gfit = gaussian2D(xvals,*popt_g)
    glfit = gaussx_lorentzz(xvals,*popt_gl)
    lgfit = lorentzx_gaussz(xvals,*popt_lg)
    
    
    r1 = gres(lorfit,sl1)
    r2 = gres(gfit,sl1)
    r3 = gres(glfit,sl1)
    r4 = gres(lgfit,sl1)
    print("Residuals: Lorentzian {},\n Gaussian: {},\n Gaussian-Lorentzian {},\n Lorentzian-Gaussian {}\n"
          .format(r1,r2,r3,r4))
    
    plt.figure()
    plt.subplot(151)
    plt.imshow(sl1)
    plt.colorbar()
    
    plt.subplot(152)
    plt.imshow(lorfit)
    plt.title("Lorentzian fit")
    plt.colorbar()
    
    plt.subplot(153)
    plt.imshow(gfit)
    plt.title("Gaussian fit")
    plt.colorbar()
    plt.tight_layout()
    
    plt.subplot(154)
    plt.imshow(glfit)
    plt.title("Gaussian-Lorentzian fit")
    plt.colorbar()
    plt.tight_layout()
    
    plt.subplot(155)
    plt.imshow(lgfit)
    plt.title("Lorentzian-Gaussian fit")
    plt.colorbar()
    plt.tight_layout()
    
confocal_file = "/home/aurelien/Documents/phd/manuscripts/02_CH_STED/data_analysis/Beads/raw_data/29_01_2019/confocal_2901.json"
x1,z1,x2,z2 = extract_resolutions(confocal_file)
CONFOCAL_X = np.mean(x1)
CONFOCAL_Z = np.mean(z1)


kwargs_keys = ["2D","z","CH07","CH","CH09","3D_10","3D","3D_30"]

names = ["2D","z",r"$ \rm CH, \rho = 0.7$",r"$ \rm CH, \rho = 0.8$",
         r"$ \rm CH, \rho = 0.9$", "3D, 88/12","3D, 60/40","3D, 25/75"]

# Paths to Image stacks
raw_data_path = "/home/aurelien/Documents/phd/manuscripts/02_CH_STED/data_analysis/Beads/raw_data/"

folders_names = ["2D","z","CH07","CH08","CH09","double_angle10",
                 "double_angle20","double_angle30"]

stacks_paths = [raw_data_path+"28_11_2018/"+w+"/" for w in folders_names]
stacks_paths[1] = raw_data_path+"29_01_2019/z/"



def compare_fitters_oneslice(sl, modes):
    """Compares fitting modes on one image slice
    Parameters:
        sl (ndarray): 2D fluorescence image slice
        modes (list): strings, keywords for func_dict
    Returns:
        ndarray: residuals for each fitting modality specified in modes"""
    residuals = list()
    for j in range(len(modes)):
        popt, xvals= fit_slice(sl,model= modes[j])
        fit_im = func_dict[modes[j]](xvals,*popt)
        residuals.append(gres(sl,fit_im))
    return np.array(residuals)

def filter_paths(files,max_power=2500):
    """Filters json files and removes those referring to highest STED laser power"""
    gn = lambda x: int(x.split(".")[-2].split("_")[-1])
    return sorted(list(filter(lambda x: gn(x)<=max_power,files)))

def compare_one_stedmode(stack_path, modes):
    """Gets all data for one sted mode and compares the different fitting modes
    Parameters:
        stack_path (string): path to folder containing tiff stacks and .json analysed files
        modes (list): strings, keywords for func_dict
    Returns:
        tuple: sted_powers, residuals. Residuals are stored in a 3D array 
        (sted_power, n_slices_per_power ,fit_mode)"""
    files_1 = sorted(glob.glob(stack_path+"/*.json"))
    files_1 = filter_paths(files_1,max_power=2500)

    pws= []
    residuals = []
    
    for file in files_1:
        slices = extract_profiles(file,plot=False)
        res = [compare_fitters_oneslice(sl,modes) for sl in slices]
        pw =  os.path.split(file)[-1].split(".")[0].split("_")[-1]
        
        pws.append(pw)
        residuals.append(res)
    return np.array(pws), np.array(residuals)

def extractfits(sumslice, model):
    """Extract individual slices from a fitted slice"""
    popt, xvals = fit_slice(sumslice, model = model)
    x0, z0 = np.round(popt[:2]).astype(int)
    slice_fit = func_dict[model](xvals,*popt)
    
    profx = sumslice[:,x0]
    profx_fit = slice_fit[:,x0]
    
    profz = sumslice[z0]
    profz_fit = slice_fit[z0]
    
    mkx = lambda x: (np.arange(x.size) - x.size//2)*pixel_size
    xx = mkx(profx)
    zz = mkx(profz)
    return xx,profx_fit,zz,profz_fit,profx,profz
    
from aotools.imspector.calibration import sted_power

fitmodes = sorted(list(func_dict.keys()))
fitmodes = ["gaussian","gauss_lor", "lor_gauss","lorentzian"]
ANALYSE_ALL = False
sp = stacks_paths[1]
for sp in stacks_paths:
    name = sp.split("/")[-2]
    pw, residuals = compare_one_stedmode(sp, fitmodes)
    pw = sted_power(pw.astype(float))
    # pw[0] = 0
    residuals/=residuals[:,:,0][:,:,np.newaxis]
    plt.figure(figsize = (2.5,2.5))
    for j in range(len(fitmodes)):
        res = residuals[:,:,j]
        plt.errorbar(pw,np.mean(res,axis=1),  yerr=np.std(res,axis=1), label = fitmodes[j], capsize=5 )
    plt.xlabel("STED power (mW)")
    plt.ylabel("Residuals/ Residuals Gaussian")
    plt.legend()
    plt.title(sp.split(os.sep)[-2])
    plt.tight_layout()
    plt.savefig(name+"_residuals.svg")
# Confocal file

with open(confocal_file,"rb") as fp:
        data = pickle.load(fp)
        
pixel_size = data["pixel_size"]
    
slices_confocal = extract_profiles(confocal_file,plot=False)
resolutions_confocal = np.array([compare_fitters_oneslice(sl, fitmodes) for sl in slices_confocal])
resolutions_confocal/=resolutions_confocal[:,0][:,np.newaxis]

conf_registered = np.array([register_slices(slices_confocal[0],w, verbose = False) for w in slices_confocal])
conf_registered = conf_registered.mean(axis=0)
extent = get_extent(conf_registered,[pixel_size,pixel_size])
xsb,ysb = plot_scalebar(conf_registered,[pixel_size,pixel_size],size=500,pos=[0.4,0.92])

model = "gaussian"
model2 = "lorentzian"
xx,profx_fit,zz,profz_fit, profx,profz = extractfits(conf_registered,model)
xx,profx_fit2,zz,profz_fit2, profx,profz = extractfits(conf_registered,model2)
    
plt.figure(figsize=(1.2,2.5))
plt.imshow(conf_registered,extent=extent,cmap="hot")
#plt.contour(slice_fit, levels = 3,cmap="gray_r")
plt.plot(xsb,ysb,color="white",solid_capstyle = "butt")
plt.axis("off")
plt.savefig("confocal_image.svg",dpi = 600)

plt.figure(figsize = (5,2.5))
plt.subplot(121)
plt.plot(xx,profx,"-o")
plt.plot(xx,profx_fit, color= "k", linestyle="--",label=model)
plt.plot(xx,profx_fit2, color= "red", linestyle="-.",label=model2)
plt.xlabel("distance (nm)")
plt.ylabel("Counts")
plt.title("z profile")
plt.legend()
#+str(resx1)+str(resx2)
#+str(resz1)+str(resz2)
plt.subplot(122)
plt.plot(zz,profz,"-o")
plt.plot(zz,profz_fit, color= "k", linestyle="--",label=model)
plt.plot(zz,profz_fit2, color= "red", linestyle="-.",label=model2)
plt.xlabel("distance (nm)")
plt.ylabel("Counts")
plt.title("x profile")
plt.legend()
plt.tight_layout()
plt.savefig("confocal_profiles.svg",dpi=600)

for sp in stacks_paths:
    name = sp.split("/")[-2]
    # Fit registered slice
    files_1 = sorted(glob.glob(sp+"/*.json"))
    files_1 = filter_paths(files_1,max_power=2500)
    file = files_1[-1]
    print(file.split("/")[-1])
    with open(file,"rb") as fp:
            data = pickle.load(fp)
            
    pixel_size = data["pixel_size"]
    slices = extract_profiles(file,plot=False)
    sl = slices[0]
    
    slices_registered = [register_slices(sl,w, verbose = False) for w in slices]
    slices_registered = np.array(slices_registered)
    sumslice = slices_registered.sum(axis=0)/slices_registered.shape[0]
    # sumslice = slices_registered[3]
    
    models = func_dict.keys()
    model = "gaussian"
    model2 = "lorentzian"
    xx,profx_fit,zz,profz_fit, profx,profz = extractfits(sumslice,model)
    xx,profx_fit2,zz,profz_fit2, profx,profz = extractfits(sumslice,model2)
    
    resx1 = round(np.mean( np.abs(profx-profx_fit) ),2)
    resx2 = round(np.mean( np.abs(profx-profx_fit2) ),2)
    resz1 = round(np.mean( np.abs(profz-profz_fit) ),2)
    resz2 = round(np.mean( np.abs(profz-profz_fit2) ),2)
    #diff_total = round(np.mean( np.abs(sumslice-) ),2)
    plt.figure(figsize = (5,2.5))
    plt.subplot(121)
    plt.plot(xx,profx,"-o")
    plt.plot(xx,profx_fit, color= "k", linestyle="--",label=model)
    plt.plot(xx,profx_fit2, color= "red", linestyle="-.",label=model2)
    plt.xlabel("distance (nm)")
    plt.ylabel("Counts")
    plt.title("z profile")
    plt.legend()
    #+str(resx1)+str(resx2)
    #+str(resz1)+str(resz2)
    plt.subplot(122)
    plt.plot(zz,profz,"-o")
    plt.plot(zz,profz_fit, color= "k", linestyle="--",label=model)
    plt.plot(zz,profz_fit2, color= "red", linestyle="-.",label=model2)
    plt.xlabel("distance (nm)")
    plt.ylabel("Counts")
    plt.title("x profile")
    plt.legend()
    plt.tight_layout()
    plt.savefig(name+"_profiles.svg",dpi=600)

    # plt.suptitle(name)

    """model = "gaussian"
    popt, xvals = fit_slice(sumslice, model = model)
    x0, z0 = np.round(popt[:2]).astype(int)
    slice_fit = func_dict[model](xvals,*popt)"""
    
    extent = get_extent(sumslice,[pixel_size,pixel_size])
    xsb,ysb = plot_scalebar(sumslice,[pixel_size,pixel_size],size=500,pos=[0.4,0.92])
    
    plt.figure(figsize=(1.2,2.5))
    plt.imshow(sumslice,extent=extent,cmap="hot")
    #plt.contour(slice_fit, levels = 3,cmap="gray_r")
    plt.plot(xsb,ysb,color="white",solid_capstyle = "butt")
    plt.axis("off")
    plt.tight_layout()
    plt.savefig(name+"_picture.svg",dpi=600)
# plot_n(slices_registered)
