using IJuliaPortrayals
using MXNet
# CNN構築
# input
data = mx.Variable(:data)
# first conv
conv1 = @mx.chain mx.Convolution(data=data, kernel=(5,5), num_filter=32) =>
mx.Activation(act_type=:relu) =>
mx.Pooling(pool_type=:max, kernel=(2,2), stride=(2,2))
# second conv
conv2 = @mx.chain mx.Convolution(data=conv1, kernel=(5,5), num_filter=64) =>
mx.Activation(act_type=:relu) =>
mx.Pooling(pool_type=:max, kernel=(2,2), stride=(2,2))
# first fully-connected
fc1 = @mx.chain mx.Flatten(data=conv2) =>
mx.FullyConnected(num_hidden=1024) =>
mx.Activation(act_type=:relu)
dp_fc1 = mx.Dropout(fc1, p=0.5)
# second fully-connected
fc2 = mx.FullyConnected(data=dp_fc1, num_hidden=10)
# softmax loss
cnn = mx.SoftmaxOutput(data=fc2, name=:softmax)
MXNet.mx.SymbolicNode(MXNet.mx.MX_SymbolHandle(Ptr{Void} @0x00007f828448a390))
GraphViz(mx.to_graphviz(cnn))
# データ取得(データプロバイダ生成)
batch_size = 100
# include(Pkg.dir("MXNet", "examples", "mnist", "mnist-data.jl"))
# train_provider, eval_provider = get_mnist_providers(batch_size)
data_name = :data
label_name = :softmax_label
flat=false
train_provider = mx.MNISTProvider(image="MNIST_data/train-images-idx3-ubyte",
label="MNIST_data/train-labels-idx1-ubyte",
data_name=data_name, label_name=label_name,
batch_size=batch_size, shuffle=true, flat=flat, silent=true)
eval_provider = mx.MNISTProvider(image="MNIST_data/t10k-images-idx3-ubyte",
label="MNIST_data/t10k-labels-idx1-ubyte",
data_name=data_name, label_name=label_name,
batch_size=batch_size, shuffle=false, flat=flat, silent=true)
MXNet.mx.MXDataProvider(MXNet.mx.MX_DataIterHandle(Ptr{Void} @0x00007f82848a5910),Tuple{Symbol,Tuple}[(:data,(28,28,1,100))],Tuple{Symbol,Tuple}[(:softmax_label,(100,))],100,true,true)
# モデル構築・最適化
# モデル setup
model = mx.FeedForward(cnn, context=mx.cpu())
# optimization algorithm
optimizer = mx.SGD(lr=0.05, momentum=0.9, weight_decay=0.00001)
# save-checkpoint callback
save_checkpoint = mx.do_checkpoint("MNIST_CNN3")
# fit parameters
mx.fit(model, optimizer, train_provider, n_epoch=30, eval_data=eval_provider, callbacks=[save_checkpoint])
INFO: Start training on [CPU0] INFO: Initializing parameters... INFO: Creating KVStore... INFO: Start training... INFO: == Epoch 001 ========== INFO: ## Training summary INFO: accuracy = 0.7142 INFO: time = 223.0584 seconds INFO: ## Validation summary INFO: accuracy = 0.9740 INFO: Saved checkpoint to 'MNIST_CNN3-0001.params' INFO: == Epoch 002 ========== INFO: ## Training summary INFO: accuracy = 0.9804 INFO: time = 208.5880 seconds INFO: ## Validation summary INFO: accuracy = 0.9865 INFO: Saved checkpoint to 'MNIST_CNN3-0002.params' INFO: == Epoch 003 ========== INFO: ## Training summary INFO: accuracy = 0.9869 INFO: time = 211.7859 seconds INFO: ## Validation summary INFO: accuracy = 0.9892 INFO: Saved checkpoint to 'MNIST_CNN3-0003.params' INFO: == Epoch 004 ========== INFO: ## Training summary INFO: accuracy = 0.9898 INFO: time = 206.5210 seconds INFO: ## Validation summary INFO: accuracy = 0.9915 INFO: Saved checkpoint to 'MNIST_CNN3-0004.params' INFO: == Epoch 005 ========== INFO: ## Training summary INFO: accuracy = 0.9917 INFO: time = 206.9616 seconds INFO: ## Validation summary INFO: accuracy = 0.9890 INFO: Saved checkpoint to 'MNIST_CNN3-0005.params' INFO: == Epoch 006 ========== INFO: ## Training summary INFO: accuracy = 0.9928 INFO: time = 206.5133 seconds INFO: ## Validation summary INFO: accuracy = 0.9909 INFO: Saved checkpoint to 'MNIST_CNN3-0006.params' INFO: == Epoch 007 ========== INFO: ## Training summary INFO: accuracy = 0.9936 INFO: time = 206.3797 seconds INFO: ## Validation summary INFO: accuracy = 0.9908 INFO: Saved checkpoint to 'MNIST_CNN3-0007.params' INFO: == Epoch 008 ========== INFO: ## Training summary INFO: accuracy = 0.9945 INFO: time = 206.2362 seconds INFO: ## Validation summary INFO: accuracy = 0.9901 INFO: Saved checkpoint to 'MNIST_CNN3-0008.params' INFO: == Epoch 009 ========== INFO: ## Training summary INFO: accuracy = 0.9951 INFO: time = 206.4188 seconds INFO: ## Validation summary INFO: accuracy = 0.9910 INFO: Saved checkpoint to 'MNIST_CNN3-0009.params' INFO: == Epoch 010 ========== INFO: ## Training summary INFO: accuracy = 0.9962 INFO: time = 206.4152 seconds INFO: ## Validation summary INFO: accuracy = 0.9913 INFO: Saved checkpoint to 'MNIST_CNN3-0010.params' INFO: == Epoch 011 ========== INFO: ## Training summary INFO: accuracy = 0.9966 INFO: time = 206.2151 seconds INFO: ## Validation summary INFO: accuracy = 0.9908 INFO: Saved checkpoint to 'MNIST_CNN3-0011.params' INFO: == Epoch 012 ========== INFO: ## Training summary INFO: accuracy = 0.9964 INFO: time = 206.3259 seconds INFO: ## Validation summary INFO: accuracy = 0.9920 INFO: Saved checkpoint to 'MNIST_CNN3-0012.params' INFO: == Epoch 013 ========== INFO: ## Training summary INFO: accuracy = 0.9967 INFO: time = 206.3057 seconds INFO: ## Validation summary INFO: accuracy = 0.9907 INFO: Saved checkpoint to 'MNIST_CNN3-0013.params' INFO: == Epoch 014 ========== INFO: ## Training summary INFO: accuracy = 0.9974 INFO: time = 206.4475 seconds INFO: ## Validation summary INFO: accuracy = 0.9909 INFO: Saved checkpoint to 'MNIST_CNN3-0014.params' INFO: == Epoch 015 ========== INFO: ## Training summary INFO: accuracy = 0.9979 INFO: time = 206.1776 seconds INFO: ## Validation summary INFO: accuracy = 0.9921 INFO: Saved checkpoint to 'MNIST_CNN3-0015.params' INFO: == Epoch 016 ========== INFO: ## Training summary INFO: accuracy = 0.9979 INFO: time = 206.1897 seconds INFO: ## Validation summary INFO: accuracy = 0.9908 INFO: Saved checkpoint to 'MNIST_CNN3-0016.params' INFO: == Epoch 017 ========== INFO: ## Training summary INFO: accuracy = 0.9980 INFO: time = 206.0748 seconds INFO: ## Validation summary INFO: accuracy = 0.9921 INFO: Saved checkpoint to 'MNIST_CNN3-0017.params' INFO: == Epoch 018 ========== INFO: ## Training summary INFO: accuracy = 0.9983 INFO: time = 205.9644 seconds INFO: ## Validation summary INFO: accuracy = 0.9931 INFO: Saved checkpoint to 'MNIST_CNN3-0018.params' INFO: == Epoch 019 ========== INFO: ## Training summary INFO: accuracy = 0.9982 INFO: time = 208.8835 seconds INFO: ## Validation summary INFO: accuracy = 0.9887 INFO: Saved checkpoint to 'MNIST_CNN3-0019.params' INFO: == Epoch 020 ========== INFO: ## Training summary INFO: accuracy = 0.9976 INFO: time = 205.8203 seconds INFO: ## Validation summary INFO: accuracy = 0.9903 INFO: Saved checkpoint to 'MNIST_CNN3-0020.params' INFO: == Epoch 021 ========== INFO: ## Training summary INFO: accuracy = 0.9984 INFO: time = 206.7411 seconds INFO: ## Validation summary INFO: accuracy = 0.9925 INFO: Saved checkpoint to 'MNIST_CNN3-0021.params' INFO: == Epoch 022 ========== INFO: ## Training summary INFO: accuracy = 0.9988 INFO: time = 206.3165 seconds INFO: ## Validation summary INFO: accuracy = 0.9914 INFO: Saved checkpoint to 'MNIST_CNN3-0022.params' INFO: == Epoch 023 ========== INFO: ## Training summary INFO: accuracy = 0.9983 INFO: time = 207.6620 seconds INFO: ## Validation summary INFO: accuracy = 0.9923 INFO: Saved checkpoint to 'MNIST_CNN3-0023.params' INFO: == Epoch 024 ========== INFO: ## Training summary INFO: accuracy = 0.9977 INFO: time = 205.8147 seconds INFO: ## Validation summary INFO: accuracy = 0.9927 INFO: Saved checkpoint to 'MNIST_CNN3-0024.params' INFO: == Epoch 025 ========== INFO: ## Training summary INFO: accuracy = 0.9986 INFO: time = 205.6267 seconds INFO: ## Validation summary INFO: accuracy = 0.9919 INFO: Saved checkpoint to 'MNIST_CNN3-0025.params' INFO: == Epoch 026 ========== INFO: ## Training summary INFO: accuracy = 0.9986 INFO: time = 1883.2102 seconds INFO: ## Validation summary INFO: accuracy = 0.9924 INFO: Saved checkpoint to 'MNIST_CNN3-0026.params' INFO: == Epoch 027 ========== INFO: ## Training summary INFO: accuracy = 0.9992 INFO: time = 208.2260 seconds INFO: ## Validation summary INFO: accuracy = 0.9930 INFO: Saved checkpoint to 'MNIST_CNN3-0027.params' INFO: == Epoch 028 ========== INFO: ## Training summary INFO: accuracy = 0.9987 INFO: time = 207.1775 seconds INFO: ## Validation summary INFO: accuracy = 0.9925 INFO: Saved checkpoint to 'MNIST_CNN3-0028.params' INFO: == Epoch 029 ========== INFO: ## Training summary INFO: accuracy = 0.9993 INFO: time = 206.7774 seconds INFO: ## Validation summary INFO: accuracy = 0.9932 INFO: Saved checkpoint to 'MNIST_CNN3-0029.params' INFO: == Epoch 030 ========== INFO: ## Training summary INFO: accuracy = 0.9995 INFO: time = 206.6229 seconds INFO: ## Validation summary INFO: accuracy = 0.9928 INFO: Saved checkpoint to 'MNIST_CNN3-0030.params'
# 予測
probs = mx.predict(model, eval_provider)
10x10000 Array{Float32,2}: 6.6912e-25 1.84421e-19 3.34017e-14 … 6.27268e-23 5.81799e-18 7.64353e-21 6.56743e-20 1.0 8.11345e-25 2.53667e-19 5.96325e-20 1.0 7.18634e-14 1.18043e-28 3.24815e-19 7.11701e-20 5.54137e-23 1.69543e-18 2.38322e-18 1.05279e-21 1.03662e-19 1.21309e-25 8.94513e-13 1.3275e-28 1.67418e-19 2.30987e-25 1.73122e-29 1.34319e-14 … 1.0 5.61854e-19 1.45137e-27 7.20162e-20 9.96349e-14 2.55676e-18 1.0 1.0 1.75189e-23 1.66512e-12 7.50446e-26 2.32328e-25 6.29969e-28 4.04601e-22 9.14715e-14 8.67362e-15 5.61445e-20 2.57974e-17 1.27717e-27 3.43039e-16 4.33827e-21 3.61943e-22
# 予測精度確認
# collect all labels from eval data
labels = Array[]
for batch in eval_provider
push!(labels, copy(mx.get(eval_provider, batch, :softmax_label)))
end
labels = cat(1, labels...)
# Now we use compute the accuracy
correct = 0
for i = 1:length(labels)
# labels are 0...9
if indmax(probs[:,i]) == labels[i]+1
correct += 1
end
end
accuracy = 100correct/length(labels)
println(mx.format("Accuracy on eval set: {1:.2f}%", accuracy))
Accuracy on eval set: 99.43%
batch = first(eval_provider)
# images0 = copy(mx.get(eval_provider, batch, :data))
# size(images0)
image = copy(mx.get(eval_provider, batch, :data))[:,:,:,1:1]
# all(x->0.0<=x<=1.0,vec(image))
# => true
# provider = mx.ArrayDataProvider(images[:,1:1])
provider = mx.ArrayDataProvider(image)
MXNet.mx.ArrayDataProvider(Array{Float32,N}[ 784x1 Array{Float32,2}: 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ⋮ 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0],[:data],Array{Float32,N}[],Symbol[],1,1,false,0.0f0,0.0f0,[mx.NDArray(28,28,1,1)],MXNet.mx.NDArray[])
mx.predict(model, provider)
10x1 Array{Float32,2}: 6.6912e-25 7.64353e-21 5.96325e-20 7.11701e-20 1.03662e-19 2.30987e-25 1.45137e-27 1.0 6.29969e-28 2.57974e-17
import JSON
function classify(a::Vector{Float32})
image = reshape(a, (28, 28, 1, 1))
# classify
result = mx.predict(model, mx.ArrayDataProvider(image))
return JSON.json(vec(result))
end
function classify(a::Vector)
classify(convert(Vector{Float32}, a))
end
function classify(s::AbstractString)
classify(JSON.parse(s))
end
classify (generic function with 3 methods)
HTML(open(readall, "classify_canvas.html"))
Result: