# -*- coding: utf-8 -*-
"""
Created on Thu Mar 21 11:53:34 2019

@author: Matthijs de Buck (matthijs.debuck@ndcn.ox.ac.uk)
"""

import numpy as np
import tkinter as tk 
from tkinter import filedialog

#%% Get filepath from a pop-up window (note: can appear behind current window)
def get_path():
    root = tk.Tk()
    root.withdraw()
    root.focus_force()

    filename = filedialog.askopenfilename()
    filename = filename.replace('/','\\')

    return filename 

#%% Compute the complex conjugate
def compl_conj(array):
    array_T = np.transpose(array)
    return np.conj(array_T)

#%% Importing all Q-matrices and corresponding voxel-locations
file_path_Qs = get_path()
file_path_voxels = get_path()

Qs_2D = np.loadtxt(file_path_Qs,dtype=complex)  #2D representation of Q-matrices
voxel_locs = np.loadtxt(file_path_voxels)       #corresponding voxel locations (in meters from the center, (x,y,z))
n_voxels = len(Qs_2D[0,:])                      #total number of voxels
n_channels = np.int16(np.sqrt(len(Qs_2D[:,0]))) #number of transmit channels

Qs = np.zeros((n_channels,n_channels,n_voxels),dtype=np.complex64)
for row in range(n_channels):
    Qs[row,:,:] = Qs_2D[row*n_channels:(row+1)*n_channels,:]    #3D representation of Q-matrices

#%% Determine eigenvalues/eigenvectors of all Q-matrices
eigenvalues = np.zeros((n_channels,n_voxels),dtype=np.complex64)
eigenvectors = np.zeros((n_channels,n_channels,n_voxels),dtype=np.complex64)

for voxel in range(n_voxels):
	# An error warning can show up about casting complex values to real, these are only residuals of rounding off. To get rid of the error,
	# the eigenvalues are first cast to a complex array of which only the real values are selected in the next step.
	eigenvalues[:,voxel], eigenvectors[:,:,voxel] = np.linalg.eig(Qs[:,:,voxel])

eigenvalues = np.real(eigenvalues)
	
peakloc = voxel_locs[np.argmax(eigenvalues)%n_voxels, :]     #voxel-location of the heighest eigenvalue
peak_index_vec = np.unravel_index(np.argmax(eigenvalues, axis=None), eigenvalues.shape)
peakconfig = eigenvectors[:,peak_index_vec[0],peak_index_vec[1]]
eigenvectors_power = eigenvectors * np.abs(eigenvectors)
peak_shim = eigenvectors_power[:,peak_index_vec[0],peak_index_vec[1]]

print('Overall max 10gSAR =',np.max(eigenvalues),'W/kg/1W at voxel-location',peakloc)

print('Peak configuration powers (W), angles =')
print(np.append([np.abs(peak_shim)],[np.angle(peak_shim)*180/np.pi],axis=0))

#%% Computation of CP-mode SAR values
p_configs_CP = 0.125*np.exp(1j*np.arange(8)*np.pi/4)[np.newaxis].T
a_configs_CP = p_configs_CP / np.sqrt(np.abs(p_configs_CP))
a_configs_CP_H = compl_conj(a_configs_CP)

CP_SARs = np.zeros(n_voxels)
for i in range(n_voxels):
    CP_SARs[i] = np.real(np.matmul(a_configs_CP_H,np.matmul(Qs[:,:,i],a_configs_CP)))   #"real" removes minor imaginary components which arise due to discretization

print('Max 10gSAR when operating CP-mode = ',max(CP_SARs),'W/kg')