%% Make a forecast using the TS model
close all
coll = {'b','r'};
patch_col = (1/256).*[152, 188, 214;
                      214,152,208];

u = 2;
K = 8;           % Number of training patients
d_end = 2;       % Forecast made at this time
gr = 1;          % group
in = 1;

sav = 0;         % Save figures?
z = 2;           % 1: Normal omega, eta; 2: rescaled
t_start = 1;     % Start time of sampling   
t_end = 10;      % End time of sampling   
inc = 0.5;       % Sampling interval
sc = 0.3*ones(10,1);
% sc = eps*ones(10,1);
G = G_func(gr);

sigma1 = 0.278;    % Training data noise (default = 0.7 for non-scaling, default = 0.278 for scaling)
sigma3 = 0.278;   % MCMC likelihood function noise (default = 0.7 for non-scaling, default = 0.278 for scaling)
dis = 1;        % Display figures?
gen = 1;        % Generate data?

%%%%% MCMC Parameters %%%%%%%%%%
sd = 0.1*[1,1,1,1];
M1 = 2e2;         % Maximum theta vec
M2 = 2e2;         % Maximum theta vec
tstop1 = 60;    % Time of MCMC simulation
tstop2 = 60;    % Time of MCMC simulation
n = 32;          % Histogram refinement

%%% Generate Training Data %%%%%%%%%%%%%%%%%%%%%%%%%%%%
tstart0 = tic;
if gen == 1
if K ~= 0
[dmat,dmatn,ti] = gen_data_SD(K,sc,G,sigma1,0,t_end,inc,z);
train_vec = [ti',dmatn];
end
end

%%% Generate Test Data %%%%%%%%%%%%%%%%%%%%%%%%%%%%
if gen == 1
sc = eps*ones(10,1);
[dmat,dmatn,ti] = gen_data_SD(1,sc,G,sigma1,0,t_end,inc,z);
x = t_start/inc;
dat1 = [ti',dmatn];
dat1(:,1) = ti'-t_start;
dat = dat1;
dat1([x+1:length(dat1)],:) = [];
dat([1:x],:) = [];
ti([1:x]) = [];
dat(:,1) = ti'-t_start;
test_dat = dat;
end

%%%%%% MCMC Forecast %%%%%%%%%%%%%%%%%%%
%%% Compute Prior %%%%%%%%%%%
t = [0:inc:t_end]';
tstart = tic;
if K ~= 0
prior_vec = MCMC_prior(train_vec,sigma1,sd,tstop1,M1,n,t_end);
else
prior_vec = [];    
train_vec = [];
end
train_vec(:,1) = [];


%% Compute Forecast %%%%%%%%%%%
u = 2;
if d_end ~= 0
[theta_vec,acc_vec,guess,MLE] = MCMC_TS_v2(test_dat,M2,1,sigma3,sd,tstop2,ones(4,1),[],prior_vec,d_end,n,u,t_start,in);
else 
theta_vec = prior_vec;
end

%% Plot Forecast %%%%%%%%%%%%%%%


close all
tn = linspace(-t_start,d_end,90);
t2a = linspace(d_end,t_end);
[t1,d1,t2,d2] = split_data(dat,d_end);

clear Yint
clear Y2

Y_vec = MCMC_sim(theta_vec,t);
for i = 1
Ya = interp1(t-theta_vec(i,4),Y_vec(i,:),tn,'linear','extrap');
end

t = linspace(0,t_start+d_end);
for i = 1:size(theta_vec,1)
theta = theta_vec(i,:);
Y1 = ode_func3(t,theta,[1 1e-2]);
Y = Y1(:,2);
% Yint(:,i) = interp1(t-theta_vec(i,4),Y,tn,'nearest','extrap');
l = size(Y1,1);
Y0 = Y1(l,:);
Y2a = ode_func3(t2a,theta,[Y0(1),d1(length(d1))]);
Y2b = Y2a(:,2);
Y2(:,i) = Y2b;
Y3 = interp1(t2a,Y2b,t2);
D(i) = err_func(Y3(2:length(Y3)),d2,4);
end

% [me,yl,yu] = shade_func(Yint',95);
% p1 = plot(tn,me,'b--','linewidth',1.5);
figure
hold on

if ~isempty(train_vec)
t9 = [0:inc:t_end]-t_start;
p1 = plot(t9,train_vec,'color',[0, 135, 0]/256,'linewidth',0.5);
end

if gr == 1
plot(tn,Ya,'b--','linewidth',1.5);
elseif gr == 2
plot(tn,Ya,'r--','linewidth',1.5);
end
[me,yl,yu] = shade_func(Y2',95);
alpha = 1;
patch([t2a fliplr(t2a)], [yl fliplr(yu)],[1 0 0],'Facecolor',patch_col(gr,:),'FaceAlpha',alpha)
p4 = plot(t2a,me,char(coll(gr)),'linewidth',1.5);
p5 = scatter(dat1(:,1),dat1(:,2),50,[0,0.6,0],'filled','MarkerEdgeColor','k');
p2 = scatter(t1,d1,50,'filled','k');
p3 = scatter(t2,d2,50,[255,130,0]./256,'filled','MarkerEdgeColor','k');
xline(d_end,'k--','linewidth',1.5)
axisfunc()
xlim([-t_start t_end-t_start])

if ~isempty(train_vec)    
legend([p1(1),p5,p2,p3,p4],{'Training Data','Unobserved Data','Observed Data','Future Data','Forecast'},'FontSize',20);    
else
legend([p5,p2,p3,p4],{'Unobserved Data','Observed Data','Future Data','Forecast'},'FontSize',20);    
end
shg

DIC2 = mean(D) + 0.5*var(D);


function Y1 = ode_func3(t,theta,Y0)
sol =  ode45(@(t,Y) TS(t,Y,theta(1),theta(2),theta(3)), t, Y0); 
Y1 = deval(sol,t)';
end


% Y_vec = MCMC_sim(theta_vec,t);
% figure
% [D,Dm] = plot_forecast_MCMC(theta_vec,train_vec,test_dat,Y_vec,d_end,dis,gr);
% DIC2 = mean(D) + 0.5*var(D);

% %% Save Figures
% if sav == 1
% CL = clock;
% filename = ['Forecast_v',num2str(v),'_sigma',num2str(sigma),'_K',num2str(K),'_d',num2str(d_end),'_',date,'_',num2str(CL(4)),num2str(CL(5))];
% saveas(gcf,filename,'epsc')
% savefig(gcf,filename)
% end
% 
% toc(tstart0)

