Skip to content

Run this notebook online:Binder

LLAMA 7B with customized preprocessing

In this tutorial, you will use LMI container from DLC to SageMaker and run inference with it.

Please make sure the following permission granted before running the notebook:

  • S3 bucket push access
  • SageMaker access

Step 1: Let's bump up SageMaker and import stuff

%pip install sagemaker --upgrade  --quiet
import boto3
import sagemaker
from sagemaker import Model, image_uris, serializers, deserializers

role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
region = sess._region_name  # region name of the current SageMaker Studio environment
account_id = sess.account_id()  # account_id of the current SageMaker Studio environment

Step 2: Start preparing model artifacts

In LMI contianer, we expect some artifacts to help setting up the model - serving.properties (required): Defines the model server settings - model.py (optional): A python file to define the core inference logic - requirements.txt (optional): Any additional pip wheel need to install

import os
os.environ['MODEL_ID'] = "openlm-research/open_llama_7b"
os.environ['HF_TRUST_REMOTE_CODE'] = "TRUE"
model_id = os.getenv('MODEL_ID')
with open('serving.properties', 'w') as f:
    f.write(f"""engine=Python
option.model_id={model_id}
option.tensor_parallel_degree=1
option.dtype=fp16
option.model_loading_timeout=3600
option.trust_remote_code=true

# rolling-batch parameters
option.max_rolling_batch_size=32
option.rolling_batch=scheduler

# seq-scheduler parameters
# limits the max_sparsity in the token sequence caused by padding
option.max_sparsity=0.33
# limits the max number of batch splits, where each split has its own inference call
option.max_splits=3
# other options: contrastive, sample
option.decoding_strategy=greedy
# default: true
option.disable_flash_attn=true
""")

In this step, we will try to override the default HuggingFace handler provided by DJLServing. We will add an extra parameter checker called password to see if password is correct in the payload.

%%writefile model.py
from djl_python.huggingface import HuggingFaceService
from djl_python import Output
from djl_python.encode_decode import encode, decode
import logging
import json
import types

_service = HuggingFaceService()

def custom_parse_input(self, inputs):
    input_data = []
    input_size = []
    parameters = []
    errors = {}
    batch = inputs.get_batches()
    for i, item in enumerate(batch):
        try:
            content_type = item.get_property("Content-Type")
            input_map = decode(item, content_type)
            _inputs = input_map.pop("inputs", input_map)
            password = input_map.pop("password", "")
            # password checker
            if _inputs != [""] and password != "12345":
                raise ValueError("Incorrect password!")
            parameters.append(input_map.pop("parameters", {}))
            if isinstance(_inputs, list):
                input_data.extend(_inputs)
                input_size.append(len(_inputs))
            else:
                input_data.append(_inputs)
                input_size.append(1)
            if "cached_prompt" in input_map:
                parameters[i]["cached_prompt"] = input_map.pop(
                    "cached_prompt")
            seed_key = 'seed' if inputs.is_batch() else f'batch_{i}.seed'
            if item.contains_key(seed_key):
                seed = parameters[i].get("seed")
                if not seed:
                    # set server provided seed if seed is not part of request
                    parameters[i]["seed"] = item.get_as_string(
                        key=seed_key)
        except Exception as e:  # pylint: disable=broad-except
            logging.exception(f"Parse input failed: {i}")
            errors[i] = str(e)

    return input_data, input_size, parameters, errors, batch

def handle(inputs):
    if not _service.initialized:
        _service.initialize(inputs.get_properties())
        # replace parse_input
        _service.parse_input = types.MethodType(custom_parse_input, _service)

    if inputs.is_empty():
        # initialization request
        return None

    return _service.inference(inputs)
%%sh
mkdir mymodel
mv serving.properties mymodel/
mv model.py mymodel/
tar czvf mymodel.tar.gz mymodel/
rm -rf mymodel

Step 3: Start building SageMaker endpoint

In this step, we will build SageMaker endpoint from scratch

Getting the container image URI

Large Model Inference available DLC

image_uri = image_uris.retrieve(
        framework="djl-deepspeed",
        region=sess.boto_session.region_name,
        version="0.27.0"
    )

Upload artifact on S3 and create SageMaker model

s3_code_prefix = "large-model-lmi/code"
bucket = sess.default_bucket()  # bucket to house artifacts
code_artifact = sess.upload_data("mymodel.tar.gz", bucket, s3_code_prefix)
print(f"S3 Code or Model tar ball uploaded to --- > {code_artifact}")

model = Model(image_uri=image_uri, model_data=code_artifact, role=role)

4.2 Create SageMaker endpoint

You need to specify the instance to use and endpoint names

instance_type = "ml.g5.2xlarge"
endpoint_name = sagemaker.utils.name_from_base("lmi-model")

model.deploy(initial_instance_count=1,
             instance_type=instance_type,
             endpoint_name=endpoint_name,
             # container_startup_health_check_timeout=3600
            )

# our requests and responses will be in json format so we specify the serializer and the deserializer
predictor = sagemaker.Predictor(
    endpoint_name=endpoint_name,
    sagemaker_session=sess,
    serializer=serializers.JSONSerializer(),
)

Step 5: Test and benchmark the inference

Firstly let's try to run with a wrong inputs

predictor.predict(
    {"inputs": "Large model inference is", "parameters": {}}
)

Then let's run with the right one

predictor.predict(
    {"inputs": "Large model inference is", "parameters": {}, "password": "12345"}
)

Clean up the environment

sess.delete_endpoint(endpoint_name)
sess.delete_endpoint_config(endpoint_name)
model.delete_model()