%pip install sagemaker --upgrade --quiet
import sagemaker
from sagemaker.djl_inference.model import DJLModel
role = sagemaker.get_execution_role() # execution role for the endpoint
session = sagemaker.session.Session() # sagemaker session for interacting with different AWS APIs
Step 2: Start building SageMaker endpoint¶
In this step, we will build SageMaker endpoint from scratch
Getting the container image URI¶
Check out available images: Large Model Inference available DLC
# Choose a specific version of LMI image directly:
# image_uri = "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.28.0-lmi10.0.0-cu124"
Create SageMaker model¶
You can deploy model from Huggingface hub or DJL model zoo.
# model_id = "djl://ai.djl.huggingface.onnxruntime/BAAI/bge-base-en-v1.5" # use DJL model zoo model
# model_id = "s3://YOUR_BUCKET" # download model from your s3 bucket
model_id = "BAAI/bge-base-en-v1.5" # model will be download form Huggingface hub
env = {
# "SERVING_BATCH_SIZE": "32", # enable dynamic batch with max batch size 32
"SERVING_MIN_WORKERS": "1", # make sure min and max Workers are equals when deploy model on GPU
"SERVING_MAX_WORKERS": "1",
# "OPTION_OPTIMIZATION=O2", # use OnnxRuntime O2 optimization
}
model = DJLModel(
model_id=model_id,
task="text-embedding",
# engine="OnnxRuntime", # explicitly choose OnnxRuntime engine if needed
#image_uri=image_uri, # choose a specific version of LMI DLC image
env=env,
role=role)
Create SageMaker endpoint¶
You need to specify the instance to use and endpoint names
instance_type = "ml.g4dn.2xlarge"
endpoint_name = sagemaker.utils.name_from_base("lmi-text-embedding")
predictor = model.deploy(initial_instance_count=1,
instance_type=instance_type,
endpoint_name=endpoint_name,
)
Step 3: Test and benchmark the inference¶
Let's try to run with an input
predictor.predict(
{"inputs": "What is Deep Learning?"}
)
You can also make requests with client side batch:
predictor.predict(
{"inputs": ["What is Deep Learning?", "How does Deep Learning work?"]}
)
Clean up the environment¶
session.delete_endpoint(endpoint_name)
session.delete_endpoint_config(endpoint_name)
model.delete_model()