Skip to content

Run this notebook online:Binder

Detecting Pneumonia from X-ray images using Deep Java Library

Disclaimer: this blog post is intended for educational purposes only. The application was developed using experimental code. The result should not be used for any medical diagnoses of pneumonia. This content has not been reviewed or approved by any scientists or medical professionals.

Introduction

In this example, we demonstrate how deep learning (DL) can be used to detect pneumonia from chest X-ray images. This work is inspired by the Chest X-ray Images Challenge on Kaggle and a related paper. In this notebook, we illustrates how artificial intelligence can assist clinical decision making with focus on enterprise deployment. This work leverages a model trained using Keras and TensorFlow with this Kaggle kernel. In this blog post, we will focus on generating predictions with this model using Deep Java Library (DJL), an open source library to build and deploy DL in Java.

Preparation

This tutorial requires the installation of Java Kernel. To install the Java Kernel, see the documentation.

These are the dependencies we will use:

// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/

%maven ai.djl:api:0.28.0
%maven ai.djl.tensorflow:tensorflow-api:0.28.0
%maven ai.djl.tensorflow:tensorflow-engine:0.28.0
%maven ai.djl.tensorflow:tensorflow-model-zoo:0.28.0
%maven org.slf4j:slf4j-simple:1.7.36
%%loadFromPOM
<dependency>
<groupid>com.google.protobuf</groupid>
<artifactid>protobuf-java</artifactid>
<version>3.19.2</version>
</dependency>

Import java packages

import ai.djl.inference.*;
import ai.djl.modality.*;
import ai.djl.modality.cv.*;
import ai.djl.modality.cv.util.*;
import ai.djl.ndarray.*;
import ai.djl.repository.zoo.*;
import ai.djl.translate.*;
import ai.djl.training.util.*;
import ai.djl.util.*;
import java.net.*;
import java.nio.file.*;
import java.util.*;

set the model URL

var modelUrl = "https://resources.djl.ai/demo/pneumonia-detection-model/saved_model.zip";

Dive deep into Translator

To successfully run inference, we need to define some preprocessing and post processing logic to achieve the best prediction result and understandable output.

class MyTranslator implements Translator<image, classifications=""> {

    private static final List<string> CLASSES = Arrays.asList("Normal", "Pneumonia");

    @Override
    public NDList processInput(TranslatorContext ctx, Image input) {
        NDManager manager = ctx.getNDManager();
        NDArray array = input.toNDArray(manager, Image.Flag.COLOR);
        array = NDImageUtils.resize(array, 224).div(255.0f);
        return new NDList(array);
    }

    @Override
    public Classifications processOutput(TranslatorContext ctx, NDList list) {
        NDArray probabilities = list.singletonOrThrow();
        return new Classifications(CLASSES, probabilities);
    }

    @Override
    public Batchifier getBatchifier() {
        return Batchifier.STACK;
    }
}

As you can see above, the translator resizes the image to 224x224 and normalizes the image by dividing by 255 before feeding it into the model. When doing inference, you need to follow the same pre-processing procedure as was used during training. In this case, we need to match the Keras training code. After running prediction, the model outputs probabilities of each class as an NDArray. We need to tell the predictor to translate it back to classes, namely “Normal” or "Pneumonia".

Until this point, all preparation work is done, we can start working on the prediction logic.

Predict using DJL

Load the image

We are going to load an CT scanned image of an infected lung from internet

var imagePath = "https://resources.djl.ai/images/chest_xray.jpg";
var image = ImageFactory.getInstance().fromUrl(imagePath);
image.getWrappedImage();
No description has been provided for this image

Load your model

Next, we will download the model from modelUrl. This will download the model into the DJL cache location

Criteria<image, classifications=""> criteria =
        Criteria.builder()
                .setTypes(Image.class, Classifications.class)
                .optModelUrls(modelUrl)
                .optTranslator(new MyTranslator())
                .optProgress(new ProgressBar())
                .build();
ZooModel model = criteria.loadModel();
[IJava-executor-0] INFO ai.djl.tensorflow.engine.javacpp.LibUtils - Downloading https://publish.djl.ai/tensorflow-2.10.1/linux/cpu/THIRD_PARTY_TF_JNI_LICENSES.gz ...

[IJava-executor-0] INFO ai.djl.tensorflow.engine.javacpp.LibUtils - Downloading https://publish.djl.ai/tensorflow-2.10.1/linux/cpu/LICENSE.gz ...

[IJava-executor-0] INFO ai.djl.tensorflow.engine.javacpp.LibUtils - Downloading https://publish.djl.ai/tensorflow-2.10.1/linux/cpu/libjnitensorflow.so.gz ...

[IJava-executor-0] INFO ai.djl.tensorflow.engine.javacpp.LibUtils - Downloading https://publish.djl.ai/tensorflow-2.10.1/linux/cpu/libtensorflow_framework.so.2.gz ...

[IJava-executor-0] INFO ai.djl.tensorflow.engine.javacpp.LibUtils - Downloading https://publish.djl.ai/tensorflow-2.10.1/linux/cpu/libtensorflow_cc.so.2.gz ...

Run inference

Lastly, we will need to create a predictor using our model and translator. Once we have a predictor, we simply need to call the predict method on our test image.

Predictor<image, classifications=""> predictor = model.newPredictor();
Classifications classifications = predictor.predict(image);

classifications
[
    {"class": "Pneumonia", "probability": 0.63486}
    {"class": "Normal", "probability": 0.36513}
]