【d2l动手学深度学习】 Lesson 10 多层感知机 + 代码实现 试验结果对比

文章目录

  • 1. 介绍
  • 2. 单层Softmax回归
    • 2.1 手写Softmax
      • 训练效果
    • 2.2 调用pytorch内置的softmax回归层实现
      • 调用pytorch内置softmax实验结果
      • 总结
  • 3. 一层感知机(MLP)+ Softmax
    • 实验结果
  • Reference
  • 写在最后


1. 介绍

在第十节课 多层感知机 的代码实现部分,做的小实验,介绍了对FashionMNIST(衣物)数据集进行十分类的神经网络实现效果,主要展示的是训练的Loss以及准确度的训练批次图,10个epoch


2. 单层Softmax回归

2.1 手写Softmax

Softmax函数,输入是一个二维张量

  1. 首先,对输入的矩阵进行求指数exp(X) , 不会改变矩阵大小
  2. 接着,对dim_1 进行求和,得到每个example的指数总和(下面公式的分母部分)
  3. 最后,将整个矩阵相除,得到指数归一化的输出(不改变矩阵形状)

softmax求和

softmax函数内部求和

def softmax(X):X_exp = torch.exp(X) # 1. 求指数partition = X_exp.sum(1, keepdim=True) # 2. 求和return X_exp / partition # keepdim to boardcast 3. 输出指数归一化矩阵 

这个函数在自定义的网络中net( X )调用

def net(X):# -1 means convert the dimension atomatically# the input shape X (256, 1, 28, 28) --> (28*28, 256)return softmax(torch.matmul(X.reshape((-1, W.shape[0])), W) + b)

最后,写一个epoch迭代
⚠️:这里面没有写梯度更新的函数,直接调用torch中的自动求道实现,也就是下文代码中的updater.step()

for X, y in train_iter:y_hat = net(X) # 将batch传入进去l = loss(y_hat, y) # 计算loss值if isinstance(updater, torch.optim.Optimizer):updater.zero_grad()l.backward() # 求梯度updater.step() # 根据梯度更新权重参数矩阵Wmetric.add(float(l) * len(y), accuracy(y_hat, y),y.size().numel())else:l.sum().backward()updater(X.shape[0])metric.add(float(l.sum()), accuracy(y_hat, y), y.numel())

训练效果

手动实现的Softmax训练效果

手动实现的Softmax训练效果


2.2 调用pytorch内置的softmax回归层实现

⚠️:nn.CrossEntropyLoss()会在输出的时候自动应用Softmax进行求Loss(公式如下图所示)

net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))  loss = nn.CrossEntropyLoss() # 损失函数 内嵌了Softmax函数trainer = torch.optim.SGD(net.parameters(), lr=0.1) # 梯度下降优化器选择

在这里插入图片描述

Pytorch文档.CROSSENTROPYLOSS

调用pytorch内置softmax实验结果

内置Softmax实现

内置Softmax实现10分类的回归效果

总结

可以看到,调用torch内部实现的Softmax函数经过优化之后,相比手写的Softmax函数,迭代的过程更加稳定(抖动更小)


3. 一层感知机(MLP)+ Softmax

W1 = nn.Parameter(torch.randn(num_inputs, num_hiddens, requires_grad=True))
b1 = nn.Parameter(torch.zeros(num_hiddens, requires_grad=True))
W2 = nn.Parameter(torch.randn(num_hiddens, num_outputs, requires_grad=True))
b2 = nn.Parameter(torch.zeros(num_outputs, requires_grad=True))# function ReLU
def relu(X):a = torch.zeros_like(X)return torch.max(X, a)# model
def net(X):X = X.reshape((-1, num_inputs)) # 行数自动调整(batch_size),列数规定H = relu(X @ W1 + b1)return (H @ W2 + b2)loss = nn.CrossEntropyLoss() # 隐式实现 Softmax

实验结果

MLP实现

MLP实现的效果跟前面用单层的Softmax实现的效果差不多


Reference

  1. 李沐老师的课程网站地址 课程网址
  2. 动手学深度学习B站课程地址

写在最后

各位看官,都看到这里了,麻烦动动手指头给博主来个点赞8,您的支持作者最大的创作动力哟!
才疏学浅,若有纰漏,恳请斧正
本文章仅用于各位作为学习交流之用,不作任何商业用途,若涉及版权问题请速与作者联系,望悉知

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/100814.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

机器学习1:k 近邻算法

k近邻算法(k-Nearest Neighbors, k-NN)是一种常用的分类和回归算法。它基于一个简单的假设:如果一个样本的k个最近邻居中大多数属于某一类别,那么该样本也很可能属于这个类别。 k近邻算法的步骤如下: 输入&#xff1a…

JVM第二讲:JVM 基础 - 字节码详解

JVM 基础 - 字节码详解 本文是JVM第二讲,JVM 基础-字节码详解。源代码通过编译器编译为字节码,再通过类加载子系统进行加载到JVM中运行。 文章目录 JVM 基础 - 字节码详解1、多语言编译为字节码在JVM运行2、Java字节码文件2.1、Class文件的结构属性2.2、…

Linux shell编程学习笔记10:expr命令 和 算术运算

Linux Shell 脚本编程和其他编程语言一样,支持算数、关系、布尔、字符串、文件测试等多种运算。上节我们研究了 Linux shell编程 中的 字符串运算,今天我们研究 Linux shell编程的算术运算 ,为了方便举例,我们同时对expr命令进行…

centos 安装svn

卸载 yum remove subversion安装 yum -y install subversion仓库目录 mkdir -p /home/svn/project版本目录 svnadmin create /home/svn/project主目录切换 cd /home/svn/project/conf服务配置 vim svnserve.confanon-access read auth-access write …

TomCat关键技术

一、Tomcat 是什么 Tomcat 是一个 HTTP 服务器。通过前面的学习,我们知道HTTP 协议就是 HTTP 客户端和 HTTP 服务器之间的交互数据的格式,同时也通过 ajax 和 Java Socket 分别构造了 HTTP 客户端。HTTP 服务器我们也同样可以通过 Java Socket 来实现. 而 Tomcat 就是基于 J…

hive add columns 后查询不到新字段数据的问题

分区表add columns 查询不到新增字段数据的问题; 5.1元数据管理 (1)基本架构 Hive的2个重要组件:hiveService2 和metastore,一个负责转成MR进行执行,一个负责元数据服务管理 beeline-->hiveService2/spar…

优思学院|八大浪费深度剖析

在工作流程中消除浪费是精益思想的目标。在深入探讨八大浪费之前,了解浪费的定义至关重要。浪费是指工作流程中的任何行动或步骤,这些行动或步骤不为客户增加价值。换句话说,浪费是客户不愿意为其付费的任何过程。 最初的七大浪费&#xff0…

竞赛选题 深度学习 python opencv 火焰检测识别

文章目录 0 前言1 基于YOLO的火焰检测与识别2 课题背景3 卷积神经网络3.1 卷积层3.2 池化层3.3 激活函数:3.4 全连接层3.5 使用tensorflow中keras模块实现卷积神经网络 4 YOLOV54.1 网络架构图4.2 输入端4.3 基准网络4.4 Neck网络4.5 Head输出层 5 数据集准备5.1 数…

KdMapper扩展实现之SOKNO S.R.L(speedfan.sys)

1.背景 KdMapper是一个利用intel的驱动漏洞可以无痕的加载未经签名的驱动,本文是利用其它漏洞(参考《【转载】利用签名驱动漏洞加载未签名驱动》)做相应的修改以实现类似功能。需要大家对KdMapper的代码有一定了解。 2.驱动信息 驱动名称spee…

Excel恢复科学技术法显示的数据

Excel中输入位数较大的数据时,软件会自动使用科学计数法显示。很多时候并不需要这样的计数格式,所以需要把它转变为普通的数字格式 操作方法 选中单元格/列/行》右键》设置单元格式 在打开的窗口中,切换到“数字”选项卡,点击“自…

【Github】将本地仓库同步到github上

许久没有用GitHub了,怎么传仓库都忘记了。在这里记录一下 If you have a local folder on your machine and you want to transform it into a GitHub repository, follow the steps below: 1. Install Git (if not already installed) Make sure you have Git in…

引领创新浪潮:“Polygon探寻新技术、新治理、新代币的未来之路!“

熊市是用来建设的,Polygon Labs一直在利用这漫长的几个月来做到这一点。 Polygon 是最常用的区块链之一,每周约有 150 万用户,每天超过 230 万笔交易,以及数千个 DApp,Polygon 最近面临着日益激烈的竞争。虽然从交易数…

cartographer(2)-launch-lua的配置

1.了解bag 1roscore2rosbag info rslidar-outdoor-gps.bag了解bag中topic的名称与类型duration: 3:33s types: geometry_msgs?QuaternionStamped nav_msgs_Odometry sensor_msgs/Imu sensor_msgs/IaserScan sensor_msgs/NavSatFix sensor_msgs/PointCloud2 tf2 msgs/TFMe…

BUUCTF [BJDCTF2020]JustRE 1

查看文件信息 使用IDA打开 shift F12搜索字符串 发现类似flag的字符串 点进去 一路跟踪到汇编窗口,然后F5 sprintf将格式化后的字符串输出到String中 最终String的值为 printf("BJD{%d%d2069a45792d233ac}",19999,0);也就是 BJD{1999902069a45792d…

【解决问题思路分析】记录hutool默认使用服务端上次返回cookie的问题解决思路

背景: 本服务需要调用第三方接口获取数据,首先调用public-key接口获取公钥,然后用公钥加密密码,将用户名和密码传入/ticket接口,获取Cookie和response body中的token。 排查思路 由于是调用第三方接口出现问题&…

Final Cut Pro 10.6.10中文用法儿

Final Cut Pro是一款专业视频编辑软件,主要用于影片的后期剪辑、调色、特效、音频处理等方面。 Final Cut Pro for Mac(fcpx视频剪辑) 10.6.10中文版 以下是一些基本的使用方法和快捷键: 添加素材: 在检视器中,可以使用E快捷键把所选素材片…

php的短信验证的流程,如何实现前端js加后端php

目录 PHP的短信验证流程通常涉及以下步骤: 实现PHP短信验证的流程通常需要以下参数: 如何实现前段加后端php: DEMO: PHP的短信验证流程通常涉及以下步骤: 获取短信验证码: 用户提供手机号码。服务器生成随机的验证码,通常是4-6位数字。将验证码与手机…

Typora for Mac:优雅的Markdown文本编辑器,提升你的写作体验

Typora是一款强大的Markdown文本编辑器,专为Mac用户设计。无论你是写作爱好者,还是专业作家或博客作者,Typora都能为你提供无与伦比的写作体验。 1. 直观的界面设计 Typora的界面简洁明了,让你专注于写作,而不是被复…

BC v1.2充电规范

1 JEITA Reference to https://www.mianbaoban.cn/blog/post/169964 符合 JEITA 规范的锂离子电池充电器解决方案 2 Battery Fuel Gauge 2.1 Cycle Count(充放电循环次数) 此指令回传一只读字段,代表电芯组已经历的完整充放电循环数。当放电容…

【力扣】单调栈:901. 股票价格跨度

【力扣】单调栈:901. 股票价格跨度 文章目录 【力扣】单调栈:901. 股票价格跨度1. 题目介绍2. 思路3. 解题代码参考 1. 题目介绍 设计一个算法收集某些股票的每日报价,并返回该股票当日价格的 跨度 。 当日股票价格的 跨度 被定义为股票价格…