https://tinyurl.com/zxyxa9ww Copy the notebook to your GDrive to edit.
%reload_ext autoreload
%autoreload 2
%matplotlib inline
!pip3 install transformers
Requirement already satisfied: transformers in /usr/local/lib/python3.7/dist-packages (4.11.2) Requirement already satisfied: huggingface-hub>=0.0.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (0.0.17) Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.7/dist-packages (from transformers) (5.4.1) Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers) (2.23.0) Requirement already satisfied: sacremoses in /usr/local/lib/python3.7/dist-packages (from transformers) (0.0.46) Requirement already satisfied: tokenizers<0.11,>=0.10.1 in /usr/local/lib/python3.7/dist-packages (from transformers) (0.10.3) Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from transformers) (4.8.1) Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (1.19.5) Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers) (3.0.12) Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (2019.12.20) Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.7/dist-packages (from transformers) (4.62.3) Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from transformers) (21.0) Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from huggingface-hub>=0.0.17->transformers) (3.7.4.3) Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.0->transformers) (2.4.7) Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->transformers) (3.5.0) Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (1.24.3) Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2021.5.30) Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2.10) Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (3.0.4) Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers) (1.15.0) Requirement already satisfied: click in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers) (7.1.2) Requirement already satisfied: joblib in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers) (1.0.1)
import torch
import random
import numpy as np
from transformers import AutoTokenizer, BertForMaskedLM
def enforce_reproducibility(seed=42):
# Sets seed manually for both CPU and CUDA
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# For atomic operations there is currently
# no simple way to enforce determinism, as
# the order of parallel operations is not known.
# CUDNN
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# System based
random.seed(seed)
np.random.seed(seed)
enforce_reproducibility()
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda")
device
device(type='cpu')
Error analysis involves examining the errors made by a system and developing a classification of them. (This is typically best done over dev data, to avoid compromising held-out test sets.) At a superficial level, this can involve breaking things down by input length, token frequency or looking at confusion matrices. But we should not limit ourselves to examining only labels (rather than input linguistic forms) as with confusion matrices, or superficial properties of the linguistic signal. Languages are, after all, complex systems and linguistic forms are structured. So a deeper error analysis involves examining those linguistic forms and looking for patterns.
A good error analysis tells us something about why method X is effective or ineffective for problem Y. This in turn provides a much richer starting point for further research, allowing us to go beyond throwing learning algorithms at the wall of tasks and seeing which stick, while allowing us to also discover which are the harder parts of a problem.
Error analysis: Does the project provide a thoughtful error analysis, which looks for linguistic patterns in the types of errors made by the system(s) evaluated and sheds light on either avenues for future work or the source of the strengths/weaknesses of the systems?
Source : http://coling2018.org/error-analysis-in-research-and-writing/
Error analysis — the attempt to analyze when, how, and why machine-learning models fail — is a crucial part of the development cycle: Researchers use it to suggest directions for future improvement, and practitioners make deployment decisions based on it. Since error analysis profoundly determines the direction of subsequent actions, we cannot afford it to be biased or incomplete.
Error Analysis can include everything that helps you understand how the model behaves, what are it's strengths and weaknesses.
When having access to the model's decisions, e.g. weights for each n-gram, and they are easily understandable, we say that the model is interpretable by design (Chapter 4, Interpretable Models) For such models, we can explore what the model chose as important features and use that in our analysis.
Language Modeling: Predict the likelihood of a sentence P(x)
P(x) is high: Barack Obama served as the 44th President of the UnitedStates.
P(x) is low: 44th the of the President United States served Barack Obama as. (syntax)
P(x) is low: Barack Obama barked as the 44th President of the kennel. (semantics)
P(x) is low: Barack Obama served as the 44th President of the UnitedStates. (facts)
P(x) is low: Barack Obama reached a height of 50 feet tall. (common sense)
Source https://www.youtube.com/watch?v=Oh2StnRQ3qE&ab_channel=3Blue1Brown
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
bert = BertForMaskedLM.from_pretrained("bert-base-cased")
def get_probs(sentence, word_idx):
input_ids = tokenizer.encode(sentence)
input_ids = torch.tensor([input_ids])
logits = bert(input_ids)[0][0]
sorted_probs = logits[word_idx].sort(descending=True)
return sorted_probs
Downloading: 0%| | 0.00/29.0 [00:00<?, ?B/s]
Downloading: 0%| | 0.00/570 [00:00<?, ?B/s]
Downloading: 0%| | 0.00/208k [00:00<?, ?B/s]
Downloading: 0%| | 0.00/426k [00:00<?, ?B/s]
Downloading: 0%| | 0.00/416M [00:00<?, ?B/s]
Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight'] - This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model). - This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
sentence = f"Copenhagen is the capital of {tokenizer.mask_token}."
sorted_probs = get_probs(sentence, 6)
for i in range(10):
token = tokenizer._convert_id_to_token(sorted_probs.indices[i].numpy().tolist())
conf = sorted_probs.values[i].detach().numpy().tolist()
print(f"{token}: {conf}")
Denmark: 18.916259765625 Sweden: 12.33653736114502 Europe: 11.822346687316895 Copenhagen: 11.669342994689941 Danish: 11.636176109313965 Scandinavia: 11.140523910522461 Latvia: 11.12883186340332 Norway: 10.961599349975586 Zealand: 10.413448333740234 Greenland: 10.1879243850708
However, since language models were never trained to solve the exact tasks that we're asking them to solve, they are:
sentence = f"The nurse finished {tokenizer.mask_token} work."
sorted_probs = get_probs(sentence, 4)
for i in range(10):
token = tokenizer._convert_id_to_token(sorted_probs.indices[i].numpy().tolist())
conf = sorted_probs.values[i].detach().numpy().tolist()
print(f"{token}: {conf}")
her: 16.3854923248291 the: 12.918558120727539 his: 11.12331771850586 its: 10.893868446350098 their: 10.132686614990234 some: 9.50537395477295 all: 9.481547355651855 to: 9.095523834228516 up: 8.563329696655273 my: 8.469127655029297
Starting point for an error analysis: If my model hits 90% accuracy, why are the remaining 10% misclassified? Are there any patterns?
import spacy
from sklearn.linear_model import LogisticRegression
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import RandomizedSearchCV
from sklearn.metrics import classification_report
We'll use two basic logistic regression classifiers for sentiment analysis from lab 2 with TF-IDF features and BPE. We use the dataset for fact-checking -- Liar dataset.
from google.colab import files
uploaded = files.upload()
for fn in uploaded.keys():
print('User uploaded file "{name}" with length {length} bytes'.format(
name=fn, length=len(uploaded[fn])))
Saving train.tsv to train (1).tsv Saving valid.tsv to valid (1).tsv Saving test.tsv to test (1).tsv User uploaded file "train.tsv" with length 2408165 bytes User uploaded file "valid.tsv" with length 301556 bytes User uploaded file "test.tsv" with length 301118 bytes
import pandas as pd
import numpy as np
train_data = pd.read_csv('./train.tsv', sep='\t', header=None).fillna('')
valid_data = pd.read_csv('./valid.tsv', sep='\t', header=None).fillna('')
test_data = pd.read_csv('./test.tsv', sep='\t', header=None).fillna('')
train_data.head()
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 2635.json | false | Says the Annies List political group supports ... | abortion | dwayne-bohac | State representative | Texas | republican | 0 | 1 | 0 | 0 | 0 | a mailer |
1 | 10540.json | half-true | When did the decline of coal start? It started... | energy,history,job-accomplishments | scott-surovell | State delegate | Virginia | democrat | 0 | 0 | 1 | 1 | 0 | a floor speech. |
2 | 324.json | mostly-true | Hillary Clinton agrees with John McCain "by vo... | foreign-policy | barack-obama | President | Illinois | democrat | 70 | 71 | 160 | 163 | 9 | Denver |
3 | 1123.json | false | Health care reform legislation is likely to ma... | health-care | blog-posting | none | 7 | 19 | 3 | 5 | 44 | a news release | ||
4 | 9028.json | half-true | The economic turnaround started at the end of ... | economy,jobs | charlie-crist | Florida | democrat | 15 | 9 | 20 | 19 | 2 | an interview on CNN |
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import numpy as np
n = 100
vectorizer = TfidfVectorizer(max_features=1000)
features = vectorizer.fit_transform(train_data.values[:, 2])
# the raw/textual n-grams the vectorizer is using
feature_array = np.array(vectorizer.get_feature_names())
# TF-IDF scores of the words in each instance of the input datasets
tfidf_sorting = np.argsort(features.toarray()).flatten()[::-1]
# With a TF-IDF Vecotrizer, we can already tell the words that have high TF-IDF scores
# Even before feeding them to the model:
# Look-up of the top-n words at the indices with 1) highest and 2) lowest scores
print(feature_array[tfidf_sorting][:n])
print(feature_array[tfidf_sorting][-n:])
['veterans' 'you' 'really' 'community' 'know' 'department' 'your' 'like' 'them' 'there' 'out' 'to' 'of' 'our' 'are' 'has' 'the' 'equal' 'employees' 'entire' 'enough' 'energy' 'estimated' 'end' 'even' 'for' 'employee' 'else' 'every' 'elections' 'election' 'elected' 'either' 'eight' 'effect' 'education' 'economy' 'economic' 'earth' 'earned' 'earn' 'ever' 'fact' 'executive' 'feingold' 'food' 'florida' 'five' 'fiscal' 'first' 'fire' 'find' 'financial' 'fight' 'fewer' 'few' 'force' 'existing' 'fees' 'federal' 'fbi' 'favor' 'fastest' 'far' 'family' 'earmarks' 'failed' 'experience' 'expansion' 'families' 'due' 'early' 'currently' 'decades' 'decade' 'debt' 'debate' 'death' 'deal' 'days' 'day' 'david' 'data' 'cutting' 'cuts' 'cut' 'current' 'each' 'cruz' 'crist' 'crisis' 'criminal' 'crimes' 'crime' 'credit' 'creation' 'creating' 'created' 'create' 'coverage'] ['private' 'prison' 'primary' 'prices' 'previous' 'presidents' 'rape' 'rates' 'president' 'rating' 'residents' 'research' 'required' 'require' 'republicans' 'republican' 'representatives' 'report' 'rep' 'released' 'regulations' 'registered' 'refused' 'refugees' 'reform' 'reduced' 'reduce' 'recovery' 'record' 'recently' 'recent' 'received' 'receive' 'really' 'real' 'reagan' 're' 'presidential' 'premiums' 'officers' 'paul' 'passed' 'pass' 'party' 'part' 'parents' 'parenthood' 'paid' 'own' 'overseas' 'over' 'out' 'our' 'other' 'oregon' 'order' 'or' 'opposes' 'opposed' 'opponent' 'open' 'only' 'one' 'once' 'old' 'oil' 'ohio' 'officials' 'past' 'pay' 'power' 'paying' 'poverty' 'position' 'population' 'popular' 'poor' 'polls' 'poll' 'policy' 'policies' 'police' 'points' 'point' 'plant' 'plans' 'planned' 'plan' 'place' 'personal' 'person' 'perry' 'percentage' 'percent' 'per' 'people' 'pension' 'pelosi' 'pays' '000']
vectorizer = TfidfVectorizer(ngram_range=(1,2), max_features=1000)
features = vectorizer.fit_transform(train_data.values[:, 2])
# training a linear model, which is interpretable by design
lr = LogisticRegression(penalty='l2', max_iter=1000, multi_class='multinomial')
lr.fit(features, train_data.values[:, 1])
LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True, intercept_scaling=1, l1_ratio=None, max_iter=1000, multi_class='multinomial', n_jobs=None, penalty='l2', random_state=None, solver='lbfgs', tol=0.0001, verbose=0, warm_start=False)
features_test = vectorizer.fit_transform(test_data.values[:, 2])
preds_tfidf = lr.predict(features_test)
features_valid = vectorizer.transform(valid_data.values[:, 2])
preds_valid_tfidf = lr.predict(features_valid)
from tqdm import tqdm
import nltk
!pip install bpemb
Requirement already satisfied: bpemb in /usr/local/lib/python3.7/dist-packages (0.3.3) Requirement already satisfied: sentencepiece in /usr/local/lib/python3.7/dist-packages (from bpemb) (0.1.96) Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from bpemb) (2.23.0) Requirement already satisfied: gensim in /usr/local/lib/python3.7/dist-packages (from bpemb) (3.6.0) Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from bpemb) (1.19.5) Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from bpemb) (4.62.3) Requirement already satisfied: six>=1.5.0 in /usr/local/lib/python3.7/dist-packages (from gensim->bpemb) (1.15.0) Requirement already satisfied: smart-open>=1.2.1 in /usr/local/lib/python3.7/dist-packages (from gensim->bpemb) (5.2.1) Requirement already satisfied: scipy>=0.18.1 in /usr/local/lib/python3.7/dist-packages (from gensim->bpemb) (1.4.1) Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->bpemb) (1.24.3) Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->bpemb) (3.0.4) Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->bpemb) (2021.5.30) Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->bpemb) (2.10)
from bpemb import BPEmb
# Load english model with 25k word-pieces
bpemb_en = BPEmb(lang='en', dim=100, vs=25000)
def get_bpemb_features(dataset, bpemb):
# With bpemb we can tokenize and embed an entire document using .embed(x)
X = [bpemb.embed(x).mean(0) for x in tqdm(dataset[:,2])]
y = dataset[:,1]
return X,y
X_train,y_train = get_bpemb_features(train_data.values, bpemb_en)
X_valid,y_valid = get_bpemb_features(valid_data.values, bpemb_en)
X_test,y_test = get_bpemb_features(test_data.values, bpemb_en)
lr_bpemb = LogisticRegression(penalty='l2', max_iter=1000, multi_class='multinomial')
lr_bpemb.fit(X_train, y_train)
100%|██████████| 10240/10240 [00:01<00:00, 7750.14it/s] 100%|██████████| 1284/1284 [00:00<00:00, 7423.41it/s] 100%|██████████| 1267/1267 [00:00<00:00, 7757.24it/s]
LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True, intercept_scaling=1, l1_ratio=None, max_iter=1000, multi_class='multinomial', n_jobs=None, penalty='l2', random_state=None, solver='lbfgs', tol=0.0001, verbose=0, warm_start=False)
preds_bpemb = lr_bpemb.predict(X_test)
preds_valid_bpemb = lr_bpemb.predict(X_valid)
Looking at the classification report and the confusion matrix is the most basic step of performing error analysis - you can find which classes are confused with which other classes most often and compare the performance of the different classes. Differences in the performance allow for a founded model choice.
# TF-IDF model
report = classification_report(y_test, preds_tfidf, output_dict=True)
pd.DataFrame(report).transpose()
precision | recall | f1-score | support | |
---|---|---|---|---|
barely-true | 0.166667 | 0.099057 | 0.124260 | 212.000000 |
false | 0.230986 | 0.329317 | 0.271523 | 249.000000 |
half-true | 0.221198 | 0.181132 | 0.199170 | 265.000000 |
mostly-true | 0.215152 | 0.294606 | 0.248687 | 241.000000 |
pants-fire | 0.090909 | 0.021739 | 0.035088 | 92.000000 |
true | 0.184332 | 0.192308 | 0.188235 | 208.000000 |
accuracy | 0.208366 | 0.208366 | 0.208366 | 0.208366 |
macro avg | 0.184874 | 0.186360 | 0.177827 | 1267.000000 |
weighted avg | 0.197334 | 0.208366 | 0.196564 | 1267.000000 |
# BPEmb model
report = classification_report(y_test, preds_bpemb, output_dict=True)
pd.DataFrame(report).transpose()
precision | recall | f1-score | support | |
---|---|---|---|---|
barely-true | 0.174312 | 0.089623 | 0.118380 | 212.000000 |
false | 0.241758 | 0.353414 | 0.287113 | 249.000000 |
half-true | 0.234604 | 0.301887 | 0.264026 | 265.000000 |
mostly-true | 0.251656 | 0.315353 | 0.279926 | 241.000000 |
pants-fire | 0.451613 | 0.152174 | 0.227642 | 92.000000 |
true | 0.233333 | 0.134615 | 0.170732 | 208.000000 |
accuracy | 0.240726 | 0.240726 | 0.240726 | 0.240726 |
macro avg | 0.264546 | 0.224511 | 0.224637 | 1267.000000 |
weighted avg | 0.244714 | 0.240726 | 0.229260 | 1267.000000 |
confusion_matrix(valid_data.values[:, 1], preds_valid_tfidf)
array([[19, 67, 42, 67, 2, 40], [32, 73, 53, 64, 4, 37], [19, 67, 50, 66, 0, 46], [25, 52, 54, 67, 7, 46], [ 8, 36, 25, 20, 2, 25], [17, 32, 43, 52, 2, 23]])
valid_values = valid_data.values[:, 1]
confusion_matrix(valid_values, preds_valid_bpemb)
array([[20, 83, 74, 40, 5, 15], [28, 98, 67, 37, 12, 21], [17, 63, 86, 60, 4, 18], [16, 70, 53, 76, 2, 34], [13, 48, 23, 9, 14, 9], [10, 37, 39, 53, 1, 29]])
Can you spot any insights?
We can now see what are the most importand words for each class by looking at the weights the model assigned to each feature in the input.
top_features = 10
# get the model's weights: n_classes x n_features - (? , ?)
all_class_coef = lr.coef_
for i, cls in enumerate(lr.classes_): # for each of the classes
print(cls)
# get the weights for the class
coef = all_class_coef[i]
# find the top negative and positive features for the class
top_positive_coefficients = np.argsort(coef)[-top_features:]
top_negative_coefficients = np.argsort(coef)[:top_features]
# combine them in one array
top_coefficients = np.hstack([top_negative_coefficients, top_positive_coefficients])
# create plot - humans tend to understand better plot visualizations
feature_names = vectorizer.get_feature_names()
plt.figure(figsize=(15, 5))
colors = ['red' if c < 0 else 'blue' for c in coef[top_coefficients]]
plt.bar(np.arange(2 * top_features), coef[top_coefficients], color=colors)
feature_names = np.array(feature_names)
plt.xticks(np.arange(1, 1 + 2 * top_features), feature_names[top_coefficients], rotation=60, ha='right')
plt.show()
barely-true
false
half-true
mostly-true
pants-fire
true
We can now look at the predictions of the model to try and spot any problems or particular features of the model.
We will look at instances the model classifies in/correctly, which in turn can be broken down to: instances with high/ low/ medium confidence. To do that, we'll use the probability/confidence of the model's prediction, which is available for all ML models.
# get the probability of the model
valid_pred_prob = lr_bpemb.predict_proba(X_valid)
from collections import defaultdict
# collect correct and wrong predictions, keeping the confidence of the prediction
errors = defaultdict(lambda: [])
correct_preds = defaultdict(lambda: [])
for (i, instance), pred, pred_score in zip(valid_data.iterrows(), preds_valid_bpemb, valid_pred_prob):
# get the index/id of the gold class in the probability array (n_classes x 1)
index_of_class = np.where(lr_bpemb.classes_ == instance[1])
# get the index/id of the predicted class in the probability array (n_classes x 1)
index_of_pred_class = np.where(lr_bpemb.classes_ == pred)
# depending on whether the prediction is correct, collect the instances as errors or correct predictions
if pred != instance[1]:
errors[instance[1]].append((instance[2], pred_score[index_of_class], pred_score[index_of_pred_class], pred))
else:
correct_preds[instance[1]].append((instance[2], pred_score[index_of_class], pred_score, pred))
import pprint
for cls in lr_bpemb.classes_:
print(cls)
print('High probability for correct class')
# sort the errors by the probability of the gold class and look at the:
# 1. instances where the gold class had a high probability
pprint.pprint(sorted(errors[cls], key=lambda x: x[1])[-10:])
print('Low probability for correct class')
# 1. instances where the gold class had a low probability
pprint.pprint(sorted(errors[cls], key=lambda x: x[1])[:10])
barely-true High probability for correct class [('Just like Donald Trump, David Jolly wants to outlaw a womans right to ' 'choose.', array([0.22251137]), array([0.23379486]), 'false'), ('Stimulus dollars paid for windmills from China.', array([0.22542609]), array([0.26144215]), 'half-true'), ('Obamacare is costing 2 million jobs.', array([0.22576654]), array([0.24788424]), 'half-true'), ('Gerry Connolly and his fellow Democrats went on a spending spree and now ' 'their credit card is maxed out.', array([0.22596097]), array([0.24121691]), 'half-true'), ('Says Hillary Clinton has been saying lately that she thinks that I am, not ' 'qualifiedto be president.', array([0.23431039]), array([0.24782692]), 'false'), ('Rick Scotts prison plan would cut Floridas prison budget in half, close ' 'prisons, and release tens of thousands of prisoners early -- murderers, ' 'rapists, sex offenders, armed robbers, drug dealers.', array([0.23454344]), array([0.23867524]), 'half-true'), ('Originally, Democrats promised that if you liked your health care plan, you ' 'could keep it. One year later we know that you need a waiver to keep your ' 'plan.', array([0.23811094]), array([0.24234747]), 'false'), ('DNC chair Debbie Wasserman Schultz denied unemployment went up under Obama.', array([0.24116964]), array([0.24564027]), 'false'), ('Says Gina Raimondos venture capital firm secured a secret no-bid contract ' 'funded by taxpayers.', array([0.24315745]), array([0.30545336]), 'false'), ('Says his plan to end the toll on Ga. 400 fulfills his campaign promise to ' 'commuters.', array([0.26012345]), array([0.27767736]), 'false')] Low probability for correct class [('Rhode Island is one of the worst states for income equality.', array([0.05065124]), array([0.28476863]), 'true'), ('Veterans can now download their electronic medical records with a click of ' 'the mouse.', array([0.0535591]), array([0.25986084]), 'true'), ('Voted the best specialty plate in America', array([0.07670996]), array([0.46605402]), 'mostly-true'), ('The U.S. "only ranks 25th worldwide on defense spending as a percentage of ' 'GDP."', array([0.07743288]), array([0.32516985]), 'mostly-true'), ('In 2013, the United States accepted 67 percent of the worlds refugees.', array([0.07826386]), array([0.28742681]), 'mostly-true'), ('Nearly 25 percent of all automobile accidents are caused by texting while ' 'driving.', array([0.07840081]), array([0.28705188]), 'mostly-true'), ('The Wisconsin Economic Development Corp. cannot account for creating one ' 'single job.', array([0.08486728]), array([0.3132537]), 'false'), ('There are over 100 pipelines between the United States and Canada right ' 'now.', array([0.08665623]), array([0.2982351]), 'true'), ('We have less Americans working now than in the 70s.', array([0.08870408]), array([0.3382231]), 'mostly-true'), ('Financed the largest parking expansion program without a rate increase.', array([0.09100782]), array([0.35659618]), 'half-true')] false High probability for correct class [('During the Monica Lewinsky scandal, the Clintons brought the Rev. Jeremiah ' 'Wright to the White House for "spiritual counseling."', array([0.23538899]), array([0.23926563]), 'pants-fire'), ('Says Charlie Bass forfeits right to equal cost for TV ads under FCC rules', array([0.23605398]), array([0.25696707]), 'half-true'), ('By advocating new requirements for voters to show ID cards at the polls, ' 'Republicans want to literally drag us all the way back to Jim Crow laws.', array([0.23784883]), array([0.24978705]), 'half-true'), ('Says Bill White is for gay marriage.', array([0.24089753]), array([0.24428836]), 'pants-fire'), ('Says Robert Hurt supports the tax loopholes that send American jobs ' 'overseas.', array([0.24317522]), array([0.25988972]), 'half-true'), ('Says U.S. Rep. Tom Price is sending letters both supporting and opposing ' 'the small-business killing Internet Tax Mandate.', array([0.24392553]), array([0.25935293]), 'half-true'), ('Obama Makes Huge Move to BAN Social Security Recipients From Owning Guns', array([0.24778217]), array([0.31845966]), 'half-true'), ('The (State Board of Administration) transparency issue got a great airing ' 'the last legislative session.', array([0.25825491]), array([0.2697218]), 'true'), ('Social Security is a Ponzi scheme.', array([0.28221914]), array([0.32342824]), 'barely-true'), ('On expanding Medicaid as part of the health care law', array([0.28621979]), array([0.30116083]), 'half-true')] Low probability for correct class [('The Atlanta Beltline paid nearly $3.5 million for less than a quarter-acre.', array([0.05157401]), array([0.40129265]), 'mostly-true'), ('Rhode Island has the highest percentage of lawyers per capita in the ' 'country.', array([0.06375065]), array([0.3522662]), 'mostly-true'), ('Says most Texas schools spend 45 out of 180 school days in mandated ' 'testing.', array([0.07274155]), array([0.29758186]), 'mostly-true'), ('More than 7,000 Americans lost their lives to climate change-fueled events ' 'last year.', array([0.08237582]), array([0.32266697]), 'mostly-true'), ('The average federal employee makes $120,000 a year. The average private ' 'employee makes $60,000 a year.', array([0.08437365]), array([0.36242032]), 'mostly-true'), ('Students at the University of Texas in Austin have been advised not to wear ' 'cowboy boots or cowboy hats on Halloween.', array([0.08573882]), array([0.35872029]), 'pants-fire'), ('State employees in Wisconsin earn about 8 percent less than if they worked ' 'in the private sector.', array([0.09202228]), array([0.33041181]), 'mostly-true'), ('My opponent supported policies that increased tuition by 18 percent.', array([0.09416232]), array([0.41780882]), 'half-true'), ('We have the highest tax rate anywhere in the world.', array([0.09905309]), array([0.35885587]), 'mostly-true'), ('Says the jurisdictions with the strictest gun control laws, almost without ' 'exception have the highest crime rates and the highest murder rates.', array([0.1068676]), array([0.33978121]), 'mostly-true')] half-true High probability for correct class [('He voted twice for a budget resolution that increases the taxes on ' 'individuals making $42,000 a year.', array([0.25026902]), array([0.31943667]), 'mostly-true'), ("''Both Democrats and Republicans are advocating for the use of student test " "scores to measure teacher effectiveness.''", array([0.25223047]), array([0.26439614]), 'mostly-true'), ('We lowered the business tax from 4 percent down to 1 percent.', array([0.25349313]), array([0.33560291]), 'mostly-true'), ('The violent crime rate in America is the same as it was in 1968, yet our ' 'prison system has grown by over 500 percent.', array([0.25502656]), array([0.27662334]), 'mostly-true'), ('One out of two Americans are living either in or near poverty. That means ' '150 million Americans, half of us.', array([0.26287953]), array([0.28627951]), 'mostly-true'), ('When I was governor, not only did test scores improve we also narrowed the ' 'achievement gap.', array([0.26813932]), array([0.26906317]), 'mostly-true'), ('Georgias high school graduation rate topped 80 percent in 2010.', array([0.27016056]), array([0.33132616]), 'mostly-true'), ('The accuracy of the Obama tax calculator', array([0.27463433]), array([0.31739085]), 'false'), ('The lieutenant governor has the power to be an economic ambassador and ' 'negotiate on economic development', array([0.2825785]), array([0.33883554]), 'false'), ('Violent crime is up since the last year of Sharpe James administration. ' 'This year its higher. The unemployment rate is almost 15 percent. The high ' 'school dropout rate is over 50 percent.', array([0.28397836]), array([0.29554822]), 'mostly-true')] Low probability for correct class [('President Obama is shrinking our military.', array([0.07725838]), array([0.68504352]), 'pants-fire'), ('New Jerseys once-broken pension system is now solvent.', array([0.0798887]), array([0.27843326]), 'false'), ('Ken Lanci is a lifelong Clevelander', array([0.08334688]), array([0.3849569]), 'pants-fire'), ('New Hampshire is currently the only state in the nation that does not have ' 'a full-service veterans hospital or equivalent access.', array([0.10698837]), array([0.23691774]), 'true'), ('Limberbutt McCubbins (a five-year-old cat) is a candidate in the 2016 ' 'presidential election.', array([0.11242005]), array([0.22600713]), 'false'), ('I waited until (the Trans-Pacific Partnership trade agreement) had actually ' 'been negotiated before deciding whether to endorse it.', array([0.12169258]), array([0.37495618]), 'false'), ('If you dont buy cigarettes at your local supermarket, your grocery bill ' 'wont go up a dime. The same is true of the sugary drink tax. If passed, you ' 'can avoid paying the tax by not buying sugary drinks.', array([0.12245907]), array([0.26155051]), 'barely-true'), ('The new Ukrainian government introduced a law abolishing the use of ' 'languages other than Ukrainian in official circumstances.', array([0.1248197]), array([0.24838289]), 'false'), ('Henry Kissinger "said that we should meet with Iran guess what without ' 'precondition."', array([0.12795797]), array([0.28149231]), 'false'), ('Each U.S. House member who voted to overhaul Social Security in 1983 was ' 're-elected.', array([0.13381434]), array([0.27317548]), 'false')] mostly-true High probability for correct class [('Half the lottery directors across the country had not run a lottery before ' 'they were hired.', array([0.2413346]), array([0.26183837]), 'true'), ('Texas had the worst voter participation in the country in the November 2014 ' 'elections.', array([0.24289069]), array([0.29152506]), 'true'), ('The port provides more than 297,000 jobs directly to the state of Georgia.', array([0.24370102]), array([0.30201776]), 'half-true'), ('Says Marco Rubio has a 98 percent voting record with the Koch brothers.', array([0.24370212]), array([0.25260857]), 'true'), ('Says we have now more Border Patrol officers than weve had at any time in ' 'our history.', array([0.24469714]), array([0.2686888]), 'true'), ('House Democrats have a two-to-one advantage money-wise against House ' 'Republicans.', array([0.24724759]), array([0.27087822]), 'barely-true'), ('I think with the exception of the last year or maybe the last two years, we ' 'were at 100 percent when it came to contributing to the Providence pension ' 'fund.', array([0.25466793]), array([0.29151177]), 'true'), ('Says roughly two-thirds of (state) corporations didnt pay any income tax in ' 'Virginia.', array([0.26660059]), array([0.30925857]), 'true'), ('Says 92 percent of Texas counties had no abortion provider in 2008.', array([0.26796873]), array([0.28335137]), 'true'), ('Says Texas has the fourth-lowest debt per capita of any state in the ' 'nation, and we are the lowest of any of the big states.', array([0.27790649]), array([0.30700115]), 'true')] Low probability for correct class [('Barack Obama refuses to acknowledge Jerusalem as the capital of Israel.', array([0.06177102]), array([0.31776937]), 'false'), ('Says Nelson Mandela was a communist.', array([0.06902091]), array([0.2780403]), 'false'), ('A proposed ban on hollow-point bullets and bullets that expand upon impact ' 'essentially bans deer hunting.', array([0.09091255]), array([0.28466038]), 'false'), ('Says Rick Perry supported a guest worker program to help people who would ' 'otherwise be illegal aliens.', array([0.09571792]), array([0.25810736]), 'barely-true'), ('Says Ohios economic recovery started in February 2010.', array([0.095992]), array([0.40859184]), 'false'), ('Zika mosquitoes cant catch me.', array([0.09634855]), array([0.3119181]), 'false'), ('Congress will begin its recess without having allocated one penny to fight ' 'Zika.', array([0.09879078]), array([0.27858522]), 'false'), ('Says that President Obama said in 2008 that his proposed greenhouse gas ' 'regulations will bankrupt anyone who wants to build a new coal-fired power ' 'plant.', array([0.10853077]), array([0.28406668]), 'false'), ('Says Barack Obama has pension investments that include Chinese firms, and ' 'investments through a Caymans trust.', array([0.11133482]), array([0.22399151]), 'false'), ('Access for 12,000 women to use Planned Parenthood -- not for the right to ' 'choose, but for basic health care -- was taken away by Wisconsin Gov. Scott ' 'Walker and Lt. Gov. Rebecca Kleefisch.', array([0.11281116]), array([0.2662776]), 'false')] pants-fire High probability for correct class [('Blue Cross Blue Shield cancelled all their individual (health care) ' 'policies in the state of Texas, effective Dec. 31.', array([0.18454858]), array([0.20862296]), 'half-true'), ('Says Ted Cruz said, There is no place for gays or atheists in my America. ' 'None. Our Constitution makes that clear.', array([0.19096911]), array([0.22496935]), 'false'), ('Never mind that no red light camera, no speed camera, nor any radar gun has ' 'ever stopped one accident from occurring.', array([0.19363173]), array([0.25798594]), 'false'), ("Page 992 of the health care bill will establish school-based 'health' " 'clinics. Your children will be indoctrinated and your grandchildren may be ' 'aborted!', array([0.19934984]), array([0.23319877]), 'false'), ('Says Denzel Washington supports Donald Trump and speaks out against Barack ' 'Obama.', array([0.19942185]), array([0.28598604]), 'barely-true'), ('Democratic Sens. Ed Markey, Al Franken and Jeanne Shaheen took Bribes From ' 'Iran They Back Insane NUKE Deal.', array([0.20414693]), array([0.2669754]), 'barely-true'), ('Reporters rehearse questions with White House press (secretary).', array([0.20751695]), array([0.23709219]), 'true'), ('From Obama\'s book: "I found a solace in nursing a pervasive sense of ' 'grievance and animosity against my mother\'s race."', array([0.21706323]), array([0.21839428]), 'half-true'), ('Barack Obama is the first president to file lawsuits against the states he ' 'swore an oath to protect.', array([0.23675977]), array([0.24826283]), 'false'), ('Michelle Nunns own plan says she funded organizations linked to terrorists.', array([0.26219134]), array([0.30099145]), 'barely-true')] Low probability for correct class [('In the past two years, Democrats have spent more money than this country ' 'has spent in the last 200 years combined.', array([0.00941126]), array([0.30349954]), 'mostly-true'), ('The number of illegal immigrants in the United States is 30 million, it ' 'could be 34 million.', array([0.02548566]), array([0.30411461]), 'mostly-true'), ('I never supported a state income tax for Texas.', array([0.02934045]), array([0.22741555]), 'mostly-true'), ('CNN did a poll recently where Obama and I are statistically tied.', array([0.0352685]), array([0.29207709]), 'mostly-true'), ('The last quarter, it was just announced, our gross domestic product was ' 'below zero. Who ever heard of this? Its never below zero.', array([0.03689553]), array([0.24529246]), 'true'), ('The United States stood alone in the war in Iraq.', array([0.03703309]), array([0.28805474]), 'mostly-true'), ('Says President Barack Obama doubled the national debt, which had taken more ' 'than two centuries to accumulate, in one year.', array([0.03879033]), array([0.25173286]), 'half-true'), ('Gov. Deal has the worst record on education in the history of this state.', array([0.04273135]), array([0.24002365]), 'true'), ('Medicare insurance premiums will be rising to $120.20 per month in 2013 and ' '$247.00 per month in 2014.', array([0.04405389]), array([0.26876243]), 'mostly-true'), ('Out of 67 counties (in Florida), I won 66, which is unprecedented. Its ' 'never happened before.', array([0.04427767]), array([0.29171739]), 'mostly-true')] true High probability for correct class [('In 1950, the average American lived for 68 years and 16 workers supported ' 'one retiree. Today, the average life expectancy is 78 and three workers ' 'support one retiree.', array([0.24255829]), array([0.30739337]), 'mostly-true'), ('The commission form of government is definitely losing favor in the United ' 'States.', array([0.24277495]), array([0.25071386]), 'false'), ('Latinos are 17 percent of our countrys population but hold only 2 percent ' 'of its wealth.', array([0.24416761]), array([0.29033684]), 'mostly-true'), ('Only five states, including Georgia, do not have a hate crimes law.', array([0.24666866]), array([0.27113862]), 'mostly-true'), ('Wisconsin still ranks first among the 50 states in manufacturing jobs per ' 'capita.', array([0.26240926]), array([0.36058172]), 'mostly-true'), ('Georgia has had ʺmore bank failures than any other state.ʺ', array([0.2656559]), array([0.26706416]), 'mostly-true'), ('We spend more per student than almost any other major country in the world.', array([0.26705833]), array([0.34742938]), 'mostly-true'), ('In the last 50 years, (the federal government has) only balanced the budget ' 'five times.', array([0.26994446]), array([0.31491449]), 'mostly-true'), ('More than half of all black children live in single-parent households, a ' 'number that has doubled doubled since we were children.', array([0.28631034]), array([0.29431726]), 'mostly-true'), ('Says, Oregon has the third largest class size in the nation.', array([0.29808822]), array([0.31721509]), 'mostly-true')] Low probability for correct class [('Ronald Reagan did amnesty.', array([0.06374213]), array([0.36507979]), 'pants-fire'), ('Says Apple CEO Steve Jobs told President Obama that the company moved ' 'factories to China because it needed 30,000 engineers.', array([0.07745747]), array([0.25274685]), 'half-true'), ('The Obama administration is allowing state waivers from welfare work ' 'requirements only if they had a credible plan to increase employment by 20 ' 'percent.', array([0.07835933]), array([0.23909043]), 'false'), ('Says Jeb Bush -- not Charlie Crist -- signed legislation that let Duke ' 'Energy collect money for nuclear projects.', array([0.08510697]), array([0.3229992]), 'false'), ('John McCain and George Bush have "absolutely no plan for universal health ' 'care."', array([0.08515328]), array([0.30561856]), 'false'), ('Our trade deficit in goods reached nearly $800 billion last year alone.', array([0.0861052]), array([0.36136203]), 'mostly-true'), ('McCain "said he was \'stumped\' when asked whether contraceptives help stop ' 'the spread of HIV."', array([0.08706378]), array([0.29362644]), 'false'), ('Weare picking up oil with shovels and paper and plastic bags.', array([0.08975892]), array([0.29177719]), 'barely-true'), ('The House-passed budget proposal could cut funding for programs that help ' 'keep local neighborhoods safe, slash more than $1.7 million in anti-terror ' 'funds for Ohio.', array([0.09587181]), array([0.23432662]), 'half-true'), ('The Israelisgave up 1,000 terrorists in return for one sergeant.', array([0.09609552]), array([0.25524104]), 'half-true')]
# We can check for some instances what were their TF-IDF scores (the input for the model)
text = 'Ken Lanci is a lifelong Clevelander'
print(vectorizer.transform([text]))
text = 'Ken Lanci is a lifelong Clevelander'
print([word for word in text.lower().split() if word in vectorizer.vocabulary_])
(0, 414) 1.0 ['is']
# do the same for correct predictions
for cls in lr_bpemb.classes_:
print(cls)
print('High probability for correct class')
pprint.pprint(sorted(correct_preds[cls], key=lambda x: x[1])[-10:])
print('Low probability for correct class')
pprint.pprint(sorted(correct_preds[cls], key=lambda x: x[1])[:10])
barely-true High probability for correct class [('Austin politicians want to cram more kids into classrooms so they dont have ' 'to make the tough decisions to balance the budget.', array([0.26302518]), array([0.26302518, 0.18762024, 0.24028468, 0.15297156, 0.05470604, 0.10139231]), 'barely-true'), ('Rosalyn Dance voted against President Obamas Medicaid expansion.', array([0.27706225]), array([0.27706225, 0.27214181, 0.1821155 , 0.09545814, 0.11163969, 0.06158261]), 'barely-true'), ('Says Hillary Clinton was asked repeatedly to provide security in Benghazi ' 'on several occasions, including direct cables.', array([0.28209335]), array([0.28209335, 0.2077499 , 0.16141698, 0.11435118, 0.1009411 , 0.13344748]), 'barely-true'), ('Says Elizabeth Warren lied when she says I want to abolish the Federal ' 'Minimum Wage.', array([0.28353144]), array([0.28353144, 0.21341335, 0.18486839, 0.15152536, 0.05881866, 0.1078428 ]), 'barely-true'), ('Michelle Nunn has praised the Occupy movement.', array([0.28617855]), array([0.28617855, 0.17698218, 0.22941991, 0.13742239, 0.08343846, 0.08655851]), 'barely-true'), ('The federal government now tells us which light bulbs to buy.', array([0.28647602]), array([0.28647602, 0.21257576, 0.15413087, 0.10254278, 0.15447978, 0.0897948 ]), 'barely-true'), ('Says Georgia Democratic Senate hopeful Michelle Nunn supports higher taxes.', array([0.29150981]), array([0.29150981, 0.19441539, 0.16508727, 0.15857091, 0.11708733, 0.07332929]), 'barely-true'), ('Says Congressman Jon Runyan has a plan to raise Medicare costs $6,400 a ' 'year.', array([0.31599415]), array([0.31599415, 0.19356116, 0.17842892, 0.1420902 , 0.07725586, 0.09266971]), 'barely-true'), ('Says Russ Feingold voted to raise taxes on Social Security benefits for ' 'seniors, he even tried to give Social Security benefits to illegal ' 'immigrants.', array([0.33618429]), array([0.33618429, 0.21845278, 0.1865223 , 0.11323648, 0.08039732, 0.06520682]), 'barely-true'), ('Brendan Doherty wants to repeal Obamacare, increasing drug prices for ' 'seniors.', array([0.42830558]), array([0.42830558, 0.19574949, 0.17069489, 0.08435248, 0.07753514, 0.04336243]), 'barely-true')] Low probability for correct class [('In the last 24 months, 10 rural Texas hospitals have been forced to shut ' 'their doors because state leaders chose not to expand Medicaid.', array([0.19998145]), array([0.19998145, 0.19605453, 0.18811482, 0.19879774, 0.06443365, 0.15261781]), 'barely-true'), ('The brutal fact is, when it comes to education, America is slipping behind ' 'other nations.', array([0.20643283]), array([0.20643283, 0.18887344, 0.20050828, 0.17323369, 0.12188373, 0.10906803]), 'barely-true'), ('Says that President Obama promised that with the stimulus plan, ' 'unemployment would never go above 8 percent. He even said it would be 6 ' 'percent by now.', array([0.21532266]), array([0.21532266, 0.18253056, 0.21468393, 0.19822653, 0.06827082, 0.1209655 ]), 'barely-true'), ('Says there are a half a trillion dollars in cuts to Medicare that are going ' 'to go in place as a result of health care reform.', array([0.22312095]), array([0.22312095, 0.20618386, 0.17928683, 0.19845712, 0.06507634, 0.1278749 ]), 'barely-true'), ('John Boehner said the jobs of teachers and nurses and police officers and ' 'firefighters are government jobs that werent worth saving.', array([0.22520183]), array([0.22520183, 0.19357467, 0.22196905, 0.13771429, 0.08113373, 0.14040644]), 'barely-true'), ('The U.S. Army had a training program that put evangelical Christians, ' 'Catholics and Mormons in the same category of religious extremism as we do ' 'al-Qaida.', array([0.22967826]), array([0.22967826, 0.18338401, 0.190708 , 0.15855762, 0.13616616, 0.10150596]), 'barely-true'), ('Some even advocate wiping out 401(k)s entirely and replacing them with ' 'government-run accounts.', array([0.23115243]), array([0.23115243, 0.20258981, 0.16605986, 0.17082931, 0.10736821, 0.12200038]), 'barely-true'), ('When Peppers tax was finally voted down, Pepper laid off cops and closed ' 'jails and let criminals run free in the streets.', array([0.24628683]), array([0.24628683, 0.18695458, 0.2315768 , 0.13769275, 0.0651632 , 0.13232584]), 'barely-true'), ('Says Donald Trump has given more money to Democratic candidates than ' 'Republican candidates.', array([0.25240317]), array([0.25240317, 0.12307927, 0.21806967, 0.22313041, 0.03756317, 0.14575431]), 'barely-true'), ('Says Gwen Graham was a Washington lobbyist.', array([0.25427718]), array([0.25427718, 0.19731912, 0.10562465, 0.18591881, 0.16743849, 0.08942174]), 'barely-true')] false High probability for correct class [('Signs letter saying Consumer Product Safety Commission is acting without ' 'consultation or input from the company to stop the sale of Buckyballs.', array([0.31712839]), array([0.13331369, 0.31712839, 0.14417145, 0.11358885, 0.10241171, 0.18938591]), 'false'), ("Obama nominee Dawn Johnsen called motherhood 'involuntary servitude.'", array([0.31836704]), array([0.15713602, 0.31836704, 0.12322423, 0.05435816, 0.26798297, 0.07893158]), 'false'), ('On support for Trade Promotion Authority, calledfast-track', array([0.33002039]), array([0.15445088, 0.33002039, 0.2705143 , 0.09879464, 0.03093324, 0.11528656]), 'false'), ('No one questioned that she (Judge Sotomayor) was out of the mainstream.', array([0.33043471]), array([0.15803969, 0.33043471, 0.17683138, 0.11791864, 0.12958066, 0.08719492]), 'false'), ('One of the main functions of the Department of Homeland Securitys SAVE ' 'database is checking voter registration citizenship status.', array([0.34409945]), array([0.10120187, 0.34409945, 0.12620345, 0.12804975, 0.13388489, 0.16656059]), 'false'), ('Minnick voted to let the government fund abortion under Obamacare.', array([0.35505905]), array([0.21344999, 0.35505905, 0.13388265, 0.07698283, 0.15895469, 0.06167078]), 'false'), ('On support of Gov. John Kasichs drilling tax plan', array([0.35588041]), array([0.16152847, 0.35588041, 0.20941763, 0.06896681, 0.10600024, 0.09820645]), 'false'), ('Ed Gillespie supports a personhood amendment.', array([0.37261171]), array([0.10623974, 0.37261171, 0.19255768, 0.12480286, 0.09520349, 0.10858452]), 'false'), ('Says Democratic Party created Planned Parenthood', array([0.38097574]), array([0.19667093, 0.38097574, 0.07340432, 0.03324059, 0.20444405, 0.11126436]), 'false'), ('On birthright citizenship for illegal immigrants.', array([0.44855039]), array([0.07032212, 0.44855039, 0.18433352, 0.09191186, 0.08787095, 0.11701116]), 'false')] Low probability for correct class [("The Z-visa that was offered in that Senate bill let everybody who's here " 'illegally, other than criminals, stay here for the rest of their lives.', array([0.19679818]), array([0.16658818, 0.19679818, 0.18915559, 0.16522236, 0.09465374, 0.18758195]), 'false'), ('A major part of the climate change bill sponsored by Sens. John Kerry and ' 'Joe Lieberman was essentially written by BP.', array([0.19737013]), array([0.18665663, 0.19737013, 0.16271615, 0.16618049, 0.15887357, 0.12820303]), 'false'), ('Every person on death row was a foster kid.', array([0.20258425]), array([0.11883106, 0.20258425, 0.15938342, 0.15787086, 0.16009893, 0.20123148]), 'false'), ('The ammunition used in the Orlando shooting is banned by Geneva Convention. ' 'It enters the body, spins explodes.', array([0.20282816]), array([0.11077081, 0.20282816, 0.16049543, 0.14251996, 0.19267677, 0.19070888]), 'false'), ('Liberals have figured out a Facebook algorithm and all the people getting ' 'banned from Facebook are somehow conservatives.', array([0.20584772]), array([0.1963959 , 0.20584772, 0.17838423, 0.12362139, 0.14751802, 0.14823273]), 'false'), ('Reporters have uncovered another Scott company accused of criminal acts. ' 'But Scott wont come clean.', array([0.20629242]), array([0.17971584, 0.20629242, 0.18833269, 0.17972274, 0.10024639, 0.14568992]), 'false'), ('We dont allow filming inside of the City Hall unless there is a specific ' 'reason.', array([0.20826226]), array([0.1833922 , 0.20826226, 0.13051615, 0.15941511, 0.15910828, 0.159306 ]), 'false'), ('The number of Americans who receive means-tested government benefits -- ' 'welfare -- now outnumbers those who are year-round full-time workers.', array([0.20972645]), array([0.20289108, 0.20972645, 0.15929458, 0.20311458, 0.0534685 , 0.17150482]), 'false'), ('In Mexico, they dont have birth certificates... they dont have registration ' 'cards for voters. They have one national ID.', array([0.21030414]), array([0.1849648 , 0.21030414, 0.13818485, 0.1895959 , 0.11062751, 0.1663228 ]), 'false'), ('In his first meeting with University of Wisconsin System officials, ' 'Republican Governor elect Scott Walker told them to prepare for cuts.', array([0.21111488]), array([0.20259171, 0.21111488, 0.18219824, 0.13625999, 0.12545277, 0.14238241]), 'false')] half-true High probability for correct class [('The planned expansion of Savannahs port is a jobs creating project.', array([0.32723884]), array([0.1342415 , 0.27495789, 0.32723884, 0.12055844, 0.04393291, 0.09907042]), 'half-true'), ('Banks paid Hillary Clinton over $1 million and are contributing millions ' 'more to elect her.', array([0.33252706]), array([0.17663693, 0.14261079, 0.33252706, 0.18351365, 0.02794379, 0.13676779]), 'half-true'), ('Every dollar we invested in high-quality, early education programs can save ' 'more than $7 later on by boosting graduation rates, reducing teen ' 'pregnancy, even reducing crime.', array([0.33281229]), array([0.1811519 , 0.13526641, 0.33281229, 0.19585909, 0.02833078, 0.12657953]), 'half-true'), ('After losing 750,000 jobs a month before this administration, the U.S. ' 'economy under Barack Obama has had 20 straight months of growth, has added ' '2.8 million jobs in the private sector and added millions of jobs in ' 'manufacturing.', array([0.33877502]), array([0.14704204, 0.13690981, 0.33877502, 0.22483682, 0.02994485, 0.12249147]), 'half-true'), ('Says Texas unemployment rate has doubled on Rick Perrys watch.', array([0.35578192]), array([0.12502519, 0.17926101, 0.35578192, 0.14677991, 0.07851097, 0.114641 ]), 'half-true'), ('Bill Clinton cut the military drastically.', array([0.35923936]), array([0.17264085, 0.1754223 , 0.35923936, 0.12168385, 0.06506992, 0.10594372]), 'half-true'), ('For each $1 billion in infrastructure investment, 42,000 jobs are created.', array([0.36015066]), array([0.16897203, 0.10566123, 0.36015066, 0.25186987, 0.01831252, 0.09503369]), 'half-true'), ('Georgia can substantially increase its funding for education by going after ' '$2.5 billion in uncollected taxes.', array([0.36477485]), array([0.12987317, 0.1299191 , 0.36477485, 0.25847386, 0.03318873, 0.08377029]), 'half-true'), ('Over half of science, technology, engineering and mathematics students ' 'receiving advanced degrees are not citizens of the United States of ' 'America.', array([0.41276349]), array([0.08239517, 0.06830964, 0.41276349, 0.30943192, 0.02646618, 0.1006336 ]), 'half-true'), ('David Perdue led efforts to ship thousands of jobs overseas.', array([0.45944817]), array([0.10616967, 0.1811695 , 0.45944817, 0.13171546, 0.05535313, 0.06614408]), 'half-true')] Low probability for correct class [('I created the school choice program.', array([0.21036266]), array([0.17511044, 0.19106213, 0.21036266, 0.20476119, 0.09644064, 0.12226293]), 'half-true'), ('If Question 1 were to pass here in Nevada, we would have more restrictive ' 'gun laws here in Nevada dealing with the transfer of firearms than they do ' 'in California.', array([0.21339909]), array([0.14635815, 0.19714444, 0.21339909, 0.20312909, 0.04491439, 0.19505484]), 'half-true'), ('Under (Obamacare), you cant reward a person for better behavior. You cant ' 'have incentives to be healthier.', array([0.21461677]), array([0.21150672, 0.20818181, 0.21461677, 0.16175092, 0.09431433, 0.10962944]), 'half-true'), ('Marco Rubio "supported $800,000 for AstroTurf for a field where he played ' 'flag football."', array([0.21503959]), array([0.15782732, 0.16778964, 0.21503959, 0.18418053, 0.08301532, 0.19214759]), 'half-true'), ('We caught (the Texas Commission on Environmental Quality) lying to us about ' 'the results of air quality studies in the Barnett Shale.', array([0.21555044]), array([0.11739072, 0.18516556, 0.21555044, 0.1962633 , 0.09398897, 0.19164101]), 'half-true'), ("He's leading by example, refusing contributions from PACs and Washington " 'lobbyists.', array([0.21889862]), array([0.18373547, 0.15530994, 0.21889862, 0.17217476, 0.10096145, 0.16891976]), 'half-true'), ('The same federal government that offers some money for a program is walking ' 'away from another health care program.', array([0.21980878]), array([0.21729909, 0.21384499, 0.21980878, 0.17316175, 0.06792889, 0.10795649]), 'half-true'), ('Americans who get their insurance through the workplace, cost savings could ' 'be as much as $3,000 less per employee than if we do nothing.', array([0.22062537]), array([0.21731516, 0.14280924, 0.22062537, 0.21867927, 0.03842627, 0.16214468]), 'half-true'), ('40 percent of illegal immigrants had a visa and then became illegal, mostly ' 'because they changed jobs.', array([0.22078109]), array([0.14852066, 0.21148064, 0.22078109, 0.21572661, 0.04314073, 0.16035027]), 'half-true'), ('Were talking right now about a $12 billion hole in our current, so-called ' 'balanced state budget.', array([0.2220753]), array([0.1400699 , 0.22012587, 0.2220753 , 0.1993672 , 0.04721604, 0.17114569]), 'half-true')] mostly-true High probability for correct class [('We admit more than 100,000 lifetime migrants from the Middle East each ' 'year.', array([0.35191719]), array([0.12078946, 0.0938687 , 0.22660249, 0.35191719, 0.0183323 , 0.18848986]), 'mostly-true'), ('More than 70 percent of American adults have committed a crime that could ' 'lead to imprisonment.', array([0.35193283]), array([0.09487863, 0.06720162, 0.29571508, 0.35193283, 0.01681881, 0.17345304]), 'mostly-true'), ('Georgia has more than 700 law enforcement agencies, and fewer than 20 ' 'percent of them are state-certified.', array([0.3591597]), array([0.09898033, 0.06596579, 0.24559441, 0.3591597 , 0.01683088, 0.21346889]), 'mostly-true'), ('We borrow a million dollars every minute.', array([0.36444243]), array([0.14579341, 0.08223078, 0.18832413, 0.36444243, 0.01223192, 0.20697733]), 'mostly-true'), ('Worldwide credit card transactions, the credit card fraud rate is 0.04 ' 'percent, compared to almost 8 percent, 9 percent, 10 percent of Medicare ' 'fraud.', array([0.36818189]), array([0.09659597, 0.08410966, 0.26479411, 0.36818189, 0.01356806, 0.1727503 ]), 'mostly-true'), ('94 percent of winning candidates in 2010 had more money than their ' 'opponents.', array([0.36915984]), array([0.1013368 , 0.061505 , 0.26146291, 0.36915984, 0.0077233 , 0.19881216]), 'mostly-true'), ('University of Texas undergraduate student debt is less than $21,000 ' 'probably one of the lowest debts across the nation.', array([0.37008317]), array([0.0715916 , 0.05839747, 0.21441318, 0.37008317, 0.01503068, 0.2704839 ]), 'mostly-true'), ('Ohios electricity rates are 10 percent below the national average.', array([0.3915675]), array([0.06960254, 0.08318926, 0.18605176, 0.3915675 , 0.01037493, 0.25921401]), 'mostly-true'), ('Americans work way more than an average of industrialized countries around ' 'the world.', array([0.39783377]), array([0.08611268, 0.07676001, 0.31685375, 0.39783377, 0.01350048, 0.10893931]), 'mostly-true'), ('Russia has an economy the size of Italy.', array([0.46374061]), array([0.05623691, 0.1219714 , 0.19440037, 0.46374061, 0.02596335, 0.13768736]), 'mostly-true')] Low probability for correct class [('Unlike virtually every other campaign, we dont have a super PAC.', array([0.20859411]), array([0.20603525, 0.15023642, 0.15129618, 0.20859411, 0.12669266, 0.15714538]), 'mostly-true'), ('Says an illegal immigrant fraudulently claimed children who actually lived ' 'in Mexico on income tax forms to collect more than $29,000.', array([0.21290631]), array([0.17591866, 0.1853713 , 0.18096353, 0.21290631, 0.09106108, 0.15377913]), 'mostly-true'), ('In a poll, 53 percent of young Republican voters . . . under age 35 said ' 'that they would describe a climate [change] denier as ignorant, out of ' 'touch or crazy.', array([0.22158575]), array([0.17223146, 0.15299953, 0.20275629, 0.22158575, 0.07266858, 0.17775838]), 'mostly-true'), ('Barack Obama will somehow manage to add more than $8 trillion to the ' 'national debt, which is more debt than the 43 presidents who held office ' 'before him compiled together.', array([0.22255775]), array([0.16181738, 0.13840224, 0.21825155, 0.22255775, 0.04476479, 0.21420629]), 'mostly-true'), ('When I talk about (raising the) minimum wage ... half of Republicansagree ' 'with it.', array([0.22335972]), array([0.2005361 , 0.14035991, 0.19710784, 0.22335972, 0.02695197, 0.21168445]), 'mostly-true'), ('Between the year 2000 and 2006, (insurance) premiums in this country ' 'doubled.', array([0.22349629]), array([0.14335926, 0.18787051, 0.16310565, 0.22349629, 0.08177795, 0.20039035]), 'mostly-true'), ('Says legislation pending in the House would effectively limit or eliminate ' 'time-and-a-half for people who work overtime.', array([0.22616179]), array([0.18882889, 0.20120041, 0.19498199, 0.22616179, 0.02483301, 0.16399391]), 'mostly-true'), ("And he's the only candidate who will fight for a national catastrophe fund " 'to reduce insurance rates.', array([0.22731101]), array([0.17388294, 0.20250774, 0.17136885, 0.22731101, 0.08632988, 0.13859958]), 'mostly-true'), ('Says Donald Trump won more counties than any candidate on our side since ' 'Ronald Reagan.', array([0.23124282]), array([0.19341058, 0.15751294, 0.20413876, 0.23124282, 0.05840977, 0.15528513]), 'mostly-true'), ('Says out of 588 school districts, we give 31 (former Abbott) districts 70 ' 'percent of the aid.', array([0.23245719]), array([0.18717314, 0.13667077, 0.18528811, 0.23245719, 0.06039563, 0.19801516]), 'mostly-true')] pants-fire High probability for correct class [('The president is brain-dead.', array([0.26844632]), array([0.15402847, 0.25656952, 0.14715299, 0.08095951, 0.26844632, 0.09284319]), 'pants-fire'), ('Nobody covered the remarks of Patricia Smith, the mother of a Benghazi ' 'victim, live, but almost everybody covered KhizrKhans, Mr. Khans remarks ' 'live.', array([0.26948679]), array([0.21122634, 0.17606194, 0.10573763, 0.11362465, 0.26948679, 0.12386266]), 'pants-fire'), ('President Obama went around the world and apologized for America.', array([0.27453481]), array([0.13692896, 0.17150436, 0.20154439, 0.12943158, 0.27453481, 0.0860559 ]), 'pants-fire'), ('A friends sister died from Obamacare becauseBlue Shield completely just ' 'pulled out of California.', array([0.27873331]), array([0.12850287, 0.1980285 , 0.13664032, 0.11336825, 0.27873331, 0.14472673]), 'pants-fire'), ('Says a rape kit can be used to clean out women, basically like dilation and ' 'curettage.', array([0.29748012]), array([0.20915077, 0.17675712, 0.10787376, 0.13019329, 0.29748012, 0.07854493]), 'pants-fire'), ('Fidel Castro endorses Obama.', array([0.29803774]), array([0.29547501, 0.11488852, 0.2263122 , 0.01979901, 0.29803774, 0.04548753]), 'pants-fire'), ('Sheldon Whitehouse [got] a secret closed-door briefing, warning of the ' '[2008 economic] crash.', array([0.31680957]), array([0.18420339, 0.21531583, 0.08773441, 0.09093313, 0.31680957, 0.10500367]), 'pants-fire'), ('Says President Barack Obama has said that everybody should hate the police.', array([0.32923461]), array([0.16983357, 0.22409941, 0.11541697, 0.06636785, 0.32923461, 0.0950476 ]), 'pants-fire'), ('Says Barack Obama is a Muslim.', array([0.3633849]), array([0.17930742, 0.16205863, 0.11749332, 0.10132006, 0.3633849 , 0.07643567]), 'pants-fire'), ('His true name is Barak Hussein Muhammed Obama.', array([0.41112157]), array([0.08697252, 0.23778562, 0.08980858, 0.0237037 , 0.41112157, 0.15060801]), 'pants-fire')] Low probability for correct class [("Clinton's former pastor convicted of child molestation.", array([0.21295808]), array([0.13307059, 0.20963104, 0.18046706, 0.081746 , 0.21295808, 0.18212723]), 'pants-fire'), ('You cannot build a little guy up by tearing a big guy down -- Abraham ' 'Lincoln said it.', array([0.22148382]), array([0.16446767, 0.18909148, 0.20660855, 0.11339259, 0.22148382, 0.10495589]), 'pants-fire'), ('President Barack Obamas latest executive order mandates the apprehension ' 'and detention of Americans who merely show signs of respiratory illness.', array([0.2646772]), array([0.15547019, 0.21085589, 0.16777896, 0.08752435, 0.2646772 , 0.11369342]), 'pants-fire'), ('Says the Democrats told the Catholic Church that theyll use federal powers ' 'to shut down church charities and hospitals if the church doesnt change its ' 'beliefs.', array([0.26727142]), array([0.17626063, 0.23986723, 0.11014199, 0.11434114, 0.26727142, 0.09211758]), 'pants-fire'), ('The president is brain-dead.', array([0.26844632]), array([0.15402847, 0.25656952, 0.14715299, 0.08095951, 0.26844632, 0.09284319]), 'pants-fire'), ('Nobody covered the remarks of Patricia Smith, the mother of a Benghazi ' 'victim, live, but almost everybody covered KhizrKhans, Mr. Khans remarks ' 'live.', array([0.26948679]), array([0.21122634, 0.17606194, 0.10573763, 0.11362465, 0.26948679, 0.12386266]), 'pants-fire'), ('President Obama went around the world and apologized for America.', array([0.27453481]), array([0.13692896, 0.17150436, 0.20154439, 0.12943158, 0.27453481, 0.0860559 ]), 'pants-fire'), ('A friends sister died from Obamacare becauseBlue Shield completely just ' 'pulled out of California.', array([0.27873331]), array([0.12850287, 0.1980285 , 0.13664032, 0.11336825, 0.27873331, 0.14472673]), 'pants-fire'), ('Says a rape kit can be used to clean out women, basically like dilation and ' 'curettage.', array([0.29748012]), array([0.20915077, 0.17675712, 0.10787376, 0.13019329, 0.29748012, 0.07854493]), 'pants-fire'), ('Fidel Castro endorses Obama.', array([0.29803774]), array([0.29547501, 0.11488852, 0.2263122 , 0.01979901, 0.29803774, 0.04548753]), 'pants-fire')] true High probability for correct class [('Congress can tell [the Supreme Court] which cases they ought to hear. We ' 'have that authority.', array([0.26751834]), array([0.17486497, 0.22588625, 0.09836713, 0.14207016, 0.09129315, 0.26751834]), 'true'), ('Three thousand felons voted in Rhode Island in 2008.', array([0.26954967]), array([0.07271338, 0.16109241, 0.20985999, 0.23901292, 0.04777164, 0.26954967]), 'true'), ('During my eight years as county executive, we cut the number of county ' 'workers by 20 percent.', array([0.27345031]), array([0.14659243, 0.11406933, 0.18530768, 0.25706216, 0.02351809, 0.27345031]), 'true'), ('Biden is "one of the least wealthy members of the U.S. Senate."', array([0.27752238]), array([0.12704761, 0.15308015, 0.13570624, 0.23677313, 0.06987049, 0.27752238]), 'true'), ('The Walton family, which owns Wal-Mart, controls a fortune equal to the ' 'wealth of the bottom 42 percent of Americans combined.', array([0.30783847]), array([0.07833839, 0.12055671, 0.23069266, 0.2266383 , 0.03593547, 0.30783847]), 'true'), ('Theres actually 600 abortions done after the 20th week of pregnancy every ' 'year in Ohio.', array([0.31190878]), array([0.11092438, 0.14964684, 0.13952456, 0.26394097, 0.02405447, 0.31190878]), 'true'), ('As a former federal prosecutor, I prosecuted over 4,000 cases.', array([0.31863488]), array([0.14044857, 0.15170702, 0.19081626, 0.17379286, 0.02460042, 0.31863488]), 'true'), ('Only 20 colleges and universities have athletic departments with revenue ' 'exceeding expenses.', array([0.3364042]), array([0.10639138, 0.04657964, 0.19563749, 0.29993382, 0.01505347, 0.3364042 ]), 'true'), ('The United states is borrowing more than 40 cents of every dollar we spend.', array([0.33731897]), array([0.09510039, 0.10495588, 0.17639187, 0.27213754, 0.01409534, 0.33731897]), 'true'), ('I am now the No. 2 member of this House in terms of length of service.', array([0.40499398]), array([0.06856361, 0.18419065, 0.09427939, 0.18781724, 0.06015515, 0.40499398]), 'true')] Low probability for correct class [('We have a retiree that is collecting a $17,000 paycheck a month . . . tax ' 'free.', array([0.20305376]), array([0.1844469 , 0.18360528, 0.17992862, 0.15924116, 0.08972429, 0.20305376]), 'true'), ('Says Ron Johnson likes to say there are too many lawyers in the Senate 57. ' 'Hed be the 70th millionaire.', array([0.20788873]), array([0.18261734, 0.16861212, 0.17959899, 0.2072759 , 0.05400692, 0.20788873]), 'true'), ('Proposed fees for Rhode Island beaches will still be less than some of the ' 'town beaches.', array([0.21294838]), array([0.14966726, 0.17734789, 0.21276733, 0.18705369, 0.06021546, 0.21294838]), 'true'), ('MikeHuckabee.com gets "more hits than virtually any other presidential ' 'candidate."', array([0.21579208]), array([0.13535892, 0.18388853, 0.21267974, 0.16788574, 0.08439499, 0.21579208]), 'true'), ('More black babies are aborted in NYC than born.', array([0.21602022]), array([0.13585434, 0.16637365, 0.20901227, 0.18271742, 0.09002211, 0.21602022]), 'true'), ('As governor, Ted Strickland left only 89 cents in Ohios rainy day fund.', array([0.21679789]), array([0.14982887, 0.2144718 , 0.1433054 , 0.19560797, 0.07998807, 0.21679789]), 'true'), ("Since 1981, reconciliation has been used 21 times. Most of it's been used " 'by Republicans.', array([0.2332101]), array([0.12069822, 0.19182052, 0.16291276, 0.22483058, 0.06652783, 0.2332101 ]), 'true'), ('I have filed every disclosure that has ever been required.', array([0.24174713]), array([0.12592984, 0.20429919, 0.12502129, 0.22504939, 0.07795317, 0.24174713]), 'true'), ('Pinellas County voters elected me as their chief financial officer (and) ' 'elected me as (their) governor four years ago.', array([0.24206037]), array([0.16379328, 0.23483695, 0.1464253 , 0.14344739, 0.0694367 , 0.24206037]), 'true'), ('There have been three people tried and convicted by the last administration ' 'in military courts. Two are walking the street right now.', array([0.25299944]), array([0.15493046, 0.15993618, 0.1866245 , 0.19221128, 0.05329814, 0.25299944]), 'true')]
Error analysis can start as early as fine-tuning the model. Observing different learning curves of the loss function of both the training and the development set can tell whether we need more data; we need a bigged dev set; we need a more complex model; whether the model is learning anything/overfitting, etc.
http://mlwiki.org/index.php/Learning_Curves
https://medium.com/uwaterloo-voice/error-analysis-in-deep-learning-6df3b3d335af
Model Debugging can be made easy with external tools, too:
"Those who don't track training are doomed to repeat it."
!pip3 install transformers
Requirement already satisfied: transformers in /usr/local/lib/python3.7/dist-packages (4.11.2) Requirement already satisfied: sacremoses in /usr/local/lib/python3.7/dist-packages (from transformers) (0.0.46) Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.7/dist-packages (from transformers) (5.4.1) Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from transformers) (4.8.1) Requirement already satisfied: tokenizers<0.11,>=0.10.1 in /usr/local/lib/python3.7/dist-packages (from transformers) (0.10.3) Requirement already satisfied: huggingface-hub>=0.0.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (0.0.17) Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (2019.12.20) Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers) (2.23.0) Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers) (3.0.12) Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from transformers) (21.0) Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.7/dist-packages (from transformers) (4.62.3) Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (1.19.5) Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from huggingface-hub>=0.0.17->transformers) (3.7.4.3) Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.0->transformers) (2.4.7) Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->transformers) (3.5.0) Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2021.5.30) Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (3.0.4) Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (1.24.3) Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2.10) Requirement already satisfied: click in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers) (7.1.2) Requirement already satisfied: joblib in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers) (1.0.1) Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers) (1.15.0)
import torch
import random
import numpy as np
import pandas as pd
from functools import partial
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from typing import List, Tuple
from tqdm import tqdm_notebook as tqdm
from transformers import PreTrainedTokenizer
from transformers import RobertaTokenizer
from transformers import RobertaConfig
from transformers import RobertaForSequenceClassification
from transformers import AdamW
from transformers import get_linear_schedule_with_warmup
from typing import List, Tuple
def accuracy(logits, labels):
logits = np.asarray(logits).reshape(-1, len(logits[0]))
labels = np.asarray(labels).reshape(-1)
return np.sum(np.argmax(logits, axis=-1) == labels).astype(np.float32) / float(labels.shape[0])
def evaluate(model: nn.Module, valid_dl: DataLoader):
"""
Evaluates the model on the given dataset
:param model: The model under evaluation
:param valid_dl: A `DataLoader` reading validation data
:return: The accuracy of the model on the dataset
"""
# VERY IMPORTANT: Put your model in "eval" mode -- this disables things like
# layer normalization and dropout
model.eval()
labels_all = []
logits_all = []
losses_all = []
# ALSO IMPORTANT: Don't accumulate gradients during this process
with torch.no_grad():
for batch in tqdm(valid_dl, desc='Evaluation'):
batch = tuple(t.to(device) for t in batch)
input_ids = batch[0]
attention_mask = batch[1]
labels = batch[2]
loss, logits = model(input_ids, attention_mask, labels=labels, return_dict=False)
labels_all.extend(list(labels.detach().cpu().numpy()))
logits_all.extend(list(logits.detach().cpu().numpy()))
losses_all.append(loss.detach().cpu().numpy())
acc = accuracy(logits_all, labels_all)
return acc, np.mean(losses_all)
def train(
model: nn.Module,
train_dl: DataLoader,
valid_dl: DataLoader,
optimizer: torch.optim.Optimizer,
n_epochs: int,
device: torch.device,
scheduler = None
):
"""
The main training loop which will optimize a given model on a given dataset
:param model: The model being optimized
:param train_dl: The training dataset
:param valid_dl: A validation dataset
:param optimizer: The optimizer used to update the model parameters
:param n_epochs: Number of epochs to train for
:param device: The device to train on
:return: (model, losses) The best model and the losses per iteration
"""
# Keep track of the loss and best accuracy
losses = []
best_acc = 0.0
# Iterate through epochs
for ep in range(n_epochs):
loss_epoch = []
#Iterate through each batch in the dataloader
for i, batch in tqdm(enumerate(train_dl)):
# VERY IMPORTANT: Make sure the model is in training mode, which turns on
# things like dropout and layer normalization
model.train()
# VERY IMPORTANT: zero out all of the gradients on each iteration -- PyTorch
# keeps track of these dynamically in its computation graph so you need to explicitly
# zero them out
optimizer.zero_grad()
# Place each tensor on the GPU
batch = tuple(t.to(device) for t in batch)
input_ids = batch[0]
attention_mask = batch[1]
labels = batch[2]
# Pass the inputs through the model, get the current loss and logits
loss, logits = model(input_ids, attention_mask, labels=labels, return_dict=False)
wandb.log({'loss': loss.item()})
losses.append(loss.item())
loss_epoch.append(loss.item())
# Calculate all of the gradients and weight updates for the model
loss.backward()
# Optional: clip gradients
#torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
# Finally, update the weights of the model
optimizer.step()
if scheduler is not None:
scheduler.step()
# CHANGED CODE -- as the Transformer model trains for a few epoch,
# we might want to look at the learning curves each other step
if i % 100 == 0:
acc, val_loss = evaluate(model, valid_dl)
wandb.log({'acc': acc, 'train_loss': np.mean(loss_epoch), 'val_loss': val_loss})
# Perform inline evaluation at the end of the epoch
acc, val_loss = evaluate(model, valid_dl)
wandb.log({'acc': acc, 'train_loss': np.mean(loss_epoch), 'val_loss': val_loss})
print(f'Validation accuracy: {acc}, train loss: {np.mean(loss_epoch)}')
# Keep track of the best model based on the accuracy
best_model = model.state_dict()
if acc > best_acc:
torch.save(model.state_dict(), 'best_model')
best_acc = acc
#gc.collect()
model.load_state_dict(best_model)
return model, losses
def text_to_batch_transformer(text: List, tokenizer: PreTrainedTokenizer) -> Tuple[List, List]:
"""Turn a piece of text into a batch for transformer model
:param text: The text to tokenize and encode
:param tokenizer: The tokenizer to use
:return: A list of IDs and a mask
"""
input_ids = [tokenizer.encode(t, add_special_tokens=True, truncation=True) for t in text]
masks = [[1] * len(i) for i in input_ids]
return input_ids, masks
class ClassificationDatasetReader(Dataset):
def __init__(self, df, tokenizer):
self.df = df
self.tokenizer = tokenizer
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
row = self.df.values[idx]
# Calls the text_to_batch function
input_ids, masks = text_to_batch_transformer([row[2]], self.tokenizer)
label = label_map[row[1]]
return input_ids, masks, label
label_map = {l:i for i,l in enumerate((set(train_data.values[:,1]) | set(valid_data.values[:,1]) | set(test_data.values[:,1])))}
num_labels = len(label_map)
print(label_map)
{'true': 0, 'barely-true': 1, 'false': 2, 'mostly-true': 3, 'pants-fire': 4, 'half-true': 5}
def collate_batch_transformer(pad_id, input_data: Tuple) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
input_ids = [i[0][0] for i in input_data]
masks = [i[1][0] for i in input_data]
labels = [i[2] for i in input_data]
max_length = max([len(i) for i in input_ids])
input_ids = [(i + [pad_id] * (max_length - len(i))) for i in input_ids]
masks = [(m + [pad_id] * (max_length - len(m))) for m in masks]
assert (all(len(i) == max_length for i in input_ids))
assert (all(len(m) == max_length for m in masks))
return torch.tensor(input_ids), torch.tensor(masks), torch.tensor(labels)
# a few steps needed to initialize the project in WANDB
!pip install wandb -qqq
import wandb
wandb.login()
wandb: Currently logged in as: kstanczak (use `wandb login --relogin` to force relogin)
True
wandb.init(project="lab-5-roberta",
config={
"batch_size": 8,
"learning_rate": 5e-5,
"dataset": "LIAR",
})
weight_decay = 0.01
n_epochs = 2
lr = 3e-5
# Get the device
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda")
# Create the tokenizer
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
vocabulary = tokenizer.get_vocab()
# Create the dataset readers
train_dataset = ClassificationDatasetReader(train_data, tokenizer)
# dataset loaded lazily with N workers in parallel
collate_fn = partial(collate_batch_transformer, 0)
train_dl = DataLoader(train_dataset, batch_size=6, shuffle=True, collate_fn=collate_fn, num_workers=8)
valid_dataset = ClassificationDatasetReader(valid_data, tokenizer)
valid_dl = DataLoader(valid_dataset, batch_size=6, collate_fn=collate_fn, num_workers=8)
config = RobertaConfig.from_pretrained('roberta-base', num_labels=6)
model = RobertaForSequenceClassification.from_pretrained('roberta-base', config=config).to(device)
# Create the optimizer
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
'weight_decay': weight_decay},
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters, lr=lr)
scheduler = get_linear_schedule_with_warmup(
optimizer,
200,
n_epochs * len(train_dl)
)
/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py:481: UserWarning: This DataLoader will create 8 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary. cpuset_checked)) Some weights of the model checkpoint at roberta-base were not used when initializing RobertaForSequenceClassification: ['lm_head.bias', 'roberta.pooler.dense.weight', 'lm_head.dense.weight', 'lm_head.layer_norm.bias', 'roberta.pooler.dense.bias', 'lm_head.decoder.weight', 'lm_head.layer_norm.weight', 'lm_head.dense.bias'] - This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model). - This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model). Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.weight', 'classifier.out_proj.bias'] You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
# magic to have plots of the learning curves
%%wandb
model, losses = train(model, train_dl, valid_dl, optimizer, n_epochs, device, scheduler)
/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py:481: UserWarning: This DataLoader will create 8 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary. cpuset_checked)) /usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:65: TqdmDeprecationWarning: This function will be removed in tqdm==5.0.0 Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
0it [00:00, ?it/s]
/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:21: TqdmDeprecationWarning: This function will be removed in tqdm==5.0.0 Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
Evaluation: 0%| | 0/214 [00:00<?, ?it/s]
Evaluation: 0%| | 0/214 [00:00<?, ?it/s]
Evaluation: 0%| | 0/214 [00:00<?, ?it/s]
Evaluation: 0%| | 0/214 [00:00<?, ?it/s]
Evaluation: 0%| | 0/214 [00:00<?, ?it/s]
Evaluation: 0%| | 0/214 [00:00<?, ?it/s]
Evaluation: 0%| | 0/214 [00:00<?, ?it/s]
Evaluation: 0%| | 0/214 [00:00<?, ?it/s]
Evaluation: 0%| | 0/214 [00:00<?, ?it/s]
Evaluation: 0%| | 0/214 [00:00<?, ?it/s]
Evaluation: 0%| | 0/214 [00:00<?, ?it/s]
Evaluation: 0%| | 0/214 [00:00<?, ?it/s]
Evaluation: 0%| | 0/214 [00:00<?, ?it/s]
Evaluation: 0%| | 0/214 [00:00<?, ?it/s]
Evaluation: 0%| | 0/214 [00:00<?, ?it/s]
Evaluation: 0%| | 0/214 [00:00<?, ?it/s]
Evaluation: 0%| | 0/214 [00:00<?, ?it/s]
Evaluation: 0%| | 0/214 [00:00<?, ?it/s]
Evaluation: 0%| | 0/214 [00:00<?, ?it/s]
Validation accuracy: 0.20482866043613707, train loss: 1.7665990396208
0it [00:00, ?it/s]
Evaluation: 0%| | 0/214 [00:00<?, ?it/s]
Evaluation: 0%| | 0/214 [00:00<?, ?it/s]
Evaluation: 0%| | 0/214 [00:00<?, ?it/s]
Evaluation: 0%| | 0/214 [00:00<?, ?it/s]
Evaluation: 0%| | 0/214 [00:00<?, ?it/s]
Evaluation: 0%| | 0/214 [00:00<?, ?it/s]
Evaluation: 0%| | 0/214 [00:00<?, ?it/s]