%matplotlib inline %reload_ext autoreload %autoreload 2 from fastai.text import * PATH = Path('data/translate') TMP_PATH = PATH/'tmp' TMP_PATH.mkdir(exist_ok=True) fname='giga-fren.release2.fixed' en_fname = PATH/f'{fname}.en' fr_fname = PATH/f'{fname}.fr' re_eq = re.compile('^(Wh[^?.!]+\?)') re_fq = re.compile('^([^?.!]+\?)') lines = ((re_eq.search(eq), re_fq.search(fq)) for eq, fq in zip(open(en_fname, encoding='utf-8'), open(fr_fname, encoding='utf-8'))) qs = [(e.group(), f.group()) for e,f in lines if e and f] pickle.dump(qs, (PATH/'fr-en-qs.pkl').open('wb')) qs = pickle.load((PATH/'fr-en-qs.pkl').open('rb')) qs[:5], len(qs) en_qs,fr_qs = zip(*qs) en_tok = Tokenizer.proc_all_mp(partition_by_cores(en_qs)) fr_tok = Tokenizer.proc_all_mp(partition_by_cores(fr_qs), 'fr') en_tok[0], fr_tok[0] np.percentile([len(o) for o in en_tok], 90), np.percentile([len(o) for o in fr_tok], 90) keep = np.array([len(o)<30 for o in en_tok]) en_tok = np.array(en_tok)[keep] fr_tok = np.array(fr_tok)[keep] pickle.dump(en_tok, (PATH/'en_tok.pkl').open('wb')) pickle.dump(fr_tok, (PATH/'fr_tok.pkl').open('wb')) en_tok = pickle.load((PATH/'en_tok.pkl').open('rb')) fr_tok = pickle.load((PATH/'fr_tok.pkl').open('rb')) def toks2ids(tok,pre): freq = Counter(p for o in tok for p in o) itos = [o for o,c in freq.most_common(40000)] itos.insert(0, '_bos_') itos.insert(1, '_pad_') itos.insert(2, '_eos_') itos.insert(3, '_unk') stoi = collections.defaultdict(lambda: 3, {v:k for k,v in enumerate(itos)}) ids = np.array([([stoi[o] for o in p] + [2]) for p in tok]) np.save(TMP_PATH/f'{pre}_ids.npy', ids) pickle.dump(itos, open(TMP_PATH/f'{pre}_itos.pkl', 'wb')) return ids,itos,stoi en_ids,en_itos,en_stoi = toks2ids(en_tok,'en') fr_ids,fr_itos,fr_stoi = toks2ids(fr_tok,'fr') def load_ids(pre): ids = np.load(TMP_PATH/f'{pre}_ids.npy') itos = pickle.load(open(TMP_PATH/f'{pre}_itos.pkl', 'rb')) stoi = collections.defaultdict(lambda: 3, {v:k for k,v in enumerate(itos)}) return ids,itos,stoi en_ids,en_itos,en_stoi = load_ids('en') fr_ids,fr_itos,fr_stoi = load_ids('fr') [fr_itos[o] for o in fr_ids[0]], len(en_itos), len(fr_itos) # ! pip install git+https://github.com/facebookresearch/fastText.git import fastText as ft en_vecs = ft.load_model(str((PATH/'wiki.en.bin'))) fr_vecs = ft.load_model(str((PATH/'wiki.fr.bin'))) def get_vecs(lang, ft_vecs): vecd = {w:ft_vecs.get_word_vector(w) for w in ft_vecs.get_words()} pickle.dump(vecd, open(PATH/f'wiki.{lang}.pkl','wb')) return vecd en_vecd = get_vecs('en', en_vecs) fr_vecd = get_vecs('fr', fr_vecs) en_vecd = pickle.load(open(PATH/'wiki.en.pkl','rb')) fr_vecd = pickle.load(open(PATH/'wiki.fr.pkl','rb')) ft_words = en_vecs.get_words(include_freq=True) ft_word_dict = {k:v for k,v in zip(*ft_words)} ft_words = sorted(ft_word_dict.keys(), key=lambda x: ft_word_dict[x]) len(ft_words) dim_en_vec = len(en_vecd[',']) dim_fr_vec = len(fr_vecd[',']) dim_en_vec,dim_fr_vec en_vecs = np.stack(list(en_vecd.values())) en_vecs.mean(),en_vecs.std() enlen_90 = int(np.percentile([len(o) for o in en_ids], 99)) frlen_90 = int(np.percentile([len(o) for o in fr_ids], 97)) enlen_90,frlen_90 en_ids_tr = np.array([o[:enlen_90] for o in en_ids]) fr_ids_tr = np.array([o[:frlen_90] for o in fr_ids]) class Seq2SeqDataset(Dataset): def __init__(self, x, y): self.x,self.y = x,y def __getitem__(self, idx): return A(self.x[idx], self.y[idx]) def __len__(self): return len(self.x) np.random.seed(42) trn_keep = np.random.rand(len(en_ids_tr))>0.1 en_trn,fr_trn = en_ids_tr[trn_keep],fr_ids_tr[trn_keep] en_val,fr_val = en_ids_tr[~trn_keep],fr_ids_tr[~trn_keep] len(en_trn),len(en_val) trn_ds = Seq2SeqDataset(fr_trn,en_trn) val_ds = Seq2SeqDataset(fr_val,en_val) bs=125 trn_samp = SortishSampler(en_trn, key=lambda x: len(en_trn[x]), bs=bs) val_samp = SortSampler(en_val, key=lambda x: len(en_val[x])) trn_dl = DataLoader(trn_ds, bs, transpose=True, transpose_y=True, num_workers=1, pad_idx=1, pre_pad=False, sampler=trn_samp) val_dl = DataLoader(val_ds, int(bs*1.6), transpose=True, transpose_y=True, num_workers=1, pad_idx=1, pre_pad=False, sampler=val_samp) md = ModelData(PATH, trn_dl, val_dl) it = iter(trn_dl) its = [next(it) for i in range(5)] [(len(x),len(y)) for x,y in its] def create_emb(vecs, itos, em_sz): emb = nn.Embedding(len(itos), em_sz, padding_idx=1) wgts = emb.weight.data miss = [] for i,w in enumerate(itos): try: wgts[i] = torch.from_numpy(vecs[w]*3) except: miss.append(w) print(len(miss),miss[5:10]) return emb nh,nl = 256,2 class Seq2SeqRNN(nn.Module): def __init__(self, vecs_enc, itos_enc, em_sz_enc, vecs_dec, itos_dec, em_sz_dec, nh, out_sl, nl=2): super().__init__() self.nl,self.nh,self.out_sl = nl,nh,out_sl self.emb_enc = create_emb(vecs_enc, itos_enc, em_sz_enc) self.emb_enc_drop = nn.Dropout(0.15) self.gru_enc = nn.GRU(em_sz_enc, nh, num_layers=nl, dropout=0.25) self.out_enc = nn.Linear(nh, em_sz_dec, bias=False) self.emb_dec = create_emb(vecs_dec, itos_dec, em_sz_dec) self.gru_dec = nn.GRU(em_sz_dec, em_sz_dec, num_layers=nl, dropout=0.1) self.out_drop = nn.Dropout(0.35) self.out = nn.Linear(em_sz_dec, len(itos_dec)) self.out.weight.data = self.emb_dec.weight.data def forward(self, inp): sl,bs = inp.size() h = self.initHidden(bs) emb = self.emb_enc_drop(self.emb_enc(inp)) enc_out, h = self.gru_enc(emb, h) h = self.out_enc(h) dec_inp = V(torch.zeros(bs).long()) res = [] for i in range(self.out_sl): emb = self.emb_dec(dec_inp).unsqueeze(0) outp, h = self.gru_dec(emb, h) outp = self.out(self.out_drop(outp[0])) res.append(outp) dec_inp = V(outp.data.max(1)[1]) if (dec_inp==1).all(): break return torch.stack(res) def initHidden(self, bs): return V(torch.zeros(self.nl, bs, self.nh)) def seq2seq_loss(input, target): sl,bs = target.size() sl_in,bs_in,nc = input.size() if sl>sl_in: input = F.pad(input, (0,0,0,0,0,sl-sl_in)) input = input[:sl] return F.cross_entropy(input.view(-1,nc), target.view(-1))#, ignore_index=1) opt_fn = partial(optim.Adam, betas=(0.8, 0.99)) rnn = Seq2SeqRNN(fr_vecd, fr_itos, dim_fr_vec, en_vecd, en_itos, dim_en_vec, nh, enlen_90) learn = RNN_Learner(md, SingleModel(to_gpu(rnn)), opt_fn=opt_fn) learn.crit = seq2seq_loss learn.lr_find() learn.sched.plot() lr=3e-3 learn.fit(lr, 1, cycle_len=12, use_clr=(20,10)) learn.save('initial') learn.load('initial') x,y = next(iter(val_dl)) probs = learn.model(V(x)) preds = to_np(probs.max(2)[1]) for i in range(180,190): print(' '.join([fr_itos[o] for o in x[:,i] if o != 1])) print(' '.join([en_itos[o] for o in y[:,i] if o != 1])) print(' '.join([en_itos[o] for o in preds[:,i] if o!=1])) print() class Seq2SeqRNN_Bidir(nn.Module): def __init__(self, vecs_enc, itos_enc, em_sz_enc, vecs_dec, itos_dec, em_sz_dec, nh, out_sl, nl=2): super().__init__() self.emb_enc = create_emb(vecs_enc, itos_enc, em_sz_enc) self.nl,self.nh,self.out_sl = nl,nh,out_sl self.gru_enc = nn.GRU(em_sz_enc, nh, num_layers=nl, dropout=0.25, bidirectional=True) self.out_enc = nn.Linear(nh*2, em_sz_dec, bias=False) self.drop_enc = nn.Dropout(0.05) self.emb_dec = create_emb(vecs_dec, itos_dec, em_sz_dec) self.gru_dec = nn.GRU(em_sz_dec, em_sz_dec, num_layers=nl, dropout=0.1) self.emb_enc_drop = nn.Dropout(0.15) self.out_drop = nn.Dropout(0.35) self.out = nn.Linear(em_sz_dec, len(itos_dec)) self.out.weight.data = self.emb_dec.weight.data def forward(self, inp): sl,bs = inp.size() h = self.initHidden(bs) emb = self.emb_enc_drop(self.emb_enc(inp)) enc_out, h = self.gru_enc(emb, h) h = h.view(2,2,bs,-1).permute(0,2,1,3).contiguous().view(2,bs,-1) h = self.out_enc(self.drop_enc(h)) dec_inp = V(torch.zeros(bs).long()) res = [] for i in range(self.out_sl): emb = self.emb_dec(dec_inp).unsqueeze(0) outp, h = self.gru_dec(emb, h) outp = self.out(self.out_drop(outp[0])) res.append(outp) dec_inp = V(outp.data.max(1)[1]) if (dec_inp==1).all(): break return torch.stack(res) def initHidden(self, bs): return V(torch.zeros(self.nl*2, bs, self.nh)) rnn = Seq2SeqRNN_Bidir(fr_vecd, fr_itos, dim_fr_vec, en_vecd, en_itos, dim_en_vec, nh, enlen_90) learn = RNN_Learner(md, SingleModel(to_gpu(rnn)), opt_fn=opt_fn) learn.crit = seq2seq_loss learn.fit(lr, 1, cycle_len=12, use_clr=(20,10)) learn.save('bidir') class Seq2SeqStepper(Stepper): def step(self, xs, y, epoch): self.m.pr_force = (10-epoch)*0.1 if epoch<10 else 0 xtra = [] output = self.m(*xs, y) if isinstance(output,tuple): output,*xtra = output self.opt.zero_grad() loss = raw_loss = self.crit(output, y) if self.reg_fn: loss = self.reg_fn(output, xtra, raw_loss) loss.backward() if self.clip: # Gradient clipping nn.utils.clip_grad_norm(trainable_params_(self.m), self.clip) self.opt.step() return raw_loss.data[0] class Seq2SeqRNN_TeacherForcing(nn.Module): def __init__(self, vecs_enc, itos_enc, em_sz_enc, vecs_dec, itos_dec, em_sz_dec, nh, out_sl, nl=2): super().__init__() self.emb_enc = create_emb(vecs_enc, itos_enc, em_sz_enc) self.nl,self.nh,self.out_sl = nl,nh,out_sl self.gru_enc = nn.GRU(em_sz_enc, nh, num_layers=nl, dropout=0.25) self.out_enc = nn.Linear(nh, em_sz_dec, bias=False) self.emb_dec = create_emb(vecs_dec, itos_dec, em_sz_dec) self.gru_dec = nn.GRU(em_sz_dec, em_sz_dec, num_layers=nl, dropout=0.1) self.emb_enc_drop = nn.Dropout(0.15) self.out_drop = nn.Dropout(0.35) self.out = nn.Linear(em_sz_dec, len(itos_dec)) self.out.weight.data = self.emb_dec.weight.data self.pr_force = 1. def forward(self, inp, y=None): sl,bs = inp.size() h = self.initHidden(bs) emb = self.emb_enc_drop(self.emb_enc(inp)) enc_out, h = self.gru_enc(emb, h) h = self.out_enc(h) dec_inp = V(torch.zeros(bs).long()) res = [] for i in range(self.out_sl): emb = self.emb_dec(dec_inp).unsqueeze(0) outp, h = self.gru_dec(emb, h) outp = self.out(self.out_drop(outp[0])) res.append(outp) dec_inp = V(outp.data.max(1)[1]) if (dec_inp==1).all(): break if (y is not None) and (random.random()=len(y): break dec_inp = y[i] return torch.stack(res) def initHidden(self, bs): return V(torch.zeros(self.nl, bs, self.nh)) rnn = Seq2SeqRNN_TeacherForcing(fr_vecd, fr_itos, dim_fr_vec, en_vecd, en_itos, dim_en_vec, nh, enlen_90) learn = RNN_Learner(md, SingleModel(to_gpu(rnn)), opt_fn=opt_fn) learn.crit = seq2seq_loss learn.fit(lr, 1, cycle_len=12, use_clr=(20,10), stepper=Seq2SeqStepper) learn.save('forcing') def rand_t(*sz): return torch.randn(sz)/math.sqrt(sz[0]) def rand_p(*sz): return nn.Parameter(rand_t(*sz)) class Seq2SeqAttnRNN(nn.Module): def __init__(self, vecs_enc, itos_enc, em_sz_enc, vecs_dec, itos_dec, em_sz_dec, nh, out_sl, nl=2): super().__init__() self.emb_enc = create_emb(vecs_enc, itos_enc, em_sz_enc) self.nl,self.nh,self.out_sl = nl,nh,out_sl self.gru_enc = nn.GRU(em_sz_enc, nh, num_layers=nl, dropout=0.25) self.out_enc = nn.Linear(nh, em_sz_dec, bias=False) self.emb_dec = create_emb(vecs_dec, itos_dec, em_sz_dec) self.gru_dec = nn.GRU(em_sz_dec, em_sz_dec, num_layers=nl, dropout=0.1) self.emb_enc_drop = nn.Dropout(0.15) self.out_drop = nn.Dropout(0.35) self.out = nn.Linear(em_sz_dec, len(itos_dec)) self.out.weight.data = self.emb_dec.weight.data self.W1 = rand_p(nh, em_sz_dec) self.l2 = nn.Linear(em_sz_dec, em_sz_dec) self.l3 = nn.Linear(em_sz_dec+nh, em_sz_dec) self.V = rand_p(em_sz_dec) def forward(self, inp, y=None, ret_attn=False): sl,bs = inp.size() h = self.initHidden(bs) emb = self.emb_enc_drop(self.emb_enc(inp)) enc_out, h = self.gru_enc(emb, h) h = self.out_enc(h) dec_inp = V(torch.zeros(bs).long()) res,attns = [],[] w1e = enc_out @ self.W1 for i in range(self.out_sl): w2h = self.l2(h[-1]) u = F.tanh(w1e + w2h) a = F.softmax(u @ self.V, 0) attns.append(a) Xa = (a.unsqueeze(2) * enc_out).sum(0) emb = self.emb_dec(dec_inp) wgt_enc = self.l3(torch.cat([emb, Xa], 1)) outp, h = self.gru_dec(wgt_enc.unsqueeze(0), h) outp = self.out(self.out_drop(outp[0])) res.append(outp) dec_inp = V(outp.data.max(1)[1]) if (dec_inp==1).all(): break if (y is not None) and (random.random()=len(y): break dec_inp = y[i] res = torch.stack(res) if ret_attn: res = res,torch.stack(attns) return res def initHidden(self, bs): return V(torch.zeros(self.nl, bs, self.nh)) rnn = Seq2SeqAttnRNN(fr_vecd, fr_itos, dim_fr_vec, en_vecd, en_itos, dim_en_vec, nh, enlen_90) learn = RNN_Learner(md, SingleModel(to_gpu(rnn)), opt_fn=opt_fn) learn.crit = seq2seq_loss lr=2e-3 learn.fit(lr, 1, cycle_len=15, use_clr=(20,10), stepper=Seq2SeqStepper) learn.save('attn') learn.load('attn') x,y = next(iter(val_dl)) probs,attns = learn.model(V(x),ret_attn=True) preds = to_np(probs.max(2)[1]) for i in range(180,190): print(' '.join([fr_itos[o] for o in x[:,i] if o != 1])) print(' '.join([en_itos[o] for o in y[:,i] if o != 1])) print(' '.join([en_itos[o] for o in preds[:,i] if o!=1])) print() attn = to_np(attns[...,180]) fig, axes = plt.subplots(3, 3, figsize=(15, 10)) for i,ax in enumerate(axes.flat): ax.plot(attn[i]) class Seq2SeqRNN_All(nn.Module): def __init__(self, vecs_enc, itos_enc, em_sz_enc, vecs_dec, itos_dec, em_sz_dec, nh, out_sl, nl=2): super().__init__() self.emb_enc = create_emb(vecs_enc, itos_enc, em_sz_enc) self.nl,self.nh,self.out_sl = nl,nh,out_sl self.gru_enc = nn.GRU(em_sz_enc, nh, num_layers=nl, dropout=0.25, bidirectional=True) self.out_enc = nn.Linear(nh*2, em_sz_dec, bias=False) self.drop_enc = nn.Dropout(0.25) self.emb_dec = create_emb(vecs_dec, itos_dec, em_sz_dec) self.gru_dec = nn.GRU(em_sz_dec, em_sz_dec, num_layers=nl, dropout=0.1) self.emb_enc_drop = nn.Dropout(0.15) self.out_drop = nn.Dropout(0.35) self.out = nn.Linear(em_sz_dec, len(itos_dec)) self.out.weight.data = self.emb_dec.weight.data self.W1 = rand_p(nh*2, em_sz_dec) self.l2 = nn.Linear(em_sz_dec, em_sz_dec) self.l3 = nn.Linear(em_sz_dec+nh*2, em_sz_dec) self.V = rand_p(em_sz_dec) def forward(self, inp, y=None): sl,bs = inp.size() h = self.initHidden(bs) emb = self.emb_enc_drop(self.emb_enc(inp)) enc_out, h = self.gru_enc(emb, h) h = h.view(2,2,bs,-1).permute(0,2,1,3).contiguous().view(2,bs,-1) h = self.out_enc(self.drop_enc(h)) dec_inp = V(torch.zeros(bs).long()) res,attns = [],[] w1e = enc_out @ self.W1 for i in range(self.out_sl): w2h = self.l2(h[-1]) u = F.tanh(w1e + w2h) a = F.softmax(u @ self.V, 0) attns.append(a) Xa = (a.unsqueeze(2) * enc_out).sum(0) emb = self.emb_dec(dec_inp) wgt_enc = self.l3(torch.cat([emb, Xa], 1)) outp, h = self.gru_dec(wgt_enc.unsqueeze(0), h) outp = self.out(self.out_drop(outp[0])) res.append(outp) dec_inp = V(outp.data.max(1)[1]) if (dec_inp==1).all(): break if (y is not None) and (random.random()=len(y): break dec_inp = y[i] return torch.stack(res) def initHidden(self, bs): return V(torch.zeros(self.nl*2, bs, self.nh)) rnn = Seq2SeqRNN_All(fr_vecd, fr_itos, dim_fr_vec, en_vecd, en_itos, dim_en_vec, nh, enlen_90) learn = RNN_Learner(md, SingleModel(to_gpu(rnn)), opt_fn=opt_fn) learn.crit = seq2seq_loss learn.fit(lr, 1, cycle_len=15, use_clr=(20,10), stepper=Seq2SeqStepper) x,y = next(iter(val_dl)) probs = learn.model(V(x)) preds = to_np(probs.max(2)[1]) for i in range(180,190): print(' '.join([fr_itos[o] for o in x[:,i] if o != 1])) print(' '.join([en_itos[o] for o in y[:,i] if o != 1])) print(' '.join([en_itos[o] for o in preds[:,i] if o!=1])) print()