Welcome to this getting started guide. We will use the Hugging Face Inference DLCs and Amazon SageMaker Python SDK to create a real-time inference endpoint running a Sentence Transformers for document embeddings. Currently, the SageMaker Hugging Face Inference Toolkit supports the pipeline feature from Transformers for zero-code deployment. This means you can run compatible Hugging Face Transformer models without providing pre- & post-processing code. Therefore we only need to provide an environment variable HF_TASK
and HF_MODEL_ID
when creating our endpoint and the Inference Toolkit will take care of it. This is a great feature if you are working with existing pipelines.
If you want to run other tasks, such as creating document embeddings, you can the pre- and post-processing code yourself, via an inference.py
script. The Hugging Face Inference Toolkit allows the user to override the default methods of the HuggingFaceHandlerService
.
The custom module can override the following methods:
model_fn(model_dir)
overrides the default method for loading a model. The return value model
will be used in thepredict_fn
for predictions.model_dir
is the the path to your unzipped model.tar.gz
.input_fn(input_data, content_type)
overrides the default method for pre-processing. The return value data
will be used in predict_fn
for predictions. The inputs are:input_data
is the raw body of your request.content_type
is the content type from the request header.predict_fn(processed_data, model)
overrides the default method for predictions. The return value predictions
will be used in output_fn
.model
returned value from model_fn
methondprocessed_data
returned value from input_fn
methodoutput_fn(prediction, accept)
overrides the default method for post-processing. The return value result
will be the response to your request (e.g.JSON
). The inputs are:predictions
is the result from predict_fn
.accept
is the return accept type from the HTTP Request, e.g. application/json
.In this example are we going to use Sentence Transformers to create sentence embeddings using a mean pooling layer on the raw representation.
NOTE: You can run this demo in Sagemaker Studio, your local machine, or Sagemaker Notebook Instances
%pip install sagemaker --upgrade
Install git
and git-lfs
# For notebook instances (Amazon Linux)
!sudo yum update -y
!curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.rpm.sh | sudo bash
!sudo yum install git-lfs git -y
# For other environments (Ubuntu)
!sudo apt-get update -y
!curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | sudo bash
!sudo apt-get install git-lfs git -y
import sagemaker
import boto3
sess = sagemaker.Session()
# sagemaker session bucket -> used for uploading data, models and logs
# sagemaker will automatically create this bucket if it not exists
sagemaker_session_bucket=None
if sagemaker_session_bucket is None and sess is not None:
# set to default bucket if a bucket name is not given
sagemaker_session_bucket = sess.default_bucket()
try:
role = sagemaker.get_execution_role()
except ValueError:
iam = boto3.client('iam')
role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']
sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)
print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")
Couldn't call 'get_role' to get Role ARN from role name philippschmid to get Role path.
sagemaker role arn: arn:aws:iam::558105141721:role/sagemaker_execution_role sagemaker bucket: sagemaker-us-east-1-558105141721 sagemaker session region: us-east-1
inference.py
script¶To use the custom inference script, you need to create an inference.py
script. In our example, we are going to overwrite the model_fn
to load our sentence transformer correctly and the predict_fn
to apply mean pooling.
We are going to use the sentence-transformers/all-MiniLM-L6-v2 model. It maps sentences & paragraphs to a 384 dimensional dense vector space and can be used for tasks like clustering or semantic search.
!mkdir code
%%writefile code/inference.py
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
# Helper: Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
def model_fn(model_dir):
# Load model from HuggingFace Hub
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModel.from_pretrained(model_dir)
return model, tokenizer
def predict_fn(data, model_and_tokenizer):
# destruct model and tokenizer
model, tokenizer = model_and_tokenizer
# Tokenize sentences
sentences = data.pop("inputs", data)
encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
# Compute token embeddings
with torch.no_grad():
model_output = model(**encoded_input)
# Perform pooling
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
# Normalize embeddings
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
# return dictonary, which will be json serializable
return {"vectors": sentence_embeddings[0].tolist()}
Overwriting code/inference.py
model.tar.gz
with inference script and model¶To use our inference.py
we need to bundle it into a model.tar.gz
archive with all our model-artifcats, e.g. pytorch_model.bin
. The inference.py
script will be placed into a code/
folder. We will use git
and git-lfs
to easily download our model from hf.co/models and upload it to Amazon S3 so we can use it when creating our SageMaker endpoint.
repository = "sentence-transformers/all-MiniLM-L6-v2"
model_id=repository.split("/")[-1]
s3_location=f"s3://{sess.default_bucket()}/custom_inference/{model_id}/model.tar.gz"
git clone
.!git lfs install
!git clone https://huggingface.co/$repository
Updated git hooks. Git LFS initialized. Cloning into 'all-MiniLM-L6-v2'... remote: Enumerating objects: 25, done. remote: Counting objects: 100% (25/25), done. remote: Compressing objects: 100% (23/23), done. remote: Total 25 (delta 3), reused 0 (delta 0).00 KiB/s Unpacking objects: 100% (25/25), 308.60 KiB | 454.00 KiB/s, done.
inference.py
into the code/
directory of the model directory.!cp -r code/ $model_id/code/
model.tar.gz
archive with all the model artifacts and the inference.py
script.%cd $model_id
!tar zcvf model.tar.gz *
/Users/philipp/.Trash/all-MiniLM-L6-v2/all-MiniLM-L6-v2 a 1_Pooling a 1_Pooling/config.json a README.md a code a code/inference.py a config.json a config_sentence_transformers.json a data_config.json a modules.json a pytorch_model.bin a sentence_bert_config.json a special_tokens_map.json a tokenizer.json a tokenizer_config.json a train_script.py a vocab.txt
model.tar.gz
to Amazon S3:!aws s3 cp model.tar.gz $s3_location
upload: ./model.tar.gz to s3://sagemaker-us-east-1-558105141721/custom_inference/all-MiniLM-L6-v2/model.tar.gz
HuggingfaceModel
¶After we have created and uploaded our model.tar.gz
archive to Amazon S3. Can we create a custom HuggingfaceModel
class. This class will be used to create and deploy our SageMaker endpoint.
from sagemaker.huggingface.model import HuggingFaceModel
# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
model_data=s3_location, # path to your model and script
role=role, # iam role with permissions to create an Endpoint
transformers_version="4.26", # transformers version used
pytorch_version="1.13", # pytorch version used
py_version='py39', # python version used
)
# deploy the endpoint endpoint
predictor = huggingface_model.deploy(
initial_instance_count=1,
instance_type="ml.g4dn.xlarge"
)
-----------!
HuggingfacePredictor
¶The .deploy()
returns an HuggingFacePredictor
object which can be used to request inference.
data = {
"inputs": "the mesmerizing performances of the leads keep the film grounded and keep the audience riveted .",
}
res = predictor.predict(data=data)
print(res)
{'vectors': [0.005078191868960857, -0.0036594511475414038, 0.016988741233944893, -0.0015786211006343365, 0.030203675851225853, 0.09331899881362915, -0.0235157310962677, 0.011795195750892162, 0.03421774506568909, -0.027907833456993103, -0.03260169178247452, 0.0679800882935524, 0.015223750844597816, 0.025948498398065567, -0.07854384928941727, -0.0023915462661534548, 0.10089637339115143, 0.0014981384156271815, -0.017778029665350914, 0.005812637507915497, 0.02445339597761631, -0.0710371807217598, 0.04755859822034836, 0.026360979303717613, -0.05716250091791153, -0.0940014198422432, 0.047949012368917465, 0.008600219152867794, 0.03297032043337822, -0.06984368711709976, -0.0552142858505249, -0.03234352916479111, -0.0003443364112172276, 0.012479404918849468, -0.07419367134571075, 0.08545409888029099, 0.019597113132476807, 0.005851477384567261, -0.08256848156452179, 0.010150186717510223, 0.028275227174162865, -0.0016121627995744348, 0.04174523428082466, -0.009756717830896378, 0.03546829894185066, -0.0673336461186409, 0.013293622992932796, -0.047809384763240814, -0.02249010093510151, 0.028243854641914368, -0.08043544739484787, -0.01009676605463028, -0.03514788672327995, -0.021383730694651604, -0.002246067626401782, -0.015066167339682579, 0.04234122484922409, -0.040479838848114014, 0.00787312351167202, -0.04465996101498604, 0.010779906995594501, 0.0038497159257531166, -0.027719097211956978, -0.007967316545546055, 0.02942546270787716, -0.012327964417636395, 0.0050182887353003025, 0.06450540572404861, 0.03108026832342148, 0.042792391031980515, 0.023805316537618637, -0.01616135612130165, 0.02578461915254593, -0.08669176697731018, -0.044727668166160583, 7.097257184796035e-05, -0.10924965143203735, -0.10867254436016083, -0.03139006346464157, -0.03511088714003563, 0.08570166677236557, -0.134019672870636, -0.0005924605648033321, 0.029533952474594116, 0.012721308507025242, 0.02152288891375065, 0.0707324892282486, -0.11056605726480484, -0.1083742305636406, 0.0982309952378273, -0.039475709199905396, -0.05996376648545265, -0.10398901998996735, 0.03040657937526703, -0.03018292225897312, -0.03471128270030022, -0.06378458440303802, 0.016372960060834885, 0.0583597756922245, 0.012307470664381981, 0.04363206401467323, -0.031246762722730637, -0.09203378111124039, -0.0062785972841084, 0.015498220920562744, -0.07184164226055145, 0.012648160569369793, 0.014564670622348785, -0.08191244304180145, 0.023379981517791748, -0.011096887290477753, 0.0394676998257637, -0.033372823148965836, 0.041654154658317566, 0.0863155946135521, 0.015705395489931107, 0.01734650880098343, 0.08271384239196777, 0.022032614797353745, 0.03559378534555435, 0.12214990705251694, 0.032827410846948624, 0.026021108031272888, -0.019847815856337547, 0.010051277466118336, -0.04892867058515549, -0.0174998976290226, -1.4977462088666326e-33, -0.01998828910291195, -0.020090218633413315, 0.009214007295668125, 0.029388802126049995, 0.01617312990128994, 0.003455288475379348, -0.07258066534996033, 0.049684278666973114, -0.06154271960258484, 0.05080917105078697, 0.05352963134646416, -0.011941409669816494, -0.0028067785315215588, -0.041576843708753586, -0.010775507427752018, 0.00046661923988722265, 0.004454561043530703, 0.030003147199749947, -0.0516991950571537, -0.030697643756866455, -0.07532348483800888, 0.05465441197156906, -0.0385969914495945, -0.04381357878446579, -0.03235914930701256, 0.017494583502411842, 0.005240216851234436, 0.06198848783969879, -0.03355488181114197, 0.011264801025390625, -0.02115759812295437, 0.00838891975581646, -0.058978889137506485, -0.00011408641876187176, 0.05079993978142738, 0.015300493687391281, -0.07043343037366867, -0.07872467488050461, 0.09050456434488297, 0.03952907398343086, -0.07477521151304245, 0.03615942969918251, -0.058201417326927185, 0.0326484851539135, -0.03198658302426338, 0.11224830150604248, -0.016622459515929222, 0.0504615381360054, -0.04651995375752449, 0.1277347207069397, 0.03776664286851883, 0.05948572978377342, 0.09149560332298279, -0.009857898578047752, 0.004627745598554611, 0.03188807889819145, 0.062271688133478165, -0.0659433975815773, 0.0032127737067639828, -0.13898129761219025, 0.026403773576021194, 0.08804035186767578, -0.05001967027783394, 0.05326379835605621, -0.02196440100669861, 0.07656972110271454, 0.013867619447410107, -0.016544628888368607, -0.009327870793640614, 0.021883144974708557, -0.1560947597026825, -0.07534021139144897, -0.01896633207798004, 0.012034989893436432, -0.07331383228302002, -0.04332052916288376, -0.03353505954146385, 0.007872307673096657, 0.16191385686397552, -0.058967869728803635, 0.024201923981308937, 0.011731469072401524, -0.002475024200975895, -0.060298558324575424, -0.023722389712929726, -0.04882300645112991, 0.000707246595993638, -0.018090907484292984, 0.07239993661642075, 0.07933493703603745, 0.054174549877643585, -0.03342485427856445, -0.007864750921726227, 0.06494550406932831, -0.08771026879549026, 1.13459770849573e-33, 0.06040865182876587, 0.006845973432064056, -0.09519106149673462, -0.004926742985844612, 0.02894597128033638, -0.0077415574342012405, -0.05669841915369034, -0.034497782588005066, 0.09411472827196121, 0.0011957630049437284, -0.03672650456428528, 0.023257385939359665, -0.029259465634822845, -0.004881837405264378, -0.034621454775333405, -0.1123257502913475, 0.041878167539834976, 0.01935793086886406, 0.019774673506617546, 0.0033800536766648293, 0.04810955002903938, -0.043293364346027374, -0.019849350675940514, -0.024460462853312492, 0.011674574576318264, 0.028871286660432816, -0.04594291001558304, -0.009591681882739067, -0.020649896934628487, -0.0767439752817154, 0.06008455529808998, -0.07102784514427185, -0.03325150907039642, -0.07066744565963745, -0.07285013049840927, 0.06852841377258301, 0.032675426453351974, -0.015307767316699028, -0.03120141103863716, -0.0008060619584284723, -0.012935955077409744, 0.01687614619731903, 0.010606919415295124, 0.05316408351063728, -0.016209596768021584, 0.05059502646327019, -0.016619250178337097, -0.003106643445789814, -0.09400973469018936, 0.02362005040049553, -0.1493453085422516, 0.03363995999097824, -0.013002770021557808, -0.0411999374628067, -0.03762894868850708, 0.01735512912273407, -0.02544626034796238, -0.015723178163170815, 0.007998578250408173, 0.04340173304080963, 0.006307568401098251, -0.031614888459444046, -0.03868135064840317, -0.11168476939201355, 0.04688170179724693, 0.02938792295753956, 0.007106451783329248, -0.023254472762346268, 0.006188348866999149, 0.032097551971673965, 0.02284681424498558, -0.020912854000926018, -0.016115304082632065, 0.006232560612261295, -0.06727242469787598, 0.0027730280999094248, -0.04707656428217888, -0.03735049441456795, 0.026144297793507576, -0.013619091361761093, -0.005712081212550402, -0.04333459213376045, -0.008567489683628082, -0.0026371825952082872, -0.04714951291680336, 0.1506747603416443, 0.060538701713085175, 0.015910591930150986, 0.0021603393834084272, 0.09120813012123108, 0.10193410515785217, 0.04816991090774536, 0.07890739291906357, -0.05583663284778595, -0.02227107249200344, -2.478202887346015e-08, -0.08490563929080963, 0.04434036836028099, 0.02475418709218502, -0.024806825444102287, 0.00536795100197196, -0.06101489067077637, 0.014922979287803173, 0.04093354195356369, 0.03936637192964554, 0.04489367827773094, 0.012824231758713722, -0.03051156736910343, 0.0662570372223854, 0.04904399439692497, 0.004838698077946901, 0.07400422543287277, 0.03470872715115547, 0.037787146866321564, -0.043043263256549835, 0.04372495785355568, 0.023403732106089592, 0.057728372514247894, 0.034502316266298294, -0.049777042120695114, -0.0041667199693620205, 0.06382499635219574, -0.007370579522103071, -0.002130263252183795, -0.04700297489762306, 0.10623563826084137, -5.87037175137084e-05, -0.012606821022927761, 0.03633716702461243, 0.024944987148046494, -0.06500178575515747, 0.07670733332633972, 0.01752745360136032, 0.019638163968920708, 0.05920606851577759, 0.021030694246292114, 0.033589065074920654, 0.014452814124524593, 0.030615368857979774, 0.13622330129146576, 0.0162414088845253, 0.07696809619665146, 0.10586545616388321, 0.06321518868207932, -0.06497083604335785, 0.0035124991554766893, 0.03836303576827049, -0.049263447523117065, -0.0939357802271843, 0.04310446232557297, 0.047002870589494705, 0.02352922037243843, 0.06475073844194412, 0.12606267631053925, -0.03936544433236122, 0.0033126939088106155, -0.005963532254099846, 0.01087606605142355, -0.006803632713854313, 0.05783495306968689]}
To clean up, we can delete the model and endpoint.
predictor.delete_model()
predictor.delete_endpoint()