Detecting Pneumonia from X-ray images using Deep Java Library¶
Disclaimer: this blog post is intended for educational purposes only. The application was developed using experimental code. The result should not be used for any medical diagnoses of pneumonia. This content has not been reviewed or approved by any scientists or medical professionals.
Introduction¶
In this example, we demonstrate how deep learning (DL) can be used to detect pneumonia from chest X-ray images. This work is inspired by the Chest X-ray Images Challenge on Kaggle and a related paper. In this notebook, we illustrates how artificial intelligence can assist clinical decision making with focus on enterprise deployment. This work leverages a model trained using Keras and TensorFlow with this Kaggle kernel. In this blog post, we will focus on generating predictions with this model using Deep Java Library (DJL), an open source library to build and deploy DL in Java.
Preparation¶
This tutorial requires the installation of Java Kernel. To install the Java Kernel, see the documentation.
These are the dependencies we will use:
// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/
%maven ai.djl:api:0.28.0
%maven ai.djl.tensorflow:tensorflow-api:0.28.0
%maven ai.djl.tensorflow:tensorflow-engine:0.28.0
%maven ai.djl.tensorflow:tensorflow-model-zoo:0.28.0
%maven org.slf4j:slf4j-simple:1.7.36
%%loadFromPOM
<dependency>
<groupid>com.google.protobuf</groupid>
<artifactid>protobuf-java</artifactid>
<version>3.19.2</version>
</dependency>
Import java packages¶
import ai.djl.inference.*;
import ai.djl.modality.*;
import ai.djl.modality.cv.*;
import ai.djl.modality.cv.util.*;
import ai.djl.ndarray.*;
import ai.djl.repository.zoo.*;
import ai.djl.translate.*;
import ai.djl.training.util.*;
import ai.djl.util.*;
import java.net.*;
import java.nio.file.*;
import java.util.*;
set the model URL¶
var modelUrl = "https://resources.djl.ai/demo/pneumonia-detection-model/saved_model.zip";
Dive deep into Translator¶
To successfully run inference, we need to define some preprocessing and post processing logic to achieve the best prediction result and understandable output.
class MyTranslator implements Translator<image, classifications=""> {
private static final List<string> CLASSES = Arrays.asList("Normal", "Pneumonia");
@Override
public NDList processInput(TranslatorContext ctx, Image input) {
NDManager manager = ctx.getNDManager();
NDArray array = input.toNDArray(manager, Image.Flag.COLOR);
array = NDImageUtils.resize(array, 224).div(255.0f);
return new NDList(array);
}
@Override
public Classifications processOutput(TranslatorContext ctx, NDList list) {
NDArray probabilities = list.singletonOrThrow();
return new Classifications(CLASSES, probabilities);
}
@Override
public Batchifier getBatchifier() {
return Batchifier.STACK;
}
}
As you can see above, the translator resizes the image to 224x224 and normalizes the image by dividing by 255 before feeding it into the model. When doing inference, you need to follow the same pre-processing procedure as was used during training. In this case, we need to match the Keras training code. After running prediction, the model outputs probabilities of each class as an NDArray. We need to tell the predictor to translate it back to classes, namely “Normal” or "Pneumonia".
Until this point, all preparation work is done, we can start working on the prediction logic.
var imagePath = "https://resources.djl.ai/images/chest_xray.jpg";
var image = ImageFactory.getInstance().fromUrl(imagePath);
image.getWrappedImage();
Load your model¶
Next, we will download the model from modelUrl
. This will download the model into the DJL cache location
Criteria<image, classifications=""> criteria =
Criteria.builder()
.setTypes(Image.class, Classifications.class)
.optModelUrls(modelUrl)
.optTranslator(new MyTranslator())
.optProgress(new ProgressBar())
.build();
ZooModel model = criteria.loadModel();
Run inference¶
Lastly, we will need to create a predictor using our model and translator. Once we have a predictor, we simply need to call the predict method on our test image.
Predictor<image, classifications=""> predictor = model.newPredictor();
Classifications classifications = predictor.predict(image);
classifications