Important: This notebook will only work with fastai-0.7.x. Do not try to run any fastai-1.x code from this path in the repository because it will load fastai-0.7.x
%reload_ext autoreload
%autoreload 2
%matplotlib inline
from fastai.io import *
from fastai.conv_learner import *
from fastai.column_data import *
We're going to download the collected works of Nietzsche to use as our data for this class.
PATH='data/nietzsche/'
get_data("https://s3.amazonaws.com/text-datasets/nietzsche.txt", f'{PATH}nietzsche.txt')
text = open(f'{PATH}nietzsche.txt').read()
print('corpus length:', len(text))
corpus length: 600893
text[:400]
'PREFACE\n\n\nSUPPOSING that Truth is a woman--what then? Is there not ground\nfor suspecting that all philosophers, in so far as they have been\ndogmatists, have failed to understand women--that the terrible\nseriousness and clumsy importunity with which they have usually paid\ntheir addresses to Truth, have been unskilled and unseemly methods for\nwinning a woman? Certainly she has never allowed herself '
chars = sorted(list(set(text)))
vocab_size = len(chars)+1
print('total chars:', vocab_size)
total chars: 85
Sometimes it's useful to have a zero value in the dataset, e.g. for padding
chars.insert(0, "\0")
''.join(chars[1:-6])
'\n !"\'(),-.0123456789:;=?ABCDEFGHIJKLMNOPQRSTUVWXYZ[]_abcdefghijklmnopqrstuvwxy'
Map from chars to indices and back again
char_indices = {c: i for i, c in enumerate(chars)}
indices_char = {i: c for i, c in enumerate(chars)}
idx will be the data we use from now on - it simply converts all the characters to their index (based on the mapping above)
idx = [char_indices[c] for c in text]
idx[:10]
[40, 42, 29, 30, 25, 27, 29, 1, 1, 1]
''.join(indices_char[i] for i in idx[:70])
'PREFACE\n\n\nSUPPOSING that Truth is a woman--what then? Is there not gro'
Create a list of every 4th character, starting at the 0th, 1st, 2nd, then 3rd characters
cs=3
c1_dat = [idx[i] for i in range(0, len(idx)-cs, cs)]
c2_dat = [idx[i+1] for i in range(0, len(idx)-cs, cs)]
c3_dat = [idx[i+2] for i in range(0, len(idx)-cs, cs)]
c4_dat = [idx[i+3] for i in range(0, len(idx)-cs, cs)]
Our inputs
x1 = np.stack(c1_dat)
x2 = np.stack(c2_dat)
x3 = np.stack(c3_dat)
Our output
y = np.stack(c4_dat)
The first 4 inputs and outputs
x1[:4], x2[:4], x3[:4]
(array([40, 30, 29, 1]), array([42, 25, 1, 43]), array([29, 27, 1, 45]))
y[:4]
array([30, 29, 1, 40])
x1.shape, y.shape
((200295,), (200295,))
Pick a size for our hidden state
n_hidden = 256
The number of latent factors to create (i.e. the size of the embedding matrix)
n_fac = 42
class Char3Model(nn.Module):
def __init__(self, vocab_size, n_fac):
super().__init__()
self.e = nn.Embedding(vocab_size, n_fac)
# The 'green arrow' from our diagram - the layer operation from input to hidden
self.l_in = nn.Linear(n_fac, n_hidden)
# The 'orange arrow' from our diagram - the layer operation from hidden to hidden
self.l_hidden = nn.Linear(n_hidden, n_hidden)
# The 'blue arrow' from our diagram - the layer operation from hidden to output
self.l_out = nn.Linear(n_hidden, vocab_size)
def forward(self, c1, c2, c3):
in1 = F.relu(self.l_in(self.e(c1)))
in2 = F.relu(self.l_in(self.e(c2)))
in3 = F.relu(self.l_in(self.e(c3)))
h = V(torch.zeros(in1.size()).cuda())
h = F.tanh(self.l_hidden(h+in1))
h = F.tanh(self.l_hidden(h+in2))
h = F.tanh(self.l_hidden(h+in3))
return F.log_softmax(self.l_out(h))
md = ColumnarModelData.from_arrays('.', [-1], np.stack([x1,x2,x3], axis=1), y, bs=512)
m = Char3Model(vocab_size, n_fac).cuda()
it = iter(md.trn_dl)
*xs,yt = next(it)
t = m(*V(xs))
opt = optim.Adam(m.parameters(), 1e-2)
fit(m, md, 1, opt, F.nll_loss)
A Jupyter Widget
[ 0. 2.09627 6.52849]
set_lrs(opt, 0.001)
fit(m, md, 1, opt, F.nll_loss)
A Jupyter Widget
[ 0. 1.84525 6.52312]
def get_next(inp):
idxs = T(np.array([char_indices[c] for c in inp]))
p = m(*VV(idxs))
i = np.argmax(to_np(p))
return chars[i]
get_next('y. ')
'T'
get_next('ppl')
'e'
get_next(' th')
'e'
get_next('and')
' '
This is the size of our unrolled RNN.
cs=8
For each of 0 through 7, create a list of every 8th character with that starting point. These will be the 8 inputs to our model.
c_in_dat = [[idx[i+j] for i in range(cs)] for j in range(len(idx)-cs)]
Then create a list of the next character in each of these series. This will be the labels for our model.
c_out_dat = [idx[j+cs] for j in range(len(idx)-cs)]
xs = np.stack(c_in_dat, axis=0)
xs.shape
(600884, 8)
y = np.stack(c_out_dat)
So each column below is one series of 8 characters from the text.
xs[:cs,:cs]
array([[40, 42, 29, 30, 25, 27, 29, 1], [42, 29, 30, 25, 27, 29, 1, 1], [29, 30, 25, 27, 29, 1, 1, 1], [30, 25, 27, 29, 1, 1, 1, 43], [25, 27, 29, 1, 1, 1, 43, 45], [27, 29, 1, 1, 1, 43, 45, 40], [29, 1, 1, 1, 43, 45, 40, 40], [ 1, 1, 1, 43, 45, 40, 40, 39]])
...and this is the next character after each sequence.
y[:cs]
array([ 1, 1, 43, 45, 40, 40, 39, 43])
val_idx = get_cv_idxs(len(idx)-cs-1)
md = ColumnarModelData.from_arrays('.', val_idx, xs, y, bs=512)
class CharLoopModel(nn.Module):
# This is an RNN!
def __init__(self, vocab_size, n_fac):
super().__init__()
self.e = nn.Embedding(vocab_size, n_fac)
self.l_in = nn.Linear(n_fac, n_hidden)
self.l_hidden = nn.Linear(n_hidden, n_hidden)
self.l_out = nn.Linear(n_hidden, vocab_size)
def forward(self, *cs):
bs = cs[0].size(0)
h = V(torch.zeros(bs, n_hidden).cuda())
for c in cs:
inp = F.relu(self.l_in(self.e(c)))
h = F.tanh(self.l_hidden(h+inp))
return F.log_softmax(self.l_out(h), dim=-1)
m = CharLoopModel(vocab_size, n_fac).cuda()
opt = optim.Adam(m.parameters(), 1e-2)
fit(m, md, 1, opt, F.nll_loss)
A Jupyter Widget
[ 0. 2.02986 1.99268]
set_lrs(opt, 0.001)
fit(m, md, 1, opt, F.nll_loss)
A Jupyter Widget
[ 0. 1.73588 1.75103]
class CharLoopConcatModel(nn.Module):
def __init__(self, vocab_size, n_fac):
super().__init__()
self.e = nn.Embedding(vocab_size, n_fac)
self.l_in = nn.Linear(n_fac+n_hidden, n_hidden)
self.l_hidden = nn.Linear(n_hidden, n_hidden)
self.l_out = nn.Linear(n_hidden, vocab_size)
def forward(self, *cs):
bs = cs[0].size(0)
h = V(torch.zeros(bs, n_hidden).cuda())
for c in cs:
inp = torch.cat((h, self.e(c)), 1)
inp = F.relu(self.l_in(inp))
h = F.tanh(self.l_hidden(inp))
return F.log_softmax(self.l_out(h), dim=-1)
m = CharLoopConcatModel(vocab_size, n_fac).cuda()
opt = optim.Adam(m.parameters(), 1e-3)
it = iter(md.trn_dl)
*xs,yt = next(it)
t = m(*V(xs))
fit(m, md, 1, opt, F.nll_loss)
A Jupyter Widget
[ 0. 1.81654 1.78501]
set_lrs(opt, 1e-4)
fit(m, md, 1, opt, F.nll_loss)
A Jupyter Widget
[ 0. 1.69008 1.69936]
def get_next(inp):
idxs = T(np.array([char_indices[c] for c in inp]))
p = m(*VV(idxs))
i = np.argmax(to_np(p))
return chars[i]
get_next('for thos')
'e'
get_next('part of ')
't'
get_next('queens a')
'n'
class CharRnn(nn.Module):
def __init__(self, vocab_size, n_fac):
super().__init__()
self.e = nn.Embedding(vocab_size, n_fac)
self.rnn = nn.RNN(n_fac, n_hidden)
self.l_out = nn.Linear(n_hidden, vocab_size)
def forward(self, *cs):
bs = cs[0].size(0)
h = V(torch.zeros(1, bs, n_hidden))
inp = self.e(torch.stack(cs))
outp,h = self.rnn(inp, h)
return F.log_softmax(self.l_out(outp[-1]), dim=-1)
m = CharRnn(vocab_size, n_fac).cuda()
opt = optim.Adam(m.parameters(), 1e-3)
it = iter(md.trn_dl)
*xs,yt = next(it)
t = m.e(V(torch.stack(xs)))
t.size()
torch.Size([8, 512, 42])
ht = V(torch.zeros(1, 512,n_hidden))
outp, hn = m.rnn(t, ht)
outp.size(), hn.size()
(torch.Size([8, 512, 256]), torch.Size([1, 512, 256]))
t = m(*V(xs)); t.size()
torch.Size([512, 85])
fit(m, md, 4, opt, F.nll_loss)
A Jupyter Widget
[ 0. 1.86065 1.84255] [ 1. 1.68014 1.67387] [ 2. 1.58828 1.59169] [ 3. 1.52989 1.54942]
set_lrs(opt, 1e-4)
fit(m, md, 2, opt, F.nll_loss)
A Jupyter Widget
[ 0. 1.46841 1.50966] [ 1. 1.46482 1.5039 ]
def get_next(inp):
idxs = T(np.array([char_indices[c] for c in inp]))
p = m(*VV(idxs))
i = np.argmax(to_np(p))
return chars[i]
get_next('for thos')
'e'
def get_next_n(inp, n):
res = inp
for i in range(n):
c = get_next(inp)
res += c
inp = inp[1:]+c
return res
get_next_n('for thos', 40)
'for those the same the same the same the same th'
Let's take non-overlapping sets of characters this time
c_in_dat = [[idx[i+j] for i in range(cs)] for j in range(0, len(idx)-cs-1, cs)]
Then create the exact same thing, offset by 1, as our labels
c_out_dat = [[idx[i+j] for i in range(cs)] for j in range(1, len(idx)-cs, cs)]
xs = np.stack(c_in_dat)
xs.shape
(75111, 8)
ys = np.stack(c_out_dat)
ys.shape
(75111, 8)
xs[:cs,:cs]
array([[40, 42, 29, 30, 25, 27, 29, 1], [ 1, 1, 43, 45, 40, 40, 39, 43], [33, 38, 31, 2, 73, 61, 54, 73], [ 2, 44, 71, 74, 73, 61, 2, 62], [72, 2, 54, 2, 76, 68, 66, 54], [67, 9, 9, 76, 61, 54, 73, 2], [73, 61, 58, 67, 24, 2, 33, 72], [ 2, 73, 61, 58, 71, 58, 2, 67]])
ys[:cs,:cs]
array([[42, 29, 30, 25, 27, 29, 1, 1], [ 1, 43, 45, 40, 40, 39, 43, 33], [38, 31, 2, 73, 61, 54, 73, 2], [44, 71, 74, 73, 61, 2, 62, 72], [ 2, 54, 2, 76, 68, 66, 54, 67], [ 9, 9, 76, 61, 54, 73, 2, 73], [61, 58, 67, 24, 2, 33, 72, 2], [73, 61, 58, 71, 58, 2, 67, 68]])
val_idx = get_cv_idxs(len(xs)-cs-1)
md = ColumnarModelData.from_arrays('.', val_idx, xs, ys, bs=512)
class CharSeqRnn(nn.Module):
def __init__(self, vocab_size, n_fac):
super().__init__()
self.e = nn.Embedding(vocab_size, n_fac)
self.rnn = nn.RNN(n_fac, n_hidden)
self.l_out = nn.Linear(n_hidden, vocab_size)
def forward(self, *cs):
bs = cs[0].size(0)
h = V(torch.zeros(1, bs, n_hidden))
inp = self.e(torch.stack(cs))
outp,h = self.rnn(inp, h)
return F.log_softmax(self.l_out(outp), dim=-1)
m = CharSeqRnn(vocab_size, n_fac).cuda()
opt = optim.Adam(m.parameters(), 1e-3)
it = iter(md.trn_dl)
*xst,yt = next(it)
def nll_loss_seq(inp, targ):
sl,bs,nh = inp.size()
targ = targ.transpose(0,1).contiguous().view(-1)
return F.nll_loss(inp.view(-1,nh), targ)
fit(m, md, 4, opt, nll_loss_seq)
A Jupyter Widget
[ 0. 2.59241 2.40251] [ 1. 2.28474 2.19859] [ 2. 2.13883 2.08836] [ 3. 2.04892 2.01564]
set_lrs(opt, 1e-4)
fit(m, md, 1, opt, nll_loss_seq)
A Jupyter Widget
[ 0. 1.99819 2.00106]
m = CharSeqRnn(vocab_size, n_fac).cuda()
opt = optim.Adam(m.parameters(), 1e-2)
m.rnn.weight_hh_l0.data.copy_(torch.eye(n_hidden))
1 0 0 ... 0 0 0 0 1 0 ... 0 0 0 0 0 1 ... 0 0 0 ... ⋱ ... 0 0 0 ... 1 0 0 0 0 0 ... 0 1 0 0 0 0 ... 0 0 1 [torch.cuda.FloatTensor of size 256x256 (GPU 0)]
fit(m, md, 4, opt, nll_loss_seq)
A Jupyter Widget
[ 0. 2.39428 2.21111] [ 1. 2.10381 2.03275] [ 2. 1.99451 1.96393] [ 3. 1.93492 1.91763]
set_lrs(opt, 1e-3)
fit(m, md, 4, opt, nll_loss_seq)
A Jupyter Widget
[ 0. 1.84035 1.85742] [ 1. 1.82896 1.84887] [ 2. 1.81879 1.84281] [ 3. 1.81337 1.83801]
from torchtext import vocab, data
from fastai.nlp import *
from fastai.lm_rnn import *
PATH='data/nietzsche/'
TRN_PATH = 'trn/'
VAL_PATH = 'val/'
TRN = f'{PATH}{TRN_PATH}'
VAL = f'{PATH}{VAL_PATH}'
# Note: The student needs to practice her shell skills and prepare her own dataset before proceeding:
# - trn/trn.txt (first 80% of nietzsche.txt)
# - val/val.txt (last 20% of nietzsche.txt)
%ls {PATH}
models/ nietzsche.txt trn/ val/
%ls {PATH}trn
trn.txt
TEXT = data.Field(lower=True, tokenize=list)
bs=64; bptt=8; n_fac=42; n_hidden=256
FILES = dict(train=TRN_PATH, validation=VAL_PATH, test=VAL_PATH)
md = LanguageModelData.from_text_files(PATH, TEXT, **FILES, bs=bs, bptt=bptt, min_freq=3)
len(md.trn_dl), md.nt, len(md.trn_ds), len(md.trn_ds[0].text)
(963, 56, 1, 493747)
class CharSeqStatefulRnn(nn.Module):
def __init__(self, vocab_size, n_fac, bs):
self.vocab_size = vocab_size
super().__init__()
self.e = nn.Embedding(vocab_size, n_fac)
self.rnn = nn.RNN(n_fac, n_hidden)
self.l_out = nn.Linear(n_hidden, vocab_size)
self.init_hidden(bs)
def forward(self, cs):
bs = cs[0].size(0)
if self.h.size(1) != bs: self.init_hidden(bs)
outp,h = self.rnn(self.e(cs), self.h)
self.h = repackage_var(h)
return F.log_softmax(self.l_out(outp), dim=-1).view(-1, self.vocab_size)
def init_hidden(self, bs): self.h = V(torch.zeros(1, bs, n_hidden))
m = CharSeqStatefulRnn(md.nt, n_fac, 512).cuda()
opt = optim.Adam(m.parameters(), 1e-3)
fit(m, md, 4, opt, F.nll_loss)
A Jupyter Widget
[ 0. 1.81983 1.81247] [ 1. 1.63097 1.66228] [ 2. 1.54433 1.57824] [ 3. 1.48563 1.54505]
set_lrs(opt, 1e-4)
fit(m, md, 4, opt, F.nll_loss)
A Jupyter Widget
[ 0. 1.4187 1.50374] [ 1. 1.41492 1.49391] [ 2. 1.41001 1.49339] [ 3. 1.40756 1.486 ]
# From the pytorch source
def RNNCell(input, hidden, w_ih, w_hh, b_ih, b_hh):
return F.tanh(F.linear(input, w_ih, b_ih) + F.linear(hidden, w_hh, b_hh))
class CharSeqStatefulRnn2(nn.Module):
def __init__(self, vocab_size, n_fac, bs):
super().__init__()
self.vocab_size = vocab_size
self.e = nn.Embedding(vocab_size, n_fac)
self.rnn = nn.RNNCell(n_fac, n_hidden)
self.l_out = nn.Linear(n_hidden, vocab_size)
self.init_hidden(bs)
def forward(self, cs):
bs = cs[0].size(0)
if self.h.size(1) != bs: self.init_hidden(bs)
outp = []
o = self.h
for c in cs:
o = self.rnn(self.e(c), o)
outp.append(o)
outp = self.l_out(torch.stack(outp))
self.h = repackage_var(o)
return F.log_softmax(outp, dim=-1).view(-1, self.vocab_size)
def init_hidden(self, bs): self.h = V(torch.zeros(1, bs, n_hidden))
m = CharSeqStatefulRnn2(md.nt, n_fac, 512).cuda()
opt = optim.Adam(m.parameters(), 1e-3)
fit(m, md, 4, opt, F.nll_loss)
A Jupyter Widget
[ 0. 1.81013 1.7969 ] [ 1. 1.62515 1.65346] [ 2. 1.53913 1.58065] [ 3. 1.48698 1.54217]
class CharSeqStatefulGRU(nn.Module):
def __init__(self, vocab_size, n_fac, bs):
super().__init__()
self.vocab_size = vocab_size
self.e = nn.Embedding(vocab_size, n_fac)
self.rnn = nn.GRU(n_fac, n_hidden)
self.l_out = nn.Linear(n_hidden, vocab_size)
self.init_hidden(bs)
def forward(self, cs):
bs = cs[0].size(0)
if self.h.size(1) != bs: self.init_hidden(bs)
outp,h = self.rnn(self.e(cs), self.h)
self.h = repackage_var(h)
return F.log_softmax(self.l_out(outp), dim=-1).view(-1, self.vocab_size)
def init_hidden(self, bs): self.h = V(torch.zeros(1, bs, n_hidden))
# From the pytorch source code - for reference
def GRUCell(input, hidden, w_ih, w_hh, b_ih, b_hh):
gi = F.linear(input, w_ih, b_ih)
gh = F.linear(hidden, w_hh, b_hh)
i_r, i_i, i_n = gi.chunk(3, 1)
h_r, h_i, h_n = gh.chunk(3, 1)
resetgate = F.sigmoid(i_r + h_r)
inputgate = F.sigmoid(i_i + h_i)
newgate = F.tanh(i_n + resetgate * h_n)
return newgate + inputgate * (hidden - newgate)
m = CharSeqStatefulGRU(md.nt, n_fac, 512).cuda()
opt = optim.Adam(m.parameters(), 1e-3)
fit(m, md, 6, opt, F.nll_loss)
A Jupyter Widget
[ 0. 1.68409 1.67784] [ 1. 1.49813 1.52661] [ 2. 1.41674 1.46769] [ 3. 1.36359 1.43818] [ 4. 1.33223 1.41777] [ 5. 1.30217 1.40511]
set_lrs(opt, 1e-4)
fit(m, md, 3, opt, F.nll_loss)
A Jupyter Widget
[ 0. 1.22708 1.36926] [ 1. 1.21948 1.3696 ] [ 2. 1.22541 1.36969]
from fastai import sgdr
n_hidden=512
class CharSeqStatefulLSTM(nn.Module):
def __init__(self, vocab_size, n_fac, bs, nl):
super().__init__()
self.vocab_size,self.nl = vocab_size,nl
self.e = nn.Embedding(vocab_size, n_fac)
self.rnn = nn.LSTM(n_fac, n_hidden, nl, dropout=0.5)
self.l_out = nn.Linear(n_hidden, vocab_size)
self.init_hidden(bs)
def forward(self, cs):
bs = cs[0].size(0)
if self.h[0].size(1) != bs: self.init_hidden(bs)
outp,h = self.rnn(self.e(cs), self.h)
self.h = repackage_var(h)
return F.log_softmax(self.l_out(outp), dim=-1).view(-1, self.vocab_size)
def init_hidden(self, bs):
self.h = (V(torch.zeros(self.nl, bs, n_hidden)),
V(torch.zeros(self.nl, bs, n_hidden)))
m = CharSeqStatefulLSTM(md.nt, n_fac, 512, 2).cuda()
lo = LayerOptimizer(optim.Adam, m, 1e-2, 1e-5)
os.makedirs(f'{PATH}models', exist_ok=True)
fit(m, md, 2, lo.opt, F.nll_loss)
A Jupyter Widget
[ 0. 1.72032 1.64016] [ 1. 1.62891 1.58176]
on_end = lambda sched, cycle: save_model(m, f'{PATH}models/cyc_{cycle}')
cb = [CosAnneal(lo, len(md.trn_dl), cycle_mult=2, on_cycle_end=on_end)]
fit(m, md, 2**4-1, lo.opt, F.nll_loss, callbacks=cb)
A Jupyter Widget
[ 0. 1.47969 1.4472 ] [ 1. 1.51411 1.46612] [ 2. 1.412 1.39909] [ 3. 1.53689 1.48337] [ 4. 1.47375 1.43169] [ 5. 1.39828 1.37963] [ 6. 1.34546 1.35795] [ 7. 1.51999 1.47165] [ 8. 1.48992 1.46146] [ 9. 1.45492 1.42829] [ 10. 1.42027 1.39028] [ 11. 1.3814 1.36539] [ 12. 1.33895 1.34178] [ 13. 1.30737 1.32871] [ 14. 1.28244 1.31518]
on_end = lambda sched, cycle: save_model(m, f'{PATH}models/cyc_{cycle}')
cb = [CosAnneal(lo, len(md.trn_dl), cycle_mult=2, on_cycle_end=on_end)]
fit(m, md, 2**6-1, lo.opt, F.nll_loss, callbacks=cb)
A Jupyter Widget
[ 0. 1.46053 1.43462] [ 1. 1.51537 1.47747] [ 2. 1.39208 1.38293] [ 3. 1.53056 1.49371] [ 4. 1.46812 1.43389] [ 5. 1.37624 1.37523] [ 6. 1.3173 1.34022] [ 7. 1.51783 1.47554] [ 8. 1.4921 1.45785] [ 9. 1.44843 1.42215] [ 10. 1.40948 1.40858] [ 11. 1.37098 1.36648] [ 12. 1.32255 1.33842] [ 13. 1.28243 1.31106] [ 14. 1.25031 1.2918 ] [ 15. 1.49236 1.45316] [ 16. 1.46041 1.43622] [ 17. 1.45043 1.4498 ] [ 18. 1.43331 1.41297] [ 19. 1.43841 1.41704] [ 20. 1.41536 1.40521] [ 21. 1.39829 1.37656] [ 22. 1.37001 1.36891] [ 23. 1.35469 1.35909] [ 24. 1.32202 1.34228] [ 25. 1.29972 1.32256] [ 26. 1.28007 1.30903] [ 27. 1.24503 1.29125] [ 28. 1.22261 1.28316] [ 29. 1.20563 1.27397] [ 30. 1.18764 1.27178] [ 31. 1.18114 1.26694] [ 32. 1.44344 1.42405] [ 33. 1.43344 1.41616] [ 34. 1.4346 1.40442] [ 35. 1.42152 1.41359] [ 36. 1.42072 1.40835] [ 37. 1.41732 1.40498] [ 38. 1.41268 1.395 ] [ 39. 1.40725 1.39433] [ 40. 1.40181 1.39864] [ 41. 1.38621 1.37549] [ 42. 1.3838 1.38587] [ 43. 1.37644 1.37118] [ 44. 1.36287 1.36211] [ 45. 1.35942 1.36145] [ 46. 1.34712 1.34924] [ 47. 1.32994 1.34884] [ 48. 1.32788 1.33387] [ 49. 1.31553 1.342 ] [ 50. 1.30088 1.32435] [ 51. 1.28446 1.31166] [ 52. 1.27058 1.30807] [ 53. 1.26271 1.29935] [ 54. 1.24351 1.28942] [ 55. 1.23119 1.2838 ] [ 56. 1.2086 1.28364] [ 57. 1.19742 1.27375] [ 58. 1.18127 1.26758] [ 59. 1.17475 1.26858] [ 60. 1.15349 1.25999] [ 61. 1.14718 1.25779] [ 62. 1.13174 1.2524 ]
def get_next(inp):
idxs = TEXT.numericalize(inp)
p = m(VV(idxs.transpose(0,1)))
r = torch.multinomial(p[-1].exp(), 1)
return TEXT.vocab.itos[to_np(r)[0]]
get_next('for thos')
'e'
def get_next_n(inp, n):
res = inp
for i in range(n):
c = get_next(inp)
res += c
inp = inp[1:]+c
return res
print(get_next_n('for thos', 400))
for those the skemps), or imaginates, though they deceives. it should so each ourselvess and new present, step absolutely for the science." the contradity and measuring, the whole! 293. perhaps, that every life a values of blood of intercourse when it senses there is unscrupulus, his very rights, and still impulse, love? just after that thereby how made with the way anything, and set for harmless philos