我們試著用比較多的照片來做遷移式學習, 看是否有較好的效果。這其實也是一種示範, 看我們如果收集到了一些照片, 怎麼樣整理就能做成訓練資料。
%matplotlib inline
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from tensorflow.keras.applications import ResNet50V2
from tensorflow.keras.applications.resnet_v2 import preprocess_input
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
從網路讀入一個 .zip
檔, 存到我們 Colab 開給我們的 /content
資料夾下。
!wget --no-check-certificate \
https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip \
-O /content/cats_and_dogs_filtered.zip
--2021-08-05 22:28:12-- https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip Resolving storage.googleapis.com (storage.googleapis.com)... 142.251.2.128, 142.250.141.128, 2607:f8b0:4023:c0b::80, ... Connecting to storage.googleapis.com (storage.googleapis.com)|142.251.2.128|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 68606236 (65M) [application/zip] Saving to: ‘/content/cats_and_dogs_filtered.zip’ /content/cats_and_d 100%[===================>] 65.43M 255MB/s in 0.3s 2021-08-05 22:28:12 (255 MB/s) - ‘/content/cats_and_dogs_filtered.zip’ saved [68606236/68606236]
這裡示範 .zip
檔解壓縮, 解壓縮一樣放到我們的 /content
資料夾中。
import os
import zipfile
local_zip = '/content/cats_and_dogs_filtered.zip'
zip_ref = zipfile.ZipFile(local_zip, 'r')
zip_ref.extractall('/content')
zip_ref.close()
這時在 /content/cats_and_dogs_filtered/validation/cats
下有 500 張貓照片, 所以我們把這個路徑記下來叫 cats_dir
。同理我們也把狗照片路徑記下來, 叫 dogs_dir
。
base_dir = '/content/cats_and_dogs_filtered/validation'
cats_dir = base_dir + '/cats'
dogs_dir = base_dir + '/dogs'
貓狗照片的檔案名稱, 分別放入 cat_fnames
和 dog_fnames
兩個串列中。
cat_fnames = os.listdir(cats_dir)
dog_fnames = os.listdir(dogs_dir)
現在 data
會放入我們轉成 array
的照片, 而 target
會是答案: 0 是貓, 狗是 1。
data = []
target = []
for cat in cat_fnames:
img = load_img(cats_dir + '/' + cat, target_size = (224,224))
x = np.array(img)
data.append(x)
target.append(0)
for dog in dog_fnames:
img = load_img(dogs_dir + '/' + dog, target_size = (224,224))
x = np.array(img)
data.append(x)
target.append(1)
data = np.array(data)
target = np.array(target)
看看 data
的 shape
, 會發現有 1,000 張 224x224x3 的照片。
data.shape
(1000, 224, 224, 3)
target
自然是有 1,000 個正確答案 (貓或狗)。
target.shape
(1000,)
最後我們進行 ResNet50 的標準預處理動作。
data = preprocess_input(data)
使用 scikit-learn
最常被用到的指令 train_test_split
切訓練和測試資料。
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(data, target,
test_size=0.2,
random_state=0)
我們讀入 ResNet50V2
, 並且去掉後面, 然後做 Global Average Pooling。注意我們讀進來 resnet
的權重要凍結。
最後因為我們只有兩類, 所以輸出就是一個數字! 為了確保輸出在 0 到 1 中間, 用 sigmoid 函數當我們的 activation function。
resnet = ResNet50V2(include_top=False, pooling="avg")
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50v2_weights_tf_dim_ordering_tf_kernels_notop.h5 94674944/94668760 [==============================] - 1s 0us/step
resnet.trainable = False
model = Sequential()
model.add(resnet)
model.add(Dense(1, activation='sigmoid'))
我們只有兩個類別, 所以用 binary_crossentropy
。
model.compile(loss='binary_crossentropy',
optimizer='adam',
metrics=['accuracy'])
欣賞一下我們的成果。
model.summary()
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= resnet50v2 (Functional) (None, 2048) 23564800 _________________________________________________________________ dense (Dense) (None, 1) 2049 ================================================================= Total params: 23,566,849 Trainable params: 2,049 Non-trainable params: 23,564,800 _________________________________________________________________
model.fit(x_train, y_train, batch_size=128, epochs=5)
Epoch 1/5 7/7 [==============================] - 38s 457ms/step - loss: 0.4688 - accuracy: 0.7862 Epoch 2/5 7/7 [==============================] - 2s 297ms/step - loss: 0.1920 - accuracy: 0.9513 Epoch 3/5 7/7 [==============================] - 2s 295ms/step - loss: 0.1002 - accuracy: 0.9837 Epoch 4/5 7/7 [==============================] - 2s 302ms/step - loss: 0.0662 - accuracy: 0.9887 Epoch 5/5 7/7 [==============================] - 2s 301ms/step - loss: 0.0509 - accuracy: 0.9925
<tensorflow.python.keras.callbacks.History at 0x7fc1760b4c90>
loss, acc = model.evaluate(x_test, y_test)
print(f"測試資料的 loss 為: {loss:.4f}")
print(f"測試資料的正確率為: {acc*100:.2f}%")
7/7 [==============================] - 2s 183ms/step - loss: 0.0640 - accuracy: 0.9800 測試資料的 loss 為: 0.0640 測試資料的正確率為: 98.00%
!pip install gradio
Collecting gradio Downloading gradio-2.2.7-py3-none-any.whl (2.1 MB) |████████████████████████████████| 2.1 MB 9.7 MB/s Collecting Flask-Login Downloading Flask_Login-0.5.0-py2.py3-none-any.whl (16 kB) Requirement already satisfied: pillow in /usr/local/lib/python3.7/dist-packages (from gradio) (7.1.2) Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from gradio) (1.4.1) Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from gradio) (1.1.5) Collecting pycryptodome Downloading pycryptodome-3.10.1-cp35-abi3-manylinux2010_x86_64.whl (1.9 MB) |████████████████████████████████| 1.9 MB 55.1 MB/s Collecting Flask-Cors>=3.0.8 Downloading Flask_Cors-3.0.10-py2.py3-none-any.whl (14 kB) Collecting ffmpy Downloading ffmpy-0.3.0.tar.gz (4.8 kB) Collecting flask-cachebuster Downloading Flask-CacheBuster-1.0.0.tar.gz (3.1 kB) Requirement already satisfied: Flask>=1.1.1 in /usr/local/lib/python3.7/dist-packages (from gradio) (1.1.4) Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from gradio) (1.19.5) Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from gradio) (2.23.0) Collecting paramiko Downloading paramiko-2.7.2-py2.py3-none-any.whl (206 kB) |████████████████████████████████| 206 kB 44.7 MB/s Collecting markdown2 Downloading markdown2-2.4.0-py2.py3-none-any.whl (34 kB) Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from gradio) (3.2.2) Collecting analytics-python Downloading analytics_python-1.4.0-py2.py3-none-any.whl (15 kB) Requirement already satisfied: Werkzeug<2.0,>=0.15 in /usr/local/lib/python3.7/dist-packages (from Flask>=1.1.1->gradio) (1.0.1) Requirement already satisfied: click<8.0,>=5.1 in /usr/local/lib/python3.7/dist-packages (from Flask>=1.1.1->gradio) (7.1.2) Requirement already satisfied: Jinja2<3.0,>=2.10.1 in /usr/local/lib/python3.7/dist-packages (from Flask>=1.1.1->gradio) (2.11.3) Requirement already satisfied: itsdangerous<2.0,>=0.24 in /usr/local/lib/python3.7/dist-packages (from Flask>=1.1.1->gradio) (1.1.0) Requirement already satisfied: Six in /usr/local/lib/python3.7/dist-packages (from Flask-Cors>=3.0.8->gradio) (1.15.0) Requirement already satisfied: MarkupSafe>=0.23 in /usr/local/lib/python3.7/dist-packages (from Jinja2<3.0,>=2.10.1->Flask>=1.1.1->gradio) (2.0.1) Collecting backoff==1.10.0 Downloading backoff-1.10.0-py2.py3-none-any.whl (31 kB) Collecting monotonic>=1.5 Downloading monotonic-1.6-py2.py3-none-any.whl (8.2 kB) Requirement already satisfied: python-dateutil>2.1 in /usr/local/lib/python3.7/dist-packages (from analytics-python->gradio) (2.8.1) Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->gradio) (2021.5.30) Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->gradio) (3.0.4) Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->gradio) (1.24.3) Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->gradio) (2.10) Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->gradio) (1.3.1) Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->gradio) (0.10.0) Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->gradio) (2.4.7) Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.7/dist-packages (from pandas->gradio) (2018.9) Collecting pynacl>=1.0.1 Downloading PyNaCl-1.4.0-cp35-abi3-manylinux1_x86_64.whl (961 kB) |████████████████████████████████| 961 kB 69.6 MB/s Collecting cryptography>=2.5 Downloading cryptography-3.4.7-cp36-abi3-manylinux2014_x86_64.whl (3.2 MB) |████████████████████████████████| 3.2 MB 52.0 MB/s Collecting bcrypt>=3.1.3 Downloading bcrypt-3.2.0-cp36-abi3-manylinux2010_x86_64.whl (63 kB) |████████████████████████████████| 63 kB 2.9 MB/s Requirement already satisfied: cffi>=1.1 in /usr/local/lib/python3.7/dist-packages (from bcrypt>=3.1.3->paramiko->gradio) (1.14.6) Requirement already satisfied: pycparser in /usr/local/lib/python3.7/dist-packages (from cffi>=1.1->bcrypt>=3.1.3->paramiko->gradio) (2.20) Building wheels for collected packages: ffmpy, flask-cachebuster Building wheel for ffmpy (setup.py) ... done Created wheel for ffmpy: filename=ffmpy-0.3.0-py3-none-any.whl size=4709 sha256=f45d59a509f070a73332ae7d96e939338e3c603a302f543938a772511f4b286b Stored in directory: /root/.cache/pip/wheels/13/e4/6c/e8059816e86796a597c6e6b0d4c880630f51a1fcfa0befd5e6 Building wheel for flask-cachebuster (setup.py) ... done Created wheel for flask-cachebuster: filename=Flask_CacheBuster-1.0.0-py3-none-any.whl size=3372 sha256=59ee6056a161976b0a5f557c164f83f365dfc9a3f5ea59874b53ac08473cdf24 Stored in directory: /root/.cache/pip/wheels/28/c0/c4/44687421dab41455be93112bd1b0dee1f3c5a9aa27bee63708 Successfully built ffmpy flask-cachebuster Installing collected packages: pynacl, monotonic, cryptography, bcrypt, backoff, pycryptodome, paramiko, markdown2, Flask-Login, Flask-Cors, flask-cachebuster, ffmpy, analytics-python, gradio Successfully installed Flask-Cors-3.0.10 Flask-Login-0.5.0 analytics-python-1.4.0 backoff-1.10.0 bcrypt-3.2.0 cryptography-3.4.7 ffmpy-0.3.0 flask-cachebuster-1.0.0 gradio-2.2.7 markdown2-2.4.0 monotonic-1.6 paramiko-2.7.2 pycryptodome-3.10.1 pynacl-1.4.0
import gradio as gr
labels = ['貓', '狗']
def classify_image(inp):
inp = inp.reshape((-1, 224, 224, 3))
inp = preprocess_input(inp)
p = model.predict(inp).flatten()[0]
return {'貓': float(1-p), '狗': float(p)}
image = gr.inputs.Image(shape=(224, 224), label="狗或貓的照片")
label = gr.outputs.Label(label="AI辨識結果")
gr.Interface(fn=classify_image, inputs=image, outputs=label,
title="AI 狗貓辨識機",
description="請輸入一張狗或貓的照片, 看我是否分得出來!"
).launch(debug=True)
Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch(). This share link will expire in 24 hours. If you need a permanent link, visit: https://gradio.app/introducing-hosted (NEW!) Running on External URL: https://24744.gradio.app Interface loading below...