MNIST 数据集 是一个手写数字组成的数据集,现在被当作一个机器学习算法评测的基准数据集。
这是一个下载并解压数据的脚本:
%%file download_mnist.py
import os
import os.path
import urllib
import gzip
import shutil
if not os.path.exists('mnist'):
os.mkdir('mnist')
def download_and_gzip(name):
if not os.path.exists(name + '.gz'):
urllib.urlretrieve('http://yann.lecun.com/exdb/' + name + '.gz', name + '.gz')
if not os.path.exists(name):
with gzip.open(name + '.gz', 'rb') as f_in, open(name, 'wb') as f_out:
shutil.copyfileobj(f_in, f_out)
download_and_gzip('mnist/train-images-idx3-ubyte')
download_and_gzip('mnist/train-labels-idx1-ubyte')
download_and_gzip('mnist/t10k-images-idx3-ubyte')
download_and_gzip('mnist/t10k-labels-idx1-ubyte')
Overwriting download_mnist.py
可以运行这个脚本来下载和解压数据:
%run download_mnist.py
使用如下的脚本来导入 MNIST 数据,源码地址:
https://github.com/Newmu/Theano-Tutorials/blob/master/load.py
%%file load.py
import numpy as np
import os
datasets_dir = './'
def one_hot(x,n):
if type(x) == list:
x = np.array(x)
x = x.flatten()
o_h = np.zeros((len(x),n))
o_h[np.arange(len(x)),x] = 1
return o_h
def mnist(ntrain=60000,ntest=10000,onehot=True):
data_dir = os.path.join(datasets_dir,'mnist/')
fd = open(os.path.join(data_dir,'train-images-idx3-ubyte'))
loaded = np.fromfile(file=fd,dtype=np.uint8)
trX = loaded[16:].reshape((60000,28*28)).astype(float)
fd = open(os.path.join(data_dir,'train-labels-idx1-ubyte'))
loaded = np.fromfile(file=fd,dtype=np.uint8)
trY = loaded[8:].reshape((60000))
fd = open(os.path.join(data_dir,'t10k-images-idx3-ubyte'))
loaded = np.fromfile(file=fd,dtype=np.uint8)
teX = loaded[16:].reshape((10000,28*28)).astype(float)
fd = open(os.path.join(data_dir,'t10k-labels-idx1-ubyte'))
loaded = np.fromfile(file=fd,dtype=np.uint8)
teY = loaded[8:].reshape((10000))
trX = trX/255.
teX = teX/255.
trX = trX[:ntrain]
trY = trY[:ntrain]
teX = teX[:ntest]
teY = teY[:ntest]
if onehot:
trY = one_hot(trY, 10)
teY = one_hot(teY, 10)
else:
trY = np.asarray(trY)
teY = np.asarray(teY)
return trX,teX,trY,teY
Overwriting load.py
Softmax
回归相当于 Logistic
回归的一个一般化,Logistic
回归处理的是两类问题,Softmax
回归处理的是 N
类问题。
Logistic
回归输出的是标签为 1 的概率(标签为 0 的概率也就知道了),对应地,对 N 类问题 Softmax
输出的是每个类对应的概率。
具体的内容,可以参考 UFLDL
教程:
http://ufldl.stanford.edu/wiki/index.php/Softmax%E5%9B%9E%E5%BD%92
import theano
from theano import tensor as T
import numpy as np
from load import mnist
Using gpu device 1: Tesla C2075 (CNMeM is disabled)
我们来看它具体的实现。
这两个函数一个是将数据转化为 GPU
计算的类型,另一个是初始化权重:
def floatX(X):
return np.asarray(X, dtype=theano.config.floatX)
def init_weights(shape):
return theano.shared(floatX(np.random.randn(*shape) * 0.01))
Softmax
的模型在 theano
中已经实现好了:
A = T.matrix()
B = T.nnet.softmax(A)
test_softmax = theano.function([A], B)
a = floatX(np.random.rand(3, 4))
b = test_softmax(a)
print b.shape
# 行和
print b.sum(1)
(3, 4) [ 1.00000012 1. 1. ]
softmax
函数会按照行对矩阵进行 Softmax
归一化。
所以我们的模型为:
def model(X, w):
return T.nnet.softmax(T.dot(X, w))
导入数据:
trX, teX, trY, teY = mnist(onehot=True)
定义变量,并初始化权重:
X = T.fmatrix()
Y = T.fmatrix()
w = init_weights((784, 10))
定义模型输出和预测:
py_x = model(X, w)
y_pred = T.argmax(py_x, axis=1)
损失函数为多类的交叉熵,这个在 theano
中也被定义好了:
cost = T.mean(T.nnet.categorical_crossentropy(py_x, Y))
gradient = T.grad(cost=cost, wrt=w)
update = [[w, w - gradient * 0.05]]
编译 train
和 predict
函数:
train = theano.function(inputs=[X, Y], outputs=cost, updates=update, allow_input_downcast=True)
predict = theano.function(inputs=[X], outputs=y_pred, allow_input_downcast=True)
迭代 100 次,测试集正确率为 0.925:
for i in range(100):
for start, end in zip(range(0, len(trX), 128), range(128, len(trX), 128)):
cost = train(trX[start:end], trY[start:end])
print "{0:03d}".format(i), np.mean(np.argmax(teY, axis=1) == predict(teX))
000 0.8862 001 0.8985 002 0.9042 003 0.9084 004 0.9104 005 0.9121 006 0.9121 007 0.9142 008 0.9158 009 0.9163 010 0.9162 011 0.9166 012 0.9171 013 0.9176 014 0.9182 015 0.9182 016 0.9184 017 0.9188 018 0.919 019 0.919 020 0.9194 021 0.9201 022 0.9204 023 0.9203 024 0.9205 025 0.9207 026 0.9207 027 0.9209 028 0.9214 029 0.9213 030 0.9212 031 0.9211 032 0.9217 033 0.9217 034 0.9217 035 0.922 036 0.9222 037 0.922 038 0.922 039 0.9218 040 0.9219 041 0.9223 042 0.9225 043 0.9226 044 0.9227 045 0.9225 046 0.9227 047 0.9231 048 0.9231 049 0.9231 050 0.9232 051 0.9232 052 0.9231 053 0.9231 054 0.9233 055 0.9233 056 0.9237 057 0.9239 058 0.9239 059 0.9239 060 0.924 061 0.9242 062 0.9242 063 0.9243 064 0.9243 065 0.9244 066 0.9244 067 0.9244 068 0.9245 069 0.9244 070 0.9244 071 0.9245 072 0.9244 073 0.9243 074 0.9243 075 0.9244 076 0.9243 077 0.9242 078 0.9244 079 0.9244 080 0.9243 081 0.9242 082 0.9239 083 0.9241 084 0.9242 085 0.9243 086 0.9244 087 0.9243 088 0.9243 089 0.9244 090 0.9246 091 0.9246 092 0.9246 093 0.9247 094 0.9246 095 0.9246 096 0.9246 097 0.9246 098 0.9246 099 0.9248