Skip to content

Run this notebook online:Binder

Detecting Pneumonia from X-ray images using Deep Java Library

Disclaimer: this blog post is intended for educational purposes only. The application was developed using experimental code. The result should not be used for any medical diagnoses of pneumonia. This content has not been reviewed or approved by any scientists or medical professionals.

Introduction

In this example, we demonstrate how deep learning (DL) can be used to detect pneumonia from chest X-ray images. This work is inspired by the Chest X-ray Images Challenge on Kaggle and a related paper. In this notebook, we illustrates how artificial intelligence can assist clinical decision making with focus on enterprise deployment. This work leverages a model trained using Keras and TensorFlow with this Kaggle kernel. In this blog post, we will focus on generating predictions with this model using Deep Java Library (DJL), an open source library to build and deploy DL in Java.

Preparation

This tutorial requires the installation of Java Kernel. To install the Java Kernel, see the documentation.

These are the dependencies we will use:

// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/

%maven ai.djl:api:0.8.0
%maven ai.djl.tensorflow:tensorflow-api:0.8.0
%maven ai.djl.tensorflow:tensorflow-engine:0.8.0
%maven ai.djl.tensorflow:tensorflow-model-zoo:0.8.0
%maven org.bytedeco:javacpp:1.5.4
%maven org.slf4j:slf4j-simple:1.7.26

// See https://github.com/awslabs/djl/blob/master/tensorflow/tensorflow-engine/README.md
// for more TensorFlow library selection options
%maven ai.djl.tensorflow:tensorflow-native-auto:2.3.1
%%loadFromPOM
<dependency>
    <groupId>com.google.protobuf</groupId>
    <artifactId>protobuf-java</artifactId>
    <version>3.8.0</version>
</dependency>

Import java packages

import ai.djl.inference.*;
import ai.djl.modality.*;
import ai.djl.modality.cv.*;
import ai.djl.modality.cv.util.*;
import ai.djl.ndarray.*;
import ai.djl.repository.zoo.*;
import ai.djl.translate.*;
import ai.djl.training.util.*;
import ai.djl.util.*;
import java.net.*;
import java.nio.file.*;
import java.util.*;

set the model URL

var modelUrl = "https://djl-ai.s3.amazonaws.com/resources/demo/pneumonia-detection-model/saved_model.zip";

Dive deep into Translator

To successfully run inference, we need to define some preprocessing and post processing logic to achieve the best prediction result and understandable output.

class MyTranslator implements Translator<Image, Classifications> {

    private static final List<String> CLASSES = Arrays.asList("Normal", "Pneumonia");

    @Override
    public NDList processInput(TranslatorContext ctx, Image input) {
        NDManager manager = ctx.getNDManager();
        NDArray array = input.toNDArray(manager, Image.Flag.COLOR);
        array = NDImageUtils.resize(array, 224).div(255.0f);
        return new NDList(array);
    }

    @Override
    public Classifications processOutput(TranslatorContext ctx, NDList list) {
        NDArray probabilities = list.singletonOrThrow();
        return new Classifications(CLASSES, probabilities);
    }

    @Override
    public Batchifier getBatchifier() {
        return Batchifier.STACK;
    }
}

As you can see above, the translator resizes the image to 224x224 and normalizes the image by dividing by 255 before feeding it into the model. When doing inference, you need to follow the same pre-processing procedure as was used during training. In this case, we need to match the Keras training code. After running prediction, the model outputs probabilities of each class as an NDArray. We need to tell the predictor to translate it back to classes, namely “Normal” or "Pneumonia".

Until this point, all preparation work is done, we can start working on the prediction logic.

Predict using DJL

Load the image

We are going to load an CT scanned image of an infected lung from internet

var imagePath = "https://djl-ai.s3.amazonaws.com/resources/images/chest_xray.jpg";
var image = ImageFactory.getInstance().fromUrl(imagePath);
image.getWrappedImage();