%% Calculate the accuracy of the 1G and 2G model forecasts as the amount of initial data used for real-time fitting increases
% load('sim.mat')

x_vec = 6:2:30;
nx = length(x_vec);
n = 25;
Im0_2G = 0.7343;
kappa = 7/4;
mu2G = [7/3,7/7];
mu1G = 7/7;
ti = 0:1:50;
inc = 1; 
I0 = 10000;
C_min = -1;

a1 = size(beta_vec_1,2);
a2 = size(beta_vec_2,2);
err_1 = zeros(nx,a1);
err_2 = zeros(nx,a2);
Cs = dat(:,2);

tic
for z = 1:2
% Simulate 1-Group Model Forecasts %%%%%%%%%%%%%%%%%%%%%%%

if z== 1
    for j = 1:nx
        j
    for i = 1:a1
    x = x_vec(j);
    beta = beta_vec_1(x,i);
    N = N_vec_1(x,i);
    
    if beta == 0
        break
    end
   
    Y0 = IC1(I0,N,kappa);
    [C1,~] = dis_cases_1(beta,mu1G,kappa,ti,Y0,C_min,inc);
    err_1(j,i) = abs(sum(C1)- sum(Cs))/sum(Cs);
    end
    end


% Simulate 2-Group Model Forecasts %%%%%%%%%%%%%%%%%%%%%%%
nu = Survival(n,Im0_2G);
elseif z == 2
    for j = 1:nx
    for i = 1:a2
    x = x_vec(j);
    beta = beta_vec_2(x,i);
    Nt = N_vec_2(x,i);
    
    if beta == 0
        break
    end
    
    N = [nu*Nt,(1-nu)*Nt];
    Y0 = IC2(I0,N,kappa);
    [C2,~] = dis_cases_2(beta,mu2G,kappa,ti,Y0,C_min,inc);
    err_2(j,i) = abs(sum(C2)- sum(Cs))/sum(Cs);
    end
    end
end
end
toc

%% Calculate Mean and CI Errors
r1 = 0.025;
r2 = 0.975;

for j = 1:nx
    A = err_1(j,:);
    A(A==0) = [];
    A = sort(A);  
    la = length(A);
    lq = max(floor(r1*la),1);
    uq = min(ceil(r2*la + 1),la);   
    err_1_m(j) = mean(A); 
    err_1_lq(j) = A(lq);
    err_1_uq(j) = A(uq);   
    
    B = err_2(j,:);
    B(B==0) = [];
    B = sort(B);   
    lb = length(B);
    lq = max(floor(r1*lb),1);
    uq = min(ceil(r2*lb + 1),lb);    
    err_2_m(j) = mean(B);
    err_2_lq(j) = B(lq);
    err_2_uq(j) = B(uq);   
end

%%
close all 
hold on
% q1 = patch([x_vec fliplr(x_vec)], [err_1_lq fliplr(err_1_uq)],[0,0,1],'Facecolor',[152, 188, 214]/256);
% alpha(q1,0.5);
% q2 = patch([x_vec fliplr(x_vec)], [err_2_lq fliplr(err_2_uq)],[0,0,1],'Facecolor',[214,152,208]/256);
% alpha(q2,0.5);

p1 = plot(x_vec,err_1_m,'b','linewidth',1.5);
p2 = plot(x_vec,err_2_m,'r','linewidth',1.5);

xlabel('Number of Data Points','Fontsize',22)
ylabel('Normalized Total Cases Error','Fontsize',22)
legend([p1 p2],{'1-Group Model','2-Group Model'},'FontSize',18);
xlim([6 30])
ylim([0 1.2])
box on
shg
