该来的还是来了hhhhhhhhhh,基本上机器学习的初学者都躲不开这个例子。开源,数据质量高,数据尺寸整齐,问题简单,实在太适合初学者食用了。
今天把代码跑通,趁着周末好好的琢磨一下里面的各种细节。
代码实现
首先鸣谢百度AI,真的直接生成的代码就能跑,不要太爽。差不多九年前大二的时候,这一点点代码,是要看完一个几小时的英文视频才能获取的。看着网络非常非常浅,就已经达到了比较好的预测效果。
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms# 定义模型
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 10, kernel_size=5) #输入为1,输出为10,卷积核大小5self.conv2 = nn.Conv2d(10, 20, kernel_size=5)self.fc = nn.Linear(20 * 4 * 4, 10)def forward(self, x):batch_size = x.size(0) #第一个维度是batch维度,图片为1*28*28时,输入为64*1*28*28x = torch.relu(self.conv1(x)) # 输入64*1*28*28, 输出64*10*24*24x = torch.max_pool2d(x, 2, 2) # 输入64*10*24*24, 输出64*10*12*12,池化层x = torch.relu(self.conv2(x)) # 输入64*10*12*12, 输出64*20*8*8x = torch.max_pool2d(x, 2, 2) # 输入64*20*8*8, 输出64*20*4*4x = x.view(batch_size, -1) # 输入64*20*4*4, 输出64*320x = self.fc(x) # 输入64*320, 输出64*10return xif __name__=="__main__":# 定义超参数batch_size = 64epochs = 10learning_rate = 0.01# 数据预处理transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])# 加载训练/测试数据 batch_size:每次训练的规模 shuffle: 是否每次训练完对数据进行洗牌train_dataset = datasets.MNIST('data', train=True, download=True, transform=transform)train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)test_dataset = datasets.MNIST('data', train=False, transform=transform)test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)# 实例化模型、损失函数和优化器model = Net()optimizer = optim.Adam(model.parameters(), lr=learning_rate)criterion = nn.CrossEntropyLoss()# 训练模型for epoch in range(epochs):for batch_idx, (data, target) in enumerate(train_loader): #自动打batchoptimizer.zero_grad() #典型的训练步骤output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()if batch_idx % 100 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),0. * batch_idx / len(train_loader), loss.item()))# 测试模型model.eval()test_loss = 0correct = 0with torch.no_grad():for data, target in test_loader:output = model(data)test_loss += criterion(output, target).item()pred = output.argmax(dim=1, keepdim=True)correct += pred.eq(target.view_as(pred)).sum().item()test_loss /= len(test_loader.dataset)print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset)))
运行结果如下:
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: ForbiddenDownloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to data\MNIST\raw\train-images-idx3-ubyte.gz
100%|██████████| 9912422/9912422 [02:41<00:00, 61401.03it/s]
Extracting data\MNIST\raw\train-images-idx3-ubyte.gz to data\MNIST\rawDownloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: ForbiddenDownloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to data\MNIST\raw\train-labels-idx1-ubyte.gz
100%|██████████| 28881/28881 [00:00<00:00, 97971.03it/s]
Extracting data\MNIST\raw\train-labels-idx1-ubyte.gz to data\MNIST\rawDownloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: ForbiddenDownloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to data\MNIST\raw\t10k-images-idx3-ubyte.gz
100%|██████████| 1648877/1648877 [00:29<00:00, 56423.58it/s]
Extracting data\MNIST\raw\t10k-images-idx3-ubyte.gz to data\MNIST\rawDownloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: ForbiddenDownloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to data\MNIST\raw\t10k-labels-idx1-ubyte.gz
100%|██████████| 4542/4542 [00:00<00:00, 4339528.19it/s]
Extracting data\MNIST\raw\t10k-labels-idx1-ubyte.gz to data\MNIST\rawTrain Epoch: 0 [0/60000 (0%)] Loss: 2.275243
Train Epoch: 0 [6400/60000 (0%)] Loss: 0.200208
Train Epoch: 0 [12800/60000 (0%)] Loss: 0.064670
Train Epoch: 0 [19200/60000 (0%)] Loss: 0.066074
Train Epoch: 0 [25600/60000 (0%)] Loss: 0.115960
Train Epoch: 0 [32000/60000 (0%)] Loss: 0.171170
Train Epoch: 0 [38400/60000 (0%)] Loss: 0.041663
Train Epoch: 0 [44800/60000 (0%)] Loss: 0.179172
Train Epoch: 0 [51200/60000 (0%)] Loss: 0.014898
Train Epoch: 0 [57600/60000 (0%)] Loss: 0.035095
Train Epoch: 1 [0/60000 (0%)] Loss: 0.016566
Train Epoch: 1 [6400/60000 (0%)] Loss: 0.008371
Train Epoch: 1 [12800/60000 (0%)] Loss: 0.006069
Train Epoch: 1 [19200/60000 (0%)] Loss: 0.009995
Train Epoch: 1 [25600/60000 (0%)] Loss: 0.020422
Train Epoch: 1 [32000/60000 (0%)] Loss: 0.155348
Train Epoch: 1 [38400/60000 (0%)] Loss: 0.059595
Train Epoch: 1 [44800/60000 (0%)] Loss: 0.038654
Train Epoch: 1 [51200/60000 (0%)] Loss: 0.084179
Train Epoch: 1 [57600/60000 (0%)] Loss: 0.147250
Train Epoch: 2 [0/60000 (0%)] Loss: 0.040161
Train Epoch: 2 [6400/60000 (0%)] Loss: 0.147080
Train Epoch: 2 [12800/60000 (0%)] Loss: 0.037228
Train Epoch: 2 [19200/60000 (0%)] Loss: 0.257872
Train Epoch: 2 [25600/60000 (0%)] Loss: 0.052811
Train Epoch: 2 [32000/60000 (0%)] Loss: 0.005805
Train Epoch: 2 [38400/60000 (0%)] Loss: 0.092318
Train Epoch: 2 [44800/60000 (0%)] Loss: 0.084066
Train Epoch: 2 [51200/60000 (0%)] Loss: 0.000331
Train Epoch: 2 [57600/60000 (0%)] Loss: 0.011482
Train Epoch: 3 [0/60000 (0%)] Loss: 0.042851
Train Epoch: 3 [6400/60000 (0%)] Loss: 0.004001
Train Epoch: 3 [12800/60000 (0%)] Loss: 0.008942
Train Epoch: 3 [19200/60000 (0%)] Loss: 0.045065
Train Epoch: 3 [25600/60000 (0%)] Loss: 0.099309
Train Epoch: 3 [32000/60000 (0%)] Loss: 0.054098
Train Epoch: 3 [38400/60000 (0%)] Loss: 0.059155
Train Epoch: 3 [44800/60000 (0%)] Loss: 0.016098
Train Epoch: 3 [51200/60000 (0%)] Loss: 0.114458
Train Epoch: 3 [57600/60000 (0%)] Loss: 0.231477
Train Epoch: 4 [0/60000 (0%)] Loss: 0.003781
Train Epoch: 4 [6400/60000 (0%)] Loss: 0.068822
Train Epoch: 4 [12800/60000 (0%)] Loss: 0.103501
Train Epoch: 4 [19200/60000 (0%)] Loss: 0.002396
Train Epoch: 4 [25600/60000 (0%)] Loss: 0.174503
Train Epoch: 4 [32000/60000 (0%)] Loss: 0.027796
Train Epoch: 4 [38400/60000 (0%)] Loss: 0.013167
Train Epoch: 4 [44800/60000 (0%)] Loss: 0.011576
Train Epoch: 4 [51200/60000 (0%)] Loss: 0.000726
Train Epoch: 4 [57600/60000 (0%)] Loss: 0.069251
Train Epoch: 5 [0/60000 (0%)] Loss: 0.006919
Train Epoch: 5 [6400/60000 (0%)] Loss: 0.015165
Train Epoch: 5 [12800/60000 (0%)] Loss: 0.117820
Train Epoch: 5 [19200/60000 (0%)] Loss: 0.031030
Train Epoch: 5 [25600/60000 (0%)] Loss: 0.031566
Train Epoch: 5 [32000/60000 (0%)] Loss: 0.046268
Train Epoch: 5 [38400/60000 (0%)] Loss: 0.055709
Train Epoch: 5 [44800/60000 (0%)] Loss: 0.021299
Train Epoch: 5 [51200/60000 (0%)] Loss: 0.004246
Train Epoch: 5 [57600/60000 (0%)] Loss: 0.014340
Train Epoch: 6 [0/60000 (0%)] Loss: 0.056358
Train Epoch: 6 [6400/60000 (0%)] Loss: 0.104084
Train Epoch: 6 [12800/60000 (0%)] Loss: 0.097005
Train Epoch: 6 [19200/60000 (0%)] Loss: 0.009379
Train Epoch: 6 [25600/60000 (0%)] Loss: 0.078417
Train Epoch: 6 [32000/60000 (0%)] Loss: 0.217889
Train Epoch: 6 [38400/60000 (0%)] Loss: 0.079795
Train Epoch: 6 [44800/60000 (0%)] Loss: 0.052873
Train Epoch: 6 [51200/60000 (0%)] Loss: 0.127716
Train Epoch: 6 [57600/60000 (0%)] Loss: 0.087016
Train Epoch: 7 [0/60000 (0%)] Loss: 0.045884
Train Epoch: 7 [6400/60000 (0%)] Loss: 0.087923
Train Epoch: 7 [12800/60000 (0%)] Loss: 0.164549
Train Epoch: 7 [19200/60000 (0%)] Loss: 0.111163
Train Epoch: 7 [25600/60000 (0%)] Loss: 0.300172
Train Epoch: 7 [32000/60000 (0%)] Loss: 0.045357
Train Epoch: 7 [38400/60000 (0%)] Loss: 0.087294
Train Epoch: 7 [44800/60000 (0%)] Loss: 0.110581
Train Epoch: 7 [51200/60000 (0%)] Loss: 0.001932
Train Epoch: 7 [57600/60000 (0%)] Loss: 0.066714
Train Epoch: 8 [0/60000 (0%)] Loss: 0.047415
Train Epoch: 8 [6400/60000 (0%)] Loss: 0.106327
Train Epoch: 8 [12800/60000 (0%)] Loss: 0.016832
Train Epoch: 8 [19200/60000 (0%)] Loss: 0.013452
Train Epoch: 8 [25600/60000 (0%)] Loss: 0.035256
Train Epoch: 8 [32000/60000 (0%)] Loss: 0.026502
Train Epoch: 8 [38400/60000 (0%)] Loss: 0.011809
Train Epoch: 8 [44800/60000 (0%)] Loss: 0.171943
Train Epoch: 8 [51200/60000 (0%)] Loss: 0.209570
Train Epoch: 8 [57600/60000 (0%)] Loss: 0.047113
Train Epoch: 9 [0/60000 (0%)] Loss: 0.126423
Train Epoch: 9 [6400/60000 (0%)] Loss: 0.016720
Train Epoch: 9 [12800/60000 (0%)] Loss: 0.210951
Train Epoch: 9 [19200/60000 (0%)] Loss: 0.072410
Train Epoch: 9 [25600/60000 (0%)] Loss: 0.042366
Train Epoch: 9 [32000/60000 (0%)] Loss: 0.002912
Train Epoch: 9 [38400/60000 (0%)] Loss: 0.074261
Train Epoch: 9 [44800/60000 (0%)] Loss: 0.004673
Train Epoch: 9 [51200/60000 (0%)] Loss: 0.074964
Train Epoch: 9 [57600/60000 (0%)] Loss: 0.040360Test set: Average loss: 0.0011, Accuracy: 9795/10000 (98%)
部分解读
下面这个语法是定义了一个二维卷积层,
nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros')
可以参考一下这篇博客 https://blog.csdn.net/qq_60245590/article/details/135856418
百度AI也给出了解释
训练数据是python实时从网上下载的,打开看看,里面还挺东西,应该最主要的就是训练数据和测试数据。可是这样的话,为啥要分布下载个train_dataset和test_dataset呢?我略有些迷茫。
batch居然不用我们自己打,咦?这个功能mindspore有吗?我自己捏的数据能自动打batch吗?能的话就很方便了。
好!今天崩铁前瞻~打游戏去咯~