#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Jan 15 16:53:19 2019

@author: aurelien
"""
import numpy as np
import matplotlib.pyplot as plt

import glob

from scipy.optimize import curve_fit
import sys
sys.path.append("../..")
import pickle
import os
from tifffile import tifffile
from scipy.ndimage import map_coordinates

from accessories.calibration import sted_power

from pyfcs.fitting import exp_fitter, linear_fitter

MAX_NR_FITS = 10 # Max number of points we want per file

def rotate(origin,angle,data):
    """rotates a set of xy coordinates around an origin point from 
    an angle in degrees
    origin: (2,)
    data : (2,npts)
    """
    ph = angle*np.pi/180
    rot_mat = np.matrix( [[np.cos(ph),np.sin(ph)],
                         [-np.sin(ph),np.cos(ph)]])
    XY = data - origin.reshape(-1,1)
    out = np.dot(rot_mat,XY) + origin.reshape(-1,1)
    return out

def gaussian2D(xvals,x0,z0,sigmax,sigmaz,amplitude,delta):
    x,z = xvals
    return amplitude*np.exp(-((x-x0)**2/(2*sigmax**2)+(z-z0)**2/(2*sigmaz**2) ))+delta

def lorentzian2D(xvals,x0,z0,sigmax,sigmaz,amplitude,delta):
    x,z = xvals
    return amplitude/(1+((x-x0)/sigmax)**2 )/(1+((z-z0)/sigmaz)**2 )+delta

def gaussx_lorentzz(xvals,x0,z0,sigmax,sigmaz,amplitude,delta):
    x,z = xvals
    return np.exp(-((x-x0)**2/(2*sigmax**2)))* amplitude/(1+((z-z0)/sigmaz)**2 )+delta

def lorentzx_gaussz(xvals,x0,z0,sigmax,sigmaz,amplitude,delta):
    x,z = xvals
    return np.exp(-((z-z0)**2/(2*sigmaz**2) ) ) *amplitude/(1+((x-x0)/sigmax)**2 )+delta

func_dict = {
            "gaussian":gaussian2D,
            "lorentzian":lorentzian2D,
            "gauss_lor":gaussx_lorentzz,
            "lor_gauss":lorentzx_gaussz
            }

def fit_slice(slc, model="gaussian"):
    x = np.arange(slc.shape[1])
    z = np.arange(slc.shape[0])
    xc,zc = np.meshgrid(x,z)
    bounds = ((0,0,0,0,0,0),
                  (x.max(),z.max(),x.max(),z.max(),slc.max()*1.5,5))
    
    assert(model in func_dict.keys())
    def f(xvals,x0,z0,sigmax,sigmaz,amplitude,delta):
        return func_dict[model](xvals,x0,z0,sigmax,sigmaz,amplitude,delta).reshape(-1)
    popt,_ = curve_fit(f,[xc,zc],slc.reshape(-1),bounds=bounds)
    # Check that fitting bounds are never reached
    for j in range(len(popt)-1):
        assert(popt[j]>1.01*bounds[0][j] and popt[j]<0.99*bounds[1][j])
    return popt,[xc,zc]

def register_slices(sl1,sl2, verbose=True):
    x0, z0 = fit_slice(sl1)[0][:2]
    x1, z1 = fit_slice(sl2)[0][:2]
    
    shifts = [-int(np.round(z1-z0,0)), -int(np.round(x1-x0,0)) ]
    if verbose:
        print(shifts)
    sl2_new = np.roll(sl2,shifts,axis = (0,1))
    return sl2_new

def register_stack(slc_stack):
    ref = slc_stack[0]
    all_slices = [ref]
    for sl in slc_stack[1:]:
        sl_corr = register_slices(ref,sl)
        all_slices.append(sl_corr)
    return all_slices

def test_stack_register(slc_stack):
    registered = register_stack(slc_stack)
    """plot_n(slc_stack)
    plt.suptitle("Original")
    plot_n(registered)
    plt.suptitle("Registered")
    diffs = [w1.astype(float)-w2.astype(float) for (w1,w2) in zip(slc_stack,registered)]
    plot_n(diffs)
    plt.suptitle("Diffs")"""
    
    sum1 = np.array(slc_stack).astype(float).sum(axis=0)
    sum2 = np.array(registered).astype(float).sum(axis=0)
    plt.figure()
    plt.subplot(121)
    plt.imshow(sum1)
    plt.title("Sum not registered")
    plt.colorbar()
    plt.subplot(122)
    plt.imshow(sum2)
    plt.title("Sum registered")
    plt.colorbar()
    
#test_stack_register(extract_profiles("/home/aurelien/Documents/phd/manuscripts/02_CH_STED/data_analysis/Beads/raw_data/28_11_2018/CH08/CH08_2500.json"))

def plot_slices(slices,fits=None):
    m = min(10,len(slices))
    plt.figure()
    for j in range(m):
        plt.subplot(2,5,j+1)
        plt.imshow(slices[j],cmap="hot")
        plt.axis("off")
        if fits is not None:
            plt.contour(fits[j])
            
def extract_resolutions(file,plot=True):
    """Wrapper function used to extract the resolutions in a dataPicker file.
    Parameters:
        file: string, location of .json file
        plot: bool, if True displays the fitted profiles in a separate plot
    Returns:
        x_fwhm: numpy array, lateral resolution estimated from 2D gaussian fit
        z_fwhm: numpy array, axial resolution estimated from 2D gaussian fit
        x_gaussians: numpy array, lateral resolution estimated from 1D gaussian fit
        z_gaussians: numpy array, axial resolution estimated from 1D gaussian fit
        """
    with open(file,"rb") as fp:
        data = pickle.load(fp)
    
    fn = data["filename"]
    
    new_path = "".join(os.path.split(file)[:-1])
    fn = os.path.join(new_path, os.path.split(fn)[-1])
    stack = tifffile.imread(fn)-2**15
    
    all_slices = list()
    all_fits = list()       #Contains the fit results parameters
    all_fits_models=list() #Contains 2D images corresponding to the fit result
    
            
    for j,fit_dict in data["fits_results"].items():
        if j>=MAX_NR_FITS:
            break
        try:
            linesize = int(fit_dict["linesize"])
        except:
            linesize = int(fit_dict["linesize:"])
        position = fit_dict["position"]
        try:
            xy_angle = fit_dict["xy_angle"]
        except:
            xy_angle = fit_dict["xy_angle:"] #Former notation
        pos = np.round(position).astype(np.int)
        
        xx1 = np.linspace(position[2] - linesize/2,
              position[2] + linesize/2,
              linesize)
        yy1 = np.linspace(position[1],
              pos[1],
              linesize)
        
        #self.angle = 28
        xx1,yy1 = rotate( np.array(pos)[1:][::-1], xy_angle, np.array([xx1,yy1]) )
        xx1 = np.repeat(xx1,stack.shape[0],axis=0)
        yy1 = np.repeat(yy1,stack.shape[0],axis=0)
        zz1 = np.arange(stack.shape[0]).reshape(-1,1)
        zz1 = np.repeat(zz1,linesize,axis=1)
        
        shape = (stack.shape[0],linesize)
        zz1 = zz1.reshape(1,*shape)
        xx1 = np.expand_dims(xx1,0)
        yy1 = np.expand_dims(yy1,0)
        
        sl1 = map_coordinates(stack,np.vstack((zz1,yy1,xx1)))
        all_slices.append(sl1)
        
        out,coords = fit_slice(sl1)
        all_fits.append(out)
        all_fits_models.append(gaussian2D(coords,*out))
        
    if plot:
        plot_slices(all_slices,all_fits_models)
    
    all_fits = np.asarray(all_fits)
    
    x_sigmas = all_fits[:,2]
    z_sigmas = all_fits[:,3]
    x_fwhm = 2*np.sqrt(2*np.log(2))*x_sigmas*data["pixel_size"]
    z_fwhm = 2*np.sqrt(2*np.log(2))*z_sigmas*data["pixel_size"]
    factor = np.sqrt(2*np.log(2)) *2*data["pixel_size"]
    
    x_gaussians = np.array([data["fits_results"][j]["fit_xy"][-1]*factor 
                            for j in range(x_fwhm.size)])
        
    z_gaussians = np.array([data["fits_results"][j]["fit_z"][-1]*factor 
                            for j in range(x_fwhm.size)])

    return x_fwhm,z_fwhm,x_gaussians,z_gaussians

            
def extract_profiles(file,plot=True):
    """Wrapper function used to extract the axial profiles from a tifffile and the
    corresponding json files
    Parameters:
        file: string, location of .json file
        plot: bool, if True displays the fitted profiles in a separate plot
    Returns:
        list: 2D slices
        """
    with open(file,"rb") as fp:
        data = pickle.load(fp)
    
    fn = data["filename"]
    
    new_path = "".join(os.path.split(file)[:-1])
    fn = os.path.join(new_path, os.path.split(fn)[-1])
    stack = tifffile.imread(fn)-2**15
    
    all_slices = list()
    all_fits = list()       #Contains the fit results parameters
    all_fits_models=list() #Contains 2D images corresponding to the fit result
    
            
    for j,fit_dict in data["fits_results"].items():
        if j>=MAX_NR_FITS:
            break
        try:
            linesize = int(fit_dict["linesize"])
        except:
            linesize = int(fit_dict["linesize:"])
        position = fit_dict["position"]
        try:
            xy_angle = fit_dict["xy_angle"]
        except:
            xy_angle = fit_dict["xy_angle:"] #Former notation
        pos = np.round(position).astype(np.int)
        
        xx1 = np.linspace(position[2] - linesize/2,
              position[2] + linesize/2,
              linesize)
        yy1 = np.linspace(position[1],
              pos[1],
              linesize)
        
        #self.angle = 28
        xx1,yy1 = rotate( np.array(pos)[1:][::-1], xy_angle, np.array([xx1,yy1]) )
        xx1 = np.repeat(xx1,stack.shape[0],axis=0)
        yy1 = np.repeat(yy1,stack.shape[0],axis=0)
        zz1 = np.arange(stack.shape[0]).reshape(-1,1)
        zz1 = np.repeat(zz1,linesize,axis=1)
        
        shape = (stack.shape[0],linesize)
        zz1 = zz1.reshape(1,*shape)
        xx1 = np.expand_dims(xx1,0)
        yy1 = np.expand_dims(yy1,0)
        
        sl1 = map_coordinates(stack,np.vstack((zz1,yy1,xx1)))
        all_slices.append(sl1)
        
        out,coords = fit_slice(sl1)
        all_fits.append(out)
        all_fits_models.append(gaussian2D(coords,*out))
        
    if plot:
        plot_slices(all_slices,all_fits_models)
    return all_slices

def resolutions_vs_powers(folder,confocal_file,plot=False):
    """Takes all results within a folder, extract the corresponding resolutions
    Parameters:
        folder: string
        confocal_file: str, path to confocal data
    Returns:
        powers: numpy array, STED powers in mW
        resolutions: numpy array, resolutions vs power. Order: power, modality 
        (2D vs 1D) fitting,data """
    files_1 = sorted(glob.glob(folder+"/*.json"))
    # print(files_1)
    resolutions = []
    powers = []
    # Confocal
    x1,z1,x2,z2 = extract_resolutions(confocal_file,plot=plot)
    resolutions.append((x1,z1,x2,z2))
    powers.append(0)
    
    for file in files_1:
        x1,z1,x2,z2 = extract_resolutions(file,plot=plot)
        resolutions.append((x1,z1,x2,z2))
        pw = os.path.split(file)[-1].split(".")[0].split("_")[-1]
        powers.append(float(pw))
        if plot:
            plt.suptitle("power:"+str(pw))
    powers = sted_power(np.asarray(powers))
    powers[0] = 0
    resolutions = np.asarray(resolutions)
    return powers,resolutions

def analyse_folder(folder,fitter_type,confocal_file,plot=False,plot_results=True,ax=None,
                   plot_kwargs = {},name = None,new_path = None):
    """Parameters:
        folder: string, folder containing data to analyse
        fitter_type: string, 'lin' or 'exp'
        plot: bool, if True plots the beads fits
        plot_results: bool, if True plots the fitting results
        ax: matplotlib axis, handle to the axis where results shall be plotted
        color: string, name of color for plotting results
        new_path: string, path to image stacks for resolutions_vs_powers
    Returns:
        x: 2D numpy array, shape (npowers,n points)
        y: 2D numpy array, shape (npowers,n points)
        out: optimal fitting parameters
        fitter: fitting function
        """
    if name is None:
        name = folder.split("/")[-2]
    #print("plot:",plot)
    pw,resolutions = resolutions_vs_powers(folder,confocal_file,plot=plot)
    x,y=resolutions[:,0],resolutions[:,1]
    assert(pw[0]==0) # the first must be confocal
    CONFOCAL_X = np.mean(x[0])
    CONFOCAL_Z = np.mean(y[0])
    
    x/=CONFOCAL_X
    y/=CONFOCAL_Z
    fitter = None
    x1= x.reshape(-1)
    y1=y.reshape(-1)
    
    if fitter_type=='lin':
        bounds= (0,np.inf)
        fitter = linear_fitter
        
    elif fitter_type=='exp':
        bounds = ((-10,-10,-10),(10,10,10))
        fitter=exp_fitter

    out,_ = curve_fit(fitter,x1,y1,bounds=bounds)
    if plot_results:
        if ax is None:
            fig = plt.figure()
            ax = fig.add_subplot(111)

        ax.errorbar(np.mean(x,axis=1),np.mean(y,axis=1),
                xerr=np.std(x,axis=1),yerr=np.std(y,axis=1),label = name,
                linestyle="",capsize=5, **plot_kwargs, markersize = 8)
        
        xx = np.linspace(x.min()*0.9,1,20)
        yy = fitter(xx,*out)
        c1 = plot_kwargs["color"]
        ls='--' #linestyle
        ax.plot(xx,yy,color=c1,linestyle=ls, markersize = 3)
        ax.legend()
    return x,y,out,fitter

def gres(original,fit):
    """Residuals of difference between two images"""
    return np.sqrt(np.mean( (original-fit)**2 ))