#!/usr/bin/env python3  

import math
import matplotlib.pyplot as plot
from dataclasses import dataclass
import socket

WAVE_UP_FREQ    = 100e3
WAVE_MAX_POINTS = 500e3

PRINT_WAVEFORM = True

IP = "192.168.1.188"
PORT = 10001

@dataclass
class Ramp:
    rangeL: float
    rangeH: float
    slewRateI: float
    length_s: float = 0.0
    
# Slew rate is 1 A/s from 0 A to 50 A
# Slew rate is 0.5 A/s from 50 A to 70 A
# Slew rate is 0.25 A/s from 70 A to 75.3 A

l = [ 
    Ramp(0.0,   50.0,  1.000),
    Ramp(50.0,  70.0,  0.500),
    Ramp(70.0,  75.3,  0.250)
]

# Calculate the length of each sector
time_s = 0
for item in l:
    item.length_s = (item.rangeH - item.rangeL) / item.slewRateI
    time_s +=  item.length_s
print ("Total waveform time is:", time_s, "seconds")



# FAST-PS-1K5 waveform update rate is 100 KHz and the max points number is 500.000
# A prescaler (P) can be used in order to reduce the update rate to 100 KHz/P 
prescaler = math.ceil((time_s*WAVE_UP_FREQ)/WAVE_MAX_POINTS)
updateFreq = WAVE_UP_FREQ/prescaler
print("The required prescaler (P) for the selected waveform is:", prescaler)
print("The update rate of the waveform is:", updateFreq, "Hz")

y = []
x = []
dx = 1.0/updateFreq
x_acc = 0.0

for i in l:
    sectionPoints = int(i.length_s*updateFreq)
    dy = i.slewRateI/updateFreq
    
    for j in range(sectionPoints):
        # y logic
        y.append(i.rangeL + j*dy)
        # x logic (it is an accumulator)
        x.append(x_acc)
        x_acc += dx
# ======================================
# Plot
# ======================================
if PRINT_WAVEFORM:
    plot.plot(x, y, color="C9", label="Ramp")
    plot.title('Ramp')

    plot.xlabel('Time')
    plot.ylabel('Amplitude')
    plot.grid(True, which='both')
    plot.axhline(y=0, color='k')
    plot.legend()
    plot.show()

# ======================================
# Create command
# ======================================
# Create wave command
wavestring = "WAVE:POINTS:"
for i in range(len(y)):
    wavestring += str(y[i])
    if (i != len(y)-1):                # if not last element add ':'
        wavestring += ":"
wavestring += "\r\n"                    # ad termination chars


# Create sockets
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)     # TCP
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.connect((IP, PORT))

# Check version
print("Check version")
prescalerStr = "VER:?\r\n"
s.sendall(prescalerStr.encode())
data = s.recv(2048).decode()
print(data)


# Set prescaler
print("Set prescaleer")
prescalerStr = "WAVE:PRESCALER:" + str(prescaler) + "\r\n"
s.sendall(prescalerStr.encode())
data = s.recv(2048).decode()
print(data)

# Set number of periods to 1
print("Set number of periods to 1")
prescalerStr = "WAVE:N_PERIODS:1\r\n"
s.sendall(prescalerStr.encode())
data = s.recv(2048).decode()
print(data)

# Send points
print("Send Waveform data...")
s.sendall(wavestring.encode())
data = s.recv(2048).decode()
print(data)

# Now it is possible to execute the waveform with the cmd WAVE:START


s.close()