The current set of notebooks are under constant development.
If you have previously cloned the tutorial repository, you may need to get the latest versions of the notebooks.
First check the status of your repository:
cd hls4ml-tutorial
make clean
git status
You may have some modified notebooks. For example:
# On branch csee-e6868-spring2021
# Changes not staged for commit:
# (use "git add <file>..." to update what will be committed)
# (use "git checkout -- <file>..." to discard changes in working directory)
#
# modified: part1_getting_started.ipynb
# modified: part2_advanced_config.ipynb
# modified: part2b_advanced_config.ipynb
#
no changes added to commit (use "git add" and/or "git commit -a")
You can make a copy of those modified notebooks if you had significat changes, otherwise the easiest thing to do is to discard those changes.
ATTENTION You will loose your local changes!
git checkout *.ipynb
At this point, you can update you copy of the repository:
git pull
It is likely that you are running this notebook in the Conda environment hls4ml-tutorial-cu
.
If you did not do that yet, you should update the hls4ml
packages with the latest changes in the working branch.
conda activate hls4ml-tutorial-cu
pip uninstall hls4ml
pip install git+https://github.com/GiuseppeDiGuglielmo/hls4ml.git@gdg/cosmetics#egg=hls4ml[profiling]
You may need to restart the Jupyter notebook.
In this notebook, we will leverage Python programming to scan the design space of our HLS model.
As we did in the previous notebooks, let's import the libraries, call the magic functions, and setup the environment variables.
from tensorflow.keras.utils import to_categorical
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.metrics import accuracy_score
import numpy as np
import matplotlib.pyplot as plt
import hls4ml
%matplotlib inline
import os
os.environ['PATH'] = '/opt/Xilinx/Vivado/2019.2/bin:' + os.environ['PATH']
def is_tool(name):
from distutils.spawn import find_executable
return find_executable(name) is not None
print('-----------------------------------')
if not is_tool('vivado_hls'):
print('Xilinx Vivado HLS is NOT in the PATH')
else:
print('Xilinx Vivado HLS is in the PATH')
print('-----------------------------------')
----------------------------------- Xilinx Vivado HLS is in the PATH -----------------------------------
# Load dataset
X_train_val = np.load('X_train_val.npy')
X_test = np.load('X_test.npy')
y_train_val = np.load('y_train_val.npy')
y_test = np.load('y_test.npy', allow_pickle=True)
classes = np.load('classes.npy', allow_pickle=True)
# Load Keras model
from tensorflow.keras.models import load_model
model = load_model('model_1/KERAS_check_best_model.h5')
y_keras = model.predict(X_test)
Let's combine everything together Python programming and Precision
and Reuse Factor
knobs.
First we encapsulate in a function the creation of a hls4ml configuration dictionary and hls4ml model creation, compilation and built (C-synthesis).
import time
def process_hls4ml(model, fxd_w, fxd_i, rf, fpga_part, dse_hls_results):
# Generate a hls4ml configuration dictionary from the Keras model
config = hls4ml.utils.config_from_keras_model(model, granularity='Model')
# Update the knobs
config['Model']['ReuseFactor'] = rf
config['Model']['Precision'] = 'ap_fixed<' + str(fxd_w) + ',' + str(fxd_i) + '>'
# Each hls4ml project / synthesis run must have its own working directory
output_dir = 'model_1/hls4ml_prj_rf' + str(rf) + '_fxd' + str(fxd_w) + '.' + str(fxd_i)
# Create an HLS model from the Keras model and the updated hls4ml configuration dictionary
hls_model = hls4ml.converters.convert_from_keras_model(model,
hls_config=config,
output_dir=output_dir,
fpga_part=fpga_part)
_ = hls_model.compile()
# C-synthesis
start_time = time.time()
hls_results = hls_model.build(csim=False)
exec_time = time.time() - start_time
# Add extra information to the synthesis results
hls_results["ExecutionTime"] = exec_time
hls_results["WbitsFixedPoint"] = fxd_w
hls_results["IbitsFixedPoint"] = fxd_i
hls_results["FPGApart"] = fpga_part
# Return results to the shared dictionary
dse_hls_results[rf] = hls_results
return hls_results
We will use the Python multiprocessing
package to run multiple synthesis runs in parallel.
ATTENTION: Pay attention to the amount of memory and CPU cores/threads available on your host!
import psutil
def get_size(bytes, suffix="B"):
"""
Scale bytes to its proper format
e.g:
1253656 => '1.20MB'
1253656678 => '1.17GB'
"""
factor = 1024
for unit in ["", "K", "M", "G", "T", "P"]:
if bytes < factor:
return f"{bytes:.2f}{unit}{suffix}"
bytes /= factor
# Memory Information
print("="*33, "Memory Information", "="*33)
# Get the memory details
svmem = psutil.virtual_memory()
print(f"Total: {get_size(svmem.total)}")
print(f"Available: {get_size(svmem.available)}")
print(f"Used: {get_size(svmem.used)}")
print(f"Percentage: {svmem.percent}%")
print("="*18, "SWAP", "="*18)
# Get the swap memory details (if exists)
swap = psutil.swap_memory()
print(f"Total: {get_size(swap.total)}")
print(f"Free: {get_size(swap.free)}")
print(f"Used: {get_size(swap.used)}")
print(f"Percentage: {swap.percent}%")
# CPU Information
print("="*38, "CPU Info", "="*38)
# Number of cores
print("Physical cores:", psutil.cpu_count(logical=False))
print("Total cores:", psutil.cpu_count(logical=True))
# CPU frequencies
cpufreq = psutil.cpu_freq()
print(f"Max Frequency: {cpufreq.max:.2f}Mhz")
print(f"Min Frequency: {cpufreq.min:.2f}Mhz")
print(f"Current Frequency: {cpufreq.current:.2f}Mhz")
# CPU usage
print("CPU Usage Per Core:")
for i, percentage in enumerate(psutil.cpu_percent(percpu=True, interval=1)):
print(f"Core {i}: {percentage}%")
print(f"Total CPU Usage: {psutil.cpu_percent()}%")
================================= Memory Information ================================= Total: 62.85GB Available: 60.15GB Used: 1.95GB Percentage: 4.3% ================== SWAP ================== Total: 14.90GB Free: 11.01GB Used: 3.89GB Percentage: 26.1% ====================================== CPU Info ====================================== Physical cores: 8 Total cores: 64 Max Frequency: 0.00Mhz Min Frequency: 0.00Mhz Current Frequency: 1127.41Mhz CPU Usage Per Core: Core 0: 0.0% Core 1: 0.0% Core 2: 1.0% Core 3: 1.0% Core 4: 0.0% Core 5: 0.0% Core 6: 0.0% Core 7: 0.0% Core 8: 0.0% Core 9: 0.0% Core 10: 6.0% Core 11: 0.0% Core 12: 0.0% Core 13: 0.0% Core 14: 0.0% Core 15: 0.0% Core 16: 0.0% Core 17: 0.0% Core 18: 0.0% Core 19: 0.0% Core 20: 0.0% Core 21: 0.0% Core 22: 0.0% Core 23: 0.0% Core 24: 0.0% Core 25: 0.0% Core 26: 0.0% Core 27: 0.0% Core 28: 0.0% Core 29: 0.0% Core 30: 0.0% Core 31: 0.0% Core 32: 0.0% Core 33: 0.0% Core 34: 0.0% Core 35: 0.0% Core 36: 0.0% Core 37: 0.0% Core 38: 0.0% Core 39: 0.0% Core 40: 0.0% Core 41: 0.0% Core 42: 0.0% Core 43: 0.0% Core 44: 0.0% Core 45: 0.0% Core 46: 0.0% Core 47: 0.0% Core 48: 25.0% Core 49: 0.0% Core 50: 0.0% Core 51: 0.0% Core 52: 0.0% Core 53: 0.0% Core 54: 0.0% Core 55: 0.0% Core 56: 9.9% Core 57: 0.0% Core 58: 0.0% Core 59: 0.0% Core 60: 0.0% Core 61: 0.0% Core 62: 0.0% Core 63: 0.0% Total CPU Usage: 3.6%
You can also monitor the situation at runtime with the htop
command in a console.
Let's start with this initial configuration, but later you can change it as you prefer.
%%time
import multiprocessing
# Choose the target FPGA chip
target_fpga_part='xczu7ev-ffvc1156-2-e' # ZCU106
#target_fpga_part='xczu3eg-sbva484-1-e' # Ultra96
#target_fpga_part='xc7z020clg400-1' # Pynq-Z1
#target_fpga_part='xc7z007sclg225-1' # Minized
# DSE: Cartesian product of ReuseFactor and Precision
reuse_factor_values = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1028]
fxd_w_values = [16]
#, 14, 12, 10, 8]
fxd_i_values = [6]
#, 4]
# Append here the processes
processes = list()
# DSE results shared among processes
manager = multiprocessing.Manager()
dse_hls_results = manager.dict()
# Swipe over the Precision and ReuseFactor values and spawn synthesis processes
for fxd_w in fxd_w_values:
for fxd_i in fxd_i_values:
for rf in reuse_factor_values:
t = multiprocessing.Process(target=process_hls4ml, args=(model, fxd_w, fxd_i, rf, target_fpga_part, dse_hls_results))
processes.append(t)
t.start()
# Wait for completion
for t in processes:
t.join()
CPU times: user 12.1 ms, sys: 566 ms, total: 578 ms Wall time: 18min 41s
Print DSE results on console.
def print_hls_results(hls_results):
print('-----------------------------------')
#print(hls_results) # Print hashmap
print("FPGA part: {}".format(hls_results['FPGApart']))
print("Fixed-Point Word Size: {}".format(hls_results['WbitsFixedPoint']))
print("Fixed-Point Integer-Part Size: {}".format(hls_results['IbitsFixedPoint']))
print("Execution Time: {} s".format(hls_results['ExecutionTime']))
print("Estimated Clock Period: {} ns".format(hls_results['EstimatedClockPeriod']))
print("Best/Worst Latency: {} / {}".format(hls_results['BestLatency'], hls_results['WorstLatency']))
print("Interval Min/Max: {} / {}".format(hls_results['IntervalMin'], hls_results['IntervalMax']))
print("BRAM_18K: {} (Aval. {})".format(hls_results['BRAM_18K'], hls_results['AvailableBRAM_18K']))
print("DSP48E: {} (Aval. {})".format(hls_results['DSP48E'], hls_results['AvailableDSP48E']))
print("FF: {} (Aval. {})".format(hls_results['FF'], hls_results['AvailableFF']))
print("LUT: {} (Aval. {})".format(hls_results['LUT'], hls_results['AvailableLUT']))
print("URAM: {} (Aval. {})".format(hls_results['URAM'], hls_results['AvailableURAM']))
print('-----------------------------------')
for rf in dse_hls_results.keys():
print_hls_results(dse_hls_results[rf])
----------------------------------- FPGA part: xczu7ev-ffvc1156-2-e Fixed-Point Word Size: 16 Fixed-Point Integer-Part Size: 6 Execution Time: 892.5826549530029 s Estimated Clock Period: 5.195 ns Best/Worst Latency: 75 / 75 Interval Min/Max: 16 / 16 BRAM_18K: 2 (Aval. 624) DSP48E: 271 (Aval. 1728) FF: 53911 (Aval. 460800) LUT: 116058 (Aval. 230400) URAM: 0 (Aval. 96) ----------------------------------- ----------------------------------- FPGA part: xczu7ev-ffvc1156-2-e Fixed-Point Word Size: 16 Fixed-Point Integer-Part Size: 6 Execution Time: 905.0179946422577 s Estimated Clock Period: 5.915 ns Best/Worst Latency: 244 / 244 Interval Min/Max: 62 / 62 BRAM_18K: 2 (Aval. 624) DSP48E: 72 (Aval. 1728) FF: 60508 (Aval. 460800) LUT: 112394 (Aval. 230400) URAM: 0 (Aval. 96) ----------------------------------- ----------------------------------- FPGA part: xczu7ev-ffvc1156-2-e Fixed-Point Word Size: 16 Fixed-Point Integer-Part Size: 6 Execution Time: 912.8469898700714 s Estimated Clock Period: 5.915 ns Best/Worst Latency: 446 / 446 Interval Min/Max: 123 / 123 BRAM_18K: 2 (Aval. 624) DSP48E: 39 (Aval. 1728) FF: 59139 (Aval. 460800) LUT: 108225 (Aval. 230400) URAM: 0 (Aval. 96) ----------------------------------- ----------------------------------- FPGA part: xczu7ev-ffvc1156-2-e Fixed-Point Word Size: 16 Fixed-Point Integer-Part Size: 6 Execution Time: 913.1918203830719 s Estimated Clock Period: 5.915 ns Best/Worst Latency: 13 / 13 Interval Min/Max: 14 / 14 BRAM_18K: 4 (Aval. 624) DSP48E: 3917 (Aval. 1728) FF: 24785 (Aval. 460800) LUT: 87244 (Aval. 230400) URAM: 0 (Aval. 96) ----------------------------------- ----------------------------------- FPGA part: xczu7ev-ffvc1156-2-e Fixed-Point Word Size: 16 Fixed-Point Integer-Part Size: 6 Execution Time: 931.0812089443207 s Estimated Clock Period: 5.685 ns Best/Worst Latency: 135 / 135 Interval Min/Max: 32 / 32 BRAM_18K: 2 (Aval. 624) DSP48E: 138 (Aval. 1728) FF: 60435 (Aval. 460800) LUT: 115443 (Aval. 230400) URAM: 0 (Aval. 96) ----------------------------------- ----------------------------------- FPGA part: xczu7ev-ffvc1156-2-e Fixed-Point Word Size: 16 Fixed-Point Integer-Part Size: 6 Execution Time: 962.3046152591705 s Estimated Clock Period: 4.882 ns Best/Worst Latency: 45 / 45 Interval Min/Max: 8 / 8 BRAM_18K: 2 (Aval. 624) DSP48E: 537 (Aval. 1728) FF: 42436 (Aval. 460800) LUT: 124997 (Aval. 230400) URAM: 0 (Aval. 96) ----------------------------------- ----------------------------------- FPGA part: xczu7ev-ffvc1156-2-e Fixed-Point Word Size: 16 Fixed-Point Integer-Part Size: 6 Execution Time: 969.2283346652985 s Estimated Clock Period: 4.766 ns Best/Worst Latency: 25 / 25 Interval Min/Max: 4 / 4 BRAM_18K: 2 (Aval. 624) DSP48E: 1069 (Aval. 1728) FF: 34329 (Aval. 460800) LUT: 130912 (Aval. 230400) URAM: 0 (Aval. 96) ----------------------------------- ----------------------------------- FPGA part: xczu7ev-ffvc1156-2-e Fixed-Point Word Size: 16 Fixed-Point Integer-Part Size: 6 Execution Time: 971.6928608417511 s Estimated Clock Period: 5.915 ns Best/Worst Latency: 13 / 13 Interval Min/Max: 14 / 14 BRAM_18K: 4 (Aval. 624) DSP48E: 3917 (Aval. 1728) FF: 24785 (Aval. 460800) LUT: 87244 (Aval. 230400) URAM: 0 (Aval. 96) ----------------------------------- ----------------------------------- FPGA part: xczu7ev-ffvc1156-2-e Fixed-Point Word Size: 16 Fixed-Point Integer-Part Size: 6 Execution Time: 1004.8475732803345 s Estimated Clock Period: 4.729 ns Best/Worst Latency: 15 / 15 Interval Min/Max: 2 / 2 BRAM_18K: 3 (Aval. 624) DSP48E: 2133 (Aval. 1728) FF: 30582 (Aval. 460800) LUT: 127780 (Aval. 230400) URAM: 0 (Aval. 96) ----------------------------------- ----------------------------------- FPGA part: xczu7ev-ffvc1156-2-e Fixed-Point Word Size: 16 Fixed-Point Integer-Part Size: 6 Execution Time: 1014.7339470386505 s Estimated Clock Period: 5.915 ns Best/Worst Latency: 877 / 877 Interval Min/Max: 246 / 246 BRAM_18K: 2 (Aval. 624) DSP48E: 22 (Aval. 1728) FF: 58701 (Aval. 460800) LUT: 104502 (Aval. 230400) URAM: 0 (Aval. 96) ----------------------------------- ----------------------------------- FPGA part: xczu7ev-ffvc1156-2-e Fixed-Point Word Size: 16 Fixed-Point Integer-Part Size: 6 Execution Time: 1104.1488993167877 s Estimated Clock Period: 4.145 ns Best/Worst Latency: 9 / 9 Interval Min/Max: 1 / 1 BRAM_18K: 4 (Aval. 624) DSP48E: 3917 (Aval. 1728) FF: 26905 (Aval. 460800) LUT: 88335 (Aval. 230400) URAM: 0 (Aval. 96) -----------------------------------
Define a plot 2D function.
import matplotlib.pyplot as plt
import collections
import matplotlib.ticker as plticker
import numpy as np
# available_resources: we are plotting some hw resources as DSPs; false if resource is like runtime, latency etc
# show_values: show values as labels on the data points
# show_precision: show fixed-point precision as a labels on the data points
# xscale_log: log2 scale on x axis
# yscale_log: log2 scale on y axis
def plot_hls_results(dse_hls_results, resource_label, unit='HLS Estimates', available_resources=True, show_values=False, show_precision=True, xscale_log=False, yscale_log=False):
# Reorder the results according to the ReuseFactor (dictionary key)
dse_hls_results_ordered = collections.OrderedDict(sorted(dse_hls_results.items(), key=lambda x:x[0], reverse=False))
# Fonts
SMALL_SIZE = 8
MEDIUM_SIZE = 10
BIGGER_SIZE = 14
plt.rc('font', size=MEDIUM_SIZE) # controls default text sizes
plt.rc('axes', titlesize=BIGGER_SIZE) # fontsize of the axes title
plt.rc('axes', labelsize=BIGGER_SIZE) # fontsize of the x and y labels
plt.rc('xtick', labelsize=BIGGER_SIZE) # fontsize of the tick labels
plt.rc('ytick', labelsize=BIGGER_SIZE) # fontsize of the tick labels
plt.rc('legend', fontsize=BIGGER_SIZE) # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE) # fontsize of the figure title
# Extract values of the x and y axes
reuse_factor_x_axis = np.array(list(dse_hls_results_ordered.keys()))
resource_usage_y_axis = []
fxd_w_precision = []
fxd_i_precision = []
resource_available_y_axis=[]
fpga_parts=[]
for i, d in dse_hls_results_ordered.items():
resource_usage_y_axis.append(int(dse_hls_results_ordered[i][resource_label]))
if (available_resources):
for i, d in dse_hls_results_ordered.items():
resource_available_y_axis.append(int(dse_hls_results_ordered[i]['Available' + resource_label]))
for i, d in dse_hls_results_ordered.items():
fxd_w_precision.append(int(dse_hls_results_ordered[i]['WbitsFixedPoint']))
for i, d in dse_hls_results_ordered.items():
fxd_i_precision.append(int(dse_hls_results_ordered[i]['IbitsFixedPoint']))
for i, d in dse_hls_results_ordered.items():
fpga_parts.append(dse_hls_results_ordered[i]['FPGApart'])
fig, ax = plt.subplots(figsize=(8, 8))
ax.set_xlabel('Reuse Factor')
ax.set_ylabel(resource_label + ' (' + unit + ')')
ax.plot(reuse_factor_x_axis, resource_usage_y_axis, linestyle='--', marker='o')
if (available_resources):
ax.axhline(y=resource_available_y_axis[0], color='r', linestyle='-', label = resource_label + ' (Available)')
ax.legend(loc='upper right')
ax.set_xticks(reuse_factor_x_axis)
ax.grid(True)
plt.figtext(0.90, 0.90, 'hls4ml', fontweight='bold', wrap=True, horizontalalignment='right', fontsize=14)
plt.figtext(0.4, 0.90, fpga_parts[0], wrap=True, horizontalalignment='right', fontsize=14)
if xscale_log: ax.set_xscale('log', base=2)
if yscale_log: ax.set_yscale('log', base=2)
if show_values:
for x,y in zip(reuse_factor_x_axis,resource_usage_y_axis):
label = "{:d}".format(y)
plt.annotate(label, # this is the text
(x,y), # this is the point to label
textcoords="offset points", # how to position the text
xytext=(0,10), # distance from text to points (x,y)
ha='center') # horizontal alignment can be left, right or center
if show_precision:
for x,y,w,i in zip(reuse_factor_x_axis,resource_usage_y_axis,fxd_w_precision,fxd_i_precision):
label = '<{},{}>'.format(w,i)
plt.annotate(label, # this is the text
(x,y), # this is the point to label
textcoords="offset points", # how to position the text
xytext=(0,10), # distance from text to points (x,y)
ha='center') # horizontal alignment can be left, right or center
Plot some of the resources from the previous DSE.
plot_hls_results(dse_hls_results, 'ExecutionTime', unit='s', available_resources=False, xscale_log=True)
plot_hls_results(dse_hls_results, 'DSP48E', xscale_log=True)
plot_hls_results(dse_hls_results, 'BRAM_18K', xscale_log=True)
plot_hls_results(dse_hls_results, 'FF', xscale_log=True)
plot_hls_results(dse_hls_results, 'LUT', xscale_log=True)
Save DSE results on file for future reference.
import json
json = json.dumps(dse_hls_results.copy())
f = open("dse_hls_results.json","w")
f.write(json)
f.close()
Load DSE results from file.
import json
with open('dse_hls_results.json') as json_file:
data = json.load(json_file)
# Print the type of data variable
print(data)
{'16': {'EstimatedClockPeriod': '5.195', 'BestLatency': '75', 'WorstLatency': '75', 'IntervalMin': '16', 'IntervalMax': '16', 'BRAM_18K': '2', 'DSP48E': '271', 'FF': '53911', 'LUT': '116058', 'URAM': '0', 'AvailableBRAM_18K': '624', 'AvailableDSP48E': '1728', 'AvailableFF': '460800', 'AvailableLUT': '230400', 'AvailableURAM': '96', 'ExecutionTime': 892.5826549530029, 'WbitsFixedPoint': 16, 'IbitsFixedPoint': 6, 'FPGApart': 'xczu7ev-ffvc1156-2-e'}, '64': {'EstimatedClockPeriod': '5.915', 'BestLatency': '244', 'WorstLatency': '244', 'IntervalMin': '62', 'IntervalMax': '62', 'BRAM_18K': '2', 'DSP48E': '72', 'FF': '60508', 'LUT': '112394', 'URAM': '0', 'AvailableBRAM_18K': '624', 'AvailableDSP48E': '1728', 'AvailableFF': '460800', 'AvailableLUT': '230400', 'AvailableURAM': '96', 'ExecutionTime': 905.0179946422577, 'WbitsFixedPoint': 16, 'IbitsFixedPoint': 6, 'FPGApart': 'xczu7ev-ffvc1156-2-e'}, '128': {'EstimatedClockPeriod': '5.915', 'BestLatency': '446', 'WorstLatency': '446', 'IntervalMin': '123', 'IntervalMax': '123', 'BRAM_18K': '2', 'DSP48E': '39', 'FF': '59139', 'LUT': '108225', 'URAM': '0', 'AvailableBRAM_18K': '624', 'AvailableDSP48E': '1728', 'AvailableFF': '460800', 'AvailableLUT': '230400', 'AvailableURAM': '96', 'ExecutionTime': 912.8469898700714, 'WbitsFixedPoint': 16, 'IbitsFixedPoint': 6, 'FPGApart': 'xczu7ev-ffvc1156-2-e'}, '512': {'EstimatedClockPeriod': '5.915', 'BestLatency': '13', 'WorstLatency': '13', 'IntervalMin': '14', 'IntervalMax': '14', 'BRAM_18K': '4', 'DSP48E': '3917', 'FF': '24785', 'LUT': '87244', 'URAM': '0', 'AvailableBRAM_18K': '624', 'AvailableDSP48E': '1728', 'AvailableFF': '460800', 'AvailableLUT': '230400', 'AvailableURAM': '96', 'ExecutionTime': 913.1918203830719, 'WbitsFixedPoint': 16, 'IbitsFixedPoint': 6, 'FPGApart': 'xczu7ev-ffvc1156-2-e'}, '32': {'EstimatedClockPeriod': '5.685', 'BestLatency': '135', 'WorstLatency': '135', 'IntervalMin': '32', 'IntervalMax': '32', 'BRAM_18K': '2', 'DSP48E': '138', 'FF': '60435', 'LUT': '115443', 'URAM': '0', 'AvailableBRAM_18K': '624', 'AvailableDSP48E': '1728', 'AvailableFF': '460800', 'AvailableLUT': '230400', 'AvailableURAM': '96', 'ExecutionTime': 931.0812089443207, 'WbitsFixedPoint': 16, 'IbitsFixedPoint': 6, 'FPGApart': 'xczu7ev-ffvc1156-2-e'}, '8': {'EstimatedClockPeriod': '4.882', 'BestLatency': '45', 'WorstLatency': '45', 'IntervalMin': '8', 'IntervalMax': '8', 'BRAM_18K': '2', 'DSP48E': '537', 'FF': '42436', 'LUT': '124997', 'URAM': '0', 'AvailableBRAM_18K': '624', 'AvailableDSP48E': '1728', 'AvailableFF': '460800', 'AvailableLUT': '230400', 'AvailableURAM': '96', 'ExecutionTime': 962.3046152591705, 'WbitsFixedPoint': 16, 'IbitsFixedPoint': 6, 'FPGApart': 'xczu7ev-ffvc1156-2-e'}, '4': {'EstimatedClockPeriod': '4.766', 'BestLatency': '25', 'WorstLatency': '25', 'IntervalMin': '4', 'IntervalMax': '4', 'BRAM_18K': '2', 'DSP48E': '1069', 'FF': '34329', 'LUT': '130912', 'URAM': '0', 'AvailableBRAM_18K': '624', 'AvailableDSP48E': '1728', 'AvailableFF': '460800', 'AvailableLUT': '230400', 'AvailableURAM': '96', 'ExecutionTime': 969.2283346652985, 'WbitsFixedPoint': 16, 'IbitsFixedPoint': 6, 'FPGApart': 'xczu7ev-ffvc1156-2-e'}, '1028': {'EstimatedClockPeriod': '5.915', 'BestLatency': '13', 'WorstLatency': '13', 'IntervalMin': '14', 'IntervalMax': '14', 'BRAM_18K': '4', 'DSP48E': '3917', 'FF': '24785', 'LUT': '87244', 'URAM': '0', 'AvailableBRAM_18K': '624', 'AvailableDSP48E': '1728', 'AvailableFF': '460800', 'AvailableLUT': '230400', 'AvailableURAM': '96', 'ExecutionTime': 971.6928608417511, 'WbitsFixedPoint': 16, 'IbitsFixedPoint': 6, 'FPGApart': 'xczu7ev-ffvc1156-2-e'}, '2': {'EstimatedClockPeriod': '4.729', 'BestLatency': '15', 'WorstLatency': '15', 'IntervalMin': '2', 'IntervalMax': '2', 'BRAM_18K': '3', 'DSP48E': '2133', 'FF': '30582', 'LUT': '127780', 'URAM': '0', 'AvailableBRAM_18K': '624', 'AvailableDSP48E': '1728', 'AvailableFF': '460800', 'AvailableLUT': '230400', 'AvailableURAM': '96', 'ExecutionTime': 1004.8475732803345, 'WbitsFixedPoint': 16, 'IbitsFixedPoint': 6, 'FPGApart': 'xczu7ev-ffvc1156-2-e'}, '256': {'EstimatedClockPeriod': '5.915', 'BestLatency': '877', 'WorstLatency': '877', 'IntervalMin': '246', 'IntervalMax': '246', 'BRAM_18K': '2', 'DSP48E': '22', 'FF': '58701', 'LUT': '104502', 'URAM': '0', 'AvailableBRAM_18K': '624', 'AvailableDSP48E': '1728', 'AvailableFF': '460800', 'AvailableLUT': '230400', 'AvailableURAM': '96', 'ExecutionTime': 1014.7339470386505, 'WbitsFixedPoint': 16, 'IbitsFixedPoint': 6, 'FPGApart': 'xczu7ev-ffvc1156-2-e'}, '1': {'EstimatedClockPeriod': '4.145', 'BestLatency': '9', 'WorstLatency': '9', 'IntervalMin': '1', 'IntervalMax': '1', 'BRAM_18K': '4', 'DSP48E': '3917', 'FF': '26905', 'LUT': '88335', 'URAM': '0', 'AvailableBRAM_18K': '624', 'AvailableDSP48E': '1728', 'AvailableFF': '460800', 'AvailableLUT': '230400', 'AvailableURAM': '96', 'ExecutionTime': 1104.1488993167877, 'WbitsFixedPoint': 16, 'IbitsFixedPoint': 6, 'FPGApart': 'xczu7ev-ffvc1156-2-e'}}
Pool
Python package to allocate a number of jobs non bigger than the number of CPU cores/threads similarly to the -j
option in GNU Parallel.Precision
and ReuseFactor
.