"""
Visualise results from various methods
"""

from __future__ import division, print_function, absolute_import, unicode_literals
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn import metrics
from scipy import stats


folder1 = "data/"
folder2 = "data/"


def confidence_interval95(list1):
    """
    Calculate the confidence interval

    @param list1:
    @return:
    """

    var1 = np.std(list1)
    N = len(list1)
    se2 = var1/N
    CSe = 1.96*np.sqrt(se2)
    return CSe


def results2():
    """

    @return:
    """

    df4 = pd.read_pickle(folder1 + "ri2_parts_t015_g0_rsr1.pkl")

    df5 = pd.read_pickle(folder1 + "ri2_parts_t015_g1_rsr1.pkl")

    # Print cases:
    print("Male: %i and Female: %i cases " % (np.sum(df4["sex"] == 'M'), np.sum(df4["sex"] == 'F')))

    # Figure 1: Dice of supervoxels vs parts compared to both original and processed delineations.

    # Figure 2: Running both male and female together vs separation
    plt.figure()
    plt.title('Figure 21b')
    box1 = plt.boxplot([df4.loc[df4['sex'] == 'F']['dice_original'].values.astype(np.float),
                        df5.loc[df5['sex'] == 'F']['dice_original'].values.astype(np.float),
                        df4.loc[df4['sex'] == 'M']['dice_original'].values.astype(np.float),
                        df5.loc[df5['sex'] == 'M']['dice_original'].values.astype(np.float)],
                       labels=['Female', 'Female \n (separate model)', 'Male', 'Male \n (separate model)'])

    for box in box1['boxes']:
        box.set(linewidth=2)
    for whisker in box1['whiskers']:
        whisker.set(linewidth=2)
    for cap in box1['caps']:
        cap.set(linewidth=2, color='blue')
    for median in box1['medians']:
        median.set(linewidth=2)
    for flier in box1['fliers']:
        flier.set(marker='o', alpha=0.5)
    plt.ylabel('DSC', fontsize=25)
    plt.ylim([0, 1])
    plt.xticks(fontsize=20)
    # plt.savefig("test.eps", format='eps')


def results1():
    """

    @return:
    """

    # Input date
    df1_parts = pd.read_pickle(folder1 + "ri2_parts_t015_g0_rsr1.pkl")
    df1_supervoxel = pd.read_pickle(folder1 + "ri1_lda.pkl")

    # Figure 1: Dice of supervoxels vs parts compared to both original and processed delineations.
    plt.figure()
    plt.title('Figure 19b')
    box1 = plt.boxplot([df1_supervoxel[df1_supervoxel['dice_original'] >= 0.1]['dice_original'].values.astype(np.float),
             df1_parts[df1_parts['dice_original'] >= 0.1]['dice_original'].values.astype(np.float)],
            labels=['Supervoxels', 'Parts'],
            patch_artist=True)
    plt.ylabel('DSC', fontsize=25)
    plt.xticks(fontsize=25)
    for box in box1['boxes']:
        box.set(linewidth=2)
    for whisker in box1['whiskers']:
        whisker.set(linewidth=2)
    for cap in box1['caps']:
        cap.set(linewidth=2, color='blue')
    for median in box1['medians']:
        median.set(linewidth=2)
    for flier in box1['fliers']:
        flier.set(marker='o', alpha=0.5)
    plt.ylim([0, 1])

    print("Pieces of parts (detected)")
    print("Median: ", df1_parts[df1_parts['dice_original']>0.1]['dice_original'].median())

    print("Pieces of parts (all)")
    print("Median: ", df1_parts['dice_original'].median())

    print("var: ", df1_parts[df1_parts['dice_original']>0.1]['dice_original'].std())

    print("Perfusion-supervoxels (detected)")
    print("Median ", df1_supervoxel[df1_parts['dice_original']>0.1]['dice_original'].median())

    print("var ", df1_supervoxel[df1_parts['dice_original']>0.1]['dice_original'].std())

    print("Perfusion-supervoxels (all)")
    print("Median: ", df1_supervoxel['dice_original'].median())


    # Figure 2
    df2 = pd.concat([df1_supervoxel['dice_original'], df1_parts['dice_original']],  axis=1)
    df2.columns = ['supervoxel', 'parts']
    ax2 = df2.plot(kind='bar',
                   color=['lightgreen', 'blue'],
                   fontsize=20,
                   title='Figure 18 without intermediary step')
    ax2.set_ylabel('DSC', fontsize=25)
    ax2.set_xlabel('Cases', fontsize=25)

    # Number of cases
    plt.figure()
    plt.title('Figure 19a')
    p1 =plt.bar([1], np.sum(df1_supervoxel["dice_original"]>0.2), color='lightgreen', width=0.3)
    p2 =plt.bar([2], np.sum(df1_parts["dice_original"]>0.2), color='blue', width=0.3)
    plt.ylabel('Cases', fontsize=20)
    ind = np.array([1, 2])
    plt.xticks(ind+0.3/2., ('Perfusion \n supervoxels', 'Pieces-of-parts'), fontsize=20)
    # plt.legend((p1[0], p2[0]), ('Perfusion-supervoxels', 'Pieces-of-parts'))
    plt.ylim([0, 23])
    plt.xlim([0, 3])

    mean1 = np.mean(df1_parts["dice_original"]-df1_supervoxel["dice_original"])
    print("Improvement: ", mean1)


def results1b():
    """
    @return:
    """

    df1_parts = pd.read_pickle(folder1 + "ri2_parts_t015_g0_rsr1.pkl")
    df1_supervoxel = pd.read_pickle(folder1 + "ri1_lda.pkl")
    df1_supervoxel_norem = pd.read_pickle(folder1 + "ri1_lda_norem.pkl")

    # Extracting for stats
    x1 = df1_parts['dice_original'].sort_index()
    x2 = df1_supervoxel['dice_original'].sort_index()
    x3 = df1_supervoxel_norem['dice_original'].sort_index()

    print("Means")
    print(np.mean(x2))
    print(np.mean(x1))

    # Paired students t-test
    mw1, p1 = stats.ttest_rel(x1, x3)
    print("Paired students t-test (parts vs noproc): ", mw1, ", p value: ", p1)
    mw1, p1 = stats.ttest_rel(x1, x2)
    print("Paired students t-test (parts vs proc): ", mw1, ", p value: ", p1)

    # Wilcoxon signed-rank test
    mw1, p1 = stats.wilcoxon(x1, x3)
    print("Wilcoxon signed-rank test (parts vs noproc): ", mw1, ", p value: ", p1)
    mw1, p1 = stats.wilcoxon(x1, x2)
    print("Wilcoxon signed-rank test (parts vs proc): ", mw1, ", p value: ", p1)

    print("Pieces of parts (detected)")
    print("Median: ", df1_parts[df1_parts['dice_original']>0.1]['dice_original'].median())

    print("Pieces of parts (all)")
    print("Median: ", df1_parts['dice_original'].median())

    print("var: ", df1_parts[df1_parts['dice_original'] > 0.1]['dice_original'].std())

    print("Perfusion-supervoxels (detected)")
    print("Median ", df1_supervoxel[df1_parts['dice_original'] > 0.1]['dice_original'].median())

    print("var ", df1_supervoxel[df1_parts['dice_original'] > 0.1]['dice_original'].std())

    print("Perfusion-supervoxels (all)")
    print("Median: ", df1_supervoxel['dice_original'].median())

    # bar plot using pandas interface
    df2 = pd.concat([df1_supervoxel['dice_original'], df1_supervoxel_norem['dice_original'], df1_parts['dice_original']],  axis=1)
    df2.columns = ['supervoxel', 'supervoxel(no post)', 'parts']
    ax2 = df2.plot(kind='bar', color=['lightgreen', 'lightgrey', 'blue'],
                   fontsize=20,
                   width=0.6,
                   title='Figure 18')

    ax2.set_ylabel('DSC', fontsize=25)
    ax2.set_xlabel('Cases', fontsize=25)


def results3():
    """
    Compare to the interrater dice overlap scores for 10 cases

    @return:
    """

    df1 = pd.read_pickle(folder1 + "ri2_parts_t015_g0_rsr1.pkl")

    # Compare
    # Methods
    df2 = pd.read_pickle(folder1 + "annotations.pkl")

    plt.figure()
    plt.title('Figure 20')
    box1 = plt.boxplot([df1['dice_original'][df1['dice_original'] > 0.1].values.astype(np.float),
                        df2['dice_JF_MA'].values.astype(np.float),
                        df2['dice_MB_MA'].values.astype(np.float)],
                       labels=['Parts', 'Inter-rater 1', 'Inter-rater 2'],
                       patch_artist=True)
    plt.ylabel('DSC', fontsize=22)
    plt.xticks(fontsize=22)

    print(df2['dice_JF_MA'].mean(), '+/-', df2['dice_JF_MA'].std())
    print(df2['dice_MB_MA'].mean(), '+/-', df2['dice_MB_MA'].std())

    for box in box1['boxes']:
        box.set(linewidth=2)
    for whisker in box1['whiskers']:
        whisker.set(linewidth=2)
    for cap in box1['caps']:
        cap.set(linewidth=2, color='blue')
    for median in box1['medians']:
        median.set(linewidth=2)
    for flier in box1['fliers']:
        flier.set(marker='o', alpha=0.5)

    box1['boxes'][0].set(facecolor='grey')

    plt.ylim([0, 1])
    plt.show()


def mean_roc(df_roc, color='b', plot_name='Plot 1'):
    """
    Called by roc_mult
    @param df_roc:
    @param color:
    @return:
    """
    # Mean values
    mean1 = np.array([1-df_roc["specificity"].astype(float).mean(level=1).values,
             df_roc["sensitivity"].astype(float).mean(level=1).values])

    mean2 = np.zeros((mean1.shape[0], mean1.shape[1]+2))
    mean2[:, 1:-1] = mean1
    mean2[:, 0] = 1

    mean_auc = metrics.auc(mean2[0], mean2[1], reorder=False)

    # Std
    df_roc2 = df_roc.swaplevel(0, 1, axis=0)
    df_roc2 = df_roc2.sortlevel()

    ci1, ci2 = [], []
    ci1.append(0), ci2.append(0)
    for ii in df_roc2.index.levels[0]:
        test1 = np.array(1 - df_roc2.ix[ii, "specificity"].values)
        test2 = np.array(df_roc2.ix[ii, "sensitivity"].values)

        ci1.append(confidence_interval95(test1))
        ci2.append(confidence_interval95(test2))
    ci1.append(0), ci2.append(0)

    ci2 = np.array(ci2)

    plt.plot(mean2[0], mean2[1],
             label=plot_name + '(AUC = %0.2f)' % mean_auc,
             linewidth=3, color=color, linestyle='-')
    plt.fill_between(mean2[0], mean2[1]-ci2, mean2[1]+ci2, alpha=0.3, color=color)
    print(mean_auc)


def roc_mult():
    """
    Sensitivity, specificity, ROC, confidence interval and AUC
    @return:
    """

    df_roc = pd.read_pickle(folder1 + "ri2_parts_t015_g0_rsr1_roc.pkl")
    df_s_roc = pd.read_pickle(folder1 + "ri2_parts_t015_g0_rsr1_roc_super.pkl")

    plt.figure()
    plt.title('Figure 17')

    for cc, ii in enumerate(df_roc.index.levels[0]):

        h1, = plt.plot(1-df_roc.loc[ii]["specificity"], df_roc.loc[ii]["sensitivity"],
                       color=[0.5, 0.5, 0.5],
                       linestyle='--')
        if cc == 0:
            h1.set_label("Individual cases (parts)")
        else:
            h1.set_label(None)

    plt.xlabel("1-Specificity", fontsize=20)
    plt.ylabel("Sensitivity", fontsize=20)

    mean_roc(df_s_roc, color='b', plot_name='perfusion-supervoxels ')
    mean_roc(df_roc, color='r', plot_name='pieces-of-parts ')

    plt.xlim([0, 1])
    plt.ylim([0, 1])
    plt.legend(loc="lower right", prop={'size': 17})


def results4():
    """
    Box plots:
    - processed vs unprocessed mask
    - Male vs female
    - Supervoxels vs parts

    @return:
    """

    df1_parts = pd.read_pickle(folder2 + "rh2_parts_t015_g0_rsr1.pkl")
    df1_supervoxel = pd.read_pickle(folder2 + "rh1_lda.pkl")

    # Shifted case number by 1
    df2 = pd.concat([df1_supervoxel['dice_original'], df1_parts['dice_original']],  axis=1)
    df2 = df2.set_index(df2.index-1)
    df2.columns = ['supervoxel', 'parts']
    ax2 = df2.plot(kind='bar',
                   color=['lightgreen', 'blue'],
                   fontsize=20,
                   title='Figure 22a')
    ax2.set_ylabel('DSC', fontsize=25)
    ax2.set_xlabel('Cases', fontsize=25)
    ax2.legend(loc='lower right')


def print_some_stats():
    """
    @return:
    """

    df1_parts = pd.read_pickle(folder1 + "ri2_parts_t015_g0_rsr1.pkl")
    df1_supervoxel = pd.read_pickle(folder1 + "ri1_lda.pkl")

    # print(df1_parts)
    # print(df1_supervoxel)
    print("Pieces of parts (detected)")
    print("Median: ", df1_parts[df1_parts['dice_original']>0.1]['dice_original'].median())

    print("Pieces of parts (all)")
    print("Median: ", df1_parts['dice_original'].median())

    print("var: ", df1_parts[df1_parts['dice_original']>0.1]['dice_original'].std())

    print("Perfusion-supervoxels (detected)")
    print("Median ", df1_supervoxel[df1_parts['dice_original']>0.1]['dice_original'].median())

    print("var ", df1_supervoxel[df1_parts['dice_original']>0.1]['dice_original'].std())

    print("Perfusion-supervoxels (all)")
    print("Median: ", df1_supervoxel['dice_original'].median())

if __name__ == "__main__":

    print_some_stats()

    results1()
    results1b()
    results2()
    roc_mult()
    results3()
    results4()
    plt.show()

