
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Aug  2 15:34:02 2019

@author: aurelien
"""

import numpy as np
import matplotlib.pyplot as plt
import aotools
from aotools.ext.paths import datapath
import glob

from skimage.external.tifffile import imread
from skimage.feature import peak_local_max
from skimage import filters
from scipy.ndimage import zoom
plt.close("all")
psize = 30
conf_xy = 240 * np.sqrt(2*np.log(2))
conf_xz = 650 * np.sqrt(2*np.log(2))

path1 = "../patterns3/"

files = glob.glob(path1+"*.tif")

names_xz = [path1+"3D_xz.tif",path1+"2D_xz.tif","../patterns2/CH075_xz.msr - C=0.tif"]
names = ["z-STED","2D-STED","CH-STED"]
profiles_xz = [imread(w)-2**15 for w in names_xz]

names_xy = ["../xy_noaberration/3D_xy.tif","../xy_noaberration/2DSTED.msr - C=0.tif","../xy_noaberration/CH075_xy.msr - C=0.tif"]

cropx = 10
cropz = 20
profiles_xz = [w[cropz:-cropz,cropx:-cropx] for w in profiles_xz]

profiles_xy = [imread(w)-2**15 for w in names_xy]

# 2D picture has different pixel size, 20 instead of 30 nm
profiles_xy[1] = zoom(profiles_xy[1],2/3)

profiles_xy = [w[cropx:-cropx,cropx:-cropx] for w in profiles_xy]

def find_zcentre(img, plot=False, gfilt=False,is2d=False):
    invert = np.ones_like(img)*img.max()-img
    if gfilt:
        invert = filters.gaussian(invert,sigma=2)
    coordinates = peak_local_max(invert,exclude_border=25,min_distance = 8)
    # 2D
    if len(coordinates)==0 or is2d:
        line = np.sum(img,axis=1)
        dd = np.argmax(line)

        c1 = peak_local_max(invert[dd],exclude_border=30,min_distance = 12)
        coordinates = np.array([dd,int(c1)]).reshape(1,-1)
    if plot:
        plt.figure()
        plt.imshow(img)
        plt.plot(coordinates[:, 1], coordinates[:, 0], 'r.')
    return coordinates

t = np.linspace(0, 2*np.pi, 100)

nx = 100
nz = 250*2/3
ratio = (nx - cropx*2)/(nz-2*cropz)
print("ratio",ratio)
fig,axes = plt.subplots(2,3,sharex = "col",gridspec_kw={"height_ratios":[ratio,1]})

# 1 rad of tip
offset = 172
for j,p in enumerate(profiles_xz):
    coords = find_zcentre(p, plot=True, is2d = j==1)
    # plt.figure()
    axes[1,j].imshow(p,cmap="magma")
    
    u= coords[0,1] + offset/psize    #x-position of the center
    v= coords[0,0]    #y-position of the center
    a = conf_xy/2/psize     #radius on the x-axis
    b = conf_xz/2/psize    #radius on the y-axis

    axes[1,j].plot( u+a*np.cos(t) , v+b*np.sin(t),color="white",linestyle="--")
    axes[1,j].axis("off")
    # plt.savefig(names[j]+"_tip_xz.svg",transparent=True)
    
for j,p in enumerate(profiles_xy):
    coords = find_zcentre(p,plot=True,gfilt = True)

    axes[0,j].imshow(p,cmap="magma")
    
    u= coords[0,1] + offset/psize  #x-position of the center
    v= coords[0,0]    #y-position of the center
    a = conf_xy/2/psize     #radius on the x-axis
    b = conf_xy/2/psize#radius on the y-axis

    axes[0,j].plot( u+a*np.cos(t) , v+b*np.sin(t),color="white",linestyle="--")
    axes[0,j].axis("off")
    # plt.savefig(names[j]+"_tip_xy.svg",transparent=True)

# scalear
sbarlen = 1000 #nm
nsc = 25
xsc = 40+np.linspace(0,sbarlen/psize,nsc)
ysc = np.ones_like(xsc) * 0.9 * profiles_xz[-1].shape[0]
axes[-1,-1].plot(xsc,ysc,color="w")

fig.savefig("tip.svg",transparent = True)