数据集下载
MINST_PNG_Training在github的项目目录中的datasets
中有MNIST的png格式数据集的压缩包
用于训练的神经网络模型
自定义数据集训练
在前文【Pytorch】13.搭建完整的CIFAR10模型我们已经知道了基本搭建神经网络的框架了,但是其中的数据集使用的torchvision
中的CIFAR10
官方数据集进行训练的
train_dataset = torchvision.datasets.CIFAR10('../datasets', train=True, download=True,transform=torchvision.transforms.ToTensor())
test_dataset = torchvision.datasets.CIFAR10('../datasets', train=False, download=True,transform=torchvision.transforms.ToTensor())
本文将用图片格式的数据集进行训练
我们通过
# Dataset CIFAR10
# Number of datapoints: 60000
# Root location: ../datasets
# Split: Train
# StandardTransform
# Transform: ToTensor()
print(train_dataset)
可以看到我们下载的数据集是这种格式的,所以我们的主要问题就是如何将自定义的数据集获取,并且转化为这种形式,剩下的步骤就和上文相同了
数据类型进行转化
我们的首要目的是,根据数据集的地址,分别将数据转化为train_dataset
与test_dataset
我们需要调用ImageFolder
方法来进行操作
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from torchvision.datasets import ImageFolder
from model import *# 训练集地址
train_root = "../datasets/mnist_png/training"
# 测试集地址
test_root = '../datasets/mnist_png/testing'# 进行数据的处理,定义数据转换
data_transform = transforms.Compose([transforms.Resize((28, 28)),transforms.Grayscale(),transforms.ToTensor()])# 加载数据集
train_dataset = ImageFolder(train_root, transform=data_transform)
test_dataset = ImageFolder(test_root, transform=data_transform)
首先我们需要将数据进行处理,通过transforms.Compose
获取对象data_transform
其中进行了三步操作
- 将图片大小变为
28*28像素
便于输入网络模型 - 将图片转化为灰度格式,因为手写数字识别不需要
三通道
的图片,只需要灰度图像就可以识别,而png格式
的图片是四通道
的 - 将图片转化为
tensor
数据类型
然后通过ImageFolder
给出图片的地址与转化类型,就可以实现与我们在官方下载数据集相同的格式
# Dataset ImageFolder
# Number of datapoints: 60000
# Root location: ../datasets/mnist_png/training
# StandardTransform
# Transform: Compose(
# Resize(size=(28, 28), interpolation=bilinear, max_size=None, antialias=True)
# ToTensor()
# )
print(train_dataset)
其他与前文【Pytorch】13.搭建完整的CIFAR10模型基本相同
完整代码
网络模型
import torch
from torch import nnclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)self.relu1 = nn.ReLU()self.pool1 = nn.MaxPool2d(2, stride=2)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)self.relu2 = nn.ReLU()self.pool2 = nn.MaxPool2d(2, stride=2)self.flatten = nn.Flatten()self.fc1 = nn.Linear(3136, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = self.conv1(x)x = self.relu1(x)x = self.pool1(x)x = self.conv2(x)x = self.relu2(x)x = self.pool2(x)x = self.flatten(x)x = self.fc1(x)x = self.fc2(x)return xif __name__ == "__main__":model = Net()input = torch.ones((1, 1, 28, 28))output = model(input)print(output.shape)
训练过程
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from torchvision.datasets import ImageFolder
from model import *# 训练集地址
train_root = "../datasets/mnist_png/training"
# 测试集地址
test_root = '../datasets/mnist_png/testing'# 进行数据的处理,定义数据转换
data_transform = transforms.Compose([transforms.Resize((28, 28)),transforms.Grayscale(),transforms.ToTensor()])# 加载数据集
train_dataset = ImageFolder(train_root, transform=data_transform)
test_dataset = ImageFolder(test_root, transform=data_transform)# Dataset ImageFolder
# Number of datapoints: 60000
# Root location: ../datasets/mnist_png/training
# StandardTransform
# Transform: Compose(
# Resize(size=(28, 28), interpolation=bilinear, max_size=None, antialias=True)
# ToTensor()
# )
# print(train_dataset)# print(train_dataset[0])train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=True)device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")model = Net().to(device)
loss_fn = nn.CrossEntropyLoss().to(device)
learning_rate = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)epoch = 10writer = SummaryWriter('../logs')
total_step = 0for i in range(epoch):model.train()pre_step = 0pre_loss = 0for data in train_loader:images, labels = dataimages = images.to(device)labels = labels.to(device)optimizer.zero_grad()outputs = model(images)loss = loss_fn(outputs, labels)loss.backward()optimizer.step()pre_loss = pre_loss + loss.item()pre_step += 1total_step += 1if pre_step % 100 == 0:print(f"Epoch: {i+1} ,pre_loss = {pre_loss/pre_step}")writer.add_scalar('train_loss', pre_loss / pre_step, total_step)model.eval()pre_accuracy = 0with torch.no_grad():for data in test_loader:images, labels = dataimages = images.to(device)labels = labels.to(device)outputs = model(images)pre_accuracy += outputs.argmax(1).eq(labels).sum().item()print(f"Test_accuracy: {pre_accuracy/len(test_dataset)}")writer.add_scalar('test_accuracy', pre_accuracy / len(test_dataset), i)torch.save(model, f'../models/model{i}.pth')writer.close()
参考文章
【CNN】搭建AlexNet网络——并处理自定义的数据集(猫狗分类)
How to download MNIST images as PNGs