Skip to content

Inference with your model

This is the third and final tutorial of our beginner tutorial series that will take you through creating, training, and running inference on a neural network. In this tutorial, you will learn how to execute your image classification model for a production system.

In the previous tutorial, you successfully trained your model. Now, we will learn how to implement a Translator to convert between POJO and NDArray as well as a Predictor to run inference.

Preparation

This tutorial requires the installation of the Java Jupyter Kernel. To install the kernel, see the Jupyter README.

// Add the snapshot repository to get the DJL snapshot artifacts
// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/

// Add the maven dependencies
%maven ai.djl:api:0.27.0
%maven ai.djl:model-zoo:0.27.0
%maven ai.djl.mxnet:mxnet-engine:0.27.0
%maven ai.djl.mxnet:mxnet-model-zoo:0.27.0
%maven org.slf4j:slf4j-simple:1.7.36
import java.awt.image.*;
import java.nio.file.*;
import java.util.*;
import java.util.stream.*;
import ai.djl.*;
import ai.djl.basicmodelzoo.basic.*;
import ai.djl.ndarray.*;
import ai.djl.modality.*;
import ai.djl.modality.cv.*;
import ai.djl.modality.cv.util.NDImageUtils;
import ai.djl.translate.*;

Step 1: Load your handwritten digit image

We will start by loading the image that we want to run our model to classify.

var img = ImageFactory.getInstance().fromUrl("https://resources.djl.ai/images/0.png");
img.getWrappedImage();
No description has been provided for this image

Step 2: Load your model

Next, we need to load the model to run inference with. This model should have been saved to the build/mlp directory when running the previous tutorial.

Path modelDir = Paths.get("build/mlp");
Model model = Model.newInstance("mlp");
model.setBlock(new Mlp(28 * 28, 10, new int[] {128, 64}));
model.load(modelDir);

In addition to loading a local model, you can also find pretrained models within our model zoo. See more options in our model loading documentation.

Step 3: Create a Translator

The Translator is used to encapsulate the pre-processing and post-processing functionality of your application. The input to the processInput and processOutput should be single data items, not batches.

Translator<image, classifications=""> translator = new Translator<image, classifications="">() {

    @Override
    public NDList processInput(TranslatorContext ctx, Image input) {
        // Convert Image to NDArray
        NDArray array = input.toNDArray(ctx.getNDManager(), Image.Flag.GRAYSCALE);
        return new NDList(NDImageUtils.toTensor(array));
    }

    @Override
    public Classifications processOutput(TranslatorContext ctx, NDList list) {
        // Create a Classifications with the output probabilities
        NDArray probabilities = list.singletonOrThrow().softmax(0);
        List<string> classNames = IntStream.range(0, 10).mapToObj(String::valueOf).collect(Collectors.toList());
        return new Classifications(classNames, probabilities);
    }

    @Override
    public Batchifier getBatchifier() {
        // The Batchifier describes how to combine a batch together
        // Stacking, the most common batchifier, takes N [X1, X2, ...] arrays to a single [N, X1, X2, ...] array
        return Batchifier.STACK;
    }
};

Step 4: Create Predictor

Using the translator, we will create a new Predictor. The predictor is the main class to orchestrate the inference process. During inference, a trained model is used to predict values, often for production use cases. The predictor is NOT thread-safe, so if you want to do prediction in parallel, you should call newPredictor multiple times to create a predictor object for each thread.

var predictor = model.newPredictor(translator);

Step 5: Run inference

With our predictor, we can simply call the predict method to run inference. For better performance, you can also call batchPredict with a list of input items. Afterwards, the same predictor should be used for further inference calls.

var classifications = predictor.predict(img);

classifications
[
    {"class": "0", "probability": 0.99999}
    {"class": "2", "probability": 5.9e-06}
    {"class": "9", "probability": 2.2e-06}
    {"class": "6", "probability": 1.5e-07}
    {"class": "8", "probability": 1.1e-07}
]

Summary

Now, you've successfully built a model, trained it, and run inference. Congratulations on finishing the beginner tutorial series. After this, you should read our other examples and jupyter notebooks to learn more about DJL.

You can find the complete source code for this tutorial in the examples project.