Retrain a classification model on-device with backpropagation

If you're familiar with backpropagation, then you know it's used to train a neural network by updating the weights in every layer after you determine the model's current loss. However, you can also use backpropagation to update weights for only the last layer, which allows you to retrain your model very quickly. And it's this technique that our SoftmaxRegression Python API provides so you can accelerate transfer-learning with the Edge TPU.

Overview

Ordinarily, because a TensorFlow Lite model must be compiled to run on the Edge TPU, the weights inside the neural network are locked and cannot be modified by training on the device. However, if you remove the last layer from the model before compiling it (thus creating an embedding extractor model that outputs an image embedding), then you can implement the last layer on the device in a way that allows for retraining of that layer. So that's what we do to enable transfer-learning with SoftmaxRegression.

The SoftmaxRegression class is an on-device implementation of the fully-connected layer with softmax activation that performs final classification. And with its APIs, you can train the weights of the layer using stochastic gradient descent (SGD), immediately run inferences using the new weights, and save it as a new .tflite model file.

Of course, this strategy has both benefits and drawbacks:

Benefits:

  • Transfer-learning happens on-device, at near-realtime speed.
  • You don't need to recompile the model.

Drawbacks:

  • The fully-connected layer with softmax activation executes on the CPU, not the Edge TPU. However, this layer represents a very small portion of the overall network, so impact on the inference speed is minimal.
  • It's compatible with image classification models only—officially, only MobileNet and Inception.
Note: We offer an alternative on-device transfer-learning API called ImprintingEngine, which uses weight imprinting instead of backpropagation to update the weights of the last layer. For a comparison of these two techniques, read Transfer-learning on-device.

API summary

The SoftmaxRegression class represents only the softmax layer for a classification model. Unlike the ImprintingEngine, it does not encapsulate the entire model graph. So in order to perform training, you must run training data through the base model (the embedding extractor) and then feed the results to this softmax layer.

The basic procedure to train using backpropagation with the SoftmaxRegression API is as follows:

  1. Create an instance of BasicEngine with your embedding extractor model.

  2. For each training image, call RunInference() and collect the returned image embeddings.

  3. Create an instance of SoftmaxRegression and call train_with_sgd(), passing it all the image embeddings.

Once training completes, you can perform inferences in a similar fashion: Pass the input data to BasicEngine.RunInference() and then pass the image embedding to the softmax layer with SoftmaxRegression.run_inference(). Alternatively, you can call save_as_tflite_model() to save the whole combined graph as a .tflite file and then perform inferences using ClassificationEngine.

See the next section for a walkthrough with our sample code.

Retrain a model with our sample code

To better illustrate how you can use the SoftmaxRegression API, we've created a sample script: backprop_last_layer.py. Follow the below procedure to try it with a flowers dataset.

If you're using the Dev Board, execute these commands on the board's terminal; if you're using the USB Accelerator, be sure it's connected to the host computer where you'll run these commands.

  1. Download and extract the flowers dataset:

    DEMO_DIR=/tmp
    
    wget -P ${DEMO_DIR} http://download.tensorflow.org/example_images/flower_photos.tgz
    tar zxf ${DEMO_DIR}/flower_photos.tgz -C ${DEMO_DIR}
  2. Download our embedding extractor (a version of the neural network without the final fully-connected layer, and pre-trained on ImageNet):

    wget -P ${DEMO_DIR} https://dl.google.com/coral/canned_models/mobilenet_v1_1.0_224_quant_embedding_extractor_edgetpu.tflite
    

    If you want to use your own model, see the section below about how to create your own embedding extractor.

  3. Start on-device transfer learning:

    # If you're on the Dev Board:
    cd /usr/lib/python3/dist-packages/edgetpu/
    
    # If you're using the USB Accelerator: cd /usr/local/lib/python3.6/dist-packages/edgetpu/

    # Start training python3 demo/backprop_last_layer.py \ --data_dir ${DEMO_DIR}/flower_photos \ --embedding_extractor_path \ ${DEMO_DIR}/mobilenet_v1_1.0_224_quant_embedding_extractor_edgetpu.tflite \ --output_dir ${DEMO_DIR}
  4. Try the retrained model works by running it through the classify_image.py script:

    # Download a rose image from Open Images:
    curl -o ${DEMO_DIR}/rose.jpg https://c2.staticflickr.com/4/3062/3067374593_f2963e50b7_o.jpg
    
    python3 demo/classify_image.py \ --model ${DEMO_DIR}/retrained_model_edgetpu.tflite \ --label ${DEMO_DIR}/label_map.txt \ --image ${DEMO_DIR}/rose.jpg

    You should see results such as this:


    --------------------------- roses Score : 0.99609375

Create an embedding extractor

To use this backpropagation technique with your own model, you need to compile your TensorFlow Lite model with its last layer removed. Doing so creates a model called an embedding extractor, which outputs an image embedding (also called a feature embedding tensor).

Separating the embedding extractor allows for the last fully-connected layer to be implemented on-device (with SoftmaxRegression) so we can backpropagate new weights. Assuming you've already trained a classification model with the supported model architectures, you can follow the steps below to create an embedding extractor from that pre-trained model.

  1. Identify the feature embedding tensor. A feature embedding tensor is the input tensor for the last fully-connected layer. For the classification model architectures we officially support, the following table lists their feature embedding tensor names, and the feature dimensions.

    Model name Feature embedding tensor name Size
    mobilenet_v1_1.0_224_quant MobilenetV1/Logits/AvgPool_1a/AvgPool 1024
    mobilenet_v2_1.0_224_quant MobilenetV2/Logits/AvgPool 1280
    inception_v1_224_quant InceptionV1/Logits/AvgPool_0a_7x7/AvgPool 1024
    inception_v2_224_quant InceptionV2/Logits/AvgPool_1a_7x7/AvgPool 1024
    inception_v3_224_quant InceptionV3/Logits/AvgPool_1a_8x8/AvgPool 2048
    inception_v4_224_quant InceptionV4/Logits/AvgPool_1a/AvgPool 1536

    (You can also find the feature embedding tensor name when you visualize your model or list all the layers of your model using tools such as tflite_convert.)

  2. Cut off the last fully-connected layer from the pre-trained classification model. Because you'll be changing the weights in the last fully-connected layer, your embedding extractor model is just a new version of the existing model but with this last layer removed. So you'll remove this layer using the tflite_convert tool, which converts the TensorFlow frozen graph into the TensorFlow Lite format. You just need to specify the output array that is the input for the last fully-connected layer (the feature embedding tensor).

    For example, the following command extracts the embedding extractor from a MobileNet v1 model, and saves it as a TensorFlow Lite model.

    # Create embedding extractor from MobileNet v1 classification model
    tflite_convert \
    --output_file=mobilenet_v1_embedding_extractor.tflite \
    --graph_def_file=mobilenet_v1_1.0_224_quant_frozen.pb \
    --inference_type=QUANTIZED_UINT8 \
    --mean_values=128 \
    --std_dev_values=128 \
    --input_arrays=input \
    --output_arrays=MobilenetV1/Logits/AvgPool_1a/AvgPool
    
  3. Compile the embedding extractor. You now have a version of the embedding extractor that's compiled for a CPU, so you now need to recompile it for the Edge TPU, using the Edge TPU Compiler. (This is no different than compiling a full classification model.)

Now just follow the procedures described in the API summary to perform training, or pass your model to the backprop_last_layer.py demo script.