LLAMA 7B with customized preprocessing¶
In this tutorial, you will use LMI container from DLC to SageMaker and run inference with it.
Please make sure the following permission granted before running the notebook:
- S3 bucket push access
- SageMaker access
Step 1: Let's bump up SageMaker and import stuff¶
%pip install sagemaker --upgrade --quiet
import boto3
import sagemaker
from sagemaker import Model, image_uris, serializers, deserializers
role = sagemaker.get_execution_role() # execution role for the endpoint
sess = sagemaker.session.Session() # sagemaker session for interacting with different AWS APIs
region = sess._region_name # region name of the current SageMaker Studio environment
account_id = sess.account_id() # account_id of the current SageMaker Studio environment
Step 2: Start preparing model artifacts¶
In LMI contianer, we expect some artifacts to help setting up the model - serving.properties (required): Defines the model server settings - model.py (optional): A python file to define the core inference logic - requirements.txt (optional): Any additional pip wheel need to install
%%writefile serving.properties
engine=MPI
option.model_id=openlm-research/open_llama_7b
option.task=text-generation
option.trust_remote_code=true
option.tensor_parallel_degree=1
option.max_rolling_batch_size=32
option.rolling_batch=lmi-dist
option.dtype=fp16
In this step, we will try to override the default HuggingFace handler provided by DJLServing. We will add an extra parameter checker called password
to see if password is correct in the payload.
%%writefile model.py
from djl_python.huggingface import HuggingFaceService
from djl_python import Output
from djl_python.encode_decode import encode, decode
from djl_python.input_parser import input_formatter
from djl_python.request_io import TextInput
import logging
import json
import types
_service = HuggingFaceService()
@input_formatter
def custom_input_formatter(self, input_item: Input, **kwargs) -> TextInput:
"""
Replace this function with your custom input formatter.
Args:
data (obj): The request data, dict or string
Returns:
(tuple): input_data (list), input_size (list), parameters (dict), errors (dict), batch (list)
"""
content_type = input_item.get_property("Content-Type")
input_map = decode(input_item, content_type)
inputs = input_map.pop("inputs", input_map)
password = input_map.pop("password", "")
# password checker
if inputs != [""] and password != "12345":
raise ValueError("Incorrect password!")
request_input = TextInput()
request_input.input_text = input_map.pop("prompt", input_map)
request_input.parameters = input_map.pop("parameters", {})
return request_input
%%sh
mkdir mymodel
mv serving.properties mymodel/
mv model.py mymodel/
tar czvf mymodel.tar.gz mymodel/
rm -rf mymodel
Step 3: Start building SageMaker endpoint¶
In this step, we will build SageMaker endpoint from scratch
Getting the container image URI¶
image_uri = image_uris.retrieve(
framework="djl-lmi",
region=sess.boto_session.region_name,
version="0.30.0"
)
Upload artifact on S3 and create SageMaker model¶
s3_code_prefix = "large-model-lmi/code"
bucket = sess.default_bucket() # bucket to house artifacts
code_artifact = sess.upload_data("mymodel.tar.gz", bucket, s3_code_prefix)
print(f"S3 Code or Model tar ball uploaded to --- > {code_artifact}")
model = Model(image_uri=image_uri, model_data=code_artifact, role=role)
4.2 Create SageMaker endpoint¶
You need to specify the instance to use and endpoint names
instance_type = "ml.g5.2xlarge"
endpoint_name = sagemaker.utils.name_from_base("lmi-model")
model.deploy(initial_instance_count=1,
instance_type=instance_type,
endpoint_name=endpoint_name,
# container_startup_health_check_timeout=3600
)
# our requests and responses will be in json format so we specify the serializer and the deserializer
predictor = sagemaker.Predictor(
endpoint_name=endpoint_name,
sagemaker_session=sess,
serializer=serializers.JSONSerializer(),
)
Step 5: Test and benchmark the inference¶
Firstly let's try to run with a wrong inputs
predictor.predict(
{"inputs": "Large model inference is", "parameters": {}}
)
Then let's run with the right one
predictor.predict(
{"inputs": "Large model inference is", "parameters": {}, "password": "12345"}
)
Clean up the environment¶
sess.delete_endpoint(endpoint_name)
sess.delete_endpoint_config(endpoint_name)
model.delete_model()