pytorch MNIST 手写数字识别 + 使用自己的测试集 + 数据增强后再训练

文章目录

  • 1. MNIST 手写数字识别
  • 2. 聚焦数据集扩充后的模型训练
  • 3. pytorch 手写数字识别基本实现
    • 3.1完整代码及 MNIST 测试集测试结果
      • 3.1.1代码
      • 3.1.2 MNIST 测试集测试结果
    • 3.2 使用自己的图片进行测试
      • 3.2.1 测试图片预处理代码
      • 3.2.2 测试图片结果
  • 4. 数据增强
    • 4.1 手动读取 MNIST 数据集
    • 4.2 数据增强
      • 4.2.1 像素反转
      • 4.2.2 图像旋转
        • 4.2.2.1 图像类别统计
        • 4.2.2.2 根据类别进行等量均类划分
      • 4.2.3 像素反转 + 图像旋转
      • 4.2.4 选择加载不同的处理后的数据集
    • 4.3 完整代码
  • 5. 模型再训练
    • 5.1 怎么加载 split 后的数据?
      • 5.1.1 创建自己的 dataset 类
      • 5.1.2 load 分割好的数据
    • 5.2 加载完成后怎么和原始数据合并,然后送入模型进行训练?
    • 5.3 完整代码
    • 5.4 训练结果
      • 5.4.1 只进行像素反转
        • 5.4.1.1 测试结果
        • 5.4.1.2 在自己的数据上测试
          • 测试代码
          • 测试结果
      • 5.4.2 只进行图像旋转
        • 5.4.2.1 测试结果
        • 5.4.2.2 在自己的数据上测试
          • 测试代码
          • 测试结果
      • 5.4.3 二者同时进行
        • 5.4.3.1 测试结果
        • 5.4.3.2 在自己的数据上测试
          • 测试代码
          • 测试结果
    • 5.5 结果整合
  • 结语

1. MNIST 手写数字识别

MNIST 数据集分为两部分,分别是训练集和测试集,其中训练集含有 60000 张图片,测试集中含有 10000 张图片。从官网下载的数据集主要包括有 4 个文件:

文件名称文件用途
train-images-idx3-ubyte.gz训练集图像
train-labels-idx1-ubyte.gz训练集 label
t10k-images-idx3-ubyte.gz测试集图像
t10k-labels-idx1-ubyte.gz测试集 label

参考:
MNIST 数据集介绍 1
MNIST 数据集介绍 2

2. 聚焦数据集扩充后的模型训练

Internet 中有很多关于 pytorch 实现手写数字识别的博客了,所以本文不再对这一方面作过多的叙述。更多地,本文对 MNIST 数据集进行了扩充,利用 3 中不同的数据集构成对模型进行训练,每类数据集构成都包含了 12000 张图片。这 3 种不同的数据集构成如下:

  • 原始数据集(60000 张)+ 像素反转后的图片(60000 张)
  • 原始数据集(60000 张)+ 对图像进行 90°, 180°, 270° 等量均类旋转后的图片(60000 张)(注意:此处的等量均类是指对每个角度都旋转了 20000 张图片,同时,这 20000 张图片中包含了数字 0-9 这十个类别的图片各 2000 张)
  • 原始数据集(60000 张)+ 像素反转后的图片(30000 张)+ 等量均类旋转的图片(30000 张)

建议自己尝试进行数据分割,也可以利用分割好了的数据 click->已分割好了的数据

3. pytorch 手写数字识别基本实现

3.1完整代码及 MNIST 测试集测试结果

3.1.1代码

完整代码如下:

import torch
import torch.nn as nn
import torchvision.datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from PIL import Imageclass CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),)self.conv2 = nn.Sequential(nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),)self.conv3 = nn.Sequential(nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),nn.ReLU(),)self.fullyConnected = nn.Sequential(nn.Flatten(),nn.Linear(in_features=7 * 7 * 64, out_features=128),nn.ReLU(),nn.Linear(in_features=128, out_features=10),)def forward(self, img):output = self.conv1(img)output = self.conv2(output)output = self.conv3(output)output = self.fullyConnected(output)return outputdef get_device():if torch.cuda.is_available():train_device = torch.device('cuda')else:train_device = torch.device('cpu')return train_devicedef get_data_loader(dat_path, bat_size, trans, to_train=False):dat_set = torchvision.datasets.MNIST(root=dat_path, train=to_train, transform=trans, download=True)if to_train is True:dat_loader = torch.utils.data.DataLoader(dat_set, batch_size=bat_size, shuffle=True)else:dat_loader = torch.utils.data.DataLoader(dat_set, batch_size=bat_size)return dat_set, dat_loaderdef show_part_of_image(dat_loader, row, col):iteration = enumerate(dat_loader)idx, (exam_img, exam_label) = next(iteration)fig = plt.figure(num=1)for i in range(row * col):plt.subplot(row, col, i + 1)plt.tight_layout()plt.imshow(exam_img[i][0], cmap='gray', interpolation='none')plt.title('Number: {}'.format(exam_label[i]))plt.xticks([])plt.yticks([])plt.show()def train(network, dat_loader, device, epos, loss_function, optimizer):for epoch in range(1, epos + 1):network.train(mode=True)for idx, (train_img, train_label) in enumerate(dat_loader):train_img = train_img.to(device)train_label = train_label.to(device)outputs = network(train_img)optimizer.zero_grad()loss = loss_function(outputs, train_label)loss.backward()optimizer.step()if idx % 100 == 0:cnt = idx * len(train_img) + (epoch - 1) * len(dat_loader.dataset)print('epoch: {}, [{}/{}({:.0f}%)], loss: {:.6f}'.format(epoch,idx * len(train_img),len(dat_loader.dataset),(100 * cnt) / (len(dat_loader.dataset) * epos),loss.item()))print('------------------------------------------------')print('Training ended.')return networkdef test(network, dat_loader, device, loss_function):test_loss_avg, correct, total = 0, 0, 0test_loss = []network.train(mode=False)with torch.no_grad():for idx, (test_img, test_label) in enumerate(dat_loader):test_img = test_img.to(device)test_label = test_label.to(device)total += test_label.size(0)outputs = network(test_img)loss = loss_function(outputs, test_label)test_loss.append(loss.item())predictions = torch.argmax(outputs, dim=1)correct += torch.sum(predictions == test_label)test_loss_avg = np.average(test_loss)print('Total: {}, Correct: {}, Accuracy: {:.2f}%, AverageLoss: {:.6f}'.format(total, correct,correct / total * 100,test_loss_avg))def show_part_of_test_result(network, dat_loader, row, col):iteration = enumerate(dat_loader)idx, (exam_img, exam_label) = next(iteration)with torch.no_grad():outputs = network(exam_img)fig = plt.figure()for i in range(row * col):plt.subplot(row, col, i + 1)plt.tight_layout()plt.imshow(exam_img[i][0], cmap='gray', interpolation='none')plt.title('Number: {}, Prediction: {}'.format(exam_label[i], outputs.data.max(1, keepdim=True)[1][i].item()))plt.xticks([])plt.yticks([])plt.show()batch_size, epochs = 64, 10
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.1307], std=[0.3081])])
my_device = get_device()path = './data'
_, train_data_loader = get_data_loader(path, batch_size, transform, True)
print('Training data loaded.')show_part_of_image(train_data_loader, 3, 3)_, test_data_loader = get_data_loader(path, batch_size, transform)
print('Testing data loaded.')cnn = CNN()
loss_func = nn.CrossEntropyLoss()
optim = torch.optim.Adam(cnn.parameters(), lr=0.01)cnn = train(cnn, train_data_loader, my_device, epochs, loss_func, optim)
test(cnn, test_data_loader, my_device, loss_func)show_part_of_test_result(cnn, test_data_loader, 5, 2)torch.save(cnn, './cnn.pth')

3.1.2 MNIST 测试集测试结果

模型测试结果:
在这里插入图片描述
其中一些超参数如下:

  • batch_size: 64
  • epochs: 10

同时,采用交叉熵 CrossEntropyLoss 来计算 loss,Adam 来进行优化:
在这里插入图片描述
模型在测试集上的准确率达到了 97.32%,从右侧的测试集采样结果来看,正确率也相对较高;

3.2 使用自己的图片进行测试

另外,还在画图中做了 0-9 这 10 个数字代入模型进行识别。注意:在画图中做的图片必须要是 28 * 28 的大小(当然也可以用 python 进行裁剪,这里就偷个懒~)
还需要注意的是,MNIST 数据集中的图片是黑底白字的,而通过画图做出的图片是白底黑字的,因此若想得到准确结果的话,必须要对需要测试的图片进行像素反转的预处理操作;

3.2.1 测试图片预处理代码

注意:由于将模型保存进了 cnn.pth 文件,测试时直接 torch.load('./cnn.pth') 即可(当然也可以用官方推荐的只保存参数的方法);需要注意的是:记得把网络结构的定义复制过来,否则会报错;

import torch
import numpy as np
from PIL import Image
from torchvision import transforms
import torch.nn as nn
import matplotlib.pyplot as pltclass CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),)self.conv2 = nn.Sequential(nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),)self.conv3 = nn.Sequential(nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),nn.ReLU(),)self.fullyConnected = nn.Sequential(nn.Flatten(),nn.Linear(in_features=7 * 7 * 64, out_features=128),nn.ReLU(),nn.Linear(in_features=128, out_features=10),)def forward(self, input):output = self.conv1(input)output = self.conv2(output)output = self.conv3(output)output = self.fullyConnected(output)return outputmodel = torch.load('./cnn.pth')
model.eval()transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.1307], std=[0.3081])])
unloader = transforms.ToPILImage()for k in range(10):infile = './testImgs/raw/' + '{}.jpg'.format(k)img = Image.open(infile)img = img.convert('L')img_array = np.array(img)# 像素反转for i in range(28):for j in range(28):img_array[i, j] = 255 - img_array[i, j]# print(img_array)img = Image.fromarray(img_array)# img.show()img = transform(img)img = torch.unsqueeze(img, 0)output = model(img)pred = torch.argmax(output, dim=1)image = torch.squeeze(img, 0)image = unloader(image)plt.subplot(5, 2, k + 1)plt.tight_layout()plt.imshow(image, cmap='gray', interpolation='none')plt.title("Number: {}, Prediction: {}".format(k, pred.item()))plt.xticks([])plt.yticks([])
plt.show()

3.2.2 测试图片结果

在这里插入图片描述
(虽然结果正确率挺高,但是那些图片看起来怎么是灰底呢!?)

4. 数据增强

由于我们需要对数据进行处理,因此需要单独将数据读取出来,再进行相应的处理后保存;

4.1 手动读取 MNIST 数据集

关于如何从 .gz 文件中读取图片和图片的 label,参考了这篇文章 手动读取 MNIST 数据集;
主要代码:

def load_mnist(folder, img_file_name, label_file_name):with gzip.open(os.path.join(folder, label_file_name), 'rb') as lbpath:y_set = np.frombuffer(lbpath.read(), np.uint8, offset=8)with gzip.open(os.path.join(folder, img_file_name), 'rb') as imgpath:x_set = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(y_set), 28, 28)return x_set, y_set

在这里插入图片描述

注意,offset 的0000-0003是 magic number,offset的0004-0007是items数目,所以跳过不读,因此将 offset 设置为 8 开始读取;同理:
在这里插入图片描述
将 offset 设置为 16,开始读取图片数据;

4.2 数据增强

4.2.1 像素反转

主要操作就是用 255 - 原像素,代码如下:

def all_divert(x, save_path):# 使 numpy 矩阵可以读写x = np.require(x, dtype='f4', requirements=['O', 'W'])for i in range(len(x)):for pixel in np.nditer(x[i], op_flags=['readwrite']):pixel[...] = 255 - pixelsave_img = Image.fromarray(x[i])save_img = save_img.convert('L')save_img.save(save_path + '{}.jpg'.format(i))

4.2.2 图像旋转

4.2.2.1 图像类别统计

在对图像进行旋转的时候,需要做到等量均类,这两个条件缺一不可(因为你不可能让一个人看到一个陌生的动物却能准确说出这个动物是什么),因此首先对图片数据根据它们的 label 进行一个统计。

def classify_img(y):cnt = {0: [], 1: [], 2: [], 3: [], 4: [], 5: [], 6: [], 7: [], 8: [], 9: []}for i in range(len(y)):label = y[i]cnt[label].append(i)return cnt

这里返回的字典 cnt 中每个字典项保存有属于该 key(label) 的图像的编号;

4.2.2.2 根据类别进行等量均类划分

有了对每个 label 的统计,从中进行划分即可。此处是对全部图像进行 90°,180°,270° 这三类旋转,因此对于每个 label 都将其编号集合进行三等分

def all_rotate(x, cnt_seq, save_path):'''x: 图像数据集cnt_seq: 统计后的 cnt 字典save_path: 图像保存路径'''x = np.require(x, dtype='f4', requirements=['O', 'W'])for i in range(10):# 将数据集分为 3 份data_len = int(len(cnt_seq[i]) / 3)for split in range(3):left = split * data_lenif split == 2:# 最后一份包含剩下的所有图像right = len(cnt_seq[i])else:right = (split + 1) * data_lenfor j in range(left, right):# split + 1 表示旋转 90° 的 (split + 1) 倍x[cnt_seq[i][j]] = np.rot90(x[cnt_seq[i][j]], split + 1)save_img = Image.fromarray(x[cnt_seq[i][j]])save_img = save_img.convert('L')save_img.save(save_path + '{}.jpg'.format(cnt_seq[i][j]))

4.2.3 像素反转 + 图像旋转

就是上面两种操作的综合,只不过将原始数据集划分为 4 等分:

def divert_and_rotate(x, cnt_seq, save_path):'''x: 图像数据集cnt_seq: 统计后的 cnt 字典save_path: 图像保存路径'''x = np.require(x, dtype='f4', requirements=['O', 'W'])for i in range(10):# 将数据集分为 4 份data_len = int(len(cnt_seq[i]) / 4)for split in range(4):left = split * data_lenif split == 3:right = len(cnt_seq[i])else:right = (split + 1) * data_lenif split == 0:# 第一等份进行像素反转for j in range(left, right):for pixel in np.nditer(x[cnt_seq[i][j]], op_flags=['readwrite']):pixel[...] = 255 - pixelsave_img = Image.fromarray(x[cnt_seq[i][j]]).convert('L')save_img.save(save_path + '{}.jpg'.format(cnt_seq[i][j]))else:# 后面的进行图像旋转for j in range(left, right):x[cnt_seq[i][j]] = np.rot90(x[cnt_seq[i][j]], split)save_img = Image.fromarray(x[cnt_seq[i][j]]).convert('L')save_img.save(save_path + '{}.jpg'.format(cnt_seq[i][j]))

4.2.4 选择加载不同的处理后的数据集

上面的三个函数可以实现将不同的处理方式处理后的数据集进行保存,需要注意的是:测试集进行了划分,训练集也要进行划分! 因此下面的 split_and_save() 函数用来选择不同的处理模式;

def split_and_save(x, y, save_img_path, to_divert=False, to_rotate=False):'''x: 图像数据; y: label 数据save_img_path: 图像保存路径to_divert: 是否进行像素反转to_rotate: 是否进行图像旋转'''count_seq = classify_img(y)if to_divert is True and to_rotate is False:all_divert(x, save_img_path)elif to_divert is False and to_rotate is True:all_rotate(x, count_seq, save_img_path)elif to_divert is True and to_rotate is True:divert_and_rotate(x, count_seq, save_img_path)else:return

接下来就是图像处理了:

root_path = './data/MNIST/raw'
# 加载训练集
img_file_path = 'train-images-idx3-ubyte.gz'
label_file_path = 'train-labels-idx1-ubyte.gz'raw_x, raw_y = load_mnist(root_path, img_file_path, label_file_path)save_root_path = './testImgs'
split_and_save(raw_x, raw_y, save_root_path + '/divert/', True, False)
split_and_save(raw_x, raw_y, save_root_path + '/rotate/', False, True)
split_and_save(raw_x, raw_y, save_root_path + '/divert_and_rotate/', True, True)# 将训练集 label 保存在 label_train.txt 中
with open(save_root_path + '/label_train.txt', 'w') as f:for label in raw_y:f.write(str(label))f.write('\n')
f.close()# 加载测试集
img_file_path = 't10k-images-idx3-ubyte.gz'
label_file_path = 't10k-labels-idx1-ubyte.gz'raw_x, raw_y = load_mnist(root_path, img_file_path, label_file_path)
split_and_save(raw_x, raw_y, save_root_path + '/divert_test/', True, False)
split_and_save(raw_x, raw_y, save_root_path + '/rotate_test/', False, True)
split_and_save(raw_x, raw_y, save_root_path + '/divert_and_rotate_test/', True, True)# 将测试集 label 保存在 label_test.txt 中
with open(save_root_path + '/label_test.txt', 'w') as f:for label in raw_y:f.write(str(label))f.write('\n')
f.close()

4.3 完整代码

import gzip
import os
import numpy as np
from PIL import Imagedef load_mnist(folder, img_file_name, label_file_name):with gzip.open(os.path.join(folder, label_file_name), 'rb') as lbpath:y_set = np.frombuffer(lbpath.read(), np.uint8, offset=8)with gzip.open(os.path.join(folder, img_file_name), 'rb') as imgpath:x_set = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(y_set), 28, 28)return x_set, y_setdef all_divert(x, save_path):x = np.require(x, dtype='f4', requirements=['O', 'W'])for i in range(len(x)):for pixel in np.nditer(x[i], op_flags=['readwrite']):pixel[...] = 255 - pixelsave_img = Image.fromarray(x[i])save_img = save_img.convert('L')save_img.save(save_path + '{}.jpg'.format(i))def classify_img(y):cnt = {0: [], 1: [], 2: [], 3: [], 4: [], 5: [], 6: [], 7: [], 8: [], 9: []}for i in range(len(y)):label = y[i]cnt[label].append(i)return cntdef all_rotate(x, cnt_seq, save_path):x = np.require(x, dtype='f4', requirements=['O', 'W'])for i in range(10):# 将数据集分为 3 份data_len = int(len(cnt_seq[i]) / 3)for split in range(3):left = split * data_lenif split == 2:right = len(cnt_seq[i])else:right = (split + 1) * data_lenfor j in range(left, right):x[cnt_seq[i][j]] = np.rot90(x[cnt_seq[i][j]], split + 1)save_img = Image.fromarray(x[cnt_seq[i][j]])save_img = save_img.convert('L')save_img.save(save_path + '{}.jpg'.format(cnt_seq[i][j]))def divert_and_rotate(x, cnt_seq, save_path):x = np.require(x, dtype='f4', requirements=['O', 'W'])for i in range(10):# 将数据集分为 4 份data_len = int(len(cnt_seq[i]) / 4)for split in range(4):left = split * data_lenif split == 3:right = len(cnt_seq[i])else:right = (split + 1) * data_lenif split == 0:for j in range(left, right):for pixel in np.nditer(x[cnt_seq[i][j]], op_flags=['readwrite']):pixel[...] = 255 - pixelsave_img = Image.fromarray(x[cnt_seq[i][j]]).convert('L')save_img.save(save_path + '{}.jpg'.format(cnt_seq[i][j]))else:for j in range(left, right):x[cnt_seq[i][j]] = np.rot90(x[cnt_seq[i][j]], split)save_img = Image.fromarray(x[cnt_seq[i][j]]).convert('L')save_img.save(save_path + '{}.jpg'.format(cnt_seq[i][j]))def split_and_save(x, y, save_img_path, to_divert=False, to_rotate=False):count_seq = classify_img(y)if to_divert is True and to_rotate is False:all_divert(x, save_img_path)elif to_divert is False and to_rotate is True:all_rotate(x, count_seq, save_img_path)elif to_divert is True and to_rotate is True:divert_and_rotate(x, count_seq, save_img_path)else:returnroot_path = './data/MNIST/raw'
img_file_path = 'train-images-idx3-ubyte.gz'
label_file_path = 'train-labels-idx1-ubyte.gz'raw_x, raw_y = load_mnist(root_path, img_file_path, label_file_path)save_root_path = './testImgs'
split_and_save(raw_x, raw_y, save_root_path + '/divert/', True, False)
split_and_save(raw_x, raw_y, save_root_path + '/rotate/', False, True)
split_and_save(raw_x, raw_y, save_root_path + '/divert_and_rotate/', True, True)with open(save_root_path + '/label_train.txt', 'w') as f:for label in raw_y:f.write(str(label))f.write('\n')
f.close()img_file_path = 't10k-images-idx3-ubyte.gz'
label_file_path = 't10k-labels-idx1-ubyte.gz'raw_x, raw_y = load_mnist(root_path, img_file_path, label_file_path)
split_and_save(raw_x, raw_y, save_root_path + '/divert_test/', True, False)
split_and_save(raw_x, raw_y, save_root_path + '/rotate_test/', False, True)
split_and_save(raw_x, raw_y, save_root_path + '/divert_and_rotate_test/', True, True)with open(save_root_path + '/label_test.txt', 'w') as f:for label in raw_y:f.write(str(label))f.write('\n')
f.close()

5. 模型再训练

模型再训练需要解决 2 个问题:

  • 怎么加载 split 后的数据?
  • 加载完成后怎么和原始数据合并,然后送入模型进行训练?

5.1 怎么加载 split 后的数据?

5.1.1 创建自己的 dataset 类

为了使自己的数据集和原始数据集进行合并,可以继承 torch.utils.data.Dataset 类开发自己的 my_dataset 类:

class my_dataset(torch.utils.data.Dataset):def __init__(self, img, label, transform=None):super(my_dataset, self).__init__()self.dataset = imgself.label = labelself.transform = transformdef __getitem__(self, item):data = self.dataset[item]lb = self.label[item]if self.transform is not None:data = self.transform(data)return data, lbdef __len__(self):return len(self.dataset)

这里,__init(self)__, __getitem(self, item)__, __len(self)__ 是必须实现的,当把之前分割好的 img, label 数据加载进来后,放入 my_dataset 类即可;

5.1.2 load 分割好的数据

加载分割好的数据,返回 my_dataset 对象;

def load(trans, to_divert=False, to_rotate=False, train=False):'''trans: torchvision.transforms 对象to_divert: 是否进行像素反转to_rotate: 是否进行图像旋转train: 是否是用于训练的数据'''x, y = [], []root_path = './testImgs/'# 加载训练数据if train is True:num = 6e4label_path = root_path + 'label_train.txt'if to_divert is True and to_rotate is False:load_path = root_path + 'divert/'elif to_divert is False and to_rotate is True:load_path = root_path + 'rotate/'elif to_divert is True and to_rotate is True:load_path = root_path + 'divert_and_rotate/'else:returnelse:num = 1e4label_path = root_path + 'label_test.txt'if to_divert is True and to_rotate is False:load_path = root_path + 'divert_test/'elif to_divert is False and to_rotate is True:load_path = root_path + 'rotate_test/'elif to_divert is True and to_rotate is True:load_path = root_path + 'divert_and_rotate_test/'else:returnfor i in range(int(num)):path = load_path + '{}.jpg'.format(i)img = Image.open(path).convert('L')x.append(img)# 加载 labelwith open(label_path, 'r') as f:for i in range(int(num)):label = f.readline()label = label.strip('\n')y.append(int(label))f.close()dataset = my_dataset(x, y, trans)return dataset

5.2 加载完成后怎么和原始数据合并,然后送入模型进行训练?

以原始训练集和数据增强后的训练集合并为例:

	path = './data'# get_data_loader() 就是基本实现中定义好的函数train_data_set, _ = get_data_loader(path, batch_size, transform, True)# 增强后的数据集enhanced_train_data_set = load(transform, True, True, True)# 采用 ConcatDataset() 进行连接train_data_set = torch.utils.data.dataset.ConcatDataset([enhanced_train_data_set, train_data_set])print('Training data loaded.')# 将 dataset 放入 loader 中train_data_loader = torch.utils.data.DataLoader(train_data_set, batch_size=batch_size, shuffle=True)

5.3 完整代码

完整代码和基本实现中的差不多,主要区别在于:

  • 增加了 my_dataset 类;
  • 加载自己的数据集的 load() 函数;
  • 主函数里面对数据进行了加载和合并;
import torch
import torch.nn as nn
import torchvision.datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from PIL import Imageclass CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),)self.conv2 = nn.Sequential(nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),)self.conv3 = nn.Sequential(nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),nn.ReLU(),)self.fullyConnected = nn.Sequential(nn.Flatten(),nn.Linear(in_features=7 * 7 * 64, out_features=128),nn.ReLU(),nn.Linear(in_features=128, out_features=10),)def forward(self, img):output = self.conv1(img)output = self.conv2(output)output = self.conv3(output)output = self.fullyConnected(output)return outputclass my_dataset(torch.utils.data.Dataset):def __init__(self, img, label, transform=None):super(my_dataset, self).__init__()self.dataset = imgself.label = labelself.transform = transformdef __getitem__(self, item):data = self.dataset[item]lb = self.label[item]if self.transform is not None:data = self.transform(data)return data, lbdef __len__(self):return len(self.dataset)def get_device():if torch.cuda.is_available():train_device = torch.device('cuda')else:train_device = torch.device('cpu')return train_devicedef get_data_loader(dat_path, bat_size, trans, to_train=False):dat_set = torchvision.datasets.MNIST(root=dat_path, train=to_train, transform=trans, download=True)if to_train is True:dat_loader = torch.utils.data.DataLoader(dat_set, batch_size=bat_size, shuffle=True)else:dat_loader = torch.utils.data.DataLoader(dat_set, batch_size=bat_size)return dat_set, dat_loaderdef show_part_of_image(dat_loader, row, col):iteration = enumerate(dat_loader)idx, (exam_img, exam_label) = next(iteration)fig = plt.figure(num=1)for i in range(row * col):plt.subplot(row, col, i + 1)plt.tight_layout()plt.imshow(exam_img[i][0], cmap='gray', interpolation='none')plt.title('Number: {}'.format(exam_label[i]))plt.xticks([])plt.yticks([])plt.show()def train(network, dat_loader, device, epos, loss_function, optimizer):for epoch in range(1, epos + 1):network.train(mode=True)for idx, (train_img, train_label) in enumerate(dat_loader):train_img = train_img.to(device)train_label = train_label.to(device)outputs = network(train_img)optimizer.zero_grad()loss = loss_function(outputs, train_label)loss.backward()optimizer.step()if idx % 100 == 0:cnt = idx * len(train_img) + (epoch - 1) * len(dat_loader.dataset)print('epoch: {}, [{}/{}({:.0f}%)], loss: {:.6f}'.format(epoch,idx * len(train_img),len(dat_loader.dataset),(100 * cnt) / (len(dat_loader.dataset) * epos),loss.item()))print('------------------------------------------------')print('Training ended.')return networkdef test(network, dat_loader, device, loss_function):test_loss_avg, correct, total = 0, 0, 0test_loss = []network.train(mode=False)with torch.no_grad():for idx, (test_img, test_label) in enumerate(dat_loader):test_img = test_img.to(device)test_label = test_label.to(device)total += test_label.size(0)outputs = network(test_img)loss = loss_function(outputs, test_label)test_loss.append(loss.item())predictions = torch.argmax(outputs, dim=1)correct += torch.sum(predictions == test_label)test_loss_avg = np.average(test_loss)print('Total: {}, Correct: {}, Accuracy: {:.2f}%, AverageLoss: {:.6f}'.format(total, correct,correct / total * 100,test_loss_avg))def show_part_of_test_result(network, dat_loader, row, col):iteration = enumerate(dat_loader)idx, (exam_img, exam_label) = next(iteration)with torch.no_grad():outputs = network(exam_img)fig = plt.figure()for i in range(row * col):plt.subplot(row, col, i + 1)plt.tight_layout()plt.imshow(exam_img[i][0], cmap='gray', interpolation='none')plt.title('Number: {}, Prediction: {}'.format(exam_label[i], outputs.data.max(1, keepdim=True)[1][i].item()))plt.xticks([])plt.yticks([])plt.show()def load(trans, to_divert=False, to_rotate=False, train=False):x, y = [], []root_path = './testImgs/'if train is True:num = 6e4label_path = root_path + 'label_train.txt'if to_divert is True and to_rotate is False:load_path = root_path + 'divert/'elif to_divert is False and to_rotate is True:load_path = root_path + 'rotate/'elif to_divert is True and to_rotate is True:load_path = root_path + 'divert_and_rotate/'else:returnelse:num = 1e4label_path = root_path + 'label_test.txt'if to_divert is True and to_rotate is False:load_path = root_path + 'divert_test/'elif to_divert is False and to_rotate is True:load_path = root_path + 'rotate_test/'elif to_divert is True and to_rotate is True:load_path = root_path + 'divert_and_rotate_test/'else:returnfor i in range(int(num)):path = load_path + '{}.jpg'.format(i)img = Image.open(path).convert('L')x.append(img)with open(label_path, 'r') as f:for i in range(int(num)):label = f.readline()label = label.strip('\n')y.append(int(label))f.close()dataset = my_dataset(x, y, trans)return datasetif __name__ == '__main__':batch_size, epochs = 128, 10transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.1307], std=[0.3081])])my_device = get_device()path = './data'train_data_set, _ = get_data_loader(path, batch_size, transform, True)enhanced_train_data_set = load(transform, True, True, True)train_data_set = torch.utils.data.dataset.ConcatDataset([enhanced_train_data_set, train_data_set])print('Training data loaded.')train_data_loader = torch.utils.data.DataLoader(train_data_set, batch_size=batch_size, shuffle=True)show_part_of_image(train_data_loader, 3, 3)test_data_set, _ = get_data_loader(path, batch_size, transform)enhanced_test_data_set = load(transform, True, True, False)test_data_set = torch.utils.data.dataset.ConcatDataset([enhanced_test_data_set, test_data_set])print('Testing data loaded.')test_data_loader = torch.utils.data.DataLoader(test_data_set, batch_size=batch_size, shuffle=True)cnn = CNN()loss_func = nn.CrossEntropyLoss()optim = torch.optim.Adam(cnn.parameters(), lr=0.01)cnn = train(cnn, train_data_loader, my_device, epochs, loss_func, optim)test(cnn, test_data_loader, my_device, loss_func)show_part_of_test_result(cnn, test_data_loader, 5, 2)torch.save(cnn, './cnn2.pth')

5.4 训练结果

5.4.1 只进行像素反转

5.4.1.1 测试结果

在这里插入图片描述
其中一些超参数如下:

  • batch_size: 128
  • epochs: 10

模型在测试集上的准确率达到了 97.76%,从右侧的测试集采样结果来看,正确率也相对较高;

5.4.1.2 在自己的数据上测试

测试代码
import torch
# from test import CNN
import numpy as np
from PIL import Image
from torchvision import transforms
import torch.nn as nn
import matplotlib.pyplot as pltclass CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),)self.conv2 = nn.Sequential(nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),)self.conv3 = nn.Sequential(nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),nn.ReLU(),)self.fullyConnected = nn.Sequential(nn.Flatten(),nn.Linear(in_features=7 * 7 * 64, out_features=128),nn.ReLU(),nn.Linear(in_features=128, out_features=10),)def forward(self, input):output = self.conv1(input)output = self.conv2(output)output = self.conv3(output)output = self.fullyConnected(output)return outputmodel = torch.load('./cnn2.pth')
model.eval()transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.1307], std=[0.3081])])
unloader = transforms.ToPILImage()for k in range(10):infile = './testImgs/raw/' + '{}.jpg'.format(k)img = Image.open(infile)img = img.convert('L')img_array = np.array(img)img = Image.fromarray(img_array)# img.show()img = transform(img)img = torch.unsqueeze(img, 0)output = model(img)pred = torch.argmax(output, dim=1)image = torch.squeeze(img, 0)image = unloader(image)plt.subplot(5, 2, k + 1)plt.tight_layout()plt.imshow(image, cmap='gray', interpolation='none')plt.title("Number: {}, Prediction: {}".format(k, pred.item()))plt.xticks([])plt.yticks([])
plt.show()
测试结果

在这里插入图片描述
可以从右侧的结果中看出,准确率比较高(除了看起来比较讨厌的灰底?);

5.4.2 只进行图像旋转

5.4.2.1 测试结果

在这里插入图片描述
用到的 batch_size, epochs 和上面的一样。
模型在测试集上的准确率达到了 93.54%,从右侧的测试集采样结果来看,正确率也相对较高;

5.4.2.2 在自己的数据上测试

测试代码
import torch
# from test import CNN
import numpy as np
from PIL import Image
from torchvision import transforms
import torch.nn as nn
import matplotlib.pyplot as pltclass CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),)self.conv2 = nn.Sequential(nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),)self.conv3 = nn.Sequential(nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),nn.ReLU(),)self.fullyConnected = nn.Sequential(nn.Flatten(),nn.Linear(in_features=7 * 7 * 64, out_features=128),nn.ReLU(),nn.Linear(in_features=128, out_features=10),)def forward(self, input):output = self.conv1(input)output = self.conv2(output)output = self.conv3(output)output = self.fullyConnected(output)return outputmodel = torch.load('./cnn2.pth')
model.eval()transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.1307], std=[0.3081])])
unloader = transforms.ToPILImage()for k in range(10):infile = './testImgs/raw/' + 'r{}.jpg'.format(k)img = Image.open(infile)img = img.convert('L')img_array = np.array(img)# 注意进行需要是黑底白字的图片for i in range(28):for j in range(28):img_array[i, j] = 255 - img_array[i, j]img = Image.fromarray(img_array)img = transform(img)img = torch.unsqueeze(img, 0)output = model(img)pred = torch.argmax(output, dim=1)image = torch.squeeze(img, 0)image = unloader(image)plt.subplot(5, 2, k + 1)plt.tight_layout()plt.imshow(image, cmap='gray', interpolation='none')plt.title("Number: {}, Prediction: {}".format(k, pred.item()))plt.xticks([])plt.yticks([])
plt.show()
测试结果

在这里插入图片描述
比较有趣的就是 9 旋转 180° 就变成 6 了hhh

5.4.3 二者同时进行

5.4.3.1 测试结果

在这里插入图片描述
用到的 batch_size, epochs 同样是 128 和 10;
模型在测试集上的准确率达到了 95.38%,从右侧的测试集采样结果来看,正确率也相对较高;

5.4.3.2 在自己的数据上测试

测试代码
import torch
# from test import CNN
import numpy as np
from PIL import Image
from torchvision import transforms
import torch.nn as nn
import matplotlib.pyplot as pltclass CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),)self.conv2 = nn.Sequential(nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),)self.conv3 = nn.Sequential(nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),nn.ReLU(),)self.fullyConnected = nn.Sequential(nn.Flatten(),nn.Linear(in_features=7 * 7 * 64, out_features=128),nn.ReLU(),nn.Linear(in_features=128, out_features=10),)def forward(self, input):output = self.conv1(input)output = self.conv2(output)output = self.conv3(output)output = self.fullyConnected(output)return outputmodel = torch.load('./cnn2.pth')
model.eval()transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.1307], std=[0.3081])])
unloader = transforms.ToPILImage()# 前 3 张图片是像素反转后的图片,后 3 张是未作处理的图片,最后 4 张二者同时进行的图片(注意像素反转是指将黑底白字转换为白底黑字)
for k in range(10):if k < 3:infile = './testImgs/raw/' + '{}.jpg'.format(k)img = Image.open(infile)img = img.convert('L')img_array = np.array(img)elif 3 <= k < 6:infile = './testImgs/raw/' + 'r{}.jpg'.format(k)img = Image.open(infile)img = img.convert('L')img_array = np.array(img)for i in range(28):for j in range(28):img_array[i, j] = 255 - img_array[i, j]else:infile = './testImgs/raw/' + 'r{}.jpg'.format(k)img = Image.open(infile)img = img.convert('L')img_array = np.array(img)img = Image.fromarray(img_array)img = transform(img)img = torch.unsqueeze(img, 0)output = model(img)pred = torch.argmax(output, dim=1)image = torch.squeeze(img, 0)image = unloader(image)plt.subplot(5, 2, k + 1)plt.tight_layout()plt.imshow(image, cmap='gray', interpolation='none')plt.title("Number: {}, Prediction: {}".format(k, pred.item()))plt.xticks([])plt.yticks([])
plt.show()
测试结果

在这里插入图片描述

5.5 结果整合

原始数据只进行像素反转只进行图像旋转二者同时进行
batch_size, epochs64, 10128, 10128, 10128, 10
accuracy97.32%97.76%93.54%95.38%

结语

这样的想法来源于在对 MNIST 手写数字识别进行基本实现并利用自己做的图进行进行测试的时候,开始由于没有认识到黑底白字和白底黑字的问题,因此模型测试结果很差;然后就是写的数字必须比较端正,否则测试结果也很差;
因此在学长的启发下对数据集进行了拓展,使之能够应用于更广的场景中;
另外,在这里我只进行了 90°, 180°, 270° 这三种旋转,如果有兴趣的话可以尝试更多不同角度的旋转;

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/565070.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

python基础(13)之数组

目录 数组 一、访问数组的元素 二、数组的长度 三、修改数组 四、数组的其它操作 数组 Python 没有对数组的内置支持&#xff0c;但可以使用Python 列表代替。 例如&#xff1a; ben ["笨小孩1", "笨小孩2", "笨小孩3"]一、访问数组的元…

C语言归并排序(合并排序)

归并排序也称合并排序&#xff0c;其算法思想是将待排序序列分为两部分&#xff0c;依次对分得的两个部分再次使用归并排序&#xff0c;之后再对其进行合并。仅从算法思想上了解归并排序会觉得很抽象&#xff0c;接下来就以对序列A[0], A[l]…, A[n-1]进行升序排列来进行解说&a…

python基础(14)之 类和对象

目录 Python类和对象 一、创建类 二、创建对象 三、init() 函数 四、对象方法 五、自参数 六、对象及其属性更改 七、pass语句 Python类和对象 Python 类/对象。Python 是一种面向对象的编程语言。Python 中的几乎所有东西都是一个对象&#xff0c;有它的属性和方法。…

C语言顺序查找

顺序査找是一种简单的査找算法&#xff0c;其实现方法是从序列的起始元素开始&#xff0c;逐个将序列中的元素与所要查找的元素进行比较&#xff0c;如果序列中有元素与所要查找的元素相等&#xff0c;那么査找成功&#xff0c;如果査找到序列的最后一个元素都不存在一个元素与…

python基础(15)之 继承

目录 Python继承 一、创建父类 二、创建子类 三、添加 init() 函数 四、使用 super() 函数 五、添加属性 六、添加方法 Python继承 继承允许我们定义一个从另一个类继承所有方法和属性的类。父类是被继承的类&#xff0c;也称为基类。子类是从另一个类继承的类&#xff…

C语言二分查找(折半查找)

二分査找也称折半査找&#xff0c;其优点是查找速度快&#xff0c;缺点是要求所要査找的数据必须是有序序列。该算法的基本思想是将所要査找的序列的中间位置的数据与所要査找的元素进行比较&#xff0c;如果相等&#xff0c;则表示査找成功&#xff0c;否则将以该位置为基准将…

python基础(16)之 日期

目录 Python日期 一、日期输入输出 二、创建日期对象 三、strftime() 方法 Python日期 Python 中的日期不是它自己的数据类型&#xff0c;但我们可以导入一个名为的模块datetime来处理日期作为日期对象。 一、日期输入输出 导入 datetime 模块并显示当前日期&#xff1a;…

python基础(17)之 JSON

Python JSON JSON 是一种用于存储和交换数据的语法。JSON 是文本&#xff0c;用 JavaScript 对象表示法编写。 Python 有一个名为 的内置包json&#xff0c;可用于处理 JSON 数据。 导入 json 模块&#xff1a; import json一.从 JSON 转换为 Python 如果您有 JSON 字符串&am…

python基础(18)之 异常处理

目录 异常处理 一、异常处理 二、else搭配 三、finally语句 四、引发异常 异常处理 try块可让您测试代码块的错误。except块可让您处理错误。finally无论 try- 和 except 块的结果如何&#xff0c;该块都允许您执行代码。 一、异常处理 例如该try块将产生异常&#xff0…

python基础(19)之 输入输出

目录 用户输入 一、格式化输入输出 二、格式化字符串字面值 三、字符串 format() 方法 四、手动格式化字符串 五、旧式字符串格式化方法 用户输入 实在太简单了&#xff0c;就是使用一个input(),将输入后的值传递给另一个变量&#xff0c;相当于动态赋值、 例如&#xff…

C语言函数返回值详解

函数的返回值是指函数被调用之后&#xff0c;执行函数体中的代码所得到的结果&#xff0c;这个结果通过 return 语句返回。 return 语句的一般形式为&#xff1a; return 表达式;或者&#xff1a; return (表达式);有没有( )都是正确的&#xff0c;为了简明&#xff0c;一般…

机器学习之线性回归(python)

目录 一、基本概念 二、概念的数学形式表达 三、确定w和b 1.读取或输入数据 2.归一化、标准化 2.1 均值 2.2 归一化 2.3 标准化 3.求解w和b 1.直接解方程 2.最小二乘法&#xff08;least square method&#xff09;求解&#xff1a; 4. 评估回归模型 四、sklearn中…

C语言函数的调用

函数调用&#xff08;Function Call&#xff09;&#xff0c;就是使用已经定义好的函数。 函数调用的一般形式为&#xff1a; functionName(param1, param2, param3 ...);functionName 是函数名称&#xff0c;param1, param2, param3 …是实参列表。实参可以是常数、变量、表…

机器学习之线性回归(matlab)

目录 一、基本概念 二、概念的数学形式表达 三、确定w和b 1.读取或输入数据 2.归一化、标准化 2.1 均值 2.2 归一化 2.3 标准化 3.求解w和b 1.直接解方程 2.最小二乘法&#xff08;least square method&#xff09;求解&#xff1a; 4. 评估回归模型 四、regress线…

C语言函数声明以及函数原型

C语言代码由上到下依次执行&#xff0c;原则上函数定义要出现在函数调用之前&#xff0c;否则就会报错。但在实际开发中&#xff0c;经常会在函数定义之前使用它们&#xff0c;这个时候就需要提前声明。 函数声明&#xff08;Declaration&#xff09;&#xff0c;就是告诉编译…

python(20)之读写文件

目录 读写文件 1.简单介绍 2.从文件中读取单行数据 3.从文件中读取多行 4.把 string&#xff08;字符串&#xff09; 的内容写入文件 5.写入其他类型的对象 本节知识总结 mode 参数 file 对象 读写文件 1.简单介绍 最常用的参数有两个: open(filename, mode) f op…

C语言全局变量和局部变量深入

局部变量 定义在函数内部的变量称为局部变量&#xff08;Local Variable&#xff09;&#xff0c;它的作用域仅限于函数内部&#xff0c; 离开该函数后就是无效的&#xff0c;再使用就会报错。 示例 int f1(int a){int b,c; //a,b,c仅在函数f1()内有效return abc; } int ma…

python实例之 67,68

目录 67.题目&#xff1a;输入数组&#xff0c;最大的与第一个元素交换&#xff0c;最小的与最后一个元素交换&#xff0c;输出数组。 68.题目&#xff1a;有 n 个整数&#xff0c;使其前面各数顺序向后移 m 个位置&#xff0c;最后 m 个数变成最前面的 m 个数 今天这个不知道…

C语言块级变量

代码块&#xff0c;就是由{ }包围起来的代码。代码块在C语言中随处可见&#xff0c;例如函数体、选择结构、循环结构等。不包含代码块的C语言程序根本不能运行&#xff0c;即使最简单的C语言程序也要包含代码块。 C语言允许在代码块内部定义变量&#xff0c;这样的变量具有块级…

python实例 69,70

69.题目&#xff1a;有n个人围成一圈&#xff0c;顺序排号。从第一个人开始报数&#xff08;从1到3报数&#xff09;&#xff0c;凡报到3的人退出圈子&#xff0c;问最后留下的是原来第几号的那位。 先看一下第一种实现方式 nmax 50 n int(input(请输入总人数:)) num [] f…