%% Create a vessel mask
%
% [vesselMask] = createVesselMask(im, mask, diskRadius, thresh, plot_bool)
%
% in:
%      im         - input image for creating vessel mask
%      mask       - mask to apply to image e.g. brain mask
%                   default - true(size(im))
%      diskRadius - radius of disk used for background signal removal
%                   (number of pixels)
%      thresh     - custom modulation of global threshold
%      plot_bool  - plot results yes/no
%
% out:
%      vesselMask - the caclulate vessel mask
%
% Written by Joseph G. Woods, FMRIB, Oxford, June 2019

function [vesselMask] = createVesselMask(im, mask, diskRadius, thresh, plot_bool)


if nargin < 5 || isempty(plot_bool)
    plot_bool = false; % Don't plot
end
if nargin < 4 || isempty(thresh)
    thresh = 1; % Don't relax calculated threshold
end
if nargin < 3 || isempty(diskRadius)
    diskRadius = 10; % Assume pixels are 1 mm in size and choose diskRadius of 10 pixels
end
if nargin < 2 || isempty(mask)
    mask = true(size(im));
end

% Rescale image for use with imbinarize
im = imRescale(im);

% Remove background signal variation
se         = strel('disk', diskRadius); % Creates disk element
background = imopen(im, se);            % Morphological opening with disk
im2        = imRescale(im-background);  % Rescale again to keep 0 and 1 and bounds

% Improve contrast of image
im3 = imadjust(im2, stretchlim(im2,[0,0.99]));

% Mask the image with the supplied mask
im4 = im3;
im4(~mask) = 0;

% Use a global threshold to binarize image (uses Otsu's method)
[level, ~]  = graythresh(im4);
maskGlobal  = imbinarize(im4, level*thresh);

% Finally, remove contiguous objects that contain fewer than a certain number of pixels
vesselMask   = bwareaopen(maskGlobal, 5);

% % Show images at the end
if plot_bool
    figure('Renderer', 'painters', 'Position', [100 100 1200 400])
    subplot(1,3,1)
    imagesc(im), colormap hot
    axis equal off
    title('Input image')
    
    subplot(1,3,2)
    imagesc(vesselMask), colormap hot
    axis equal off
    title('Mask')
    
    subplot(1,3,3)
    imagesc(im.*vesselMask), colormap hot
    axis equal off
    title('Masked image')

end


