Retrain an image classification model on-device

Instead of retraining a classification model with TensorFlow transfer-learning on your desktop computer (as described in this other tutorial), you can also perform transfer-learning accelerated by the Edge TPU, using the ImprintingEngine API (for classification models only).


This API is based on the weight imprinting technique proposed in Low-Shot Learning with Imprinted Weights. Basically, it uses a trained CNN network (all layers except for the last fully-connected layer) as an embedding (feature) extractor. When given new training images, it calculates embedding vectors using this embedding extractor and imprints the averaged embedding vectors into the weights of the last fully-connected layer. This process does not require backward propagation for retraining, so the retraining can be accelerated on the Edge TPU.

Of course, this strategy has both benefits and drawbacks:


  • Transfer-learning happens on-device, at near-realtime speed.
  • You don't need to recompile the model.


  • Training data size is limited to a max of 200 images per class.
  • It is most suitable only for datasets that have a small inner class variation (see note below).
  • The last fully-connected layer runs on the CPU, not the Edge TPU. So it will be slightly less efficient than running a pre-compiled on Edge TPU.

To show you how this works, we've created a sample script,, which performs on-device transfer-learning using the ImprintingEngine API. The script accepts the embedding extractor model, the new training data, and a value specifying the ratio of data to use for testing. The script then outputs a new classification model that you can immediately run using the ClassificationEngine.

Run the imprinting demo

Follow these steps to try it out using the flowers dataset. If you're using the Dev Board, execute these commands on the board's terminal; if you're using the USB Accelerator, be sure it's connected to the host computer where you'll run these commands.

  1. Download and extract the dataset:

    wget -P ${DEMO_DIR}
    tar zxf ${DEMO_DIR}/flower_photos.tgz -C ${DEMO_DIR}
  2. Download our embedding extractor (this is a version of the CNN network without the final fully-connected layer that's pretrained on ImageNet):

    wget -P ${DEMO_DIR}

    (The section below explains how to create your own embedding extractor.)

  3. Start on-device transfer learning:

    # If you're on the Dev Board:
    cd /usr/lib/python3/dist-packages/edgetpu/
    # If you're using the USB Accelerator: cd python-tflite-source/edgetpu/
    python3 demo/ \ --extractor ${DEMO_DIR}/mobilenet_v1_1.0_224_quant_embedding_extractor_edgetpu.tflite \ --data ${DEMO_DIR}/flower_photos \ --output ${DEMO_DIR}/flower_model.tflite \ --test_ratio 0.95

    The training and evaluation takes 1 - 2 minutes (speed depends on the host platform and input image resolution). You will see it can achieve around 76% top1 accuracy when trained with only 5% of the total data.

    The script arguments are as follows:

    • extractor: Path of embedding extractor, by default it will specify mobilenet_v1_1.0_224_quant_embedding_extractor_edgetpu.tflite under test_data/imprinting.

    • data: Path to the directory of data set. By default it is pointing to test_data/open_image_v4_subset. Please notice that you need to run test_data/ to generate the data set.

    • output: Output name of the trained model. By default it is [extarctor_name]_retrained.tflite.

    • test_ratio: The ratio of images used for test. By default it's 0.25.

  4. Verify the transfer-learned model works by running it through the script:

    # Download a rose image from Open Images:
    wget -O ${DEMO_DIR}/rose.jpg
    python3 demo/ \ --model ${DEMO_DIR}/flower_model.tflite \ --label ${DEMO_DIR}/flower_model.txt \ --image ${DEMO_DIR}/rose.jpg

    It could print results such as this:

    --------------------------- roses Score : 0.226562
    --------------------------- tulips Score : 0.207031
    --------------------------- sunflowers Score : 0.191406


  • You must use the embedding extractor version of a model for retraining.
  • Our demo script requires a specific directory layout for the training data (all images of one class in one subfolder).
  • The API supports a maximum of 200 training images per class. You can use the test_ratio flag to control how many of your images are used for training vs testing.
  • The training script will output a tflite model and the corresponding label file.

About data variation:

You might have noticed that the score difference for flowers demo is not big among different classes. This is because unlike traditional training that uses back-propagation to minimize a loss function, this method simply imprints an averaged embedding vector into the weights of the fully-connected layer. And when your training images of the same class contain a lot of variations (different lighting, angle, aspect ratio, and so on) the imprinted weights act as one set of hyperplanes that separate different classes instead of the best.

However, this method works well when the training images (and the images to classify) have less variations. And in such case, only a few images (less than 10) are needed to train. For example, when an assembly-line produces a new type of defective part, you can use this method to teach the device what a defective part looks like on the fly.

Create an embedding extractor

For your own applications, you'll want to create your own embedding extractor models, instead of using the model provided above, which was trained on a very generic dataset.

An embedding extractor is a subgraph of a pretrained classification model that allows for the last fully-connected layer to be imprinted with new embedding vectors. Assuming you've already trained a classification model with the supported model architectures, you can follow the steps below to create an embedding extractor from that pretrained model.

  1. Identify the feature embedding tensor. A feature embedding tensor is the input tensor of the last classification layer, which is where the new embedding vectors will be imprinted during on-device transfer-learning. For the classification model architectures we officially support, the following table lists their feature embedding tensor names, and the feature dimensions.

    Model name Feature embedding tensor name Size
    mobilenet_v1_1.0_224_quant MobilenetV1/Logits/AvgPool_1a/AvgPool 1024
    mobilenet_v2_1.0_224_quant MobilenetV2/Logits/AvgPool 1280
    inception_v1_224_quant InceptionV1/Logits/AvgPool_0a_7x7/AvgPool 1024
    inception_v2_224_quant InceptionV2/Logits/AvgPool_1a_7x7/AvgPool 1024
    inception_v3_224_quant InceptionV3/Logits/AvgPool_1a_8x8/AvgPool 2048
    inception_v4_224_quant InceptionV4/Logits/AvgPool_1a/AvgPool 1536

    (You can also find the feature embedding tensor name when you visualize your model or list all the layers of your model using tools such as tflite_convert.)

  2. Cut off the last fully-connected layer from the pretrained classification model. Because you'll be imprinting new weights into the last fully-connected layer, your embedding extractor model is just a new version of the existing model but with this last layer removed. So you'll remove this layer using the tflite_convert tool, which converts the TensorFlow frozen graph into the TensorFlow Lite format. You just need to specify the output array that is the input for the last fully-connected layer (the feature embedding tensor).

    For example, the following command extracts the embedding extractor from a MobileNet v1 model, and saves it as a TensorFlow Lite model.

    # Create embedding extractor from MobileNet v1 classification model
    tflite_convert \
    --output_file=mobilenet_v1_embedding_extractor.tflite \
    --graph_def_file=mobilenet_v1_1.0_224_quant_frozen.pb \
    --inference_type=QUANTIZED_UINT8 \
    --mean_values=128 \
    --std_dev_values=128 \
    --input_arrays=input \
  3. Compile the embedding extractor. You now have a version of the embedding extractor that's compiled for a CPU, so you now need to recompile it for the Edge TPU, using the Edge TPU Compiler. (This is no different than compiling a full classification model.)