
class ImpactTemplate(dict):    #Template definition for ImpactT input file
    
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self['Lattice'] = []
    
    def set_procs(self, procs):
        #Set the processor array for simulation: total threads = col x row
        self['Procs'] = procs
    
    def set_steps(self, steps):
        #Set the timestep: dt (in s), max number of steps, and number of bunches
        self['Steps'] = steps
    
    def set_parts(self, parts):
        #Set the particle phase space dimension, number of particles per bunch,
        #integration flag, error output flag, diagnostic flag,
        #image charge flag, and image charge cuttoff
        self['Parts'] = parts
    
    def set_mesh(self, mesh):
        #Set the field solver mesh dimensions: nx, ny, nz; boundary condition;
        #transverse pipe radii: xrad, yrad (in m); and max domain length (in m)
        self['Mesh'] = mesh
    
    def set_dist_type(self, dist):
        #Set the initial bunch distribution code, restart flag, substep flag,
        #number of timesteps for emission, and max emission duration (in s)
        self['DistType'] = dist
    
    def set_dist(self, coord, dist):
        #Set individual coordinate distribution parameters (see guide)
        if coord in ['x','y','z']:
            self['Dist_'+coord] = dist
        else:
            print('First arg. for set_distr. must be \'x\', \'y\', or \'z\'.')
    
    def set_beam(self, beam):
        #Set beam current (in A), kinetic energy (in eV), particle mass (in eV),
        #charge (in q_e+), ref frequency (in Hz), and phase offset (in s)
        self['Beam'] = beam
    
    def add_element(self, element):
        self['Lattice'] = self['Lattice'] + [element]
    
    def remove_lattice_element(self, row=-1):
        #Remove lattice element (default: last element in lattice)
        if row ==0 or abs(row) > len(self['Lattice']):
            print('Lattice element ' + str(row) + ' not found. There are ' 
                  + str(len(self['Lattice'])) + ' elements.')
        if row == -1:
            self['Lattice'] = self['Lattice'][0:-1]
        else:
            self['Lattice'] = self['Lattice'][0:row-1] + self['Lattice'][row::]
    
    def write_file(self, filename):
        #Write to external input file and filter special characters
        lines = []
        lines.append(str(self['Procs']).replace(',','') + '\n')
        lines.append(str(self['Steps']).replace(',','') + '\n')
        lines.append(str(self['Parts']).replace(',','') + '\n')
        lines.append(str(self['Mesh']).replace(',','') + '\n')
        lines.append(str(self['DistType']).replace(',','') + '\n')
        lines.append(str(self['Dist_x']).replace(',','') + '\n')
        lines.append(str(self['Dist_y']).replace(',','') + '\n')
        lines.append(str(self['Dist_z']).replace(',','') + '\n')
        lines.append(str(self['Beam']).replace(',','') + '\n')
        for row in self['Lattice']:
            lines.append(str(row).replace(',','') + ' / \n')
        with open(filename, 'w') as file:
            file.writelines(lines)
