This notebook shows how to serve a large language model on Vast's GPU platform using HuggingFace's open source inference framework TGI. TGI
is particularly easy to use if you're familiar with Huggingface. It automatically batches queries for you and is compatible with the OpenAI API. This notebook is adapted from our vLLM
guide so that you can see what exactly needs to be changed between the two of them.
The commands in this notebook can be run here, or copied and pasted into your terminal (Minus the %%bash
). At the end, we will include a way to query your TGI
service in either python or with a curl request for the terminal.
%%bash
#In an environment of your choice
pip install --upgrade vastai
%%bash
# Here we will set our api key
vastai set api-key <Your-Vast-API-Key-Here>
# Here we need to set a huggingface token
HF_TOKEN = "<Your-HuggingFace-Token-Here>"
Now we are going to look for GPU's on vast. The model that we are using is going to be very small, but to allow for easily swapping out the model you desire, we will select machines that:
TGI
primarily serves one copy of a model.TGI
image is based upon a CUDA 12.1 Base Image%%bash
vastai search offers 'compute_cap >= 800 gpu_ram >= 40 num_gpus = 1 static_ip=true direct_port_count > 1 cuda_vers >= 12.1'
Copy and Paste the id of a machine that you would like to choose below for <instance-id>
, same with <huggingface-toke>
from the one assigned above.
We will activate this instance with the text-generation-inference:latest
image. This image gives us a TGI server that is compatible with the OpenAI SDK. This means that it can slide in to any application that uses the openAI api. All you need to change in your app is the base_url
and the model_id
to the model that you are using so that the requests are properly routed to your model.
This command also exposes the port 8000 in the docker container, the default openAI server port, and tells the docker container to automatically download and serve the text-generation-inference/gemma-7b-it-medusa
. You can change the model by using any HuggingFace model ID. We chose this because it is fast to download and start playing with.
We use vast's --args
command to funnel the rest of the command to the container, in this case --model-id text-generation-inference/gemma-7b-it-medusa
, which TGI
uses to download the model, and --port 8000
to ensure that TGI is listening on the right port. --speculate 2
is where the magic is.
%%bash
vastai create instance <instance-id> --image ghcr.io/huggingface/text-generation-inference:latest --env '-p 8000:8000 -e HF_TOKEN=<hf-token>' --disk 60 --args --port 8000 --model-id text-generation-inference/gemma-7b-it-medusa --speculate 2
Now, we need to verify that our setup is working. We first need to wait for our machine to download the image and the model and start serving. This will take a few minutes. The logs will show you when it's done.
Then, at the top of the instance, there is a button with an IP address in it. Click this and a panel will show up of the ip address and the forwarded ports. You should see something like:
Open Ports
XX.XX.XXX.XX:YYYY -> 8000/tcp
Copy and paste the IP address and the port in the curl command below.
This curl command sends and OpenAI compatible request to your TGI server. You should see the response if everything is setup correctly.
%%bash
pip install openai torch transformers datasets
Next, we're going to download the lmsys dataset and load up the functions to start the performance testing with medusa's speculate=2
with Google's 7b gemma.
import time
from openai import OpenAI
from transformers import AutoTokenizer
import datasets
def save_dataset():
ds = datasets.load_dataset("lmsys/lmsys-chat-1m")
ds.save_to_disk("~/.cache/lmsys/lmsys-chat-1m")
# Load the dataset, get the first 10 samples
def load_dataset():
dataset = datasets.load_from_disk("~/.cache/lmsys/lmsys-chat-1m")
return dataset["train"][:10]
# Load the tokenizer for our model
def load_tokenizer(model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name)
return tokenizer
# Set up an OpenAI Client using the base_url and Port
def setup_openai_client(base_url, port, api_key):
client = OpenAI(base_url=f"{base_url}:{port}/v1", api_key=api_key)
return client
# Function to run samples and measure performance
def run_samples(samples, model_id, openai_client, tokenizer):
latencies = []
input_tokens = []
output_tokens = []
for sample in samples:
start_time = time.time()
# Tokenize the input
input_ids = tokenizer.encode(sample, return_tensors="pt")
input_token_count = len(input_ids[0])
# Run the model
response = openai_client.completions.create(model=model_id, prompt=sample)
end_time = time.time()
latency = end_time - start_time
# Collect metrics
latencies.append(latency)
input_tokens.append(input_token_count)
output_ids = tokenizer.encode(response.choices[0].text, return_tensors="pt")
output_token_count = len(output_ids[0])
output_tokens.append(output_token_count)
# Calculate average latency, throughput, and total tokens
avg_latency = sum(latencies) / len(latencies)
total_input_tokens = sum(input_tokens)
total_output_tokens = sum(output_tokens)
throughput = total_output_tokens / sum(latencies)
return avg_latency, throughput, total_input_tokens, total_output_tokens
save_dataset()
Now we'll performance test medusa, replace base_url
and port
with the url from above.
model_name = "text-generation-inference/gemma-7b-it-medusa" # Replace with your model name
base_url = "http://[BASE_URL]" # replace with your instance's ip address
port = 0 # replace with your instance's port
# Load dataset and tokenizer
samples = load_dataset()
tokenizer = load_tokenizer(model_name)
# Set up OpenAI client
openai_client = setup_openai_client(base_url, port, "key")
# Run samples and get performance metrics
medusa_avg_latency, medusa_throughput, medusa_total_input_tokens, medusa_total_output_tokens = run_samples(
samples, model_name, openai_client, tokenizer
)
# Print results
print(f"Medusa Average Latency: {medusa_avg_latency:.4f} seconds")
print(f"Medusa Throughput: {medusa_throughput:.2f} tokens/second")
print(f"Medusa Total Input Tokens: {medusa_total_input_tokens}")
print(f"Medusa Total Output Tokens: {medusa_total_output_tokens}")
Next we're going to test just google/gemma-7b-it
.
Copy and Paste the id of a machine that you just used into <instance-id>
and <hf-token>
This command also exposes the port 8000 in the docker container, the default openAI server port, and tells the docker container to automatically download and serve the google/gemma-7b-it
. You can change the model by using any HuggingFace model ID. We chose this because it is fast to download and start playing with.
We use vast's --args
command to funnel the rest of the command to the container, in this case --model-id google/gemma-7b-it
, which TGI
uses to download the model, and --port 8000
to ensure that TGI is listening on the right port.
%%bash
# This request assumes you haven't changed the model. If you did, fill it in the "model" value in the payload json below
curl -X POST http://[BASE_URL]:[PORT]/v1/completions -H "Content-Type: application/json" -d '{"model" : "google/gemma-7b-it", "prompt": "Hello, how are you?", "max_tokens": 50}'
# Configuration
model_name = "google/gemma-7b-it" # Replace with your model name
base_url = "http://[BASE_URL]" # replace with your instance's ip address
port = 0 # replace with your instance's port
# Load dataset and tokenizer
samples = load_dataset()
tokenizer = load_tokenizer(model_name)
# Set up OpenAI client
openai_client = setup_openai_client(base_url, port, "key")
# Run samples and get performance metrics
base_avg_latency, base_throughput, base_total_input_tokens, base_total_output_tokens = run_samples(
samples, model_name, openai_client, tokenizer
)
# Print results
print(f"Base Average Latency: {base_avg_latency:.4f} seconds")
print(f"Base Throughput: {base_throughput:.2f} tokens/second")
print(f"Base Total Input Tokens: {base_total_input_tokens}")
print(f"Base Total Output Tokens: {base_total_output_tokens}")
# compare it to Medusa!
print(f"\nMedusa Average Latency: {medusa_avg_latency:.4f} seconds")
print(f"Medusa Throughput: {medusa_throughput:.2f} tokens/second")
print(f"Medusa Total Input Tokens: {medusa_total_input_tokens}")
print(f"Medusa Total Output Tokens: {medusa_total_output_tokens}")