【d2l动手学深度学习】 Lesson 10 多层感知机 + 代码实现 试验结果对比

文章目录

  • 1. 介绍
  • 2. 单层Softmax回归
    • 2.1 手写Softmax
      • 训练效果
    • 2.2 调用pytorch内置的softmax回归层实现
      • 调用pytorch内置softmax实验结果
      • 总结
  • 3. 一层感知机(MLP)+ Softmax
    • 实验结果
  • Reference
  • 写在最后


1. 介绍

在第十节课 多层感知机 的代码实现部分,做的小实验,介绍了对FashionMNIST(衣物)数据集进行十分类的神经网络实现效果,主要展示的是训练的Loss以及准确度的训练批次图,10个epoch


2. 单层Softmax回归

2.1 手写Softmax

Softmax函数,输入是一个二维张量

  1. 首先,对输入的矩阵进行求指数exp(X) , 不会改变矩阵大小
  2. 接着,对dim_1 进行求和,得到每个example的指数总和(下面公式的分母部分)
  3. 最后,将整个矩阵相除,得到指数归一化的输出(不改变矩阵形状)

softmax求和

softmax函数内部求和

def softmax(X):X_exp = torch.exp(X) # 1. 求指数partition = X_exp.sum(1, keepdim=True) # 2. 求和return X_exp / partition # keepdim to boardcast 3. 输出指数归一化矩阵 

这个函数在自定义的网络中net( X )调用

def net(X):# -1 means convert the dimension atomatically# the input shape X (256, 1, 28, 28) --> (28*28, 256)return softmax(torch.matmul(X.reshape((-1, W.shape[0])), W) + b)

最后,写一个epoch迭代
⚠️:这里面没有写梯度更新的函数,直接调用torch中的自动求道实现,也就是下文代码中的updater.step()

for X, y in train_iter:y_hat = net(X) # 将batch传入进去l = loss(y_hat, y) # 计算loss值if isinstance(updater, torch.optim.Optimizer):updater.zero_grad()l.backward() # 求梯度updater.step() # 根据梯度更新权重参数矩阵Wmetric.add(float(l) * len(y), accuracy(y_hat, y),y.size().numel())else:l.sum().backward()updater(X.shape[0])metric.add(float(l.sum()), accuracy(y_hat, y), y.numel())

训练效果

手动实现的Softmax训练效果

手动实现的Softmax训练效果


2.2 调用pytorch内置的softmax回归层实现

⚠️:nn.CrossEntropyLoss()会在输出的时候自动应用Softmax进行求Loss(公式如下图所示)

net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))  loss = nn.CrossEntropyLoss() # 损失函数 内嵌了Softmax函数trainer = torch.optim.SGD(net.parameters(), lr=0.1) # 梯度下降优化器选择

在这里插入图片描述

Pytorch文档.CROSSENTROPYLOSS

调用pytorch内置softmax实验结果

内置Softmax实现

内置Softmax实现10分类的回归效果

总结

可以看到,调用torch内部实现的Softmax函数经过优化之后,相比手写的Softmax函数,迭代的过程更加稳定(抖动更小)


3. 一层感知机(MLP)+ Softmax

W1 = nn.Parameter(torch.randn(num_inputs, num_hiddens, requires_grad=True))
b1 = nn.Parameter(torch.zeros(num_hiddens, requires_grad=True))
W2 = nn.Parameter(torch.randn(num_hiddens, num_outputs, requires_grad=True))
b2 = nn.Parameter(torch.zeros(num_outputs, requires_grad=True))# function ReLU
def relu(X):a = torch.zeros_like(X)return torch.max(X, a)# model
def net(X):X = X.reshape((-1, num_inputs)) # 行数自动调整(batch_size),列数规定H = relu(X @ W1 + b1)return (H @ W2 + b2)loss = nn.CrossEntropyLoss() # 隐式实现 Softmax

实验结果

MLP实现

MLP实现的效果跟前面用单层的Softmax实现的效果差不多


Reference

  1. 李沐老师的课程网站地址 课程网址
  2. 动手学深度学习B站课程地址

写在最后

各位看官,都看到这里了,麻烦动动手指头给博主来个点赞8,您的支持作者最大的创作动力哟!
才疏学浅,若有纰漏,恳请斧正
本文章仅用于各位作为学习交流之用,不作任何商业用途,若涉及版权问题请速与作者联系,望悉知

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/100814.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

机器学习1:k 近邻算法

k近邻算法(k-Nearest Neighbors, k-NN)是一种常用的分类和回归算法。它基于一个简单的假设:如果一个样本的k个最近邻居中大多数属于某一类别,那么该样本也很可能属于这个类别。 k近邻算法的步骤如下: 输入&#xff1a…

JVM第二讲:JVM 基础 - 字节码详解

JVM 基础 - 字节码详解 本文是JVM第二讲,JVM 基础-字节码详解。源代码通过编译器编译为字节码,再通过类加载子系统进行加载到JVM中运行。 文章目录 JVM 基础 - 字节码详解1、多语言编译为字节码在JVM运行2、Java字节码文件2.1、Class文件的结构属性2.2、…

Linux shell编程学习笔记10:expr命令 和 算术运算

Linux Shell 脚本编程和其他编程语言一样,支持算数、关系、布尔、字符串、文件测试等多种运算。上节我们研究了 Linux shell编程 中的 字符串运算,今天我们研究 Linux shell编程的算术运算 ,为了方便举例,我们同时对expr命令进行…

TomCat关键技术

一、Tomcat 是什么 Tomcat 是一个 HTTP 服务器。通过前面的学习,我们知道HTTP 协议就是 HTTP 客户端和 HTTP 服务器之间的交互数据的格式,同时也通过 ajax 和 Java Socket 分别构造了 HTTP 客户端。HTTP 服务器我们也同样可以通过 Java Socket 来实现. 而 Tomcat 就是基于 J…

hive add columns 后查询不到新字段数据的问题

分区表add columns 查询不到新增字段数据的问题; 5.1元数据管理 (1)基本架构 Hive的2个重要组件:hiveService2 和metastore,一个负责转成MR进行执行,一个负责元数据服务管理 beeline-->hiveService2/spar…

优思学院|八大浪费深度剖析

在工作流程中消除浪费是精益思想的目标。在深入探讨八大浪费之前,了解浪费的定义至关重要。浪费是指工作流程中的任何行动或步骤,这些行动或步骤不为客户增加价值。换句话说,浪费是客户不愿意为其付费的任何过程。 最初的七大浪费&#xff0…

竞赛选题 深度学习 python opencv 火焰检测识别

文章目录 0 前言1 基于YOLO的火焰检测与识别2 课题背景3 卷积神经网络3.1 卷积层3.2 池化层3.3 激活函数:3.4 全连接层3.5 使用tensorflow中keras模块实现卷积神经网络 4 YOLOV54.1 网络架构图4.2 输入端4.3 基准网络4.4 Neck网络4.5 Head输出层 5 数据集准备5.1 数…

KdMapper扩展实现之SOKNO S.R.L(speedfan.sys)

1.背景 KdMapper是一个利用intel的驱动漏洞可以无痕的加载未经签名的驱动,本文是利用其它漏洞(参考《【转载】利用签名驱动漏洞加载未签名驱动》)做相应的修改以实现类似功能。需要大家对KdMapper的代码有一定了解。 2.驱动信息 驱动名称spee…

Excel恢复科学技术法显示的数据

Excel中输入位数较大的数据时,软件会自动使用科学计数法显示。很多时候并不需要这样的计数格式,所以需要把它转变为普通的数字格式 操作方法 选中单元格/列/行》右键》设置单元格式 在打开的窗口中,切换到“数字”选项卡,点击“自…

引领创新浪潮:“Polygon探寻新技术、新治理、新代币的未来之路!“

熊市是用来建设的,Polygon Labs一直在利用这漫长的几个月来做到这一点。 Polygon 是最常用的区块链之一,每周约有 150 万用户,每天超过 230 万笔交易,以及数千个 DApp,Polygon 最近面临着日益激烈的竞争。虽然从交易数…

BUUCTF [BJDCTF2020]JustRE 1

查看文件信息 使用IDA打开 shift F12搜索字符串 发现类似flag的字符串 点进去 一路跟踪到汇编窗口,然后F5 sprintf将格式化后的字符串输出到String中 最终String的值为 printf("BJD{%d%d2069a45792d233ac}",19999,0);也就是 BJD{1999902069a45792d…

【解决问题思路分析】记录hutool默认使用服务端上次返回cookie的问题解决思路

背景: 本服务需要调用第三方接口获取数据,首先调用public-key接口获取公钥,然后用公钥加密密码,将用户名和密码传入/ticket接口,获取Cookie和response body中的token。 排查思路 由于是调用第三方接口出现问题&…

Typora for Mac:优雅的Markdown文本编辑器,提升你的写作体验

Typora是一款强大的Markdown文本编辑器,专为Mac用户设计。无论你是写作爱好者,还是专业作家或博客作者,Typora都能为你提供无与伦比的写作体验。 1. 直观的界面设计 Typora的界面简洁明了,让你专注于写作,而不是被复…

BC v1.2充电规范

1 JEITA Reference to https://www.mianbaoban.cn/blog/post/169964 符合 JEITA 规范的锂离子电池充电器解决方案 2 Battery Fuel Gauge 2.1 Cycle Count(充放电循环次数) 此指令回传一只读字段,代表电芯组已经历的完整充放电循环数。当放电容…

【力扣】单调栈:901. 股票价格跨度

【力扣】单调栈:901. 股票价格跨度 文章目录 【力扣】单调栈:901. 股票价格跨度1. 题目介绍2. 思路3. 解题代码参考 1. 题目介绍 设计一个算法收集某些股票的每日报价,并返回该股票当日价格的 跨度 。 当日股票价格的 跨度 被定义为股票价格…

PicGo+Gitee+Typora搭建云图床

🙈作者简介:练习时长两年半的Java up主 🙉个人主页:程序员老茶 🙊 ps:点赞👍是免费的,却可以让写博客的作者开心好久好久😎 📚系列专栏:Java全栈,…

数据结构 堆——详细动画图解,形象理解

作者主页 📚lovewold少个r博客主页 ​➡️栈和队列博客传送门 🌳参天大树充满生命力,其根深叶茂,分枝扶疏,为我们展示了数据分治的生动形态 目录 🌳 树 树的常见概念 📒树的表示 二叉树 一…

探索乡村新风貌:VR全景记录乡村发展,助力乡村振兴

引言: 中国乡村正经历着巨大变革,长期以来,乡村地区一直面临着人口外流、资源匮乏等问题。然而,近年来,政府的政策支持以及新兴技术的崭露头角,如虚拟现实(VR)全景记录,…

随着 ChatGPT 凭借 GPT-4V(ision) 获得关注,多模态 AI 不断发展

原创 | 文 BFT机器人 在不断努力让人工智能更像人类的过程中,OpenAI的GPT模型不断突破界限GPT-4现在能够接受文本和图像的提示。 生成式人工智能中的多模态表示模型根据输入生成文本、图像或音频等各种输出的能力。这些模型经过特定数据的训练,学习底层模…

【photoshop学习】用 Photoshop 做的 15 件创意事

用 Photoshop 做的 15 件创意事 每个人总是谈论 Photoshop 的无限可能。您可以使用该程序做很多事情,列表几乎是无穷无尽的。 嘿,我是卡拉!如果您花过一些时间使用 在线ps,您可能见过我(并且注意到我提到了这一点&am…