import numpy as np
import shutil
import array
import os
from deap import base, creator, tools
from libensemble.libE import libE
from libensemble.alloc_funcs.start_only_persistent import only_persistent_gens as alloc_f
from libensemble.tools import add_unique_random_streams
from libensemble.message_numbers import STOP_TAG, PERSIS_STOP, FINISHED_PERSISTENT_GEN_TAG
from libensemble.tools.gen_support import sendrecv_mgr_worker_msg
from libensemble.executors.mpi_executor import MPIExecutor
from configparser import ConfigParser, ExtendedInterpolation
from A3PI_WorkflowTemplate import A3PI_Workflow
import mpi4py
mpi4py.rc.recv_mprobe = False
from mpi4py import MPI
libE_specs = {'comms': 'mpi'}    #Use MPI comms (other options not supported for A3PI)

def setup_libensemble(workflow):
    
    if workflow.config.has_option('LIBENSEMBLE', 'num_workers'):
        nworkers = workflow.config.get('LIBENSEMBLE', 'num_workers')
        if nworkers == 'auto':
            nworkers = MPI.COMM_WORLD.Get_size() - 1
    else:
        nworkers = MPI.COMM_WORLD.Get_size() - 1    #One process reserved for manager
    is_master = (MPI.COMM_WORLD.Get_rank() == 0)    #Master process (manager) is rank 0

    exctr = MPIExecutor(central_mode=True)  #Create MPI executor for A3PI
    persis_info = add_unique_random_streams({}, nworkers+1)  #Intitialize manager/workers random streams
    
    #Add simulation apps to executor for each application needed in workflow
    steps = [x[1].split()[0] for x in workflow.config.items('WORKFLOW')]
    if 'cubit' in steps:
        app_path = os.path.abspath(workflow.config.get('PATHS','cubit_path'))
        exctr.register_calc(full_path=app_path, app_name='cubit', calc_type='sim')
    if 'rfpost' in steps:
        app_path = os.path.abspath(workflow.config.get('PATHS','acdtool_path'))
        exctr.register_calc(full_path=app_path, app_name='acdtool', calc_type='sim')
    if 'omega3p' in steps:
        app_path = os.path.abspath(workflow.config.get('PATHS','omega3p_path'))
        exctr.register_calc(full_path=app_path, app_name='omega3p', calc_type='sim')
    if 'impactt' in steps:
        app_path = os.path.abspath(workflow.config.get('PATHS','impactt_path'))
        exctr.register_calc(full_path=app_path, app_name='impactt', calc_type='sim')
    if 'impactz' in steps:
        app_path = os.path.abspath(workflow.config.get('PATHS','impactz_path'))
        exctr.register_calc(full_path=app_path, app_name='impactz', calc_type='sim')

    #Population size, max number of generations, and input/output dimensions
    pop_size = int(workflow.config.get('LIBENSEMBLE','pop_size'))
    num_gen = int(workflow.config.get('LIBENSEMBLE','num_gen'))
    ind_size = len(workflow.config.get('LIBENSEMBLE','vars').split())
    num_objectives = len(workflow.config.get('LIBENSEMBLE','objectives').split())

    #List of lower and upper bounds, weights must be a tuple
    lb = [float(x.strip(',')) for x in workflow.config.get('LIBENSEMBLE','lower_bounds').split()]
    ub = [float(x.strip(',')) for x in workflow.config.get('LIBENSEMBLE','upper_bounds').split()]
    w = tuple([float(x.strip(',')) for x in workflow.config.get('LIBENSEMBLE','weights').split()])
    
    #NSGA-II mutation parameters and probabilities
    cxpb = float(workflow.config.get('LIBENSEMBLE','cxpb'))
    eta = float(workflow.config.get('LIBENSEMBLE','cxpb'))
    indpb = float(workflow.config.get('LIBENSEMBLE','cxpb'))
    
    #Construct generator function parameters with DEAP parameters (user-passed arguments)
    gen_specs = {'gen_f': deap_nsga2,
                 'in': ['sim_id'],
                 'out': [('individual', float, ind_size), ('generation', int)],
                 'user': {
                    'lb': lb,
                    'ub': ub,
                    'weights': w,
                    'pop_size': pop_size,
                    'indiv_size': ind_size,
                    'cxpb': cxpb,
                    'eta': eta,
                    'indpb': indpb}
                 }
    
    #Construct simulation function parameters for optimization
    sim_specs = {'sim_f': A3PI_sim_workflow,
                 'in': ['individual'],
                 'out': [('fitness_values', float, num_objectives)],
                 'user': {'workflow':workflow}
                 }
    
    #Resource allocation parameters (using persistent generator)
    alloc_specs = {'out': [('given_back', bool)], 'alloc_f': alloc_f}
    
    #Termination criteria for optimization
    exit_criteria = {'sim_max': pop_size*(num_gen+1)}
    
    #Call libEnsemble and return history array H
    H, persis_info, flag = libE(sim_specs, gen_specs, exit_criteria, persis_info,
                                alloc_specs)
    
    if is_master:
        print('---')
        print(H)
        print('---')
        for i in range(len(H['individual'])):
            print(H['individual'][i])

def A3PI_sim_workflow(H, persis_info, sim_specs, _):
    xvals = H['individual'][0]
    workflow = sim_specs['user']['workflow']

    #Make a new config file specific for the libEnsemble worker by copying the 
    #base A3PI workflow config file and adjusting parameters
    worker_configfile = 'workflow_worker_' + str(MPI.COMM_WORLD.Get_rank())
    shutil.copyfile(workflow.configfile, worker_configfile)
    worker_config = ConfigParser(interpolation=ExtendedInterpolation())
    worker_config.read(worker_configfile)
    basepath = worker_config.get('PATHS', 'workflow_path')
    worker_path = basepath + '_worker_' + str(MPI.COMM_WORLD.Get_rank())
    worker_config.set('PATHS','workflow_path', worker_path)
    worker_config.set('RUN_PARAMETERS','A3PI_mode','worker')
    
    #Update the values in VARS requested for optimization
    ivar_names = [x.strip(',') for x in workflow.config.get('LIBENSEMBLE','vars').split()]
    for i in range(len(ivar_names)):
        worker_config.set('VARS', ivar_names[i], str(xvals[i]))

    #Write updated values to worker config file
    with open(worker_configfile, 'w') as file:
        worker_config.write(file)
    
    #Evaluate worker config file with A3PI workflow
    worker_workflow = A3PI_Workflow(worker_configfile)
    os.remove(worker_configfile)
    worker_workflow.run_all()
    
    #Create an output yvals to be processed as objective function
    out_names = [x.strip(',') for x in workflow.config.get('LIBENSEMBLE','objectives').split()]
    out_dim = len(out_names)
    yvals = np.zeros(out_dim)
    
    #Read raw objective values from ImpactT output files
    for i in range(out_dim):
        yvals[i] = worker_workflow.evaluate(out_names[i])

    #Apply constraints as penalty terms based on distance from penalty bound
    penalty = 1.0   #Start penalty factor at 1.0 (no penalty)
    w = [float(x.strip(',')) for x in workflow.config.get('LIBENSEMBLE','weights').split()]
    if workflow.config.has_option('LIBENSEMBLE','constraints'):
        const_names = [x.strip(',') for x in workflow.config.get('LIBENSEMBLE','constraints').split()]
        const_lb = [float(x.strip(',')) for x in workflow.config.get('LIBENSEMBLE','const_lower_bounds').split()]
        const_ub = [float(x.strip(',')) for x in workflow.config.get('LIBENSEMBLE','const_upper_bounds').split()]
        const_pn = [float(x.strip(',')) for x in workflow.config.get('LIBENSEMBLE','const_penalty').split()]
        const_dim = len(const_names)
        const_vals = np.zeros(const_dim)
        
        #Evaluate constraints and check if within bounds
        for i in range(const_dim):
            const_vals[i] = worker_workflow.evaluate(const_names[i])
            if const_vals[i] < const_lb[i]:
                penalty = penalty + abs(const_vals[i]-const_lb[i])*const_pn[i]
            if const_vals[i] > const_ub[i]:
                penalty = penalty + abs(const_vals[i]-const_ub[i])*const_pn[i]
    
    #Penalize yvals with penalty factor and weights
    yvals_pn = [yvals[i]*(penalty**(-np.sign(yvals[i])*w[i])) for i in range(out_dim)]
    out = np.zeros(out_dim, dtype=sim_specs['out'])
    out['fitness_values'] = tuple(yvals_pn)

    #Send output to DEAP for optimization
    return out, persis_info

def nsga2_toolbox(gen_specs):
    '''
    Returns a DEAP toolbox for use in a NSGA2 loop, derived from `this example.
    <https://github.com/ChristopherMayes/xdeap/blob/master/xdeap/nsga2_tools.py>`_
    '''
    w = gen_specs['user']['weights']
    eta = gen_specs['user']['eta']
    inp = gen_specs['user']['indpb']
    lb = gen_specs['user']['lb']
    ub = gen_specs['user']['ub']
    dim = gen_specs['user']['indiv_size']

    creator.create('MyFitness', base.Fitness, weights=w)
    creator.create('Individual', array.array, typecode='d', fitness=creator.MyFitness)
    toolbox = base.Toolbox()

    toolbox.register('attr_float', uniform, lb, ub, dim)
    toolbox.register('individual', tools.initIterate, creator.Individual, toolbox.attr_float)
    toolbox.register('population', tools.initRepeat, list, toolbox.individual)

    toolbox.register('mate', tools.cxSimulatedBinaryBounded, low=lb, up=ub, eta=eta)
    toolbox.register('mutate', tools.mutPolynomialBounded, low=lb, up=ub, eta=eta, indpb=inp)
    toolbox.register('select', tools.selNSGA2)

    return toolbox

def evaluate_pop(g, deap_object, Out, comm):
    '''
    Evaluates the fitness of a population by communicating the individuals in
    the population to the libEnsemble manager, and then awaiting their fitness_values.
    '''
    #Take population or list of individuals
    #Sending individuals from population to sim to calc fitness
    for index, ind in enumerate(deap_object):
        Out['individual'][index] = ind
        Out['generation'][index] = g
    #Sending work to sim_f, which is defined in main call script
    #A fitness value will be returned in calc_in
    tag, Work, calc_in = sendrecv_mgr_worker_msg(comm, Out)

    if tag not in [STOP_TAG, PERSIS_STOP]:
        for i, ind in enumerate(deap_object):
            #Attaching fitness values from sim to population
            #i.e. replacing values with those generated by the sim

            if isinstance(calc_in['fitness_values'][i], tuple):
                ind.fitness.values = calc_in['fitness_values'][i]
            else:
                ind.fitness.values = tuple((calc_in['fitness_values'][i],)[0])

    return deap_object, tag

def deap_nsga2(H, persis_info, gen_specs, libE_info):
    '''
    An implementation of the NSGA2 evolutionary algorithm from LibEnsemble
    '''
    # Check to make sure boundaries are list, not array
    if isinstance(gen_specs['user']['lb'], list):
        if isinstance(gen_specs['user']['ub'], list):
            pass
    else:
        print('Lower or Upper bound is not a list')
        print('This will break DEAP crossover function')
        assert isinstance(gen_specs['user']['lb'], list), "lb is wrong type"
        assert isinstance(gen_specs['user']['ub'], list), "ub is wrong type"
    
    #Initialize NSGA2 DEAP toolbox
    toolbox = nsga2_toolbox(gen_specs)
    
    #CXPB  is the probability with which two individuals are crossed
    MU, CXPB = gen_specs['user']['pop_size'], gen_specs['user']['cxpb']
    pop = toolbox.population(n=MU)  # MU is Population size ( # of individuals)
    comm = libE_info['comm']
    
    #Running fitness calc for first generation
    g = 0    #Generation count
    Out = np.zeros(gen_specs['user']['pop_size'], dtype=gen_specs['out'])
    pop, tag = evaluate_pop(g, pop, Out, comm)
    
    #This is just to assign the crowding distance to the individuals
    #no actual selection is done
    pop = toolbox.select(pop, len(pop))
    
    #Begin the multi-objective optimization evolution
    while tag not in [STOP_TAG, PERSIS_STOP]:
        
        g = g + 1
        print("-- Generation %i --" % g, flush=True)

        # Apply crossover and mutation on the offspring
        offspring = tools.selTournamentDCD(pop, len(pop))
        offspring = [toolbox.clone(ind) for ind in offspring]

        for ind1, ind2 in zip(offspring[::2], offspring[1::2]):
            if np.random.uniform() <= CXPB:
                toolbox.mate(ind1, ind2)

            toolbox.mutate(ind1)
            toolbox.mutate(ind2)
            del ind1.fitness.values, ind2.fitness.values

        #Evaluate the individuals with an invalid fitness
        #These are individuals who had their fitness deleted by crossover or mutation
        invalid_ind = [ind for ind in offspring if not ind.fitness.valid]

        #Need to check that there were invalid in divides first.
        #When using small test number of points (2, 5, etc)
        #There is a probability that there will be no invalid individuals
        if invalid_ind:
            print('Finished evaluating population, doing selection now.')
            #Running fitness calc on gens > 0
            invalid_ind, tag = evaluate_pop(g, invalid_ind, Out, comm)
            if tag not in [STOP_TAG, PERSIS_STOP]:
                #Select the next generation population
                pop = toolbox.select(pop + offspring, MU)
        else:
            print('There were no invalid individuals')
            pass

        # fits = [ind.fitness.values[0] for ind in pop]
        # if tag in [STOP_TAG, PERSIS_STOP]:
        #     #Min value when exiting
        #     print('Met exit criteria. Current minimum is:', np.min(fits))
        # else:
        #     print('Current minimum:', np.min(fits))
        #     print('Sum of fit values at end of loop', sum(fits))
    
    return Out, persis_info, FINISHED_PERSISTENT_GEN_TAG

def uniform(low, up, size=None):
    try:
        return [np.random.uniform(a, b) for a, b in zip(low, up)]
    except TypeError:
        return [np.random.uniform(a, b) for a, b in zip([low] * size, [up] * size)]
