Skip to content

Stable Diffusion in DJL

Stable Diffusion is an open-source model developed by Stability.ai. It aimed to produce images (artwork, pictures, etc.) based on an input sentence and images.

This example is a basic reimplementation of Stable Diffusion in Java. It can be run with CPU or GPU using the PyTorch engine.

Java solution Developed by:

  • Tyler (Github: tosterberg)
  • Calvin (Github: mymagicpower)
  • Qing (GitHub: lanking520)

Model Architecture

We took four components from the original Stable Diffusion models and traced them in PyTorch:

  • Text Encoder: The CLIP encoder used for text embedding generation
  • Image Encoder: The VAE encoder to build image to embedding
  • Image Decoder: The VAE decoder to convert embedding to image
  • Unet executor: The processing unit for generation

Getting started

We recommend running the model on GPU devices because CPU generation is slow. To run this example, just do:

cd examples
./gradlew run -Dmain=ai.djl.examples.inference.stablediffusion.ImageGeneration

Our input prompt is:

Photograph of an astronaut riding a horse in desert

Output:

Conversion script

Use the below script to export the model:

from diffusers import EulerDiscreteScheduler, UNet2DConditionModel, AutoencoderKL
from transformers import CLIPTextModel, AutoTokenizer
import torch

model_id = "stabilityai/stable-diffusion-2-1"

scheduler = EulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", torchscript=True, return_dict=False)
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torchscript=True, return_dict=False)
unet_model = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet", torchscript=True, return_dict=False)
tokenizer = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer")

vae.eval()
unet_model.eval()
text_encoder.eval()

prompt = "Astronaut riding a horse"
tokenized = tokenizer(prompt,
                      padding="max_length", max_length=tokenizer.model_max_length,
                      truncation=True, return_tensors="pt")

text_embeddings = text_encoder(**tokenized)
traced_text = torch.jit.trace(text_encoder, (tokenized["input_ids"], tokenized['attention_mask']))

latents = torch.randn((1, 4, 64, 64))
latent_model_input = torch.cat([latents] * 2)
text_embeddings = torch.stack([text_embeddings[0], text_embeddings[0]]).squeeze()


def forward(self, sample, timestep, encoder_hidden_states,
            class_labels=None, return_dict: bool = False, ):
    return UNet2DConditionModel.forward(self, sample, timestep, encoder_hidden_states, class_labels, return_dict)


unet_model.forward = forward.__get__(unet_model, UNet2DConditionModel)
traced_unet = torch.jit.trace(unet_model, (latent_model_input, torch.tensor([981]), text_embeddings))


# You need to do sample yourself
def encode(self, x: torch.FloatTensor):
    h = self.encoder(x)
    return self.quant_conv(h)


def decode(self, z: torch.FloatTensor, return_dict: bool = False):
    return AutoencoderKL.decode(self, z, return_dict)


vae.encode = encode.__get__(vae, AutoencoderKL)
vae.decode = decode.__get__(vae, AutoencoderKL)
traced_vae = torch.jit.trace_module(vae, {"encode": [torch.ones((1, 3, 512, 512), dtype=torch.float32)],
                                          "decode": [latents]})
torch.jit.save(traced_vae, "vae_model.pt")
torch.jit.save(traced_unet, "unet_model.pt")
torch.jit.save(traced_text, "clip_model.pt")