机器学习深度学习——softmax回归的简洁实现

👨‍🎓作者简介:一位即将上大四,正专攻机器学习的保研er
🌌上期文章:机器学习&&深度学习——softmax回归从零开始实现
📚订阅专栏:机器学习&&深度学习
希望文章对你们有所帮助

继续使用Fashion-MNIST数据集,并保持批量大小为256:

import torch
from torch import nn
from d2l import torch as d2lbatch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

softmax回归的简洁实现

  • 初始化模型参数
  • 重新审视softmax的实现
    • 数学推导
    • 交叉熵函数
  • 优化算法
  • 训练

初始化模型参数

softmax的输出层是一个全连接层,因此,为了实现模型,我们只需要在Sequential中添加一个带有10个输出的全连接层。当然这里的Sequential并不是必要的,但是他是深度模型的基础。我们仍旧以均值为0,标准差为0.01来随机初始化权重。

# pytorch不会隐式地调整输入的形状
# 因此在线性层前就定义了展平层flatten,来调整网络输入的形状
net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))def init_weights(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, std=0.01)net.apply(init_weights)  # 给net每一层跑一次init_weights函数

重新审视softmax的实现

数学推导

在之前的例子里,我们计算了模型的输出,然后将此输出送入交叉熵损失。看似合理,但是指数级计算可能会造成数值的稳定性问题。
回想一下之前的softmax函数:
y ^ j = e x p ( o j ) ∑ k e x p ( o k ) 其中 y ^ j 是预测的概率分布, o j 是未规范化的第 j 个元素 \hat{y}_j=\frac{exp(o_j)}{\sum_kexp(o_k)}\\ 其中\hat{y}_j是预测的概率分布,o_j是未规范化的第j个元素 y^j=kexp(ok)exp(oj)其中y^j是预测的概率分布,oj是未规范化的第j个元素
由于o中的一些数值会非常大,所以可能会让其指数值上溢,使得分子或分母变成inf,最后得到的预测值可能变成的0、inf或者nan。此时我们无法得到一个明确的交叉熵值。
提出解决这个问题的一个技巧:在继续softmax计算之前,先从所有的o中减去max(o),修改softmax函数的构造且不改变其返回值:
y ^ j = e x p ( o j − m a x ( o k ) ) e x p ( m a x ( o k ) ) ∑ k e x p ( o j − m a x ( o k ) ) e x p ( m a x ( o k ) ) \hat{y}_j=\frac{exp(o_j-max(o_k))exp(max(o_k))}{\sum_kexp(o_j-max(o_k))exp(max(o_k))} y^j=kexp(ojmax(ok))exp(max(ok))exp(ojmax(ok))exp(max(ok))
这样操作以后,可能会使得一些分子的exp(o-max(o))有接近0的值,即为下溢。这些值可能会四舍五入为0,这样就会使得预测值为0,那么此时要是取对数以后就会变为-inf。要是这样反向传播几步,我们可能会发现自己屏幕有一堆的nan。
尽管我们需要计算指数函数,但是我们最终会在计算交叉熵损失的时候会取他们的对数。尽管通过将softmax和交叉熵结合在一起,可以避免反向传播过程中可能会困扰我们的数值稳定性问题。如下面的式子:
l o g ( y ^ j ) = l o g ( e x p ( o j − m a x ( o k ) ) ∑ k e x p ( o k − m a x ( o k ) ) ) = l o g ( e x p ( o j − m a x ( o k ) ) ) − l o g ( ∑ k e x p ( o k − m a x ( o k ) ) ) = o j − m a x ( o k ) − l o g ( ∑ k e x p ( o k − m a x ( o k ) ) ) log(\hat{y}_j)=log(\frac{exp(o_j-max(o_k))}{\sum_kexp(o_k-max(o_k))})\\ =log(exp(o_j-max(o_k)))-log(\sum_kexp(o_k-max(o_k)))\\ =o_j-max(o_k)-log(\sum_kexp(o_k-max(o_k))) log(y^j)=log(kexp(okmax(ok))exp(ojmax(ok)))=log(exp(ojmax(ok)))log(kexp(okmax(ok)))=ojmax(ok)log(kexp(okmax(ok)))
通过上式,我们避免了计算单独的exp(o-max(o)),而是直接使用o-max(o)。
因此,我们计算交叉熵函数的时候,传递的不是未规范化的预测o,而不是softmax。
但是我们也希望保留传统的softmax函数,以备我们要评估通过模型输出的概率。

交叉熵函数

在这里介绍一下交叉熵函数,以用于上面推导所需的需求:

torch.nn.CrossEntropyLoss(weight=None,ignore_index=-100,reduction='mean')

交叉熵函数是将LogSoftMax和NLLLoss集成到一个类中,通常用于多分类问题。其参数使用情况:

ignore_index:指定被忽略且对输入梯度没有贡献的目标值。
reduction:string类型的可选项,可在[none,mean,sum]中选。none表示不降维,返回和target一样的形状;mean表示对一个batch的损失求均值;sum表示对一个batch的损失求和。
weight:是一个一维的张量,包含n个元素,分别代表n类的权重,在训练样本不均衡时很有用,默认为None:
(1)当weight=None时,损失函数计算方式为
loss(x,class)=-log(exp(x[class])/Σexp(x[j]))=-x[class]+log(Σexp(x[j])
(2)当weight被指定时,损失函数计算方式为:
loss(x,class)=weight[class]×(-x[class]+log(Σexp(x[j]))

# 在交叉熵损失函数中传递未归一化的预测,并同时计算softmax及其导数
loss = nn.CrossEntropyLoss(reduction='none')

优化算法

# 优化算法
trainer = torch.optim.SGD(net.parameters(), lr=0.1)

训练

调用之前定义的训练函数来训练模型:

# 调用之前的训练函数来训练模型
num_epochs = 10
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)
d2l.plt.show()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/10592.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java 设计模式——原型模式

目录 1.概述2.结构3.实现3.1.浅拷贝3.2.深拷贝3.2.1.通过对象序列化实现深拷贝(推荐)3.2.2.重写 clone() 方法来实现深拷贝 4.优缺点5.使用场景 1.概述 (1)原型模式 (Prototype Pattern) 是一种创建型设计模式,是通过…

DNS协议详解

DNS协议详解 DNS协议介绍DNS解析过程DNS查询的方式递归查询迭代查询区别 DNS协议介绍 DNS 协议是一个应用层协议,它建立在 UDP 或 TCP 协议之上,默认使用 53 号端口。该协议的功能就是将人类可读的域名 (如,www.qq.com) 转换为机器可读的 IP…

uniapp 微信小程序 页面+组件的生命周期顺序

uniapp 微信小程序 页面组件的生命周期顺序 首页页面父组件子组件完整顺序参考资料 首页 首页只提供了一个跳转按钮。 <template><view><navigator url"/pages/myPage/myPage?namejerry" hover-class"navigator-hover"><button ty…

智慧工厂4G+蓝牙+UWB+RTK人员定位系统解决方案

人员定位在智慧工厂的应用正逐渐受到重视&#xff0c;通过使用现代化的技术和智能终端设备&#xff0c;工厂管理者能够实时定位和跟踪员工的位置&#xff0c;方便进行人员调度管理和监督人员的工作情况&#xff1b;人员遇到紧急情况&#xff0c;可通过定位设备一键报警求救&…

vue 快速自定义分页el-pagination

vue 快速自定义分页el-pagination template <div style"text-align: center"><el-paginationbackground:current-page"pageObj.currentPage":page-size"pageObj.page":page-sizes"pageObj.pageSize"layout"total,prev,…

安全文件传输:如何降低数据丢失的风险

在当今数字化时代&#xff0c;文件传输是必不可少的一项工作。但是&#xff0c;数据丢失一直是一个令人头疼的问题。本文将探讨一些减少数据丢失风险的方法&#xff0c;包括加密、备份和使用可信的传输协议等。采取这些措施将有助于保护数据免受意外丢失的危险。 一、加密保护数…

24考研数据结构-栈

目录 第三章 栈和队列3.1栈&#xff08;stack&#xff09;3.1.1栈的基本概念栈的基本概念知识回顾 3.1.2 栈的顺序存储上溢与下溢栈的顺序存储知识回顾 3.1.3栈的链式存储链栈的基本操作 第三章 栈和队列 3.1栈&#xff08;stack&#xff09; 3.1.1栈的基本概念 栈的定义 栈…

通过ETL自动化同步飞书数据到本地数仓

一、飞书数据同步到数据库需求 使用飞书的企业都有将飞书的数据自动同步到本地数据库、数仓以及其他业务系统表的需求&#xff0c;主要是为了实现飞书的数据与业务系统进行流程拉通或数据分析时使用&#xff0c;以下是一些具体的同步场景示例&#xff1a; 组织架构同步&#…

9.NIO非阻塞式网络通信入门

highlight: arduino-light Selector 示意图和特点说明 一个 I/O 线程可以并发处理 N 个客户端连接和读写操作&#xff0c;这从根本上解决了传统同步阻塞 I/O 一连接一线程模型。架构的性能、弹性伸缩能力和可靠性都得到了极大的提升。 服务端流程 1、当客户端连接服务端时&…

ADS仿真低噪声放大器学习笔记

ADS仿真低噪声放大器 文章目录 ADS仿真低噪声放大器1. 安装晶体管的库文件2. 直流分析DC Tracing3. 偏置电路的设计4. 稳定性分析5. 输入匹配和输出匹配 设计要求&#xff1a; 工作频率&#xff1a;2.4~2.5GHz ISM频段 噪声系数&#xff1a;NF < 0.7 增益&#xff1a;Gain &…

分享200+个关于AI的网站

分享200个关于AI的网站 欢迎大家访问&#xff1a;https://tools.haiyong.site/ai 快速导航 AI 应用AI 写作AI 编程AI 设计AI 作图AI 训练模型AI 影音编辑AI 效率助手 AI 应用 文心一言: https://yiyan.baidu.com/ 百度出品的人工智能语言模型 ChatGPT: https://chat.openai.c…

人脸检测实战-insightface

目录 简介 一、InsightFace介绍 二、安装 三、快速体验 四、代码实战 1、人脸检测 2、人脸识别 五、代码及示例图片链接 简介 目前github有非常多的人脸识别开源项目&#xff0c;下面列出几个常用的开源项目&#xff1a; 1、deepface 2、CompreFace 3、face_recogn…

【Python 实战】---- 批量识别图片中的文字,存入excel中【使用百度的通用文字识别】

分析 1. 获取信息图片示例 2. 运行实例 3. 运行结果 4. 各个文件的位置 实现 1. 需求分析 识别图片中的文字【采用百度的通用文字识别】;文字筛选,按照分类获取对应的文本;采用 openpyxl 实现将数据存入 excel 中。2. 获取 access_token 获取本地缓存的

网络安全大厂面试题

自我介绍 有没有挖过src&#xff1f; 平时web渗透怎么学的&#xff0c;有实战吗&#xff1f;有过成功发现漏洞的经历吗&#xff1f; 做web渗透时接触过哪些工具 xxe漏洞是什么&#xff1f;ssrf是什么&#xff1f; 打ctf的时候负责什么方向的题 为什么要搞信息安全&#xff0c;对…

数据结构之顺序表

一、概念及结构 顺序表是用一段物理地址连续的存储单元依次存储数据元素的线性结构&#xff0c;一般情况下采用数组存 储。在数组上完成数据的增删查改。 顺序表一般可以分为&#xff1a; 1. 静态顺序表&#xff1a;使用定长数组存储元素。 2. 动态顺序表&#xff1a;使用动…

django学习笔记(1)

django创建项目 先创建一个文件夹用来放django的项目&#xff0c;我这里是My_Django_it 之后打开到该文件下&#xff0c;并用下面的指令来创建myDjango1项目 D:\>cd My_Django_itD:\My_Django_it>"D:\zzu_it\Django_learn\Scripts\django-admin.exe" startpr…

Websocket协议-http协议-tcp协议区别和相同点

通讯形式 单工通讯-数据只能单向传送一方来发送数据&#xff0c;另一方来接收数据 半双工通讯-数据能双向传送但不能同时双向传送 全双工通讯-数据能够同时双向传送和接受 注&#xff1a;http的通讯方式是分版本 http1.0&#xff1a;单工。因为是短连接&#xff0c;客户端…

malloc(1) 会分配多大的虚拟内存?

malloc() 分配的是虚拟内存。 如果分配后的虚拟内存没有被访问的话&#xff0c;虚拟内存是不会映射到物理内存的&#xff0c;这样就不会占用物理内存了。 只有在访问已分配的虚拟地址空间的时候&#xff0c;操作系统通过查找页表&#xff0c;发现虚拟内存对应的页没有在物理内…

TEE GP(Global Platform)技术委员会及中国任务小组

TEE之GP(Global Platform)认证汇总 一、TEE GP技术委员会 二、GP中国任务小组 参考&#xff1a; GlobalPlatform Certification - GlobalPlatform

MultipartFile类型接收上传文件报出的UncheckedIOException以及删除tomcat临时文件失败源码探索

1、描述异常背景&#xff1a; 因为需要分析数据&#xff0c;待处理excel文件的数据行数太大&#xff0c;手动太累&#xff0c;花半小时写了一个定制的数据入库工具&#xff0c;改成了通用的&#xff0c;整个项目中的万级别数据都在工具上分析&#xff0c;写SQL进行分析&#x…