// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/
%maven ai.djl:api:0.28.0
%maven ai.djl:model-zoo:0.28.0
%maven ai.djl.mxnet:mxnet-engine:0.28.0
%maven ai.djl.mxnet:mxnet-model-zoo:0.28.0
%maven org.slf4j:slf4j-simple:1.7.36
import java.awt.image.*;
import java.nio.file.*;
import ai.djl.*;
import ai.djl.inference.*;
import ai.djl.ndarray.*;
import ai.djl.modality.*;
import ai.djl.modality.cv.*;
import ai.djl.modality.cv.util.*;
import ai.djl.modality.cv.transform.*;
import ai.djl.modality.cv.translator.*;
import ai.djl.translate.*;
import ai.djl.training.util.*;
import ai.djl.util.*;
Step 1: Prepare your MXNet model¶
This tutorial assumes that you have a MXNet model trained using Python. A MXNet symbolic model usually contains the following files: * Symbol file: {MODEL_NAME}-symbol.json - a json file that contains network information about the model * Parameters file: {MODEL_NAME}-{EPOCH}.params - a binary file that stores the parameter weight and bias * Synset file: synset.txt - an optional text file that stores classification classes labels
This tutorial uses a pre-trained MXNet resnet18_v1
model.
We use DownloadUtils
for downloading files from internet.
DownloadUtils.download("https://mlrepo.djl.ai/model/cv/image_classification/ai/djl/mxnet/resnet/0.0.1/resnet18_v1-symbol.json", "build/resnet/resnet18_v1-symbol.json", new ProgressBar());
DownloadUtils.download("https://mlrepo.djl.ai/model/cv/image_classification/ai/djl/mxnet/resnet/0.0.1/resnet18_v1-0000.params.gz", "build/resnet/resnet18_v1-0000.params", new ProgressBar());
DownloadUtils.download("https://mlrepo.djl.ai/model/cv/image_classification/ai/djl/mxnet/synset.txt", "build/resnet/synset.txt", new ProgressBar());
Step 2: Load your model¶
Path modelDir = Paths.get("build/resnet");
Model model = Model.newInstance("resnet");
model.load(modelDir, "resnet18_v1");
Step 3: Create a Translator
¶
Pipeline pipeline = new Pipeline();
pipeline.add(new CenterCrop()).add(new Resize(224, 224)).add(new ToTensor());
Translator<image, classifications=""> translator = ImageClassificationTranslator.builder()
.setPipeline(pipeline)
.optSynsetArtifactName("synset.txt")
.optApplySoftmax(true)
.build();
Step 4: Load image for classification¶
var img = ImageFactory.getInstance().fromUrl("https://resources.djl.ai/images/kitten.jpg");
img.getWrappedImage()
Step 5: Run inference¶
Predictor<image, classifications=""> predictor = model.newPredictor(translator);
Classifications classifications = predictor.predict(img);
classifications
Summary¶
Now, you can load any MXNet symbolic model and run inference.
You might also want to check out load_pytorch_model which demonstrates loading a local model using the ModelZoo API.