from fastai.text import * # Quick access to NLP functionality
An example of creating a language model and then transfering to a classifier.
path = untar_data(URLs.IMDB_SAMPLE)
path
PosixPath('/home/ubuntu/.fastai/data/imdb_sample')
Open and view the independent and dependent variables:
df = pd.read_csv(path/'texts.csv')
df.head()
label | text | is_valid | |
---|---|---|---|
0 | negative | Un-bleeping-believable! Meg Ryan doesn't even ... | False |
1 | positive | This is a extremely well-made film. The acting... | False |
2 | negative | Every once in a long while a movie will come a... | False |
3 | positive | Name just says it all. I watched this movie wi... | False |
4 | negative | This movie succeeds at being one of the most u... | False |
Create a DataBunch
for each of the language model and the classifier:
data_lm = TextLMDataBunch.from_csv(path, 'texts.csv')
data_clas = TextClasDataBunch.from_csv(path, 'texts.csv', vocab=data_lm.train_ds.vocab, bs=42)
We'll fine-tune the language model. fast.ai has a pre-trained English model available that we can download, we just have to specify it like this:
moms = (0.8,0.7)
learn = language_model_learner(data_lm, AWD_LSTM)
learn.unfreeze()
learn.fit_one_cycle(4, slice(1e-2), moms=moms)
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
1 | 4.752841 | 3.936385 | 0.286949 | 00:17 |
2 | 4.450984 | 3.839346 | 0.292485 | 00:17 |
3 | 4.170427 | 3.803478 | 0.294866 | 00:17 |
4 | 3.946596 | 3.798583 | 0.295342 | 00:17 |
Save our language model's encoder:
learn.save_encoder('enc')
Fine tune it to create a classifier:
learn = text_classifier_learner(data_clas, AWD_LSTM)
learn.load_encoder('enc')
learn.fit_one_cycle(4, moms=moms)
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
1 | 0.706597 | 0.635787 | 0.671642 | 00:33 |
2 | 0.687497 | 0.627676 | 0.651741 | 00:33 |
3 | 0.668437 | 0.603736 | 0.681592 | 00:30 |
4 | 0.650379 | 0.597291 | 0.676617 | 00:30 |
learn.unfreeze()
learn.fit_one_cycle(8, slice(1e-5,1e-3), moms=moms)
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
1 | 0.635453 | 0.589016 | 0.696517 | 00:51 |
2 | 0.644397 | 0.604481 | 0.676617 | 00:55 |
3 | 0.646375 | 0.579778 | 0.696517 | 00:55 |
4 | 0.632535 | 0.552055 | 0.731343 | 00:53 |
5 | 0.638672 | 0.544292 | 0.736318 | 00:51 |
6 | 0.621485 | 0.549286 | 0.736318 | 00:55 |
7 | 0.612489 | 0.543637 | 0.741294 | 00:53 |
8 | 0.611454 | 0.542703 | 0.736318 | 00:54 |