%reload_ext autoreload
%autoreload 2
%matplotlib inline
from fastai import *
from fastai.text import *
bs=128
torch.cuda.set_device(2)
data_path = Config.data_path()
lang = 'en'
name = f'{lang}wiki'
path = data_path/name
path.mkdir(exist_ok=True, parents=True)
lm_fns = [f'{lang}_wt', f'{lang}_wt_vocab']
from nlputils import split_wiki,get_wiki
get_wiki(path,lang)
!head -n4 {path}/{name}
/home/jhoward/.fastai/data/enwiki/enwiki already exists; not downloading <doc id="12" url="https://en.wikipedia.org/wiki?curid=12" title="Anarchism"> Anarchism Anarchism is an anti-authoritarian political philosophy that advocates self-managed, self-governed societies based on voluntary, cooperative institutions and the rejection of hierarchies those societies view as unjust. These institutions are often described as stateless societies, although several authors have defined them more specifically as distinct institutions based on non-hierarchical or free associations. Anarchism's central disagreement with other ideologies is that it holds the state to be undesirable, unnecessary, and harmful.
dest = split_wiki(path,lang)
/home/jhoward/.fastai/data/enwiki/docs already exists; not splitting
data = (TextList.from_folder(dest)
.split_by_rand_pct(0.1, seed=42)
.label_for_lm()
.databunch(bs=bs, num_workers=1))
data.save(f'{lang}_databunch')
len(data.vocab.itos),len(data.train_ds)
---------------------------------------------------------------------- OSError Traceback (most recent call last) <ipython-input-6-3e15cb154237> in <module> 1 data = (TextList.from_folder(dest) ----> 2 .split_by_rand_pct(0.1, seed=42) 3 .label_for_lm() 4 .databunch(bs=bs, num_workers=1)) 5 ~/git/fastai/fastai/data_block.py in _inner(*args, **kwargs) 475 self.valid = fv(*args, from_item_lists=True, **kwargs) 476 self.__class__ = LabelLists --> 477 self.process() 478 return self 479 return _inner ~/git/fastai/fastai/data_block.py in process(self) 529 "Process the inner datasets." 530 xp,yp = self.get_processors() --> 531 for ds,n in zip(self.lists, ['train','valid','test']): ds.process(xp, yp, name=n) 532 #progress_bar clear the outputs so in some case warnings issued during processing disappear. 533 for ds in self.lists: ~/git/fastai/fastai/data_block.py in process(self, xp, yp, name) 708 p.warns = [] 709 self.x,self.y = self.x[~filt],self.y[~filt] --> 710 self.x.process(xp) 711 return self 712 ~/git/fastai/fastai/data_block.py in process(self, processor) 81 if processor is not None: self.processor = processor 82 self.processor = listify(self.processor) ---> 83 for p in self.processor: p.process(self) 84 return self 85 ~/git/fastai/fastai/text/data.py in process(self, ds) 294 tokens = [] 295 for i in progress_bar(range(0,len(ds),self.chunksize), leave=False): --> 296 tokens += self.tokenizer.process_all(ds.items[i:i+self.chunksize]) 297 ds.items = tokens 298 ~/git/fastai/fastai/text/transform.py in process_all(self, texts) 118 if self.n_cpus <= 1: return self._process_all_1(texts) 119 with ProcessPoolExecutor(self.n_cpus) as e: --> 120 return sum(e.map(self._process_all_1, partition_by_cores(texts, self.n_cpus)), []) 121 122 class Vocab(): ~/anaconda3/lib/python3.7/concurrent/futures/process.py in map(self, fn, timeout, chunksize, *iterables) 643 results = super().map(partial(_process_chunk, fn), 644 _get_chunks(*iterables, chunksize=chunksize), --> 645 timeout=timeout) 646 return _chain_from_iterable_of_lists(results) 647 ~/anaconda3/lib/python3.7/concurrent/futures/_base.py in map(self, fn, timeout, chunksize, *iterables) 573 end_time = timeout + time.monotonic() 574 --> 575 fs = [self.submit(fn, *args) for args in zip(*iterables)] 576 577 # Yield must be hidden in closure so that the futures are submitted ~/anaconda3/lib/python3.7/concurrent/futures/_base.py in <listcomp>(.0) 573 end_time = timeout + time.monotonic() 574 --> 575 fs = [self.submit(fn, *args) for args in zip(*iterables)] 576 577 # Yield must be hidden in closure so that the futures are submitted ~/anaconda3/lib/python3.7/concurrent/futures/process.py in submit(self, fn, *args, **kwargs) 613 self._queue_management_thread_wakeup.wakeup() 614 --> 615 self._start_queue_management_thread() 616 return f 617 submit.__doc__ = _base.Executor.submit.__doc__ ~/anaconda3/lib/python3.7/concurrent/futures/process.py in _start_queue_management_thread(self) 567 thread_wakeup.wakeup() 568 # Start the processes so that their sentinels are known. --> 569 self._adjust_process_count() 570 self._queue_management_thread = threading.Thread( 571 target=_queue_management_worker, ~/anaconda3/lib/python3.7/concurrent/futures/process.py in _adjust_process_count(self) 591 self._initializer, 592 self._initargs)) --> 593 p.start() 594 self._processes[p.pid] = p 595 ~/anaconda3/lib/python3.7/multiprocessing/process.py in start(self) 110 'daemonic processes are not allowed to have children' 111 _cleanup() --> 112 self._popen = self._Popen(self) 113 self._sentinel = self._popen.sentinel 114 # Avoid a refcycle if the target function holds an indirect ~/anaconda3/lib/python3.7/multiprocessing/context.py in _Popen(process_obj) 275 def _Popen(process_obj): 276 from .popen_fork import Popen --> 277 return Popen(process_obj) 278 279 class SpawnProcess(process.BaseProcess): ~/anaconda3/lib/python3.7/multiprocessing/popen_fork.py in __init__(self, process_obj) 18 self.returncode = None 19 self.finalizer = None ---> 20 self._launch(process_obj) 21 22 def duplicate_for_child(self, fd): ~/anaconda3/lib/python3.7/multiprocessing/popen_fork.py in _launch(self, process_obj) 68 code = 1 69 parent_r, child_w = os.pipe() ---> 70 self.pid = os.fork() 71 if self.pid == 0: 72 try: OSError: [Errno 12] Cannot allocate memory
data = load_data(path, f'{lang}_databunch', bs=bs)
learn = language_model_learner(data, AWD_LSTM, drop_mult=0.5, pretrained=False).to_fp16()
lr = 1e-2
lr *= bs/48 # Scale learning rate by batch size
learn.unfreeze()
learn.fit_one_cycle(10, lr, moms=(0.8,0.7))
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 3.436113 | 3.491434 | 0.366925 | 28:52 |
1 | 3.441240 | 3.544118 | 0.361326 | 28:33 |
2 | 3.571766 | 3.556932 | 0.358438 | 28:31 |
3 | 3.510540 | 3.519243 | 0.362278 | 28:27 |
4 | 3.447639 | 3.449320 | 0.369404 | 28:29 |
5 | 3.412284 | 3.406376 | 0.375022 | 28:20 |
6 | 3.286754 | 3.255309 | 0.391874 | 28:19 |
7 | 3.172497 | 3.128522 | 0.406803 | 28:37 |
8 | 3.126867 | 3.025249 | 0.419882 | 28:36 |
9 | 3.128793 | 2.991077 | 0.424622 | 28:39 |
Save the pretrained model and vocab:
path.ls()
[PosixPath('/home/jhoward/data/viwiki/docs'), PosixPath('/home/jhoward/data/viwiki/viwiki-latest-pages-articles.xml'), PosixPath('/home/jhoward/data/viwiki/vi_wt87_vocab.pkl'), PosixPath('/home/jhoward/data/viwiki/extract'), PosixPath('/home/jhoward/data/viwiki/tmp'), PosixPath('/home/jhoward/data/viwiki/test.csv'), PosixPath('/home/jhoward/data/viwiki/viwiki'), PosixPath('/home/jhoward/data/viwiki/log'), PosixPath('/home/jhoward/data/viwiki/train.csv')]
mdl_path = path/'models'
mdl_path.mkdir(exist_ok=True)
learn.to_fp32().save(mdl_path/lm_fns[0], with_opt=False)
learn.data.vocab.save(mdl_path/(lm_fns[1] + '.pkl'))
train_df = pd.read_csv(path/'train.csv')
train_df.loc[pd.isna(train_df.comment),'comment']='NA'
train_df.head()
id | comment | label | |
---|---|---|---|
0 | train_000000 | Dung dc sp tot cam on \nshop Đóng gói sản phẩm... | 0 |
1 | train_000001 | Chất lượng sản phẩm tuyệt vời . Son mịn nhưng... | 0 |
2 | train_000002 | Chất lượng sản phẩm tuyệt vời nhưng k có hộp ... | 0 |
3 | train_000003 | :(( Mình hơi thất vọng 1 chút vì mình đã kỳ vọ... | 1 |
4 | train_000004 | Lần trước mình mua áo gió màu hồng rất ok mà đ... | 1 |
test_df = pd.read_csv(path/'test.csv')
test_df.loc[pd.isna(test_df.comment),'comment']='NA'
test_df.head()
id | comment | |
---|---|---|
0 | test_000000 | Chưa dùng thử nên chưa biết |
1 | test_000001 | Không đáng tiềnVì ngay đợt sale nên mới mua n... |
2 | test_000002 | Cám ơn shop. Đóng gói sản phẩm rất đẹp và chắc... |
3 | test_000003 | Vải đẹp.phom oki luôn.quá ưng |
4 | test_000004 | Chuẩn hàng đóng gói đẹp |
df = pd.concat([train_df,test_df], sort=False)
data_lm = (TextList.from_df(df, path, cols='comment')
.split_by_rand_pct(0.1, seed=42)
.label_for_lm()
.databunch(bs=bs, num_workers=1))
learn_lm = language_model_learner(data_lm, AWD_LSTM, pretrained_fnames=lm_fns, drop_mult=1.0)
lr = 1e-3
lr *= bs/48
learn_lm.fit_one_cycle(2, lr*10, moms=(0.8,0.7))
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 4.975080 | 4.138585 | 0.317773 | 00:07 |
1 | 4.408635 | 4.025489 | 0.326423 | 00:07 |
learn_lm.unfreeze()
learn_lm.fit_one_cycle(8, lr, moms=(0.8,0.7))
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 4.142114 | 3.928278 | 0.336230 | 00:09 |
1 | 4.010835 | 3.793583 | 0.349972 | 00:09 |
2 | 3.873617 | 3.694702 | 0.357240 | 00:09 |
3 | 3.761377 | 3.632186 | 0.364648 | 00:09 |
4 | 3.679017 | 3.595601 | 0.366964 | 00:09 |
5 | 3.614548 | 3.576386 | 0.369224 | 00:09 |
6 | 3.575895 | 3.567496 | 0.370285 | 00:09 |
7 | 3.560278 | 3.566525 | 0.370173 | 00:10 |
learn_lm.save(f'{lang}fine_tuned')
learn_lm.save_encoder(f'{lang}fine_tuned_enc')
data_clas = (TextList.from_df(train_df, path, vocab=data_lm.vocab, cols='comment')
.split_by_rand_pct(0.1, seed=42)
.label_from_df(cols='label')
.databunch(bs=bs, num_workers=1))
data_clas.save(f'{lang}_textlist_class')
data_clas = load_data(path, f'{lang}_textlist_class', bs=bs, num_workers=1)
from sklearn.metrics import f1_score
@np_func
def f1(inp,targ): return f1_score(targ, np.argmax(inp, axis=-1))
learn_c = text_classifier_learner(data_clas, AWD_LSTM, drop_mult=0.5, metrics=[accuracy,f1]).to_fp16()
learn_c.load_encoder(f'{lang}fine_tuned_enc')
learn_c.freeze()
lr=2e-2
lr *= bs/48
learn_c.fit_one_cycle(2, lr, moms=(0.8,0.7))
epoch | train_loss | valid_loss | accuracy | _inner | time |
---|---|---|---|---|---|
0 | 0.338150 | 0.275298 | 0.899876 | 0.878430 | 00:02 |
1 | 0.302302 | 0.245949 | 0.902985 | 0.877226 | 00:02 |
learn_c.fit_one_cycle(2, lr, moms=(0.8,0.7))
epoch | train_loss | valid_loss | accuracy | _inner | time |
---|---|---|---|---|---|
0 | 0.321768 | 0.255457 | 0.899254 | 0.871367 | 00:02 |
1 | 0.305934 | 0.250888 | 0.894901 | 0.872021 | 00:02 |
learn_c.freeze_to(-2)
learn_c.fit_one_cycle(2, slice(lr/(2.6**4),lr), moms=(0.8,0.7))
epoch | train_loss | valid_loss | accuracy | _inner | time |
---|---|---|---|---|---|
0 | 0.300939 | 0.261080 | 0.893657 | 0.866201 | 00:03 |
1 | 0.263790 | 0.220207 | 0.906716 | 0.886115 | 00:03 |
learn_c.freeze_to(-3)
learn_c.fit_one_cycle(2, slice(lr/2/(2.6**4),lr/2), moms=(0.8,0.7))
epoch | train_loss | valid_loss | accuracy | _inner | time |
---|---|---|---|---|---|
0 | 0.282888 | 0.238203 | 0.905473 | 0.886483 | 00:04 |
1 | 0.248599 | 0.216489 | 0.918532 | 0.901550 | 00:04 |
learn_c.unfreeze()
learn_c.fit_one_cycle(1, slice(lr/10/(2.6**4),lr/10), moms=(0.8,0.7))
epoch | train_loss | valid_loss | accuracy | _inner | time |
---|---|---|---|---|---|
0 | 0.201508 | 0.217176 | 0.911070 | 0.890084 | 00:05 |
learn_c.save(f'{lang}clas')
Competition top 3 f1 scores: 0.90, 0.89, 0.89. Winner used an ensemble of 4 models: TextCNN, VDCNN, HARNN, and SARNN.
data_clas = load_data(path, f'{lang}_textlist_class', bs=bs, num_workers=1)
learn_c = text_classifier_learner(data_clas, AWD_LSTM, drop_mult=0.5, metrics=[accuracy,f1]).to_fp16()
learn_c.load(f'{lang}clas', purge=False);
preds,targs = learn_c.get_preds(ordered=True)
accuracy(preds,targs),f1(preds,targs)
(tensor(0.9111), tensor(0.8952))
data_clas_bwd = load_data(path, f'{lang}_textlist_class_bwd', bs=bs, num_workers=1, backwards=True)
learn_c_bwd = text_classifier_learner(data_clas_bwd, AWD_LSTM, drop_mult=0.5, metrics=[accuracy,f1]).to_fp16()
learn_c_bwd.load(f'{lang}clas_bwd', purge=False);
preds_b,targs_b = learn_c_bwd.get_preds(ordered=True)
accuracy(preds_b,targs_b),f1(preds_b,targs_b)
(tensor(0.9092), tensor(0.8957))
preds_avg = (preds+preds_b)/2
accuracy(preds_avg,targs_b),f1(preds_avg,targs_b)
(tensor(0.9154), tensor(0.9016))