#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Jun  3 13:41:14 2019

@author: aurelien
"""

import glob
import aotools
from aotools.ext.paths import datapath

from skimage.external import tifffile
import matplotlib.pyplot as plt
from skimage.transform import resize

plt.close("all")

path = datapath+"/2019_04_30/Gold beads/"

files = glob.glob(path+"*.tif")

fdict_xy = {"Astigmatism":{},
         "Coma":{},
         "Sph":{},
         }
fdict_xz = {"Astigmatism":{},
         "Coma":{},
         "Sph":{},
         }
def fill_dict(keyw,fdict):
    if keyw in file:
        if "CH" in file:
            fdict[keyw]["CH"] = file
        elif "3D" in file:
            fdict[keyw]["z"] = file
        elif "2D" in file:
            fdict[keyw]["2D"] = file
        else:
            raise KeyError(file,"does not fit any category")
            
for file in files:
    if "xy" not in file:
        continue
    for key in fdict_xy.keys():
        fill_dict(key,fdict_xy)
            
            
for file in files:
    if "xz" not in file:
        continue
    for key in fdict_xy.keys():
        fill_dict(key,fdict_xz)        
import numpy as np
PSIZE = 30 #nm
def plot_aberration_mode(mode,cropx = 10, cropz = 20,cmap = "magma",save=True, 
                         scalebar = True):
    keys = ["z","2D","CH"]
    
    nx = 100
    nz = 250*2/3
    ratio = (nx - cropx*2)/(nz-2*cropz)
    print("ratio",ratio)
    fig,axes = plt.subplots(2,3,sharex = "col",gridspec_kw={"height_ratios":[ratio,1]})
    
    for j,key in enumerate(keys):
        filexy = fdict_xy[mode][key]
        imgxy = tifffile.imread(filexy)-2**15
        imgxy = imgxy[cropx:-cropx,cropx:-cropx]
        
        
        filez = fdict_xz[mode][key]
        imgxz = tifffile.imread(filez)-2**15
        u,v = imgxz.shape
        imgxz = resize(imgxz,(u*2//3,v))
        imgxz = imgxz[cropz:-cropz,cropx:-cropx]
        print("xz shape",imgxz.shape)
        ax1 = axes[0,j]
        ax2 = axes[1,j]
        ax1.imshow(imgxy,cmap = cmap)
        ax2.imshow(imgxz,cmap = cmap)
        ax1.axis("off")
        ax2.axis("off")
    if scalebar:
        sbarlen = 1000 #nm
        nsc = 25
        xsc = 40+np.linspace(0,sbarlen/PSIZE,nsc)
        ysc = np.ones_like(xsc) * 0.9 * imgxz.shape[0]
        plt.plot(xsc,ysc,color="w")
    if save:
        fig.savefig("depletion_patterns/"+mode+"_wrap.svg",dpi = 600,transparent = True)
        
plot_aberration_mode("Astigmatism")
plot_aberration_mode("Coma")
plot_aberration_mode("Sph")
