第一輪是讀進我們基本套件, 第二輪是 TensorFlow 用到的套件。
%matplotlib inline
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
import pickle
from tensorflow.keras.models import load_model, model_from_json
這裡是參考《精通機器學習-使用 Scikit-Learn, Keras 與 TensorFlow》這本書中莎士比亞生成器的部份寫成的。架構很簡單,就每輸入 100 個字, 預測下一個字是什麼。雙層 LSTM, 每層 128 個神經元。訓練 10 次, 在 1080Ti GPU 的電腦上大概花了 10 個小時。
以下會從我的 GitHub 重新讀入模型, 已下載模型請直接跳到下一段。
from urllib.request import urlretrieve
urlretrieve("https://raw.githubusercontent.com/yenlung/Deep-Learning-Basics/master/dream_rnn_architecture.json", "architecture.json")
urlretrieve("https://github.com/yenlung/Deep-Learning-Basics/raw/master/dream_rnn_weights.h5", "weights.h5")
urlretrieve("https://github.com/yenlung/Deep-Learning-Basics/raw/master/dream_tokenizer2.pkl", "tokenizer.pkl")
('tokenizer.pkl', <http.client.HTTPMessage at 0x7f1a7f4a5f10>)
f = open('architecture.json', 'r')
loaded_model = f.read()
f.close()
model = model_from_json(loaded_model)
WARNING:tensorflow:Layer lstm_1 will not use cuDNN kernels since it doesn't meet the criteria. It will use a generic GPU kernel as fallback when running on GPU.
model.load_weights("weights.h5")
f = open('tokenizer.pkl', 'rb')
tokenizer = pickle.load(f)
f.close()
#from google.colab import drive
#drive.mount('/content/drive')
#%cd '/content/drive/MyDrive/Colab Notebooks/'
#model = load_model('dream_rnn')
#f = open('dream_tokenizer2.pkl', 'rb')
#tokenizer = pickle.load(f)
#f.close()
首先 max_id
是記錄《紅樓夢》用到的所有不同的中文字字數, 包括新式標點符號。很讓人驚訝 (?) 的是, 字數並沒有想像中多。
max_id = len(tokenizer.word_index)
接下來是一段文字, 我們用事先訓練好的 tokenizer 換成一段數字, 最後用 one-hot encoding 回傳。
def preprocess(texts):
X = np.array(tokenizer.texts_to_sequences([texts]))-1
return tf.one_hot(X, max_id)
這段程式主要依輸入的一段文字, 用我們的 model 去預測下一個字。注意像平常的分類問題, 這裡輸出是每人個字出現機率最高的。但都照這樣, 我們輸入同一段文字, 之後出現的文字永遠是一樣的! 常用的手法是去設定 temperature
, temperature
接近 0, 大致上就取機率最高的字; temperature
越大就越隨機。太隨機就變成亂數取字! 一般 temperature
設 1 左右效果最佳。
def next_char(texts, temperature=1):
X_new = preprocess(texts)
y_predict = model.predict(X_new)[0, -1:, :]
rescaled_logits = tf.math.log(y_predict) / temperature
char_id = tf.random.categorical(rescaled_logits, num_samples=1) + 1
return tokenizer.sequences_to_texts(char_id.numpy())[0]
最後就一段文字進來,再產生 n_chars
這麼多個字。我們原本一段文字只能生一個字, 那就一次生一個字, 最後要生多少個字就生多少個字。
原本訓練我們一段是 100 個字去訓練的, 這裡超過 100 字時我們就取最後 100 個字丟入模型。
def complete_text(texts, n_chars=50, temperature=1):
n_chars=int(n_chars)
for _ in range(n_chars):
texts = texts + next_char(texts[-100:], temperature)
return texts
做成 web app 前, 先來測試一下。
complete_text("自孫悟空從石頭中蹦出來之後,", n_chars=300, temperature=0.2)
'自孫悟空從石頭中蹦出來之後,說道:「寶玉既有如此,但未知之言,但未知之方,可不是遁世離群、無關之處,望二二叔之基,就是寶玉之處,或者塵中勞動,聊倩鳥呼歸去;山靈好客,更從石化飛來,亦未可知。」雨村聽了,益發驚異:「你們不知道此,何以如此?」士隱道:「神仙名長,何如此?」士隱道:「此事不知。」雨村聽了,益發驚異:「請問仙長,何必如此?」士隱道:「此事不知。」雨村聽了,益發驚異:「請問仙長,何必如此?」士隱道:「此事傳出,夙世前因,自有一概不知。但是敝族閨秀,如此之多,何元妃以下,凡事的事,託他傳遍,知道奇而不奇,俗而不俗,真而不真,假而不假。或者塵夢勞人,聊倩鳥呼歸去;山靈好客,更從石化飛來,亦未可知。」雨村聽畢,仍舊擲下'
!pip install gradio
Collecting gradio Downloading https://files.pythonhosted.org/packages/a2/31/9fc0bfcfb5e3be94350917640a709daca53ab3b35440d4ed67e60bf05567/gradio-2.1.2-py3-none-any.whl (2.5MB) |████████████████████████████████| 2.5MB 6.8MB/s Collecting paramiko Downloading https://files.pythonhosted.org/packages/95/19/124e9287b43e6ff3ebb9cdea3e5e8e88475a873c05ccdf8b7e20d2c4201e/paramiko-2.7.2-py2.py3-none-any.whl (206kB) |████████████████████████████████| 215kB 36.7MB/s Collecting Flask-Cors>=3.0.8 Downloading https://files.pythonhosted.org/packages/db/84/901e700de86604b1c4ef4b57110d4e947c218b9997adf5d38fa7da493bce/Flask_Cors-3.0.10-py2.py3-none-any.whl Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from gradio) (2.23.0) Collecting ffmpy Downloading https://files.pythonhosted.org/packages/bf/e2/947df4b3d666bfdd2b0c6355d215c45d2d40f929451cb29a8a2995b29788/ffmpy-0.3.0.tar.gz Collecting flask-cachebuster Downloading https://files.pythonhosted.org/packages/74/47/f3e1fedfaad965c81c2f17234636d72f71450f1b4522ca26d2b7eb4a0a74/Flask-CacheBuster-1.0.0.tar.gz Collecting Flask-Login Downloading https://files.pythonhosted.org/packages/2b/83/ac5bf3279f969704fc1e63f050c50e10985e50fd340e6069ec7e09df5442/Flask_Login-0.5.0-py2.py3-none-any.whl Collecting markdown2 Downloading https://files.pythonhosted.org/packages/5d/be/3924cc1c0e12030b5225de2b4521f1dc729730773861475de26be64a0d2b/markdown2-2.4.0-py2.py3-none-any.whl Collecting pycryptodome Downloading https://files.pythonhosted.org/packages/ad/16/9627ab0493894a11c68e46000dbcc82f578c8ff06bc2980dcd016aea9bd3/pycryptodome-3.10.1-cp35-abi3-manylinux2010_x86_64.whl (1.9MB) |████████████████████████████████| 1.9MB 49.4MB/s Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from gradio) (1.19.5) Requirement already satisfied: pillow in /usr/local/lib/python3.7/dist-packages (from gradio) (7.1.2) Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from gradio) (3.2.2) Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from gradio) (1.1.5) Collecting analytics-python Downloading https://files.pythonhosted.org/packages/30/81/2f447982f8d5dec5b56c10ca9ac53e5de2b2e9e2bdf7e091a05731f21379/analytics_python-1.3.1-py2.py3-none-any.whl Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from gradio) (1.4.1) Requirement already satisfied: Flask>=1.1.1 in /usr/local/lib/python3.7/dist-packages (from gradio) (1.1.4) Collecting bcrypt>=3.1.3 Downloading https://files.pythonhosted.org/packages/26/70/6d218afbe4c73538053c1016dd631e8f25fffc10cd01f5c272d7acf3c03d/bcrypt-3.2.0-cp36-abi3-manylinux2010_x86_64.whl (63kB) |████████████████████████████████| 71kB 11.5MB/s Collecting cryptography>=2.5 Downloading https://files.pythonhosted.org/packages/b2/26/7af637e6a7e87258b963f1731c5982fb31cd507f0d90d91836e446955d02/cryptography-3.4.7-cp36-abi3-manylinux2014_x86_64.whl (3.2MB) |████████████████████████████████| 3.2MB 51.2MB/s Collecting pynacl>=1.0.1 Downloading https://files.pythonhosted.org/packages/9d/57/2f5e6226a674b2bcb6db531e8b383079b678df5b10cdaa610d6cf20d77ba/PyNaCl-1.4.0-cp35-abi3-manylinux1_x86_64.whl (961kB) |████████████████████████████████| 962kB 52.1MB/s Requirement already satisfied: Six in /usr/local/lib/python3.7/dist-packages (from Flask-Cors>=3.0.8->gradio) (1.15.0) Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->gradio) (2.10) Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->gradio) (3.0.4) Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->gradio) (1.24.3) Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->gradio) (2021.5.30) Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->gradio) (1.3.1) Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->gradio) (2.8.1) Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->gradio) (0.10.0) Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->gradio) (2.4.7) Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.7/dist-packages (from pandas->gradio) (2018.9) Collecting monotonic>=1.5 Downloading https://files.pythonhosted.org/packages/9a/67/7e8406a29b6c45be7af7740456f7f37025f0506ae2e05fb9009a53946860/monotonic-1.6-py2.py3-none-any.whl Collecting backoff==1.10.0 Downloading https://files.pythonhosted.org/packages/f0/32/c5dd4f4b0746e9ec05ace2a5045c1fc375ae67ee94355344ad6c7005fd87/backoff-1.10.0-py2.py3-none-any.whl Requirement already satisfied: Jinja2<3.0,>=2.10.1 in /usr/local/lib/python3.7/dist-packages (from Flask>=1.1.1->gradio) (2.11.3) Requirement already satisfied: Werkzeug<2.0,>=0.15 in /usr/local/lib/python3.7/dist-packages (from Flask>=1.1.1->gradio) (1.0.1) Requirement already satisfied: itsdangerous<2.0,>=0.24 in /usr/local/lib/python3.7/dist-packages (from Flask>=1.1.1->gradio) (1.1.0) Requirement already satisfied: click<8.0,>=5.1 in /usr/local/lib/python3.7/dist-packages (from Flask>=1.1.1->gradio) (7.1.2) Requirement already satisfied: cffi>=1.1 in /usr/local/lib/python3.7/dist-packages (from bcrypt>=3.1.3->paramiko->gradio) (1.14.5) Requirement already satisfied: MarkupSafe>=0.23 in /usr/local/lib/python3.7/dist-packages (from Jinja2<3.0,>=2.10.1->Flask>=1.1.1->gradio) (2.0.1) Requirement already satisfied: pycparser in /usr/local/lib/python3.7/dist-packages (from cffi>=1.1->bcrypt>=3.1.3->paramiko->gradio) (2.20) Building wheels for collected packages: ffmpy, flask-cachebuster Building wheel for ffmpy (setup.py) ... done Created wheel for ffmpy: filename=ffmpy-0.3.0-cp37-none-any.whl size=4710 sha256=85a071b6fa61e2fa9622ea36c1fb30c45999955b58770b6b2e9636349c5dbc20 Stored in directory: /root/.cache/pip/wheels/cc/ac/c4/bef572cb7e52bfca170046f567e64858632daf77e0f34e5a74 Building wheel for flask-cachebuster (setup.py) ... done Created wheel for flask-cachebuster: filename=Flask_CacheBuster-1.0.0-cp37-none-any.whl size=3372 sha256=687dec1b68821be6fc6e09bde9d141065ab42f2815b3a5d1325d7aee55899a69 Stored in directory: /root/.cache/pip/wheels/9f/fc/a7/ab5712c3ace9a8f97276465cc2937316ab8063c1fea488ea77 Successfully built ffmpy flask-cachebuster Installing collected packages: bcrypt, cryptography, pynacl, paramiko, Flask-Cors, ffmpy, flask-cachebuster, Flask-Login, markdown2, pycryptodome, monotonic, backoff, analytics-python, gradio Successfully installed Flask-Cors-3.0.10 Flask-Login-0.5.0 analytics-python-1.3.1 backoff-1.10.0 bcrypt-3.2.0 cryptography-3.4.7 ffmpy-0.3.0 flask-cachebuster-1.0.0 gradio-2.1.2 markdown2-2.4.0 monotonic-1.6 paramiko-2.7.2 pycryptodome-3.10.1 pynacl-1.4.0
import gradio as gr
iface = gr.Interface(
fn=complete_text,
inputs=[
"text",
gr.inputs.Slider(50, 200, 1, 50),
gr.inputs.Slider(0.2, 2, 0.2, 1)],
outputs="text",
title="紅樓夢生成器",
description="起個頭, 幫你完成一段紅樓夢。可以改變 temperature, 越小生出的字越固定, 越大越隨機。")
iface.launch(share=True)
Colab notebook detected. To show errors in colab notebook, set `debug=True` in `launch()` This share link will expire in 24 hours. If you need a permanent link, visit: https://gradio.app/introducing-hosted (NEW!) Running on External URL: https://15970.gradio.app Interface loading below...
(<Flask 'gradio.networking'>, 'http://127.0.0.1:7860/', 'https://15970.gradio.app')