【深度学习笔记】3_7 softmax回归的简介实现

注:本文为《动手学深度学习》开源内容,仅为个人学习记录,无抄袭搬运意图

3.7 softmax回归的简洁实现

我们在3.3节(线性回归的简洁实现)中已经了解了使用Pytorch实现模型的便利。下面,让我们再次使用Pytorch来实现一个softmax回归模型。首先导入所需的包或模块。

import torch
from torch import nn
from torch.nn import init
import numpy as np
import sys
sys.path.append("..") 
import d2lzh_pytorch as d2l

3.7.1 获取和读取数据

我们仍然使用Fashion-MNIST数据集和上一节中设置的批量大小。

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

3.7.2 定义和初始化模型

在3.4节(softmax回归)中提到,softmax回归的输出层是一个全连接层,所以我们用一个线性模块就可以了。因为前面我们数据返回的每个batch样本x的形状为(batch_size, 1, 28, 28), 所以我们要先用view()x的形状转换成(batch_size, 784)才送入全连接层。

num_inputs = 784
num_outputs = 10class LinearNet(nn.Module):def __init__(self, num_inputs, num_outputs):super(LinearNet, self).__init__()self.linear = nn.Linear(num_inputs, num_outputs)def forward(self, x): # x shape: (batch, 1, 28, 28)y = self.linear(x.view(x.shape[0], -1))return ynet = LinearNet(num_inputs, num_outputs)

我们将对x的形状转换的这个功能自定义一个FlattenLayer并记录在d2lzh_pytorch中方便后面使用。

# 本函数已保存在d2lzh_pytorch包中方便以后使用
class FlattenLayer(nn.Module):def __init__(self):super(FlattenLayer, self).__init__()def forward(self, x): # x shape: (batch, *, *, ...)return x.view(x.shape[0], -1)

这样我们就可以更方便地定义我们的模型:

from collections import OrderedDictnet = nn.Sequential(# FlattenLayer(),# nn.Linear(num_inputs, num_outputs)OrderedDict([('flatten', FlattenLayer()),('linear', nn.Linear(num_inputs, num_outputs))])
)

然后,我们使用均值为0、标准差为0.01的正态分布随机初始化模型的权重参数。

init.normal_(net.linear.weight, mean=0, std=0.01)
init.constant_(net.linear.bias, val=0) 

3.7.3 softmax和交叉熵损失函数

如果做了上一节的练习,那么你可能意识到了分开定义softmax运算和交叉熵损失函数可能会造成数值不稳定。因此,PyTorch提供了一个包括softmax运算和交叉熵损失计算的函数。它的数值稳定性更好。

loss = nn.CrossEntropyLoss()

3.7.4 定义优化算法

我们使用学习率为0.1的小批量随机梯度下降作为优化算法。

optimizer = torch.optim.SGD(net.parameters(), lr=0.1)

3.7.5 训练模型

接下来,我们使用上一节中定义的训练函数来训练模型。

num_epochs = 5
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, None, None, optimizer)

输出:

epoch 1, loss 0.0031, train acc 0.745, test acc 0.790
epoch 2, loss 0.0022, train acc 0.812, test acc 0.807
epoch 3, loss 0.0021, train acc 0.825, test acc 0.806
epoch 4, loss 0.0020, train acc 0.832, test acc 0.810
epoch 5, loss 0.0019, train acc 0.838, test acc 0.823

小结

  • PyTorch提供的函数往往具有更好的数值稳定性。
  • 可以使用PyTorch更简洁地实现softmax回归。

注:本节除了代码之外与原书基本相同,原书传送门

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/702616.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

每日五道java面试题之spring篇(六)

目录: 第一题 ApplicationContext通常的实现是什么?第二题 什么是Spring的依赖注入?第三题 依赖注入的基本原则第四题 依赖注入有什么优势?第五题 有哪些不同类型的依赖注入实现方式? 第一题 ApplicationContext通常的…

d3dcompiler_47.dll是什么,电脑出现d3dcompiler_47.dll丢失如何解决

当打开软件时提示“d3dcompiler_47.dll丢失”时,用户通常会看到类似于以下的错误消息: “无法启动此程序,因为计算机中丢失了d3dcompiler_47.dll。尝试重新安装该程序以解决此问题。” “找不到d3dcompiler_47.dll文件,因此应用…

分布式一致性软件-zookeeper

在我们进行软件开发过程中,为了实现某个功能可能借助多个软件,如存储数据的数据库软件:MySQL,Redis;消息中间件:rocketMq,kafka等。那么在分布式系统中,如果想实现数据一致性&#x…

[git] 根据master更新本地分支

git 当前分支 t1 拉取远程master分支, 并更新将master分支合并到当前分支t1 git fetch origin master git merge origin/master t1第一行命令 git fetch origin master 会从远程仓库中获取最新的 master 分支代码,将其保存到本地的 origin/master 分支。 第二行命…

数据库原理及应用(MySQL版在线实训版)

学习视频:学生学习页面 (xueyinonline.com) 头歌实验:虚拟教研室-数据库原理及应用 (educoder.net) 习题答案:数据库原理及应用(MySQL版在线实训版)-习题答案 陈业斌 - 道客巴巴 (doc88.com)

[C++]虚函数用法

讲虚函数之前先讲讲面向对象的三大特性:封装、继承、多态。 1、封装 封装是指将数据(属性)和操作数据的方法(函数)封装在一个单元中,这个单元就是类。封装的主要目的是隐藏类的内部实现细节,只…

Java内部类的使用与应用

内部类的使用与应用 1. 内部类用法 普通内部类: 实例化内部类对象需要先实例化外部类对象,然后再通过OuterClassName.new InnerClassName()方式实例化内部类。内部类对象在创建后会与外部类对象秘密链接,因此无法独立于外部类创建内部类对象…

JWT学习笔记

了解 JWT Token 释义及使用 | Authing 文档 JSON Web Token Introduction - jwt.io JSON Web Token (JWT,RFC 7519 (opens new window)),是为了在网络应用环境间传递声明而执行的一种基于 JSON 的开放标准((RFC 7519)。该 token 被设计为紧凑…

微服务-微服务Spring Security OAuth 2实战

1. Spring Authorization Server 是什么 Spring Authorization Server 是一个框架,它提供了 OAuth 2.1 和 OpenID Connect 1.0 规范以及其他相关规范的实现。它建立在 Spring Security 之上,为构建 OpenID Connect 1.0 身份提供者和 OAuth2 授权服务器产…

高等数学(极限)

目录 一、数列 二、极限 2.1 讲解 2.2 例题 一、数列 按照一定次数排列的一列数: 其中 叫做通项。 对于数列,如果当n无限增大时,其通项无限接近于一个常数A,则称该数列以A为极限或称数列收敛于A,否则称数列为发散…

C语言所有标准头文件汇总及功能说明

文章目录 #include <assert.h>#include <ctype.h>#include <errno.h>#include <float.h>#include <fenv.h>#include <inttypes.h>#include <limits.h>#include <locale.h>#include <math.h>#include <process.h>#…

OpenCV(2)

1.OpenCV的模块 其中core、highgui、imgproc是最基础的模块&#xff0c;该课程主要是围绕这几个模块展开的&#xff0c;分别介绍如下&#xff1a; core模块实现了最核心的数据结构及其基本运算&#xff0c;如绘图函数、数组操作相关函数等。highgui模块实现了视频与图像的读取…

【JVM】计数器引用和可达性分析

&#x1f4dd;个人主页&#xff1a;五敷有你 &#x1f525;系列专栏&#xff1a;JVM ⛺️稳中求进&#xff0c;晒太阳 C/C的内存管理 在C/C这类没有自动垃圾回收机制的语言中&#xff0c;一个对象如果不再使用&#xff0c;需要手动释放&#xff0c;否则就会出现内存泄漏…

一文get,最容易碰上的接口自动化测试问题汇总

本篇文章分享几个接口自动化用例编写过程遇到的问题总结&#xff0c;希望能对初次探索接口自动化测试的小伙伴们解决问题上提供一小部分思路。 sql语句内容出现错误 空格&#xff1a;由于有些字段判断是变量&#xff0c;需要将sql拼接起来&#xff0c;但是在拼接字符串时没有…

对象池模板

概述 对象池的引入也是嵌入式开发的常用方法&#xff0c;也是内存预分配的一种&#xff0c;主要是用来隐藏全局对象的跟踪&#xff0c;通常预内存分配是通过数组来实现。 CMake配置 cmake_minimum_required(VERSION 3.5.1)project(objpool)add_executable(objpool objpool.cp…

C语言《数据结构与算法》安排教学计划课设

背景&#xff1a; 10、安排教学计划 (1) 问题描述。 学校每学期开设的课程是有先后顺序的&#xff0c;如计算机专业&#xff1a;开设《数据结构》课程之前&#xff0c;必须先开设《C语言程序设计》和《离散数学》课程&#xff0c;这种课程开设的先后顺序称为先行、后继课程关…

在使用nginx的时候快速测试配置文件,并重新启动

小技巧 Nginx修改配置文件后需要重新启动&#xff0c;常规操作是启动在任务管理器中关闭程序然后再次双击nginx.exe启动&#xff0c;但是使用命令行就可以快速的完成操作。 将cmd路径切换到nginx的安装路径 修改完成配置文件后 使用 nginx -t校验nginx 的配置文件是否出错 …

【Swift学习路线讲解】

Swift学习路线讲解 Swift学习路线 Swift学习路线 Swift 是 Apple 开发的一种强大的编程语言&#xff0c;专门为 iOS、macOS、watchOS 和 tvOS 应用程序设计&#xff0c;如果你想成为一名 Swift 开发者&#xff0c;以下是一个推荐的学习路线&#xff1a; 基础概念&#xff1a; …

海豚调度DolphinScheduler入门学习

DS简介&#xff1a; DolphinScheduler 是一款分布式的、易扩展的、高可用的数据处理平台&#xff0c;主要包含调度中心、元数据管理、任务编排、任务调度、任务执行和告警等模块。其技术架构基于 Spring Boot 和 Spring Cloud 技术栈&#xff0c;采用了分布式锁、分布式任务队列…

vue3 实现 el-pagination页面分页组件的封装以及调用

示例图 一、组件代码 <template><el-config-provider :locale"zhCn"><el-pagination background class"lj-paging" layout"prev, pager, next, jumper" :pager-count"5" :total"total":current-page"p…