机器学习深度学习——softmax回归的简洁实现

👨‍🎓作者简介:一位即将上大四,正专攻机器学习的保研er
🌌上期文章:机器学习&&深度学习——softmax回归从零开始实现
📚订阅专栏:机器学习&&深度学习
希望文章对你们有所帮助

继续使用Fashion-MNIST数据集,并保持批量大小为256:

import torch
from torch import nn
from d2l import torch as d2lbatch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

softmax回归的简洁实现

  • 初始化模型参数
  • 重新审视softmax的实现
    • 数学推导
    • 交叉熵函数
  • 优化算法
  • 训练

初始化模型参数

softmax的输出层是一个全连接层,因此,为了实现模型,我们只需要在Sequential中添加一个带有10个输出的全连接层。当然这里的Sequential并不是必要的,但是他是深度模型的基础。我们仍旧以均值为0,标准差为0.01来随机初始化权重。

# pytorch不会隐式地调整输入的形状
# 因此在线性层前就定义了展平层flatten,来调整网络输入的形状
net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))def init_weights(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, std=0.01)net.apply(init_weights)  # 给net每一层跑一次init_weights函数

重新审视softmax的实现

数学推导

在之前的例子里,我们计算了模型的输出,然后将此输出送入交叉熵损失。看似合理,但是指数级计算可能会造成数值的稳定性问题。
回想一下之前的softmax函数:
y ^ j = e x p ( o j ) ∑ k e x p ( o k ) 其中 y ^ j 是预测的概率分布, o j 是未规范化的第 j 个元素 \hat{y}_j=\frac{exp(o_j)}{\sum_kexp(o_k)}\\ 其中\hat{y}_j是预测的概率分布,o_j是未规范化的第j个元素 y^j=kexp(ok)exp(oj)其中y^j是预测的概率分布,oj是未规范化的第j个元素
由于o中的一些数值会非常大,所以可能会让其指数值上溢,使得分子或分母变成inf,最后得到的预测值可能变成的0、inf或者nan。此时我们无法得到一个明确的交叉熵值。
提出解决这个问题的一个技巧:在继续softmax计算之前,先从所有的o中减去max(o),修改softmax函数的构造且不改变其返回值:
y ^ j = e x p ( o j − m a x ( o k ) ) e x p ( m a x ( o k ) ) ∑ k e x p ( o j − m a x ( o k ) ) e x p ( m a x ( o k ) ) \hat{y}_j=\frac{exp(o_j-max(o_k))exp(max(o_k))}{\sum_kexp(o_j-max(o_k))exp(max(o_k))} y^j=kexp(ojmax(ok))exp(max(ok))exp(ojmax(ok))exp(max(ok))
这样操作以后,可能会使得一些分子的exp(o-max(o))有接近0的值,即为下溢。这些值可能会四舍五入为0,这样就会使得预测值为0,那么此时要是取对数以后就会变为-inf。要是这样反向传播几步,我们可能会发现自己屏幕有一堆的nan。
尽管我们需要计算指数函数,但是我们最终会在计算交叉熵损失的时候会取他们的对数。尽管通过将softmax和交叉熵结合在一起,可以避免反向传播过程中可能会困扰我们的数值稳定性问题。如下面的式子:
l o g ( y ^ j ) = l o g ( e x p ( o j − m a x ( o k ) ) ∑ k e x p ( o k − m a x ( o k ) ) ) = l o g ( e x p ( o j − m a x ( o k ) ) ) − l o g ( ∑ k e x p ( o k − m a x ( o k ) ) ) = o j − m a x ( o k ) − l o g ( ∑ k e x p ( o k − m a x ( o k ) ) ) log(\hat{y}_j)=log(\frac{exp(o_j-max(o_k))}{\sum_kexp(o_k-max(o_k))})\\ =log(exp(o_j-max(o_k)))-log(\sum_kexp(o_k-max(o_k)))\\ =o_j-max(o_k)-log(\sum_kexp(o_k-max(o_k))) log(y^j)=log(kexp(okmax(ok))exp(ojmax(ok)))=log(exp(ojmax(ok)))log(kexp(okmax(ok)))=ojmax(ok)log(kexp(okmax(ok)))
通过上式,我们避免了计算单独的exp(o-max(o)),而是直接使用o-max(o)。
因此,我们计算交叉熵函数的时候,传递的不是未规范化的预测o,而不是softmax。
但是我们也希望保留传统的softmax函数,以备我们要评估通过模型输出的概率。

交叉熵函数

在这里介绍一下交叉熵函数,以用于上面推导所需的需求:

torch.nn.CrossEntropyLoss(weight=None,ignore_index=-100,reduction='mean')

交叉熵函数是将LogSoftMax和NLLLoss集成到一个类中,通常用于多分类问题。其参数使用情况:

ignore_index:指定被忽略且对输入梯度没有贡献的目标值。
reduction:string类型的可选项,可在[none,mean,sum]中选。none表示不降维,返回和target一样的形状;mean表示对一个batch的损失求均值;sum表示对一个batch的损失求和。
weight:是一个一维的张量,包含n个元素,分别代表n类的权重,在训练样本不均衡时很有用,默认为None:
(1)当weight=None时,损失函数计算方式为
loss(x,class)=-log(exp(x[class])/Σexp(x[j]))=-x[class]+log(Σexp(x[j])
(2)当weight被指定时,损失函数计算方式为:
loss(x,class)=weight[class]×(-x[class]+log(Σexp(x[j]))

# 在交叉熵损失函数中传递未归一化的预测,并同时计算softmax及其导数
loss = nn.CrossEntropyLoss(reduction='none')

优化算法

# 优化算法
trainer = torch.optim.SGD(net.parameters(), lr=0.1)

训练

调用之前定义的训练函数来训练模型:

# 调用之前的训练函数来训练模型
num_epochs = 10
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)
d2l.plt.show()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/10592.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java 设计模式——原型模式

目录 1.概述2.结构3.实现3.1.浅拷贝3.2.深拷贝3.2.1.通过对象序列化实现深拷贝(推荐)3.2.2.重写 clone() 方法来实现深拷贝 4.优缺点5.使用场景 1.概述 (1)原型模式 (Prototype Pattern) 是一种创建型设计模式,是通过…

在Windows 10/11 上安装GNS3模拟器

文章目录 在Windows 10/11 上安装GNS3模拟器简介支持的操作系统最低要求推荐配置要求最佳配置要求下载GNS3 all-in-one 安装文件安装GNS3在Windows 10/11 上安装GNS3模拟器 简介 本文档解释了如何在Windows环境中安装GNS3。你将学习如何: 下载所需的软件安装前提条件和可选软…

【C++】 VS2020 vector+template的案例

如果对博主其他文章感兴趣可以通过【CSDN文章】博客文章索引找到。 # include <iostream> # include <vector> using namespace std;template<class T> // 用class或者typename均可 void my_print(T& v, const string msg) // v前面不允许加const, 加…

SpiderFlow爬虫平台(爬虫学习)

申明 作为自己学习的记录,方面后期查阅 官网 SpiderFlow官网 简介 spider-flow 是一个爬虫平台&#xff0c;以图形化方式定义爬虫流程&#xff0c;无需代码即可实现一个爬虫 是使用springboot开发的项目,后端代码直接运行即可使用

Codeforces Round 888 (Div. 3)(A-F)

文章目录 ABCDEF A 题意&#xff1a; 就是有一个m步的楼梯。每一层都有k厘米高&#xff0c;现在A的身高是H&#xff0c;给了你n个人的身高问有多少个人与A站在不同层的楼梯高度相同。 思路&#xff1a; 我们只需要去枚举对于A来说每一层和他一样高&#xff08;人的身高和楼…

DNS协议详解

DNS协议详解 DNS协议介绍DNS解析过程DNS查询的方式递归查询迭代查询区别 DNS协议介绍 DNS 协议是一个应用层协议&#xff0c;它建立在 UDP 或 TCP 协议之上&#xff0c;默认使用 53 号端口。该协议的功能就是将人类可读的域名 (如&#xff0c;www.qq.com) 转换为机器可读的 IP…

uniapp 微信小程序 页面+组件的生命周期顺序

uniapp 微信小程序 页面组件的生命周期顺序 首页页面父组件子组件完整顺序参考资料 首页 首页只提供了一个跳转按钮。 <template><view><navigator url"/pages/myPage/myPage?namejerry" hover-class"navigator-hover"><button ty…

数据结构---顺序栈、链栈

特点 typedef struct Stack { int* base; //栈底 int* top;//栈顶 int stacksize //栈的容量; }SqStack; typedef struct StackNode { int data;//数据域 struct StackNode* next; //指针域 }StackNode,*LinkStack; 顺序栈 #define MaxSize 100 typedef struct Stack { int*…

PyTorch quantization observer

文章目录 PyTorch quantization observerbasic classstandard observersubstandard observer PyTorch quantization observer basic class nameinheritdescribeObserverBaseABC, nn.ModuleBase observer ModuleUniformQuantizationObserverBaseObserverBase standard observ…

智慧工厂4G+蓝牙+UWB+RTK人员定位系统解决方案

人员定位在智慧工厂的应用正逐渐受到重视&#xff0c;通过使用现代化的技术和智能终端设备&#xff0c;工厂管理者能够实时定位和跟踪员工的位置&#xff0c;方便进行人员调度管理和监督人员的工作情况&#xff1b;人员遇到紧急情况&#xff0c;可通过定位设备一键报警求救&…

vue 快速自定义分页el-pagination

vue 快速自定义分页el-pagination template <div style"text-align: center"><el-paginationbackground:current-page"pageObj.currentPage":page-size"pageObj.page":page-sizes"pageObj.pageSize"layout"total,prev,…

安全文件传输:如何降低数据丢失的风险

在当今数字化时代&#xff0c;文件传输是必不可少的一项工作。但是&#xff0c;数据丢失一直是一个令人头疼的问题。本文将探讨一些减少数据丢失风险的方法&#xff0c;包括加密、备份和使用可信的传输协议等。采取这些措施将有助于保护数据免受意外丢失的危险。 一、加密保护数…

24考研数据结构-栈

目录 第三章 栈和队列3.1栈&#xff08;stack&#xff09;3.1.1栈的基本概念栈的基本概念知识回顾 3.1.2 栈的顺序存储上溢与下溢栈的顺序存储知识回顾 3.1.3栈的链式存储链栈的基本操作 第三章 栈和队列 3.1栈&#xff08;stack&#xff09; 3.1.1栈的基本概念 栈的定义 栈…

通过ETL自动化同步飞书数据到本地数仓

一、飞书数据同步到数据库需求 使用飞书的企业都有将飞书的数据自动同步到本地数据库、数仓以及其他业务系统表的需求&#xff0c;主要是为了实现飞书的数据与业务系统进行流程拉通或数据分析时使用&#xff0c;以下是一些具体的同步场景示例&#xff1a; 组织架构同步&#…

9.NIO非阻塞式网络通信入门

highlight: arduino-light Selector 示意图和特点说明 一个 I/O 线程可以并发处理 N 个客户端连接和读写操作&#xff0c;这从根本上解决了传统同步阻塞 I/O 一连接一线程模型。架构的性能、弹性伸缩能力和可靠性都得到了极大的提升。 服务端流程 1、当客户端连接服务端时&…

ADS仿真低噪声放大器学习笔记

ADS仿真低噪声放大器 文章目录 ADS仿真低噪声放大器1. 安装晶体管的库文件2. 直流分析DC Tracing3. 偏置电路的设计4. 稳定性分析5. 输入匹配和输出匹配 设计要求&#xff1a; 工作频率&#xff1a;2.4~2.5GHz ISM频段 噪声系数&#xff1a;NF < 0.7 增益&#xff1a;Gain &…

__init__函数用法

__init__是Python类中的一个特殊方法&#xff08;special method&#xff09;&#xff0c;也称为构造函数。它在类实例化&#xff08;创建对象&#xff09;的过程中自动被调用&#xff0c;用于初始化对象的属性和执行其他必要的设置。 构造函数的完整命名是__init__()&#xf…

北斗gps卫星授时服务器(NTP)应用于防火墙场景

北斗gps卫星授时服务器&#xff08;NTP&#xff09;应用于防火墙场景 北斗gps卫星授时服务器&#xff08;NTP&#xff09;应用于防火墙场景 作为网络建设中不可或缺的两方面&#xff0c;在保证网络安全稳定以及时间同步精确性方面&#xff0c;防火墙和NTP服务器都极为重要。而防…

分享200+个关于AI的网站

分享200个关于AI的网站 欢迎大家访问&#xff1a;https://tools.haiyong.site/ai 快速导航 AI 应用AI 写作AI 编程AI 设计AI 作图AI 训练模型AI 影音编辑AI 效率助手 AI 应用 文心一言: https://yiyan.baidu.com/ 百度出品的人工智能语言模型 ChatGPT: https://chat.openai.c…

Matlab遍历文件及直方图统计

参考链接&#xff1a; 使用MATLAB遍历文件 strtrim用法 strsplit用法 cell单元数据使用{} close all; dir_path C:/Users/; fileFolder ls(dir_path); fileNum length(fileFolder(:,1)) - 2; for i 3:(3fileNum-1)file_path strcat(dir_path, strtrim(fileFolder(i,:)))…