Motivation
Many times, we encounter the need to execute our inference tasks on GPU for data stored, for instance, in Amazon S3. In such scenarios, where our inference isn't about creating an endpoint, employing Batch Transform Jobs from AWS SageMaker proves beneficial. In this guide, we'll walk through the process of setting up a job in SageMaker to perform inference on files stored in S3. Additionally, we'll execute all the steps using the aws cli
in bash, ensuring convenient management and storage of resources for future reference.
Build Custom Container
In the initial step, we craft our custom container and subsequently push it to ECR (Amazon Elastic Container Registry). Embracing this approach ensures maximum flexibility in implementation. For our batch transform job, we design a service featuring two APIs named ping
and invocations
. below is app.py
:
import os
from flask import Flask, request
from werkzeug.middleware.proxy_fix import ProxyFix
from code.inference import model_fn, predict_fn
app = Flask(__name__)
# Load the model by reading the `SM_MODEL_DIR` environment variable
# passed to the container by SageMaker (usually /opt/ml/model).
model = model_fn(os.environ["SM_MODEL_DIR"])
# As the web application operates behind a proxy (nginx), we include
# this setting in our app.
app.wsgi_app = ProxyFix(
app.wsgi_app, x_for=1, x_proto=1, x_host=1, x_prefix=1
)
@app.route("/ping", methods=["GET"])
def ping():
"""
Healthcheck function.
"""
return "pong"
@app.route("/invocations", methods=["POST"])
def invocations():
"""
Function to handle invocations requests.
"""
if request.headers.get('Content-Type') == 'application/json':
body = request.json
else:
body = request.get_data().decode('utf-8').strip().split('\n')
return predict_fn(body, model)
This Python script sets up a Flask application, loading a machine learning model from the designated directory. The ping
endpoint serves as a health check, responding with "pong" when accessed via HTTP GET requests. On the other hand, the /invocations
endpoint processes POST requests, invoking the model prediction function predict_fn
with the input data and the loaded model. Depending on the content type of the request, the function handles JSON or raw data accordingly.
As depicted, the service necessitates two essential functions:
model_fn
, is responsible for loading the machine learning model, while the latter, predict_fn
, is invoked when actual inference needs to be executed from the loaded model.
Below, we also provide a sample of code/inference.py
module to briefly explain the internal workings:
import json
import logging
import sys
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler(sys.stdout))
from transformers import pipeline
def model_fn(model_dir):
"""
Function to load the machine learning model.
"""
pipe = pipeline("ner",
model=f'{model_dir}/model',
tokenizer=f'{model_dir}/tokenizer')
return pipe
def predict_fn(data, pipe):
"""
Function to perform inference using the loaded model.
"""
if isinstance(data, dict):
logger.info(f'Data received: {data}')
data_dict = data
elif isinstance(data, list):
logger.info(f'Length of the data: {len(data)}')
data_dict = {"sentences": [json.loads(el)["sentences"] for el in data]}
else:
logger.info(f'Data received: {data}')
data_dict = json.loads(data)
ner_sentences = pipe(data_dict["sentences"])
if isinstance(data, list):
# Batch transform job
return [{"ner_sentences": json.dumps(str(ner_el))} for ner_el in ner_sentences]
# Create response
return {"ner_sentences": json.dumps(str(ner_sentences))}
- The logger configuration is set up to facilitate viewing logs on CloudWatch, hence it's important to include
logging.StreamHandler(sys.stdout))
as a handler for your logger. model_dir
essentially corresponds to theSM_MODEL_DIR
variable, which will be set in the Dockerfile later. SageMaker copies model data to this directory (/opt/ml/model
by convention).- Data in the
predict_fn
function will be received in list format, hence the need to parse it as demonstrated in the code.
With app.py
and code/inference.py
in place, we can create a Dockerfile as follows to define our container:
FROM python:3.10.4
WORKDIR /app
# Pip install for pytorch
RUN pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
COPY . .
ENV SM_MODEL_DIR /opt/ml/model
ENV PYTHONUNBUFFERED 1
ENTRYPOINT ["gunicorn", "-b", "0.0.0.0:8080", "app:app", "-n"]
/app
, copies the current directory's contents into the container's /app
directory, installs dependencies specified in requirements.txt
, exposes port 8080, sets an environment variable NAME
to "World", and finally runs app.py
when the container launches.
Below is the modified version of the bash script with some improvements and enhancements:
#!/bin/bash
# Define variables
REPO_NAME="your_repository_name"
ACCOUNT_ID="your_aws_account_id"
REGION="your_aws_region"
# Build Docker image
docker build -f Dockerfile -t "$REPO_NAME" .
# Tag the Docker image
docker tag "$REPO_NAME" "$ACCOUNT_ID.dkr.ecr.$REGION.amazonaws.com/$REPO_NAME:latest"
# Log in to AWS ECR
aws ecr get-login-password --region "$REGION" | docker login --username AWS --password-stdin "$ACCOUNT_ID.dkr.ecr.$REGION.amazonaws.com"
# Check if repository exists, if not, create it
aws ecr describe-repositories --repository-names "$REPO_NAME" || aws ecr create-repository --repository-name "$REPO_NAME"
# Push Docker image to ECR
docker push "$ACCOUNT_ID.dkr.ecr.$REGION.amazonaws.com/$REPO_NAME:latest"
- Added shebang
#!/bin/bash
at the beginning to explicitly specify the interpreter to use. - Enclosed variable names in double quotes to prevent word splitting and globbing issues.
- Added comments for better readability and understanding of each step.
- Used consistent naming conventions for variables.
- Removed the need for users to provide input during the script execution by hardcoding the variables. However, you can still keep them as input parameters if required.
Create Model
Here, "Model" primarily refers to utilizing the Docker image that has been previously constructed and pushed to Amazon ECR (Elastic Container Registry) for operation with SageMaker jobs. We typically configure parameters such as image URI and model path in Amazon S3, where the model is expected to be loaded. To upload your model to S3, you can follow these steps:
aws s3 cp $MODEL_NAME s3://"$S3_BUCKET_NAME"/$TASK_NAME/
ModelS3Input=s3://"$S3_BUCKET_NAME"/$TASK_NAME/$MODEL_NAME
Using the AWS CLI, you can easily create a model using a configuration JSON file, as demonstrated below. Let's call this file create_model.json
:
{
"ModelName": "ModelNameInput",
"PrimaryContainer": {
"Image": "ImageInput",
"ModelDataUrl": "ModelS3Input"
},
"ExecutionRoleArn": "ExecutionRoleArnInput"
}
and finally:
Create Batch Transform Job
To create a transform job, you can use a JSON configuration file similar to the one below, named create_transform_model.json
:
{
"MaxPayloadInMB": 16,
"ModelName": "ModelNameInput",
"BatchStrategy": "MultiRecord",
"TransformInput": {
"ContentType": "application/jsonl",
"DataSource": {
"S3DataSource": {
"S3DataType": "S3Prefix",
"S3Uri": "S3InputData"
}
},
"SplitType": "Line"
},
"TransformJobName": "TransformNameInput",
"TransformOutput": {
"S3OutputPath": "S3Output",
"Accept": "application/jsonl",
"AssembleWith": "Line"
},
"TransformResources": {
"InstanceCount": 1,
"InstanceType": "ml.g4dn.xlarge"
}
}
Here are some key points to note and corrections:
- The
BatchStrategy
is set to "MultiRecord", indicating that the job will split files into multiple records, with the split type defined as"Line"
. - Ensure to provide the correct S3 URI for the input data under
"S3Uri"
in"S3DataSource"
. - Define the instance type as
"ml.g4dn.xlarge"
under"TransformResources"
. Note that sometimes you may need to request an increase in resources or quota for this instance type. - Other configurations such as
MaxPayloadInMB
,ModelName
,TransformJobName
, and output settings seem appropriately clear what they are.
and finally
Now you can view your model
and job
in the SageMaker inference panel and monitor your logs in Amazon CloudWatch.