try: paraview.simple
except: from paraview.simple import *

import os

from subprocess import *

#in_filename = "c:/Users/.../t3p_results/OUTPUT/wakefield.out"
#in_filename = "/Users/cho/Desktop/cw14/CW14/examples/t3p/test/t3p_results/OUTPUT/wakefield.out"
in_filename = "./t3p_results/OUTPUT/wakefield.out"
#print "in_filename: '%s'" % in_filename
print("in_filename:", in_filename)


########################################


def get_data(filename):
  wake_out = CSVReader(FileName=filename,
                       FieldDelimiterCharacters = ' ',
                       HaveHeaders = 0)
  wake_out.UpdatePipeline()  # this is necessary!  blank objects otherwise...
  return wake_out


def show_wakeplot(wake_out):
  # Create a new 'Line Chart View'
  lineChartView1 = CreateView('XYChartView')
  lineChartView1.ViewTime=0.0
  lineChartView1.ChartTitle = 'Wakefield'
  lineChartView1.LeftAxisTitle = 'W [V/pC]'
  lineChartView1.BottomAxisTitle = 's [m]'
  # XYChartView1 = CreateView('XYChartView')
  # XYChartView1.ViewTime=0.0

  # animscene = GetAnimationScene()
  # animscene.ViewModules = [ XYChartView1 ]

  ######################################
  pf1 = ProgrammableFilter(registrationName='ProgrammableFilter1', Input=wake_out)
  pf1.OutputDataSetType = 'vtkTable'
  pf1.CopyArrays = 0
  pf1.RequestUpdateExtentScript = ''
  pf1.PythonPath = ''
  pf1.RequestInformationScript = ''
  pf1.Script = \
"""\
import math
import numpy

try:
  import warnings
  warnings.simplefilter("ignore", numpy.ComplexWarning)
except:
  pass

input = self.GetInput()
scale = input.GetRowData().GetArray(1).GetRange()[1] / input.GetRowData().GetArray(2).GetRange()[1]

s1 = input.GetColumn(0)
W1 = input.GetColumn(1)
I1 = input.GetColumn(2)

n1 = s1.GetNumberOfTuples()

# s
col_s = vtk.vtkDoubleArray()
col_s.SetName("s")
col_s.SetNumberOfTuples(n1)
#col_s.DeepCopy(input.GetColumn(0))
for i in range(n1):
  col_s.SetValue(i, s1.GetValue(i))

# W
col_W = vtk.vtkDoubleArray()
col_W.SetName("W")
col_W.SetNumberOfTuples(n1)
#col_W.DeepCopy(input.GetColumn(1))
for i in range(n1):
  col_W.SetValue(i, W1.GetValue(i))

# I
col_I = vtk.vtkDoubleArray()
col_I.SetName("I")
col_I.SetNumberOfTuples(n1)
#col_I.DeepCopy(input.GetColumn(2))
for i in range(n1):
  col_I.SetValue(i, scale * I1.GetValue(i))

integral = 0.0
Q = 0.0
for i in range(s1.GetNumberOfTuples()-1):
  w = s1.GetValue(i+1) - s1.GetValue(i)
  h = 0.5 * (I1.GetValue(i)*W1.GetValue(i) + I1.GetValue(i+1)*W1.GetValue(i+1))
  integral += w * h
  h1 = 0.5 * (I1.GetValue(i) + I1.GetValue(i+1))
  Q += w * h1

# print "Loss factor:", integral/(Q/1.e-12)/Q, "V/pC"
print("Loss factor:", integral*1.e12/Q*1.e-12, "V/pC")
print("Bunch Charge:", Q, "C")

####################
# populate the graph

output = self.GetOutputDataObject(0)
output.AddColumn(col_s)
output.AddColumn(col_I)
output.AddColumn(col_W)
"""

  ######################################
  pd1 = PlotData(registrationName='PlotData1', Input=pf1)

  dr1 = GetDisplayProperties(pd1)
  dr1.Visibility = 0
  dr1.UseIndexForXAxis = 0
  dr1.XArrayName = 's'
  dr1.AttributeType = 'Row Data'
  dr1.SeriesVisibility = ['W', 'I']
  dr1.SeriesColor = ['s',                  '0.0', '1.0', '0.0', \
                     'W',                  '1.0', '0.0', '0.0', \
                     'I',                  '0.0', '0.7', '0.0', \
                     'vtkOriginalIndices', '0.0', '0.0', '0.0']
  dr1.SeriesLineStyle = ['I', '2']
  dr1.Visibility = 1

  Show(pd1, lineChartView1, 'XYChartRepresentation')
  return lineChartView1

def show_impedanceplot(wake_out):
  # Create a new 'Line Chart View'
  lineChartView2 = CreateView('XYChartView')
  lineChartView2.ChartTitle = 'Impedance Spectrum'
  lineChartView2.LeftAxisTitle = 'Z [Ohm]'
  lineChartView2.BottomAxisTitle = 'f [GHz]'

  pf2 = ProgrammableFilter(registrationName='ProgrammableFilter2', Input=wake_out)
  pf2.OutputDataSetType = 'vtkTable'
  pf2.CopyArrays = 0
  pf2.RequestUpdateExtentScript = ''
  pf2.PythonPath = ''
  pf2.RequestInformationScript = ''
  pf2.Script = \
"""\
import math
import numpy

try:
  import warnings
  warnings.simplefilter("ignore", numpy.ComplexWarning)
except:
  pass

input = self.GetInput()

c = 2.99792458e8  # m/s

n = input.GetColumn(0).GetNumberOfTuples()
n2 = int(n/50)
#n2 = n

s0 = input.GetColumn(0).GetValue(0)
s1 = input.GetColumn(0).GetValue(1)
dt = (s1 - s0)/c
f_max = 1.0/dt

print("s0:", s0)
print("s1:", s1)
print("dt:", dt)
print("f_max:", f_max)

####################
# get FFTs

temp = []
for i in range(n):
  temp.append(input.GetColumn(1).GetValue(i))
w = numpy.fft.fft(temp)  # FFT(W)

print(w)

Q = 0.
ds = s1 - s0

temp = []
for i in range(n):
  value = input.GetColumn(2).GetValue(i)
  if i == 0 or i == n:
    Q += 0.5*value*ds
  else:
    Q += value*ds
  temp.append(value)
v =  numpy.fft.fft(temp) # FFT(I)

fac = 1.e12*Q
print("Q", Q, fac)

####################
# figure out where to stop on x axis

I_max = input.GetRowData().GetArray(2).GetRange()[1]
limit = ((math.e)**-3.0) * I_max

print("limit", limit)

for i in range(n2):
  if abs(v[i]) < limit:
    n2 = i
    break

####################
# frequency axis

col_f = vtk.vtkDoubleArray()
col_f.SetName("f")
col_f.SetNumberOfTuples(n2)
for i in range(n2):
#  col_f.SetValue(i, (f_max)*(float(i)/n)/1.0e9)
  col_f.SetValue(i, (f_max)*(i/n)/1.0e9)
  print(i, f_max*(i/n)/1.0e9)

####################
# FFT of W

col_W = vtk.vtkDoubleArray()
col_W.SetName("FFT(W)")
col_W.SetNumberOfTuples(n2)
for i in range(n2):
  col_W.SetValue(i, numpy.sqrt(w[i]*w[i].conjugate()))
  print(w[i], numpy.sqrt(w[i]*w[i].conjugate()))

####################
# FFT of I

print("   ")

col_I = vtk.vtkDoubleArray()
col_I.SetName("FFT(I)")
col_I.SetNumberOfTuples(n2)
for i in range(n2):
  col_I.SetValue(i, numpy.sqrt(v[i]*v[i].conjugate()))
  print(v[i], numpy.sqrt(v[i]*v[i].conjugate()))

####################
# impedance

Zmax = 0.
Imax = 0.

print("    ")

col_Z = vtk.vtkDoubleArray()
col_Z.SetName("Z")
col_Z.SetNumberOfTuples(n2)
for i in range(n2):
  W = w[i]
  I = v[i]*c

  Z = W/I * fac
  Z = numpy.sqrt(Z*Z.conjugate())
  print(W, I, Z)

  col_Z.SetValue(i, Z)

  if Z > Zmax:
       Zmax = Z
  I = numpy.sqrt(I*I.conjugate())
  if I > Imax:
       Imax = I

scale = Zmax/Imax
col_I1 = vtk.vtkDoubleArray()
col_I1.SetName("I")
col_I1.SetNumberOfTuples(n2)
for i in range(n2):
  I = scale*v[i]*c
  I = numpy.sqrt(I*I.conjugate())

  col_I1.SetValue(i, I)

####################
# populate the graph

output = self.GetOutputDataObject(0)
output.AddColumn(col_f)
output.AddColumn(col_Z)
output.AddColumn(col_I1)
output.AddColumn(col_W)
output.AddColumn(col_I)
"""
  ######################################
  pd2 = PlotData(registrationName='PlotData2', Input=pf2)

  dr2 = GetDisplayProperties(pd2)
  dr2.Visibility = 0
  dr2.UseIndexForXAxis = 0
  dr2.XArrayName = 'f'
  dr2.AttributeType = 'Row Data'
  dr2.SeriesVisibility = ['Z', 'I']
  # dr2.SeriesVisibility = ['FFT(W)', '0']
  dr2.SeriesColor = ['f',                  '0.0', '1.0', '0.0', \
                     'FFT(W)',             '1.0', '0.0', '0.0', \
                     'FFT(I)',             '0.0', '0.7', '0.0', \
                     'I',                  '0.0', '0.7', '0.0', \
                     'Z',                  '0.0', '0.0', '1.0', \
                     'vtkOriginalIndices', '0.0', '0.0', '0.0']
  dr2.SeriesLineStyle = ['I', '2']
  dr2.Visibility = 1
  Show(pd2, lineChartView2, 'XYChartRepresentation')
  return lineChartView2

################################################################################
################################################################################

paraview.simple._DisableFirstRenderCameraReset()

SetActiveView(None)

tmp_filename = in_filename+".tmp"
out_file = open(tmp_filename, 'w')
for l in open(in_filename):
  line = l.strip()
  if not line.startswith('#'):
#    print >>out_file, line.rstrip()
    print (line.rstrip(), file=out_file)
out_file.close()

if tmp_filename == '':
  print("Wake plot cancelled.")
else:
#  print("temporary filename: '%s'") % tmp_filename
  print("temporary filename:", tmp_filename)


wake_out = get_data(tmp_filename)
# ----------------------------------------------------------------
# setup view layouts
# ----------------------------------------------------------------
lineChartView1 = show_wakeplot(wake_out)
lineChartView2 = show_impedanceplot(wake_out)
# create new layout object 'wakeplot'
layout1 = CreateLayout(name='wakeplot')
layout1.SplitVertical(0, 0.500000)
layout1.AssignView(1, lineChartView1)
layout1.AssignView(2, lineChartView2)

# ----------------------------------------------------------------
# restore active view
SetActiveView(lineChartView1)

#os.remove(tmp_filename)
Render()

