import numpy as np
import pandas as pd
import torch
from IPython.display import clear_output
from torch import nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
# download and transform train dataset
train_loader = torch.utils.data.DataLoader(datasets.MNIST('../mnist_data',
download=True,
train=True,
transform=transforms.Compose([
transforms.ToTensor(), # first, convert image to PyTorch tensor
transforms.Normalize((0.1307,), (0.3081,)) # normalize inputs
])),
batch_size=10,
shuffle=True)
# download and transform test dataset
test_loader = torch.utils.data.DataLoader(datasets.MNIST('../mnist_data',
download=True,
train=False,
transform=transforms.Compose([
transforms.ToTensor(), # first, convert image to PyTorch tensor
transforms.Normalize((0.1307,), (0.3081,)) # normalize inputs
])),
batch_size=10,
shuffle=True)
class CNNClassifier(nn.Module):
"""Custom module for a simple convnet classifier"""
def __init__(self):
super(CNNClassifier, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.dropout = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
# input is 28x28x1
# conv1(kernel=5, filters=10) 28x28x10 -> 24x24x10
# max_pool(kernel=2) 24x24x10 -> 12x12x10
# Do not be afraid of F's - those are just functional wrappers for modules form nn package
# Please, see for yourself - http://pytorch.org/docs/_modules/torch/nn/functional.html
x = F.relu(F.max_pool2d(self.conv1(x), 2))
# conv2(kernel=5, filters=20) 12x12x20 -> 8x8x20
# max_pool(kernel=2) 8x8x20 -> 4x4x20
x = F.relu(F.max_pool2d(self.dropout(self.conv2(x)), 2))
# flatten 4x4x20 = 320
x = x.view(-1, 320)
# 320 -> 50
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
# 50 -> 10
x = self.fc2(x)
# transform to logits
return F.log_softmax(x)
# create classifier and optimizer objects
clf = CNNClassifier()
opt = torch.optim.Adam(clf.parameters(), lr=1e-3)
loss_history = []
acc_history = []
avg_loss = []
def train(epoch):
clf.train() # set model in training mode (need this because of dropout)
# dataset API gives us pythonic batching
for batch_id, (data, label) in enumerate(train_loader):
data = Variable(data)
target = Variable(label)
# forward pass, calculate loss and backprop!
opt.zero_grad()
preds = clf(data)
loss = F.nll_loss(preds, target)
loss.backward()
loss_history.append(loss.data[0])
opt.step()
if batch_id % 100 == 0:
print(loss.data[0])
def test(epoch):
clf.eval() # set model in inference mode (need this because of dropout)
test_loss = 0
correct = 0
for data, target in test_loader:
data = Variable(data, volatile=True)
target = Variable(target)
output = clf(data)
test_loss += F.nll_loss(output, target).data[0]
pred = output.data.max(1)[1] # get the index of the max log-probability
correct += pred.eq(target.data).cpu().sum()
test_loss = test_loss
test_loss /= len(test_loader) # loss function already averages over batch size
accuracy = 100. * correct / len(test_loader.dataset)
acc_history.append(accuracy)
avg_loss.append(test_loss)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
accuracy))
for epoch in range(0, 10):
print("Epoch %d" % epoch)
train(epoch)
test(epoch)
Epoch 0 tensor(2.4586)
/Users/tomininvladislav/env/lib/python3.6/site-packages/ipykernel_launcher.py:67: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument. /Users/tomininvladislav/env/lib/python3.6/site-packages/ipykernel_launcher.py:90: UserWarning: invalid index of a 0-dim tensor. This will be an error in PyTorch 0.5. Use tensor.item() to convert a 0-dim tensor to a Python number /Users/tomininvladislav/env/lib/python3.6/site-packages/ipykernel_launcher.py:94: UserWarning: invalid index of a 0-dim tensor. This will be an error in PyTorch 0.5. Use tensor.item() to convert a 0-dim tensor to a Python number
tensor(1.7086) tensor(0.5826) tensor(0.2692) tensor(0.2487) tensor(0.4166) tensor(0.5672) tensor(0.5756) tensor(0.6888) tensor(0.6057) tensor(0.5696) tensor(0.2001) tensor(0.3525) tensor(0.4945) tensor(0.2618) tensor(0.5547) tensor(0.7051) tensor(0.2723) tensor(0.1821) tensor(0.1595) tensor(0.1177) tensor(0.3860) tensor(1.4795) tensor(0.2476) tensor(0.5320) tensor(0.0635) tensor(0.3147) tensor(0.1732) tensor(0.1785) tensor(0.3657) tensor(0.0507) tensor(0.7144) tensor(1.2825) tensor(0.1025) tensor(0.1407) tensor(0.1856) tensor(0.3272) tensor(0.2267) tensor(0.3201) tensor(0.4373) tensor(0.3322) tensor(0.3937) tensor(0.6058) tensor(0.5727) tensor(0.0127) tensor(0.5363) tensor(0.1845) tensor(0.0272) tensor(0.3691) tensor(0.2913) tensor(0.0427) tensor(0.5855) tensor(0.9825) tensor(0.1314) tensor(0.0227) tensor(0.1730) tensor(0.1166) tensor(0.0372) tensor(0.2356) tensor(0.0754)
/Users/tomininvladislav/env/lib/python3.6/site-packages/ipykernel_launcher.py:102: UserWarning: volatile was removed and now has no effect. Use `with torch.no_grad():` instead. /Users/tomininvladislav/env/lib/python3.6/site-packages/ipykernel_launcher.py:106: UserWarning: invalid index of a 0-dim tensor. This will be an error in PyTorch 0.5. Use tensor.item() to convert a 0-dim tensor to a Python number
Test set: Average loss: 0.0808, Accuracy: 9754/10000 (97%) Epoch 1 tensor(0.5011) tensor(0.1081) tensor(0.1593) tensor(0.0023) tensor(0.3648) tensor(0.3286) tensor(0.0283) tensor(0.2335) tensor(0.2880) tensor(0.1459) tensor(0.1050) tensor(0.4812) tensor(0.2546) tensor(0.3647) tensor(0.0491) tensor(0.0793) tensor(0.1642) tensor(0.0647) tensor(0.0013) tensor(0.0918) tensor(0.1254) tensor(0.0339) tensor(0.0185) tensor(0.0365) tensor(0.0057) tensor(0.2250) tensor(0.6726) tensor(0.0396) tensor(0.0958) tensor(0.2006) tensor(0.0689) tensor(0.0423) tensor(0.5617) tensor(0.6247) tensor(0.0840) tensor(0.5270) tensor(0.2182) tensor(0.2125) tensor(0.4093) tensor(0.0545) tensor(0.3090) tensor(0.0246) tensor(0.0470) tensor(0.3542) tensor(0.0378) tensor(0.5692) tensor(0.2637) tensor(0.2449) tensor(0.0664) tensor(0.0280) tensor(0.5333) tensor(0.3940) tensor(0.1393) tensor(0.0783) tensor(0.0213) tensor(0.0043) tensor(0.7582) tensor(0.2603) tensor(0.0836) tensor(0.0444) Test set: Average loss: 0.0592, Accuracy: 9811/10000 (98%) Epoch 2 tensor(0.0356) tensor(0.2295) tensor(0.7449) tensor(0.0223) tensor(0.3879) tensor(0.1033) tensor(0.0313) tensor(0.3570) tensor(0.0310) tensor(0.6206) tensor(0.7683) tensor(0.3666) tensor(0.1605) tensor(0.0033) tensor(0.0024) tensor(0.0051) tensor(0.3910) tensor(0.2479) tensor(0.4155) tensor(0.1365) tensor(0.0128) tensor(0.0355) tensor(0.2341) tensor(0.1922) tensor(0.2527) tensor(0.4097) tensor(0.1587) tensor(0.5817) tensor(0.3662) tensor(0.4334) tensor(0.2345) tensor(0.0106) tensor(0.0141) tensor(0.0042) tensor(0.4613) tensor(0.7768) tensor(0.1973) tensor(0.0012) tensor(0.4430) tensor(0.0328) tensor(0.0728) tensor(0.1306) tensor(0.0867) tensor(0.2390) tensor(0.0535) tensor(0.1353) tensor(0.2972) tensor(0.1562) tensor(0.1724) tensor(0.1046) tensor(0.0135) tensor(0.0290) tensor(0.0850) tensor(0.6875) tensor(0.0036) tensor(0.0550) tensor(0.1765) tensor(0.4230) tensor(0.0045) tensor(0.2658) Test set: Average loss: 0.0541, Accuracy: 9830/10000 (98%) Epoch 3 tensor(0.0302) tensor(0.0097) tensor(0.1190) tensor(0.0327) tensor(0.1781) tensor(0.2394) tensor(0.0975) tensor(0.0131) tensor(0.0714) tensor(0.2129) tensor(0.3632) tensor(0.0061) tensor(0.3545) tensor(0.2080) tensor(0.0563) tensor(0.0652) tensor(0.0120) tensor(0.2049) tensor(0.1424) tensor(0.2713) tensor(0.0074) tensor(0.0697) tensor(0.0793) tensor(0.0052) tensor(0.0001) tensor(0.2846) tensor(0.3846) tensor(0.6088) tensor(0.1475) tensor(0.0120) tensor(0.0033) tensor(0.0365) tensor(0.0013) tensor(0.2852) tensor(0.0511) tensor(0.0336) tensor(0.0023) tensor(0.0646) tensor(0.0326) tensor(0.0204) tensor(0.3698) tensor(0.0477) tensor(0.0008) tensor(0.0263) tensor(0.0007) tensor(0.0354) tensor(0.0299) tensor(0.3428) tensor(0.2711) tensor(0.3607) tensor(0.0073) tensor(0.0643) tensor(0.1716) tensor(0.2070) tensor(0.5111) tensor(0.0676) tensor(0.1594) tensor(0.0719) tensor(0.1455) tensor(0.0417) Test set: Average loss: 0.0522, Accuracy: 9841/10000 (98%) Epoch 4 tensor(0.0165) tensor(0.0081) tensor(0.4495) tensor(0.0751) tensor(0.1982) tensor(0.1231) tensor(0.0004) tensor(0.0024) tensor(0.0699) tensor(0.2190) tensor(0.1784) tensor(0.2229) tensor(0.0142) tensor(0.2616) tensor(0.2341) tensor(0.0157) tensor(0.0093) tensor(0.1259) tensor(0.1688) tensor(0.0060) tensor(0.2117) tensor(0.0234) tensor(0.4492) tensor(0.0078) tensor(0.7927) tensor(0.0220) tensor(0.0130) tensor(0.2484) tensor(0.0004) tensor(0.0144) tensor(0.0898) tensor(0.5596) tensor(0.0277) tensor(0.0758) tensor(0.0096) tensor(0.2347) tensor(0.0635) tensor(1.9689) tensor(0.0222) tensor(0.0247) tensor(0.7723) tensor(0.7936) tensor(0.0375) tensor(0.0624) tensor(0.0303) tensor(0.1175) tensor(0.0136) tensor(0.0437) tensor(0.1468) tensor(0.1978) tensor(0.2901) tensor(0.2014) tensor(0.2236) tensor(0.1130) tensor(0.3500) tensor(0.0004) tensor(0.7839) tensor(0.0080) tensor(0.2743) tensor(0.0185) Test set: Average loss: 0.0508, Accuracy: 9856/10000 (98%) Epoch 5 tensor(0.0074) tensor(0.1237) tensor(0.1067) tensor(0.2071) tensor(0.0011) tensor(0.0098) tensor(0.4357) tensor(0.2684) tensor(0.0214) tensor(0.2270) tensor(0.1825) tensor(0.0113) tensor(0.4253) tensor(0.0030) tensor(0.3864) tensor(0.0103) tensor(0.0761) tensor(0.1037) tensor(0.7902) tensor(0.0947) tensor(0.1519) tensor(0.2418) tensor(0.3236) tensor(0.0486) tensor(0.2020) tensor(1.2028) tensor(0.0674) tensor(0.1056) tensor(0.3415) tensor(0.2594) tensor(0.0514) tensor(0.8194) tensor(0.2716) tensor(0.1434) tensor(0.1740) tensor(0.0270) tensor(0.3025) tensor(0.2866) tensor(0.0449) tensor(0.0780) tensor(0.0474) tensor(0.0357) tensor(0.0332) tensor(0.0607) tensor(0.0031) tensor(0.3152) tensor(0.3706) tensor(0.1494) tensor(0.0140) tensor(0.1064) tensor(0.0262) tensor(0.1726) tensor(0.0605) tensor(0.0904) tensor(0.0325) tensor(0.7019) tensor(0.0126) tensor(0.1766) tensor(0.0191) tensor(0.4115) Test set: Average loss: 0.0431, Accuracy: 9872/10000 (98%) Epoch 6 tensor(0.1196) tensor(0.0067) tensor(0.0050) tensor(0.0058) tensor(0.0072) tensor(0.2078) tensor(0.0011) tensor(0.1679) tensor(0.0179) tensor(0.0244) tensor(0.6456) tensor(0.0548) tensor(0.0046) tensor(0.0106) tensor(0.1514) tensor(0.0077) tensor(0.7885) tensor(0.0659) tensor(0.0329) tensor(0.2223) tensor(0.0463) tensor(0.1784) tensor(0.0072) tensor(0.2236) tensor(0.3164) tensor(0.5943) tensor(0.1310) tensor(0.0663) tensor(0.1226) tensor(0.4081) tensor(0.0182) tensor(0.0695) tensor(0.0202) tensor(0.0035) tensor(0.1952) tensor(0.9424) tensor(0.2418) tensor(0.3266) tensor(0.0059) tensor(0.0178) tensor(0.0391) tensor(0.0541) tensor(1.1549) tensor(0.0356) tensor(0.0171) tensor(0.6049) tensor(0.3155) tensor(0.0005) tensor(0.0430) tensor(0.0095) tensor(0.4134) tensor(0.1603) tensor(0.0098) tensor(0.0063) tensor(0.0105) tensor(0.3159) tensor(0.4090) tensor(0.2921) tensor(0.0477) tensor(0.0368) Test set: Average loss: 0.0431, Accuracy: 9870/10000 (98%) Epoch 7 tensor(0.5931) tensor(0.0371) tensor(0.0601) tensor(0.0342) tensor(0.0076) tensor(0.0137) tensor(0.6174) tensor(0.1458) tensor(0.0009) tensor(0.2065) tensor(0.0317) tensor(0.3709) tensor(0.0696) tensor(0.0119) tensor(0.3254) tensor(0.0079) tensor(0.0659) tensor(0.0099) tensor(0.1840) tensor(0.4708) tensor(0.2299) tensor(0.0104) tensor(0.1337) tensor(0.0045) tensor(0.0498) tensor(0.0277) tensor(0.2539) tensor(0.0682) tensor(0.0154) tensor(0.0882) tensor(0.0153) tensor(0.0027) tensor(0.1196) tensor(0.1423) tensor(0.0697) tensor(0.0065) tensor(0.0260) tensor(0.3950) tensor(0.0315) tensor(0.6597) tensor(0.7344) tensor(0.1829) tensor(0.0536) tensor(0.2632) tensor(0.1588) tensor(0.3388) tensor(0.0224) tensor(0.1183) tensor(0.1177) tensor(0.0198) tensor(0.0148) tensor(0.0698) tensor(0.1600) tensor(0.2374) tensor(0.0152) tensor(0.4647) tensor(0.0083) tensor(0.0011) tensor(0.1099) tensor(0.1983) Test set: Average loss: 0.0454, Accuracy: 9873/10000 (98%) Epoch 8 tensor(0.4058) tensor(0.0341) tensor(0.0197) tensor(0.0630) tensor(0.0145) tensor(0.0281) tensor(0.0087) tensor(0.0009) tensor(0.0097) tensor(0.2276) tensor(0.0198) tensor(0.0232) tensor(0.1321) tensor(0.0503) tensor(0.0688) tensor(0.0748) tensor(0.0061) tensor(0.3066) tensor(0.0508) tensor(0.2585) tensor(0.0030) tensor(0.0179) tensor(0.0033) tensor(0.2763) tensor(0.0480) tensor(0.0051) tensor(0.0584) tensor(0.1261) tensor(0.0451) tensor(0.0120) tensor(0.3614) tensor(0.1144) tensor(0.0301) tensor(0.5878) tensor(0.0191) tensor(0.2227) tensor(0.7869) tensor(0.2182) tensor(0.0143) tensor(0.0094) tensor(0.0243) tensor(0.0022) tensor(0.0020) tensor(0.0662) tensor(0.1738) tensor(0.0722) tensor(0.9319) tensor(0.0450) tensor(0.1798) tensor(0.0070) tensor(0.0888) tensor(0.6651) tensor(0.0842) tensor(0.1716) tensor(0.0064) tensor(0.0090) tensor(0.0615) tensor(0.2282) tensor(0.0252) tensor(0.0120) Test set: Average loss: 0.0440, Accuracy: 9869/10000 (98%) Epoch 9 tensor(0.0189) tensor(0.0757) tensor(1.3769) tensor(0.4845) tensor(0.0056) tensor(0.4848) tensor(0.2691) tensor(0.0161) tensor(0.0079) tensor(0.3211) tensor(0.2998) tensor(0.1961) tensor(0.0263) tensor(0.1055) tensor(0.4013) tensor(0.3441) tensor(0.0764) tensor(0.0393) tensor(0.0681) tensor(0.0497) tensor(0.1159) tensor(0.0293) tensor(0.0143) tensor(0.9224) tensor(0.0714) tensor(0.0042) tensor(0.5861) tensor(0.0866) tensor(0.1000) tensor(0.0407) tensor(0.2220) tensor(0.0571) tensor(0.0248) tensor(0.5344) tensor(0.0242) tensor(0.0153) tensor(0.1869) tensor(0.0102) tensor(0.2262) tensor(0.1935) tensor(0.0121) tensor(0.4245) tensor(0.4803) tensor(0.0650) tensor(1.5459) tensor(0.3217) tensor(0.0521) tensor(0.0113) tensor(0.7782) tensor(0.1123) tensor(0.1777) tensor(0.1091) tensor(0.1289) tensor(0.0200) tensor(0.0901) tensor(0.0012) tensor(0.0150) tensor(0.0195) tensor(0.0145) tensor(0.2845) Test set: Average loss: 0.0430, Accuracy: 9870/10000 (98%)
plt.plot(range(len(loss_history)), loss_history)
plt.xlabel("number of iteration", fontsize = 17)
plt.ylabel("tain loss", fontsize = 17)
plt.title("Loss (iteration)")
Text(0.5,1,'Loss (iteration)')
plt.plot(range(len(acc_history)), acc_history)
plt.xlabel("epoch", fontsize = 17)
plt.ylabel("accuracy", fontsize = 17)
plt.title("accuracy(epoch)")
plt.grid(True)
plt.plot(range(len(avg_loss)), avg_loss)
plt.xlabel("epoch", fontsize = 17)
plt.ylabel("avg test loss", fontsize = 17)
plt.title("loss(epoch)")
plt.grid(True)
accuracy_final = []
def MonitoringAccuracy(disp):
# create classifier and optimizer objects
clf = CNNClassifier()
opt = torch.optim.Adam(clf.parameters(), lr=1e-3)
loss_history = []
acc_history = []
avg_loss = []
def train(epoch, disp):
clf.train() # set model in training mode (need this because of dropout)
# dataset API gives us pythonic batching
for batch_id, (data, label) in enumerate(train_loader):
data = Variable(data + torch.Tensor(np.random.normal(0,disp, [10, 1, 28, 28])))
target = Variable(label)
# forward pass, calculate loss and backprop!
opt.zero_grad()
preds = clf(data)
loss = F.nll_loss(preds, target)
loss.backward()
loss_history.append(loss.data[0])
opt.step()
if batch_id % 100 == 0:
print(loss.data[0])
def test(epoch):
clf.eval() # set model in inference mode (need this because of dropout)
test_loss = 0
correct = 0
for data, target in test_loader:
data = Variable(data, volatile=True)
target = Variable(target)
output = clf(data)
test_loss += F.nll_loss(output, target).data[0]
pred = output.data.max(1)[1] # get the index of the max log-probability
correct += pred.eq(target.data).cpu().sum()
test_loss = test_loss
test_loss /= len(test_loader) # loss function already averages over batch size
accuracy = 100. * correct / len(test_loader.dataset)
acc_history.append(accuracy)
avg_loss.append(test_loss)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
accuracy))
accuracy_final.append(accuracy)
for epoch in range(0, 1):
train(epoch,disp)
test(epoch)
for disp in np.linspace(0, 10, 10):
MonitoringAccuracy(disp)
/Users/tomininvladislav/env/lib/python3.6/site-packages/ipykernel_launcher.py:67: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument. /Users/tomininvladislav/env/lib/python3.6/site-packages/ipykernel_launcher.py:26: UserWarning: invalid index of a 0-dim tensor. This will be an error in PyTorch 0.5. Use tensor.item() to convert a 0-dim tensor to a Python number /Users/tomininvladislav/env/lib/python3.6/site-packages/ipykernel_launcher.py:30: UserWarning: invalid index of a 0-dim tensor. This will be an error in PyTorch 0.5. Use tensor.item() to convert a 0-dim tensor to a Python number
tensor(2.3043) tensor(1.9872) tensor(1.6060) tensor(0.7386) tensor(0.8171) tensor(0.5953) tensor(0.6926) tensor(0.2747) tensor(0.2679) tensor(1.1111) tensor(0.3211) tensor(0.1784) tensor(0.3896) tensor(0.5361) tensor(0.4876) tensor(0.1343) tensor(0.7784) tensor(0.0915) tensor(0.0609) tensor(0.3252) tensor(0.2540) tensor(0.1344) tensor(1.8978) tensor(0.0747) tensor(0.1236) tensor(0.1614) tensor(0.3952) tensor(0.4646) tensor(0.7779) tensor(0.7010) tensor(0.4817) tensor(0.0347) tensor(0.5327) tensor(0.4109) tensor(0.4489) tensor(0.1911) tensor(1.7798) tensor(0.4198) tensor(0.2628) tensor(0.0972) tensor(0.1134) tensor(0.0833) tensor(0.1923) tensor(0.2487) tensor(0.1231) tensor(0.9556) tensor(0.2313) tensor(0.0149) tensor(0.5682) tensor(0.0633) tensor(0.3155) tensor(0.6313) tensor(0.1329) tensor(0.0938) tensor(0.2915) tensor(0.5344) tensor(0.0352) tensor(0.6538) tensor(0.0555) tensor(0.0240)
/Users/tomininvladislav/env/lib/python3.6/site-packages/ipykernel_launcher.py:38: UserWarning: volatile was removed and now has no effect. Use `with torch.no_grad():` instead. /Users/tomininvladislav/env/lib/python3.6/site-packages/ipykernel_launcher.py:42: UserWarning: invalid index of a 0-dim tensor. This will be an error in PyTorch 0.5. Use tensor.item() to convert a 0-dim tensor to a Python number
Test set: Average loss: 0.0784, Accuracy: 9761/10000 (97%) tensor(2.3532) tensor(2.0452) tensor(1.0761) tensor(0.5852) tensor(1.0204) tensor(0.6283) tensor(0.4834) tensor(1.1529) tensor(0.5052) tensor(0.3921) tensor(1.3586) tensor(0.4793) tensor(1.4567) tensor(0.7758) tensor(0.8992) tensor(0.6025) tensor(0.7153) tensor(0.5562) tensor(1.1504) tensor(0.2448) tensor(0.9506) tensor(1.2470) tensor(0.1711) tensor(0.2481) tensor(0.6830) tensor(0.3502) tensor(0.5139) tensor(0.5422) tensor(0.7124) tensor(0.3073) tensor(0.5131) tensor(0.2495) tensor(0.1383) tensor(0.3369) tensor(0.2157) tensor(0.2084) tensor(1.2748) tensor(0.7956) tensor(0.1484) tensor(0.2741) tensor(0.1104) tensor(0.8630) tensor(0.4463) tensor(0.1398) tensor(0.3714) tensor(0.4061) tensor(0.5439) tensor(0.2552) tensor(0.0289) tensor(1.4360) tensor(0.2907) tensor(0.5284) tensor(0.5968) tensor(0.1964) tensor(0.4730) tensor(0.1241) tensor(1.4166) tensor(0.1826) tensor(0.3955) tensor(0.3469) Test set: Average loss: 0.1223, Accuracy: 9637/10000 (96%) tensor(2.3502) tensor(2.1690) tensor(1.7710) tensor(1.2366) tensor(0.9113) tensor(1.2075) tensor(0.7804) tensor(1.6028) tensor(1.4061) tensor(1.1290) tensor(0.8035) tensor(0.8703) tensor(1.0407) tensor(1.9271) tensor(1.1154) tensor(0.7763) tensor(0.5873) tensor(0.6214) tensor(1.7302) tensor(0.9990) tensor(1.0255) tensor(0.7247) tensor(1.4419) tensor(0.8202) tensor(0.8467) tensor(1.3437) tensor(0.6918) tensor(1.0135) tensor(1.3484) tensor(0.2878) tensor(0.7749) tensor(1.2502) tensor(1.0230) tensor(1.1792) tensor(1.4390) tensor(1.7681) tensor(1.3830) tensor(0.8489) tensor(0.4327) tensor(0.6233) tensor(1.0440) tensor(1.1102) tensor(1.0955) tensor(0.7096) tensor(0.6351) tensor(0.6956) tensor(0.4329) tensor(0.4981) tensor(0.7183) tensor(0.4723) tensor(0.7801) tensor(0.5629) tensor(0.5491) tensor(0.9221) tensor(0.7271) tensor(0.3497) tensor(0.3187) tensor(1.1137) tensor(0.7339) tensor(0.5529) Test set: Average loss: 0.2547, Accuracy: 9339/10000 (93%) tensor(2.3618) tensor(2.2768) tensor(2.2933) tensor(2.3792) tensor(2.2992) tensor(2.0398) tensor(2.1501) tensor(2.0464) tensor(1.7620) tensor(1.5862) tensor(2.1251) tensor(1.7927) tensor(1.9181) tensor(1.8846) tensor(1.9889) tensor(1.9729) tensor(1.7651) tensor(2.0716) tensor(2.0603) tensor(1.6725) tensor(1.5055) tensor(1.4235) tensor(1.8171) tensor(1.7234) tensor(1.6059) tensor(2.0889) tensor(1.9431) tensor(1.6176) tensor(2.0197) tensor(2.1037) tensor(1.9061) tensor(1.9189) tensor(2.1563) tensor(2.0859) tensor(1.9405) tensor(1.7318) tensor(1.9778) tensor(1.7186) tensor(1.5849) tensor(1.2716) tensor(1.8881) tensor(1.9137) tensor(2.5917) tensor(1.7526) tensor(1.7421) tensor(1.8363) tensor(1.8821) tensor(2.2899) tensor(1.9528) tensor(1.6123) tensor(1.3523) tensor(1.7410) tensor(1.6162) tensor(1.6519) tensor(1.4763) tensor(1.9487) tensor(1.6876) tensor(2.3164) tensor(1.8722) tensor(2.1064) Test set: Average loss: 1.2857, Accuracy: 7485/10000 (74%) tensor(2.5903) tensor(2.2259) tensor(2.3199) tensor(2.3124) tensor(2.3241) tensor(2.2991) tensor(2.3253) tensor(2.3069) tensor(2.3020) tensor(2.2938) tensor(2.2817) tensor(2.3361) tensor(2.2946) tensor(2.2775) tensor(2.3196) tensor(2.2944) tensor(2.2917) tensor(2.3454) tensor(2.2953) tensor(2.3087) tensor(2.3040) tensor(2.3202) tensor(2.2935) tensor(2.2922) tensor(2.2838) tensor(2.3157) tensor(2.3238) tensor(2.3087) tensor(2.2749) tensor(2.3184) tensor(2.3142) tensor(2.2949) tensor(2.3016) tensor(2.2794) tensor(2.2986) tensor(2.2988) tensor(2.3285) tensor(2.3106) tensor(2.3223) tensor(2.3038) tensor(2.3066) tensor(2.2818) tensor(2.3019) tensor(2.3376) tensor(2.3041) tensor(2.2811) tensor(2.2851) tensor(2.2831) tensor(2.2721) tensor(2.3050) tensor(2.3222) tensor(2.2887) tensor(2.3070) tensor(2.3059) tensor(2.2890) tensor(2.3256) tensor(2.2936) tensor(2.3248) tensor(2.2683) tensor(2.2951) Test set: Average loss: 2.3002, Accuracy: 1135/10000 (11%) tensor(3.2183) tensor(2.2484) tensor(2.3075) tensor(2.3137) tensor(2.3413) tensor(2.2660) tensor(2.3209) tensor(2.2936) tensor(2.3250) tensor(2.2998) tensor(2.3097) tensor(2.2934) tensor(2.3170) tensor(2.2712) tensor(2.2890) tensor(2.3110) tensor(2.3072) tensor(2.2821) tensor(2.2629) tensor(2.3599) tensor(2.2969) tensor(2.2903) tensor(2.3176) tensor(2.3071) tensor(2.3026) tensor(2.3082) tensor(2.2835) tensor(2.2989) tensor(2.2867) tensor(2.3061) tensor(2.2985) tensor(2.3164) tensor(2.3346) tensor(2.2799) tensor(2.3338) tensor(2.3028) tensor(2.2832) tensor(2.2829) tensor(2.3108) tensor(2.2806) tensor(2.3057) tensor(2.3075) tensor(2.2999) tensor(2.2960) tensor(2.3249) tensor(2.3107) tensor(2.3014) tensor(2.3011) tensor(2.2752) tensor(2.2865) tensor(2.2896) tensor(2.3131) tensor(2.3196) tensor(2.3116) tensor(2.2901) tensor(2.2878) tensor(2.2908) tensor(2.3382) tensor(2.3131) tensor(2.2730) Test set: Average loss: 2.3014, Accuracy: 1135/10000 (11%) tensor(2.4941) tensor(2.3129) tensor(2.3005) tensor(2.3023) tensor(2.3169) tensor(2.3034) tensor(2.3149) tensor(2.2962) tensor(2.2986) tensor(2.3322) tensor(2.2996) tensor(2.2949) tensor(2.2975) tensor(2.3067) tensor(2.3090) tensor(2.3188) tensor(2.3008) tensor(2.3121) tensor(2.2955) tensor(2.3290) tensor(2.3121) tensor(2.3009) tensor(2.3086) tensor(2.3377) tensor(2.2696) tensor(2.3571) tensor(2.2805) tensor(2.2867) tensor(2.3143) tensor(2.2978) tensor(2.3193) tensor(2.2808) tensor(2.3077) tensor(2.2892) tensor(2.2815) tensor(2.2946) tensor(2.3170) tensor(2.3019) tensor(2.3005) tensor(2.3109) tensor(2.3328) tensor(2.3180) tensor(2.3129) tensor(2.3036) tensor(2.3155) tensor(2.3089) tensor(2.3005) tensor(2.3039) tensor(2.2818) tensor(2.2714) tensor(2.3264) tensor(2.3070) tensor(2.3360) tensor(2.3005) tensor(2.2958) tensor(2.3001) tensor(2.3254) tensor(2.2838) tensor(2.3058) tensor(2.2927) Test set: Average loss: 2.3013, Accuracy: 1135/10000 (11%) tensor(5.1102) tensor(2.2979) tensor(2.3063) tensor(2.3203) tensor(2.3200) tensor(2.3176) tensor(2.2654) tensor(2.3120) tensor(2.2896) tensor(2.2820) tensor(2.3035) tensor(2.2915) tensor(2.2875) tensor(2.3125) tensor(2.2915) tensor(2.3212) tensor(2.3240) tensor(2.3364) tensor(2.3529) tensor(2.3086) tensor(2.3229) tensor(2.2931) tensor(2.3419) tensor(2.3032) tensor(2.3018) tensor(2.3040) tensor(2.3147) tensor(2.3106) tensor(2.2878) tensor(2.3143) tensor(2.3117) tensor(2.3079) tensor(2.3159) tensor(2.2863) tensor(2.3175) tensor(2.2992) tensor(2.3183) tensor(2.3124) tensor(2.3212) tensor(2.3099) tensor(2.3007) tensor(2.2981) tensor(2.2834) tensor(2.2639) tensor(2.3095) tensor(2.2795) tensor(2.2775) tensor(2.2890) tensor(2.2897) tensor(2.2681) tensor(2.2951) tensor(2.3349) tensor(2.3113) tensor(2.2962) tensor(2.2935) tensor(2.2938) tensor(2.2883) tensor(2.3179) tensor(2.3249) tensor(2.2858) Test set: Average loss: 2.3015, Accuracy: 1135/10000 (11%) tensor(3.1951) tensor(2.2986) tensor(2.2839) tensor(2.2949) tensor(2.3275) tensor(2.2839) tensor(2.3129) tensor(2.3050) tensor(2.3078) tensor(2.2827) tensor(2.2956) tensor(2.2830) tensor(2.3189) tensor(2.2792) tensor(2.2588) tensor(2.3142) tensor(2.2816) tensor(2.3040) tensor(2.2652) tensor(2.3297) tensor(2.2923) tensor(2.3065) tensor(2.2710) tensor(2.3121) tensor(2.2841) tensor(2.3498) tensor(2.3075) tensor(2.2717) tensor(2.2934) tensor(2.2885) tensor(2.2355) tensor(2.2728) tensor(2.2954) tensor(2.3016) tensor(2.3113) tensor(2.2950) tensor(2.3010) tensor(2.3016) tensor(2.3253) tensor(2.3131) tensor(2.3087) tensor(2.2881) tensor(2.2714) tensor(2.2912) tensor(2.3249) tensor(2.3063) tensor(2.2965) tensor(2.3025) tensor(2.3039) tensor(2.2855) tensor(2.3044) tensor(2.3209) tensor(2.3281) tensor(2.3124) tensor(2.3243) tensor(2.3061) tensor(2.2964) tensor(2.3125) tensor(2.3112) tensor(2.2657) Test set: Average loss: 2.3012, Accuracy: 1135/10000 (11%) tensor(3.1833) tensor(2.3155) tensor(2.3247) tensor(2.3163) tensor(2.2883) tensor(2.3146) tensor(2.2891) tensor(2.2893) tensor(2.3161) tensor(2.3111) tensor(2.2964) tensor(2.3107) tensor(2.3024) tensor(2.3049) tensor(2.3205) tensor(2.2961) tensor(2.2587) tensor(2.3406) tensor(2.3214) tensor(2.2971) tensor(2.3066) tensor(2.3128) tensor(2.3260) tensor(2.3173) tensor(2.3129) tensor(2.3116) tensor(2.2984) tensor(2.2922) tensor(2.3077) tensor(2.3215) tensor(2.3122) tensor(2.2955) tensor(2.3077) tensor(2.2907) tensor(2.2714) tensor(2.2687) tensor(2.3020) tensor(2.2990) tensor(2.2970) tensor(2.2923) tensor(2.3147) tensor(2.3139) tensor(2.2859) tensor(2.3175) tensor(2.2891) tensor(2.2714) tensor(2.3520) tensor(2.3051) tensor(2.2716) tensor(2.3148) tensor(2.2954) tensor(2.2872) tensor(2.3245) tensor(2.3141) tensor(2.3162) tensor(2.3141) tensor(2.3017) tensor(2.3204) tensor(2.2856) tensor(2.2954) Test set: Average loss: 2.3013, Accuracy: 1135/10000 (11%)
plt.plot(np.linspace(0, 10, 10), accuracy_final)
plt.xlabel("$\sigma$", fontsize = 17)
plt.ylabel("accuracy, %", fontsize = 17)
plt.title("accuracy(dispersion)", fontsize = 17)
plt.grid(True)