(gradio-serve-tutorial)=
In this example, we will show you how to wrap a machine learning model served by Ray Serve in a Gradio demo.
Specifically, we're going to download a GPT-2 model from the transformer
library,
define a Ray Serve deployment with it, and then define and launch a Gradio Interface
.
Let's take a look.
# Install all dependencies for this example.
! pip install ray gradio transformers requests
To start off, we import Ray Serve, Gradio, the transformers
and requests
libraries,
and then simply start Ray Serve:
import gradio as gr
from ray import serve
from transformers import pipeline
import requests
serve.start()
Next, we define a Ray Serve deployment with a GPT-2 model, by using the @serve.deployment
decorator on a model
function that takes a request
argument.
In this function we define a GPT-2 model with a call to pipeline
and return the result of querying the model.
@serve.deployment
def model(request):
language_model = pipeline("text-generation", model="gpt2")
query = request.query_params["query"]
return language_model(query, max_length=100)
This model
can now easily be deployed using a model.deploy()
call.
To test this deployment we use a simple example
query to get a response
from the model running
on localhost:8000/model
.
The first time you use this endpoint, the model will be downloaded first, which can take a while to complete.
Subsequent calls will be faster.
model.deploy()
example = "What's the meaning of life?"
response = requests.get(f"http://localhost:8000/model?query={example}")
print(response.text)
Defining a Gradio interface is now straightforward.
All we need is a function that Gradio can call to get the response from the model.
That's just a thin wrapper around our previous requests
call:
def gpt2(query):
response = requests.get(f"http://localhost:8000/model?query={query}")
return response.json()[0]["generated_text"]
Apart from our gpt2
function, the only other thing that we need to define a Gradio interface is
a description of the model inputs and outputs that Gradio understands.
Since our model takes text as input and output, this turns out to be pretty simple:
iface = gr.Interface(
fn=gpt2,
inputs=[gr.inputs.Textbox(
default=example, label="Input prompt"
)],
outputs=[gr.outputs.Textbox(label="Model output")]
)
For more complex models served with Ray, you might need multiple gr.inputs
and gr.outputs
of different types.
{margin}
The [Gradio documentation](https://gradio.app/docs/) covers all viable input and output components in detail.
Finally, we can launch the interface using iface.launch()
:
iface.launch()
This should launch an interface that you can interact with that looks like this:
{image}
You can run this examples directly in the browser, for instance by launching this notebook directly
into Google Colab or Binder, by clicking on the rocket icon at the top right of this page.
If you run this code locally in Python, this Gradio app will be served on http://127.0.0.1:7861/
.
Let's take a look at another example, so that you can see the slight differences to the first example in direct comparison.
# Install all dependencies for this example.
! pip install ray gradio requests scikit-learn
This time we're going to use a Scikit-Learn model that we quickly train
ourselves on the famous Iris dataset.
To do this, we'll download the Iris dataset using the built-in load_iris
function from the sklearn
library,
and we used the GradientBoostingClassifier
from the sklearn.ensemble
module for training.
This time we'll use the @serve.deployment
decorator on a class called BoostingModel
, which has an
asynchronous __call__
method that Ray Serve needs to define your deployment.
All else remains the same as in the first example.
import gradio as gr
import requests
from sklearn.datasets import load_iris
from sklearn.ensemble import GradientBoostingClassifier
from ray import serve
# Train your model.
iris_dataset = load_iris()
model = GradientBoostingClassifier()
model.fit(iris_dataset["data"], iris_dataset["target"])
# Start Ray Serve.
serve.start()
# Define your deployment.
@serve.deployment(route_prefix="/iris")
class BoostingModel:
def __init__(self, model):
self.model = model
self.label_list = iris_dataset["target_names"].tolist()
async def __call__(self, request):
payload = (await request.json())["vector"]
print(f"Received http request with data {payload}")
prediction = self.model.predict([payload])[0]
human_name = self.label_list[prediction]
return {"result": human_name}
# Deploy your model.
BoostingModel.deploy(model)
Equipped with our BoostingModel
class, we can now define and launch a Gradio interface as follows.
The Iris dataset has a total of four features, namely the four numeric values sepal length, sepal width,
petal length, and petal width.
We use this fact to define an iris
function that takes these four features and returns the predicted class,
using our deployed model.
This time, the Gradio interface takes four input Number
s, and returns the predicted class as text
.
Go ahead and try it out in the browser yourself.
# Define gradio function
def iris(sl, sw, pl, pw):
request_input = {"vector": [sl, sw, pl, pw]}
response = requests.get(
"http://localhost:8000/iris", json=request_input)
return response.json()[0]["result"]
# Define gradio interface
iface = gr.Interface(
fn=iris,
inputs=[
gr.inputs.Number(default=1.0, label="sepal length (cm)"),
gr.inputs.Number(default=1.0, label="sepal width (cm)"),
gr.inputs.Number(default=1.0, label="petal length (cm)"),
gr.inputs.Number(default=1.0, label="petal width (cm)"),
],
outputs="text")
# Launch the gradio interface
iface.launch()
Launching this interface, you should see an interactive interface that looks like this:
{image}
To summarize, it's easy to build Gradio apps from Ray Serve deployments. You only need to properly encode your model's inputs and outputs in a Gradio interface, and you're good to go!