动手学深度学习——从零实现softmax分类模型

1. 数据集

fashion mnist是一个由10个类别图像组成的服装分类数据集,共包含60000张训练集图像和10000张测试集图像, 前者用于训练模型参数,后者用于评估模型性能。

2.1 数据集下载

先进行依赖库导入:

%matplotlib inline       # jupyter魔法命令,用于显示matplotlib生成的图形。
import torch             # 用于构建和训练深度学习模型。
import torchvision       # pytorch视觉工具库,用于处理图像数据。
from torch.utils import data       # 一些数据处理的工具类
from torchvision import transforms # 图像转换和增强
from d2l import torch as d2ld2l.use_svg_display()              # 使用svg来显示图片,清晰度更高

接下来使用框架内置函数来下载数据集并读取到内存中,数据集大概在100MB左右。

# ToTensor:图像预处理,将图像数据转为tensor格式
trans = transforms.ToTensor()
# 从网上下载训练数据集,并通过transform转换
mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)
# 从网上下载验证数据集,并通过tranform转换为张量
mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)

数据集下载和解析的过程如下,以train开头的为训练集,以t10k开头的为测试集:

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Using downloaded and verified file: ../data/FashionMNIST/raw/train-images-idx3-ubyte.gz
Extracting ../data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ../data/FashionMNIST/rawDownloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Using downloaded and verified file: ../data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
Extracting ../data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ../data/FashionMNIST/rawDownloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ../data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
100.0%
Extracting ../data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ../data/FashionMNIST/rawDownloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ../data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
100.0%
Extracting ../data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/FashionMNIST/raw

每张图像28*28像素,全部为灰度图像,通道数为1,形状如下:

len(mnist_train), len(mnist_test), mnist_train[0][0].shape, mnist_test[0][0].shape> (60000, 10000, torch.Size([1, 28, 28]), torch.Size([1, 28, 28]))

数据图形示例如下:
在这里插入图片描述

1.2 数据读取

同前面的线性回归一样,我们采用小批量数据读取来训练和测试模型,所以需要封装一个小批量数据读取的迭代器。

batch_size = 256
workers = 4
train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=workers)
test_iter = data.DataLoader(mnist_test, batch_size, shuffle=False,num_workers=workers))
  • batch_size: 分批的批次大小
  • shuffle: 置为True可以打乱样本顺序,随机读取
  • num_workers: 使用多少个进程来并发读取数据

train_iter和test_iter都是一个数据迭代器,可以理解为集合中的iterator,只不过每次迭代的不是一条数据,而是batch_size大小的小批量数据集。

以train_iter为例输出下形状:

for X, y in train_iter:print(X.shape, X.dtype, y.shape, y.dtype)break
> torch.Size([256, 1, 28, 28]) torch.float32 torch.Size([256]) torch.int64

读数据是常见的性能瓶颈,训练之前最好先测试下数据读取速度。

timer = d2l.Timer()
for X, y in train_iter:continue
f'{timer.stop():.2f} sec'# 使用1个进程读取数据
> '10.63 sec'
# 使用4个进程读取数据
> '5.77 sec'

到这里,已经准备好Fashion-MNIST数据集,下面可以有它来训练和评估分类算法性能。

2. 模型

2.1 初始化模型参数

原始数据集中的每个样本都是28x28的图像,每个图像都有784个像素,可以理解为784个特征,我们可以把输入数据都看作长度为784的向量。

前文提到过,在softmax回归中,输出与类别一样多。 因为我们的数据集有10个类别,所以网络模型的输出维度为10。 因此,权重W将构成一个784x10的矩阵, 偏置b将构成一个长度为10的行向量。

num_inputs = 784
num_outputs = 10# 与线性回归一样,使用正态分布初始化我们的权重W,偏置初始化为0。
W = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True)
b = torch.zeros(num_outputs, requires_grad=True)

2.2 定义softmax操作

参考前文,实现softmax由三个步骤组成:

  • 对每个项求幂(使用exp);
  • 对每一行求和(小批量中每个样本是一行),得到每个样本的规范化常数;
    将每一行除以其规范化常数,确保结果的和为1。

数学表达式如下:
在这里插入图片描述
代码实现:

def softmax(X):X_exp = torch.exp(X)partition = X_exp.sum(1, keepdim=True)  # 这里的1表示坐标轴1,即每一行的所有列求和return X_exp / partition  # 这里应用了广播机制

接下来验证是否正确,主要在于两方面:

  • 所有元素是否为正
  • 每一行的和是否为1
X = torch.normal(0, 1, (2, 5))   # 均值为0,标准差为1,2行5列的元素
X_prob = softmax(X)
X_prob, X_prob.sum(1)> (tensor([[0.1686, 0.4055, 0.0849, 0.1064, 0.2347],[0.0217, 0.2652, 0.6354, 0.0457, 0.0321]]),tensor([1.0000, 1.0000]))

2.3 定义模型

模型定义了如何将输入数据通过网络映射到输出。

def net(X):return softmax(torch.matmul(X.reshape((-1, W.shape[0])), W) + b)
  • X.reshape((-1, W.shape[0])): 将输入X的形状由4维矩阵[256, 1, 28, 28]调整为2维矩阵[256, 784],0维为批量大小,1维为向量W的0维长度784
  • 与线性回归一样,使用torch.matmul来计算矩阵X与向量W的矩阵向量积,再加上偏置b就是线性输出
  • 对线性输出softmax就得到各个类别的预测概率

2.4 定义损失函数

前文提到,交叉熵可以认为是真实标签的预测概率的负对数。那在计算交叉熵之前要先拿到真实标签的预测概率。

拿下面的样本数据来说明,y_hat是一个包含2个样本在3个类别的预测概率, y是对应的真实标签,采用下标来表示类别。

y = torch.tensor([0, 2, 1])
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5], [0.075, 0.88, 0.045]])

样本1中,第一类是正确的预测,预测概率为0.1;
样本2中,第三类是正确的预测,预测概率为0.5;
样本3中,第二类是正确的预测,预测概率为0.88;

方法一:采用循环:

result = []
for i in range(len(y)):result.append(y_hat[i, y[i]])torch.tensor(result)> tensor([0.1000, 0.5000, 0.8800])

方法二:直接将y作为y_hat中概率的索引,因为y中存放的正确类别下标与y_hat中是对应的。

y_hat[[0, 1, 2], y]> tensor([0.1000, 0.5000, 0.8800])
  • y_hat[[0, 1, 2], y] 本质上与常规二维数组索引方式y_hat[i, j]形式相同,不同点在于i、j不再是具体的数字,因为要一次性取多个样本的预测值;
  • i = [0, 1, 2]表示行方向上取第0、1、2三个样本;
  • j = y表示三个样本列方向分别取第0, 2, 1个元素;
  • 最终取出的元素是y_hat张量中第0行的第0列、第1行的第2列和第2行的第1列;

方法二比方法一要简单很多,由于是python内置语法,运行效率也更高。这样只需一行代码就可以实现交叉熵损失函数。

def cross_entropy(y_hat, y):return - torch.log(y_hat[range(len(y_hat)), y])cross_entropy(y_hat, y)> tensor([2.3026, 0.6931, 0.1278])
  • 第1个正确值的概率只有0.1,所以计算出来交叉熵损失2.3026就比较大;
  • 第2个正确值的概率有0.5,所以交叉熵0.6931也有所收敛;
  • 第3个正确值 的概率较高0.88, 所以交叉熵0.1278就比较小;

2.4 分类精度

给定预测概率分布y_hat,当我们必须输出预测类别时,我们通常会选择预测概率最高的类别来作为预测结果,但预测概率高的类别有时候不一定是正确预测,这时候就产生了错误预测。

就如同上面第一个样本数据中,预测概率最高的0.6并非正确类别,实际正确类别的预测概率只有0.1。

我们需要一个指标来衡量模型预测的正确率,称之为分类精度,它是正确预测数量与总预测数量之比。

以上面的y和y_hat示例数据为例,可以通过如下步骤来计算分类精度:

  1. 使用argmax获得每行中最大元素的索引来获得预测类别。
  2. 将预测类别与真实y元素进行等值比较,比较前需要将y_hat的数据类型转换为与y的数据类型一致,因为等式运算符“==”对数据类型很敏感,
  3. 结果是一个包含0(错)和1(对)的张量,进行求和就可以得到正确预测的数量。

代码实现如下:

def accuracy(y_hat, y):  #@save"""计算预测正确的数量"""if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:y_hat = y_hat.argmax(axis=1)cmp = y_hat.type(y.dtype) == yreturn float(cmp.type(y.dtype).sum())

以上面的数据来测试:

accuracy(y_hat, y) / len(y)> 0.6666666666666666
  • 第一个样本的预测错误,预测概率最大的索引2(概率0.6)与正确标签0不一致。
  • 第二个样本的预测正确,预测概率最大的索引2(概率0.5)与正确标签2一致。
  • 第三个样本的预测正确,预测概率最大的索引1(概率0.88)与正确标签1一致。

由于我们采用的是小批量多轮迭代训练,会有产生多轮预测数据,所以我们需要封装一个能支持多轮迭代的精度计算函数(主要用于训练后的精度测试)。

# @param net: 网络模型,用于对输入数据X进行类别预测,输出预测概率
# @param data_iter: 数据迭代器,每一轮迭代都包含输入数据X和对应的标签y
def evaluate_accuracy(net, data_iter):  #@save"""计算在指定数据集上模型的精度"""metric = Accumulator(2)  # 2个元素的累加器,用于统计正确预测数、预测总数;with torch.no_grad():for X, y in data_iter:metric.add(accuracy(net(X), y), y.numel())return metric[0] / metric[1]

步骤解读:

  1. 先用模型net对输入X进行类别预测,得到预测概率;
  2. 再使用accuracy对预测结果和真实标签计算精度,并把精度和标签数量进行累加;
  3. 返回模型在数据集上的精度,正确预测数与总预测数的比值。

3. 训练

3.1 定义参数更新函数

这里我们复用线性回归中定义的参数优化函数sgd(小批量随机梯度下降),学习率设为0.1。

lr = 0.1def updater(batch_size):return d2l.sgd([W, b], lr, batch_size)

3.2 定义单轮迭代训练流程

def train_epoch_ch3(net, train_iter, loss, updater):  #@save"""训练模型一个迭代周期(定义见第3章)"""# 长度为3的累加器,分别累加训练损失总和、训练准确度总和、样本数metric = Accumulator(3)for X, y in train_iter:# 使用模型来计算得到预测概率y_hat = net(X)# 计算损失l = loss(y_hat, y)# 反向累积计算梯度l.sum().backward()# 更新优化参数updater(X.shape[0])# 累加损失、精度、样本数metric.add(float(l.sum()), accuracy(y_hat, y), y.numel())# 返回训练损失和训练精度return metric[0] / metric[2], metric[1] / metric[2]

3.3 定义整体训练流程

整体训练流程比较简单,就是循环执行多轮训练,每轮训练后参数都会得到更新,再拿测试数据集基于更新的参数去执行模型当前的表现,得到一个精度值。

def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater):  #@save"""训练模型(定义见第3章)"""for epoch in range(num_epochs):train_metrics = train_epoch_ch3(net, train_iter, loss, updater)test_acc = evaluate_accuracy(net, test_iter)print(f"epoch: {epoch + 1}, loss: {train_metrics[0]}, test_acc: {test_acc}")

3.4 运行训练

基于前面定义的模型,进行10次迭代训练:

num_epochs = 10
train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, updater)
  • num_epochs: 迭代训练次数
  • net: 网络模型
  • train_iter: 训练数据集
  • test_iter: 测试数据集,用于测试模型训练后的性能
  • cross_entropy: 损失函数
  • updater: 参数优化器

整个训练过程中的损失和测试精度变化:

epoch: 1, loss: 0.7857203146616618, test_acc: 0.7882
epoch: 2, loss: 0.5686315283457438, test_acc: 0.7985
epoch: 3, loss: 0.5252757650375366, test_acc: 0.8192
epoch: 4, loss: 0.5007046510060629, test_acc: 0.8231
epoch: 5, loss: 0.4856935443242391, test_acc: 0.8196
epoch: 6, loss: 0.4738648806254069, test_acc: 0.8249
epoch: 7, loss: 0.46540179011027016, test_acc: 0.8299
epoch: 8, loss: 0.45916082598368324, test_acc: 0.8271
epoch: 9, loss: 0.45219682502746583, test_acc: 0.833
epoch: 10, loss: 0.4484250022888184, test_acc: 0.8328

可以看出,随着训练的不断迭代,损失在持续减小,测试精度虽然有略微起伏,但总体上也是在不断提升。

4. 预测

使用训练好的模型对图像进行分类预测,比较图像的实际标签和模型预测是否相同:

def predict_ch3(net, test_iter, n=6):  #@save"""预测标签(定义见第3章)"""for X, y in test_iter:breaktrues = d2l.get_fashion_mnist_labels(y)preds = d2l.get_fashion_mnist_labels(net(X).argmax(axis=1))titles = [true +'\n' + pred for true, pred in zip(trues, preds)]d2l.show_images(X[0:n].reshape((n, 28, 28)), 1, n, titles=titles[0:n])predict_ch3(net, test_iter)

结果如下:
在这里插入图片描述

总结

本文softmax分类模型与前面线性回归模型的整体训练过程比较相似:先读取数据,再定义模型和损失函数,然后使用优化算法训练模型。大多数常见的深度学习模型都有类似的训练过程。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/5166.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

图像识别应用技术

⚠申明: 未经许可,禁止以任何形式转载,若要引用,请标注链接地址。 全文共计3077字,阅读大概需要3分钟 🌈更多学习内容, 欢迎👏关注👀【文末】我的个人微信公众号&#xf…

Akamai 分布式“云+边缘”,打造下一代数字化基座

当下,数字化基础设施正逐步向分布式部署演化,云计算与边缘计算正在成为两大技术支柱。Gartner 数据显示,云服务占 IT 整体支出比例连年上涨,在过去一年已增长至12.1%;IDC 报告显示,截至2021年已有超过500亿…

Grafana页面嵌入自建Web应用页面

目录 一、应用场景 二、实现方式 1、修改Grafana配置文件 2、获取监控页面url 3、隐藏左侧和顶部菜单 一、应用场景 需要将Grafana监控页面嵌入自建Web应用页面,使Grafana监控页面成为自建Web应用的一部分。 二、实现方式 总体思路:修改Grafana配…

C++之STL-list+模拟实现

目录 一、list的介绍和基本使用的方法 1.1 list的介绍 1.2 list的基本使用方法 1.2.1 构造方法 1.2.2 迭代器 1.2.3 容量相关的接口 1.2.4 增删查改的相关接口 1.3 关于list迭代器失效的问题 二、模拟实现list 2.1 节点类 2.2 迭代器类 2.3 主类list类 2.3.1 成员变…

多线程同步

1.多线程并发 1).多线程并发引例 #include <stdio.h> #include <stdlib.h> #include <unistd.h> #include <assert.h> #include <pthread.h>int wg0; void *fun(void *arg) {for(int i0;i<1000;i){wg;printf("wg%d\n",wg);} } in…

tp6.0 rabbitmq死信队列

rabbitMq交换机&#xff0c;队列情况&#xff0c;先手动创建 1. 创建普通交换机exchange&#xff0c;普通队列order_queue_expire&#xff0c;队列设置属性&#xff1a; 消息过期时间&#xff1a;60000毫秒&#xff0c;过期绑定dead_exchange交换机&#xff0c;routing_key:de…

web前端学习笔记2

2. 网页穿上美丽外衣 2.0 代码地址 https://gitee.com/qiangge95243611/java118/tree/master/web/day02 2.1 什么是CSS CSS (Cascading Style Sheets,层叠样式表),是一种用来为结构化文档(如 HTML 文档或 XML 应用)添加样式(字体、间距和颜色等)的计算机语言,CSS 文…

文件系统学习

软连接&#xff1a;可以跨不同的磁盘块&#xff0c;创建出不同的inode节点 应连接&#xff1a;相同的inode节点&#xff0c;不同的文件名字记录在父亲节点目录中 分区(fdisk)&#xff0c;格式化(mkfs)&#xff0c;挂载(mount)&#xff0c;大于2T分区&#xff08;parted&#…

FSNotes for Mac v6.7.1中文激活版:强大的笔记管理工具

FSNotes for Mac是一款功能强大的文本处理与笔记管理工具&#xff0c;为Mac用户提供了一个直观、高效的笔记记录和整理平台。 FSNotes for Mac v6.7.1中文激活版下载 FSNotes支持Markdown语法&#xff0c;使用户能够轻松设置笔记格式并添加链接、图像等元素&#xff0c;实现笔记…

基于H.264的RTP打包中的组合封包以及分片封包结构图简介及抓包分析;FU-A FU-B STAP-A STAP-B简介;

H.264视频流的RTP封装类型分析&#xff1a; 前言&#xff1a; 1.RTP打包原则&#xff1a; RTP的包长度必须要小于MTU(最大传输单元)&#xff0c;IP协议中MTU的最大长度为1500字节。除去IP报头&#xff08;20字节&#xff09;、UDP报头&#xff08;8字节&#xff09;、RTP头&a…

C#编程模式之装饰模式

创作背景&#xff1a;朋友们&#xff0c;我们继续C#编程模式的学习&#xff0c;本文我们将一起探讨装饰模式。装饰模式也是一种结构型设计模式&#xff0c;它允许你通过在运行时向对象添加额外的功能&#xff0c;从而动态的修改对象的行为。装饰模式本质上还是继承的一种替换方…

设计模式 基本认识

文章目录 设计模式的作用设计模式三原则设计模式与类图设计模式的分类 设计模式的作用 设计模式是在软件设计过程中针对常见问题的解决方案的一种通用、可重用的解决方案。设计模式提供了一种经过验证的方法&#xff0c;可以帮助开发人员解决特定类型的问题&#xff0c;并在软…

【论文阅读】IPT:Pre-TrainedImageProcessingTransformer

Pre-TrainedImageProcessingTransformer 论文地址摘要1. 简介2.相关作品2.1。图像处理2.2。 Transformer 3. 图像处理3.1. IPT 架构3.2 在 ImageNet 上进行预训练 4. 实验4.1. 超分辨率4.2. Denoising 5. 结论与讨论 论文地址 1、论文地址 2、源码 摘要 随着现代硬件的计算能…

mybatis工程需要的pom.xml,以及@Data 、@BeforeEach、@AfterEach 的使用,简化mybatis

对 “mybatis - XxxMapper.java接口中方法的参数 和 返回值类型&#xff0c;怎样在 XxxMapper.xml 中配置的问题” 这篇文章做一下优化 这个pom.xml文件&#xff0c;就是上面说的这篇文章的父工程的pom.xml&#xff0c;即&#xff1a;下面这个pom.xml 是可以拿来就用的 <?…

7天入门Android开发之第1天——初识Android

一、Android系统 1.Linux内核层&#xff1a; 这是安卓系统的底层&#xff0c;它提供了基本的系统功能&#xff0c;如内存管理、进程管理、驱动程序模型等。安卓系统构建在Linux内核之上&#xff0c;借助于Linux的稳定性和安全性。 2.系统运行库层&#xff1a; 这一层包括了安卓…

GITEE本地项目上传到远程

由于需要&#xff0c;我这边将本地的仓库上传至GITEE。之前在网上搜索了相关的文档&#xff0c;但是步骤很繁琐&#xff0c;我这边介绍一个非常简单的。 一、在GITEE新建仓库 跟着指引一步步新建。 二、打开本地仓库&#xff0c;删除.git文件 默认情况下不会有这个.git文件&a…

【保姆级讲解如何安装与配置Xcode】

&#x1f308;个人主页: 程序员不想敲代码啊 &#x1f3c6;CSDN优质创作者&#xff0c;CSDN实力新星&#xff0c;CSDN博客专家 &#x1f44d;点赞⭐评论⭐收藏 &#x1f91d;希望本文对您有所裨益&#xff0c;如有不足之处&#xff0c;欢迎在评论区提出指正&#xff0c;让我们共…

在kuboard中添加k8s集群

1.登录kuboard后&#xff0c;点击添加集群面板 系统会跳转到k8s集群添加页面&#xff0c;按照页面提示输入自身的集群信息即可&#xff0c;此处没有什么难点。 添加成功后&#xff0c;点击集群面板&#xff0c;然后点击集群概要信息&#xff0c;就可以查看集群节点信息。 集群节…

ssm092基于Tomcat技术的车库智能管理平台+jsp

车库智能管理平台设计与实现 摘 要 现代经济快节奏发展以及不断完善升级的信息化技术&#xff0c;让传统数据信息的管理升级为软件存储&#xff0c;归纳&#xff0c;集中处理数据信息的管理方式。本车库智能管理平台就是在这样的大环境下诞生&#xff0c;其可以帮助管理者在短…

Java字符缓冲区

字符缓冲区是在计算机编程中非常重要的一种数据结构&#xff0c;它主要用于存储和高效地操作字符序列。 在 Java 中&#xff0c;StringBuffer类就是典型的字符缓冲区实现。与String类不同&#xff0c;StringBuffer具有动态可变性&#xff0c;这意味着我们可以在原有的字符序列…