import h5py
import numpy as np
import shutil
from scipy.interpolate import interpn, interp1d
from scipy.spatial import Delaunay
import subprocess

def run_field_interpolate_3d(workflow, args):
    
    fid = eval(args)  #Fix later for multiple fids
    folder_name = workflow.config.get('INTERPOLATE_FIELD', 'folder_name')
    vals = workflow.config.get('INTERPOLATE_FIELD', 'folder_vals').split()
    
    wnames = [eval(folder_name.replace('?',x.strip(','))) for x in vals]
    interp_var = eval(workflow.config.get('INTERPOLATE_FIELD', 'interp_var'))
    interp_pts = np.asarray([eval(x.strip(',')) for x in vals])
    
    wids, wts = param_interpolate(interp_var, interp_pts)
    construct_interp_fields_3d(fid, wnames, wids, wts, workflow.folder)
    
    from A3PI_ImpactT import make_t7_file
    make_t7_file(fid, workflow.folder)
    
def run_field_interpolate_1d(workflow, args):
    
    fid = eval(args)  #Fix later for multiple fids
    folder_name = workflow.config.get('INTERPOLATE_FIELD', 'folder_name')
    vals = workflow.config.get('INTERPOLATE_FIELD', 'folder_vals').split()
    
    wnames = [eval(folder_name.replace('?',x.strip(','))) for x in vals]
    interp_var = eval(workflow.config.get('INTERPOLATE_FIELD', 'interp_var'))
    interp_pts = np.asarray([eval(x.strip(',')) for x in vals])
    
    wids, wts = param_interpolate(interp_var, interp_pts)
    construct_interp_fields_1d(fid, wnames, wids, wts, workflow.folder)

def param_interpolate(x, xp):
    
    if xp.size == xp.shape[0]:  #Single dimension case (structured data)
        xp_sort = sorted(xp)
        id_sort = sorted(range(len(xp)), key=lambda k: xp[k])
        for i in range(len(xp)):
            if i == len(xp)-1:
                print('Error: query point outside convex hull of data')
                return 0, 0
            if x >= xp_sort[i] and x < xp_sort[i+1]:
                break
        wids = [id_sort[i], id_sort[i+1]]
        wts = np.zeros(2)
        wts[0] = (xp[wids[1]]-x)/(xp[wids[1]]-xp[wids[0]])
        wts[1] = (x-xp[wids[0]])/(xp[wids[1]]-xp[wids[0]])
    else:    #Multi-dimensional case (unstructured data)
        tri = Delaunay(xp)
        simp = Delaunay.find_simplex(tri, x)
        wids = tri.simplices[simp]
        xis = tri.points[wids]
        T = np.vstack([np.transpose(xis),np.ones(len(x)+1)])
        wts = np.linalg.solve(T,np.hstack([x,1]))
        wts = wts/wts.sum()
    
    return wids, wts

def construct_interp_fields_3d(fid, wnames, wids, wts, wfolder = './'):
    
    folder = wnames[wids[0]]
    Ereal_f = h5py.File(folder +'/data' + str(fid) + '.h5', 'r')
    xmin = np.zeros(len(wids));  ymin = xmin.copy();  zmin = xmin.copy()
    xmax = np.zeros(len(wids));  ymax = xmax.copy();  zmax = xmax.copy()
    xres = [0]*len(wids);  yres = xres.copy();  zres = xres.copy()
    dx = np.zeros(len(wids));  dy = dx.copy();  dz = dx.copy()
    phase = np.zeros(len(wids))
    
    for i in range(len(wids)):
        folder = wnames[wids[i]]
        Ereal_f = h5py.File(folder +'/data' + str(fid) + '.h5', 'r')
        Ereal_d = Ereal_f['data/200/fields/E_Real']
        Eimag_f = h5py.File(folder +'/data' + str(fid+1) + '.h5', 'r')
        Eimag_d = Eimag_f['data/200/fields/E_Imag']
        
        [xmin[i], ymin[i], zmin[i]] = Ereal_d.attrs['gridGlobalOffset']
        [xres[i], yres[i], zres[i]] = Ereal_d['x'].shape[0:3]
        [dx[i], dy[i], dz[i]] = Ereal_d.attrs['gridSpacing']
        xmax[i] = xmin[i] + dx[i]*(xres[i]-1)
        ymax[i] = ymin[i] + dy[i]*(yres[i]-1)
        zmax[i] = zmin[i] + dz[i]*(zres[i]-1)
        xgrid = np.linspace(xmin[i], xmin[i] + dx[i]*(xres[i]-1), xres[i])
        ygrid = np.linspace(ymin[i], ymin[i] + dy[i]*(yres[i]-1), yres[i])
        
        x0 = xgrid.tolist().index(min(abs(xgrid)))
        y0 = ygrid.tolist().index(min(abs(ygrid)))
        Ez = Ereal_d['z'][x0, y0, :] + 1j*Eimag_d['z'][x0, y0, :]
        z0 = np.argmax(abs(Ez))  #Ez maximum amplitude position
        phase[i] = np.angle(Ez[z0])  #Ez maximum amplitude phase
        
    xmin_itrp = np.dot(xmin, wts);    xmax_itrp = np.dot(xmax, wts)
    ymin_itrp = np.dot(ymin, wts);    ymax_itrp = np.dot(ymax, wts)
    zmin_itrp = np.dot(zmin, wts);    zmax_itrp = np.dot(zmax, wts)
    xres_itrp = max(xres);  dx_itrp = (xmax_itrp-xmin_itrp)/(xres_itrp-1)
    yres_itrp = max(yres);  dy_itrp = (ymax_itrp-ymin_itrp)/(yres_itrp-1)
    zres_itrp = max(zres);  dz_itrp = (zmax_itrp-zmin_itrp)/(zres_itrp-1)
    xg_itrp = np.linspace(0, 1, xres_itrp)
    yg_itrp = np.linspace(0, 1, yres_itrp)
    zg_itrp = np.linspace(0, 1, zres_itrp)
    Yg, Xg, Zg = np.meshgrid(yg_itrp, xg_itrp, zg_itrp)
    itrp_pts = np.vstack([Xg.flatten(), Yg.flatten(), Zg.flatten()]).transpose()
    
    #Create new h5 data files from existing files and adjusts data array shapes
    #to store interpolated field data
    field_names = ['E_Real','E_Imag','B_Real','B_Imag']
    field_sz = [xres_itrp, yres_itrp, zres_itrp]
    for fid_ind in range(4):
        data_num = str(fid + fid_ind)
        shutil.copyfile(folder +'/data' + data_num + '.h5',\
                        wfolder + '/data' + data_num + '.h5')
        field_n = h5py.File(wfolder + '/data' + data_num + '.h5', 'r+')
        field_name = 'data/200/fields/' + field_names[fid_ind]
        for field_comp in ['x','y','z']:
            field_pos = [0.5 if c==field_comp else 0 for c in ['x','y','z']]
            field_str = field_name + '/' + field_comp
            del field_n[field_str]
            field_n.create_dataset(field_str, field_sz, 'f8')
            field_n[field_str].attrs.modify('position', field_pos)
            field_n[field_str].attrs.modify('unitSI', 1.0)
        field_n[field_name].attrs.modify('gridGlobalOffset',\
                                       [xmin_itrp, ymin_itrp, zmin_itrp])
        field_n[field_name].attrs.modify('gridSpacing',\
                                       [dx_itrp, dy_itrp, dz_itrp])
        field_n.close()
    
    fields = ['E','B']
    for fid_ind in range(2):
        data_num_r = str(fid + 2*fid_ind)
        data_num_i = str(fid + 2*fid_ind + 1)
        temp_r = {'x':np.zeros(field_sz),\
                  'y':np.zeros(field_sz),\
                  'z':np.zeros(field_sz)}
        temp_i = {'x':np.zeros(field_sz),\
                  'y':np.zeros(field_sz),\
                  'z':np.zeros(field_sz)}
        for i in range(len(wids)):
            folder = wnames[wids[i]]
            field_r = h5py.File(folder +'/data' + data_num_r + '.h5', 'r')
            field_i = h5py.File(folder +'/data' + data_num_i + '.h5', 'r')
            field_name_r = 'data/200/fields/' + fields[fid_ind] + '_Real'
            field_name_i = 'data/200/fields/' + fields[fid_ind] + '_Imag'
            for field_comp in ['x','y','z']:
                field_str_r = field_name_r + '/' + field_comp
                field_str_i = field_name_i + '/' + field_comp
                
                field_total = field_r[field_str_r][...] + field_i[field_str_i][...]
                field_total = field_total * np.exp(-1j*phase[i])
                
                xg_ref = np.linspace(0, 1, xres[i])
                yg_ref = np.linspace(0, 1, yres[i])
                zg_ref = np.linspace(0, 1, zres[i])
                
                field_itrp = interpn([xg_ref, yg_ref, zg_ref], field_total,\
                                      itrp_pts)*wts[i]
                field_itrp = field_itrp.reshape(xres_itrp, yres_itrp, zres_itrp)
                temp_r[field_comp] = temp_r[field_comp] + np.real(field_itrp)
                temp_i[field_comp] = temp_i[field_comp] + np.imag(field_itrp)
            field_r.close()
            field_i.close()
        field_r_fname = wfolder + '/data' + data_num_r + '.h5'
        field_i_fname = wfolder + '/data' + data_num_i + '.h5'
        field_n_r = h5py.File(field_r_fname, 'r+')
        field_n_i = h5py.File(field_i_fname, 'r+')
        for field_comp in ['x','y','z']:
            field_str_r = field_name_r + '/' + field_comp
            field_str_i = field_name_i + '/' + field_comp
            field_n_r[field_str_r][...] = temp_r[field_comp]
            field_n_i[field_str_i][...] = temp_i[field_comp]
        field_n_r.close()
        field_n_i.close()
        try:
            subprocess.call('h5repack ' + field_r_fname + ' temp_r')
            shutil.move('temp_r', field_r_fname)
            subprocess.call('h5repack ' + field_i_fname + ' temp_i')
            shutil.move('temp_i', field_i_fname)
        except:
            pass
        
def construct_interp_fields_1d(fid, wnames, wids, wts, wfolder = './'):
    
    folder = wnames[wids[0]]
    Ereal_f = h5py.File(folder +'/data' + str(fid) + '.h5', 'r')
    Ez = [0]*len(wids);
    xmin = np.zeros(len(wids));  ymin = xmin.copy();  zmin = xmin.copy()
    xmax = np.zeros(len(wids));  ymax = xmax.copy();  zmax = xmax.copy()
    xres = [0]*len(wids);  yres = xres.copy();  zres = xres.copy()
    dx = np.zeros(len(wids));  dy = dx.copy();  dz = dx.copy()
    phase = np.zeros(len(wids))
    
    for i in range(len(wids)):
        folder = wnames[wids[i]]
        Ereal_f = h5py.File(folder +'/data' + str(fid) + '.h5', 'r')
        Ereal_d = Ereal_f['data/200/fields/E_Real']
        Eimag_f = h5py.File(folder +'/data' + str(fid+1) + '.h5', 'r')
        Eimag_d = Eimag_f['data/200/fields/E_Imag']
        
        [xmin[i], ymin[i], zmin[i]] = Ereal_d.attrs['gridGlobalOffset']
        [xres[i], yres[i], zres[i]] = Ereal_d['x'].shape[0:3]
        [dx[i], dy[i], dz[i]] = Ereal_d.attrs['gridSpacing']
        xmax[i] = xmin[i] + dx[i]*(xres[i]-1)
        ymax[i] = ymin[i] + dy[i]*(yres[i]-1)
        zmax[i] = zmin[i] + dz[i]*(zres[i]-1)
        xgrid = np.linspace(xmin[i], xmin[i] + dx[i]*(xres[i]-1), xres[i])
        ygrid = np.linspace(ymin[i], ymin[i] + dy[i]*(yres[i]-1), yres[i])
        
        x0 = xgrid.tolist().index(min(abs(xgrid)))
        y0 = ygrid.tolist().index(min(abs(ygrid)))
        Ez[i] = Ereal_d['z'][x0, y0, :] + 1j*Eimag_d['z'][x0, y0, :]
        z0 = np.argmax(abs(Ez[i]))  #Ez maximum amplitude position
        phase[i] = np.angle(Ez[i][z0])  #Ez maximum amplitude phase
        
    zmin_itrp = np.dot(zmin, wts);    zmax_itrp = np.dot(zmax, wts)
    zres_itrp = max(zres);  dz_itrp = (zmax_itrp-zmin_itrp)/(zres_itrp-1)
    zg_itrp = np.linspace(0, 1, zres_itrp)

    Ez_interp = np.zeros(zres_itrp)
    zlen = zmax_itrp - zmin_itrp
    for i in range(len(wids)):
        zg_ref = np.linspace(0, 1, zres[i])
        Ez_interpf = interp1d(zg_ref, Ez[i])
        Ez_interp = Ez_interp + Ez_interpf(zg_itrp)*wts[i]*np.exp(-1j*phase[i])
    
    Ez0 = np.real(Ez_interp)    #Remove negligible imag part of Ez_norm
    Ez0p = np.gradient(Ez0, dz_itrp, edge_order=2)    #First derivative of Ez
    Ez0pp = np.gradient(Ez0p, dz_itrp, edge_order=2)    #Second derivative of Ez
    Ez0ppp = np.gradient(Ez0pp, dz_itrp, edge_order=2)    #Third derivative of Ez
    
    #Write rfdata file using 1D on-axis Ez field and its derivatives
    lines = [str(zres_itrp) + '  ' + "{:.6e}".format(zmin_itrp) + '  ' 
        + "{:.6e}".format(zmax_itrp) + '  ' + "{:.6e}".format(zlen) + '\n']
    for zind in range(zres_itrp):
        lines.append('  ' + "{: .9e}".format(Ez0[zind]) 
                     + '  ' + "{: .9e}".format(Ez0p[zind])
                     + '  ' + "{: .9e}".format(Ez0pp[zind])
                     + '  ' + "{: .9e}".format(Ez0ppp[zind]) + '\n')
    lines.append('1  0.0  0.0  0.0 \n')    #Ignore on-axis magnetic field
    lines.append('0.0  0.0  0.0  0.0 \n')
    with open(wfolder + '/rfdata' + str(fid), 'w') as file:
        file.writelines(lines)