#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Jan 24 10:36:25 2019

@author: aurelien

'Win to win to yes need the no to win again the no' J. P Raffarin (P. M. of 
                                                    France for a while)
"""

import glob

from psfsim.calc import attenuation,pinhole1D
#from aotools.ext.highNAPSF2 import attenuation,pinhole1D

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from resolution_analysis import fit_slice
import sys
sys.path.append("../..")
from constants import PLOT_FONT_SIZE, FIG_UNIT,PLOT_DEFAULT

import matplotlib.colors as colors
from accessories.plotting import plot_scalebar

font = {'size'   : PLOT_FONT_SIZE}

matplotlib.rc('font', **font)

NA = 1.4
lbd = 640
airy_radius = 1.22*lbd/(2*NA)

PSIZE = 10


max_power = 3
DELTA=0


spath = "./"
files = glob.glob("../../simulations/10nm/*.npy")
files.sort()

simulations = [np.load(w) for w in files]

excitation = simulations.pop(2)
detection = simulations.pop(-1)

detection = pinhole1D(airy_radius,PSIZE,detection)

depletions = simulations[:]
r = 0.5  #Doughnut to top-hat ratio
depletions.append( r*depletions[0]+(1-r)*depletions[-1] )
depletions = [w for w in depletions]

def dual_generator(r):
    dual = r*depletions[0]+(1-r)*depletions[-1]
    return dual

def dmap(img):
    u1,v1 = img.shape
    dx = (np.arange(v1) - v1//2).reshape(1,-1).astype(np.float)


    fc = (1+v1%2)/4 # Add 1/4 if odd nr of pixels and 1/2 otherwise
    dx[dx==0]+=fc
    dx[dx>=0]+=fc#takes account first pixel
    dx[dx<=0]-=fc #takes account first pixel
    dx = np.abs(dx)
    dmap = np.abs(np.repeat(dx,u1,axis=0))
    return dmap

#Compute distance maps
def integrate(img):
    """Integral in cylindrical coordinates with rotational invariance"""
    u1,v1 = img.shape
    dm = dmap(img)
    integral = np.pi*np.sum(img*dm)
    return integral

def get_snb(psf_data, plot=False, thr = (1/np.e**2)):
    """From a STED PSF, returns the signal and background values
    
    Parameters:
        psf_data: 2D array, PSF values
    Returns:
        signal: array, bool: True where pixels are signal
        background: array, bool: True where pixels are background
        sbr: float,signal to background ratio"""
    bg = psf_data/psf_data.max() < thr
    background = psf_data.copy()
    signal = psf_data.copy()
    
    background[~bg] = 0
    signal[bg] = 0
    
    if plot:
        plt.figure()
        plt.subplot(131)
        plt.imshow(background)
        plt.title("Background")
        plt.subplot(132)
        plt.imshow(signal)
        plt.title("Signal")
        plt.subplot(133)
        plt.imshow(psf_data)
        plt.title("PSF")
        
    sbr = integrate(signal)/integrate(background)
    return integrate(signal),integrate(background), sbr

def test_integral(nz = 5,plot = False):
    # In our method, we integrate radius between -rmax and rmax, hence the
    # angle is integrated only between 0 and pi
    v=201
    x1 = np.linspace(-v/2,v/2,v).reshape(1,-1)
    x2 = x1**2
    x2 = np.repeat(x2,nz,axis=0)
    x3 = np.ones_like(x1).reshape(1,-1)
    x3 = np.repeat(x3,nz,axis=0)
    if plot:
        plt.figure()
        plt.subplot(121)
        plt.imshow(x2)
        plt.subplot(122)
        plt.imshow(x3)
    res = integrate(x2)
    theory = 2*np.pi*(v/2)**4/4*nz
    
    res3 = integrate(x3)
    th3 = 2*np.pi*(v/2)**2/2*nz
    print(res,"theoretical",theory)
    print("relative error:",(res-theory)/theory,"ratio theory/experimental:",theory/res)
    print("Integral cst",res3,"theory",th3)
    print("relative error for 2nd test:",(res3-th3)/th3)
    print("ratio theory/experimental for 2nd test:",th3/res3)
#test_integral(plot=False)

# integrate(excitation)

def compute_snbs(power,plot=False,titles = ["z","CH","2D","3D"]):
    """compute signal-to-background ratios based on 1/e2 thresholding."""
    psfs = [attenuation(excitation,depletion,power,delta=DELTA)*detection for depletion in depletions]
    bgs=[]
    resolutions = []
    if plot:
        plt.figure()
        plt.suptitle(power)
    for j,psf in enumerate(psfs):
        psf/=np.max(psf)
        popt,cds = fit_slice(psf)
        bg = psf<(psf.max()/np.e**2)
        bgs.append(bg )
        if plot:
            psf_signal = psf.copy()
            psf_background = bg.copy()
            psf_background = (~bg).astype(np.float)
            psf_background[bg] = np.nan
            plt.subplot(2,2,j+1)
            #plt.imshow(psf_signal,vmin=0,vmax=1/np.e**2)
            #plt.imshow(psf_background,cmap="Greens_r" )
            plt.imshow(np.log(psf_signal) )
            if titles is not None:
                plt.title(titles[j])
            bg2 = bg.copy()
            bg2[~bg]=np.nan
            #plt.imshow(bg2,alpha=0.3,cmap='Reds')
            #plt.contour(gaussian2D(cds,*popt),levels=3)
        factor = 2*np.sqrt(2*np.log(2))
        xsize = popt[2]*factor*PSIZE
        zsize = popt[3]*factor*PSIZE
        resolutions.append([xsize,zsize])
        
    def compute_snb(psf_data,bg):
        background = psf_data.copy()
        signal = psf_data.copy()
        
        background[~bg] = 0
        signal[bg] = 0
        return integrate(signal)/integrate(background)
    
    snbs = [ get_snb(p) for p in psfs]
    signal = [w[0] for w in snbs]
    background = [w[1] for w in snbs]
    snbs = [w[2] for w in snbs]
    #[ print("Signal:",np.sum(p[~bg]),"Background:", np.sum(p[bg])) for p,bg in zip(psfs,bgs)]
    snbs = np.asarray(snbs)
    resolutions = np.asarray(resolutions)
    return snbs,resolutions,signal,background
    
class LogDivergingNorm(colors.LogNorm):
    def __init__(self, vcenter, vmin=None, vmax=None, clip=False):

        self.vcenter = vcenter
        self.vmin = vmin
        self.vmax = vmax
        self.clip = clip
        if vcenter is not None and vmax is not None and vcenter >= vmax:
            raise ValueError('vmin, vcenter, and vmax must be in '
                             'ascending order')
        if vcenter is not None and vmin is not None and vcenter <= vmin:
            raise ValueError('vmin, vcenter, and vmax must be in '
                             'ascending order')
            
            
    def __call__(self, value, clip=None):
        if clip is None:
            clip = self.clip

        result, is_scalar = self.process_value(value)

        result = np.ma.masked_less_equal(result, 0, copy=False)

        self.autoscale_None(result)
        vmin, vmax = self.vmin, self.vmax
        if vmin > vmax:
            raise ValueError("minvalue must be less than or equal to maxvalue")
        elif vmin <= 0:
            raise ValueError("values must all be positive")
        elif vmin == vmax:
            result.fill(0)
        else:
            if clip:
                mask = np.ma.getmask(result)
                result = np.ma.array(np.clip(result.filled(vmax), vmin, vmax),
                                     mask=mask)
            # in-place equivalent of above can be much faster
            resdat = result.data
            mask = result.mask
            if mask is np.ma.nomask:
                mask = (resdat <= 0)
            else:
                mask |= resdat <= 0
            np.copyto(resdat, 1, where=mask)
            np.log(resdat, resdat)
            result = np.ma.array(resdat, mask=mask, copy=False)
            
            result = np.ma.masked_array(
                np.interp(result, [np.log(self.vmin), np.log(self.vcenter),
                                   np.log(self.vmax)],
                          [0, 0.5, 1.]), mask=np.ma.getmask(result))
                
        if is_scalar:
            result = result[0]
        return result
    
    def inverse(self, value):
        if not self.scaled():
            raise ValueError("Not invertible until scaled")
        vmin, vmax = self.vmin, self.vmax

        if np.iterable(value):
            val = np.ma.asarray(value)
            return vmin * np.ma.power((vmax / vmin), val)
        else:
            return vmin * pow((vmax / vmin), value)

    def autoscale(self, A):
        # docstring inherited.
        super().autoscale(np.ma.masked_less_equal(A, 0, copy=False))

    def autoscale_None(self, A):
        # docstring inherited.
        super().autoscale_None(np.ma.masked_less_equal(A, 0, copy=False))



def plot_psf_bg_double(im,ax=None,ax2=None,crop = [None,None],
                           show_cbar = False, vmin = None,fig=None,
                           depletion = None,show1only = False, 
                           show2only = False,name =None):
    """Plots the background contributions in a PSF.
    Parameters:
        im: numpy array, the PSF
        ax: mpl axis, axis where the image will be plotted
        ax2: mpl axis, axis where intensity profiles are ploted
        crop: list, values to crop in x and z """   
    
    psfc = im.copy()/im.max()
    if crop[0] is not None:
        c = crop[0]
        psfc = psfc[c:-c]
    if crop[1] is not None:
        c1 = crop[1]
        psfc = psfc[:,c1:-c1]

    uu,vv = psfc.shape
    uu*=PSIZE 
    vv*=PSIZE # micrometers
    extent = [-vv/2,vv-vv/2,-uu/2,uu-uu/2]
    
    if not show2only:
        imsg = ax.imshow(psfc,
                 cmap='PRGn',extent = extent, norm = LogDivergingNorm(0.13,vmin = 10**-2,vmax = 1))
    if show_cbar:
        if fig is not None:
            cbaxes = fig.add_axes([0.1, 0.1, 0.03, 0.8])
            cbaxes1 = fig.add_axes([0.8, 0.05, 0.03, 0.8])
        else:
            cbaxes = None
            cbaxes1 = None
        plt.colorbar(imsg,anchor = (0,0.5),cax = cbaxes)
            
    map1 = dmap(psfc)
    psf_i = psfc*map1
    psf_i /= psf_i.max()
    
    # print("signal values:",
    #       psf_i[psfc>=1/np.e**2].min(),psf_i[psfc>=1/np.e**2].max())
    # print("background values:",
    #       psf_i[psfc<=1/np.e**2].min(),psf_i[psfc<=1/np.e**2].max())
    
    u1,v1 = psf_i.shape
    if not show2only:
        psf_i[:,:v1//2] = None
    p1 = psf_i.copy()
    p2 = psf_i.copy()
    p1[psfc<1/np.e**2] = None
    p2[psfc>=1/np.e**2] = None

    if not show1only:
        imsg = ax.imshow(psf_i, cmap='PRGn',extent=extent, norm = LogDivergingNorm(0.13,vmin = 10**-2,vmax = 1))
    xl = np.zeros(20)
    yl = np.linspace(extent[2],extent[3],xl.size)
    if not show1only and not show2only:
        ax.plot(xl,yl,color="white")
    
    if ax2 is not None:
        psf_i2 = psfc * map1
        intensity = psf_i2.sum(axis=1)
        imax = intensity.max()
        signal = psf_i2.copy()
        signal[psfc<1/np.e**2] = 0
        signal = np.sum(signal,axis=1)
        
        bgr = psf_i2.copy()
        bgr[psfc>=1/np.e**2] = 0
        bgr = np.sum(bgr,axis=1)
        
        intensity/=imax
        signal/=imax
        bgr/=imax
        
        zsignal = (np.arange(signal.size)-signal.size//2)*PSIZE
        zbgr = (np.arange(bgr.size)-signal.size//2)*PSIZE
        assert( (zsignal==zbgr).all() )
        ax2.plot(bgr/np.sum(signal),zbgr,label="background",color="m")
        ax2.set_xlim(right = 0.012)
        ax2.set_xticks([0,0.012])
        ax2.set_xticklabels([0,0.012])
        ax2.set_xlabel(r"$I/I_{signal}$")   
    if show_cbar:
        plt.colorbar(imsg,cax = cbaxes1)
    ax.set_ylim(bottom = extent[2],top=extent[3])
    ax.axis("off")
    res = [PSIZE,PSIZE]
    xsb,ysb = plot_scalebar(psf,res,axis=1,pos = [0.3,-0.3],size=200)
    ax.plot(xsb,ysb,color="white", solid_capstyle = "butt")


power = 1
powers_special = np.ones(4) * power
psfs = [attenuation(excitation,depletion,pw,delta=DELTA)*detection for pw,depletion in zip(powers_special,depletions)]


psfs.append(excitation*detection)
names = ["z-STED","CH-STED","2D-STED","3D-STED","confocal"]


# Boundaries for psf representation
vmin = 5*10**-3
croparea = [40,20]
a1 = None

figsize = (FIG_UNIT*0.6,FIG_UNIT*0.85)
for j,psf in enumerate(psfs):
    f, a0 = plt.subplots(1, 1, figsize=figsize)
    
    """if names[j]!="confocal":
        dep = depletions[j]
        dep = None
    else:
        dep = None"""
    plot_psf_bg_double(psf,ax=a0,ax2=a1,crop=croparea,show_cbar=False, 
                       vmin = vmin,fig = f,name=names[j])
    
    a0.set_title(names[j])
    f.tight_layout()
    f.savefig(spath+names[j]+".svg",dpi=600,transparent = True)
    if j==3:
        # a1.legend()
        f.tight_layout()
        f.savefig(spath+names[j]+"_legend.svg",dpi=600,transparent = True)
        
        
        f, a0 = plt.subplots(1, 1, figsize=figsize)
        plot_psf_bg_double(psf,ax=a0,ax2=a1,crop=croparea,show_cbar=True,fig=f, vmin = vmin,name=names[j])
        f.savefig(spath+names[j]+"_colorbars.svg",dpi=600,transparent = True)

# For supplementary material
SHOW_POWER_SNB = True
if SHOW_POWER_SNB:
    pws = np.linspace(0,3,20)
    all_snbs = list()
    all_resolutions = list()
    all_signals = list()
    all_backgrounds = list()
    for p in pws:
        snbs,resolutions,signals,bgds = compute_snbs(p)
        all_snbs.append(snbs)
        all_resolutions.append(resolutions)
        all_signals.append(signals)
        all_backgrounds.append(bgds)
        
    all_resolutions = np.asarray(all_resolutions)
    
    # Find what divides 2D STED resolution by 2
    res1 = all_resolutions[:,2,0]
    arm = np.argmin( np.abs(res1-res1[0]/2) )
    psat = pws[arm]
    

    all_snbs = np.array(all_snbs)
    plt.figure(figsize = (FIG_UNIT*0.8,FIG_UNIT*0.8))
    types = ["z","CH","2D","3D"]
    for j in range(len(types)):
        c = PLOT_DEFAULT[types[j]]["color"]
        plt.plot(pws/psat, all_snbs[:,j],label=types[j],color = c)
    plt.xlabel(r"$ \rm STED\ laser\ power\ (P/P_{sat})$")
    plt.ylabel("SBR")
    plt.legend()
    plt.tight_layout()
    plt.savefig("SBR.svg",transparent=True)