Based on: https://github.com/keras-team/keras/blob/master/examples/lstm_text_generation.py
This notebook explores the idea of generate text from Nietzsche's writings.
At least 20 training epochs are required before the generated text starts sounding coherent.
It is recommended to run this notebook on a GPU, as recurrent networks are quite computationally intensive.
If you try this script on new data, make sure your corpus has at least ~100k characters. ~1M is better.
import conx as cx
Using TensorFlow backend. ConX, version 3.7.5
cx.download('https://s3.amazonaws.com/text-datasets/nietzsche.txt')
Using cached https://s3.amazonaws.com/text-datasets/nietzsche.txt as './nietzsche.txt'.
text = open("nietzsche.txt").read().lower()
len(text)
600893
text[:100]
'preface\n\n\nsupposing that truth is a woman--what then? is there not ground\nfor suspecting that all ph'
chars = sorted(list(set(text)))
print('total unique chars:', len(chars))
total unique chars: 57
"".join(chars)
'\n !"\'(),-.0123456789:;=?[]_abcdefghijklmnopqrstuvwxyzäæéë'
char_indices = dict((c, i) for i, c in enumerate(chars))
indices_char = dict((i, c) for i, c in enumerate(chars))
print("char to index:", char_indices)
print("index to char:", indices_char)
char to index: {'\n': 0, ' ': 1, '!': 2, '"': 3, "'": 4, '(': 5, ')': 6, ',': 7, '-': 8, '.': 9, '0': 10, '1': 11, '2': 12, '3': 13, '4': 14, '5': 15, '6': 16, '7': 17, '8': 18, '9': 19, ':': 20, ';': 21, '=': 22, '?': 23, '[': 24, ']': 25, '_': 26, 'a': 27, 'b': 28, 'c': 29, 'd': 30, 'e': 31, 'f': 32, 'g': 33, 'h': 34, 'i': 35, 'j': 36, 'k': 37, 'l': 38, 'm': 39, 'n': 40, 'o': 41, 'p': 42, 'q': 43, 'r': 44, 's': 45, 't': 46, 'u': 47, 'v': 48, 'w': 49, 'x': 50, 'y': 51, 'z': 52, 'ä': 53, 'æ': 54, 'é': 55, 'ë': 56} index to char: {0: '\n', 1: ' ', 2: '!', 3: '"', 4: "'", 5: '(', 6: ')', 7: ',', 8: '-', 9: '.', 10: '0', 11: '1', 12: '2', 13: '3', 14: '4', 15: '5', 16: '6', 17: '7', 18: '8', 19: '9', 20: ':', 21: ';', 22: '=', 23: '?', 24: '[', 25: ']', 26: '_', 27: 'a', 28: 'b', 29: 'c', 30: 'd', 31: 'e', 32: 'f', 33: 'g', 34: 'h', 35: 'i', 36: 'j', 37: 'k', 38: 'l', 39: 'm', 40: 'n', 41: 'o', 42: 'p', 43: 'q', 44: 'r', 45: 's', 46: 't', 47: 'u', 48: 'v', 49: 'w', 50: 'x', 51: 'y', 52: 'z', 53: 'ä', 54: 'æ', 55: 'é', 56: 'ë'}
Cut the text in semi-redundant sequences of maxlen characters:
maxlen = 40
step = 3
sequences = []
for i in range(0, len(text) - maxlen - 1, step):
sequences.append(text[i: i + maxlen + 1])
print('sequences:', len(sequences))
sequences: 200284
sequences[0:10]
['preface\n\n\nsupposing that truth is a woman', 'face\n\n\nsupposing that truth is a woman--w', 'e\n\n\nsupposing that truth is a woman--what', '\nsupposing that truth is a woman--what th', 'pposing that truth is a woman--what then?', 'sing that truth is a woman--what then? is', 'g that truth is a woman--what then? is th', 'hat truth is a woman--what then? is there', ' truth is a woman--what then? is there no', 'uth is a woman--what then? is there not g']
len(sequences[0])
41
(len(sequences), maxlen, len(chars))
(200284, 40, 57)
cx.onehot(2, 5)
[0, 0, 1, 0, 0]
char_encode = {ch: cx.onehot(char_indices[ch], len(chars)) for ch in chars}
print(char_encode["a"])
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
inputs = []
targets = []
for s in range(len(sequences)):
current = [char_encode[ch] for ch in sequences[s]]
inputs.append(current[:-1])
targets.append(current[-1])
cx.shape(inputs)
(200284, 40, 57)
cx.shape(targets)
(200284, 57)
net = cx.Network("LSTM Text Generation")
net.add(
cx.Layer("input", (maxlen, len(chars))),
cx.LSTMLayer("lstm", 128),
cx.Layer("output", len(chars), activation="softmax"),
)
net.connect()
net.compile(error="categorical_crossentropy", optimizer="RMSProp", lr=0.01)
net.dataset.load(inputs=inputs, targets=targets)
net.dataset.summary()
_________________________________________________________________ LSTM Text Generation Dataset: Patterns Shape Range ================================================================= inputs (40, 57) (0.0, 1.0) targets (57,) (0.0, 1.0) ================================================================= Total patterns: 200284 Training patterns: 200284 Testing patterns: 0 _________________________________________________________________
net.dashboard()
Dashboard(children=(Accordion(children=(HBox(children=(VBox(children=(Select(description='Dataset:', index=1, …
net.dataset.chop(.99)
net.dataset.summary()
_________________________________________________________________ LSTM Text Generation Dataset: Patterns Shape Range ================================================================= inputs (40, 57) (0.0, 1.0) targets (57,) (0.0, 1.0) ================================================================= Total patterns: 2003 Training patterns: 2003 Testing patterns: 0 _________________________________________________________________
"".join([indices_char[cx.argmax(v)] for v in net.dataset.inputs[0]])
probs = sorted(enumerate(net.propagate(net.dataset.inputs[0])),
key=lambda v: v[1], reverse=True)
[(indices_char[w[0]], round(w[1], 2)) for w in probs]
[('ë', 0.02), ("'", 0.02), ('m', 0.02), ('e', 0.02), ('t', 0.02), ('j', 0.02), ('s', 0.02), ('0', 0.02), ('l', 0.02), ('4', 0.02), ('é', 0.02), ('i', 0.02), ('?', 0.02), ('w', 0.02), ('r', 0.02), ('æ', 0.02), ('5', 0.02), ('(', 0.02), ('6', 0.02), ('"', 0.02), (';', 0.02), ('q', 0.02), ('-', 0.02), ('[', 0.02), ('3', 0.02), ('c', 0.02), ('p', 0.02), ('9', 0.02), (',', 0.02), ('_', 0.02), ('b', 0.02), ('y', 0.02), ('7', 0.02), (':', 0.02), ('ä', 0.02), ('.', 0.02), ('k', 0.02), ('2', 0.02), ('!', 0.02), ('x', 0.02), ('o', 0.02), ('n', 0.02), ('g', 0.02), ('f', 0.02), ('d', 0.02), (' ', 0.02), ('\n', 0.02), ('8', 0.02), ('z', 0.02), (')', 0.02), ('u', 0.02), ('v', 0.02), (']', 0.02), ('a', 0.02), ('1', 0.02), ('=', 0.02), ('h', 0.02)]
from IPython.display import clear_output
def on_epoch_end(network, epoch=None, logs=None):
import io
epoch = epoch if epoch is not None else network.epoch_count
s = io.StringIO()
s.write("\n")
s.write('----- Generating text after Epoch: %d\n' % epoch)
start_index = cx.choice(len(text) - maxlen - 1)
for diversity in [0.2, 0.5, 1.0, 1.2]:
sentence = text[start_index: start_index + maxlen]
s.write('----- diversity: %s\n' % diversity)
s.write('----- Generating with seed: "' + sentence + '"\n\n')
s.write(sentence)
current = [char_encode[ch] for ch in sentence]
for i in range(400):
output = network.propagate(current)
next_index = cx.choice(p=output, temperature=diversity, index=True)
s.write(indices_char[next_index])
next_char = char_encode[indices_char[next_index]]
current = current[1:]
current.append(next_char)
s.write("\n")
clear_output()
print(s.getvalue())
%%time
net.train(1, batch_size=128, plot=False)
Training... | Training | Training Epochs | Error | Accuracy ------ | --------- | --------- # 1 | 3.23989 | 0.15127 # 2 | 2.95753 | 0.18822 ======================================================== # 2 | 2.95753 | 0.18822 CPU times: user 3.86 s, sys: 613 ms, total: 4.47 s Wall time: 1.54 s
%%time
net.train(1, batch_size=128, plot=False)
Training... | Training | Training Epochs | Error | Accuracy ------ | --------- | --------- # 2 | 2.95753 | 0.18822 # 3 | 2.81788 | 0.23665 ======================================================== # 3 | 2.81788 | 0.23665 CPU times: user 3.98 s, sys: 634 ms, total: 4.61 s Wall time: 1.6 s
"".join([indices_char[cx.argmax(v)] for v in net.dataset.inputs[0]])
'preface\n\n\nsupposing that truth is a woma'
probs = sorted(enumerate(net.propagate(net.dataset.inputs[0])),
key=lambda v: v[1], reverse=True)
probs[0]
(1, 0.1335555613040924)
[(indices_char[w[0]], round(w[1], 2)) for w in probs]
[(' ', 0.13), ('s', 0.13), ('t', 0.12), ('n', 0.11), ('e', 0.1), ('l', 0.07), ('a', 0.06), ('i', 0.04), ('o', 0.03), ('g', 0.03), ('d', 0.02), ('r', 0.02), ('c', 0.02), ('m', 0.01), ('\n', 0.01), ('b', 0.01), ('u', 0.01), ('p', 0.01), ('f', 0.01), ('y', 0.01), (',', 0.01), ('v', 0.01), ('k', 0.01), ('h', 0.0), ('w', 0.0), ('?', 0.0), ('-', 0.0), ('q', 0.0), ('.', 0.0), (':', 0.0), (';', 0.0), ('!', 0.0), ('"', 0.0), ('x', 0.0), ('z', 0.0), ('1', 0.0), ('8', 0.0), ('j', 0.0), ('9', 0.0), ('6', 0.0), ('ä', 0.0), (')', 0.0), ('=', 0.0), ("'", 0.0), ('[', 0.0), ('(', 0.0), ('5', 0.0), ('é', 0.0), ('ë', 0.0), ('7', 0.0), ('4', 0.0), ('2', 0.0), (']', 0.0), ('_', 0.0), ('3', 0.0), ('æ', 0.0), ('0', 0.0)]
on_epoch_end(net)
----- Generating text after Epoch: 3 ----- diversity: 0.2 ----- Generating with seed: ", than the happily preserved petit fait " , than the happily preserved petit fait te t o e s the t e e te the it e an as te an t e an the the te e s e an t an e e o e t e t e an at o e at t t e an e et an te t e ae an to e e te an at as at t e at an an at an e s e e te the an an the at an an te te the the te te e t t e an an e s at e t e e e e o at as e e an an an e e e o te t t e an as the as at o as the ae as te te t t t an the an e te t e as an at an e t an te t an an at ----- diversity: 0.5 ----- Generating with seed: ", than the happily preserved petit fait " , than the happily preserved petit fait iint an ane ta ecale s ane the to s an ete acas e e it the eet soe ian ie as thas shan t ie ool t e ale se e t b tae e an e t ie e in e so tot ee e s a e io o oe e te te t s as at s s as t ems g o at oo o s onie ie o aas ts to e te t e it the at ie as t ote it t t te se t oe e se s as te to le as to aio s te te t e at an abe as eie on it t og oe o t aiae ee ate sa e at ele ocoa n se a ----- diversity: 1.0 ----- Generating with seed: ", than the happily preserved petit fait " , than the happily preserved petit fait ?snwigeika eine wec v bso nay w soe an se -nm o d s tht enl adrsse t t oceangrme b vt "twtio. wtoom t elonm dpd lhiato he egt, ooraeiabe ot tteetepy it thtey a en ra t asu is ce w getr eonoi eed tk l? ao d nt aae a angsn n o, b asenteenertisio,,eiée goseqbiae imy s eo h aceuato yteyemieni, eoaeea siag nk anrbalre alasintot, ler o! oed annd suo. ritaastsilteef o g nus c st le aouueere s d ----- diversity: 1.2 ----- Generating with seed: ", than the happily preserved petit fait " , than the happily preserved petit fait metyebte s efoalliitul, an e , e, t smesol, s sctoovifaal igsuaew oy s uk eris ta hc en- saeue i"upilee ss hes r e he :seed iey s decenic opeemot , nauphln: wsa k loekcosahf o doinoa1l, ,omesroqidgdo s qtoos n nn ltnetgaa:os ssclodidy anuterngoxieoseans teoyregeonseny att usle wepms se shiaaieffeaaeyis,ratrepsefitit8ox por rei-gn dwo l e oa slbn:o rr nbos mede is ysbl c e oel r mry rpuitkiiaa ss