Import TensorFlow models in DJL¶
This document will how you how to import various TensorFlow models in DJL. DJL only supports loading TensorFlow models in SavedModel format, we will walk you through the steps to convert other model formats to SavedModel.
How to import Keras models in DJL¶
In DJL TensorFlow engine and model zoo, only SavedModel format (.pb files) is supported. However, many Keras users save their model using keras.model.save API and it produce a .h5 file.
This document shows you how to convert a .h5 model file into TensorFlow SavedModel(.pb) file so it can be imported in DJL. All the code here are Python code, you need to install TensorFlow for Python.
For example, if you have a ResNet50 with trained weight, you can directly save it in SavedModel format using tf.saved_model.save
.
Here we just use pre-trained weights of ResNet50 from Keras Applications:
import tensorflow as tf
import tensorflow.keras as keras
resnet = keras.applications.ResNet50()
tf.saved_model.save(resnet, "resnet/1/")
However, if you already saved your model in .h5 file using keras.model.save
like below, you need a simple python script
to convert it to SavedModel format.
resnet.save("resnet.h5")
Just load your .h5 model file back to Keras and save it to SavedModel:
loaded_model = keras.models.load_model("resnet.h5")
tf.saved_model.save(loaded_model, "resnet/1/")
Once you have a SavedModel, you can load your Keras model using DJL TensorFlow engine. Refer to How to load models in DJL.
How to import TensorFlow Hub models¶
DJL support loading models directly from TensorFlow Hub, you can use the optModelUrls
method in Critera.Builder
to specify the model URL.
The URL can be a http link that points to the TensorFlow Hub models, or local path points to the model downloaded from TensorFlow Hub. Note that you need to click the download model button to find the actual Google Storage link that hosts the TensorFlow Hub model.
Please refer to these two examples:
- Object Detection with TensorFlow for loading from TensorFlow Hub url.
- BERT Classification for loading from local downloaded model.
How to import TensorFlow models that use TensorFlow Extensions¶
You can use DJL to run TensorFlow models that require ops available in TensorFlow extensions (not in core), such as models that use ops in TensorFlow Text.
To use such models in DJL you must know:
- Which extension contains the additional required ops
- The TensorFlow extension library version that is compatible with the version of Tensorflow currently being used in DJL
Here we show an example that uses DJL to run the Multilingual Universal Sentence Encoder. This is a text encoder from tensorflow hub that uses ops from the TensorFlow Text extension.
- You need to download the library files for the extension so that we can load them for use. Here are instructions for downloading TensorFlow text. We'll use version
2.7.0
in this example. - After the library files are downloaded, you are ready to load them for use in DJL. You do this by using
TensorFlow Java
TensorFlow.loadLibrary(...)
API before loading the model in DJL. - We can update the existing UniversalSentenceEncoder example to use the Multilingual model:
// Build Criteria for Multilingual Universal Sentence Encoder from TensorFlow Hub
String modelUrl = "https://storage.googleapis.com/tfhub-modules/google/universal-sentence-encoder-multilingual/3.tar.gz"
Criteria<String[], float[][]> criteria =
Criteria.builder()
.optApplication(Application.NLP.TEXT_EMBEDDING)
.setTypes(String[].class, float[][].class)
.optModelUrls(modelUrl)
.optTranslator(new MyTranslator())
.optEngine("TensorFlow")
.optProgress(new ProgressBar())
.build();
// Make sure TensorFlow engine is loaded before loading extension
Engine.getEngine("TensorFlow");
// Load TensorFlow Text libraries for necessary ops
// Example Path of installation on Linux may look like: /home/<user>/.local/python-3.8.3/lib/python3.8/site-packages/tensorflow_text/
TensorFlow.loadLibrary("<path_to_library_installation>/python/ops/_sentencepiece_tokenizer.so");
// Run the model
try (ZooModel<String[], float[][]> model = criteria.loadModel();
Predictor<String[], float[][]> predictor = model.newPredictor()) {
return predictor.predict(inputs.toArray(new String[0]));
}
How to load TensorFlow Checkpoints¶
To load an TensorFlow Estimator checkpoint, you need to convert it to SavedModel format in using Python. 1. Load the checkpoint file back as an Estimator using Python API. 2. Export the Estimator to SavedModel using the Estimator.export_saved_model API.
Here is an example:
# define model_fn, serving_input_receiver_fn based on training script
estimator = tf.estimator.Estimator(model_fn, "./checkpoint_dir")
estimator.export_saved_model("./export_dir", serving_input_receiver_fn)
For more information on saving, loading and exporting checkpoints, please refer to TensorFlow documentation.
How to load DJL TensorFlow model zoo models¶
The steps are the same as loading any other DJL model zoo models, you can use the Criteria
API as documented here.
Note for TensorFlow image classification models, you need to manually specify the translator instead of using the built-in one because
TensorFlow requires channels last ("NHWC") image formats while DJL use channels first ("NCHW") image formats. By default, DJL will add
ToTensor
transformation in the translator, which will convert the image to channels first format. TensorFlow models does not need this step.
Here is an example:
ImageClassificationTranslator myTranslator =
ImageClassificationTranslator.builder()
.addTransform(new Resize(224, 224))
.addTransform(array -> array.div(127.5f).sub(1f))
.build();
Criteria<Image, Classifications> criteria =
Criteria.builder()
.optApplication(Application.CV.IMAGE_CLASSIFICATION)
.setTypes(Image.class, Classifications.class)
.optTranslator(myTranslator)
.optFilter("flavor", "v2")
.optProgress(new ProgressBar())
.build();
ZooModel<Image, Classifications> model = criteria.loadModel();
How to import TensorFlow 1.x models¶
DJL supports TensorFlow models trained using both 1.x and 2.x. When loading the model we need to know the Tags name saved in the SavedModel(.pb) file. By default, the tag name for 1.x models is "" (an empty String), and for 2.x models it's "serve". DJL by default will use "serve" to load the model.
You will also need to know what's the SignatureDefs needed, they can be specified by SignatureDefKeys.
By default, DJL will use "serving_default"
as the key which is common for TF 2.x models.
Most TF 1.x models use "default" as the key.
In summary, you need to specify "Tags" and "SignatureDefKey" using optOption
or optOptions
when loading 1.x models in Criteria
.
Here is an example loading SSD model trained in TensorFlow 1.x:
Criteria<Image, DetectedObjects> criteria =
Criteria.builder()
.optApplication(Application.CV.OBJECT_DETECTION)
.setTypes(Image.class, DetectedObjects.class)
.optFilter("backbone", "mobilenet_v2")
.optEngine("TensorFlow")
.optOption("Tags", "")
.optOption("SignatureDefKey", "default")
.optProgress(new ProgressBar())
.build();
ZooModel<Image, DetectedObjects> model = criteria.loadModel();
Tips and tricks when writing Translator for TensorFlow models¶
It's very common you have to write your own translator for a new TensorFlow model, you can follow these steps:
- Identify the required model inputs and outputs shapes, orders and names:
You can find out by either calling the model.describeInput
and model.describeOutput
methods after you load the model,
or using the Saved Model CLI tool.
It will show you the number of inputs and number of outputs required, their names, and orders.
- Write
processInput
andprocessOutput
methods:
The input of processInput
can be any Object and output must be NDList
for DJL to run inference.
You can convert your inputs to NDArrays in the order described by model.describeInput()
,
and add names for each NDArray using NDArray.setName
method.
Then you can assemble the NDArrays into a NDList and return it to the Predictor
.
Similarly, processOutput
takes a NDList as input and convert it back to any Object.
For models with multiple outputs, you already know the outputs orders by calling model.describeOutput
in step 1.
You can also use NDArray.getName
to figure out what does the output represent.