%% Fit data to the 1-group SEIR model to severe 2-group data

% x: number of data points the model is fitted to

function [theta,C,t,err,res,one_Total,L] = fit_func(dat,N,Im_frac,beta,mu,kappa,p,dis_data,dis,scat,titl,f,col,x,v,C0)

dat_full = dat;
dat = dat_full([1:x],:);
C_min = -10;
ti = dat(:,1);
d1 = dat(:,2);
tl = dat_full(:,1);
d1l = dat_full(:,2);
inc = ti(2)-ti(1);

% Initial parameter guesses
beta_0 = beta;
N_0 = N;
mu_0 = mu;

% N and beta fitting parameter
if p == 1
theta_0 = [N_0,beta_0];

if v == 1
options = optimset('Display','off');
theta = fminsearch(@(theta) odefit(theta,d1,kappa,mu,ti,f), theta_0,options);

elseif v == 2
    N_l = 1.5e7;
    N_u = 4.5e7;
    beta_l = 3.5e-8;
    beta_u = 6.5e-8;   
options = optimoptions(@fmincon,'Algorithm','interior-point','Display','off');
theta = fmincon(@(theta) odefit(theta,d1,kappa,mu,ti,f), theta_0, [],[],[],[],[0,0],[130e6,1e-6],[],options);

end

N = abs(theta(1));
beta = abs(theta(2));
end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

% Solve and plot best fit model

Y0 = IC1(d1(1),N,kappa,mu,0); 
[C,t] = dis_cases_1(N,beta,mu,kappa,tl,Y0,C_min,inc,C0); 
t=t';
one_Total = sum(C);

[err,res] = err_func(d1l(1:x),C(1:x),f);
[~,~,L] = err_func_MCMC(d1l(1:x),C(1:x),f,120);

if dis_data == 1
    hold on
    if scat == 1
    p1 = scatter(dat_full(:,1),dat_full(:,2),30,'r','filled');
    p1 = scatter(ti,d1,30,'k','filled');
    else
    p1 = plot(dat_full(:,1),dat_full(:,2),'r','Linewidth',1.5);
    end
box on 
end


if dis == 1
hold on
p2 = plot(t,C,'Color',[col],'Linewidth',1.5);
xlabel('Time (Weeks)','FontSize',22)
ylabel('Number of Cases','FontSize',22)
if titl == 1
title(['$N_I$ = ', num2str(100*Im_frac), '$\%$'],'Fontsize',22);
end
xlim([tl(1) tl(length(tl))])
shg
legend([p1 p2],{'Data','1-Group Model'},'FontSize',18);
end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

function err = odefit(theta,d1,kappa,mu,ti,f)
N = abs(theta(1));
beta = abs(theta(2));
Y0 = IC1(d1(1),N,kappa,mu,0); 
[C,t] = dis_cases_1(N,beta,mu,kappa,ti,Y0,C_min,inc,C0); 

l1 = length(d1);
l2 = length(C);
if l1 ~= l2
    err = 1e10;
else
    err = err_func(d1,C,f);
end

end

end
