Skip to content

How to convert your PyTorch model to TorchScript

There are two ways to convert your model to TorchScript: tracing and scripting. We will only demonstrate the first one, tracing, but you can find information about scripting from the PyTorch documentation. When tracing, we use an example input to record the actions taken and capture the the model architecture. This works best when your model doesn't have control flow. If you do have control flow, you will need to use the scripting approach. In DJL, we use tracing to create TorchScript for our ModelZoo models.

Basic conversion

Here is an example of tracing in actions:

import torch
import torchvision

# An instance of your model.
model = torchvision.models.resnet18(pretrained=True)

# Switch the model to eval model
model.eval()

# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224)

# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)

# Save the TorchScript model
traced_script_module.save("traced_resnet_model.pt")

Then you can just use the saved model in DJL like all other models.

Advanced case

In some cases, you may have a method name that is not forward in pytorch (HuggingFace) like one below:

from transformers import CLIPModel, CLIPProcessor
import torch
from PIL import Image
import requests
model_name = "openai/clip-vit-base-patch32"

model = CLIPModel.from_pretrained(model_name, torchscript=True, return_dict=False)
processor = CLIPProcessor.from_pretrained(model_name)

test_text = "this is a cat"
text_inputs = processor(text=test_text, return_tensors="pt")
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
image_inputs = processor(images=image, return_tensors="pt")

converted = torch.jit.trace_module(model,  {'get_text_features': [text_inputs['input_ids'], text_inputs['attention_mask']],
                                            'get_image_features': [image_inputs['pixel_values']]})

torch.jit.save(converted, "cliptext.pt")

You can trace by using the torch.traceModule function.

To run inference with such model in DJL, you could provide a placeholder NDArray like below:

NDArray array = NDManager.create("");
array.setName("module_method:get_text_features");

to tell which entryPoint method you would like to use. The placeholder NDArray will NOT be used in inference.

For more information on exporting TorchScript, see the Loading a torch script model in C++ tutorial.