Skip to content

Run this notebook online:Binder

Bert text embedding inference deployment guide

In this tutorial, you will use LMI container from DLC to SageMaker and run inference with it.

Please make sure the following permission granted before running the notebook:

  • SageMaker access

Step 1: Let's bump up SageMaker and import stuff

%pip install sagemaker --upgrade  --quiet
import sagemaker
from sagemaker.djl_inference.model import DJLModel

role = sagemaker.get_execution_role()  # execution role for the endpoint
session = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs

Step 2: Start building SageMaker endpoint

In this step, we will build SageMaker endpoint from scratch

Getting the container image URI

Check out available images: Large Model Inference available DLC

# Choose a specific version of LMI image directly:
# image_uri = "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.31.0-lmi13.0.0-cu124"

Create SageMaker model

You can deploy model from Huggingface hub or DJL model zoo.

# model_id = "djl://ai.djl.huggingface.rust/BAAI/bge-base-en-v1.5" # use DJL model zoo model
# model_id = "s3://YOUR_BUCKET" # download model from your s3 bucket
model_id = "BAAI/bge-base-en-v1.5" # model will be download form Huggingface hub

env = {
    "SERVING_BATCH_SIZE": "32", # enable dynamic batch with max batch size 32
    "SERVING_MIN_WORKERS": "1", # make sure min and max Workers are equals when deploy model on GPU
    "SERVING_MAX_WORKERS": "1",
}

model = DJLModel(
    model_id=model_id,
    task="text-embedding",
    #engine="Rust",   # explicitly choose Rust engine if needed
    #image_uri=image_uri,     # choose a specific version of LMI DLC image
    env=env,
    role=role)

Create SageMaker endpoint

You need to specify the instance to use and endpoint names

instance_type = "ml.g4dn.2xlarge"
endpoint_name = sagemaker.utils.name_from_base("lmi-text-embedding")

predictor = model.deploy(initial_instance_count=1,
    instance_type=instance_type,
    endpoint_name=endpoint_name,
)

Step 3: Test and benchmark the inference

Let's try to run with an input

predictor.predict(
    {"inputs": "What is Deep Learning?"}
)

You can also make requests with client side batch:

predictor.predict(
    {"inputs": ["What is Deep Learning?", "How does Deep Learning work?"]}
)

Clean up the environment

session.delete_endpoint(endpoint_name)
session.delete_endpoint_config(endpoint_name)
model.delete_model()