#!/usr/bin/env python # coding: utf-8 # # Using Large Language Models as text classifiers with an sklearn interface # In this notebook, we will learn how to use skorch's `ZeroShotClassifier` and `FewShotClassifier` to perform classification without any training thanks to the power of (Large) Language Models (LLMs). For this, we rely on the the [Hugging Face transformers](https://huggingface.co/docs/transformers/index) library, which allows us to use all the available text generation models provided by Hugging Face. #
# # Run in Google Colab # # View source on GitHub
# The notebook requires Hugging Face `transformers` and `datasets` as additional dependencies. If you have not already installed it, you can do so like this: # # `python -m pip install transformers datasets` # In[1]: import subprocess # Installation on Google Colab try: import google.colab subprocess.run(['python', '-m', 'pip', 'install', 'skorch', 'transformers', 'datasets']) except ImportError: pass # ## Imports # In[2]: import datasets import numpy as np import pandas as pd import transformers import torch from sklearn.metrics import accuracy_score, log_loss from sklearn.model_selection import GridSearchCV # In[3]: # let's reduce some of the noise from transformers and datasets logs transformers.logging.set_verbosity_warning() datasets.logging.set_verbosity_error() # In[4]: device = 'cuda:0' if torch.cuda.is_available() else 'cpu' # ## Load data # For this example, we make use of the IMDB dataset. It consists of movie reviews written by IMDB users and the target is the sentiment, i.e. "positive" or "negative". # In[5]: imdb = datasets.load_dataset('imdb').shuffle(seed=0) # We limit the number of samples to 100. Using zero/few-shot learning mostly makes sense when there are few labeled samples, otherwise, supervised machine learning methods will probably give better results. # In[6]: X = imdb['train'][:100]['text'] y = imdb['train'][:100]['label'] # Let's take a quick look at the data. Our `X` contains the user reviews: # In[7]: print(X[0]) # Our `y` contains the label-encoded targets: # In[8]: print(y[:5]) # For a standard machine learning solution, having label-encoded targets is desired. Here, we prefer to have the actual labels, however. It is much easier for the language model to predict the label "positive" for the text above than to predict "1". How would it know what "1" means? Sure, if we provide a few examples, it may work, but let's not make the language model's life harder than necessary and thus provide the actual labels. # In[9]: labels = np.array(['negative', 'positive'])[y] # In[10]: labels[:5] # ## Zero-shot classification # Now let's see how we can use zero-shot classification with skorch. First, let's load the `ZeroShotClassifier` class: # In[11]: from skorch.llm import ZeroShotClassifier # ### "train" zero-shot classifier # For demonstration purposes, we use a small language model here, `flan-t5-small`, which is hosted on Hugging Face. It has the advantage that it's quite fast and, as we'll see, still performs quite well. For more details on this model, check out [its model card on Hugging Face](https://huggingface.co/google/flan-t5-small). # In[12]: clf = ZeroShotClassifier('google/flan-t5-small', device=device, use_caching=False) # Notes: # # - `flan-t5` has an encoder-decoder architecture, for which caching is not available, which is why we turn it off. The loss of speed shouldn't matter much for this task. # - At the moment, we only support Hugging Face transformers models, or models that are compatible with it. We don't support APIs, so using OpenAI is not possible. For this, take a look at [scikit-llm](https://github.com/iryna-kondr/scikit-llm), which works with OpenAI. There are some restrictions associated with using an API, though, which means that not all features are available. # In[13]: get_ipython().run_line_magic('time', "clf.fit(X=None, y=['positive', 'negative'])") # In general, fitting is fast because, basically, nothing happens. If the transformers model and tokenizer are not cached locally, they will, however, be downloaded from Hugging Face, which may take some time. # ### evaluation # Let's evaluate how well the model works. As with any sklearn-compatible model, we can just call `predict_proba` to get the probabilities that the model assigns to each sample: # In[14]: get_ipython().run_line_magic('time', 'y_proba = clf.predict_proba(X)') # The prediction speed is a bit slow, as should be expected from a language model. If runtime is a big concern, this is probably not the right approach. # # Now let's check how well the model does. First we check the log loss, then the accuracy: # In[15]: log_loss(y, y_proba) # In[16]: y_pred = y_proba.argmax(1) # In[17]: accuracy_score(y, y_pred) # Given that this is zero-shot, those scores are actually not so bad! # # Sure, on the [leaderboard](https://huggingface.co/spaces/autoevaluate/leaderboards?dataset=imdb&only_verified=0&task=-any-&config=-unspecified-&split=-unspecified-&metric=accuracy) we can find models with better accuracy, but those are fine-tuned on the dataset. # # Notice that if we call `predict`, we get back the labels, i.e. "positive" or "negative". # In[18]: clf.predict(["A masterpiece, instant classic, 5 stars out of 5"]) # ### Grid searching the prompt # Since `ZeroShotClassifier` is sckit-learn compatible, we can easily do a grid search for the best prompt. In this example, let's compare two different prompts that are worded slightly differently. Could one of them be the better choice? # In[19]: prompt0 = """You are a text classification assistant. The text to classify: ``` {text} ``` Choose the label among the following possibilities with the highest probability. Only return the label, nothing more: {labels} Your response: """ # In[20]: prompt1 = """Your task is to classify text. Choose the label among the following possibilities with the highest probability. Only return the label, nothing more: {labels} The text to classify: ``` {text} ``` Your response: """ # In[21]: params = {'prompt': [prompt0, prompt1]} # In[22]: search = GridSearchCV(clf, param_grid=params, cv=2, scoring=['accuracy', 'neg_log_loss'], refit=False) # In[23]: get_ipython().run_line_magic('time', 'search.fit(X, labels)') # grid search results: # In[24]: pd.DataFrame(search.cv_results_)[['mean_test_accuracy', 'mean_test_neg_log_loss', 'param_prompt', 'mean_score_time']] # **Conclusion**: `prompt1` is performing better than `prompt0`. The mean test accuracy of 93% and log loss of 0.25 are pretty good overall, given that we use zero-shot and don't perform any fine-tuning. # # Going further, we could also grid search different language models, or combinations of LLMs and prompts, to find the best working zero-shot model. # ## Few-shot classification # Sometimes, helping the language model out by providing a few examples will boost the performance. To test this, we skorch provides the `FewShotClassifier` class. Let's try it out. # In[25]: from skorch.llm import FewShotClassifier from transformers import AutoModelForSeq2SeqLM, AutoTokenizer # ### train the few-shot classifier # Instead of passing the model name to initialize the classifier, as in `clf = FewShotClassifier('google/flan-t5-small')`, it is also possible to pass the model and tokenizer explicitly. This is a good option if you need more control over them. In our case, it amounts to the same result. It's useful to keep this option in mind, though, if the model requires any changes or if you want to provide a model that is not uploaded to Hugging Face. # In[26]: model = AutoModelForSeq2SeqLM.from_pretrained('google/flan-t5-small').to(device) tokenizer = AutoTokenizer.from_pretrained('google/flan-t5-small') # To control the amount of samples used for few-shot learning, use `max_samples` parameter. In this case, let's use 5 examples: # In[27]: clf = FewShotClassifier( model=model, tokenizer=tokenizer, max_samples=5, use_caching=False ) # In[28]: get_ipython().run_line_magic('time', 'clf.fit(X, labels)') # Let's make sure that everything works as expected by inspecting the prompt. This is possible using the `get_prompt` method: # In[29]: print(clf.get_prompt("A masterpiece, instant classic, 5 stars out of 5")) # If we're unhappy with the prompt, we can also provide our own prompt using the `prompt` argument, as we saw earlier in this notebook. # ### evaluation # In[30]: get_ipython().run_line_magic('time', 'y_proba = clf.predict_proba(X)') # In[31]: log_loss(y, y_proba) # In[32]: y_pred = y_proba.argmax(1) # In[33]: accuracy_score(y, y_pred) # In[34]: clf.predict(["Even if paid $1000, I would not watch this movie again"]) # This looks like a small improvement over what we got with zero-shot learning. Let's see if we can get even better. # ### grid search best number of few-shot samples # Maybe we can do even better if we pick a better number of samples for few-shot learning? Let's try this out with grid search. # # Note that grid search will split `X` and `y` for each run. Since the few-shot samples are taken from `X` and `y`, those will thus be different for each split, which could have a big influence on the performance of the model. If you always want to have the same few-shot samples in each split, you should craft your own prompt with those examples and then use it with `ZeroShotClassifier`. Just ensure that those prompts are not part of the validation/test data! # # Now let's test 3, 5, and 7 samples and see what works best. # In[35]: params = {'max_samples': [3, 5, 7]} # In[36]: search = GridSearchCV(clf, param_grid=params, cv=2, scoring=['accuracy', 'neg_log_loss'], refit=False) # In[37]: get_ipython().run_line_magic('time', 'search.fit(X, labels)') # In[38]: pd.DataFrame(search.cv_results_)[['mean_test_accuracy', 'mean_test_neg_log_loss', 'param_max_samples', 'mean_score_time']] # **Conclusion**: There is no significant change in accuracy compared to zero-shot but a small improvement in log loss. Having more samples doesn't help but slows down the inference time, as we can see when looking at `mean_score_time`. Overall, few-shot learning helps a bit but makes inference slower. It's up to you to decide if the trade-off is worth it in this specific case. # ## Debugging # Working with LLMs can be difficult because it is hard to know for certain if the prompt works well and if the LLM is capable of classifying the input. For this reason, skorch provides a few options to help identify those issues. # ### Returning unnormalized probabilities # By default, the model will normalize the probabilities to sum to 1. This is what is expected when calling `predict_proba`. However, this can hide underlying issues. The LLM can in theory predict any token from its vocabulary, there is no guarantee that it will choose one of the provided labels. skorch will force the LLM to use one of the labels, but we also track the probabilities assigned, or not assigned, to these labels. # # To give an example, for a given input, it's possible that the LLM predicts a probability of 10% that the label is 'negative' and 70% that it is 'positive'. By default, we normalize the probability to be 1, i.e. we return 0.125 and 0.875. The problem is that we would return the same normalized probabilities even if the model predicts 1% and 7%. But if the model predicts such low probabilities, there is probably something wrong and we would like to know about it. # # For this reason, we added the option to disable the normalization of probabilities. Let's check how well our zero-shot flan-t5 model is doing without normalization: # In[39]: clf = ZeroShotClassifier('google/flan-t5-small', use_caching=False, probas_sum_to_1=False) # In[40]: clf.fit(X=None, y=['positive', 'negative']) # In[41]: y_proba = clf.predict_proba(X[:3]) # In[42]: y_proba # Let's check the sum of the two classes combined: # In[43]: y_proba.sum(1) # As you can see, the summed probabilities returned by flan-t5 are quite high. Without normalization, they still sum up to ~99%, which is very good. # # Now let's take a look at an LLM that doesn't work well for this task, GPT2. # # Note that, in contrast to flan-t5, GPT2 is a decoder-only language model, we don't need to set `use_caching=False`. # In[44]: clf = ZeroShotClassifier('gpt2', probas_sum_to_1=False) # In[45]: clf.fit(X=None, y=['positive', 'negative']) # In[46]: y_proba = clf.predict_proba(X[:3]) # In[47]: y_proba # As we can see, the probabilities are really low, but if we had normalized them, we might not have noticed: # In[48]: # normalize probabilities to sum up to 1 y_proba / y_proba.sum(1, keepdims=True) # This means we should probably use a different LLM or tinker with the prompt until we get better results. # ### Specific actions when probabilities are low # There are more options to identify low probabilities in a way that does not require manually inspecting the probabilities. For this, we provide two arguments for `ZeroShotClassifier` and `FewShotClassifier`: # # The first argument is called `error_low_prob`. It should be one of the following strings: `'ignore'`, `'warn'`, `'raise'`, or `'return_none'`. # # By default, it is `'ignore'`, which means that nothing happens, no matter how low the predicted proabilities. By setting it to `'warn'`, there will be a warning when the total probabilities of at least one predicted sample is too low. Use this option if you want to get the result but be alerted about possible problems. # # By passing `error_low_prob='raise'`, an error will be raised as soon as a sample with low total probabilities is encountered. This is useful if you want inference to stop immediately, instead of waiting for all predictions to be made. # # Finally, you can set `error_low_prob='return_none'`. In this case, nothing changes when calling `predict_proba`. When calling `predict`, however, the probabilities for the samples will be checked and if they're too low, the prediction will be replaced by `None`. This is useful if the predictions are generally good, but some examples are, for one reason or another, hard to predict. # # The second parameter, which should be used in conjunction with `error_low_prob`, is called `threshold_low_prob`. This is simply a float between 0 and 1 that indicates what the probability is that should be considered "low". Note that this value is compared to the _sum of the probability for all labels_ of a given sample. So when setting `threshold_low_prob=0.1`, and the probability for 'negative' is 0.05, but the probability for 'positive' is 0.2, this would be fine because in total, their probabilities exceed 0.1. # # Let's see how this works in practice by using the option to raise an error and setting the threshold to 0.5: # In[49]: # note that since GPT2 is a decoder-only language model, we don't need to set use_caching=False clf = ZeroShotClassifier('gpt2', error_low_prob='raise', threshold_low_prob=0.5) # In[50]: clf.fit(X=None, y=['positive', 'negative']) # In[51]: try: clf.predict_proba(X[:3]) except Exception as exc: print("There was an error:", exc) # As you can see, we indeed got an error, alerting us immediately to potential issues. # ## Testing MNLI # There are other zero-shot classification methods out there. One such method is to use natural language inference (NLI). In a nutshell, this method works by creating the text embedding for the input and the embeddings for each label, then calculating the probability based on the similarity of the text and label embeddings. # # Let's compare the results to https://huggingface.co/facebook/bart-large-mnli, which is the most used zero-shot classifier on Hugging Face at the time of writing. # In[52]: from transformers import pipeline # In[53]: classifier = pipeline('zero-shot-classification', model='facebook/bart-large-mnli', device=device) # In[54]: get_ipython().run_line_magic('time', "preds = classifier(imdb['train'][:100]['text'], ['negative', 'positive'])") # In[55]: y_proba = np.vstack([p['scores'] if p['labels'] == ['negative', 'positive'] else p['scores'][::-1] for p in preds]) # In[56]: accuracy_score(y, y_proba.argmax(1)) # In[57]: log_loss(y, y_proba) # **Conclusion**: This model is slower than the tested zero-shot classifier, it is less flexible (we cannot adjust prompt or other parameters), and it performs worse. For this task, it is, therefore, better to use skorch's `ZeroShotClassifier`. # ## Testing a standard machine learning solution # Finally, let's compare the results to a classical supervised machine learning approach. For this, we use TFIDF to vectorize the input and a logistic regression for classification. This a standard pipeline for text classification tasks and works really well with enough data. # In[58]: from sklearn.pipeline import Pipeline from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.linear_model import LogisticRegression from sklearn.model_selection import cross_validate # In[59]: tfidf = Pipeline([ ('tfidf', TfidfVectorizer()), ('clf', LogisticRegression()), ]) # Let's run a grid search on a couple of hyper-parameters to ensure we pick good ones. # In[60]: params = {'tfidf__max_features': [500, 1000], 'tfidf__ngram_range': [(1, 1), (1, 2), (1, 3)]} # In[61]: search = GridSearchCV( tfidf, param_grid=params, cv=2, scoring=['accuracy', 'neg_log_loss'], refit=False ) # In[62]: get_ipython().run_line_magic('time', 'search.fit(X, y)') # The table is quite big, let's look at the top 5 best log losses: # In[63]: cols = ['mean_test_accuracy', 'mean_test_neg_log_loss', 'param_tfidf__max_features', 'param_tfidf__ngram_range'] pd.DataFrame(search.cv_results_)[cols].sort_values('mean_test_neg_log_loss', ascending=False).head() # **Conclusion**: This classical model is much faster, even if we include the training time, because it is much smaller than a language model. However, it's scores are also much worse, which is due to the small size of the dataset. If speed is no concern, using an LLM classifier would thus be a good option for this task. # ## Summary # In this notebook, we learned how to use skorch's `ZeroShotClassifier` and `FewShotClassifier` for a text classification task. Let's list a few advantages that we gained from using those classes: # # - On this particular dataset, zero- and few-shot learning outperformed a classical supervised machine learning approach. We also got better scores than what we got from MNLI. # - We can use `ZeroShotClassifier` and `FewShotClassifier` as drop-in replacement for other sklearn text classification models because `fit`, `predict`, and `predict_proba` work as expected from an sklearn model. # - It is trivial to run a grid search. This way, we can find out what model works best, what prompt is optimal, and how many few-shot samples to provide. # - We can call `predict_proba` to get the (relative) probability the model assigns to each label, which is not something we normally get from a language model. # - `ZeroShotClassifier` and `FewShotClassifier` also give us some nice extra features. Most notably, they force the language models to predict one of the provided labels, which is typically not a guarantee when using language models. We also get easy ways to detect issues and caching (for decoder-only models). # --- # # ## ✨ Bonus ✨ # # Not every task is a classification task but some tasks can be broken down into a classification task! # # As an example we show you how a task such as [PIQA](https://huggingface.co/datasets/piqa) # can be reformulated to be solved with the LLM classifier. PIQA defines the task of giving the right solution # out of two options to achieve a given goal while the more sensible of the two is labelled correct. # # Two example entries of the PIQA dataset: # # | goal (string) | sol1 (string) | sol2 (string) | label (class label) | # | - | - | - | - | # | "When boiling butter, when it's ready, you can" | "Pour it onto a plate" | "Pour it into a jar" | 1 | # | "To permanently attach metal legs to a chair, you can" | "Weld the metal together to get it to stay firmly in place" | "Nail the metal together to get it to stay firmly in place" | 0 | # # A generative approach to this problem would be to tell the model to name the correct solution, compare it with the given options and determine its number and compare it with the correct label. # This approach doesn't work with the LLM classifier of course. But we can re-phrase the task a bit to give each solution a number and ask the model to predict the correct number for the task. # # Therefore, for a zero-shot formulation we could prompt the model like this: # # ``` # prompt = """Goal is: Do cardio exercise without running. # Solution 1: Use a jump rope for 15 minutes. # Solution 2: Run around a chair for 15 minutes. # Correct: Solution """ # ``` # # We then expect the model to complete `1` or `2`, which are now our classes. As always, the ideal way of prompting may differ according # to the used model and instruction-trained models may need a more precise prompt. # # We will test this task on `bloomz-1b1` to cover another popular LLM and because we know [what to expect from this model on this task](https://huggingface.co/datasets/bigscience/evaluation-results/viewer/bloom-1b1/test) (at best 67.14% zero-shot). # In[64]: dataset = datasets.load_dataset('piqa').shuffle(seed=42) # In[65]: template = """Goal is: {goal} Solution 1: {sol1} Solution 2: {sol2}""" X = [] y = [] # iterating over dataset['train'] directly is not possible, since that only yields the keys for i in range(len(dataset['train'])): row = dataset['train'][i] X.append(template.format(**row)) y.append(" 1" if row['label'] == 0 else " 2") # Take note that we chose to use " 1" and " 2" as labels. `bloomz` seems to be trained in such a way that it favors " 1" over "1". This is not the case for `flan-t5-*`, for example, but something to keep in mind when prompting and testing these models. # In[66]: print(X[0]) # In[67]: model = 'bigscience/bloomz-1b1' prompt = """{text} Correct: Solution""" # In[68]: clf = ZeroShotClassifier(model_name=model, prompt=prompt, probas_sum_to_1=False, device=device) # In[69]: get_ipython().run_line_magic('time', 'clf.fit(X=None, y=y)') # To save you some time we will just classify 1000 of the ~16,000 samples. This is an example and not a benchmark, after all. # In[70]: max_n = 1000 get_ipython().run_line_magic('time', 'y_proba = clf.predict_proba(X[:max_n])') # ### Evaluation # In[71]: log_loss(y[:max_n], y_proba) # In[72]: y_pred = clf.classes_[y_proba.argmax(1)] accuracy_score(y[:max_n], y_pred) # In[73]: y_proba_normed = y_proba / y_proba.sum(axis=1)[:, None] abs(y_proba_normed[:, 0] - y_proba_normed[:, 1]).mean() # You can see that the accuracy is below what we expected (67.14%) and the probabilities are very close. Why is that? # # The reported accuracy of the reference benchmark is determined by choosing the answer which has the higher log-probability. What the EleutherAI benchmark does is to ask for the probabilities of `" "` and `" "`: the prompt with the higher probability is the winner. This is *leveraging common knowledge* (i.e. a more likely phrase correlates with a more common phrase which is a good bias for correctness - you are more likely to find 'eat a burger' than 'throw a ball' for the goal context 'i am hungry'). # # Another aspect is that we are introducing an *indirection* with our task framing: we're answering but we're choosing a symbol *for* the answer. This makes it a lot harder for the model to choose the correct answer because it not only needs to understand to only answer with the valid options (1 and 2) but only to memorize what they signifiy in terms of the goal phrase. This is also indicated by the very low mean absolute difference between the probabilities of both options (of only ~15 percent points). Different models seem to perform differently on this. `flan-t5` does this really well, `bloomz` seems to perform worse, only `bloomz-3b` is able to achieve accuracies >65% with this setup. # In[74]: y_proba.min(0), y_proba.max(0) # By looking at the minimum and maximum (absolute) probabilities over all samples we can also see that the model is never really strongly certain about an answer which, in itself, is not a problem but combined with the low accuracy is indicative for a limited capability in "understanding" the task at hand. # # **So... is this bad?!** - the performance? Yes. But in general: **no** - we tasked the model with a more complex task and arguably it is good thing that the probabilities we computed revealed to us that the model is not able to fit the data. If you can see that the difference in probabilities between the options is quite small it is likely that the model is not able to solve the task and might not be well-suited for the task at hand. This could be an indicator for you to chose a different model (or maybe even simpler, a different prompt). Note that the reason might not simply be 'model complexity' but could also mean that there's an unfortunate tokenization that needs a bigger model to sort out - that's very hard to say. # # The lesson here is that the probabilistic view lets you reason a bit more about the performance of these models in a familiar way. # # And this concludes this bonus section. ✨ # # You have seen how frame atypical tasks into a classification problem and experienced first-hand how good or bad the capability to handle indirection can vary between models. You've also learned the importance of dealing with tokenization preferences (`"1"` vs `" 1"`) of language models and saw that having an interface to look into the probabilities and their differences can tell you a bit about your model and it's ability to solve your task. # In[ ]: