%% Plots forecasts of a future epidemic in advance using a prior

clear all
close all
%% Parameters %%%%%%%%%%%%%%%%%%%%%%%
S = 1;
fig = 5;
sav = 0;
x = 0;
n = 75;
y_top = 2e6;
ci = 95;
sigma = 120;

%%%%%%%%% Gamma Prior Parameters %%%%%%%%%%%%%%%%%%%%%%%
sc1 = 5;
sc2 = 5;
sc = sc1;

%%
M = 1e6;
s = 50;
l = 500;
[N1_m, beta1_m] = par_func1;
[N2_m, beta2_m] = par_func2(S);

var_N1 = sc1*1e13;
var_beta1 = sc1*5e-3; 
var_N2 = sc2*1e13;
var_beta2 = sc2*5e-3; 

X1 = linspace(1e7,7e7,l);
X2 = linspace(1.2,2.4,l);
X = [X1;X2;X1;X2];
Y = zeros(4,l);

[N_vec_1,Y(1,:)] = func(N1_m,var_N1,M,s,X1);
[beta_vec_1,Y(2,:)] = func(beta1_m,var_beta1,M,s,X2);
[N_vec_2,Y(3,:)] = func(N2_m,var_N1,M,s,X1);
[beta_vec_2,Y(4,:)] = func(beta2_m,var_beta1,M,s,X2);

%% Fixed Parameters
[Nt_s, beta_s, Im0_2G] = par_func2(S);
kappa = 7/4;
mu2G = [7/3,7/7];
mu1G = 7/7;
nu = Survival(n,Im0_2G);
ti = 0:1:50;
ti_a = ti(1:x+1);
ti_b = ti(x+1:length(ti));
tf = 0:ti(length(ti))-x;
inc = 1; 
C0 = 1e4;
C_min = 0;
model_name = {'1G','2G'};

%% Initialise Vectors
a1 = length(beta_vec_1);
b1 = length(N_vec_1);
a2 = length(beta_vec_2);
b2 = length(N_vec_2);
ll = length(ti);
la = length(ti_a);
lb = length(ti_b);
C_vec_1a = zeros(a1*b1,la);
C_vec_1b = zeros(a1*b1,lb);
C_vec_2a = zeros(a2*b2,la);
C_vec_2b = zeros(a2*b2,lb);

%% Simulate Data %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

N = [nu*Nt_s,(1-nu)*Nt_s];
Y0 = IC2(C0,N,kappa,mu2G,0);
[C2,t2] = dis_cases_2(N,beta_s,mu2G,kappa,ti,Y0,C_min,inc,C0);
dat = [t2',C2];
d2 = C2([x+2:length(t2)]);
C0f = C2(x+1);

tic
close all
for z = 1:2
%%%%%%%%%%% Simulate 1-Group Model Forecasts %%%%%%%%%%%%%%%%%%%%%%%
figure
if z == 1

% if x ~= 0    
% for i = 1:a1
%     beta = beta_vec_1(i);
%     for j = 1:b1
%         j;
%     N = N_vec_1(j);
%     Y0 = IC1(C0,N,kappa,mu1G,0); 
%     [C1,~] = dis_cases_1(N,beta,mu1G,kappa,ti_a,Y0,C_min,inc,C0);
%     C_vec_1a((i-1)*b1 + j,:) = C1;
%     end
% end
% 
% [yl,yu,bl,bu,m] = shade(C_vec_1a,sigma,ci);
% plot_shade(ti_a,yl,yu,bl,bu,1,0);
% plot(ti_a,m,'b','linewidth',1.5);
% end

for i = 1:a1
    beta = beta_vec_1(i);
    for j = 1:b1
        j;
    N = N_vec_1(j);
    Y0 = IC1(dat(:,2),N,kappa,mu1G,x); 
    [C1,~] = dis_cases_1(N,beta,mu1G,kappa,tf,Y0,-1,inc,C0f);
    C_vec_1b((i-1)*b1 + j,:) = C1;
    d1 = C1([2:length(d2)+1]);
    D1(i,j) = err_func_MCMC(d1,d2,3,sigma);
    end
end

%%%%% Mean Posterior Parameters %%%%%%%%%
    beta = mean(beta_vec_1);
    N = mean(N_vec_1);
    Y0 = IC1(dat(:,2),N,kappa,mu1G,x); 
    [C,~] = dis_cases_1(N,beta,mu1G,kappa,tf,Y0,0,1,C0f);
    d1 = C([2:length(d2)+1]);
    D1m = err_func_MCMC(d1,d2,3,sigma); 
    
%%%%% Plot %%%%%%%%%

[yl,yu,bl,bu,m] = shade(C_vec_1b,sigma,ci);
plot_shade(ti_b,yl,yu,bl,bu,1,0);
p1 = plot(ti_b,m,'b','linewidth',1.5);


%%%%%%%%%%% Simulate 2-Group Model Forecasts %%%%%%%%%%%%%%%%%%%%%%%
elseif z == 2
    
% if x ~= 0        
% for i = 1:a2
%     beta = beta_vec_2(i);
%     for j = 1:b2
%         j;
%     Nt = N_vec_2(j);
%     N = [nu*Nt,(1-nu)*Nt];
%     Y0 = IC2(C0,N,kappa,mu2G,0);
%     [C2,~] = dis_cases_2(N,beta,mu2G,kappa,ti_a,Y0,C_min,inc,C0);
%     C_vec_2a((i-1)*b2 + j,:) = C2;
%     end
% end
% 
% [yl,yu,bl,bu,m] = shade(C_vec_2a,sigma,ci);
% plot_shade(ti_a,yl,yu,bl,bu,2,0);
% plot(ti_a,m,'r','linewidth',1.5);
% end

for i = 1:a2
    beta = beta_vec_2(i);
    for j = 1:b2
        j;
    Nt = N_vec_2(j);
    N = [nu*Nt,(1-nu)*Nt];
    Y0 = IC2(dat(:,2),N,kappa,mu2G,x);
    [C2,~] = dis_cases_2(N,beta,mu2G,kappa,tf,Y0,C_min,inc,C0f);
    C_vec_2b((i-1)*b2 + j,:) = C2;
    d1 = C2([2:length(d2)+1]);
    D2(i,j) = err_func_MCMC(d1,d2,3,sigma);
    end
end

%%%%% Mean Posterior Parameters %%%%%%%%%%%%%%%%%%%%%%%%%
    beta = mean(beta_vec_2);
    N = mean(N_vec_2);
    N_2G = N*[nu,1-nu];
    Y0 = IC2(dat(:,2),N_2G,kappa,mu2G,x); 
    [C,~] = dis_cases_2(N_2G,beta,mu2G,kappa,tf,Y0,0,1,C0f);
    d1 = C([2:length(d2)+1]);
    D2m = err_func_MCMC(d1,d2,3,sigma);

%%%%% Plot %%%%%%%%%

[yl,yu,bl,bu,m] = shade(C_vec_2b,sigma,ci);
plot_shade(ti_b,yl,yu,bl,bu,2,0);
p1 = plot(ti_b,m,'r','linewidth',1.5);

end

%% Figure Details %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

if x == 0
    p3 = scatter(ti,dat(:,2),30,[100,100,100]/256,'filled');
else
    p2 = scatter(ti_a,dat([1:x+1],2),30,'k','filled');
    p3 = scatter(dat([x+2:length(dat)],1),dat([x+2:length(dat)],2),30,[100,100,100]/256,'filled');
end

xlabel('Time (Weeks)','Fontsize',22)
ylabel('Number of Cases','Fontsize',22)
xlim([0 ti(length(ti))]);
ylim([0 y_top])
yline(y_top,'k');

if x == 0
else
xline(x,'k--','linewidth',2);
end

if x == 0 
    if z == 1
legend([p3 p1],{'Future Data','1-Group Model'},'FontSize',18);
elseif z == 2
legend([p3 p1],{'Future Data','2-Group Model'},'FontSize',18);
    end
    else
if z == 1
legend([p2 p3 p1],{'Observed Data','Future Data','1-Group Model'},'FontSize',18);
elseif z == 2
legend([p2 p3 p1],{'Observed Data','Future Data','2-Group Model'},'FontSize',18);
end
end
box on

if sav == 1
if fig == 4
filename = [pwd '/Figures/Fig_4/',char(model_name(z)),'_n',num2str(n),'_sc',num2str(sc)];
saveas(gca,[pwd '/Figures/Fig_4/',char(model_name(z)),'_n',num2str(n),'_sc',num2str(sc),'.fig']);
print(gcf,filename,'-depsc')
elseif fig == 5
% filename = [pwd '/Figures/Fig_5/',char(model_name(z)),'_x',num2str(x),'_sc',num2str(sc)];
% saveas(gca,[pwd '/Figures/Fig_5/',char(model_name(z)),'_x',num2str(x),'_sc',num2str(sc),'.fig']);
filename = [char(model_name(z)),'_x',num2str(x),'_sc',num2str(sc),'_S',num2str(S)];
saveas(gca,[char(model_name(z)),'_x',num2str(x),'_sc',num2str(sc),'_S',num2str(S),'.fig']);
print(gcf,filename,'-depsc')
end
end

shg

end

%%%%% DIC %%%%%%%%%%%%%%%%%%%%%%%%%
DIC1 = 2*mean(mean(D1)) - D1m
DIC2 = 2*mean(mean(D2)) - D2m
toc

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    
function [vec,Y] = func(mu,v,M,s,X)

k = (mu^2)/v;
theta = v/mu;
Y = gampdf(X,k,theta);
R = gamrnd(k,theta,M,1);
R = sort(R);
vec = zeros(s,1);

for i = 1:M
    if floor(i/(M/s)) == i/(M/s)
        vec(i/(M/s)) = R(i - M/(2*s));
    end
end
end
