Skip to content

Run this notebook online:Binder

Run this notebook online:Binder

Load MXNet model

In this tutorial, you learn how to load an existing MXNet model and use it to run a prediction task.

Preparation

This tutorial requires the installation of Java Kernel. For more information on installing the Java Kernel, see the README.

// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/

%maven ai.djl:api:0.28.0
%maven ai.djl:model-zoo:0.28.0
%maven ai.djl.mxnet:mxnet-engine:0.28.0
%maven ai.djl.mxnet:mxnet-model-zoo:0.28.0
%maven org.slf4j:slf4j-simple:1.7.36
import java.awt.image.*;
import java.nio.file.*;
import ai.djl.*;
import ai.djl.inference.*;
import ai.djl.ndarray.*;
import ai.djl.modality.*;
import ai.djl.modality.cv.*;
import ai.djl.modality.cv.util.*;
import ai.djl.modality.cv.transform.*;
import ai.djl.modality.cv.translator.*;
import ai.djl.translate.*;
import ai.djl.training.util.*;
import ai.djl.util.*;

Step 1: Prepare your MXNet model

This tutorial assumes that you have a MXNet model trained using Python. A MXNet symbolic model usually contains the following files: * Symbol file: {MODEL_NAME}-symbol.json - a json file that contains network information about the model * Parameters file: {MODEL_NAME}-{EPOCH}.params - a binary file that stores the parameter weight and bias * Synset file: synset.txt - an optional text file that stores classification classes labels

This tutorial uses a pre-trained MXNet resnet18_v1 model.

We use DownloadUtils for downloading files from internet.

DownloadUtils.download("https://mlrepo.djl.ai/model/cv/image_classification/ai/djl/mxnet/resnet/0.0.1/resnet18_v1-symbol.json", "build/resnet/resnet18_v1-symbol.json", new ProgressBar());
DownloadUtils.download("https://mlrepo.djl.ai/model/cv/image_classification/ai/djl/mxnet/resnet/0.0.1/resnet18_v1-0000.params.gz", "build/resnet/resnet18_v1-0000.params", new ProgressBar());
DownloadUtils.download("https://mlrepo.djl.ai/model/cv/image_classification/ai/djl/mxnet/synset.txt", "build/resnet/synset.txt", new ProgressBar());

Step 2: Load your model

Path modelDir = Paths.get("build/resnet");
Model model = Model.newInstance("resnet");
model.load(modelDir, "resnet18_v1");

Step 3: Create a Translator

Pipeline pipeline = new Pipeline();
pipeline.add(new CenterCrop()).add(new Resize(224, 224)).add(new ToTensor());
Translator<image, classifications=""> translator = ImageClassificationTranslator.builder()
                .setPipeline(pipeline)
                .optSynsetArtifactName("synset.txt")
                .optApplySoftmax(true)
                .build();

Step 4: Load image for classification

var img = ImageFactory.getInstance().fromUrl("https://resources.djl.ai/images/kitten.jpg");
img.getWrappedImage()
No description has been provided for this image

Step 5: Run inference

Predictor<image, classifications=""> predictor = model.newPredictor(translator);
Classifications classifications = predictor.predict(img);

classifications
[
    {"class": "n02123045 tabby, tabby cat", "probability": 0.48384}
    {"class": "n02123159 tiger cat", "probability": 0.20599}
    {"class": "n02124075 Egyptian cat", "probability": 0.18810}
    {"class": "n02123394 Persian cat", "probability": 0.06411}
    {"class": "n02127052 lynx, catamount", "probability": 0.01021}
]

Summary

Now, you can load any MXNet symbolic model and run inference.

You might also want to check out load_pytorch_model which demonstrates loading a local model using the ModelZoo API.