Skip to content

Serve Multiple Fine-Tuned LoRA Adapters with DJL Serving (Advanced)

This notebook will demonstrate how you can deploy multiple fine-tuned LoRA adapters with a single base model copy on SageMaker using the DJL Serving Large Model Inference DLC. LoRA (Low Rank Adapters) is a powerful technique for fine-tuning large language models. This technique significantly reduces the number of trainable parameters compared to traditional fine-tuning while achieving comparable or superior performance. You can learn more about the LoRA technique in this paper.

A major benefit of LoRA is that the fine-tuned adapters can easily be added to and removed from the base model, which makes switching adapters pretty cheap and viable at runtime. In this notebook we will show how you can deploy a SageMaker endpoint with a single base model and multiple LoRA adapters, and change adapters for different requests.

Since LoRA adapters are much smaller than the size of a base model (can realistically be 100x-1000x smaller), we can deploy an endpoint with a single base model and multiple LoRA adapters using much less hardware than deploying an equivalent number of fully fine-tuned models.

The example we will work through in this notebook is guided by the multi adapter example in HuggingFace's PEFT library: https://github.com/huggingface/peft/blob/main/examples/multi_adapter_examples/PEFT_Multi_LoRA_Inference.ipynb.

This is the advanced notebook demonstrating the usage of a custom handler. For the basic usage, see the main adapters notebook.

Install Packages and Import Dependencies

!pip install huggingface_hub sagemaker boto3 awscli --upgrade --quiet
import sagemaker
from sagemaker import image_uris
import boto3
import os
import time
import json
from pathlib import Path
from sagemaker.utils import name_from_base
from huggingface_hub import snapshot_download

Download Model Artifacts and Upload to S3

We will be deploying an endpoint with 2 LoRA adapters. These are the models we will be using: - Base Model: https://huggingface.co/huggyllama/llama-7b - LoRA Fine Tuned Adapter 1: https://huggingface.co/tloen/alpaca-lora-7b - LoRA Fine Tuned Adapter 2: https://huggingface.co/22h/cabrita-lora-v0-1

!rm -rf lora-multi-adapter
!mkdir -p lora-multi-adapter/adapters
snapshot_download("tloen/alpaca-lora-7b", local_dir="lora-multi-adapter/adapters/eng_alpaca", local_dir_use_symlinks=False)
snapshot_download("22h/cabrita-lora-v0-1", local_dir="lora-multi-adapter/adapters/portuguese_alpaca", local_dir_use_symlinks=False)

Creating Inference Handler and DJL Serving Configuration

The following files cover the model server configuration (serving.properties) and custom inference handler (model.py). The custom inference handler is optional and if not specified, default handler from djl-serving will be used. This configuration can be used as an example to write your own inference handler for different models.

The core structure to cover here is the model directory. We include both the base model and LoRA adapters in the model directory like this:

|- model_dir
    |- adapters/
        |--- <adapter_1>/
        |--- <adapter_2>/
        |--- ...
        |--- <adapter_n>/
    |- serving.properties
    |- model.py (optional)

Each of the adapters in the adapters directory contains the LoRA adapter artifacts. Typically there are two files: adapter_model.bin and adapter_config.json which are the adapter weights and adapter configuration respectively. These are typically obtained from the Peft library via the PeftModel.save_pretrained() method.

%%writefile lora-multi-adapter/serving.properties
engine=Python
option.model_id=huggyllama/llama-7b
option.dtype=fp16
option.entryPoint=model.py
option.tensor_parallel_degree=1
load_on_devices=0
%%writefile lora-multi-adapter/model.py
from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig
from peft import PeftModel
import torch
import os
from djl_python.inputs import Input
from djl_python.outputs import Output
import logging

model = None
tokenizer = None

def generate_prompt(instruction, input=None):
    if input:
        return f"""Below is an instruction that describes a task, paired with an input that provides further context. 
        Write a response that appropriately completes the request. ### Instruction: {instruction} ### Input: {input} 
        ### Response:"""
    else:
        return f"""Below is an instruction that describes a task. Write a response that appropriately completes the 
        request.### Instruction: {instruction} ### Response:"""


def evaluate(
        instruction,
        adapters,
        input=None,
        max_new_tokens=64,
        **kwargs,
):
    prompts = []
    for inp in instruction:
        prompts.append(generate_prompt(inp, input))
    inputs = tokenizer(prompts, return_tensors="pt", padding=True)
    input_ids = inputs["input_ids"].to(torch.cuda.current_device())
    attention_mask = inputs["attention_mask"].to(torch.cuda.current_device())
    generation_config = GenerationConfig(num_beams=1, do_sample=False)

    logging.info(f"using adapters: {adapters}")
    with torch.no_grad():
        generation_output = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            adapters=adapters,
            generation_config=generation_config,
            max_new_tokens=max_new_tokens,
        )
    output = tokenizer.batch_decode(generation_output, skip_special_tokens=True)
    return output


def load_model(model_id):
    model = LlamaForCausalLM.from_pretrained(
        model_id,
        low_cpu_mem_usage=True,
        torch_dtype=torch.float16,
        device_map="auto"
    )
    tokenizer = LlamaTokenizer.from_pretrained(model_id)
    if not tokenizer.pad_token:
        tokenizer.pad_token = '[PAD]'
    logging.info(f"Loaded Base Model {model_id}")
    return model, tokenizer


def register_adapter(inputs: Input):
    """
    Registers lora adapter with the model.
    """
    global model
    adapter_name = inputs.get_property("name")
    adapter_model_id_or_path = inputs.get_property("src")
    logging.info(
        f"Registering adapter {adapter_name} from {adapter_model_id_or_path}")
    if isinstance(model, PeftModel):
        model.load_adapter(adapter_model_id_or_path, adapter_name)
    else:
        model = PeftModel.from_pretrained(model,
                                           adapter_model_id_or_path,
                                           adapter_name)


def handle(inputs: Input):
    global model, tokenizer
    if not model:
        properties = inputs.get_properties()
        model_id = properties.get("model_id")
        model, tokenizer = load_model(model_id)

    if inputs.is_empty():
        return None


    json_inputs = inputs.get_as_json()
    sentence = json_inputs.get("inputs")
    adapters = json_inputs.get("adapters", [])
    generation_kwargs = json_inputs.get("parameters", {})
    outputs = evaluate(sentence, adapters, **generation_kwargs)

    return Output().add_as_json(outputs)
!rm -f model.tar.gz
!rm -rf lora-multi-adapter/.ipynb_checkpoints
!tar czvf model.tar.gz -C lora-multi-adapter .

Create SageMaker Model and Endpoint

role = "arn:aws:iam::125045733377:role/AmazonSageMaker-ExecutionRole-djl"  # execution role for the endpoint
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
bucket = sess.default_bucket()  # bucket to house artifacts
model_bucket = sess.default_bucket()  # bucket to house artifacts
s3_code_prefix = "hf-large-model-djl/lora-multi-adapter"  # folder within bucket where code artifact will go

region = sess._region_name
account_id = sess.account_id()

s3_client = boto3.client("s3")
sm_client = boto3.client("sagemaker")
smr_client = boto3.client("sagemaker-runtime")
s3_code_artifact_accelerate = sess.upload_data("model.tar.gz", bucket, s3_code_prefix)
inference_image_uri = image_uris.retrieve(
        framework="djl-deepspeed",
        region=region,
        version="0.27.0"
    )model_name_acc = name_from_base(f"lora-multi-adapter")

# LoRA Adapters feature is a preview feature and ENABLE_ADAPTERS_PREVIEW environmnet variable should be set to use it
create_model_response = sm_client.create_model(
    ModelName=model_name_acc,
    ExecutionRoleArn=role,
    PrimaryContainer={"Image": inference_image_uri,
                      "ModelDataUrl": s3_code_artifact_accelerate,
                     })
model_arn = create_model_response["ModelArn"]
endpoint_config_name = f"{model_name_acc}-config"
endpoint_name = f"{model_name_acc}-endpoint"

endpoint_config_response = sm_client.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            "VariantName": "variant1",
            "ModelName": model_name_acc,
            "InstanceType": "ml.g5.12xlarge",
            "InitialInstanceCount": 1,
            "ModelDataDownloadTimeoutInSeconds": 1800,
            "ContainerStartupHealthCheckTimeoutInSeconds": 1800,
        },
    ],
)
print(f"endpoint_name: {endpoint_name}")
create_endpoint_response = sm_client.create_endpoint(
    EndpointName=f"{endpoint_name}", EndpointConfigName=endpoint_config_name
)
import time

resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
status = resp["EndpointStatus"]
print("Status: " + status)

while status == "Creating":
    time.sleep(60)
    resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
    status = resp["EndpointStatus"]
    print("Status: " + status)

print("Arn: " + resp["EndpointArn"])
print("Status: " + status)

Make Inference Requests

%%time

response_model = smr_client.invoke_endpoint(
    EndpointName=endpoint_name,
    Body=json.dumps({"inputs": ["Tell me about Alpacas", "Invente uma desculpa criativa pra dizer que não preciso ir à festa.", "Tell me about AWS"],
                     "adapters": ["eng_alpaca", "portuguese_alpaca", "eng_alpaca"]}),
    ContentType="application/json",
)

response_model["Body"].read().decode("utf8")

Inference Request targetting the base model without any adapters

Clean up Resources

sm_client.delete_endpoint(EndpointName=endpoint_name)
sm_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)