!pip install -U tensorflow-probability
!pip install typing-extensions --upgrade
!pip install gradio
!pip install langchain
Requirement already satisfied: tensorflow-probability in /usr/local/lib/python3.10/dist-packages (0.22.0) Collecting tensorflow-probability Downloading tensorflow_probability-0.23.0-py2.py3-none-any.whl (6.9 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.9/6.9 MB 53.9 MB/s eta 0:00:00 Requirement already satisfied: absl-py in /usr/local/lib/python3.10/dist-packages (from tensorflow-probability) (1.4.0) Requirement already satisfied: six>=1.10.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow-probability) (1.16.0) Requirement already satisfied: numpy>=1.13.3 in /usr/local/lib/python3.10/dist-packages (from tensorflow-probability) (1.23.5) Requirement already satisfied: decorator in /usr/local/lib/python3.10/dist-packages (from tensorflow-probability) (4.4.2) Requirement already satisfied: cloudpickle>=1.3 in /usr/local/lib/python3.10/dist-packages (from tensorflow-probability) (2.2.1) Requirement already satisfied: gast>=0.3.2 in /usr/local/lib/python3.10/dist-packages (from tensorflow-probability) (0.5.4) Requirement already satisfied: dm-tree in /usr/local/lib/python3.10/dist-packages (from tensorflow-probability) (0.1.8) Installing collected packages: tensorflow-probability Attempting uninstall: tensorflow-probability Found existing installation: tensorflow-probability 0.22.0 Uninstalling tensorflow-probability-0.22.0: Successfully uninstalled tensorflow-probability-0.22.0 Successfully installed tensorflow-probability-0.23.0 Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (4.5.0) Collecting typing-extensions Downloading typing_extensions-4.8.0-py3-none-any.whl (31 kB) Installing collected packages: typing-extensions Attempting uninstall: typing-extensions Found existing installation: typing_extensions 4.5.0 Uninstalling typing_extensions-4.5.0: Successfully uninstalled typing_extensions-4.5.0 Successfully installed typing-extensions-4.8.0 Collecting gradio Downloading gradio-4.7.1-py3-none-any.whl (16.5 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 16.5/16.5 MB 66.4 MB/s eta 0:00:00 Collecting aiofiles<24.0,>=22.0 (from gradio) Downloading aiofiles-23.2.1-py3-none-any.whl (15 kB) Requirement already satisfied: altair<6.0,>=4.2.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (4.2.2) Collecting fastapi (from gradio) Downloading fastapi-0.104.1-py3-none-any.whl (92 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 92.9/92.9 kB 11.1 MB/s eta 0:00:00 Collecting ffmpy (from gradio) Downloading ffmpy-0.3.1.tar.gz (5.5 kB) Preparing metadata (setup.py) ... done Collecting gradio-client==0.7.0 (from gradio) Downloading gradio_client-0.7.0-py3-none-any.whl (302 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 302.7/302.7 kB 35.4 MB/s eta 0:00:00 Collecting httpx (from gradio) Downloading httpx-0.25.2-py3-none-any.whl (74 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 75.0/75.0 kB 11.3 MB/s eta 0:00:00 Requirement already satisfied: huggingface-hub>=0.14.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (0.19.4) Requirement already satisfied: importlib-resources<7.0,>=1.3 in /usr/local/lib/python3.10/dist-packages (from gradio) (6.1.1) Requirement already satisfied: jinja2<4.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (3.1.2) Requirement already satisfied: markupsafe~=2.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (2.1.3) Requirement already satisfied: matplotlib~=3.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (3.7.1) Requirement already satisfied: numpy~=1.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (1.23.5) Collecting orjson~=3.0 (from gradio) Downloading orjson-3.9.10-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (138 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 138.7/138.7 kB 20.0 MB/s eta 0:00:00 Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from gradio) (23.2) Requirement already satisfied: pandas<3.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (1.5.3) Requirement already satisfied: pillow<11.0,>=8.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (9.4.0) Collecting pydantic>=2.0 (from gradio) Downloading pydantic-2.5.2-py3-none-any.whl (381 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 381.9/381.9 kB 45.4 MB/s eta 0:00:00 Collecting pydub (from gradio) Downloading pydub-0.25.1-py2.py3-none-any.whl (32 kB) Collecting python-multipart (from gradio) Downloading python_multipart-0.0.6-py3-none-any.whl (45 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 45.7/45.7 kB 6.5 MB/s eta 0:00:00 Requirement already satisfied: pyyaml<7.0,>=5.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (6.0.1) Requirement already satisfied: requests~=2.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (2.31.0) Collecting semantic-version~=2.0 (from gradio) Downloading semantic_version-2.10.0-py2.py3-none-any.whl (15 kB) Collecting tomlkit==0.12.0 (from gradio) Downloading tomlkit-0.12.0-py3-none-any.whl (37 kB) Requirement already satisfied: typer[all]<1.0,>=0.9 in /usr/local/lib/python3.10/dist-packages (from gradio) (0.9.0) Requirement already satisfied: typing-extensions~=4.0 in /usr/local/lib/python3.10/dist-packages (from gradio) (4.8.0) Collecting uvicorn>=0.14.0 (from gradio) Downloading uvicorn-0.24.0.post1-py3-none-any.whl (59 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 59.7/59.7 kB 8.7 MB/s eta 0:00:00 Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from gradio-client==0.7.0->gradio) (2023.6.0) Collecting websockets<12.0,>=10.0 (from gradio-client==0.7.0->gradio) Downloading websockets-11.0.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (129 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 129.9/129.9 kB 19.1 MB/s eta 0:00:00 Requirement already satisfied: entrypoints in /usr/local/lib/python3.10/dist-packages (from altair<6.0,>=4.2.0->gradio) (0.4) Requirement already satisfied: jsonschema>=3.0 in /usr/local/lib/python3.10/dist-packages (from altair<6.0,>=4.2.0->gradio) (4.19.2) Requirement already satisfied: toolz in /usr/local/lib/python3.10/dist-packages (from altair<6.0,>=4.2.0->gradio) (0.12.0) Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.14.0->gradio) (3.13.1) Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.14.0->gradio) (4.66.1) Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib~=3.0->gradio) (1.2.0) Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib~=3.0->gradio) (0.12.1) Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib~=3.0->gradio) (4.44.3) Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib~=3.0->gradio) (1.4.5) Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib~=3.0->gradio) (3.1.1) Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.10/dist-packages (from matplotlib~=3.0->gradio) (2.8.2) Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas<3.0,>=1.0->gradio) (2023.3.post1) Collecting annotated-types>=0.4.0 (from pydantic>=2.0->gradio) Downloading annotated_types-0.6.0-py3-none-any.whl (12 kB) Collecting pydantic-core==2.14.5 (from pydantic>=2.0->gradio) Downloading pydantic_core-2.14.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.1 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.1/2.1 MB 92.5 MB/s eta 0:00:00 Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests~=2.0->gradio) (3.3.2) Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests~=2.0->gradio) (3.4) Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests~=2.0->gradio) (2.0.7) Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests~=2.0->gradio) (2023.7.22) Requirement already satisfied: click<9.0.0,>=7.1.1 in /usr/local/lib/python3.10/dist-packages (from typer[all]<1.0,>=0.9->gradio) (8.1.7) Collecting colorama<0.5.0,>=0.4.3 (from typer[all]<1.0,>=0.9->gradio) Downloading colorama-0.4.6-py2.py3-none-any.whl (25 kB) Collecting shellingham<2.0.0,>=1.3.0 (from typer[all]<1.0,>=0.9->gradio) Downloading shellingham-1.5.4-py2.py3-none-any.whl (9.8 kB) Requirement already satisfied: rich<14.0.0,>=10.11.0 in /usr/local/lib/python3.10/dist-packages (from typer[all]<1.0,>=0.9->gradio) (13.7.0) Collecting h11>=0.8 (from uvicorn>=0.14.0->gradio) Downloading h11-0.14.0-py3-none-any.whl (58 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 58.3/58.3 kB 8.6 MB/s eta 0:00:00 Requirement already satisfied: anyio<4.0.0,>=3.7.1 in /usr/local/lib/python3.10/dist-packages (from fastapi->gradio) (3.7.1) Collecting starlette<0.28.0,>=0.27.0 (from fastapi->gradio) Downloading starlette-0.27.0-py3-none-any.whl (66 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 67.0/67.0 kB 9.8 MB/s eta 0:00:00 Collecting httpcore==1.* (from httpx->gradio) Downloading httpcore-1.0.2-py3-none-any.whl (76 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 76.9/76.9 kB 11.5 MB/s eta 0:00:00 Requirement already satisfied: sniffio in /usr/local/lib/python3.10/dist-packages (from httpx->gradio) (1.3.0) Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<4.0.0,>=3.7.1->fastapi->gradio) (1.1.3) Requirement already satisfied: attrs>=22.2.0 in /usr/local/lib/python3.10/dist-packages (from jsonschema>=3.0->altair<6.0,>=4.2.0->gradio) (23.1.0) Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.10/dist-packages (from jsonschema>=3.0->altair<6.0,>=4.2.0->gradio) (2023.11.1) Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.10/dist-packages (from jsonschema>=3.0->altair<6.0,>=4.2.0->gradio) (0.31.0) Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.10/dist-packages (from jsonschema>=3.0->altair<6.0,>=4.2.0->gradio) (0.13.0) Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.7->matplotlib~=3.0->gradio) (1.16.0) Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich<14.0.0,>=10.11.0->typer[all]<1.0,>=0.9->gradio) (3.0.0) Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich<14.0.0,>=10.11.0->typer[all]<1.0,>=0.9->gradio) (2.16.1) Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich<14.0.0,>=10.11.0->typer[all]<1.0,>=0.9->gradio) (0.1.2) Building wheels for collected packages: ffmpy Building wheel for ffmpy (setup.py) ... done Created wheel for ffmpy: filename=ffmpy-0.3.1-py3-none-any.whl size=5579 sha256=df0d74d69a9fc1979c4c971fdb993a59693bad106cb5b1150e7b73dd57b39e04 Stored in directory: /root/.cache/pip/wheels/01/a6/d1/1c0828c304a4283b2c1639a09ad86f83d7c487ef34c6b4a1bf Successfully built ffmpy Installing collected packages: pydub, ffmpy, websockets, tomlkit, shellingham, semantic-version, python-multipart, pydantic-core, orjson, h11, colorama, annotated-types, aiofiles, uvicorn, starlette, pydantic, httpcore, httpx, fastapi, gradio-client, gradio Attempting uninstall: pydantic Found existing installation: pydantic 1.10.13 Uninstalling pydantic-1.10.13: Successfully uninstalled pydantic-1.10.13 ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts. lida 0.0.10 requires kaleido, which is not installed. llmx 0.0.15a0 requires cohere, which is not installed. llmx 0.0.15a0 requires openai, which is not installed. llmx 0.0.15a0 requires tiktoken, which is not installed. Successfully installed aiofiles-23.2.1 annotated-types-0.6.0 colorama-0.4.6 fastapi-0.104.1 ffmpy-0.3.1 gradio-4.7.1 gradio-client-0.7.0 h11-0.14.0 httpcore-1.0.2 httpx-0.25.2 orjson-3.9.10 pydantic-2.5.2 pydantic-core-2.14.5 pydub-0.25.1 python-multipart-0.0.6 semantic-version-2.10.0 shellingham-1.5.4 starlette-0.27.0 tomlkit-0.12.0 uvicorn-0.24.0.post1 websockets-11.0.3 Collecting langchain Downloading langchain-0.0.340-py3-none-any.whl (2.0 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.0/2.0 MB 29.7 MB/s eta 0:00:00 Requirement already satisfied: PyYAML>=5.3 in /usr/local/lib/python3.10/dist-packages (from langchain) (6.0.1) Requirement already satisfied: SQLAlchemy<3,>=1.4 in /usr/local/lib/python3.10/dist-packages (from langchain) (2.0.23) Requirement already satisfied: aiohttp<4.0.0,>=3.8.3 in /usr/local/lib/python3.10/dist-packages (from langchain) (3.8.6) Requirement already satisfied: anyio<4.0 in /usr/local/lib/python3.10/dist-packages (from langchain) (3.7.1) Requirement already satisfied: async-timeout<5.0.0,>=4.0.0 in /usr/local/lib/python3.10/dist-packages (from langchain) (4.0.3) Collecting dataclasses-json<0.7,>=0.5.7 (from langchain) Downloading dataclasses_json-0.6.2-py3-none-any.whl (28 kB) Collecting jsonpatch<2.0,>=1.33 (from langchain) Downloading jsonpatch-1.33-py2.py3-none-any.whl (12 kB) Collecting langsmith<0.1.0,>=0.0.63 (from langchain) Downloading langsmith-0.0.66-py3-none-any.whl (46 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 46.8/46.8 kB 135.8 kB/s eta 0:00:00 Requirement already satisfied: numpy<2,>=1 in /usr/local/lib/python3.10/dist-packages (from langchain) (1.23.5) Requirement already satisfied: pydantic<3,>=1 in /usr/local/lib/python3.10/dist-packages (from langchain) (2.5.2) Requirement already satisfied: requests<3,>=2 in /usr/local/lib/python3.10/dist-packages (from langchain) (2.31.0) Requirement already satisfied: tenacity<9.0.0,>=8.1.0 in /usr/local/lib/python3.10/dist-packages (from langchain) (8.2.3) Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (23.1.0) Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (3.3.2) Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (6.0.4) Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (1.9.2) Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (1.4.0) Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (1.3.1) Requirement already satisfied: idna>=2.8 in /usr/local/lib/python3.10/dist-packages (from anyio<4.0->langchain) (3.4) Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.10/dist-packages (from anyio<4.0->langchain) (1.3.0) Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<4.0->langchain) (1.1.3) Collecting marshmallow<4.0.0,>=3.18.0 (from dataclasses-json<0.7,>=0.5.7->langchain) Downloading marshmallow-3.20.1-py3-none-any.whl (49 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 49.4/49.4 kB 6.8 MB/s eta 0:00:00 Collecting typing-inspect<1,>=0.4.0 (from dataclasses-json<0.7,>=0.5.7->langchain) Downloading typing_inspect-0.9.0-py3-none-any.whl (8.8 kB) Collecting jsonpointer>=1.9 (from jsonpatch<2.0,>=1.33->langchain) Downloading jsonpointer-2.4-py2.py3-none-any.whl (7.8 kB) Requirement already satisfied: annotated-types>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from pydantic<3,>=1->langchain) (0.6.0) Requirement already satisfied: pydantic-core==2.14.5 in /usr/local/lib/python3.10/dist-packages (from pydantic<3,>=1->langchain) (2.14.5) Requirement already satisfied: typing-extensions>=4.6.1 in /usr/local/lib/python3.10/dist-packages (from pydantic<3,>=1->langchain) (4.8.0) Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2->langchain) (2.0.7) Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2->langchain) (2023.7.22) Requirement already satisfied: greenlet!=0.4.17 in /usr/local/lib/python3.10/dist-packages (from SQLAlchemy<3,>=1.4->langchain) (3.0.1) Requirement already satisfied: packaging>=17.0 in /usr/local/lib/python3.10/dist-packages (from marshmallow<4.0.0,>=3.18.0->dataclasses-json<0.7,>=0.5.7->langchain) (23.2) Collecting mypy-extensions>=0.3.0 (from typing-inspect<1,>=0.4.0->dataclasses-json<0.7,>=0.5.7->langchain) Downloading mypy_extensions-1.0.0-py3-none-any.whl (4.7 kB) Installing collected packages: mypy-extensions, marshmallow, jsonpointer, typing-inspect, jsonpatch, langsmith, dataclasses-json, langchain Successfully installed dataclasses-json-0.6.2 jsonpatch-1.33 jsonpointer-2.4 langchain-0.0.340 langsmith-0.0.66 marshmallow-3.20.1 mypy-extensions-1.0.0 typing-inspect-0.9.0
這裡我們讀入一些套件, 今天暫時不要理會細節。
%matplotlib inline
# 標準數據分析、畫圖套件
import numpy as np
import matplotlib.pyplot as plt
# 神經網路方面
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import SGD
# 互動設計用
from ipywidgets import interact_manual
Keras 很貼心的幫我們準備好 MNIST 數據庫, 我們可以這樣讀進來 (第一次要花點時間)。
(x_train, y_train), (x_test, y_test) = mnist.load_data()
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz 11490434/11490434 [==============================] - 2s 0us/step
我們來看看訓練資料是不是 6 萬筆、測試資料是不是有 1 筆。
print(f'訓練資料總筆數為 {len(x_train)} 筆資料')
print(f'測試資料總筆數為 {len(x_test)} 筆資料')
訓練資料總筆數為 60000 筆資料 測試資料總筆數為 10000 筆資料
每筆輸入 (x) 就是一個手寫的 0-9 中一個數字的圖檔, 大小為 28x28。而輸出 (y) 當然就是「正確答案」。我們來看看編訓練資料的 x 輸入、輸出的部份分別長什麼樣子。
def show_xy(n=0):
ax = plt.gca()
X = x_train[n]
plt.xticks([], [])
plt.yticks([], [])
plt.imshow(X, cmap = 'Greys')
print(f'本資料 y 給定的答案為: {y_train[n]}')
interact_manual(show_xy, n=(0,59999));
interactive(children=(IntSlider(value=0, description='n', max=59999), Button(description='Run Interact', style…
def show_data(n = 100):
X = x_train[n]
print(X)
interact_manual(show_data, n=(0,59999));
interactive(children=(IntSlider(value=100, description='n', max=59999), Button(description='Run Interact', sty…
我們現在要用標準神經網路學學手寫辨識。原來的每筆數據是個 28x28 的矩陣 (array), 但標準神經網路只吃「平平的」, 也就是每次要 28x28=784 長的向量。因此我們要用 reshape
調校一下。
x_train = x_train.reshape(60000, 784)/255
x_test = x_test.reshape(10000, 784)/255
我們可能會想, 我們想學的函數是這樣的型式:
$$\hat{f} \colon \mathbb{R}^{784} \to \mathbb{R}$$其實這樣不太好! 為什麼呢? 比如說我們的輸入 x 是一張 0 的圖, 因為我們訓練的神經網路總會有點誤差, 所以可能會得到:
$$\hat{f}(x) = 0.5$$那這意思是有可能是 0, 也有可能是 1 嗎!!?? 可是 0 和 1 根本不像啊。換句話說分類的問題這樣做其實不合理!
於是我們會做 "1-hot enconding", 也就是
等等。因為分類問題基本上都要做這件事, Keras 其實已幫我們準備好套件!
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)
我們來看看剛剛某號數據的答案。
n = 87
y_train[n]
array([0., 0., 0., 0., 0., 0., 0., 0., 0., 1.], dtype=float32)
和我們想的一樣! 至此我們可以打造我們的神經網路了。
我們決定了我們的函數是
$$\hat{f} \colon \mathbb{R}^{784} \to \mathbb{R}^{10}$$這個樣子。而我們又說第一次要用標準神網路試試, 所以我們只需要再決定要幾個隱藏層、每層要幾個神經元, 用哪個激發函數就可以了。
假如我們要用 ReLU 當激發函數, 要設計神經網路, 只差要指定多少個隱藏層、每層多少個神經元就好了!
設計完了基本上就是告訴 TensorFlow, 我們的想法就可以了!
#@title 設計你的神經網路
隱藏層數 = 3 #@param{type:"integer"}
神經元1 = 0#@param{type:"integer"}
神經元2 = 0 #@param{type:"integer"}
神經元3 = 0 #@param{type:"integer"}
和以前做迴歸或機器學習一樣, 我們就打開個「函數學習機」。標準一層一層傳遞的神經網路叫 Sequential
, 於是我們打開一個空的神經網路。
model = Sequential()
我們每次用 add
去加一層, 從第一個隱藏層開始。而第一個隱藏層因為 TensorFlow 當然猜不到輸入有 784 個 features, 所以我們要告訴它。
model.add(Dense(20, input_dim=784, activation='relu'))
第二層開始就不用再說明輸入神經元個數 (因為就是前一層神經元數)。
model.add(Dense(20, activation='relu'))
model.add(Dense(20, activation='relu'))
輸出有 10 個數字, 所以輸出層的神經元是 10 個! 而如果我們的網路輸出是
$$(y_1, y_2, \ldots, y_{10})$$我們還希望
$$\sum_{i=1}^{10} y_i = 1$$這可能嗎, 結果是很容易, 就用 softmax
當激發函數就可以!!
model.add(Dense(10, activation='softmax'))
至此我們的第一個神經網路就建好了!
和之前比較不一樣的是我們還要做 compile
才正式把我們的神經網路建好。你可以發現我們還需要做幾件事:
mse
為了一邊訓練一邊看到結果, 我們加設
metrics=['accuracy']
本行基本上和我們的神經網路功能沒有什麼關係。
model.compile(loss='mse', optimizer=SGD(learning_rate=0.087), metrics=['accuracy'])
我們可以檢視我們神經網路的架構, 可以確認一下是不是和我們想像的一樣。
model.summary()
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense (Dense) (None, 20) 15700 dense_1 (Dense) (None, 20) 420 dense_2 (Dense) (None, 20) 420 dense_3 (Dense) (None, 10) 210 ================================================================= Total params: 16750 (65.43 KB) Trainable params: 16750 (65.43 KB) Non-trainable params: 0 (0.00 Byte) _________________________________________________________________
很快算算參數數目和我們想像是否是一樣的!
恭喜! 我們完成了第一個神經網路。現在要訓練的時候, 你會發現不是像以前沒頭沒腦把訓練資料送進去就好。這裡我們還有兩件事要決定:
batch_size
), 我們就 100 筆調一次參數好了epochs
), 我們訓練個 10 次試試於是最精彩的就來了。你要有等待的心理準備...
model.fit(x_train, y_train, batch_size=100, epochs=10)
Epoch 1/10 600/600 [==============================] - 9s 5ms/step - loss: 0.0891 - accuracy: 0.1645 Epoch 2/10 600/600 [==============================] - 2s 3ms/step - loss: 0.0860 - accuracy: 0.3095 Epoch 3/10 600/600 [==============================] - 2s 3ms/step - loss: 0.0758 - accuracy: 0.4489 Epoch 4/10 600/600 [==============================] - 2s 3ms/step - loss: 0.0555 - accuracy: 0.6170 Epoch 5/10 600/600 [==============================] - 2s 3ms/step - loss: 0.0369 - accuracy: 0.7638 Epoch 6/10 600/600 [==============================] - 2s 3ms/step - loss: 0.0272 - accuracy: 0.8270 Epoch 7/10 600/600 [==============================] - 3s 5ms/step - loss: 0.0226 - accuracy: 0.8544 Epoch 8/10 600/600 [==============================] - 3s 5ms/step - loss: 0.0201 - accuracy: 0.8699 Epoch 9/10 600/600 [==============================] - 2s 4ms/step - loss: 0.0183 - accuracy: 0.8816 Epoch 10/10 600/600 [==============================] - 2s 3ms/step - loss: 0.0169 - accuracy: 0.8911
<keras.src.callbacks.History at 0x7d8e760d9030>
我們來用比較炫的方式來看看可愛的神經網路學習成果。對指令有問題可以參考《少年Py的大冒險:成為Python數據分析達人的第一門課》。
loss, acc = model.evaluate(x_test, y_test)
313/313 [==============================] - 1s 3ms/step - loss: 0.0156 - accuracy: 0.8997
print(f"測試資料正確率 {acc*100:.2f}%")
測試資料正確率 89.97%
from ipywidgets import interact_manual
我們 "predict" 放的是我們神經網路的學習結果。做完之後用 argmax 找到數值最大的那一項。
predict = np.argmax(model.predict(x_test), axis=-1)
313/313 [==============================] - 1s 2ms/step
predict
array([7, 2, 1, ..., 4, 8, 6])
不要忘了我們的 x_test
每筆資料已經換成 784 維的向量, 我們要整型回 28x28 的矩陣才能當成圖形顯示出來!
def test(測試編號):
plt.imshow(x_test[測試編號].reshape(28,28), cmap='Greys')
print('神經網路判斷為:', predict[測試編號])
test(87)
神經網路判斷為: 3
interact_manual(test, 測試編號=(0, 9999));
interactive(children=(IntSlider(value=4999, description='測試編號', max=9999), Button(description='Run Interact', …
到底測試資料總的狀況如何呢? 我們可以給我們神經網路「總評量」。
score = model.evaluate(x_test, y_test)
313/313 [==============================] - 1s 2ms/step - loss: 0.0156 - accuracy: 0.8997
print('loss:', score[0])
print('正確率', score[1])
loss: 0.01560008805245161 正確率 0.8996999859809875
import gradio as gr
def recognize_digit(img):
img = img.reshape(1,784)
prediction = model.predict(img).flatten()
labels = list('0123456789')
return {labels[i]: float(prediction[i]) for i in range(10)}
gr.Interface(fn=recognize_digit, inputs="sketchpad", outputs="label").launch(share=True,
debug=True)
Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch(). Running on public URL: https://5fd54ac1b42ab99d52.gradio.live This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)
Keyboard interruption in main thread... closing server. Killing tunnel 127.0.0.1:7862 <> https://5fd54ac1b42ab99d52.gradio.live