Skip to content

LLAMA2-13B SmoothQuant with Rolling Batch deployment guide

In this tutorial, you will use LMI container from DLC to SageMaker and run inference with it.

Please make sure the following permission granted before running the notebook:

  • S3 bucket push access
  • SageMaker access

Continuous batching (Blog)

Continuous batching is an optimization specific for text generation. It improves throughput and doesn’t sacrifice the time to first byte latency. Continuous batching (also known as iterative or rolling batching) addresses the challenge of idle GPU time and builds on top of the dynamic batching approach further by continuously pushing newer requests in the batch. The following diagram shows continuous batching of requests. When requests 2 and 3 finish processing, another set of requests is scheduled.

Continuous batching

The following interactive diagram dives deeper into how continuous batching works.

Continuous batching

(Courtesy: https://github.com/InternLM/lmdeploy)

You can use a powerful technique to make LLMs and text generation efficient: caching some of the attention matrices. This means that the first pass of a prompt is different from the subsequent forward passes. For the first pass, you have to compute the entire attention matrix, whereas the follow-ups only require you to compute the new token attention. The first pass is called prefill throughout this code base, whereas the follow-ups are called decode. Because prefill is much more expensive than decode, we don’t want to do it all the time, but a currently running query is probably doing decode. If we want to use continuous batching as explained previously, we need to run prefill at some point in order to create the attention matrix required to be able to join the decode group.

This technique may allow up to a 20-times increase in throughput compared to no batching by effectively utilizing the idle GPUs.

Step 1: Let's bump up SageMaker and import stuff

%pip install sagemaker --upgrade  --quiet
import boto3
import sagemaker
from sagemaker import Model, image_uris, serializers, deserializers

role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
region = sess._region_name  # region name of the current SageMaker Studio environment
account_id = sess.account_id()  # account_id of the current SageMaker Studio environment

Step 2: Start preparing model artifacts

In LMI contianer, we expect some artifacts to help setting up the model - serving.properties (required): Defines the model server settings - model.py (optional): A python file to define the core inference logic - requirements.txt (optional): Any additional pip wheel need to install

%%writefile serving.properties
engine=DeepSpeed
option.model_id=TheBloke/Llama-2-13B-fp16
option.tensor_parallel_degree=1
option.dtype=fp16
option.quantize=smoothquant
option.rolling_batch=deepspeed
option.max_tokens=1024
option.max_rolling_batch_size=4
%%sh
mkdir mymodel
mv serving.properties mymodel/
tar czvf mymodel.tar.gz mymodel/
rm -rf mymodel

Step 3: Start building SageMaker endpoint

In this step, we will build SageMaker endpoint from scratch

Getting the container image URI

Large Model Inference available DLC

image_uri = image_uris.retrieve(
        framework="djl-deepspeed",
        region=sess.boto_session.region_name,
        version="0.27.0"
    )

Upload artifact on S3 and create SageMaker model

s3_code_prefix = "large-model-lmi/code"
bucket = sess.default_bucket()  # bucket to house artifacts
code_artifact = sess.upload_data("mymodel.tar.gz", bucket, s3_code_prefix)
print(f"S3 Code or Model tar ball uploaded to --- > {code_artifact}")

model = Model(image_uri=image_uri, model_data=code_artifact, role=role)

4.2 Create SageMaker endpoint

You need to specify the instance to use and endpoint names

instance_type = "ml.g5.2xlarge"
endpoint_name = sagemaker.utils.name_from_base("lmi-model")

model.deploy(initial_instance_count=1,
             instance_type=instance_type,
             endpoint_name=endpoint_name,
             # container_startup_health_check_timeout=3600
            )

# our requests and responses will be in json format so we specify the serializer and the deserializer
predictor = sagemaker.Predictor(
    endpoint_name=endpoint_name,
    sagemaker_session=sess,
    serializer=serializers.JSONSerializer(),
)

Step 5: Test and benchmark the inference

Once deployed, we can run the following inference

predictor.predict(
    {"inputs": "Deep Learning is", "parameters": {"max_new_tokens":128, "do_sample":True}}
)

Clean up the environment

sess.delete_endpoint(endpoint_name)
sess.delete_endpoint_config(endpoint_name)
model.delete_model()