Skip to content

Run this notebook online:Binder

PySDK instruction for using LMI container on SageMaker

In this tutorial, you will use LMI container from DLC to SageMaker and run inference with it.

Please make sure the following permission granted before running the notebook:

  • SageMaker access

Step 1: Let's bump up SageMaker and import stuff

%pip install sagemaker boto3 awscli --upgrade  --quiet
import sagemaker
from sagemaker.djl_inference.model import DJLModel

role = sagemaker.get_execution_role()  # execution role for the endpoint
session = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs

Step 2: Start building SageMaker endpoint

In this step, we will build SageMaker endpoint from scratch

Getting the container image URI (optional)

Check out available images: Large Model Inference available DLC

# Choose a specific version of LMI image directly:
# image_uri = "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.28.0-lmi10.0.0-cu124"

Create SageMaker model

Here we are using LMI PySDK to create the model.

Checkout more configuration options.

# model_id = "s3://YOUR_BUCKET"            # download model from your s3 bucket
model_id = "TheBloke/Llama-2-7B-Chat-fp16" # model will be download form Huggingface hub

env = {
    "TENSOR_PARALLEL_DEGREE": "max",          # make sure min and max Workers are equals when deploy model on GPU
    "OPTION_ROLLING_BATCH": "vllm",           # use vllm for rolling batching
    # "OPTION_MODEL_LOADING_TIMEOUT": "2400", # set model loading timeout in seconds, default 1800
    # "OPTION_XXXX=XXX",                      # set model specific options
}

model = DJLModel(
    model_id=model_id,
    #image_uri=image_uri, # choose a specific version of LMI DLC image
    env=env,
    role=role)

Create SageMaker endpoint

You need to specify the instance to use and endpoint names

instance_type = "ml.g4dn.4xlarge"
endpoint_name = sagemaker.utils.name_from_base("lmi-model")

predictor = model.deploy(initial_instance_count=1,
             instance_type=instance_type,
             endpoint_name=endpoint_name,
             # container_startup_health_check_timeout=3600
            )

Step 3: Test and benchmark the inference

%%timeit -n3 -r1
predictor.predict(
    {"inputs": "Large model inference is", "parameters": {}}
)

Clean up the environment

session.delete_endpoint(endpoint_name)
session.delete_endpoint_config(endpoint_name)
model.delete_model()