Skip to content

Run this notebook online:Binder

Load MXNet model

In this tutorial, you learn how to load an existing MXNet model and use it to run a prediction task.

Preparation

This tutorial requires the installation of Java Kernel. For more information on installing the Java Kernel, see the README.

// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/

%maven ai.djl:api:0.8.0
%maven ai.djl:model-zoo:0.8.0
%maven ai.djl.mxnet:mxnet-engine:0.8.0
%maven ai.djl.mxnet:mxnet-model-zoo:0.8.0
%maven org.slf4j:slf4j-api:1.7.26
%maven org.slf4j:slf4j-simple:1.7.26
%maven net.java.dev.jna:jna:5.3.0

// See https://github.com/awslabs/djl/blob/master/mxnet/mxnet-engine/README.md
// for more MXNet library selection options
%maven ai.djl.mxnet:mxnet-native-auto:1.7.0-backport
import java.awt.image.*;
import java.nio.file.*;
import ai.djl.*;
import ai.djl.inference.*;
import ai.djl.ndarray.*;
import ai.djl.modality.*;
import ai.djl.modality.cv.*;
import ai.djl.modality.cv.util.*;
import ai.djl.modality.cv.transform.*;
import ai.djl.modality.cv.translator.*;
import ai.djl.translate.*;
import ai.djl.training.util.*;
import ai.djl.util.*;

Step 1: Prepare your MXNet model

This tutorial assumes that you have a MXNet model trained using Python. A MXNet symbolic model usually contains the following files: * Symbol file: {MODEL_NAME}-symbol.json - a json file that contains network information about the model * Parameters file: {MODEL_NAME}-{EPOCH}.params - a binary file that stores the parameter weight and bias * Synset file: synset.txt - an optional text file that stores classification classes labels

This tutorial uses a pre-trained MXNet resnet18_v1 model.

We use DownloadUtils for downloading files from internet.

DownloadUtils.download("https://mlrepo.djl.ai/model/cv/image_classification/ai/djl/mxnet/resnet/0.0.1/resnet18_v1-symbol.json", "build/resnet/resnet18_v1-symbol.json", new ProgressBar());
DownloadUtils.download("https://mlrepo.djl.ai/model/cv/image_classification/ai/djl/mxnet/resnet/0.0.1/resnet18_v1-0000.params.gz", "build/resnet/resnet18_v1-0000.params", new ProgressBar());
DownloadUtils.download("https://mlrepo.djl.ai/model/cv/image_classification/ai/djl/mxnet/synset.txt", "build/resnet/synset.txt", new ProgressBar());
Downloading: 100% |████████████████████████████████████████| resnet18_v1-symbol.json
Downloading: 100% |████████████████████████████████████████| resnet18_v1-0000.params
Downloading: 100% |████████████████████████████████████████| synset.txt

Step 2: Load your model

Path modelDir = Paths.get("build/resnet");
Model model = Model.newInstance("resnet");
model.load(modelDir, "resnet18_v1");

Step 3: Create a Translator

Pipeline pipeline = new Pipeline();
pipeline.add(new CenterCrop()).add(new Resize(224, 224)).add(new ToTensor());
Translator<Image, Classifications> translator = ImageClassificationTranslator.builder()
                .setPipeline(pipeline)
                .optSynsetArtifactName("synset.txt")
                .optApplySoftmax(true)
                .build();

Step 4: Load image for classification

var img = ImageFactory.getInstance().fromUrl("https://djl-ai.s3.amazonaws.com/resources/images/kitten.jpg");
img.getWrappedImage()

Step 5: Run inference

Predictor<Image, Classifications> predictor = model.newPredictor(translator);
Classifications classifications = predictor.predict(img);

classifications
[
    class: "n02123045 tabby, tabby cat", probability: 0.48384
    class: "n02123159 tiger cat", probability: 0.20599
    class: "n02124075 Egyptian cat", probability: 0.18810
    class: "n02123394 Persian cat", probability: 0.06411
    class: "n02127052 lynx, catamount", probability: 0.01021
]

Summary

Now, you can load any MXNet symbolic model and run inference.

You might also want to check out load_pytorch_model.ipynb which demonstrates loading a local model using the ModelZoo API.