%reload_ext autoreload
%autoreload 2
%matplotlib inline
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID";
os.environ["CUDA_VISIBLE_DEVICES"]="0"
In this notebook, we will demonstrate zero-shot topic classification. Zero-Shot Learning (ZSL) is being able to solve a task despite not having received any training examples of that task. The ZeroShotClassifier
class in ktrain can be used to perform topic classification with no training examples. The technique is based on Natural Language Inference (or NLI) as described in this interesting blog post by Joe Davison.
We first instantiate the zero-shot-classifier and then describe the topic labels for our classifier with strings.
from ktrain.text.zsl import ZeroShotClassifier
zsl = ZeroShotClassifier()
labels=['politics', 'elections', 'sports', 'films', 'television']
There is no training involved here, as we are using zero-shot-learning. We will simply supply the document that is being classified and the topic_strings
defined earlier. The predict
method uses Natural Language Inference (NLI) to infer the topic probabilities.
doc = 'I am extremely dissatisfied with the President and will definitely vote in 2020.'
zsl.predict(doc, labels=labels, include_labels=True)
[('politics', 0.979189932346344), ('elections', 0.9874580502510071), ('sports', 0.0005765462410636246), ('films', 0.0022924456279724836), ('television', 0.0010546103585511446)]
As you can see, our model correctly assigned the highest probabilities to politics
and elections
, as the text supplied pertains to both these topics.
Let's try some other examples.
television
¶doc = 'What is your favorite sitcom of all time?'
zsl.predict(doc, labels=labels, include_labels=True)
[('politics', 0.00015667644038330764), ('elections', 0.00032881161314435303), ('sports', 0.00013884963118471205), ('films', 0.07557642459869385), ('television', 0.9813269376754761)]
politics
and television
¶doc = """
President Donald Trump's senior adviser and son-in-law, Jared Kushner, praised
the administration's response to the coronavirus pandemic as a \"great success story\" on Wednesday --
less than a day after the number of confirmed coronavirus cases in the United States topped 1 million.
Kushner painted a rosy picture for \"Fox and Friends\" Wednesday morning,
saying that \"the federal government rose to the challenge and
this is a great success story and I think that that's really what needs to be told.\"
"""
zsl.predict(doc, labels=labels, include_labels=True)
[('politics', 0.8049427270889282), ('elections', 0.01889326609671116), ('sports', 0.005504833068698645), ('films', 0.05876927077770233), ('television', 0.8776823878288269)]
sports
, television
, and film
¶doc = "The Last Dance is a 2020 American basketball documentary miniseries co-produced by ESPN Films and Netflix."
zsl.predict(doc, labels=labels, include_labels=True)
[('politics', 0.0005349867278710008), ('elections', 0.0007852867711335421), ('sports', 0.9848827123641968), ('films', 0.9576993584632874), ('television', 0.941143274307251)]
As stated above, the ZeroShotClassifier
is implemented using Natural Language Inference (NLI). That is, the document is treated as a premise, and each label is treated as a hypothesis. To predict labels, an NLI model is used to predict whether or not each label is entailed by the premise. By default, the template used for the hypothesis is of the form "This text is about <label>."
, where <label>
is replaced with a candidate label (e.g., politics
, sports
, etc.). Although this works well for many text classification problems such as the topic classification examples above, we can customize the template with the nli_template
parameter if necessary. For instance, if predicting sentiment of movie reviews, we might change the template as follows:
doc = "I will definitely not be seeing this movie again."
zsl.predict(doc, labels=['negative', 'positive'], include_labels=True,
nli_template="The sentiment of this movie review is {}.")
[('negative', 0.9995395541191101), ('positive', 0.011613081209361553)]
If you compare with the default template, you'll see the negative score is higher with the custom template.
Let's now consider a more ambiguous review:
I will definitely not be seeing this movie again, but the acting was good.
doc = "I will definitely not be seeing this movie again, but the acting was good."
zsl.predict(doc, labels=['negative', 'positive'], include_labels=True,
nli_template="The sentiment of this movie review is {}.")
[('negative', 0.8110149502754211), ('positive', 0.5280577540397644)]
From the output above, we see that the results do NOT sum to one and both labels are above a standard threshold of 0.5
. By default, ZeroShotClassifier
treats the task as a multilabel problem, which allows multiple labels to be true. Since the review is both negative and positive, both scores are above the 0.5
threshold (although the positive
class is only above slightly when using the custom template).
If the labels are to be treated as mutually-exclusive, we can set multilabel=False
in which case the scores will sum to 1 we will classify the review as negative overall:
doc = "I will definitely not be seeing this movie again, but the acting was good."
zsl.predict(doc, labels=['negative', 'positive'], include_labels=True,
nli_template="The sentiment of this movie review is {}.",
multilabel=False)
[('negative', 0.6576023101806641), ('positive', 0.34239766001701355)]
The predict
method can accept a large list of documents. Documents are automatically split into batches based on the batch_size
parameter, which can be increased to speed up predictions.
Note also that the predict
method of ZeroShotClassifier
generates a separate NLI prediction for each label included in the labels
parameter. As len(labels)
and the number of documents fed to predict
increases, the prediction time will also increase. You can speed up predictions by increasing the batch_size
. The default batch_size
is currently set conservatively at 8:
batch_size=1
¶%%time
doc = 'I am extremely dissatisfied with the President and will definitely vote in 2020.'
labels=['politics', 'elections', 'sports', 'films', 'television']
predictions = zsl.predict(doc, labels=labels*160, include_labels=True, batch_size=1)
CPU times: user 53.2 s, sys: 728 ms, total: 53.9 s Wall time: 26 s
As you can see, 26 seconds is slow. We can speed things up by increasing batch_size
:
batch_size=64
¶%%time
doc = 'I am extremely dissatisfied with the President and will definitely vote in 2020.'
predictions = zsl.predict(doc, labels=labels*160, include_labels=True, batch_size=64)
CPU times: user 1.74 s, sys: 480 ms, total: 2.22 s Wall time: 1.67 s
batch_size=64
¶%%time
doc = 'I am extremely dissatisfied with the President and will definitely vote in 2020.'
predictions = zsl.predict([doc]*1000, labels=labels, include_labels=True, batch_size=64)
CPU times: user 11.1 s, sys: 2.57 s, total: 13.7 s Wall time: 10.3 s
With 1000 documents and 5 topics, we are essentially making 5000 predictions in ~10 seconds with a batch_size=64
. Lower batch sizes would be much slower given this many predictions. The batch_size
should be set based on available memory.
Finally, there is a max_length
parameter that is set to 512 as default.