%This script will create an automated animation of a multi-objective
%optimization process by reading-in text files for input parameters and
%output objectives from a parsed libEnsemble history array. To parse an
%array in .npy format, please use the provided A3PI_ParseHistory.py python
%script to generate the two required data text files.
%
%Also, this script can create an interactive plot where the user can adjust
%the generation with a slider object and manually select data points and
%view the full population individual information of that data point. While
%only two objectives can be plotted, these data points will provide the
%full list of input parameters {x_i} and output objectives {y_j}.
%
%Text files containing lists of input parameters and output objectives.
%The columns of data_inputs correspond to the different input paramters
%{x_i} while the columns of data_outputs correspond to the different output
%objectives {y_j}. The rows of both files must be equal to the number of 
%individial evaluations which is (pop. size)x(num. gen. +1)
%
%Written by David Bizzozero
%March 31, 2021
data_inputs = 'history_demo_inputs';
data_outputs = 'history_demo_outputs';

%Select mode: mode = 'animation' or mode = 'interactive', if 'animation'
%mode is selected, then the video writing option can be used.
mode = 'interactive';

plot_ind = [1,2];   %Indicies of two objectives to plot
pb1 = [0.5,1.2];	%Optional plot bounds for 1st objective
pb2 = [3,6];        %Optional plot bounds for 2nd objective
N_pop = 128;        %Population size (must be constant per generation)

%This section is only needed if video writing is desired (only used when
%the plot mode is set to 'animation')
v_flag = 0;     %Video flag (toggle writing to file on/off)
v_name = 'my_video.mp4';	%Video filename (filename must be free)
v_fps = 16;                 %Video framerate

%Custom data labels for inputs and outputs
xlist = {'Cathode Depth [mm]','Initial \sigma_{xy} [m]','Initial \sigma_z [m]',...
    'Solenoid str. [T]','Cathode Phase','Booster 1 Phase','Booster 2 Phase',...
    'Booster 3 Phase','Booster 4 Phase'};
ylist = {'Transverse Emittance [mm mrad]','Bunch Length [mm]'};

%% End of user inputs section

%Set optional parameters to defaults
if ~exist('pb1','var'), pb1 = []; end
if ~exist('pb2','var'), pb2 = []; end
if ~exist('xlist','var'), xlist = {}; end
if ~exist('ylist','var'), ylist = {}; end

%Read in history array of inputs and reshape to 3D array by individual,
%generation, and input parameter
x = importdata(data_inputs);
N_gen = (size(x,1)/N_pop)-1;
N_i = size(x,2);
x = reshape(x,N_pop,N_gen+1,N_i);

%Read in history array of outputs and reshape to 3D array by individual,
%generation, and objective value
y = importdata(data_outputs);
N_o = size(y,2);
y = reshape(y,N_pop,N_gen+1,N_o);
y(:,:,plot_ind(1)) = y(:,:,plot_ind(1))*1e6;
y(:,:,plot_ind(2)) = y(:,:,plot_ind(2))*1e3;

%Slice history 3D array for 2 desired objective values to plot
y1 = y(:,:,plot_ind(1));
y2 = y(:,:,plot_ind(2));

if strcmpi(mode,'animation')

    if v_flag	%Set up video writer object
        video_obj = VideoWriter(v_name,'MPEG-4');
        video_obj.FrameRate = v_fps;
        open(video_obj)
    end

    figure('position',get(0,'Screensize'));
    for i = 1:N_gen
        clf
        plot(y1(:,i),y2(:,i),'ob','markersize',10,'markerfacecolor','b')
        set(gca,'FontName','Times','FontSize',32)
        title('A3PI Multi-Objective Optimization Demo: generation '+string(i-1))
        if length(ylist)==N_o
            xlabel(ylist{plot_ind(1)})
            ylabel(ylist{plot_ind(2)})
        else
            xlabel('Objective '+num2str(plot_ind(1)))
            ylabel('Objective '+num2str(plot_ind(2)))
        end
        if length(pb1)==2, xlim(pb1); end
        if length(pb2)==2, ylim(pb2); end
        drawnow;
        if v_flag
            writeVideo(video_obj,getframe(gcf))
            writeVideo(video_obj,getframe(gcf))
        end
    end
    if v_flag; close(video_obj); end

elseif strcmpi(mode,'interactive')
    
    f = figure('position',get(0,'Screensize'),'color','w');
    
    slmin = 0;      %Minimum slider position (generation 0)
    slmax = N_gen;  %Maximum slider position (generation N_gen)
    plot_history(x,y,1,plot_ind,xlist,ylist,pb1,pb2)
    
    %Set up slider control object and callback to plot_history
    bgcolor = f.Color;
    fntsz = 24;
    hsl = uicontrol('Parent',f,'Style','slider','Min',slmin,'Max',slmax,...
                    'SliderStep',[1 1]./(slmax-slmin),'Value',slmin,...
                    'Position',[400,100,1100,23],...
                    'FontName','Times','FontSize',fntsz);
    bl1 = uicontrol('Parent',f,'Style','text','Position',[340,70,60,60],...
                    'String',num2str(slmin),'BackgroundColor',bgcolor,...
                    'FontName','Times','FontSize',fntsz);
    bl2 = uicontrol('Parent',f,'Style','text','Position',[1500,70,60,60],...
                    'String',num2str(slmax),'BackgroundColor',bgcolor,...
                    'FontName','Times','FontSize',fntsz);
    bl3 = uicontrol('Parent',f,'Style','text','Position',[800,30,300,60],...
                    'String','Generation','BackgroundColor',bgcolor,...
                    'FontName','Times','FontSize',fntsz);
    set(hsl,'Callback',@(hObject,eventdata) ...
        plot_history(x,y,round(get(hObject,'Value')+1),...
        plot_ind,xlist,ylist,pb1,pb2))
    addlistener(hsl,'ContinuousValueChange',@(hObject,eventdata) ...
        plot_history(x,y,round(get(hObject,'Value')+1),...
        plot_ind,xlist,ylist,pb1,pb2));
    
end

function plot_history(x,y,n,plot_ind,xlist,ylist,pb1,pb2)
%This function plots the 2 selected objectives given in the y array. The
%data is plotted with custom data cursor which provides full x and y data
%for given individuals in a generation.

N_o = size(y,3);
y1 = y(:,:,plot_ind(1));
y2 = y(:,:,plot_ind(2));

plot(y1(:,n),y2(:,n),'ob','markersize',10,'markerfacecolor','b');
if length(ylist)==N_o
    xlabel(ylist{plot_ind(1)})
    ylabel(ylist{plot_ind(2)})
else
    xlabel('Objective '+num2str(plot_ind(1)))
    ylabel('Objective '+num2str(plot_ind(2)))
end
if length(pb1)==2, set(gca,'xlim',pb1); end
if length(pb2)==2, set(gca,'ylim',pb2); end
title('Population: ' + string(size(y1,1)) + ', Generation: ' + string(n-1),'interpreter','latex')
set(gca,'FontName','Times','FontSize',36)
set(gca,'Parent',gcf,'position',[0.1 0.25 0.8 0.68]);
drawnow

datacursormode on
dcm_obj = datacursormode(gcf);
set(dcm_obj,'DisplayStyle','datatip','SnapToDataVertex','off','Enable','on')
set(dcm_obj,'UpdateFcn',{@update_datacursor,y1,y2,x,y,n,xlist,ylist})

end

function txt = update_datacursor(~,event_obj,y1,y2,x,y,n,xlist,ylist)
%This function is a custom data cursor used to overlay tooltip information
%with the data provided from the input parameters to the objective output
%values plotted. If xlist and ylist are provided, then the data cursor will
%display variable names, otherwise it will display x_1, x_2, ... for inputs
%and y_1, y_2, ... for outputs for the given generation.

pos = get(event_obj,'Position');
disp(pos)
txt = {};
ind = (y1(:,n)==pos(1) & y2(:,n)==pos(2));

N_i = size(x,3);
N_o = size(y,3);

for i = 1:N_i
    if length(xlist)==N_i
        txt = [txt(:)',{[xlist{i},' = ',num2str(x(ind,n,i))]}];
    else
        txt = [txt(:)',{['x_',num2str(i),' = ',num2str(x(ind,n,i))]}];
    end
end
   
for i = 1:N_o
    if length(ylist)==N_o
        txt = [txt(:)',{[ylist{i},' = ',num2str(y(ind,n,i))]}];
    else
        txt = [txt(:)',{['y_',num2str(i),' = ',num2str(y(ind,n,i))]}];
    end
end

end
