Inference with your model¶
This is the third and final tutorial of our beginner tutorial series that will take you through creating, training, and running inference on a neural network. In this tutorial, you will learn how to execute your image classification model for a production system.
In the previous tutorial, you successfully trained your model. Now, we will learn how to implement a Translator
to convert between POJO and NDArray
as well as a Predictor
to run inference.
Preparation¶
This tutorial requires the installation of the Java Jupyter Kernel. To install the kernel, see the Jupyter README.
// Add the snapshot repository to get the DJL snapshot artifacts
// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/
// Add the maven dependencies
%maven ai.djl:api:0.28.0
%maven ai.djl:model-zoo:0.28.0
%maven ai.djl.mxnet:mxnet-engine:0.28.0
%maven ai.djl.mxnet:mxnet-model-zoo:0.28.0
%maven org.slf4j:slf4j-simple:1.7.36
import java.awt.image.*;
import java.nio.file.*;
import java.util.*;
import java.util.stream.*;
import ai.djl.*;
import ai.djl.basicmodelzoo.basic.*;
import ai.djl.ndarray.*;
import ai.djl.modality.*;
import ai.djl.modality.cv.*;
import ai.djl.modality.cv.util.NDImageUtils;
import ai.djl.translate.*;
Step 1: Load your handwritten digit image¶
We will start by loading the image that we want to run our model to classify.
var img = ImageFactory.getInstance().fromUrl("https://resources.djl.ai/images/0.png");
img.getWrappedImage();
Step 2: Load your model¶
Next, we need to load the model to run inference with. This model should have been saved to the build/mlp
directory when running the previous tutorial.
Path modelDir = Paths.get("build/mlp");
Model model = Model.newInstance("mlp");
model.setBlock(new Mlp(28 * 28, 10, new int[] {128, 64}));
model.load(modelDir);
In addition to loading a local model, you can also find pretrained models within our model zoo. See more options in our model loading documentation.
Step 3: Create a Translator
¶
The Translator
is used to encapsulate the pre-processing and post-processing functionality of your application. The input to the processInput and processOutput should be single data items, not batches.
Translator<image, classifications=""> translator = new Translator<image, classifications="">() {
@Override
public NDList processInput(TranslatorContext ctx, Image input) {
// Convert Image to NDArray
NDArray array = input.toNDArray(ctx.getNDManager(), Image.Flag.GRAYSCALE);
return new NDList(NDImageUtils.toTensor(array));
}
@Override
public Classifications processOutput(TranslatorContext ctx, NDList list) {
// Create a Classifications with the output probabilities
NDArray probabilities = list.singletonOrThrow().softmax(0);
List<string> classNames = IntStream.range(0, 10).mapToObj(String::valueOf).collect(Collectors.toList());
return new Classifications(classNames, probabilities);
}
@Override
public Batchifier getBatchifier() {
// The Batchifier describes how to combine a batch together
// Stacking, the most common batchifier, takes N [X1, X2, ...] arrays to a single [N, X1, X2, ...] array
return Batchifier.STACK;
}
};
Step 4: Create Predictor¶
Using the translator, we will create a new Predictor
. The predictor is the main class to orchestrate the inference process. During inference, a trained model is used to predict values, often for production use cases. The predictor is NOT thread-safe, so if you want to do prediction in parallel, you should call newPredictor multiple times to create a predictor object for each thread.
var predictor = model.newPredictor(translator);
Step 5: Run inference¶
With our predictor, we can simply call the predict method to run inference. For better performance, you can also call batchPredict with a list of input items. Afterwards, the same predictor should be used for further inference calls.
var classifications = predictor.predict(img);
classifications
Summary¶
Now, you've successfully built a model, trained it, and run inference. Congratulations on finishing the beginner tutorial series. After this, you should read our other examples and jupyter notebooks to learn more about DJL.
You can find the complete source code for this tutorial in the examples project.