import glob
import os
import shutil
import sys
from A3PI_Cubit import run_cubit
from A3PI_Acdtool import run_meshconvert, run_rfpost
from A3PI_Omega3P import run_omega3p
from A3PI_Impact import run_impactt, run_impactz, read_impact_output
from A3PI_Interpolate import run_field_interpolate_3d, run_field_interpolate_1d
from configparser import ConfigParser, ExtendedInterpolation

class A3PI_Workflow:    #Class definition for the ACE3P-IMPACT workflow
    
    #Initialize with configuration file (required) and folder (recommended)
    def __init__(self, setupfile):
        self.configfile = setupfile
        self.config = ConfigParser(interpolation=ExtendedInterpolation())
        self.config.read(self.configfile)
        self.A3PI_mode = self.config.get('RUN_PARAMETERS', 'A3PI_mode')
        
        #Check which mode A3PI is configured in
        if self.A3PI_mode == 'optimize':
            from A3PI_LibEnsemble import setup_libensemble
            setup_libensemble(self)
        
        elif self.A3PI_mode == 'worker':
            if self.config.has_option('PATHS', 'workflow_path'):
                self.folder = self.config.get('PATHS', 'workflow_path')
            self.read_steps()
            self.make_folder()
            
        elif self.A3PI_mode == 'single':
            if self.config.has_option('PATHS', 'workflow_path'):
                self.folder = self.config.get('PATHS', 'workflow_path')
            else:
                self.folder = os.getcwd()
            self.read_steps()
            self.make_folder()
            self.run_all()
    
    def make_folder(self):    #Create workflow folder and copy files
        if not glob.glob(self.folder):
            os.mkdir(self.folder)
        shutil.copy(self.configfile, self.folder)
        if self.config.has_option('RUN_PARAMETERS', 'static_files'):
            files = self.config.get('RUN_PARAMETERS', 'static_files').split()
            for file in files:
                if not glob.glob(self.folder + '/' + file.strip(',')):
                    shutil.copy(file.strip(','), self.folder)
    
    def read_steps(self):    #Add all steps in workflow config file
        self.steps = []    #List to store steps in the workflow
        if self.config.has_section('WORKFLOW'):
            tasks = self.config.items('WORKFLOW')
            tasks.sort(key=lambda x: x[0])
            for task in tasks:
                task_args = task[1].split()
                step = task_args[0]
                step_args = task_args[1::]
                if len(step_args) == 1:
                    step_args = step_args[0]
                self.steps.append((step, step_args))
    
    def run_step(self):    #Run next single step in the workflow
        [step, step_args] = self.steps.pop(0)
        if step.lower() == 'cubit':
            run_cubit(self, step_args)
        elif step.lower() == 'meshconvert':
            run_meshconvert(self, step_args)
        elif step.lower() == 'omega3p':
             run_omega3p(self, step_args)
        elif step.lower() == 'rfpost':
            run_rfpost(self, step_args)
        elif step.lower() == 'interpolate_3d':
            run_field_interpolate_3d(self, step_args)
        elif step.lower() == 'interpolate_1d':
            run_field_interpolate_1d(self, step_args)
        elif step.lower() == 'impactt':
            run_impactt(self, step_args)
        elif step.lower() == 'impactz':
            run_impactz(self, step_args)
        else:
            print('Workflow step: ' + str(step) + ' not recognized.')
    
    def evaluate(self, args):   #Evaluate specific output from ImpactT
        out = read_impact_output(self, args)
        return out
        
    def run_all(self):    #Run entire workflow in sequence
        while len(self.steps)>0:
            self.run_step()

if __name__=='__main__':
    if len(sys.argv)==2:
        configfile = sys.argv[1]
        A3PI_Workflow(configfile)