%% Fit the 2 group model to Japan data

%%%%%%%%% Parameters %%%%%%%%%%%%%%%%%%%%%%%
load('Japan_2009_flu.mat');
K = 100;
x = length(dat);
N_tot = 4.7e06;
Im_frac = 0.3149;
N = N_tot*[Im_frac,1-Im_frac];
beta = 2;
mu = [7/3,7/7];
kappa = 7/4;
C0 = dat(1,2);
col = [1 0 0];

%%%%%%%%% Initial Fit %%%%%%%%%%%%%%%%%%%%%%%
[theta,C,t,err,res] = fit_func_2(dat,N_tot,N,Im_frac,beta,mu,kappa,1,0,0,0,0,3,col,x,1,C0);
N_s = theta(1)
beta_s = theta(2);
sigma = sqrt(err/(x-2))
% sigma = mean(abs(res));
nu = Im_frac + sum(C)/N_s
AIC_2 = x*log(err/x) + 3
R0_s = R0_func(beta_s,mu,N_s,Im_frac);

%% Generate Synthetic Data  %%%%%%%%%%%%%%%%%%%%%%%%%%%%

lC = length(C);
S = zeros(lC,K);
for k = 1:K
S(:,k) = noisy_data(C,sigma);
end 

%% Fit Model to Synthetic Data  %%%%%%%%%%%%%%%%%%%%%%%%%%%%

beta_vec_2 = zeros(K,1);
N_vec_2 = zeros(K,1);
C_vec_2 = zeros(K,lC);
AIC_vec_2 = zeros(K,1);

tic
for k = 1:K
    if rem(k/10,1) == 0
        k
    end
    dat = [dat(:,1),S(:,k)];
    [theta,C,~,err] = fit_func_2(dat,N_tot,N,Im_frac,beta,mu,kappa,1,0,0,0,0,3,col,x,1,C0);
    N_vec_2(k) = theta(1);
    beta_vec_2(k) = theta(2);
    C_vec_2(k,:) = C;
    AIC_vec_2(k) = x*log(err/x) + 3;
end
toc

%% Plot Simulation Results %%%%%%%%%%%%%%%%%%

ci = 95;
[yl,yu,bl,bu,m] = shade(C_vec_2,sigma,ci);
plot_shade(dat(:,1)',yl,yu,bl,bu,2,0);
p2 = plot(t,m,'r','linewidth',1.5);
load('Japan_2009_flu.mat');
p1 = scatter(dat(:,1),dat(:,2),30,'k','filled');
legend([p1 p2],{'Data','2-Group Model'},'FontSize',18);

[beta_CI,beta_m] = conf_int(beta_vec_2,ci);
[N_CI,N_m] = conf_int(N_vec_2,ci);
[AIC_CI,AIC_m] = conf_int(AIC_vec_2,ci);
R0_vec = R0_func(beta_vec_2,mu,N_vec_2,Im_frac);
[R0_CI,R0_m] = conf_int(R0_vec,ci);

%% R0, N scatter plot %%%%%%%%%%%%
figure 
hold on
scatter(N_vec_2,R0_func(beta_vec_2,mu,N_vec_2,Im_frac),10,'k','filled');
scatter(N_s,R0_func(beta_s,mu,N_s,Im_frac),100,'r','filled');
xlabel('Effective Population Size $N$','FontSize',22)
ylabel('Basic Reproduction Number $R_0$','FontSize',22)
% xlim([2.9e7, 5.5e7])
% ylim([3.4e-8, 5.4e-8])
box on
rho = corr(beta_vec_2,N_vec_2,'type','Spearman');	
shg

%% Plot Residuals %%%%%%%%%%%%
scatter(dat(:,1),res,30,'k','filled');
xlabel('Time (Weeks)','FontSize',22)
ylabel('Residuals','FontSize',22)
box on

figure
scatter(dat(:,2),res,30,'k','filled');
xlabel('Model Value (Cases)','FontSize',22)
ylabel('Residuals','FontSize',22)
box on