The current set of notebooks are under constant development.
If you have previously cloned the tutorial repository, you may need to get the latest versions of the notebooks.
First check the status of your repository:
cd hls4ml-tutorial
make clean
git status
You may have some modified notebooks. For example:
# On branch csee-e6868-spring2022
# Changes not staged for commit:
# (use "git add <file>..." to update what will be committed)
# (use "git checkout -- <file>..." to discard changes in working directory)
#
# modified: part1_getting_started.ipynb
# modified: part2_advanced_config.ipynb
#
no changes added to commit (use "git add" and/or "git commit -a")
You can make a copy of those modified notebooks if you had significat changes, otherwise the easiest thing to do is to discard those changes.
ATTENTION You will loose your local changes!
git checkout *.ipynb
At this point, you can update you copy of the repository:
git pull
Import packages from TensorFlow, scikit-learn, and NumPy.
from tensorflow.keras.utils import to_categorical
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler
import matplotlib.pyplot as plt
import numpy as np
Use a magic function to include matplotlib graphs in the notebook.
%matplotlib inline
Force a deterministic behaviour with a constant seed. In TensorFlow, tf.random.set_seed
sets a global-random seed; you can also specify operation-level seeds. More details.
seed = 0
np.random.seed(seed)
import tensorflow as tf
tf.random.set_seed(seed)
Specify where the to find the executable of Xilinx Vivado HLS. The path on the Columbia servers is /opt/xilinx/Vivado/2019.1/bin
, but you can change it if you have a local installation of this notebook.
import os
os.environ['PATH'] = '/opt/xilinx/Vivado/2019.1/bin:' + os.environ['PATH']
def is_tool(name):
from distutils.spawn import find_executable
return find_executable(name) is not None
print('-----------------------------------')
if not is_tool('vivado_hls'):
print('Xilinx Vivado HLS is NOT in the PATH')
else:
print('Xilinx Vivado HLS is in the PATH')
print('-----------------------------------')
The jet tagging dataset is publicly available on OpenML.
data = fetch_openml('hls4ml_lhc_jets_hlf')
X, y = data['data'], data['target']
Let's print some information about the dataset (e.g. feature names and the dataset shape).
print('-----------------------------------')
print('Feature names')
print(data['feature_names'])
print('-----------------------------------')
print('Shape of the data and label (target) arrays')
print(X.shape, y.shape)
print('-----------------------------------')
Let's print some data and labels.
import pandas as pd
print('-----------------------------------')
print('\nFirst five samples in the data set')
display(pd.DataFrame(data=X[:5]))
print('\nFirst five labels (targets) in the data set')
display(pd.DataFrame(data=y[:5]))
print('-----------------------------------')
We can visualize the data with boxplots and notice that the distribution of some of the features is larger than others. You can also plot the outliers with showfliers=True
.
plt.boxplot(X, showfliers=False)
_ = plt.xticks(np.arange(1, X.shape[1] + 1), data['feature_names'], rotation=30, ha="right")
As you saw above, the y
target is an array of strings, e.g. ['g', 'w',...] etc.
We need to make this a One Hot encoding for the training phase.
print('-----------------------------------')
print(y[:5]) # Target labels
print('-----------------------------------')
le = LabelEncoder()
y = le.fit_transform(y) # Encode target labels with values
print(y[:5])
print('-----------------------------------')
y = to_categorical(y, 5) # Convert those values to one-hot encoding
print(y[:5])
print('-----------------------------------')
Split the dataset into training (80% of the samples) and test (20%) sets.
X_train_val, X_test, y_train_val, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
print('-----------------------------------')
print('*** Shape of the splitted arrays ***')
print(X_train_val.shape, X_test.shape, y_train_val.shape, y_test.shape)
print('-----------------------------------')
As we have done before, let's plot the boxplots for the training-validation set only.
plt.boxplot(X_train_val, showfliers=False)
_ = plt.xticks(np.arange(1, X_train_val.shape[1] + 1), data['feature_names'], rotation=30, ha="right")
Preprocess the data X
with the StandardScaler
Zi = (Xi - u) / s
where u
is the mean of the training samples and s
is the standard deviation of the training samples. The resulting will have a mean value (closer to) 0 and standard deviation of (closer to) 1.
scaler = StandardScaler()
X_train_val = scaler.fit_transform(X_train_val)
X_test = scaler.transform(X_test)
Finally let's plot the boxplots for the training-validation set after the standard scaling.
plt.boxplot(X_train_val, showfliers=False)
_ = plt.xticks(np.arange(1, X_train_val.shape[1] + 1), data['feature_names'], rotation=30, ha="right")
Save NumPy arrays to files for this notebooks and the next ones (so you do not have to run it again).
np.save('X_train_val.npy', X_train_val)
np.save('X_test.npy', X_test)
np.save('y_train_val.npy', y_train_val)
np.save('y_test.npy', y_test)
np.save('classes.npy', le.classes_)
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Activation, BatchNormalization
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.regularizers import l1
from callbacks import all_callbacks
We create a simple multi-layer perceptron (MLP) model. An MLP consists of at least three dense layers of nodes alternating with activation functions.
model = Sequential()
model.add(Dense(64, input_shape=(16,), name='fc1', kernel_initializer='lecun_uniform', kernel_regularizer=l1(0.0001)))
model.add(Activation(activation='relu', name='relu1'))
model.add(Dense(32, name='fc2', kernel_initializer='lecun_uniform', kernel_regularizer=l1(0.0001)))
model.add(Activation(activation='relu', name='relu2'))
model.add(Dense(32, name='fc3', kernel_initializer='lecun_uniform', kernel_regularizer=l1(0.0001)))
model.add(Activation(activation='relu', name='relu3'))
model.add(Dense(5, name='output', kernel_initializer='lecun_uniform', kernel_regularizer=l1(0.0001)))
model.add(Activation(activation='softmax', name='softmax'))
Plot model. See this post on How do you visualize neural network architectures?
The question mark ?
or None
is for the batch size that is unknown to the model.
tf.keras.utils.plot_model(model, to_file='model.png', show_shapes=True, show_layer_names=True)
If this is the first time you run the notebook train = True
; if you've restarted the notebook kernel after training once, set train = False
to load the trained model from file.
train = True
Once the model is created, you can config the model with losses and metrics with model.compile()
, train the model with model.fit()
.
model_1
.if train:
adam = Adam(lr=0.0001)
model.compile(optimizer=adam, loss=['categorical_crossentropy'], metrics=['accuracy'])
callbacks = all_callbacks(stop_patience = 1000,
lr_factor = 0.5,
lr_patience = 10,
lr_epsilon = 0.000001,
lr_cooldown = 2,
lr_minimum = 0.0000001,
outputDir = 'model_1')
model.fit(X_train_val, y_train_val, batch_size=1024,
epochs=30, validation_split=0.25, shuffle=True,
callbacks = callbacks.callbacks)
else:
from tensorflow.keras.models import load_model
model = load_model('model_1/KERAS_check_best_model.h5')
Check the accuracy.
y_keras = model.predict(X_test)
from sklearn.metrics import accuracy_score
print('-----------------------------------')
print("Keras Accuracy: {}".format(accuracy_score(np.argmax(y_test, axis=1), np.argmax(y_keras, axis=1))))
print('-----------------------------------')
Accuracy may be not the best or only metric that you should consider when you are dealing with a classification problem, expecially with a skewed dataset.
A confusion matrix is the a tool that you can use to get a better understanding of how a classifier perform.
import plotting # Import local package plotting.py
from sklearn.metrics import confusion_matrix
cm = confusion_matrix(y_true=np.argmax(y_test, axis=1), y_pred=np.argmax(y_keras, axis=1))
plt.figure(figsize=(9,9))
_ = plotting.plot_confusion_matrix(cm, le.classes_)
Another tool that you can use is the ROC curve.
A ROC curve (typically) features true positive rate (TPR) on the vertical axis, and false positive rate (FPR) on the horizzontal axis. The top left corner of the plot is the ideal point - a FPR rate of zero, and a TPR of one. This also mean that a larger area under the curve (AUC) is usually better
plt.figure(figsize=(9,9))
_ = plotting.plotMultiClassRoc(y_test, y_keras, le.classes_)
Now we will go through the steps to convert the model we trained to a low-latency optimized FPGA firmware with hls4ml.
hls4ml comes with a Python API so all of the next steps will be run through the notebook and that includes HLS.
# Let's import hls4ml package!
import hls4ml
hls4ml is controlled through an hls4ml configuration dictionary. In this example, we'll use the most simple variation (granularity='model'
), later exercises will look at more advanced configuration.
# Generate a hls4ml configuration dictionary from the Keras model
config = hls4ml.utils.config_from_keras_model(model, granularity='model')
print('-----------------------------------')
# Show the generated configuration dictionary for hls4ml
plotting.print_dict(config)
print('-----------------------------------')
hls_model = hls4ml.converters.convert_from_keras_model(model,
hls_config=config,
output_dir='model_1/hls4ml_prj',
#part='xczu7ev-ffvc1156-2-e') # ZCU106
part='xczu3eg-sbva484-1-e') # Ultra96
#part='xc7z020clg400-1') # Pynq-Z1
#part='xc7z007sclg225-1') # MiniZed
Let's visualise the HLS model that we created. The model architecture is shown annotated with the layer shapes and data types. Please note that we are converting the trained model from a floating-point implementation to a fixed-point implementation. Post-training quantization is a conversion technique that can reduce resource requriments and latency of the final hardware accelerator, with little degradation in model accuracy.
hls4ml.utils.plot_model(hls_model, show_shapes=True, show_precision=True, to_file=None)
Because of the quantization, now we need to check that the HLS-model performance is still good. We first compile the hls_model
.
%%time
hls_model.compile()
Then we use hls_model.predict
to execute the FPGA firmware with bit-accurate emulation on the CPU.
%%time
y_hls = hls_model.predict(np.ascontiguousarray(X_test))
# this an alternative to np.ascontiguousarray()
#X_test = X_test.copy(order='C')
#y_hls = hls_model.predict(X_test)
That was easy! Now let's see how the performance compares to Keras:
print('-----------------------------------')
print("Keras Accuracy: {}".format(accuracy_score(np.argmax(y_test, axis=1), np.argmax(y_keras, axis=1))))
print("hls4ml Accuracy: {}".format(accuracy_score(np.argmax(y_test, axis=1), np.argmax(y_hls, axis=1))))
print('-----------------------------------')
# Enable logarithmic scale on TPR and FPR axes
logscale_tpr = False # Y axis
logscale_fpr = False # X axis
fig, ax = plt.subplots(figsize=(9, 9))
_ = plotting.plotMultiClassRoc(y_test, y_keras, le.classes_, logscale_tpr=logscale_tpr, logscale_fpr=logscale_fpr)
plt.gca().set_prop_cycle(None) # reset the colors
_ = plotting.plotMultiClassRoc(y_test, y_hls, le.classes_, logscale_tpr=logscale_tpr, logscale_fpr=logscale_fpr, linestyle='--')
from matplotlib.lines import Line2D
lines = [Line2D([0], [0], ls='-'),
Line2D([0], [0], ls='--')]
from matplotlib.legend import Legend
leg = Legend(ax, lines, labels=['keras', 'hls4ml'],
loc='center right', frameon=False)
_ = ax.add_artist(leg)
The AUC results for the Keras and hls4ml implementation are really close - up to the second decimal point. You can notice the difference on ROC curves if you apply logaritmic scale on the FPR axis (logscale_fpr=True
).
Now we'll actually use Vivado HLS to synthesize the model (C-Synthesis). We can run the build using a method of our hls_model
object.
After running this step, we can integrate the generated IP into a workflow to compile for a specific FPGA board. In this case, we'll just review the reports that Vivado HLS generates, checking the latency and resource usage.
%%time
hls_results = hls_model.build(csim=False)
This takes approx. 15 minutes on Columbia servers.
While the C-Synthesis is running, we can monitor the progress looking at the log file by opening a terminal from the notebook home, and executing:
tail -f model_1/hls4ml_prj/vivado_hls.log
You can print the HLS results from the synthesis at the previous step.
print('-----------------------------------')
#print(hls_results) # Print hashmap
print("Estimated Clock Period: {} ns".format(hls_results['EstimatedClockPeriod']))
print("Best/Worst Latency: {} / {}".format(hls_results['BestLatency'], hls_results['WorstLatency']))
print("Interval Min/Max: {} / {}".format(hls_results['IntervalMin'], hls_results['IntervalMax']))
print("BRAM_18K: {} (Aval. {})".format(hls_results['BRAM_18K'], hls_results['AvailableBRAM_18K']))
print("DSP48E: {} (Aval. {})".format(hls_results['DSP48E'], hls_results['AvailableDSP48E']))
print("FF: {} (Aval. {})".format(hls_results['FF'], hls_results['AvailableFF']))
print("LUT: {} (Aval. {})".format(hls_results['LUT'], hls_results['AvailableLUT']))
print("URAM: {} (Aval. {})".format(hls_results['URAM'], hls_results['AvailableURAM']))
print('-----------------------------------')
print(hls_results)
You can also view the entire reports generated by Vivado HLS. Pay attention to the Latency and the Utilization Estimates sections.
hls4ml.report.read_vivado_report('model_1/hls4ml_prj/')
The hls_model
and in particular all of the the hls4ml-generated files are in the model_1/hls4ml_prj
directory.
In this tutorial we use the Python API to hls4ml, but the tool comes also with a command line interface
With the current hls4ml configurations, the resource usage that HLS estimates for the current design is greater than the available resources on each of the boards (ZCU106, Ultra96, Pynq-Z1, and MiniZed).
In the next notebooks, we will learn how to reduce the hardware-resource usage without affecting the model accuracy.
Here we summarize the expected latency and resource costs for each of these boards from the previous synthesis runs.
+-----------------------------------------------------------+
| ZCU106 |
+-----------------+---------+-------+--------+--------+-----+
| Name | BRAM_18K| DSP48E| FF | LUT | URAM|
+-----------------+---------+-------+--------+--------+-----+ +-----+-----+-----+-----+----------+
|Total | 4| 3911| 26921| 88404| 0| | Latency | Interval | Pipeline |
+-----------------+---------+-------+--------+--------+-----+ | min | max | min | max | Type |
|Available | 624| 1728| 460800| 230400| 96| +-----+-----+-----+-----+----------+
+-----------------+---------+-------+--------+--------+-----+ | 9| 9| 1| 1| function |
|Utilization (%) | ~0 | 226| 5| 38| 0| +-----+-----+-----+-----+----------+
+-----------------+---------+-------+--------+--------+-----+
+-----------------------------------------------------------+
| Ultra96 |
+-----------------+---------+-------+--------+-------+------+
| Name | BRAM_18K| DSP48E| FF | LUT | URAM|
+-----------------+---------+-------+--------+-------+------+ +-----+-----+-----+-----+----------+
|Total | 4| 3911| 49742| 88564| 0| | Latency | Interval | Pipeline |
+-----------------+---------+-------+--------+-------+------+ | min | max | min | max | Type |
|Available | 432| 360| 141120| 70560| 0| +-----+-----+-----+-----+----------+
+-----------------+---------+-------+--------+-------+------+ | 14| 14| 1| 1| function |
|Utilization (%) | ~0 | 1086| 35| 125| 0| +-----+-----+-----+-----+----------+
+-----------------+---------+-------+--------+-------+------+
+----------------------------------------------------------+
| Pynq-Z1 |
+-----------------+---------+-------+--------+-------+-----+ +-----+-----+-----+-----+----------+
| Name | BRAM_18K| DSP48E| FF | LUT | URAM| | Latency | Interval | Pipeline |
+-----------------+---------+-------+--------+-------+-----+ | min | max | min | max | Type |
|Total | 4| 3911| 270258| 90772| 0| +-----+-----+-----+-----+----------+
+-----------------+---------+-------+--------+-------+-----+ | 52| 52| 1| 1| function |
|Available | 280| 220| 106400| 53200| 0| +-----+-----+-----+-----+----------+
+-----------------+---------+-------+--------+-------+-----+
|Utilization (%) | 1| 1777| 254| 170| 0|
+-----------------+---------+-------+--------+-------+-----+
+----------------------------------------------------------+
| MiniZed |
+-----------------+---------+-------+--------+-------+-----+ +-----+-----+-----+-----+----------+
|Total | 4| 3911| 270258| 90772| 0| | Latency | Interval | Pipeline |
+-----------------+---------+-------+--------+-------+-----+ | min | max | min | max | Type |
|Available | 100| 66| 28800| 14400| 0| +-----+-----+-----+-----+----------+
+-----------------+---------+-------+--------+-------+-----+ | 52| 52| 1| 1| function |
|Utilization (%) | 4| 5925| 938| 630| 0| +-----+-----+-----+-----+----------+
+-----------------+---------+-------+--------+-------+-----+
Since ReuseFactor = 1
we expect each multiplication used in the inference of our neural network to use 1 DSP. Is this what we see? (Note that the Softmax layer should use 5 DSPs, or 1 per class)
Calculate how many multiplications are performed for the inference of this network...
(We'll discuss the outcome)