'''
###############################################################################
Code written by Jacob Blackmore and William Hughes for Arxiv 2112:05795

This code generates figure 9 showing the effect of different cooperativity
on the rate at which photons can be extracted from the cavity for
a sin^2 driving function

outputs are shown to the user using matplotlib and saved in png and pdf formats

requires the cavity module from /dist or DOI:10.5281/zenodo.7020047

optionally uses the ox_plot plotting module, which is included.
###############################################################################
'''

import numpy as np
import matplotlib.pyplot as plt
import json
from jqc import Ox_plot


def rescale_an_axis_within_a_figure(ax, x_rescale_factor=1.0, y_rescale_factor=1.0):
    """
    Rescales an axis within a figure
    :param ax: The ax object to rescale
    :param x_rescale_factor: The multiplicative factor to rescale in the x direction
    :param y_rescale_factor: The multiplicative factor to rescale in the y direction
    """
    l, b, w, h = ax.get_position().bounds
    new_position = [l -0.5*(x_rescale_factor-1.0)*w, b-0.5*(y_rescale_factor-1.0)*h, x_rescale_factor*w, y_rescale_factor*h]
    ax.set_position(new_position)

def convert_number_to_string(number, maximum_significant_figures=3):
    """
    This function converts a number to a string
    The complexity is if the number does not have as many digits as the number of significant figures
    In this case, the function usually returns a list (not what we want)
    Therefore in this case, the code is called recursively with one lower significant figure until the stringing works
    properly
    :param number: The number to string
    :param maximum_significant_figures: The maximum number of significant figures of this string
    :return: The number in string format
    """
    number_string = "{:.{sig_figs}}".format(number, sig_figs=maximum_significant_figures)
    if isinstance(number_string, str):
        return number_string
    else:
        return convert_number_to_string(number, maximum_significant_figures=maximum_significant_figures-1)

def json_numpy_obj_hook(dct):
    if isinstance(dct, dict) and '__ndarray__' in dct:
        if '__complex__' in dct:
            complex_number = dct['real']+1j*dct['imag']
            return_value = np.array(complex_number)
            if np.size(return_value)==1:
                return complex_number
            else:
                return np.array(complex_number)
        else:
            return np.array(dct['values'])
    return dct

def load(*args, **kwargs):
    kwargs.setdefault('object_hook', json_numpy_obj_hook)
    return json.load(*args, **kwargs)

def json_numpy_read_from_file(file_name):
    """
    The function that reads a JSON file written with a special encoding
    :param file_name: The file where the data is written (a string ending in .json)
    :return: The data that has been read (it will be a dictionary)
    """
    with open(file_name, 'r') as read_file:
        output = load(read_file, object_hook=json_numpy_obj_hook)
        return output

def read_study_finite_time_probability_drop_fixed_kappa(file_path, fontsize_plot=12, labelsize_plot=12,
                                                        tick_size_plot=5, x_ticks=None, y_ticks=None,
                                                        axis_scaling_x=1.0, axis_scaling_y=1.0,
                                                        colour_list = ['r','b', 'y'], show=True):
    datasets = (json_numpy_read_from_file(file_path))['data']
    sorted_data = {}
    colour_dictionary = {}
    # used gamma is essentially a way of keeping track of which gamma values are currently in the plot- the method
    # wants to have just one colour for each gamma value, but each gamma value will be repeated across several
    # cooperativities, so there needs to be some clever sorting
    used_gamma = []
    for dataset in datasets:
        data = dataset['data']
        t_range = data['t_range']
        max_outputs = data['max_outputs']

        gamma = data['gamma']
        cooperativity = data['cooperativity']
        cooperativity_string = convert_number_to_string(cooperativity)
        gamma_string = convert_number_to_string(gamma)
        """
        Here proceeds a complicated method to have a look at all the data and sort into categories based on gamma and
        cooperativity. There was a time when I knew what it did, but that is a long time ago
        """
        if gamma_string not in colour_dictionary:
            number_of_assigned_colours = len(colour_dictionary)
            colour_dictionary.update({gamma_string: colour_list[number_of_assigned_colours]})
        gamma_dictionary = {'max_outputs': max_outputs, 't_range': t_range}
        if cooperativity_string not in sorted_data:
            sorted_data.update({cooperativity_string: {gamma_string: {'max_outputs': max_outputs, 't_range': t_range}}})
        if cooperativity_string in sorted_data:
            (sorted_data[cooperativity_string]).update({gamma_string: gamma_dictionary})
    fig, ax = plt.subplots()

    Ox_plot.plot_style()
    max_t = 0
    # Having sorted data into cooperativity classes, plot each class as a batch
    for cooperativity_string in sorted_data:
        cooperativity = float(cooperativity_string)
        infinite_time_output_probability = np.divide(2*cooperativity,(2*cooperativity+1))
        for gamma_string in sorted_data[cooperativity_string]:
            max_outputs = ((sorted_data[cooperativity_string])[gamma_string])['max_outputs']
            t_range = ((sorted_data[cooperativity_string])[gamma_string])['t_range']
            if gamma_string not in used_gamma:
                used_gamma.append(gamma_string)
                line_label = r'$\gamma$ = '+gamma_string+r' $\kappa$'
            else:
                line_label = ''
            ax.plot(t_range, max_outputs, label=line_label, color=colour_dictionary[gamma_string])
            if np.amax(t_range)> max_t:
                max_t = np.amax(t_range)
        ax.plot([0.0, max_t],[infinite_time_output_probability, infinite_time_output_probability],'--', color='k')
        ax.legend(frameon=False)
        ax.set_xlabel(r'Time for Output ($\kappa^{-1}$)')
        ax.set_ylabel('Output Probability')
        ax.set_xlim([np.amin(t_range), np.amax(t_range)])
        ax.set_ylim([0.0,1.0])
        ax.text(0.5, 0.2, '$C$ = 0.1', horizontalalignment='center',
                            verticalalignment='center', transform=ax.transAxes)
        ax.text(0.5, 0.57, '$C$ = 1', horizontalalignment='center',
                            verticalalignment='center', transform=ax.transAxes)
        ax.text(0.5, 0.89, '$C$ = 10', horizontalalignment='center',
                            verticalalignment='center', transform=ax.transAxes)
        if x_ticks is not None:
            ax.set_xticks(x_ticks)
        if y_ticks is not None:
            ax.set_yticks(y_ticks)
        rescale_an_axis_within_a_figure(ax, x_rescale_factor=axis_scaling_x, y_rescale_factor=axis_scaling_y)
    if show:
        plt.show()
    return fig, ax

if __name__ == '__main__':
    from pathlib import Path
    import os
    cwd =os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
    # edit this path to the desired path
    file_path = cwd+'\\data\\driving\\sine_square_pulse_evaluation_multi_dataset_2.json'
    fig, ax = read_study_finite_time_probability_drop_fixed_kappa(file_path,
                                        colour_list=[Ox_plot.colours['red'],
                                                    Ox_plot.colours['ox blue'],
                                                    Ox_plot.colours['orange']],
                                        x_ticks=[0,5.0,10.0,15.0,20.0],
                                        y_ticks=[0.0,0.2,0.4,0.6,0.8,1.0],
                                         show=False)

    plt.tight_layout()

    plt.savefig("pdf\\sine_squared_comparison.pdf")
    plt.savefig("png\\sine_squared_comparison.png")

    plt.show()
