#hide
from utils import *
#hide
from fastai2.text.all import *
path = untar_data(URLs.HUMAN_NUMBERS)
lines = L()
with open(path/'train.txt') as f: lines += L(*f.readlines())
with open(path/'valid.txt') as f: lines += L(*f.readlines())
text = ' . '.join([l.strip() for l in lines])
tokens = text.split(' ')
vocab = L(*tokens).unique()
word2idx = {w:i for i,w in enumerate(vocab)}
nums = L(word2idx[i] for i in tokens)
def group_chunks(ds, bs):
m = len(ds) // bs
new_ds = L()
for i in range(m): new_ds += L(ds[i + m*j] for j in range(bs))
return new_ds
sl,bs = 16,64
seqs = L((tensor(nums[i:i+sl]), tensor(nums[i+1:i+sl+1]))
for i in range(0,len(nums)-sl-1,sl))
cut = int(len(seqs) * 0.8)
dls = DataLoaders.from_dsets(group_chunks(seqs[:cut], bs),
group_chunks(seqs[cut:], bs),
bs=bs, drop_last=True, shuffle=False)
class LMModel5(Module):
def __init__(self, vocab_sz, n_hidden, n_layers):
self.i_h = nn.Embedding(vocab_sz, n_hidden)
self.rnn = nn.RNN(n_hidden, n_hidden, n_layers, batch_first=True)
self.h_o = nn.Linear(n_hidden, vocab_sz)
self.h = torch.zeros(n_layers, bs, n_hidden)
def forward(self, x):
res,h = self.rnn(self.i_h(x), self.h)
self.h = h.detach()
return self.h_o(res)
def reset(self): self.h.zero_()
learn = Learner(dls, LMModel5(len(vocab), 64, 2), loss_func=CrossEntropyLossFlat(), metrics=accuracy, cbs=ModelReseter)
learn.fit_one_cycle(15, 3e-3)
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 3.055853 | 2.591640 | 0.437907 | 00:01 |
1 | 2.162359 | 1.787310 | 0.471598 | 00:01 |
2 | 1.710663 | 1.941807 | 0.321777 | 00:01 |
3 | 1.520783 | 1.999726 | 0.312012 | 00:01 |
4 | 1.330846 | 2.012902 | 0.413249 | 00:01 |
5 | 1.163297 | 1.896192 | 0.450684 | 00:01 |
6 | 1.033813 | 2.005209 | 0.434814 | 00:01 |
7 | 0.919090 | 2.047083 | 0.456706 | 00:01 |
8 | 0.822939 | 2.068031 | 0.468831 | 00:01 |
9 | 0.750180 | 2.136064 | 0.475098 | 00:01 |
10 | 0.695120 | 2.139140 | 0.485433 | 00:01 |
11 | 0.655752 | 2.155081 | 0.493652 | 00:01 |
12 | 0.629650 | 2.162583 | 0.498535 | 00:01 |
13 | 0.613583 | 2.171649 | 0.491048 | 00:01 |
14 | 0.604309 | 2.180355 | 0.487874 | 00:01 |
class LSTMCell(Module):
def __init__(self, ni, nh):
self.forget_gate = nn.Linear(ni + nh, nh)
self.input_gate = nn.Linear(ni + nh, nh)
self.cell_gate = nn.Linear(ni + nh, nh)
self.output_gate = nn.Linear(ni + nh, nh)
def forward(self, input, state):
h,c = state
h = torch.stack([h, input], dim=1)
forget = torch.sigmoid(self.forget_gate(h))
c = c * forget
inp = torch.sigmoid(self.input_gate(h))
cell = torch.tanh(self.cell_gate(h))
c = c + inp * cell
out = torch.sigmoid(self.output_gate(h))
h = outgate * torch.tanh(c)
return h, (h,c)
class LSTMCell(Module):
def __init__(self, ni, nh):
self.ih = nn.Linear(ni,4*nh)
self.hh = nn.Linear(nh,4*nh)
def forward(self, input, state):
h,c = state
#One big multiplication for all the gates is better than 4 smaller ones
gates = (self.ih(input) + self.hh(h)).chunk(4, 1)
ingate,forgetgate,outgate = map(torch.sigmoid, gates[:3])
cellgate = gates[3].tanh()
c = (forgetgate*c) + (ingate*cellgate)
h = outgate * c.tanh()
return h, (h,c)
t = torch.arange(0,10); t
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
t.chunk(2)
(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9]))
class LMModel6(Module):
def __init__(self, vocab_sz, n_hidden, n_layers):
self.i_h = nn.Embedding(vocab_sz, n_hidden)
self.rnn = nn.LSTM(n_hidden, n_hidden, n_layers, batch_first=True)
self.h_o = nn.Linear(n_hidden, vocab_sz)
self.h = [torch.zeros(2, bs, n_hidden) for _ in range(n_layers)]
def forward(self, x):
res,h = self.rnn(self.i_h(x), self.h)
self.h = [h_.detach() for h_ in h]
return self.h_o(res)
def reset(self):
for h in self.h: h.zero_()
learn = Learner(dls, LMModel6(len(vocab), 64, 2), loss_func=CrossEntropyLossFlat(), metrics=accuracy, cbs=ModelReseter)
learn.fit_one_cycle(15, 1e-2)
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 3.000821 | 2.663942 | 0.438314 | 00:02 |
1 | 2.139642 | 2.184780 | 0.240479 | 00:02 |
2 | 1.607275 | 1.812682 | 0.439779 | 00:02 |
3 | 1.347711 | 1.830982 | 0.497477 | 00:02 |
4 | 1.123113 | 1.937766 | 0.594401 | 00:02 |
5 | 0.852042 | 2.012127 | 0.631592 | 00:02 |
6 | 0.565494 | 1.312742 | 0.725749 | 00:02 |
7 | 0.347445 | 1.297934 | 0.711263 | 00:02 |
8 | 0.208191 | 1.441269 | 0.731201 | 00:02 |
9 | 0.126335 | 1.569952 | 0.737305 | 00:02 |
10 | 0.079761 | 1.427187 | 0.754150 | 00:02 |
11 | 0.052990 | 1.494990 | 0.745117 | 00:02 |
12 | 0.039008 | 1.393731 | 0.757894 | 00:02 |
13 | 0.031502 | 1.373210 | 0.758464 | 00:02 |
14 | 0.028068 | 1.368083 | 0.758464 | 00:02 |
class Dropout(Module):
def __init__(self, p): self.p = p
def forward(self, x):
if not self.training: return x
mask = x.new(*x.shape).bernoulli_(1-p)
return x * mask.div_(1-p)
class LMModel7(Module):
def __init__(self, vocab_sz, n_hidden, n_layers, p):
self.i_h = nn.Embedding(vocab_sz, n_hidden)
self.rnn = nn.LSTM(n_hidden, n_hidden, n_layers, batch_first=True)
self.drop = nn.Dropout(p)
self.h_o = nn.Linear(n_hidden, vocab_sz)
self.h_o.weight = self.i_h.weight
self.h = [torch.zeros(2, bs, n_hidden) for _ in range(n_layers)]
def forward(self, x):
raw,h = self.rnn(self.i_h(x), self.h)
out = self.drop(raw)
self.h = [h_.detach() for h_ in h]
return self.h_o(out),raw,out
def reset(self):
for h in self.h: h.zero_()
learn = Learner(dls, LMModel7(len(vocab), 64, 2, 0.5),
loss_func=CrossEntropyLossFlat(), metrics=accuracy,
cbs=[ModelReseter, RNNRegularizer(alpha=2, beta=1)])
learn = TextLearner(dls, LMModel7(len(vocab), 64, 2, 0.4),
loss_func=CrossEntropyLossFlat(), metrics=accuracy)
learn.fit_one_cycle(15, 1e-2, wd=0.1)
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 2.693885 | 2.013484 | 0.466634 | 00:02 |
1 | 1.685549 | 1.187310 | 0.629313 | 00:02 |
2 | 0.973307 | 0.791398 | 0.745605 | 00:02 |
3 | 0.555823 | 0.640412 | 0.794108 | 00:02 |
4 | 0.351802 | 0.557247 | 0.836100 | 00:02 |
5 | 0.244986 | 0.594977 | 0.807292 | 00:02 |
6 | 0.192231 | 0.511690 | 0.846761 | 00:02 |
7 | 0.162456 | 0.520370 | 0.858073 | 00:02 |
8 | 0.142664 | 0.525918 | 0.842285 | 00:02 |
9 | 0.128493 | 0.495029 | 0.858073 | 00:02 |
10 | 0.117589 | 0.464236 | 0.867188 | 00:02 |
11 | 0.109808 | 0.466550 | 0.869303 | 00:02 |
12 | 0.104216 | 0.455151 | 0.871826 | 00:02 |
13 | 0.100271 | 0.452659 | 0.873617 | 00:02 |
14 | 0.098121 | 0.458372 | 0.869385 | 00:02 |