%reload_ext autoreload
%autoreload 2
%matplotlib inline
from fastai.model import fit
from fastai.dataset import *
import torchtext
from torchtext import vocab, data
from torchtext.datasets import language_modeling
from fastai.rnn_reg import *
from fastai.rnn_train import *
from fastai.nlp import *
from fastai.lm_rnn import *
import dill as pickle
import random
bs,bptt = 64,70
import os, requests, time
# feedparser isn't a fastai dependency so you may need to install it.
import feedparser
import pandas as pd
class GetArXiv(object):
def __init__(self, pickle_path, categories=list()):
"""
:param pickle_path (str): path to pickle data file to save/load
:param pickle_name (str): file name to save pickle to path
:param categories (list): arXiv categories to query
"""
if os.path.isdir(pickle_path):
pickle_path = f"{pickle_path}{'' if pickle_path[-1] == '/' else '/'}all_arxiv.pkl"
if len(categories) < 1:
categories = ['cs*', 'cond-mat.dis-nn', 'q-bio.NC', 'stat.CO', 'stat.ML']
# categories += ['cs.CV', 'cs.AI', 'cs.LG', 'cs.CL']
self.categories = categories
self.pickle_path = pickle_path
self.base_url = 'http://export.arxiv.org/api/query'
@staticmethod
def build_qs(categories):
"""Build query string from categories"""
return '+OR+'.join(['cat:'+c for c in categories])
@staticmethod
def get_entry_dict(entry):
"""Return a dictionary with the items we want from a feedparser entry"""
try:
return dict(title=entry['title'], authors=[a['name'] for a in entry['authors']],
published=pd.Timestamp(entry['published']), summary=entry['summary'],
link=entry['link'], category=entry['category'])
except KeyError:
print('Missing keys in row: {}'.format(entry))
return None
@staticmethod
def strip_version(link):
"""Strip version number from arXiv paper link"""
return link[:-2]
def fetch_updated_data(self, max_retry=5, pg_offset=0, pg_size=1000, wait_time=15):
"""
Get new papers from arXiv server
:param max_retry: max number of time to retry request
:param pg_offset: number of pages to offset
:param pg_size: num abstracts to fetch per request
:param wait_time: num seconds to wait between requests
"""
i, retry = pg_offset, 0
df = pd.DataFrame()
past_links = []
if os.path.isfile(self.pickle_path):
df = pd.read_pickle(self.pickle_path)
df.reset_index()
if len(df) > 0: past_links = df.link.apply(self.strip_version)
while True:
params = dict(search_query=self.build_qs(self.categories),
sortBy='submittedDate', start=pg_size*i, max_results=pg_size)
response = requests.get(self.base_url, params='&'.join([f'{k}={v}' for k, v in params.items()]))
entries = feedparser.parse(response.text).entries
if len(entries) < 1:
if retry < max_retry:
retry += 1
time.sleep(wait_time)
continue
break
results_df = pd.DataFrame([self.get_entry_dict(e) for e in entries])
max_date = results_df.published.max().date()
new_links = ~results_df.link.apply(self.strip_version).isin(past_links)
print(f'{i}. Fetched {len(results_df)} abstracts published {max_date} and earlier')
if not new_links.any():
break
df = pd.concat((df, results_df.loc[new_links]), ignore_index=True)
i += 1
retry = 0
time.sleep(wait_time)
print(f'Downloaded {len(df)-len(past_links)} new abstracts')
df.sort_values('published', ascending=False).groupby('link').first().reset_index()
df.to_pickle(self.pickle_path)
return df
@classmethod
def load(cls, pickle_path):
"""Load data from pickle and remove duplicates"""
return pd.read_pickle(cls(pickle_path).pickle_path)
@classmethod
def update(cls, pickle_path, categories=list(), **kwargs):
"""
Update arXiv data pickle with the latest abstracts
"""
cls(pickle_path, categories).fetch_updated_data(**kwargs)
return True
PATH='data/arxiv/'
ALL_ARXIV = f'{PATH}all_arxiv.pkl'
# all_arxiv.pkl: if arxiv hasn't been downloaded yet, it'll take some time to get it - go get some coffee
if not os.path.exists(ALL_ARXIV): GetArXiv.update(ALL_ARXIV)
# arxiv.csv: see dl1/nlp-arxiv.ipynb to get this one
df_mb = pd.read_csv(f'{PATH}arxiv.csv')
df_all = pd.read_pickle(ALL_ARXIV)
def get_txt(df):
return '<CAT> ' + df.category.str.replace(r'[\.\-]','') + ' <SUMM> ' + df.summary + ' <TITLE> ' + df.title
df_mb['txt'] = get_txt(df_mb)
df_all['txt'] = get_txt(df_all)
n=len(df_all); n
49600
os.makedirs(f'{PATH}trn/yes', exist_ok=True)
os.makedirs(f'{PATH}val/yes', exist_ok=True)
os.makedirs(f'{PATH}trn/no', exist_ok=True)
os.makedirs(f'{PATH}val/no', exist_ok=True)
os.makedirs(f'{PATH}all/trn', exist_ok=True)
os.makedirs(f'{PATH}all/val', exist_ok=True)
os.makedirs(f'{PATH}models', exist_ok=True)
for (i,(_,r)) in enumerate(df_all.iterrows()):
dset = 'trn' if random.random()>0.1 else 'val'
open(f'{PATH}all/{dset}/{i}.txt', 'w').write(r['txt'])
for (i,(_,r)) in enumerate(df_mb.iterrows()):
lbl = 'yes' if r.tweeted else 'no'
dset = 'trn' if random.random()>0.1 else 'val'
open(f'{PATH}{dset}/{lbl}/{i}.txt', 'w').write(r['txt'])
from spacy.symbols import ORTH
# install the 'en' model if the next line of code fails by running:
#python -m spacy download en # default English model (~50MB)
#python -m spacy download en_core_web_md # larger English model (~1GB)
my_tok = spacy.load('en')
my_tok.tokenizer.add_special_case('<SUMM>', [{ORTH: '<SUMM>'}])
my_tok.tokenizer.add_special_case('<CAT>', [{ORTH: '<CAT>'}])
my_tok.tokenizer.add_special_case('<TITLE>', [{ORTH: '<TITLE>'}])
my_tok.tokenizer.add_special_case('<BR />', [{ORTH: '<BR />'}])
my_tok.tokenizer.add_special_case('<BR>', [{ORTH: '<BR>'}])
def my_spacy_tok(x): return [tok.text for tok in my_tok.tokenizer(x)]
TEXT = data.Field(lower=True, tokenize=my_spacy_tok)
FILES = dict(train='trn', validation='val', test='val')
md = LanguageModelData.from_text_files(f'{PATH}all/', TEXT, **FILES, bs=bs, bptt=bptt, min_freq=10)
pickle.dump(TEXT, open(f'{PATH}models/TEXT.pkl','wb'))
len(md.trn_dl), md.nt, len(md.trn_ds), len(md.trn_ds[0].text)
(2129, 17951, 1, 9543290)
TEXT.vocab.itos[:12]
['<unk>', '<pad>', '\n', 'the', ',', '.', 'of', '-', 'and', 'a', 'to', 'in']
' '.join(md.trn_ds[0].text[:150])
'<cat> csni <summ> the exploitation of mm - wave bands is one of the key - enabler for 5 g mobile \n radio networks . however , the introduction of mm - wave technologies in cellular \n networks is not straightforward due to harsh propagation conditions that limit \n the mm - wave access availability . mm - wave technologies require high - gain antenna \n systems to compensate for high path loss and limited power . as a consequence , \n directional transmissions must be used for cell discovery and synchronization \n processes : this can lead to a non - negligible access delay caused by the \n exploration of the cell area with multiple transmissions along different \n directions . \n the integration of mm - wave technologies and conventional wireless access \n networks with the objective of speeding up the cell search process requires new \n'
em_sz = 200
nh = 500
nl = 3
opt_fn = partial(optim.Adam, betas=(0.7, 0.99))
learner = md.get_model(opt_fn, em_sz, nh, nl,
dropout=0.05, dropouth=0.1, dropouti=0.05, dropoute=0.02, wdrop=0.2)
# dropout=0.4, dropouth=0.3, dropouti=0.65, dropoute=0.1, wdrop=0.5
# dropouti=0.05, dropout=0.05, wdrop=0.1, dropoute=0.02, dropouth=0.05)
learner.reg_fn = partial(seq2seq_reg, alpha=2, beta=1)
learner.clip=0.3
learner.fit(3e-3, 1, wds=1e-6)
A Jupyter Widget
[ 0. 4.36189 4.2185 ]
learner.fit(3e-3, 3, wds=1e-6, cycle_len=1, cycle_mult=2)
A Jupyter Widget
[ 0. 4.11236 3.99207] [ 1. 4.03207 3.89298] [ 2. 3.91653 3.81915] [ 3. 3.97808 3.8428 ] [ 4. 3.88482 3.76226] [ 5. 3.79955 3.70472] [ 6. 3.75721 3.69048]
learner.save_encoder('adam2_enc')
learner.fit(3e-3, 10, wds=1e-6, cycle_len=5, cycle_save_name='adam3_10')
A Jupyter Widget
[ 0. 3.89388 3.76575] [ 1. 3.82548 3.71875] [ 2. 3.76471 3.66974] [ 3. 3.71713 3.63861] [ 4. 3.67534 3.62983] [ 5. 3.83938 3.71551] [ 6. 3.78093 3.68056] [ 7. 3.72828 3.63638] [ 8. 3.66743 3.60355] [ 9. 3.65793 3.59448] [ 10. 3.80545 3.68213] [ 11. 3.75299 3.65219] [ 12. 3.7057 3.61324] [ 13. 3.63348 3.58048] [ 14. 3.62304 3.57257] [ 15. 3.78656 3.66324] [ 16. 3.73522 3.63348] [ 17. 3.67258 3.59369] [ 18. 3.6242 3.56674] [ 19. 3.61123 3.55783] [ 20. 3.77443 3.65472] [ 21. 3.7374 3.62169] [ 22. 3.68367 3.58247] [ 23. 3.61606 3.55567] [ 24. 3.58527 3.54725] [ 25. 3.75456 3.63861] [ 26. 3.72061 3.61084] [ 27. 3.65141 3.57073] [ 28. 3.59711 3.54414] [ 29. 3.57052 3.53622] [ 30. 3.75229 3.62935] [ 31. 3.70693 3.60198] [ 32. 3.65193 3.56444] [ 33. 3.59173 3.53742] [ 34. 3.58699 3.53152] [ 35. 3.74211 3.62154] [ 36. 3.70016 3.59831] [ 37. 3.64095 3.55689] [ 38. 3.60686 3.53296] [ 39. 3.5627 3.523 ] [ 40. 3.72944 3.61956] [ 41. 3.68161 3.58779] [ 42. 3.62305 3.55187] [ 43. 3.58559 3.52524] [ 44. 3.56087 3.51682] [ 45. 3.72533 3.61458] [ 46. 3.68025 3.58452] [ 47. 3.64447 3.55002] [ 48. 3.575 3.52066] [ 49. 3.54424 3.5133 ]
learner.save_encoder('adam3_10_enc')
learner.fit(3e-3, 8, wds=1e-6, cycle_len=10, cycle_save_name='adam3_5')
A Jupyter Widget
[ 0. 3.70587 3.61666] [ 1. 3.71738 3.61174] [ 2. 3.68606 3.59661] [ 3. 3.65407 3.5742 ] [ 4. 3.62901 3.55795] [ 5. 3.59921 3.53632] [ 6. 3.58401 3.52149] [ 7. 3.55126 3.50797] [ 8. 3.52965 3.50178] [ 9. 3.52336 3.49997] [ 10. 3.7109 3.60817] [ 11. 3.69879 3.60047] [ 12. 3.6735 3.58623] [ 13. 3.64365 3.56568] [ 14. 3.6099 3.54776] [ 15. 3.58244 3.52829] [ 16. 3.54894 3.51071] [ 17. 3.52702 3.50173] [ 18. 3.51357 3.49522] [ 19. 3.50302 3.49272] [ 20. 3.72505 3.60198] [ 21. 3.70037 3.59914] [ 22. 3.68386 3.58279] [ 23. 3.64176 3.56435] [ 24. 3.60259 3.54304] [ 25. 3.58669 3.52432] [ 26. 3.54075 3.50703] [ 27. 3.50951 3.49534] [ 28. 3.51915 3.4896 ] [ 29. 3.48695 3.48968] [ 30. 3.70563 3.59631] [ 31. 3.68822 3.58723] [ 32. 3.67549 3.58141] [ 33. 3.63267 3.55537] [ 34. 3.60638 3.5386 ] [ 35. 3.58803 3.52154] [ 36. 3.53987 3.50394] [ 37. 3.51036 3.49244] [ 38. 3.48651 3.48652] [ 39. 3.49061 3.48673] [ 40. 3.70093 3.59211] [ 41. 3.67371 3.58516] [ 42. 3.66558 3.57032] [ 43. 3.65089 3.55939] [ 44. 3.59885 3.53445] [ 45. 3.56369 3.51585] [ 46. 3.55304 3.50237] [ 47. 3.50469 3.48919] [ 48. 3.49559 3.48289] [ 49. 3.50912 3.48136] [ 50. 3.70603 3.59182] [ 51. 3.669 3.58069] [ 52. 3.64965 3.56896] [ 53. 3.62839 3.55251] [ 54. 3.59578 3.53297] [ 55. 3.55814 3.51205] [ 56. 3.53653 3.49682] [ 57. 3.50043 3.48502] [ 58. 3.49535 3.4797 ] [ 59. 3.48039 3.47882] [ 60. 3.68319 3.58874] [ 61. 3.68893 3.58173] [ 62. 3.6516 3.56403] [ 63. 3.63432 3.55047] [ 64. 3.59697 3.52815] [ 65. 3.55784 3.50832] [ 66. 3.52815 3.49319] [ 67. 3.50618 3.48222] [ 68. 3.48319 3.47645] [ 69. 3.49879 3.47596] [ 70. 3.68466 3.58318] [ 71. 3.67045 3.57351] [ 72. 3.64409 3.5606 ] [ 73. 3.61991 3.54552] [ 74. 3.60503 3.52782] [ 75. 3.56681 3.50743] [ 76. 3.52401 3.49046] [ 77. 3.50519 3.47875] [ 78. 3.49343 3.47452] [ 79. 3.49275 3.47175]
learner.fit(3e-3, 1, wds=1e-6, cycle_len=20, cycle_save_name='adam3_20')
A Jupyter Widget
[ 0. 3.47841 3.4751 ] [ 1. 3.69717 3.57883] [ 2. 3.68267 3.57793] [ 3. 3.66797 3.57299] [ 4. 3.66805 3.56847] [ 5. 3.63489 3.56238] [ 6. 3.62479 3.54928] [ 7. 3.60663 3.53879] [ 8. 3.59124 3.53175] [ 9. 3.58617 3.52009] [ 10. 3.56924 3.51174] [ 11. 3.5509 3.49974] [ 12. 3.51595 3.49008] [ 13. 3.50939 3.48222] [ 14. 3.48886 3.47952] [ 15. 3.4676 3.47311] [ 16. 3.4856 3.46577] [ 17. 3.44909 3.46499] [ 18. 3.46791 3.46314] [ 19. 3.44028 3.46231]
learner.save_encoder('adam3_20_enc')
learner.save('adam3_20')
def proc_str(s): return TEXT.preprocess(TEXT.tokenize(s))
def num_str(s): return TEXT.numericalize([proc_str(s)])
m=learner.model
s="""<CAT> cscv <SUMM> algorithms that"""
def sample_model(m, s, l=50):
t = num_str(s)
m[0].bs=1
m.eval()
m.reset()
res,*_ = m(t)
print('...', end='')
for i in range(l):
n=res[-1].topk(2)[1]
n = n[1] if n.data[0]==0 else n[0]
word = TEXT.vocab.itos[n.data[0]]
print(word, end=' ')
if word=='<eos>': break
res,*_ = m(n[0].unsqueeze(0))
m[0].bs=bs
sample_model(m,"<CAT> csni <SUMM> algorithms that")
...use the same network as a single node are not able to achieve the same performance as the traditional network - based routing algorithms . in this paper , we propose a novel routing scheme for routing protocols in wireless networks . the proposed scheme is based ...
sample_model(m,"<CAT> cscv <SUMM> algorithms that")
...use the same data to perform image classification are increasingly being used to improve the performance of image classification algorithms . in this paper , we propose a novel method for image classification using a deep convolutional neural network ( cnn ) . the proposed method is ...
sample_model(m,"<CAT> cscv <SUMM> algorithms. <TITLE> on ")
...the performance of deep learning for image classification <eos>
sample_model(m,"<CAT> csni <SUMM> algorithms. <TITLE> on ")
...the performance of wireless networks <eos>
sample_model(m,"<CAT> cscv <SUMM> algorithms. <TITLE> towards ")
...a new approach to image classification <eos>
sample_model(m,"<CAT> csni <SUMM> algorithms. <TITLE> towards ")
...a new approach to the analysis of wireless networks <eos>
TEXT = pickle.load(open(f'{PATH}models/TEXT.pkl','rb'))
class ArxivDataset(torchtext.data.Dataset):
def __init__(self, path, text_field, label_field, **kwargs):
fields = [('text', text_field), ('label', label_field)]
examples = []
for label in ['yes', 'no']:
fnames = glob(os.path.join(path, label, '*.txt'));
assert fnames, f"can't find 'yes.txt' or 'no.txt' under {path}/{label}"
for fname in fnames:
with open(fname, 'r') as f: text = f.readline()
examples.append(data.Example.fromlist([text, label], fields))
super().__init__(examples, fields, **kwargs)
@staticmethod
def sort_key(ex): return len(ex.text)
@classmethod
def splits(cls, text_field, label_field, root='.data',
train='train', test='test', **kwargs):
return super().splits(
root, text_field=text_field, label_field=label_field,
train=train, validation=None, test=test, **kwargs)
ARX_LABEL = data.Field(sequential=False)
splits = ArxivDataset.splits(TEXT, ARX_LABEL, PATH, train='trn', test='val')
md2 = TextData.from_splits(PATH, splits, bs)
# dropout=0.3, dropouti=0.4, wdrop=0.3, dropoute=0.05, dropouth=0.2)
from sklearn.metrics import precision_recall_curve
import matplotlib.pyplot as plt
def prec_at_6(preds,targs):
precision, recall, _ = precision_recall_curve(targs==2, preds[:,2])
print(recall[precision>=0.6][0])
return recall[precision>=0.6][0]
# dropout=0.4, dropouth=0.3, dropouti=0.65, dropoute=0.1, wdrop=0.5
m3 = md2.get_model(opt_fn, 1500, bptt, emb_sz=em_sz, n_hid=nh, n_layers=nl,
dropout=0.1, dropouti=0.65, wdrop=0.5, dropoute=0.1, dropouth=0.3)
m3.reg_fn = partial(seq2seq_reg, alpha=2, beta=1)
m3.clip=25.
# this notebook has a mess of some things going under 'all/' others not, so a little hack here
!ln -sf ../all/models/adam3_20_enc.h5 {PATH}models/adam3_20_enc.h5
m3.load_encoder(f'adam3_20_enc')
lrs=np.array([1e-4,1e-3,1e-3,1e-2,3e-2])
m3.freeze_to(-1)
m3.fit(lrs/2, 1, metrics=[accuracy])
m3.unfreeze()
m3.fit(lrs, 1, metrics=[accuracy], cycle_len=1)
A Jupyter Widget
[ 0. 0.47654 0.44322 0.78525]
A Jupyter Widget
[ 0. 0.43033 0.40192 0.80087]
m3.fit(lrs, 2, metrics=[accuracy], cycle_len=4, cycle_save_name='imdb2')
A Jupyter Widget
[ 0. 0.42236 0.39006 0.8194 ] [ 1. 0.39477 0.37063 0.82086] [ 2. 0.39389 0.37082 0.82449] [ 3. 0.40728 0.36999 0.82195] [ 4. 0.39308 0.3675 0.81977] [ 5. 0.38662 0.36737 0.8234 ] [ 6. 0.39259 0.36512 0.82486] [ 7. 0.38047 0.36538 0.82522]
prec_at_6(*m3.predict_with_targs())
0.659305993691
0.65930599369085174
m3.fit(lrs, 4, metrics=[accuracy], cycle_len=2, cycle_save_name='imdb2')
A Jupyter Widget
[ 0. 0.38752 0.36351 0.82486] [ 1. 0.38664 0.36123 0.82558] [ 2. 0.3904 0.36098 0.82486] [ 3. 0.37319 0.36144 0.82486] [ 4. 0.38074 0.36334 0.82595] [ 5. 0.36405 0.3594 0.82413] [ 6. 0.38781 0.35914 0.82522] [ 7. 0.37722 0.357 0.82631]
prec_at_6(*m3.predict_with_targs())
0.695583596215
0.69558359621451105