我們終於要開始做生命中第一個神經網路...
這是 TensorFlow 2 的版本, TensorFlow 完全融入一個叫 Keras 的套件, 整個變成非常容易使用的深度學習框架。
這裡我們讀入一些套件, 今天暫時不要理會細節。
%matplotlib inline
# 標準數據分析、畫圖套件
import numpy as np
import matplotlib.pyplot as plt
# 神經網路方面
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import SGD
# 互動設計用
from ipywidgets import interact_manual
Keras 很貼心的幫我們準備好 MNIST 數據庫, 我們可以這樣讀進來 (第一次要花點時間)。
(x_train, y_train), (x_test, y_test) = mnist.load_data()
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz 11493376/11490434 [==============================] - 0s 0us/step
我們來看看訓練資料是不是 6 萬筆、測試資料是不是有 1 筆。
print(f'訓練資料總筆數為 {len(x_train)} 筆資料')
print(f'測試資料總筆數為 {len(x_test)} 筆資料')
訓練資料總筆數為 60000 筆資料 測試資料總筆數為 10000 筆資料
每筆輸入 (x) 就是一個手寫的 0-9 中一個數字的圖檔, 大小為 28x28。而輸出 (y) 當然就是「正確答案」。我們來看看編訓練資料的 x 輸入、輸出的部份分別長什麼樣子。
def show_xy(n=0):
ax = plt.gca()
X = x_train[n]
plt.xticks([], [])
plt.yticks([], [])
plt.imshow(X, cmap = 'Greys')
print(f'本資料 y 給定的答案為: {y_train[n]}')
interact_manual(show_xy, n=(0,59999));
interactive(children=(IntSlider(value=0, description='n', max=59999), Button(description='Run Interact', style…
def show_data(n = 100):
X = x_train[n]
print(X)
interact_manual(show_data, n=(0,59999));
interactive(children=(IntSlider(value=100, description='n', max=59999), Button(description='Run Interact', sty…
我們現在要用標準神經網路學學手寫辨識。原來的每筆數據是個 28x28 的矩陣 (array), 但標準神經網路只吃「平平的」, 也就是每次要 28x28=784 長的向量。因此我們要用 reshape
調校一下。
x_train.shape
(60000, 28, 28)
x_test.shape
(10000, 28, 28)
x_train = x_train.reshape(60000, 784)/255
x_test = x_test.reshape(10000, 784)/255
x_train[87]
array([0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.15294118, 0.49019608, 0.88235294, 0.99607843, 0.99607843, 1. , 0.99607843, 0.66666667, 0.18823529, 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.16862745, 0.39607843, 0.98039216, 0.99215686, 0.99215686, 0.99215686, 0.99215686, 0.99215686, 0.99215686, 0.99215686, 0.98039216, 0.63137255, 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.36862745, 0.96470588, 0.96862745, 0.99215686, 0.99215686, 0.76862745, 0.89019608, 0.45490196, 0.21960784, 0.99215686, 0.99215686, 0.99215686, 0.91764706, 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.59607843, 0.99215686, 0.99215686, 0.70588235, 0.0745098 , 0.03529412, 0.05882353, 0. , 0.01568627, 0.21568627, 0.99215686, 0.99215686, 0.65098039, 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.16078431, 0.93333333, 0.99215686, 0.99215686, 0.49019608, 0. , 0. , 0. , 0.08235294, 0.74117647, 0.90980392, 0.99215686, 0.99215686, 0.45882353, 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.85882353, 0.99215686, 0.8627451 , 0.64705882, 0.13333333, 0.36078431, 0.08235294, 0.20392157, 0.89411765, 0.99215686, 0.99215686, 0.94509804, 0.32156863, 0.05098039, 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.14901961, 0.94509804, 0.66666667, 0.09803922, 0.07843137, 0.04705882, 0.29411765, 0.15294118, 0.23137255, 0.99215686, 0.99215686, 0.99215686, 0.43137255, 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.43921569, 0.99215686, 0.9254902 , 0.2627451 , 0. , 0. , 0. , 0. , 0.39215686, 0.99215686, 0.99215686, 0.86666667, 0.0627451 , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.09019608, 0.9372549 , 0.99215686, 0.92156863, 0.79215686, 0.52941176, 0.38823529, 0.67843137, 0.94117647, 0.99215686, 0.99215686, 0.43137255, 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.21176471, 0.78431373, 0.99215686, 0.99215686, 0.99215686, 0.99215686, 0.99215686, 0.99215686, 0.99215686, 0.94509804, 0.24705882, 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.09803922, 0.43921569, 0.95686275, 0.99215686, 0.92941176, 0.55686275, 0.99215686, 0.99215686, 0.43529412, 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.23921569, 0.27843137, 0.2 , 0.62352941, 0.99215686, 0.7372549 , 0.08627451, 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.05882353, 0.58823529, 0.9254902 , 0.83137255, 0.08627451, 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.38823529, 0.99215686, 0.95294118, 0.38431373, 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.28235294, 0.92941176, 0.99215686, 0.41176471, 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.03921569, 0.85882353, 0.99215686, 0.76470588, 0.08627451, 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.03137255, 0.67058824, 0.99215686, 0.81176471, 0.08235294, 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.41176471, 0.99215686, 0.77647059, 0.29803922, 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.23529412, 0.94901961, 0.99215686, 0.14901961, 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.92156863, 0.99215686, 0.80784314, 0.0745098 , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ])
我們可能會想, 我們想學的函數是這樣的型式:
$$\hat{f} \colon \mathbb{R}^{784} \to \mathbb{R}$$其實這樣不太好! 為什麼呢? 比如說我們的輸入 x 是一張 0 的圖, 因為我們訓練的神經網路總會有點誤差, 所以可能會得到:
$$\hat{f}(x) = 0.5$$那這意思是有可能是 0, 也有可能是 1 嗎!!?? 可是 0 和 1 根本不像啊。換句話說分類的問題這樣做其實不合理!
於是我們會做 "1-hot enconding", 也就是
等等。因為分類問題基本上都要做這件事, Keras 其實已幫我們準備好套件!
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)
我們來看看剛剛某號數據的答案。
n = 87
y_train[n]
array([0., 0., 0., 0., 0., 0., 0., 0., 0., 1.], dtype=float32)
和我們想的一樣! 至此我們可以打造我們的神經網路了。
我們決定了我們的函數是
$$\hat{f} \colon \mathbb{R}^{784} \to \mathbb{R}^{10}$$這個樣子。而我們又說第一次要用標準神網路試試, 所以我們只需要再決定要幾個隱藏層、每層要幾個神經元, 用哪個激發函數就可以了。
假如我們要用 ReLU 當激發函數, 要設計神經網路, 只差要指定多少個隱藏層、每層多少個神經元就好了!
設計完了基本上就是告訴 TensorFlow, 我們的想法就可以了!
#@title 設計你的神經網路
隱藏層數 = 3 #@param{type:"integer"}
神經元1 = 0#@param{type:"integer"}
神經元2 = 0 #@param{type:"integer"}
神經元3 = 0 #@param{type:"integer"}
神經元4 = 0 #@param{type:"integer"}
神經元5 = 0 #@param{type:"integer"}
和以前做迴歸或機器學習一樣, 我們就打開個「函數學習機」。標準一層一層傳遞的神經網路叫 Sequential
, 於是我們打開一個空的神經網路。
model = Sequential()
我們每次用 add
去加一層, 從第一個隱藏層開始。而第一個隱藏層因為 TensorFlow 當然猜不到輸入有 784 個 features, 所以我們要告訴它。
model.add(Dense(20, input_dim=784, activation='relu'))
第二層開始就不用再說明輸入神經元個數 (因為就是前一層神經元數)。
model.add(Dense(20, activation='relu'))
model.add(Dense(20, activation='relu'))
輸出有 10 個數字, 所以輸出層的神經元是 10 個! 而如果我們的網路輸出是
$$(y_1, y_2, \ldots, y_{10})$$我們還希望
$$\sum_{i=1}^{10} y_i = 1$$這可能嗎, 結果是很容易, 就用 softmax
當激發函數就可以!!
model.add(Dense(10, activation='softmax'))
至此我們的第一個神經網路就建好了!
和之前比較不一樣的是我們還要做 compile
才正式把我們的神經網路建好。你可以發現我們還需要做幾件事:
mse
為了一邊訓練一邊看到結果, 我們加設
metrics=['accuracy']
本行基本上和我們的神經網路功能沒有什麼關係。
model.compile(loss='mse', optimizer=SGD(learning_rate=0.087), metrics=['accuracy'])
我們可以檢視我們神經網路的架構, 可以確認一下是不是和我們想像的一樣。
model.summary()
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense (Dense) (None, 20) 15700 _________________________________________________________________ dense_1 (Dense) (None, 20) 420 _________________________________________________________________ dense_2 (Dense) (None, 20) 420 _________________________________________________________________ dense_3 (Dense) (None, 10) 210 ================================================================= Total params: 16,750 Trainable params: 16,750 Non-trainable params: 0 _________________________________________________________________
很快算算參數數目和我們想像是否是一樣的!
恭喜! 我們完成了第一個神經網路。現在要訓練的時候, 你會發現不是像以前沒頭沒腦把訓練資料送進去就好。這裡我們還有兩件事要決定:
batch_size
), 我們就 100 筆調一次參數好了epochs
), 我們訓練個 10 次試試於是最精彩的就來了。你要有等待的心理準備...
model.fit(x_train, y_train, batch_size=100, epochs=10)
Epoch 1/10 600/600 [==============================] - 4s 2ms/step - loss: 0.0892 - accuracy: 0.1038 Epoch 2/10 600/600 [==============================] - 1s 2ms/step - loss: 0.0855 - accuracy: 0.1593 Epoch 3/10 600/600 [==============================] - 1s 2ms/step - loss: 0.0808 - accuracy: 0.3638 Epoch 4/10 600/600 [==============================] - 1s 2ms/step - loss: 0.0674 - accuracy: 0.4868 Epoch 5/10 600/600 [==============================] - 1s 2ms/step - loss: 0.0477 - accuracy: 0.6993 Epoch 6/10 600/600 [==============================] - 1s 2ms/step - loss: 0.0306 - accuracy: 0.8251 Epoch 7/10 600/600 [==============================] - 1s 2ms/step - loss: 0.0230 - accuracy: 0.8566 Epoch 8/10 600/600 [==============================] - 1s 2ms/step - loss: 0.0198 - accuracy: 0.8731 Epoch 9/10 600/600 [==============================] - 1s 2ms/step - loss: 0.0178 - accuracy: 0.8855 Epoch 10/10 600/600 [==============================] - 1s 2ms/step - loss: 0.0164 - accuracy: 0.8947
<tensorflow.python.keras.callbacks.History at 0x7ff210050a50>
我們來用比較炫的方式來看看可愛的神經網路學習成果。對指令有問題可以參考《少年Py的大冒險:成為Python數據分析達人的第一門課》。
loss, acc = model.evaluate(x_test, y_test)
313/313 [==============================] - 1s 2ms/step - loss: 0.0150 - accuracy: 0.9029
print(f"測試資料正確率 {acc*100:.2f}%")
測試資料正確率 90.29%
from ipywidgets import interact_manual
我們 "predict" 放的是我們神經網路的學習結果。做完之後用 argmax 找到數值最大的那一項。
predict = np.argmax(model.predict(x_test), axis=-1)
predict
array([7, 2, 1, ..., 4, 8, 6])
不要忘了我們的 x_test
每筆資料已經換成 784 維的向量, 我們要整型回 28x28 的矩陣才能當成圖形顯示出來!
def test(測試編號):
plt.imshow(x_test[測試編號].reshape(28,28), cmap='Greys')
print('神經網路判斷為:', predict[測試編號])
test(87)
神經網路判斷為: 5
interact_manual(test, 測試編號=(0, 9999));
interactive(children=(IntSlider(value=4999, description='測試編號', max=9999), Button(description='Run Interact', …
到底測試資料總的狀況如何呢? 我們可以給我們神經網路「總評量」。
score = model.evaluate(x_test, y_test)
313/313 [==============================] - 1s 2ms/step - loss: 0.0150 - accuracy: 0.9029
print('loss:', score[0])
print('正確率', score[1])
loss: 0.015002136118710041 正確率 0.902899980545044
如果對訓練成果滿意, 我們當然不想每次都再訓練一次! 我們可以把神經網路的架構和訓練好的參數都存起來, 以供日後使用!
在 Colab 上, 我們要先連到自己的 Google Drive。
from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive
再來是 cd 到你的資料夾中, 我們通常是放到自己 Colab Notebooks 中, 自然你可以指定其他的資料夾。
%cd '/content/drive/My Drive/Colab Notebooks'
/content/drive/My Drive/Colab Notebooks
model.save('my_model')
INFO:tensorflow:Assets written to: my_model/assets
日後要讀回來就是要用 tf.keras.models
的 load_model
:
from tensorflow.kears.models import load_model
連上自己的 Google Drive, cd 進去原來存的資料夾。
model = load_model('my_model')
就可以了!
!pip install gradio
Collecting gradio Downloading https://files.pythonhosted.org/packages/c4/c7/0606fd431bd963ba704d8f71b7404ef778236d6f7f7981a36157ba68e6c5/gradio-2.0.10-py3-none-any.whl (2.4MB) |████████████████████████████████| 2.4MB 25.4MB/s Requirement already satisfied: Flask>=1.1.1 in /usr/local/lib/python3.7/dist-packages (from gradio) (1.1.4) Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from gradio) (1.4.1) Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from gradio) (2.23.0) Collecting Flask-Cors>=3.0.8 Downloading https://files.pythonhosted.org/packages/db/84/901e700de86604b1c4ef4b57110d4e947c218b9997adf5d38fa7da493bce/Flask_Cors-3.0.10-py2.py3-none-any.whl Collecting markdown2 Downloading https://files.pythonhosted.org/packages/5d/be/3924cc1c0e12030b5225de2b4521f1dc729730773861475de26be64a0d2b/markdown2-2.4.0-py2.py3-none-any.whl Collecting paramiko Downloading https://files.pythonhosted.org/packages/95/19/124e9287b43e6ff3ebb9cdea3e5e8e88475a873c05ccdf8b7e20d2c4201e/paramiko-2.7.2-py2.py3-none-any.whl (206kB) |████████████████████████████████| 215kB 54.0MB/s Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from gradio) (1.1.5) Collecting ffmpy Downloading https://files.pythonhosted.org/packages/bf/e2/947df4b3d666bfdd2b0c6355d215c45d2d40f929451cb29a8a2995b29788/ffmpy-0.3.0.tar.gz Collecting pycryptodome Downloading https://files.pythonhosted.org/packages/ad/16/9627ab0493894a11c68e46000dbcc82f578c8ff06bc2980dcd016aea9bd3/pycryptodome-3.10.1-cp35-abi3-manylinux2010_x86_64.whl (1.9MB) |████████████████████████████████| 1.9MB 54.7MB/s Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from gradio) (1.19.5) Collecting flask-cachebuster Downloading https://files.pythonhosted.org/packages/74/47/f3e1fedfaad965c81c2f17234636d72f71450f1b4522ca26d2b7eb4a0a74/Flask-CacheBuster-1.0.0.tar.gz Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from gradio) (3.2.2) Collecting Flask-Login Downloading https://files.pythonhosted.org/packages/2b/83/ac5bf3279f969704fc1e63f050c50e10985e50fd340e6069ec7e09df5442/Flask_Login-0.5.0-py2.py3-none-any.whl Collecting analytics-python Downloading https://files.pythonhosted.org/packages/30/81/2f447982f8d5dec5b56c10ca9ac53e5de2b2e9e2bdf7e091a05731f21379/analytics_python-1.3.1-py2.py3-none-any.whl Requirement already satisfied: pillow in /usr/local/lib/python3.7/dist-packages (from gradio) (7.1.2) Requirement already satisfied: itsdangerous<2.0,>=0.24 in /usr/local/lib/python3.7/dist-packages (from Flask>=1.1.1->gradio) (1.1.0) Requirement already satisfied: click<8.0,>=5.1 in /usr/local/lib/python3.7/dist-packages (from Flask>=1.1.1->gradio) (7.1.2) Requirement already satisfied: Jinja2<3.0,>=2.10.1 in /usr/local/lib/python3.7/dist-packages (from Flask>=1.1.1->gradio) (2.11.3) Requirement already satisfied: Werkzeug<2.0,>=0.15 in /usr/local/lib/python3.7/dist-packages (from Flask>=1.1.1->gradio) (1.0.1) Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->gradio) (3.0.4) Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->gradio) (1.24.3) Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->gradio) (2021.5.30) Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->gradio) (2.10) Requirement already satisfied: Six in /usr/local/lib/python3.7/dist-packages (from Flask-Cors>=3.0.8->gradio) (1.15.0) Collecting pynacl>=1.0.1 Downloading https://files.pythonhosted.org/packages/9d/57/2f5e6226a674b2bcb6db531e8b383079b678df5b10cdaa610d6cf20d77ba/PyNaCl-1.4.0-cp35-abi3-manylinux1_x86_64.whl (961kB) |████████████████████████████████| 962kB 40.5MB/s Collecting cryptography>=2.5 Downloading https://files.pythonhosted.org/packages/b2/26/7af637e6a7e87258b963f1731c5982fb31cd507f0d90d91836e446955d02/cryptography-3.4.7-cp36-abi3-manylinux2014_x86_64.whl (3.2MB) |████████████████████████████████| 3.2MB 40.6MB/s Collecting bcrypt>=3.1.3 Downloading https://files.pythonhosted.org/packages/26/70/6d218afbe4c73538053c1016dd631e8f25fffc10cd01f5c272d7acf3c03d/bcrypt-3.2.0-cp36-abi3-manylinux2010_x86_64.whl (63kB) |████████████████████████████████| 71kB 11.3MB/s Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.7/dist-packages (from pandas->gradio) (2018.9) Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas->gradio) (2.8.1) Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->gradio) (0.10.0) Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->gradio) (1.3.1) Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->gradio) (2.4.7) Collecting monotonic>=1.5 Downloading https://files.pythonhosted.org/packages/9a/67/7e8406a29b6c45be7af7740456f7f37025f0506ae2e05fb9009a53946860/monotonic-1.6-py2.py3-none-any.whl Collecting backoff==1.10.0 Downloading https://files.pythonhosted.org/packages/f0/32/c5dd4f4b0746e9ec05ace2a5045c1fc375ae67ee94355344ad6c7005fd87/backoff-1.10.0-py2.py3-none-any.whl Requirement already satisfied: MarkupSafe>=0.23 in /usr/local/lib/python3.7/dist-packages (from Jinja2<3.0,>=2.10.1->Flask>=1.1.1->gradio) (2.0.1) Requirement already satisfied: cffi>=1.4.1 in /usr/local/lib/python3.7/dist-packages (from pynacl>=1.0.1->paramiko->gradio) (1.14.5) Requirement already satisfied: pycparser in /usr/local/lib/python3.7/dist-packages (from cffi>=1.4.1->pynacl>=1.0.1->paramiko->gradio) (2.20) Building wheels for collected packages: ffmpy, flask-cachebuster Building wheel for ffmpy (setup.py) ... done Created wheel for ffmpy: filename=ffmpy-0.3.0-cp37-none-any.whl size=4710 sha256=d163a268b4059033aa7dddc9aac604b1193e90af7b70f2eb6db448291c85c760 Stored in directory: /root/.cache/pip/wheels/cc/ac/c4/bef572cb7e52bfca170046f567e64858632daf77e0f34e5a74 Building wheel for flask-cachebuster (setup.py) ... done Created wheel for flask-cachebuster: filename=Flask_CacheBuster-1.0.0-cp37-none-any.whl size=3372 sha256=43d9b200564391402342902a9b7b4660520d2c96dd018299bbeb9c4e9984518c Stored in directory: /root/.cache/pip/wheels/9f/fc/a7/ab5712c3ace9a8f97276465cc2937316ab8063c1fea488ea77 Successfully built ffmpy flask-cachebuster Installing collected packages: Flask-Cors, markdown2, pynacl, cryptography, bcrypt, paramiko, ffmpy, pycryptodome, flask-cachebuster, Flask-Login, monotonic, backoff, analytics-python, gradio Successfully installed Flask-Cors-3.0.10 Flask-Login-0.5.0 analytics-python-1.3.1 backoff-1.10.0 bcrypt-3.2.0 cryptography-3.4.7 ffmpy-0.3.0 flask-cachebuster-1.0.0 gradio-2.0.10 markdown2-2.4.0 monotonic-1.6 paramiko-2.7.2 pycryptodome-3.10.1 pynacl-1.4.0
import gradio as gr
def recognize_digit(img):
img = img.reshape(1,784)
prediction = model.predict(img).flatten()
labels = list('0123456789')
return {labels[i]: float(prediction[i]) for i in range(10)}
gr.Interface(fn=recognize_digit, inputs="sketchpad", outputs="label").launch()
Colab notebook detected. To show errors in colab notebook, set `debug=True` in `launch()` This share link will expire in 24 hours. If you need a permanent link, visit: https://gradio.app/introducing-hosted (NEW!) Running on External URL: https://32792.gradio.app Interface loading below...
(<Flask 'gradio.networking'>, 'http://127.0.0.1:7860/', 'https://32792.gradio.app')