【深度学习笔记】3_6 代码实现softmax-regression

注:本文为《动手学深度学习》开源内容,仅为个人学习记录,无抄袭搬运意图

3.6 softmax回归的从零开始实现

这一节我们来动手实现softmax回归。首先导入本节实现所需的包或模块。

import torch
import torchvision
import numpy as np
import sys
sys.path.append("..") # 为了导入上层目录的d2lzh_pytorch
import d2lzh_pytorch as d2l

3.6.1 获取和读取数据

我们将使用Fashion-MNIST数据集,并设置批量大小为256。

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

3.6.2 初始化模型参数

跟线性回归中的例子一样,我们将使用向量表示每个样本。已知每个样本输入是高和宽均为28像素的图像。模型的输入向量的长度是 28 × 28 = 784 28 \times 28 = 784 28×28=784:该向量的每个元素对应图像中每个像素。由于图像有10个类别,单层神经网络输出层的输出个数为10,因此softmax回归的权重和偏差参数分别为 784 × 10 784 \times 10 784×10 1 × 10 1 \times 10 1×10的矩阵。

num_inputs = 784
num_outputs = 10W = torch.tensor(np.random.normal(0, 0.01, (num_inputs, num_outputs)), dtype=torch.float)
b = torch.zeros(num_outputs, dtype=torch.float)

同之前一样,我们需要模型参数梯度。

W.requires_grad_(requires_grad=True)
b.requires_grad_(requires_grad=True) 

3.6.3 实现softmax运算

在介绍如何定义softmax回归之前,我们先描述一下对如何对多维Tensor按维度操作。在下面的例子中,给定一个Tensor矩阵X。我们可以只对其中同一列(dim=0)或同一行(dim=1)的元素求和,并在结果中保留行和列这两个维度(keepdim=True)。

X = torch.tensor([[1, 2, 3], [4, 5, 6]])
print(X.sum(dim=0, keepdim=True))
print(X.sum(dim=1, keepdim=True))

输出:

tensor([[5, 7, 9]])
tensor([[ 6],[15]])

下面我们就可以定义前面小节里介绍的softmax运算了。在下面的函数中,矩阵X的行数是样本数,列数是输出个数。为了表达样本预测各个输出的概率,softmax运算会先通过exp函数对每个元素做指数运算,再对exp矩阵同行元素求和,最后令矩阵每行各元素与该行元素之和相除。这样一来,最终得到的矩阵每行元素和为1且非负。因此,该矩阵每行都是合法的概率分布。softmax运算的输出矩阵中的任意一行元素代表了一个样本在各个输出类别上的预测概率。

def softmax(X):X_exp = X.exp()partition = X_exp.sum(dim=1, keepdim=True)return X_exp / partition  # 这里应用了广播机制

可以看到,对于随机输入,我们将每个元素变成了非负数,且每一行和为1。

X = torch.rand((2, 5))
X_prob = softmax(X)
print(X_prob, X_prob.sum(dim=1))

输出:

tensor([[0.2206, 0.1520, 0.1446, 0.2690, 0.2138],[0.1540, 0.2290, 0.1387, 0.2019, 0.2765]]) tensor([1., 1.])

3.6.4 定义模型

有了softmax运算,我们可以定义上节描述的softmax回归模型了。这里通过view函数将每张原始图像改成长度为num_inputs的向量。

def net(X):return softmax(torch.mm(X.view((-1, num_inputs)), W) + b)

3.6.5 定义损失函数

上一节中,我们介绍了softmax回归使用的交叉熵损失函数。为了得到标签的预测概率,我们可以使用gather函数。在下面的例子中,变量y_hat是2个样本在3个类别的预测概率,变量y是这2个样本的标签类别。通过使用gather函数,我们得到了2个样本的标签的预测概率。与3.4节(softmax回归)数学表述中标签类别离散值从1开始逐一递增不同,在代码中,标签类别的离散值是从0开始逐一递增的。

y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
y = torch.LongTensor([0, 2])
y_hat.gather(1, y.view(-1, 1))

输出:

tensor([[0.1000],[0.5000]])

下面实现了3.4节(softmax回归)中介绍的交叉熵损失函数。

def cross_entropy(y_hat, y):return - torch.log(y_hat.gather(1, y.view(-1, 1)))

3.6.6 计算分类准确率

给定一个类别的预测概率分布y_hat,我们把预测概率最大的类别作为输出类别。如果它与真实类别y一致,说明这次预测是正确的。分类准确率即正确预测数量与总预测数量之比。

为了演示准确率的计算,下面定义准确率accuracy函数。其中y_hat.argmax(dim=1)返回矩阵y_hat每行中最大元素的索引,且返回结果与变量y形状相同。相等条件判断式(y_hat.argmax(dim=1) == y)是一个类型为ByteTensorTensor,我们用float()将其转换为值为0(相等为假)或1(相等为真)的浮点型Tensor

def accuracy(y_hat, y):return (y_hat.argmax(dim=1) == y).float().mean().item()

让我们继续使用在演示gather函数时定义的变量y_haty,并将它们分别作为预测概率分布和标签。可以看到,第一个样本预测类别为2(该行最大元素0.6在本行的索引为2),与真实标签0不一致;第二个样本预测类别为2(该行最大元素0.5在本行的索引为2),与真实标签2一致。因此,这两个样本上的分类准确率为0.5。

print(accuracy(y_hat, y))

输出:

0.5

类似地,我们可以评价模型net在数据集data_iter上的准确率。

# 本函数已保存在d2lzh_pytorch包中方便以后使用。该函数将被逐步改进:它的完整实现将在“图像增广”一节中描述
def evaluate_accuracy(data_iter, net):acc_sum, n = 0.0, 0for X, y in data_iter:acc_sum += (net(X).argmax(dim=1) == y).float().sum().item()n += y.shape[0]return acc_sum / n

因为我们随机初始化了模型net,所以这个随机模型的准确率应该接近于类别个数10的倒数即0.1。

print(evaluate_accuracy(test_iter, net))

输出:

0.0681

3.6.7 训练模型

训练softmax回归的实现跟3.2(线性回归的从零开始实现)一节介绍的线性回归中的实现非常相似。我们同样使用小批量随机梯度下降来优化模型的损失函数。在训练模型时,迭代周期数num_epochs和学习率lr都是可以调的超参数。改变它们的值可能会得到分类更准确的模型。

num_epochs, lr = 5, 0.1# 本函数已保存在d2lzh包中方便以后使用
def train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size,params=None, lr=None, optimizer=None):for epoch in range(num_epochs):train_l_sum, train_acc_sum, n = 0.0, 0.0, 0for X, y in train_iter:y_hat = net(X)l = loss(y_hat, y).sum()# 梯度清零if optimizer is not None:optimizer.zero_grad()elif params is not None and params[0].grad is not None:for param in params:param.grad.data.zero_()l.backward()if optimizer is None:d2l.sgd(params, lr, batch_size)else:optimizer.step()  # “softmax回归的简洁实现”一节将用到train_l_sum += l.item()train_acc_sum += (y_hat.argmax(dim=1) == y).sum().item()n += y.shape[0]test_acc = evaluate_accuracy(test_iter, net)print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f'% (epoch + 1, train_l_sum / n, train_acc_sum / n, test_acc))train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, batch_size, [W, b], lr)

输出:

epoch 1, loss 0.7878, train acc 0.749, test acc 0.794
epoch 2, loss 0.5702, train acc 0.814, test acc 0.813
epoch 3, loss 0.5252, train acc 0.827, test acc 0.819
epoch 4, loss 0.5010, train acc 0.833, test acc 0.824
epoch 5, loss 0.4858, train acc 0.836, test acc 0.815

3.6.8 预测

训练完成后,现在就可以演示如何对图像进行分类了。给定一系列图像(第三行图像输出),我们比较一下它们的真实标签(第一行文本输出)和模型预测结果(第二行文本输出)。

X, y = next(iter(test_iter))# 注意这里有坑,pytorch版本不一样next写法可能不一样,可根据自己所使用的版本修改true_labels = d2l.get_fashion_mnist_labels(y.numpy())
pred_labels = d2l.get_fashion_mnist_labels(net(X).argmax(dim=1).numpy())
titles = [true + '\n' + pred for true, pred in zip(true_labels, pred_labels)]d2l.show_fashion_mnist(X[0:9], titles[0:9])

在这里插入图片描述

小结

  • 可以使用softmax回归做多类别分类。与训练线性回归相比,你会发现训练softmax回归的步骤和它非常相似:获取并读取数据、定义模型和损失函数并使用优化算法训练模型。事实上,绝大多数深度学习模型的训练都有着类似的步骤。

注:本节除了代码之外与原书基本相同,原书传送门

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/700480.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

QT Widget自定义菜单

此文以设置QListWidget的自定义菜单为例,其他继承于QWidget的类也都可以按类似的方法去实现。 1、ui文件设置contextMenuPolicy属性为CustomContextMenu 2、添加槽函数 /*** brief onCustomContextMenuRequested 右键弹出菜单* param pos 右键的坐标*/void onCusto…

十一、Qt数据库操作

一、Sql介绍 Qt Sql模块包含多个类,实现数据库的连接,Sql语句的执行,数据获取与界面显示,数据与界面直接使用Model/View结构。1、使用Sql模块 (1)工程加入 QT sql(2)添加头文件 …

2023年的AI模型学习/部署/优化

可以的话,github上给点一个小心心,感谢观看。 LDC边缘检测的轻量级密集卷积神经网络: meiqisheng/LDC (github.com)https://github.com/meiqisheng/LDC segment-anything分割一切的图像分割算法模型: meiqisheng/segment-anyt…

群晖NAS DSM7.2.1安装宝塔之后无法登陆账号密码问题解决

宝塔的安装就不在这赘述了,只说下,启动之后默认账号密码无法登陆的问题。 按照上面给出的账号密码,无法登陆 然后点忘记密码,由于是docker安装的,根目录下没有/www/server/panel 。 也没有bt命令 要怎么修改呢。 既然…

【Java程序设计】【C00283】基于Springboot的校园志愿者管理系统(有论文)

基于Springboot的校园志愿者管理系统(有论文) 项目简介项目获取开发环境项目技术运行截图 项目简介 这是一个基于Springboot的校园志愿者管理系统 本系统分为系统功能模块、管理员功能模块以及志愿者功能模块。 系统功能模块:用户进入到系统…

应用中如何将单数据库升级为集群【数据库集群】【MySQL集群模式】

MySQL集群架构搭建以及多数据源管理实战 应用中如何将单数据库升级为集群1、搭建MySQL集群,实现服务和数据的高可用1>搭建基础MySQL服务。​ 2、启动MySQL服务​ 3、连接MySQL 2>搭建MySQL主从集群1》配置master服务2》配置slave从服务3》主从集群测试4》全库…

Github 2024-02-23 开源项目日报 Top10

根据Github Trendings的统计,今日(2024-02-23统计)共有10个项目上榜。根据开发语言中项目的数量,汇总情况如下: 开发语言项目数量非开发语言项目4Python项目3TypeScript项目1HTML项目1Dart项目1Rust项目1 从零开始构建你喜爱的技术 创建周…

智胜未来,新时代IT技术人风口攻略-第七版(弃稿)

文章目录 前言鸿蒙生态科普调研人员画像角色先行结论 - 市场下的增量蛋糕高校助力鸿蒙 - 掀起鸿蒙教育热潮高校鸿蒙课程开设占比 - 巨大需求背后是矛盾冲突教研力量并非唯一原因 - 看重教学成果复用与效率 企业布局规划 - 多元市场前瞻视野全盘接纳仍需一段时间 - 积极正向的一…

【新手易错点】golang中byte和rune

1 总体区别 在Golang中,byte和rune是两种不同类型的数据。简单来说,byte是一个8位的无符号整数类型,而rune则是一个32位的Unicode字符类型。 Byte: 在Golang中,byte类型实际上是uint8的别名,它用来表示8位的无符号整…

2.22日学习打卡----正则表达式

2.22日学习打卡 目录: 2.22日学习打卡正则表达式什么是正则表达式?正则表达式的作用正则表达式特点基础语法表格元字符Java 中正则表达式的使用正则表达式语法规则内容限定单个字符限定范围字符限定取反限定 长度限定长度限定符号预定义字符正则表达式的…

【MySQL】MySQL从0到0.9 - 持续更新ing

MySQL SQL 基础 DDL 语句 show databases; show tables; desc table_name; # 获取表信息 show create table 表名; // 查询指定表的建表语句 数据类型 char(10) 不满10个会用空格填充,性能好一点 varchar(10) 变长字符串,性能差一点 CREATE TABLE tabl…

力扣 187. 重复的DNA序列

1.题目 DNA序列 由一系列核苷酸组成,缩写为 A, C, G 和 T.。 例如,"ACGAATTCCG" 是一个 DNA序列 。 在研究 DNA 时,识别 DNA 中的重复序列非常有用。 给定一个表示 DNA序列 的字符串 s ,返回所有在 DNA 分子中出现不止一…

【word技巧】word文档如何设置限制编辑

Word文档中为了提高办公效率以及文档安全,我们可以考虑为word文档设置一个限制编辑起到保护文档的作用。今天介绍word文档设置限制编辑的方法。 打开word文档之后,点击功能栏中的【审阅】功能,选择【限制编辑】功能 这是我们勾选右边弹框中的…

【Spring MVC篇】简单案例分析

个人主页:兜里有颗棉花糖 欢迎 点赞👍 收藏✨ 留言✉ 加关注💓本文由 兜里有颗棉花糖 原创 收录于专栏【Spring MVC】 本专栏旨在分享学习Spring MVC的一点学习心得,欢迎大家在评论区交流讨论💌 目录 一、加法计算器二…

匿名+有名管道

管道 相关概念 4种情况 正常情况,如果管道没有数据,读端陷入等待,直到有数据才能唤醒正常情况,如果管道被写满,写端陷入等待,直到有空间才能唤醒写段关闭,读端一直读取,read返回0…

mac m1调试aarch64 android kernel最终方案

问题 这是之前的,调试android kernel的方案还是太笨重了 完美调试android-goldfish(linux kernel) aarch64的方法 然后,看GeekCon AVSS 2023 Qualifier,通过 sdk-repo-linux_aarch64-emulator-8632828.zip 进行启动 完整编译的aosp kernnl…

51单片机学习(4)-----独立按键进一步控制LED灯

前言:感谢您的关注哦,我会持续更新编程相关知识,愿您在这里有所收获。如果有任何问题,欢迎沟通交流!期待与您在学习编程的道路上共同进步。 目录 一. 独立按键灵活控制LED 程序一:单个独立按键控制多个…

Web性能优化-详细讲解与实用方法-MDN文档学习笔记

Web性能优化 查看更多学习笔记:GitHub:LoveEmiliaForever MDN中文官网 性能优良的网站能够提高访问者留存和用户满意度,减少客户端和服务器之间传输的数据量可降低各方的成本 不同的业务目标和用户需求需要不同的性能度量,要提高…

STM32CubeMX FOC工程配置(AuroraFOC)

一. 简介 哈喽,大家好,今天给大家带来基于AuroraFOC开发板的STM32CubeMX的工程配置,主要配置的参数如下: 1. 互补PWM输出 2. 定时器注入中断ADC采样 3. SPI配置 4. USB CDC配置 5. RT Thread配置 大家如果对这几部分感兴趣的话&#xff0c…

贪婪算法入门指南

想象一下,你在玩一款捡金币的游戏。在这个游戏里,地图中散布着各种大小不一的金币,而你的目标就是尽可能快地收集到最多的金币。你可能会采取一个直观的策略:每次都去捡最近的、看起来最大的金币。这种在每一步都采取局部最优解的…