%% Fit data to the 2-group severe SEIR model 

function [theta,C,t,err,res,s_sq,L] = fit_func_2(dat,N_tot,N,Im_frac,beta,mu,kappa,p,dis_data,dis,scat,titl,f,col,x,v,C0)

dat_full = dat;
dat = dat_full([1:x],:);
ti = dat(:,1);
d1 = dat(:,2);
tl = dat_full(:,1);
d1l = dat_full(:,2);
inc = ti(2)-ti(1);

% Initial parameter guesses
beta_0 = beta;
N_Im_0 = N_tot/2;
N_tot_0 = N_tot;
mu_I0 = mu(1);
mu_N0 = mu(2);
Im_frac_0 = Im_frac;

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

% beta, N fitting parameters
if p == 1
theta_0 = [N_tot_0,beta_0];

if v == 1
options = optimset('Display','off');
theta = fminsearch(@(theta) odefit(theta,d1,kappa,mu,Im_frac,ti,f), theta_0,options);

elseif v == 2
options = optimoptions(@fmincon,'Algorithm','interior-point','Display','off');
theta = fmincon(@(theta) odefit(theta,d1,kappa,mu,Im_frac,ti,f), theta_0, [],[],[],[],[10e6,0],[130e6,Inf],[],options);

end

beta = abs(theta(2));
N_tot = abs(theta(1));
N = N_tot*[Im_frac,1-Im_frac];
end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Solve best fit model
Y0 = IC2(d1(1),N,kappa,mu,0);
[C,t] = dis_cases_2(N,beta,mu,kappa,tl,Y0,-1,inc,C0); 
[err,res] = err_func(d1l,C,f);
[~,~,L] = err_func_MCMC(d1l,C,f,120);
s_sq = err/(length(d1)-2);

%%%%% Plot %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

if dis_data == 1
    hold on
    if scat == 1
    p1 = scatter(dat_full(:,1),dat_full(:,2),30,'k','filled');
    else
    p1 = plot(dat_full(:,1),dat_full(:,2),'k','Linewidth',1.5);
    end
box on 
end

if dis == 1
hold on
p2 = plot(t,C,'Color',[col],'Linewidth',1.5);
xlabel('Time (Weeks)','FontSize',22)
ylabel('Number of Cases','FontSize',22)
if titl == 1
title(['$N_I$ = ', num2str(100*Im_frac), '$\%$'],'Fontsize',22);
end
xlim([tl(1) tl(length(tl))])
shg
legend([p1 p2],{'Data','2-Group Model'},'FontSize',18);
end


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

function err = odefit(theta,d1,kappa,mu,Im_frac,ti,f)
beta = abs(theta(2));
N_tot = abs(theta(1));
N = N_tot*[Im_frac,1-Im_frac];
Y0 = IC2(d1(1),N,kappa,mu,0);
[C,t] = dis_cases_2(N,beta,mu,kappa,ti,Y0,0,inc,C0); 

l1 = length(d1);
l2 = length(C);

if l1 ~= l2
    err = 1e10;
else
    err = err_func(d1,C,f);
end

end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%


end