台灣最常見的八哥有 (土) 八哥、白尾八哥及家八哥三種。我們來挑戰三種八哥總共用不到三十張照片, 看能不能打造一個神經網路學會辨識這三種八哥。
%matplotlib inline
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tensorflow.keras.applications import ResNet50V2
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.applications.resnet_v2 import preprocess_input
from tensorflow.keras.preprocessing.image import load_img, img_to_array
!wget --no-check-certificate \
https://github.com/yenlung/Deep-Learning-Basics/raw/master/images/myna.zip \
-O /content/myna.zip
--2022-08-20 16:20:11-- https://github.com/yenlung/Deep-Learning-Basics/raw/master/images/myna.zip Resolving github.com (github.com)... 20.205.243.166 Connecting to github.com (github.com)|20.205.243.166|:443... connected. HTTP request sent, awaiting response... 302 Found Location: https://raw.githubusercontent.com/yenlung/Deep-Learning-Basics/master/images/myna.zip [following] --2022-08-20 16:20:11-- https://raw.githubusercontent.com/yenlung/Deep-Learning-Basics/master/images/myna.zip Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ... Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 964098 (942K) [application/zip] Saving to: ‘/content/myna.zip’ /content/myna.zip 100%[===================>] 941.50K --.-KB/s in 0.05s 2022-08-20 16:20:11 (19.6 MB/s) - ‘/content/myna.zip’ saved [964098/964098]
import os
import zipfile
local_zip = '/content/myna.zip'
zip_ref = zipfile.ZipFile(local_zip, 'r')
zip_ref.extractall('/content')
zip_ref.close()
base_dir = '/content/'
myna_folders = ['crested_myna', 'javan_myna', 'common_myna']
我們可以列出在某個資料夾的檔名! 比方說(土)八哥是這樣。
thedir = base_dir + myna_folders[0]
os.listdir(thedir)
['crested_myna02.jpg', 'crested_myna03.jpg', 'crested_myna01.jpg']
接下來,我們要將這三個資料夾底下的照片作成輸入 (data)、輸出 (target)。
data = []
target = []
for i in range(3):
thedir = base_dir + myna_folders[i]
myna_fnames = os.listdir(thedir)
for myna in myna_fnames:
img_path = thedir + '/' + myna
img = load_img(img_path , target_size = (256,256))
x = img_to_array(img)
data.append(x)
target.append(i)
data = np.array(data)
順便看一下我們總共有多少張圖片
data.shape
(23, 256, 256, 3)
接著,隨便挑一張照片來看看它是什麼「鳥」樣
n = 1
plt.imshow(data[n]/255)
plt.axis('off');
看來沒有什麼意外, 就是個鳥圖。我們用 ResNet 的預處理再看一次。
x_train = preprocess_input(data)
plt.imshow(x_train[n])
plt.axis('off');
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
這邊會跳出一個小小的警告,這是因為 ResNet 的預處理會將圖片的數據範圍常規化到 [-1, 1] 之間,與我們之前常規化到 [0, 1] 之間差一點點,所以畫圖時,matplotlib會自動作一些調整
每張圖的答案就是 0, 1, 2 其中一個數字。
target[n]
0
做 one-hot enconding。
y_train = to_categorical(target, 3)
y_train[0]
array([1., 0., 0.], dtype=float32)
ResNet50 是 2015 ImageNet 的冠軍, 我們用第二版來試試。原本 ImageNet 是做了 1,000 個類別的圖形辨識。我們想直接用來辨識八哥, 就是把最後一層 (通常就 1,000 個輸出的 dense 層) 砍掉 (include_top=False
), 然後換我們的就好。
再來我們可以把每個 filter 的結果做個大總合, 例如算每個 filter 計分板的總平均 (global average pooling), 這本來該我們自己做, 但是 tf.Keras
是善良無比的幫我們做好。只要下個參數 pooling="avg"
)。
from tensorflow.keras.applications import ResNet50V2
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
resnet = ResNet50V2(include_top=False, pooling="avg")
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50v2_weights_tf_dim_ordering_tf_kernels_notop.h5 94674944/94668760 [==============================] - 2s 0us/step 94683136/94668760 [==============================] - 2s 0us/step
再來就是正式打造我們遷移學習版的函數學習機! 可以發現我們只是加入了最後一層...
model = Sequential()
model.add(resnet)
model.add(Dense(3, activation='softmax'))
我們是遷移式學習, 原本 ResNet 的部份我們當然沒有重新訓練的意思。於是就設這邊不需要訓練。
resnet.trainable = False
model.summary()
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= resnet50v2 (Functional) (None, 2048) 23564800 dense (Dense) (None, 3) 6147 ================================================================= Total params: 23,570,947 Trainable params: 6,147 Non-trainable params: 23,564,800 _________________________________________________________________
這裡我們用分類時非常標準的 categorical_crossentropy
, 順便試試有名的 adam
學習法。
model.compile(loss='categorical_crossentropy',
optimizer='adam',
metrics=['accuracy'])
我們可以發現原來有超過兩千萬個參數, 經我們偷來, 不是, 借來以後, 只有 6,147 個參數要調。
這裡我們全部的資料也只有 23 筆, 所以 batch_size
就選擇 23 了...
model.fit(x_train, y_train, batch_size=23, epochs=10)
Epoch 1/10 1/1 [==============================] - 14s 14s/step - loss: 1.4943 - accuracy: 0.2609 Epoch 2/10 1/1 [==============================] - 0s 80ms/step - loss: 1.2043 - accuracy: 0.4348 Epoch 3/10 1/1 [==============================] - 0s 81ms/step - loss: 1.0581 - accuracy: 0.4783 Epoch 4/10 1/1 [==============================] - 0s 83ms/step - loss: 0.9931 - accuracy: 0.5217 Epoch 5/10 1/1 [==============================] - 0s 81ms/step - loss: 0.9404 - accuracy: 0.5652 Epoch 6/10 1/1 [==============================] - 0s 83ms/step - loss: 0.8728 - accuracy: 0.5652 Epoch 7/10 1/1 [==============================] - 0s 83ms/step - loss: 0.7905 - accuracy: 0.6087 Epoch 8/10 1/1 [==============================] - 0s 81ms/step - loss: 0.7039 - accuracy: 0.6957 Epoch 9/10 1/1 [==============================] - 0s 81ms/step - loss: 0.6234 - accuracy: 0.6957 Epoch 10/10 1/1 [==============================] - 0s 84ms/step - loss: 0.5555 - accuracy: 0.8261
<keras.callbacks.History at 0x7f103c057ed0>
我們先用 model.evaluate 看一下模型表現得如何
loss, acc = model.evaluate(x_train, y_train)
print(f"Loss: {loss}")
print(f"Accuracy: {acc}")
1/1 [==============================] - 1s 888ms/step - loss: 0.5000 - accuracy: 0.8696 Loss: 0.4999895393848419 Accuracy: 0.8695651888847351
為了後面的需要,我們將三種八哥的答案給寫成 labels
labels = ["土八哥", "白尾八哥", "家八哥"]
對了, 為何這次我們沒有切測試一一資料呢? 那是因為畢竟我們全部只有 23 張照片,而且每種八哥的照片也沒幾張。我們可以看一下訓練成果。
y_predict = np.argmax(model.predict(x_train), -1)
y_predict
array([2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])
好奇的話比較一下正確答案。
target
[0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
gradio
打造八哥辨識 web app!¶!pip install gradio
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/ Collecting gradio Downloading gradio-3.1.6-py3-none-any.whl (6.1 MB) |████████████████████████████████| 6.1 MB 28.0 MB/s Collecting fastapi Downloading fastapi-0.79.1-py3-none-any.whl (54 kB) |████████████████████████████████| 54 kB 4.1 MB/s Collecting orjson Downloading orjson-3.7.12-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (275 kB) |████████████████████████████████| 275 kB 74.4 MB/s Requirement already satisfied: pydantic in /usr/local/lib/python3.7/dist-packages (from gradio) (1.9.2) Collecting ffmpy Downloading ffmpy-0.3.0.tar.gz (4.8 kB) Collecting paramiko Downloading paramiko-2.11.0-py2.py3-none-any.whl (212 kB) |████████████████████████████████| 212 kB 60.5 MB/s Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from gradio) (1.21.6) Collecting analytics-python Downloading analytics_python-1.4.0-py2.py3-none-any.whl (15 kB) Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from gradio) (1.3.5) Collecting h11<0.13,>=0.11 Downloading h11-0.12.0-py3-none-any.whl (54 kB) |████████████████████████████████| 54 kB 4.4 MB/s Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from gradio) (3.2.2) Collecting websockets Downloading websockets-10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (112 kB) |████████████████████████████████| 112 kB 73.4 MB/s Collecting httpx Downloading httpx-0.23.0-py3-none-any.whl (84 kB) |████████████████████████████████| 84 kB 4.9 MB/s Collecting pycryptodome Downloading pycryptodome-3.15.0-cp35-abi3-manylinux2010_x86_64.whl (2.3 MB) |████████████████████████████████| 2.3 MB 57.7 MB/s Requirement already satisfied: aiohttp in /usr/local/lib/python3.7/dist-packages (from gradio) (3.8.1) Collecting pydub Downloading pydub-0.25.1-py2.py3-none-any.whl (32 kB) Requirement already satisfied: fsspec in /usr/local/lib/python3.7/dist-packages (from gradio) (2022.7.1) Collecting markdown-it-py[linkify,plugins] Downloading markdown_it_py-2.1.0-py3-none-any.whl (84 kB) |████████████████████████████████| 84 kB 4.3 MB/s Requirement already satisfied: pillow in /usr/local/lib/python3.7/dist-packages (from gradio) (7.1.2) Collecting uvicorn Downloading uvicorn-0.18.2-py3-none-any.whl (57 kB) |████████████████████████████████| 57 kB 6.2 MB/s Collecting python-multipart Downloading python-multipart-0.0.5.tar.gz (32 kB) Requirement already satisfied: Jinja2 in /usr/local/lib/python3.7/dist-packages (from gradio) (2.11.3) Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from gradio) (2.23.0) Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.7/dist-packages (from aiohttp->gradio) (6.0.2) Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->gradio) (1.8.1) Requirement already satisfied: charset-normalizer<3.0,>=2.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->gradio) (2.1.0) Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.7/dist-packages (from aiohttp->gradio) (1.2.0) Requirement already satisfied: asynctest==0.13.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->gradio) (0.13.0) Requirement already satisfied: typing-extensions>=3.7.4 in /usr/local/lib/python3.7/dist-packages (from aiohttp->gradio) (4.1.1) Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.7/dist-packages (from aiohttp->gradio) (1.3.1) Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.7/dist-packages (from aiohttp->gradio) (4.0.2) Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->gradio) (22.1.0) Requirement already satisfied: idna>=2.0 in /usr/local/lib/python3.7/dist-packages (from yarl<2.0,>=1.0->aiohttp->gradio) (2.10) Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from analytics-python->gradio) (1.15.0) Collecting monotonic>=1.5 Downloading monotonic-1.6-py2.py3-none-any.whl (8.2 kB) Requirement already satisfied: python-dateutil>2.1 in /usr/local/lib/python3.7/dist-packages (from analytics-python->gradio) (2.8.2) Collecting backoff==1.10.0 Downloading backoff-1.10.0-py2.py3-none-any.whl (31 kB) Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->gradio) (1.24.3) Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->gradio) (2022.6.15) Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->gradio) (3.0.4) Collecting starlette==0.19.1 Downloading starlette-0.19.1-py3-none-any.whl (63 kB) |████████████████████████████████| 63 kB 2.4 MB/s Collecting anyio<5,>=3.4.0 Downloading anyio-3.6.1-py3-none-any.whl (80 kB) |████████████████████████████████| 80 kB 11.5 MB/s Collecting sniffio>=1.1 Downloading sniffio-1.2.0-py3-none-any.whl (10 kB) Collecting rfc3986[idna2008]<2,>=1.3 Downloading rfc3986-1.5.0-py2.py3-none-any.whl (31 kB) Collecting httpcore<0.16.0,>=0.15.0 Downloading httpcore-0.15.0-py3-none-any.whl (68 kB) |████████████████████████████████| 68 kB 8.3 MB/s Requirement already satisfied: MarkupSafe>=0.23 in /usr/local/lib/python3.7/dist-packages (from Jinja2->gradio) (2.0.1) Collecting mdurl~=0.1 Downloading mdurl-0.1.2-py3-none-any.whl (10.0 kB) Collecting linkify-it-py~=1.0 Downloading linkify_it_py-1.0.3-py3-none-any.whl (19 kB) Collecting mdit-py-plugins Downloading mdit_py_plugins-0.3.0-py3-none-any.whl (43 kB) |████████████████████████████████| 43 kB 2.7 MB/s Collecting uc-micro-py Downloading uc_micro_py-1.0.1-py3-none-any.whl (6.2 kB) Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->gradio) (0.11.0) Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->gradio) (3.0.9) Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->gradio) (1.4.4) Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.7/dist-packages (from pandas->gradio) (2022.2.1) Collecting pynacl>=1.0.1 Downloading PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl (856 kB) |████████████████████████████████| 856 kB 72.9 MB/s Collecting bcrypt>=3.1.3 Downloading bcrypt-3.2.2-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl (62 kB) |████████████████████████████████| 62 kB 1.2 MB/s Collecting cryptography>=2.5 Downloading cryptography-37.0.4-cp36-abi3-manylinux_2_24_x86_64.whl (4.1 MB) |████████████████████████████████| 4.1 MB 68.3 MB/s Requirement already satisfied: cffi>=1.1 in /usr/local/lib/python3.7/dist-packages (from bcrypt>=3.1.3->paramiko->gradio) (1.15.1) Requirement already satisfied: pycparser in /usr/local/lib/python3.7/dist-packages (from cffi>=1.1->bcrypt>=3.1.3->paramiko->gradio) (2.21) Requirement already satisfied: click>=7.0 in /usr/local/lib/python3.7/dist-packages (from uvicorn->gradio) (7.1.2) Building wheels for collected packages: ffmpy, python-multipart Building wheel for ffmpy (setup.py) ... done Created wheel for ffmpy: filename=ffmpy-0.3.0-py3-none-any.whl size=4712 sha256=995b68bebd7f1b41d99a559817e2a84e05633b1ae6f31b0f9f0e0ef7ed7a392f Stored in directory: /root/.cache/pip/wheels/13/e4/6c/e8059816e86796a597c6e6b0d4c880630f51a1fcfa0befd5e6 Building wheel for python-multipart (setup.py) ... done Created wheel for python-multipart: filename=python_multipart-0.0.5-py3-none-any.whl size=31678 sha256=bf1022062cb99d3354912282fdcf9ac3e8e564844f5e004fc17bd415122cea6d Stored in directory: /root/.cache/pip/wheels/2c/41/7c/bfd1c180534ffdcc0972f78c5758f89881602175d48a8bcd2c Successfully built ffmpy python-multipart Installing collected packages: sniffio, mdurl, uc-micro-py, rfc3986, markdown-it-py, h11, anyio, starlette, pynacl, monotonic, mdit-py-plugins, linkify-it-py, httpcore, cryptography, bcrypt, backoff, websockets, uvicorn, python-multipart, pydub, pycryptodome, paramiko, orjson, httpx, ffmpy, fastapi, analytics-python, gradio Successfully installed analytics-python-1.4.0 anyio-3.6.1 backoff-1.10.0 bcrypt-3.2.2 cryptography-37.0.4 fastapi-0.79.1 ffmpy-0.3.0 gradio-3.1.6 h11-0.12.0 httpcore-0.15.0 httpx-0.23.0 linkify-it-py-1.0.3 markdown-it-py-2.1.0 mdit-py-plugins-0.3.0 mdurl-0.1.2 monotonic-1.6 orjson-3.7.12 paramiko-2.11.0 pycryptodome-3.15.0 pydub-0.25.1 pynacl-1.5.0 python-multipart-0.0.5 rfc3986-1.5.0 sniffio-1.2.0 starlette-0.19.1 uc-micro-py-1.0.1 uvicorn-0.18.2 websockets-10.3
import gradio as gr
注意現在主函數和我們做辨識 Cooper 的例子很像, 只是現在我們只有三個種類。而且是我們的 model!
def classify_image(inp):
inp = inp.reshape((-1, 256, 256, 3))
inp = preprocess_input(inp)
prediction = model.predict(inp).flatten()
return {labels[i]: float(prediction[i]) for i in range(3)}
image = gr.Image(shape=(256, 256), label="八哥照片")
label = gr.Label(num_top_classes=3, label="AI辨識結果")
some_text="我能辨識(土)八哥、白尾八哥、家八哥。找張八哥照片來考我吧!"
我們將八哥數據庫中的圖片拿出來當作範例圖片讓使用者使用
sample_images = []
for i in range(3):
thedir = base_dir + myna_folders[i]
for file in os.listdir(thedir):
sample_images.append(myna_folders[i] + '/' + file)
最後,將所有東西組裝在一起,就大功告成了!
gr.Interface(fn=classify_image,
inputs=image,
outputs=label,
title="AI 八哥辨識機",
description=some_text,
examples=sample_images).launch(share=True)
Colab notebook detected. To show errors in colab notebook, set `debug=True` in `launch()` Running on public URL: https://52478.gradio.app This share link expires in 72 hours. For free permanent hosting, check out Spaces: https://huggingface.co/spaces
(<gradio.routes.App at 0x7f0da6c42dd0>, 'http://127.0.0.1:7860/', 'https://52478.gradio.app')