function plotSVM_djm(varargin)
% For output of TDT:
% 
% plotSVM_djm(results,cfg,passed_data [, istep [,fig]])
%
% or just directly for output of libsvm:
%
% plotSVM_djm(data_train,data_test,labels_train,labels_test,model [,fig])
%
% If patterns have 2 features (e.g. after PCA), this script plots the 
% libsvm output for a single 'i_step' of TDT's decoding.m
% If >2 features, do PCA here.
% 
% Daniel Mitchell. 17/3/17

%% check inputs
if isstruct(varargin{1}) % get variables from TDT structures
    isTDT=true;
    results=varargin{1};
    cfg=varargin{2};
    passed_data=varargin{3};
    istep=1;
    if nargin>3
        istep=varargin{4};
        if isempty(istep), istep=1; end
    end
    if nargin>4
        fig=varargin{5};
    end
    if nargin>5
        mytit=varargin{6};
    end
    
    data_train=passed_data.data(logical(cfg.design.train(:,istep)),:);
    data_test=passed_data.data(logical(cfg.design.test(:,istep)),:);
    labels_train=cfg.files.label(logical(cfg.design.train(:,istep)));
    labels_test=cfg.files.label(logical(cfg.design.test(:,istep)));
    model=results.model_parameters.output(1).model(istep);
else
    isTDT=false;
    data_train=varargin{1};
    data_test=varargin{2};
    labels_train=varargin{3};
    labels_test=varargin{4};
    model=varargin{5};
    if nargin>5
        fig=varargin{6};
    end
    if nargin>6
        mytit=varargin{7};
    end
end
   
%% get predicted labels
[labels_predicted, acc, dv]=svmpredict(labels_test,data_test,model,'-q');

%% prepare figure
if ~exist('fig','var') || isempty(fig), fig=999; end
figure(fig); clf; 
set(fig,'color','w')
ms=15; % marker size
cmapc=(hsv(model.nr_class + 1)); % for contour overlay
cmap=(hsv(128));
cind=round(linspace(1,128/3*2,model.nr_class));
hold on; 

%% if data_train has > 2 features, do PCA on train data 
% and apply same rotation to data_test and support vectors
if size(data_train,2)>2
    fullmodel=model;
    doPCA=true;
    
    try
        [PC_coeff,data_train, ~, ~, percentvar] = pca(data_train,'Economy',true,'Centered',false);
    catch
        error('To do this with princomp would need to undo centering')
        [PC_coeff,data_train] = princomp(data_train);
    end
    data_test = data_test * PC_coeff;
    model.SVs = model.SVs * PC_coeff;
    
    data_train=data_train(:,1:2);
    data_test=data_test(:,1:2);
    model.SVs=model.SVs(:,1:2);
else
    doPCA=false;
end

%% remap labels to uniform scale (helps colouring)
ulabs_train=nan(size(labels_train));
ulabs_test=nan(size(labels_test));
ulabs_predicted=nan(size(labels_predicted));
ulabels=unique([labels_train(:); labels_test(:)]);
for c=1:length(ulabels)
    ulabs_train(labels_train==ulabels(c))=c;
    ulabs_test(labels_test==ulabels(c))=c;
    ulabs_predicted(labels_predicted==ulabels(c))=c;
end

%% plot train and test data
for c=1:length(ulabels)
    ind=ulabs_train==c;
    plot(data_train(ind,1),data_train(ind,2),'.','color',cmap(cind(c),:),'MarkerSize',ms*2); 
    hold on
    
    ind=ulabs_test==c;
    plot(data_test(ind,1),data_test(ind,2),'o','color',cmap(cind(c),:),'MarkerSize',ms); 
end
axis tight %square

%% print title and possibly data descriptors
tit={'Trained on coloured dots, tested on coloured circles;',...
    'support vectors highlighted by black circles;'};

if doPCA
    tit=[tit sprintf('%d dimensions in original feature space.',size(fullmodel.SVs,2))];
end

if isTDT
    try
        [~, c]=spm_str_manip(cfg.files.name,'C');
        c.m=c.m(:)'; % ensure row vector
        tit=[tit sprintf('Data names prefixed with "%s..."',c.s)];
    catch
        clear c
        c.m=cfg.files.name(:)';
    end
    labs=[c.m(logical(cfg.design.train(:,istep))), c.m(logical(cfg.design.test(:,istep)))];
    text([data_train(:,1); data_test(:,1)],[data_train(:,2); data_test(:,2)],strcat('__',labs), 'interpreter','none');
end

if exist('mytit','var')
    title(mytit,'interpreter','none');
    disp(tit);
else
    title(tit,'interpreter','none');
end

if doPCA && exist('percentvar','var')
    xlabel(sprintf('PC 1 (%.0f%% variance)',percentvar(1)));
    ylabel(sprintf('PC 2 (%.0f%% variance)',percentvar(2)));
else
    xlabel('Feature (voxel) 1');
    ylabel('Feature (voxel) 2');
end

%% get decision space
myxlim=xlim;
myylim=ylim;
res=200;
[X, Y] = meshgrid(linspace(myxlim(1),myxlim(2),res),linspace(myylim(1),myylim(2),res));
Z = X;

if doPCA
    XY=zeros(numel(X),size(PC_coeff,2));
    XY(:,1:2)=[X(:), Y(:)];
    XY=XY*pinv(PC_coeff);
    [Z(:), acc, dv]=svmpredict(zeros(size(X(:))),XY,fullmodel,'-q');
else
    XY=[X(:), Y(:)];
    [Z(:), acc, dv]=svmpredict(zeros(size(X(:))),XY,model,'-q');
end

%% remap labels to uniform scale (helps colouring)
uZ=nan(size(Z));
for c=1:length(ulabels)
    uZ(Z==ulabels(c))=c;
end

%% highlight support vectors
sv=model.SVs;
plot(sv(:,1),sv(:,2),'ko','MarkerSize',ms); 

%% plot hyperplane (if linear 2-class)
% (not really necessary since we colour decision space, but nice to see how to do it)
if model.Parameters(2)==0 && model.nr_class==2
    w=model.SVs' * model.sv_coef;
    b=-model.rho;
    if model.Label(1)==-1
        w=w-w; b=-b;
    end
    y=(-1/w(2)) * (w(1)*myxlim + b); % from http://stackoverflow.com/questions/28556266/plot-svm-margins-using-matlab-and-libsvm
    plot(myxlim, y, 'k-');
end

%% plot correct and incorrect predictions
% (not obvious from 2d decion space after PCA)
for c=1:length(ulabels)    
    ind=ulabs_predicted==c & ulabs_test==c;
    plot(data_test(ind,1),data_test(ind,2),'+','color',cmap(cind(c),:),'MarkerSize',ms); 
    
    ind=ulabs_predicted==c & ulabs_test~=c;
    plot(data_test(ind,1),data_test(ind,2),'x','color',cmap(cind(c),:),'MarkerSize',ms); 
end

%% plot and overlay transparent decision space
divs=(1:c)-0.5;
[~, ch]=contourf(X,Y,uZ, divs,'LineStyle','none');
colormap(cmapc);
axis off
drawnow
F=getframe(gca);
axis on
delete(ch);
ax=gca;

tax=axes;
ch=image(F.cdata,'parent',tax);
alpha(ch,0.2);
axis(tax,'off')
if exist('mytit','var')
    title(mytit,'interpreter','none');
    disp(tit);
else
    title(tit,'interpreter','none');
end
axes(ax);
set(ax,'color','none')

return