In this notebook, you will learn how to use Spark NLP and Legal NLP to train multilabel classification models.
Let`s dive in!
! pip install -q johnsnowlabs
Using my.johnsnowlabs.com SSO
from johnsnowlabs import nlp, legal
# nlp.install(force_browser=True)
If you are not registered in my.johnsnowlabs.com, you received a license via e-email or you are using Safari, you may need to do a manual update of the license.
from google.colab import files
print('Please Upload your John Snow Labs License using the button below')
license_keys = files.upload()
nlp.install()
spark = nlp.start()
For the text classification tasks, we will use two annotators:
MultiClassifierDL
: Multilabel Classification
(can predict more than one class for each text) using a Bidirectional GRU with Convolution architecture built with TensorFlow that supports up to 100 classes. The inputs are Sentence Embeddings such as state-of-the-art UniversalSentenceEncoder, BertSentenceEmbeddings or SentenceEmbeddings.ClassifierDL
: uses the state-of-the-art Universal Sentence Encoder as an input for text classifications. Then, a deep learning model (DNNs) built with TensorFlow that supports Binary Classification
and Multiclass Classification
(up to 100 classes).! wget -q https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-workshop/master/legal-nlp/data/finance_data.csv
import pandas as pd
df = pd.read_csv('./finance_data.csv')
df['label'] = df['label'].apply(eval)
print(f"Shape of the full dataset: {df.shape}")
Shape of the full dataset: (27527, 2)
We will use a sample from this dataset to avoid making the training process faster (to illustrate how to perform them). Use the full dataset if you want to experiment with it and achieve more realistic results.
The sample has size of 1000 observations only, please keep in mind that this will impact the accuracy and generalization capabilities of the model. Since the dataset is smaller now, we use 90% of it to train the model and the other 10% for testing.
data = spark.createDataFrame(df)
# If you have a single dataset, then split it or else you can load the test dataset the same way that you load the train data.
train, test = data.limit(1000).randomSplit([0.7, 0.3], seed=42)
train.show(truncate=50)
+--------------------------------------------------+-----------------------------------+ | provision| label| +--------------------------------------------------+-----------------------------------+ |(a) No failure or delay of the Administrative A...| [waivers, amendments]| |(a) Seller, the Agent, each Managing Agent, eac...| [assignments]| |(a) To induce the other parties hereto to enter...| [representations, warranties]| |(a) The provisions of this Agreement shall be ...| [assigns, successors]| |(a) All of the representations and warranties m...| [representations, warranties]| |(a) THIS AGREEMENT AND ANY CLAIM, CONTROVERSY, ...|[governing laws, entire agreements]| |All Bank Expenses (including reasonable attorne...| [expenses]| |All agreements, covenants, representations, war...| [terminations]| |All agreements, representations and warranties ...| [survival]| |All communications hereunder will be in writing...| [notices]| |All costs and expenses incurred in connection w...| [expenses]| |All covenants of the Company contained in this ...| [survival]| |All covenants, agreements, representations and ...| [survival]| |All covenants, agreements, representations and ...| [survival]| |All covenants, agreements, representations and ...| [survival]| |All demands, notices and communications hereund...| [notices]| |All demands, notices and communications hereund...| [notices]| |All indemnities set forth herein including, wit...| [survival]| |All non-competition, non-solicitation, non-disc...| [survival]| |All notices and other communications given or m...| [notices]| +--------------------------------------------------+-----------------------------------+ only showing top 20 rows
from pyspark.sql.functions import col
test.groupBy("label").count().orderBy(col("count").desc()).show()
+--------------------+-----+ | label|count| +--------------------+-----+ | [governing laws]| 35| | [notices]| 31| | [severability]| 27| | [entire agreements]| 27| | [counterparts]| 24| | [survival]| 19| |[assigns, success...| 14| | [terminations]| 14| | [amendments]| 13| | [expenses]| 11| | [assignments]| 10| |[waivers, amendme...| 8| | [waivers]| 7| |[amendments, enti...| 3| | [representations]| 3| | [successors]| 2| |[amendments, term...| 2| |[representations,...| 2| | [warranties]| 1| |[severability, su...| 1| +--------------------+-----+ only showing top 20 rows
document_assembler = (
nlp.DocumentAssembler()
.setInputCol("provision")
.setOutputCol("document")
.setCleanupMode("shrink")
)
embeddings = (
nlp.UniversalSentenceEncoder.pretrained()
.setInputCols("document")
.setOutputCol("sentence_embeddings")
)
classifierdl = (
nlp.MultiClassifierDLApproach()
.setInputCols(["sentence_embeddings"])
.setOutputCol("class")
.setLabelColumn("label")
.setMaxEpochs(20)
.setLr(0.001)
.setRandomSeed(42)
.setEnableOutputLogs(True)
.setOutputLogsPath("multilabel_use_logs")
.setBatchSize(8)
)
clf_pipeline = nlp.Pipeline(stages=[document_assembler, embeddings, classifierdl])
tfhub_use download started this may take some time. Approximate size to download 923.7 MB [OK!]
Since this model can takes longer time to train, we will limit (reduce) the size of the training data to avoid having it training for hours.
Please note that this reduction can greatly impact the performance of the model
%%time
clf_pipelineModel = clf_pipeline.fit(train)
CPU times: user 494 ms, sys: 67.1 ms, total: 561 ms Wall time: 1min 26s
import os
log_file_name = os.listdir("multilabel_use_logs")[0]
with open("multilabel_use_logs/"+log_file_name, "r") as log_file :
print(log_file.read())
Training started - epochs: 20 - learning_rate: 0.001 - batch_size: 8 - training_examples: 744 - classes: 15 Epoch 0/20 - 5.90s - loss: 0.31367278 - acc: 0.91523325 - batches: 93 Epoch 1/20 - 2.32s - loss: 0.20648386 - acc: 0.93324363 - batches: 93 Epoch 2/20 - 1.74s - loss: 0.15775694 - acc: 0.9456988 - batches: 93 Epoch 3/20 - 1.76s - loss: 0.13085833 - acc: 0.9548385 - batches: 93 Epoch 4/20 - 1.72s - loss: 0.11435161 - acc: 0.9614694 - batches: 93 Epoch 5/20 - 1.71s - loss: 0.1033926 - acc: 0.965412 - batches: 93 Epoch 6/20 - 1.67s - loss: 0.09538201 - acc: 0.96827936 - batches: 93 Epoch 7/20 - 1.69s - loss: 0.08915223 - acc: 0.9700714 - batches: 93 Epoch 8/20 - 1.72s - loss: 0.08416093 - acc: 0.9717739 - batches: 93 Epoch 9/20 - 1.67s - loss: 0.08005884 - acc: 0.9731181 - batches: 93 Epoch 10/20 - 1.68s - loss: 0.07660815 - acc: 0.9741037 - batches: 93 Epoch 11/20 - 1.66s - loss: 0.07365137 - acc: 0.9750894 - batches: 93 Epoch 12/20 - 1.67s - loss: 0.071067244 - acc: 0.9752686 - batches: 93 Epoch 13/20 - 1.71s - loss: 0.0687695 - acc: 0.97598547 - batches: 93 Epoch 14/20 - 1.70s - loss: 0.066702016 - acc: 0.97697115 - batches: 93 Epoch 15/20 - 1.69s - loss: 0.06482201 - acc: 0.97732955 - batches: 93 Epoch 16/20 - 1.66s - loss: 0.063097924 - acc: 0.9780463 - batches: 93 Epoch 17/20 - 2.00s - loss: 0.061507963 - acc: 0.9785841 - batches: 93 Epoch 18/20 - 1.69s - loss: 0.060031768 - acc: 0.978853 - batches: 93 Epoch 19/20 - 1.73s - loss: 0.05865668 - acc: 0.9790321 - batches: 93
preds = clf_pipelineModel.transform(test)
preds_df = preds.select('label','provision',"class.result").toPandas()
preds_df.head()
label | provision | result | |
---|---|---|---|
0 | [assigns, successors] | (a) The provisions of this Agreement shall be ... | [successors] |
1 | [waivers] | (a) Any provision of this Agreement may be wai... | [waivers, amendments] |
2 | [waivers, amendments] | (a) This Agreement may be amended, supplemente... | [waivers] |
3 | [counterparts] | (a) This Agreement may be executed by one or m... | [counterparts] |
4 | [survival] | All agreements, representations and warranties... | [survival] |
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import classification_report
from sklearn.metrics import f1_score
from sklearn.metrics import roc_auc_score
mlb = MultiLabelBinarizer()
y_true = mlb.fit_transform(preds_df['label'])
y_pred = mlb.transform(preds_df['result'])
print("Classification report: \n", (classification_report(y_true, y_pred)))
print("F1 micro averaging:",(f1_score(y_true, y_pred, average='micro')))
print("ROC: ",(roc_auc_score(y_true, y_pred, average="micro")))
Classification report: precision recall f1-score support 0 0.85 0.42 0.56 26 1 0.00 0.00 0.00 10 2 0.88 0.50 0.64 14 3 1.00 1.00 1.00 24 4 0.97 0.97 0.97 30 5 0.86 0.55 0.67 11 6 0.97 0.92 0.94 36 7 0.93 0.81 0.86 31 8 0.60 0.60 0.60 5 9 0.93 0.93 0.93 30 10 0.91 0.62 0.74 16 11 0.79 0.55 0.65 20 12 0.88 0.44 0.58 16 13 1.00 0.69 0.81 16 14 0.40 0.67 0.50 3 micro avg 0.91 0.72 0.80 288 macro avg 0.80 0.64 0.70 288 weighted avg 0.88 0.72 0.78 288 samples avg 0.73 0.74 0.73 288 F1 micro averaging: 0.8038834951456311 ROC: 0.8565596846846847