edgetpu.learn.imprinting.engine

A weight imprinting engine that performs low-shot transfer-learning for image classification models.

For more information about how to use this API and how to create the type of model required, see Retrain a classification model on-device with weight imprinting.

Note

We updated ImprintingEngine in the July 2019 library update (version 2.11.1), which requires code changes if you used the previous version. The API changes are as follows:

  • Most importantly, the input model has new architecture requirements. For details, read Retrain a classification model on-device with weight imprinting.
  • The initialization function accepts a new keep_classes boolean to indicate whether you want to keep the pre-trained classes from the provided model.
  • Train() now requires a second argument for the class ID you want to train, thus allowing you to retrain classes with additional data. (It no longer returns the class ID.)
  • TrainAll() requires a different format for the input data. It now uses a list in which each index corresponds to a class ID, and each list entry is an array of training images for that class. (It no longer returns a mapping of label IDs.)
  • New methods ClassifyWithResizedImage() and ClassifyWithInputTensor() allow you to immediately perform inferences, though you can still choose to save the trained model as a .tflite file with SaveModel().
class edgetpu.learn.imprinting.engine.ImprintingEngine(model_path, keep_classes=False)

Performs weight imprinting (transfer learning) with the given model.

Parameters:
  • model_path (str) – Path to the model you want to retrain. This model must be a .tflite file output by the join_tflite_models tool. For more information about how to create a compatible model, read Retrain an image classification model on-device.
  • keep_classes (bool) – If True, keep the existing classes from the pre-trained model (and use training to add additional classes). If False, drop the existing classes and train the model to include new classes only.
ClassifyWithInputTensor(input_tensor, threshold=0.0, top_k=3)

Performs classification with the retrained model using the given raw input tensor.

This requires you to process the input data (the image) and convert it to the appropriately formatted input tensor for your model.

Parameters:
  • input_tensor (numpy.ndarray) – A 1-D array as the input tensor.
  • threshold (float) – Minimum confidence threshold for returned classifications. For example, use 0.5 to receive only classifications with a confidence equal-to or higher-than 0.5.
  • top_k (int) – The maximum number of classifications to return.
Returns:

A list of classifications, each of which is a list [int, float] that represents the label id (int) and the confidence score (float).

Raises:

ValueError – If argument values are invalid.

ClassifyWithResizedImage(img, threshold=0.1, top_k=3)

Performs classification with the retrained model using the given image.

Note: The given image must already be resized to match the model’s input tensor size.

Parameters:
  • img (PIL.Image) – The image you want to classify.
  • threshold (float) – Minimum confidence threshold for returned classifications. For example, use 0.5 to receive only classifications with a confidence equal-to or higher-than 0.5.
  • top_k (int) – The maximum number of classifications to return.
Returns:

A list of classifications, each of which is a list [int, float] that represents the label id (int) and the confidence score (float).

Raises:

ValueError – If argument values are invalid.

SaveModel(output_path)

Saves the newly trained model as a .tflite file.

You can then use the saved model to perform inferencing with using ClassificationEngine. Alternatively, you can immediately perform inferences with the retrained model using the local inferencing methods, ClassifyWithResizedImage() or ClassifyWithInputTensor().

Parameters:output_path (str) – The path and filename where you’d like to save the trained model (must end with .tflite).
Train(input, class_id)

Trains the model with a set of images for one class.

You can use this to add new classes to the model or retrain classes that you previously added using this imprinting API.

Parameters:
  • input (list of numpy.array) – The images to use for training in a single class. Each numpy.array in the list represents an image as a 1-D tensor. You can convert each image to this format by passing it as an PIL.Image to numpy.asarray(). The maximum number of images allowed in the list is 200.
  • class_id (int) – The label id for this class. The index must be either the number of existing classes (to add a new class to the model) or the index of an existing class that was trained using this imprinting API (you can’t retrain classes from the pre-trained model).
TrainAll(input_data)

Trains the model with multiple sets of images for multiple classes.

This essentially calls Train() for each class of images you provide. You can use this to add a batch of new classes or retrain existing classes. Just beware that if you’ve already added new classes using the imprinting API, then the data input here must include the same classes in the same order. Alternatively, you can use Train() to retrain specific classes one at a time.

Parameters:input_data (list of numpy.array) – The images to train for multiple classes. Each numpy.array in the list represents a different class, which itself contains a list of numpy.array objects, which each represent an image as a 1-D tensor. You can convert each image to this format by passing it as a PIL.Image to numpy.asarray().