from qick.qick import *

class AxisConstantIQ(SocIp):
    # AXIS Constant IQ registers:
    # REAL_REG : 16-bit.
    # IMAG_REG : 16-bit.
    # WE_REG   : 1-bit. Update registers.
    bindto = ['user.org:user:axis_constant_iq:1.0']
    REGISTERS = {'real_reg':0, 'imag_reg':1, 'we_reg':2}
        
    def __init__(self, description):
        # Initialize ip
        super().__init__(description)
        
        # Default registers.
        self.real_reg = 30000
        self.imag_reg = 30000
        
        # Generics
        self.B = int(description['parameters']['B'])
        self.N = int(description['parameters']['N'])
        self.MAX_V = 2**(self.B-1)-1
        
        # Register update.
        self.update()
        
    def config(self, tile, block, fs):
        self.tile = tile
        self.dac = block
        self.fs = fs

    def update(self):
        self.we_reg = 1        
        self.we_reg = 0
        
    def set_iq(self,i=1,q=1):
        # Set registers.
        self.real_reg = int(i*self.MAX_V)
        self.imag_reg = int(q*self.MAX_V)
        
        # Register update.
        self.update()
        
class AxisDdsCicV1(SocIp):
    bindto = ['user.org:user:axis_ddscic_v1:1.0']
    REGISTERS = {'pinc_reg'     : 0, 
                 'pinc_we_reg'  : 1, 
                 'prodsel_reg'  : 2,
                 'cicsel_reg'   : 3,
                 'qprod_reg'    : 4, 
                 'qcic_reg'     : 5, 
                 'dec_reg'      : 6}
    
    # Decimation range.
    MIN_D       = 2
    MAX_D       = 1000
    
    # Quantization range for product.
    MIN_QPROD   = 0
    MAX_QPROD   = 16

    # Quantization range for cic.
    MIN_QCIC    = 0
    MAX_QCIC    = 30
    
    # Sampling frequency and frequency resolution (Hz).
    FS_DDS      = 1000
    DF_DDS      = 1
    
    # DDS bits.
    B_DDS       = 32
    
    def __init__(self, description):
        # Initialize ip
        super().__init__(description)
        
        # Default registers.
        self.pinc_reg       = 0              # DC frequency.
        self.pinc_we_reg    = 0              # Don't write.
        self.prodsel_reg    = 2              # By-pass DDS.
        self.cicsel_reg     = 1              # By-pass CIC.
        self.qprod_reg      = self.MIN_QPROD # Lower bits.
        self.qcic_reg       = self.MIN_QCIC  # Lower bits.
        self.dec_reg        = self.MIN_D     # Minimum decimation.
        
    def configure(self, fs):
        fs_hz = fs*1000*1000
        self.FS_DDS = fs_hz
        self.DF_DDS = self.FS_DDS/2**self.B_DDS

    def ddsfreq(self, f=0):
        # Sanity check.
        if (f >= 0 and f < self.FS_DDS):
            # Compute register value.
            ki = int(round(f/self.DF_DDS))
            
            # Write value into hardware.
            self.pinc_reg       = ki
            self.pinc_we_reg    = 1
            self.pinc_we_reg    = 0
        
    def prodsel(self, sel="product"):
        if sel == "product":
            self.prodsel_reg = 0
        elif sel == "dds":
            self.prodsel_reg = 1
        elif sel == "input":
            self.prodsel_reg = 2

    def cicsel(self, sel="yes"):
        if sel == "yes":
            self.cicsel_reg = 0
        if sel == "no":
            self.cicsel_reg = 1

    def outsel(self, data="product", cic="yes"):
        self.prodsel(data)
        self.cicsel(cic)

    def set_qprod(self, value=0):
        # Sanity check.
        if (value >= self.MIN_QPROD and value <= self.MAX_QPROD):
            self.qprod_reg = value
            
    def get_qprod(self):
        return self.qprod_reg            

    def set_qcic(self, value=0):
        # Sanity check.
        if (value >= self.MIN_QCIC and value <= self.MAX_QCIC):
            self.qcic_reg = value
            
    def get_qcic(self):
        return self.qcic_reg
            
    def set_dec(self, value):
        # Sanity check.
        if (value >= self.MIN_D and value <= self.MAX_D):
            self.dec_reg = value

    def decimation(self, value):
        # Sanity check.
        if (value >= self.MIN_D and value <= self.MAX_D):
            # Compute CIC output quantization.
            qsel = np.ceil(3*np.log2(value))
            
            # Set values.
            self.set_dec(value)
            self.set_qcic(qsel)

class AxisWxfft65536(SocIp):
    bindto = ['user.org:user:axis_wxfft_65536:1.0']
    REGISTERS = {'dw_addr_reg':0, 'dw_we_reg':1, 'xfft_scale_reg':2, 'xfft_we_reg':3}
    
    # Number of FFT points.
    N = 65536
    
    # Number of bits.
    B = 16
    
    # Scale factors.
    SCALE_MIN = 0
    SCALE_MAX = 0xffff
    SCALE_NORM = 0xaaaa
    
    def __init__(self, description):
        # Initialize ip
        super().__init__(description)
        
        # Default registers.
        self.dw_addr_reg    = 0
        self.dw_we_reg      = 0 # Don't write.
        self.xfft_we_reg    = 0 # Don't write.                

    def configure(self, axi_dma):
        # dma.
        self.dma = axi_dma

    def window(self, wtype="hanning"):
        w = self.gen_window(wtype)
        self.load(win=w)
        
    def gen_window(self, wtype="hanning"):
        if wtype == "hanning":
            w = (2**(self.B-1)-1)*np.hanning(self.N)
        elif wtype == "rect":
            w = (2**(self.B-1)-1)*np.ones(self.N)
            
        return w

    # Load window coefficients.
    def load(self, win, addr=0):
        # Check for max length.
        if len(win) > self.N:
            raise RuntimeError("%s: buffer length must be %d samples or less." %
                  (self.__class__.__name__, self.N))

        # Check for max value.
        if np.max(win) > np.iinfo(np.int16).max or np.min(win) < np.iinfo(np.int16).min:
            raise ValueError(
                "window data exceeds limits of int16 datatype")

        # Format data.
        win = win.astype(np.int16)

        # Define buffer.
        self.buff = allocate(shape=len(win), dtype=np.int16)
        np.copyto(self.buff, win)

        #################
        ### Load data ###
        #################
        # Enable writes.
        self._wr_enable(addr)

        # DMA data.
        self.dma.sendchannel.transfer(self.buff)
        self.dma.sendchannel.wait()

        # Disable writes.
        self._wr_disable()

    def _wr_enable(self, addr=0):
        self.dw_addr_reg = addr
        self.dw_we_reg = 1

    def _wr_disable(self):
        self.dw_we_reg = 0

    def scale(self, value=0):
        self.xfft_scale_reg  = value
        self.xfft_we_reg     = 1
        self.xfft_we_reg     = 0

class AxisBufferUram(SocIp):
    # AXIS_buffer URAM registers.
    # RW_REG
    # * 0 : read operation.
    # * 1 : write operation.
    #
    # START_REG
    # * 0 : stop.
    # * 1 : start operation.
    #
    # SYNC_REG
    # * 0 : don't sync with Tlast.
    # * 1 : sync capture with Tlast.
    #
    # The block will either capture or send data out based on RW_REG operation.
    # Read/write operations will use the entire buffer. Tlast is created at the
    # end of the read to ensure DMA does not hang. Both s_axis_tdata and tuser
    # are captured. Output is always 64 bits, with the lower B bits being the
    # data and the upper 16 tuser. Un-used bits should be zero.
    #
    # With SYNC_REG, the user can control to start the capture after a Tlast
    # has been received at the input interface. Previous samples are discarded,
    # included the one with the Tlast flag, and the capture starts right after
    # that sample. If SYNC_REG is set to 0, the block will start capturing data
    # without waiting for Tlast to happen.
    bindto = ['user.org:user:axis_buffer_uram:1.0']
    REGISTERS = {   'rw_reg'    : 0, 
                    'start_reg' : 1, 
                    'sync_reg'  : 2}
    
    def __init__(self, description):
        # Initialize ip
        super().__init__(description)
        
        # Default registers.
        self.rw_reg     = 0 # Read operation.
        self.start_reg  = 0 # Stop.
        self.sync_reg   = 0 # Don't sync on TLAST.

        # Generics
        self.B = int(description['parameters']['B'])
        self.N = int(description['parameters']['N'])

        # Buffer length
        self.BUFFER_LENGTH = (1 << self.N)

    def configure(self, dma, sync="no"):
        # DMA block.
        self.dma = dma

        if sync == "no":
            self.sync_reg = 0
        elif sync == "yes":
            self.sync_reg = 1

    def capture(self, t=0.1):
        # Enable write operation.
        self.rw_reg     = 1
        self.start_reg  = 1
        
        # Wait for capture
        time.sleep(t)
        
        # Stop capture
        self.start_reg = 0
        
    def transfer(self):
        # Enable read operation.
        self.rw_reg = 0        
        
        # Define buffer:         
        buff = allocate(shape=(self.BUFFER_LENGTH,), dtype=np.uint64)

        # Start transfer.
        self.start_reg = 1

        # DMA data.
        self.dma.recvchannel.transfer(buff)
        self.dma.recvchannel.wait()

        # Stop transfer.
        self.start_reg = 0

        # Format data.
        dataI = buff & 0xFFFF;
        dataI = dataI.astype(np.int16)
        dataQ = (buff >> 16) & 0xFFFF;
        dataQ = dataQ.astype(np.int16)
        tuser = (buff >> 48) & 0xFFFF;
        tuser = tuser.astype(np.uint16)
        
        # Return data
        return np.stack((dataI,dataQ,tuser))
    
    def getdata(self, t):
        # Compute buffer capture time.
        self.capture(t)
        return self.transfer()
    
    def length(self):
        return self.BUFFER_LENGTH

class AxisAccumulatorV1(SocIp):
    bindto = ['user.org:user:axis_accumulator_v1:1.0']
    REGISTERS = {   'process_reg'           :0, 
                    'tx_and_cnt_reg'        :1, 
                    'tx_and_rst_reg'        :2, 
                    'usr_round_samples_reg' :3, 
                    'usr_epoch_rounds_reg'  :4, 
                    'debug_reg'             :12, 
                    'round_cnt_reg'         :13, 
                    'epoch_cnt_reg'         :14, 
                    'transmitting_reg'      :15}
        
    def __init__(self, description):
        # Initialize ip
        super().__init__(description)
        
        # Default registers.
        self.process_reg            = 0
        self.tx_and_cnt_reg         = 0
        self.tx_and_rst_reg         = 0
        self.usr_round_samples_reg  = 100
        self.usr_epoch_rounds_reg   = 1
        
        # Generics
        self.AXIS_IN_DW     = int(description['parameters']['AXIS_IN_DW'])
        self.AXIS_OUT_DW    = int(description['parameters']['AXIS_OUT_DW'])
        self.FFT_AW         = int(description['parameters']['FFT_AW'])
        self.BANK_ARRAY_AW  = int(description['parameters']['BANK_ARRAY_AW'])
        self.MEM_DW         = int(description['parameters']['MEM_DW'])
        self.MEM_PIPE       = int(description['parameters']['MEM_PIPE'])
        self.FFT_STORE      = int(description['parameters']['FFT_STORE'])

        # Check Parameters.
        if (self.AXIS_IN_DW != 32):
            raise ValueError('Data Width=%d not supported. Must be 32-bit'%self.AXIS_IN_DW)

        if (self.FFT_AW != 16):
            raise ValueError('FFT length=%d not supported. Must be 65536'%2**(self.FFT_AW))

        if (self.BANK_ARRAY_AW != 0):
            raise ValueError('Number of parallel input=%d not supported. Must be 1'%2**(self.BANK_ARRAY_AW))

        if (self.FFT_STORE != 0):
            raise ValueError('FFT_STORE must be set to all (0)')

        # Buffer length (one more for metadata).
        self.BUFFER_LENGTH = 2**self.FFT_AW + 1

    def configure(self, dma):
        self.dma = dma

    def start(self):
        self.process_reg = 1

    def stop(self):
        self.process_reg = 0

    def setavg(self, N = 100):
        self.usr_round_samples_reg  = N

    def transmitting(self):
        return self.transmitting_reg

    def transfer(self):
        # Define buffer:         
        buff = allocate(shape=(self.BUFFER_LENGTH,2), dtype=np.int64)

        # DMA data.
        self.dma.recvchannel.transfer(buff)
        self.dma.recvchannel.wait()

        # Format data.
        samples = buff[0:-1]
        samples = samples[:,0]
        meta0   = buff[-1,0]
        meta1   = buff[-1,1]
        tot_smp = meta0 >> np.int64(32)
        return [samples,tot_smp]   
        
class Mixer:    
    # rf
    rf = 0
    
    def __init__(self, ip):        
        # Get Mixer Object.
        self.rf = ip
    
    def set_freq(self,f,tile,block,btype="dac"):
        
        if btype == "adc":
            mixer_set = self.rf.adc_tiles[tile].blocks[block].MixerSettings
        elif btype == "dac":
            mixer_set = self.rf.dac_tiles[tile].blocks[block].MixerSettings
        else:
            raise ValueError('Block Type %s not supported.'%btype)
        
        # Make a copy of mixer settings.
        new_mixcfg = mixer_set.copy()

        # Update the copy
        new_mixcfg.update({
            'Freq' : f,            
            'PhaseOffset' : 0})

        # Update settings.
        if btype == "adc":
            self.rf.adc_tiles[tile].blocks[block].MixerSettings = new_mixcfg
            self.rf.adc_tiles[tile].blocks[block].UpdateEvent(xrfdc.EVENT_MIXER)
        else:
            self.rf.dac_tiles[tile].blocks[block].MixerSettings = new_mixcfg
            self.rf.dac_tiles[tile].blocks[block].UpdateEvent(xrfdc.EVENT_MIXER)            
       
    def set_nyquist(self,nz,tile,dac):
        dac_tile = self.rf.dac_tiles[tile]
        dac_block = dac_tile.blocks[dac]
        dac_block.NyquistZone = nz          

class TopSoc(Overlay):    
    # Constructor.
    def __init__(self, bitfile=None, force_init_clks=False,ignore_version=True, **kwargs):
        # Load overlay (don't download to PL).
        Overlay.__init__(self, bitfile, ignore_version=ignore_version, download=False, **kwargs)
        
        # Configuration dictionary.
        self.cfg = {}
        self.cfg['board'] = os.environ["BOARD"]        
        self.cfg['refclk_freq'] = 204.8

        # Read the config to get a list of enabled ADCs and DACs, and the sampling frequencies.
        self.list_rf_blocks(self.ip_dict['usp_rf_data_converter_0']['parameters'])
        
        # Configure PLLs if requested, or if any ADC/DAC is not locked.
        if force_init_clks:
            self.set_all_clks()
            self.download()
        else:
            self.download()        
        
        # DDS + CIC.
        self.ddscic = self.axis_ddscic_v1_0
        self.ddscic.configure(self.adcs['00']['fs']/8)
        
        # WXFFT.
        self.fft = self.axis_wxfft_65536_0
        self.fft.configure(self.axi_dma_coef)
        self.fft.scale(self.fft.SCALE_NORM)
        self.fft.window(wtype="hanning")
        
        # Accumulator.
        self.acc = self.axis_accumulator_v1_0
        self.acc.configure(self.axi_dma_0)
        
        # RF data converter (for configuring ADCs and DACs)
        self.rf = self.usp_rf_data_converter_0
        
        # Mixer.
        self.mixer = Mixer(self.usp_rf_data_converter_0)
        
        # Constant.
        self.iq = self.axis_constant_iq_0
        
        # Sampling frequency and decimation at ADC.
        self.fs_adc = self.adcs['00']['fs']
        self.D_adc = 8
        self.fs = self.fs_adc/self.D_adc
        
    def list_rf_blocks(self, rf_config):
        """
        Lists the enabled ADCs and DACs and get the sampling frequencies.
        XRFdc_CheckBlockEnabled in xrfdc_ap.c is not accessible from the Python interface to the XRFdc driver.
        This re-implements that functionality.
        """

        hs_adc = rf_config['C_High_Speed_ADC']=='1'

        self.dac_tiles = []
        self.adc_tiles = []
        dac_fabric_freqs = []
        adc_fabric_freqs = []
        refclk_freqs = []
        self.dacs = {}
        self.adcs = {}

        for iTile in range(4):
            if rf_config['C_DAC%d_Enable'%(iTile)]!='1':
                continue
            self.dac_tiles.append(iTile)
            f_fabric = float(rf_config['C_DAC%d_Fabric_Freq'%(iTile)])
            f_refclk = float(rf_config['C_DAC%d_Refclk_Freq'%(iTile)])
            dac_fabric_freqs.append(f_fabric)
            refclk_freqs.append(f_refclk)
            fs = float(rf_config['C_DAC%d_Sampling_Rate'%(iTile)])*1000
            for iBlock in range(4):
                if rf_config['C_DAC_Slice%d%d_Enable'%(iTile,iBlock)]!='true':
                    continue
                self.dacs["%d%d"%(iTile,iBlock)] = {'fs':fs,
                                                    'f_fabric':f_fabric,
                                                    'tile':iTile,
                                                    'block':iBlock}

        for iTile in range(4):
            if rf_config['C_ADC%d_Enable'%(iTile)]!='1':
                continue
            self.adc_tiles.append(iTile)
            f_fabric = float(rf_config['C_ADC%d_Fabric_Freq'%(iTile)])
            f_refclk = float(rf_config['C_ADC%d_Refclk_Freq'%(iTile)])
            adc_fabric_freqs.append(f_fabric)
            refclk_freqs.append(f_refclk)
            fs = float(rf_config['C_ADC%d_Sampling_Rate'%(iTile)])*1000
            #for iBlock,block in enumerate(tile.blocks):
            for iBlock in range(4):
                if hs_adc:
                    if iBlock>=2 or rf_config['C_ADC_Slice%d%d_Enable'%(iTile,2*iBlock)]!='true':
                        continue
                else:
                    if rf_config['C_ADC_Slice%d%d_Enable'%(iTile,iBlock)]!='true':
                        continue
                self.adcs["%d%d"%(iTile,iBlock)] = {'fs':fs,
                                                    'f_fabric':f_fabric,
                                                    'tile':iTile,
                                                    'block':iBlock}

    def set_all_clks(self):
        """
        Resets all the board clocks
        """
        if self.cfg['board']=='ZCU111':
            print("resetting clocks:",self.cfg['refclk_freq'])
            xrfclk.set_all_ref_clks(self.cfg['refclk_freq'])
        elif self.cfg['board']=='ZCU216':
            lmk_freq = self.cfg['refclk_freq']
            lmx_freq = self.cfg['refclk_freq']*2
            print("resetting clocks:",lmk_freq, lmx_freq)
            xrfclk.set_ref_clks(lmk_freq=lmk_freq, lmx_freq=lmx_freq)
