#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Jan 15 10:09:06 2019

@author: aurelien
"""

import numpy as np
import matplotlib.pyplot as plt

import glob
import pickle

from scipy.ndimage import map_coordinates

from tifffile import tifffile
from scipy.optimize import curve_fit

MAX_NR_FITS = 10

def rotate(origin,angle,data):
    """rotates a set of xy coordinates around an origin point from 
    an angle in degrees
    origin: (2,)
    data : (2,npts)
    """
    ph = angle*np.pi/180
    rot_mat = np.matrix( [[np.cos(ph),np.sin(ph)],
                         [-np.sin(ph),np.cos(ph)]])
    XY = data - origin.reshape(-1,1)
    out = np.dot(rot_mat,XY) + origin.reshape(-1,1)
    return out

def change_psize(file,pixel_size):
    """Sets a new value for the resolution within a file if it wsa incorrectly
    entered"""
    with open(file,"rb") as fp:
        data = pickle.load(fp)
    print("Changing pixel size: previous "+str(data["pixel_size"] ))
    data["pixel_size"] = pixel_size
    
    print("New: "+str(data["pixel_size"] ))
    with open(file,"wb") as fp:
        data = pickle.dump(data,fp)
    
def get_psize(folder):
    files=glob.glob(folder+"/*.json")
    for file in files:
        with open(file,"rb") as fp:
            data = pickle.load(fp)
        print("pixel size:",data["pixel_size"])

def get_resolutions(file):
    with open(file,"rb") as fp:
        data = pickle.load(fp)
    
    if "pixel_size" not in data.keys():
        psize = 20
    else:
        psize = data["pixel_size"]
    
    
    factor = np.sqrt(2*np.log(2)) *2 * psize
    
    fits_results = data["fits_results"]
    npts = len(fits_results)
    xres,zres = list(), list()
    
    for j in range(npts):
        ff = fits_results[j]
        #last element of fits_results is sigma
        xres.append(ff['fit_xy'][-1] * factor)
        zres.append(ff['fit_z'][-1] * factor)
        
    return np.asarray(xres),np.asarray(zres)

def gaussian2D(xvals,x0,z0,sigmax,sigmaz,amplitude,delta):
    x,z = xvals
    return amplitude*np.exp(-((x-x0)**2/(2*sigmax**2)+(z-z0)**2/(2*sigmaz**2) ))+delta



def fit_slice(slc):
    x = np.arange(slc.shape[1])
    z = np.arange(slc.shape[0])
    xc,zc = np.meshgrid(x,z)
    """print("slc max",slc.max())
    print("minus",-(slc.max())/5)
    print(type(slc.max()))"""
    bounds = ((0,0,0,0,0,-float(slc.max())/5),
              (x.max(),z.max(),x.max(),z.max(),slc.max(),float(slc.max())/5))
    def f(xvals,x0,z0,sigmax,sigmaz,amplitude,delta):
        return gaussian2D(xvals,x0,z0,sigmax,sigmaz,amplitude,delta).reshape(-1)
    popt,_ = curve_fit(f,[xc,zc],slc.reshape(-1),bounds=bounds)
    
    return popt,[xc,zc]

def plot_slices(slices,fits=None):
    m = min(10,len(slices))
    plt.figure()
    for j in range(m):
        plt.subplot(2,5,j+1)
        plt.imshow(slices[j],cmap="hot")
        plt.axis("off")
        if fits is not None:
            plt.contour(fits[j])

def extract_resolutions(file,plot=True):
    """Wrapper function used to extract the resolutions in a dataPicker file.
    Parameters:
        file: string, location of .json file
        plot: bool, if True displays the fitted profiles in a separate plot
    Returns:
        x_fwhm: numpy array, lateral resolution estimated from 2D gaussian fit
        z_fwhm: numpy array, axial resolution estimated from 2D gaussian fit
        x_gaussians: numpy array, lateral resolution estimated from 1D gaussian fit
        z_gaussians: numpy array, axial resolution estimated from 1D gaussian fit
        """
    with open(file,"rb") as fp:
        data = pickle.load(fp)
    
    fname = data["filename"]
    fname = datapath + fname.split("Data/")[-1]
    print(fname)
    stack = tifffile.imread(fname)-2**15
    
    all_slices = list()
    all_fits = list()       #Contains the fit results parameters
    all_fits_models=list() #Contains 2D images corresponding to the fit result
    
            
    for j,fit_dict in data["fits_results"].items():
        if j>=MAX_NR_FITS:
            break
        try:
            linesize = int(fit_dict["linesize"])
        except:
            linesize = int(fit_dict["linesize:"])
        position = fit_dict["position"]
        try:
            xy_angle = fit_dict["xy_angle"]
        except:
            xy_angle = fit_dict["xy_angle:"]
        pos = np.round(position).astype(np.int)
        
        xx1 = np.linspace(position[2] - linesize/2,
              position[2] + linesize/2,
              linesize)
        yy1 = np.linspace(position[1],
              pos[1],
              linesize)
        
        #self.angle = 28
        xx1,yy1 = rotate( np.array(pos)[1:][::-1], xy_angle, np.array([xx1,yy1]) )
        xx1 = np.repeat(xx1,stack.shape[0],axis=0)
        yy1 = np.repeat(yy1,stack.shape[0],axis=0)
        zz1 = np.arange(stack.shape[0]).reshape(-1,1)
        zz1 = np.repeat(zz1,linesize,axis=1)
        
        shape = (stack.shape[0],linesize)
        zz1 = zz1.reshape(1,*shape)
        xx1 = np.expand_dims(xx1,0)
        yy1 = np.expand_dims(yy1,0)
        
        sl1 = map_coordinates(stack,np.vstack((zz1,yy1,xx1)))
        all_slices.append(sl1)
        
        out,coords = fit_slice(sl1)
        all_fits.append(out)
        all_fits_models.append(gaussian2D(coords,*out))
        
    if plot:
        plot_slices(all_slices,all_fits_models)
    
    all_fits = np.asarray(all_fits)
    
    x_sigmas = all_fits[:,2]
    z_sigmas = all_fits[:,3]
    x_fwhm = 2*np.sqrt(2*np.log(2))*x_sigmas*data["pixel_size"]
    z_fwhm = 2*np.sqrt(2*np.log(2))*z_sigmas*data["pixel_size"]
    factor = np.sqrt(2*np.log(2)) *2*data["pixel_size"]
    
    x_gaussians = np.array([data["fits_results"][j]["fit_xy"][-1]*factor 
                            for j in range(x_fwhm.size)])
        
    z_gaussians = np.array([data["fits_results"][j]["fit_z"][-1]*factor 
                            for j in range(x_fwhm.size)])

    return x_fwhm,z_fwhm,x_gaussians,z_gaussians
