DJL BERT Inference Demo¶
Introduction¶
In this tutorial, you walk through running inference using DJL on a BERT QA model trained with MXNet and PyTorch. You can provide a question and a paragraph containing the answer to the model. The model is then able to find the best answer from the answer paragraph.
Example:
Q: When did BBC Japan start broadcasting?
Answer paragraph:
BBC Japan was a general entertainment channel, which operated between December 2004 and April 2006.
It ceased operations after its Japanese distributor folded.
And it picked the right answer:
A: December 2004
One of the most powerful features of DJL is that it's engine agnostic. Because of this, you can run different backend engines seamlessly. We showcase BERT QA first with an MXNet pre-trained model, then with a PyTorch model.
// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/
%maven ai.djl:api:0.28.0
%maven ai.djl.mxnet:mxnet-engine:0.28.0
%maven ai.djl.mxnet:mxnet-model-zoo:0.28.0
%maven ai.djl.pytorch:pytorch-engine:0.28.0
%maven ai.djl.pytorch:pytorch-model-zoo:0.28.0
%maven org.slf4j:slf4j-simple:1.7.36
Import java packages by running the following:¶
import ai.djl.*;
import ai.djl.engine.*;
import ai.djl.modality.nlp.qa.*;
import ai.djl.repository.zoo.*;
import ai.djl.training.util.*;
import ai.djl.inference.*;
import ai.djl.repository.zoo.*;
Now that all of the prerequisites are complete, start writing code to run inference with this example.
Load the model and input¶
First, load the input
var question = "When did BBC Japan start broadcasting?";
var resourceDocument = "BBC Japan was a general entertainment Channel.\n" +
"Which operated between December 2004 and April 2006.\n" +
"It ceased operations after its Japanese distributor folded.";
QAInput input = new QAInput(question, resourceDocument);
Then load the model and vocabulary. Create a variable model
by using the ModelZoo
as shown in the following code.
Criteria<qainput, string=""> criteria = Criteria.builder()
.optApplication(Application.NLP.QUESTION_ANSWER)
.setTypes(QAInput.class, String.class)
.optEngine("MXNet") // For DJL to use MXNet engine
.optProgress(new ProgressBar()).build();
ZooModel<qainput, string=""> model = criteria.loadModel();
Run inference¶
Once the model is loaded, you can call Predictor
and run inference as follows
Predictor<qainput, string=""> predictor = model.newPredictor();
String answer = predictor.predict(input);
answer
Running inference on DJL is that easy. Now, let's try the PyTorch engine by specifying PyTorch engine in Criteria.optEngine("PyTorch"). Let's rerun the inference code.
var question = "When did BBC Japan start broadcasting?";
var resourceDocument = "BBC Japan was a general entertainment Channel.\n" +
"Which operated between December 2004 and April 2006.\n" +
"It ceased operations after its Japanese distributor folded.";
QAInput input = new QAInput(question, resourceDocument);
Criteria<qainput, string=""> criteria = Criteria.builder()
.optApplication(Application.NLP.QUESTION_ANSWER)
.setTypes(QAInput.class, String.class)
.optFilter("modelType", "distilbert")
.optEngine("PyTorch") // Use PyTorch engine
.optProgress(new ProgressBar()).build();
ZooModel<qainput, string=""> model = criteria.loadModel();
Predictor<qainput, string=""> predictor = model.newPredictor();
String answer = predictor.predict(input);
answer