Skip to content

Introduction to BNNs with Larq

This tutorial demonstrates how to train a simple binarized Convolutional Neural Network (CNN) to classify MNIST digits. This simple network will achieve approximately 98% accuracy on the MNIST test set. This tutorial uses Larq and the Keras Sequential API, so creating and training our model will require only a few lines of code.

import tensorflow as tf
import larq as lq

Download and prepare the MNIST dataset

(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

train_images = train_images.reshape((60000, 28, 28, 1))
test_images = test_images.reshape((10000, 28, 28, 1))

# Normalize pixel values to be between -1 and 1
train_images, test_images = train_images / 127.5 - 1, test_images / 127.5 - 1
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 1s 0us/step

Create the model

The following will create a simple binarized CNN.

The quantization function $$ q(x) = \begin{cases} -1 & x < 0 \\ 1 & x \geq 0 \end{cases} $$ is used in the forward pass to binarize the activations and the latent full precision weights. The gradient of this function is zero almost everywhere which prevents the model from learning.

To be able to train the model the gradient is instead estimated using the Straight-Through Estimator (STE) (the binarization is essentially replaced by a clipped identity on the backward pass): $$ \frac{\partial q(x)}{\partial x} = \begin{cases} 1 & \left|x\right| \leq 1 \\ 0 & \left|x\right| > 1 \end{cases} $$

In Larq this can be done by using input_quantizer="ste_sign" and kernel_quantizer="ste_sign". Additionally, the latent full precision weights are clipped to -1 and 1 using kernel_constraint="weight_clip".

# All quantized layers except the first will use the same options
kwargs = dict(input_quantizer="ste_sign",
              kernel_quantizer="ste_sign",
              kernel_constraint="weight_clip")

model = tf.keras.models.Sequential()

# In the first layer we only quantize the weights and not the input
model.add(lq.layers.QuantConv2D(32, (3, 3),
                                kernel_quantizer="ste_sign",
                                kernel_constraint="weight_clip",
                                use_bias=False,
                                input_shape=(28, 28, 1)))
model.add(tf.keras.layers.MaxPooling2D((2, 2)))
model.add(tf.keras.layers.BatchNormalization(scale=False))

model.add(lq.layers.QuantConv2D(64, (3, 3), use_bias=False, **kwargs))
model.add(tf.keras.layers.MaxPooling2D((2, 2)))
model.add(tf.keras.layers.BatchNormalization(scale=False))

model.add(lq.layers.QuantConv2D(64, (3, 3), use_bias=False, **kwargs))
model.add(tf.keras.layers.BatchNormalization(scale=False))
model.add(tf.keras.layers.Flatten())

model.add(lq.layers.QuantDense(64, use_bias=False, **kwargs))
model.add(tf.keras.layers.BatchNormalization(scale=False))
model.add(lq.layers.QuantDense(10, use_bias=False, **kwargs))
model.add(tf.keras.layers.BatchNormalization(scale=False))
model.add(tf.keras.layers.Activation("softmax"))

Almost all parameters in the network are binarized, so either -1 or 1. This makes the network extremely fast if it would be deployed on custom BNN hardware.

Here is the complete architecture of our model:

lq.models.summary(model)
+sequential stats------------------------------------------------------------------------------------------+
| Layer                  Input prec.           Outputs  # 1-bit  # 32-bit  Memory  1-bit MACs  32-bit MACs |
|                              (bit)                        x 1       x 1    (kB)                          |
+----------------------------------------------------------------------------------------------------------+
| quant_conv2d                     -  (-1, 26, 26, 32)      288         0    0.04           0       194688 |
| max_pooling2d                    -  (-1, 13, 13, 32)        0         0       0           0            0 |
| batch_normalization              -  (-1, 13, 13, 32)        0        64    0.25           0            0 |
| quant_conv2d_1                   1  (-1, 11, 11, 64)    18432         0    2.25     2230272            0 |
| max_pooling2d_1                  -    (-1, 5, 5, 64)        0         0       0           0            0 |
| batch_normalization_1            -    (-1, 5, 5, 64)        0       128    0.50           0            0 |
| quant_conv2d_2                   1    (-1, 3, 3, 64)    36864         0    4.50      331776            0 |
| batch_normalization_2            -    (-1, 3, 3, 64)        0       128    0.50           0            0 |
| flatten                          -         (-1, 576)        0         0       0           0            0 |
| quant_dense                      1          (-1, 64)    36864         0    4.50       36864            0 |
| batch_normalization_3            -          (-1, 64)        0       128    0.50           0            0 |
| quant_dense_1                    1          (-1, 10)      640         0    0.08         640            0 |
| batch_normalization_4            -          (-1, 10)        0        20    0.08           0            0 |
| activation                       -          (-1, 10)        0         0       0           ?            ? |
+----------------------------------------------------------------------------------------------------------+
| Total                                                   93088       468   13.19     2599552       194688 |
+----------------------------------------------------------------------------------------------------------+
+sequential summary----------------------------+
| Total params                      93.6 k     |
| Trainable params                  93.1 k     |
| Non-trainable params              468        |
| Model size                        13.19 KiB  |
| Model size (8-bit FP weights)     11.82 KiB  |
| Float-32 Equivalent               365.45 KiB |
| Compression Ratio of Memory       0.04       |
| Number of MACs                    2.79 M     |
| Ratio of MACs that are binarized  0.9303     |
+----------------------------------------------+

Compile and train the model

Note: This may take a few minutes depending on your system.

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model.fit(train_images, train_labels, batch_size=64, epochs=6)

test_loss, test_acc = model.evaluate(test_images, test_labels)
Train on 60000 samples
Epoch 1/6
60000/60000 [==============================] - 16s 270us/sample - loss: 0.6429 - accuracy: 0.9095
Epoch 2/6
60000/60000 [==============================] - 15s 255us/sample - loss: 0.4732 - accuracy: 0.9622
Epoch 3/6
60000/60000 [==============================] - 15s 251us/sample - loss: 0.4483 - accuracy: 0.9691
Epoch 4/6
60000/60000 [==============================] - 15s 256us/sample - loss: 0.4356 - accuracy: 0.9737
Epoch 5/6
60000/60000 [==============================] - 15s 253us/sample - loss: 0.4304 - accuracy: 0.9754
Epoch 6/6
60000/60000 [==============================] - 15s 257us/sample - loss: 0.4276 - accuracy: 0.9767
10000/10000 [==============================] - 1s 135us/sample - loss: 0.3899 - accuracy: 0.9800

Evaluate the model

print(f"Test accuracy {test_acc * 100:.2f} %")
Test accuracy 98.00 %

As you can see, our simple binarized CNN has achieved a test accuracy of 98 %. Not bad for a few lines of code!