Skip to content

Inference with your model

This is the third and final tutorial of our beginner tutorial series that will take you through creating, training, and running inference on a neural network. In this tutorial, you will learn how to execute your image classification model for a production system.

In the previous tutorial, you successfully trained your model. Now, we will learn how to implement a Translator to convert between POJO and NDArray as well as a Predictor to run inference.

Preparation

This tutorial requires the installation of the Java Jupyter Kernel. To install the kernel, see the Jupyter README.

// Add the snapshot repository to get the DJL snapshot artifacts
// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/

// Add the maven dependencies
%maven ai.djl:api:0.6.0
%maven ai.djl:model-zoo:0.6.0
%maven ai.djl.mxnet:mxnet-engine:0.6.0
%maven ai.djl.mxnet:mxnet-model-zoo:0.6.0
%maven org.slf4j:slf4j-api:1.7.26
%maven org.slf4j:slf4j-simple:1.7.26
%maven net.java.dev.jna:jna:5.3.0

// See https://github.com/awslabs/djl/blob/master/mxnet/mxnet-engine/README.md
// for more MXNet library selection options
%maven ai.djl.mxnet:mxnet-native-auto:1.7.0-b
import java.awt.image.*;
import java.nio.file.*;
import java.util.*;
import java.util.stream.*;
import ai.djl.*;
import ai.djl.basicmodelzoo.basic.*;
import ai.djl.ndarray.*;
import ai.djl.modality.*;
import ai.djl.modality.cv.*;
import ai.djl.modality.cv.util.NDImageUtils;
import ai.djl.translate.*;

Step 1: Load your handwritten digit image

We will start by loading the image that we want to run our model to classify.

var img = ImageFactory.getInstance().fromUrl("https://djl-ai.s3.amazonaws.com/resources/images/0.png");
img.getWrappedImage();

Step 2: Load your model

Next, we need to load the model to run inference with. This model should have been saved to the build/mlp directory when running the previous tutorial.

TODO: Mention model zoo? List models in model zoo? TODO: Key Concept ZooModel TODO: Link to Model javadoc

Path modelDir = Paths.get("build/mlp");
Model model = Model.newInstance("mlp");
model.setBlock(new Mlp(28 * 28, 10, new int[] {128, 64}));
model.load(modelDir);

Step 3: Create a Translator

The Translator is used to encapsulate the pre-processing and post-processing functionality of your application. The input to the processInput and processOutput should be single data items, not batches.

Translator<Image, Classifications> translator = new Translator<Image, Classifications>() {

    @Override
    public NDList processInput(TranslatorContext ctx, Image input) {
        // Convert Image to NDArray
        NDArray array = input.toNDArray(ctx.getNDManager(), Image.Flag.GRAYSCALE);
        return new NDList(NDImageUtils.toTensor(array));
    }

    @Override
    public Classifications processOutput(TranslatorContext ctx, NDList list) {
        NDArray probabilities = list.singletonOrThrow().softmax(0);
        List<String> indices = IntStream.range(0, 10).mapToObj(String::valueOf).collect(Collectors.toList());
        return new Classifications(indices, probabilities);
    }

    @Override
    public Batchifier getBatchifier() {
        return Batchifier.STACK;
    }
};

Step 4: Create Predictor

Using the translator, we will create a new Predictor. The predictor is the main class to orchestrate the inference process. During inference, a trained model is used to predict values, often for production use cases. The predictor is NOT thread-safe, so if you want to do prediction in parallel, you should create a predictor object(with the same model) for each thread.

var predictor = model.newPredictor(translator);

Step 5: Run inference

With our predictor, we can simply call the predict method to run inference. Afterwards, the same predictor should be used for further inference calls.

var classifications = predictor.predict(img);

classifications
[
    class: "0", probability: 0.99972
    class: "2", probability: 0.00027
    class: "7", probability: 2.5e-06
    class: "9", probability: 2.2e-06
    class: "6", probability: 3.9e-07
]

Summary

Now, you've successfully built a model, trained it, and run inference. Congratulations on finishing the beginner tutorial series. After this, you should read our other examples and jupyter notebooks to learn more about DJL.

You can find the complete source code for this tutorial in the examples project.