Skip to content

Run this notebook online:Binder

DJL BERT Inference Demo

Introduction

In this tutorial, you walk through running inference using DJL on a BERT QA model trained with MXNet and PyTorch. You can provide a question and a paragraph containing the answer to the model. The model is then able to find the best answer from the answer paragraph.

Example:

Q: When did BBC Japan start broadcasting?

Answer paragraph:

BBC Japan was a general entertainment channel, which operated between December 2004 and April 2006.
It ceased operations after its Japanese distributor folded.

And it picked the right answer:

A: December 2004

One of the most powerful features of DJL is that it's engine agnostic. Because of this, you can run different backend engines seamlessly. We showcase BERT QA first with an MXNet pre-trained model, then with a PyTorch model.

Preparation

This tutorial requires the installation of Java Kernel. To install the Java Kernel, see the README.

// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/

%maven ai.djl:api:0.28.0
%maven ai.djl.mxnet:mxnet-engine:0.28.0
%maven ai.djl.mxnet:mxnet-model-zoo:0.28.0
%maven ai.djl.pytorch:pytorch-engine:0.28.0
%maven ai.djl.pytorch:pytorch-model-zoo:0.28.0
%maven org.slf4j:slf4j-simple:1.7.36

Import java packages by running the following:

import ai.djl.*;
import ai.djl.engine.*;
import ai.djl.modality.nlp.qa.*;
import ai.djl.repository.zoo.*;
import ai.djl.training.util.*;
import ai.djl.inference.*;
import ai.djl.repository.zoo.*;

Now that all of the prerequisites are complete, start writing code to run inference with this example.

Load the model and input

First, load the input

var question = "When did BBC Japan start broadcasting?";
var resourceDocument = "BBC Japan was a general entertainment Channel.\n" +
    "Which operated between December 2004 and April 2006.\n" +
    "It ceased operations after its Japanese distributor folded.";

QAInput input = new QAInput(question, resourceDocument);

Then load the model and vocabulary. Create a variable model by using the ModelZoo as shown in the following code.

Criteria<qainput, string=""> criteria = Criteria.builder()
    .optApplication(Application.NLP.QUESTION_ANSWER)
    .setTypes(QAInput.class, String.class)
    .optEngine("MXNet") // For DJL to use MXNet engine
    .optProgress(new ProgressBar()).build();
ZooModel<qainput, string=""> model = criteria.loadModel();

Run inference

Once the model is loaded, you can call Predictor and run inference as follows

Predictor<qainput, string=""> predictor = model.newPredictor();
String answer = predictor.predict(input);
answer
december 2004

Running inference on DJL is that easy. Now, let's try the PyTorch engine by specifying PyTorch engine in Criteria.optEngine("PyTorch"). Let's rerun the inference code.

var question = "When did BBC Japan start broadcasting?";
var resourceDocument = "BBC Japan was a general entertainment Channel.\n" +
    "Which operated between December 2004 and April 2006.\n" +
    "It ceased operations after its Japanese distributor folded.";

QAInput input = new QAInput(question, resourceDocument);

Criteria<qainput, string=""> criteria = Criteria.builder()
    .optApplication(Application.NLP.QUESTION_ANSWER)
    .setTypes(QAInput.class, String.class)
    .optFilter("modelType", "distilbert")
    .optEngine("PyTorch") // Use PyTorch engine
    .optProgress(new ProgressBar()).build();
ZooModel<qainput, string=""> model = criteria.loadModel();
Predictor<qainput, string=""> predictor = model.newPredictor();
String answer = predictor.predict(input);
answer
[IJava-executor-0] INFO ai.djl.pytorch.engine.PtEngine - PyTorch graph executor optimizer is enabled, this may impact your inference latency and throughput. See: https://docs.djl.ai/docs/development/inference_performance_optimization.html#graph-executor-optimization

[IJava-executor-0] INFO ai.djl.pytorch.engine.PtEngine - Number of inter-op threads is 4

[IJava-executor-0] INFO ai.djl.pytorch.engine.PtEngine - Number of intra-op threads is 2

december 2004

Summary

Suprisingly, there are no differences between the PyTorch code snippet and MXNet code snippet. This is power of DJL. We define a unified API where you can switch to different backend engines on the fly. Next chapter: Inference with your own BERT: MXNet PyTorch.