基于Pytorch深度学习——Softmax回归

本文章来源于对李沐动手深度学习代码以及原理的理解,并且由于李沐老师的代码能力很强,以及视频中讲解代码的部分较少,所以这里将代码进行尽量逐行详细解释
并且由于pytorch的语法有些小伙伴可能并不熟悉,所以我们会采用逐行解释+小实验的方式来给大家解释代码

大家都知道二分类问题我们在机器学习里面使用到的是逻辑回归这个算法,但是针对于多分类问题,我们常用的是Softmax技术,大家不要被这个名字给迷惑了,softmax回归并不是一种回归技术,而是一种分类技术

tips:本文需要下载d2l包,大家可以按照下面的指令安装

pip install -U d2l

导入模块

import torch
from IPython import display
from d2l import torch as d2l

这里如果存在报错的话,可以自己pip install对应的模块,如果还出现了其他问题,可以在评论区中问出来

导入数据集

这里我们使用的是手写数字识别分类数据集mnist,李沐老师在d2l包李沐已经帮我们写好了下载并且导入这个数据集的函数,我们只需要指定batch_size,然后再导入对应的模块即可

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

这个代码里面load_data_fashion_mnist返回的是两个迭代器,如果对迭代器并不是特别了解的同学,可以先去看看python基础课

batch_size

相信初学者们肯定对batch_size这个这个超参数感到非常的疑惑(超参数这个名词我们在以后的文章也会解释到)。
batch_size的官方定义是深度学习模型在训练过程中一次性输入给模型的样本数量,因为深度学习通常使用的样本都是上万的数据量,如果一口气全部给模型进行训练或者推理,那么GPU/CPU肯定会吃不消,所以我们通常会分批次将数据输入给模型
当然有关于如何选择batch_size的问题也是一个值得探究的问题,一般而言,如果你的CPU/GPU的内存比较小的话,建议选择比较小的batch_size,如果你的内存比较大的话,可以选择较大的batch_size,较大的batch_size可以加快你的训练速度

初始化模型参数

在我们导入的数据集中,每个样本都是 28 × 28 28 \times 28 28×28的图像,但是在我们一般的处理中,我们会选择将展平每个图像,把它们看作长度为784的向量。
需要我们注意的是Softmax回归,我们的输出与输入的类别的数目是一样多的
因此,权重将构成一个 784 × 10 784 \times 10 784×10的矩阵,偏置将构成一个 1 × 10 1 \times 10 1×10的行向量。
于是我们可以得到下面的矩阵计算公式
X 1 × 784 × W 784 × 10 + b 1 × 10 = Y 1 × 10 (1.1) X_{1 \times 784} \times W_{784 ×10} + b_{1×10}= Y_{1×10} \tag{1.1} X1×784×W784×10+b1×10=Y1×10(1.1)
这个式子和我们的线性回归模型十分相似,但是存在一定的不一样,因为线性回归的输入类别和输出类别都是2

num_inputs = 784
num_outputs = 10W = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True)
b = torch.zeros(num_outputs, requires_grad=True)

在这个代码里面,我们将W初始化为(784×10)的一个矩阵,并且里面的每个元素都服从均值为0,标准差为0.01的正态分布;将b初始化为(1×10)的一个全零向量

Softmax实现

想要实现softmax,我们需要完成三个步骤,这三个步骤是当我们已经计算出了Y的值之后
Step1:对Y的每个维度进行求e指数
Step2:需要对每一行进行求和,得到一个规范化的常数
Step3:将每一行都与规范化的常数进行相除,确定结果的和为1
用公式来表达就是下面的这个公式
s o f t m a x ( X ) i j = exp ⁡ ( X i j ) ∑ k exp ⁡ ( X i k ) . \mathrm{softmax}(\mathbf{X})_{ij} = \frac{\exp(\mathbf{X}_{ij})}{\sum_k \exp(\mathbf{X}_{ik})}. softmax(X)ij=kexp(Xik)exp(Xij).

def softmax(X):X_exp = torch.exp(X)partition = X_exp.sum(1, keepdim=True)return X_exp / partition  # 这里应用了广播机制

在上面的这个定义中,其实都还比较好理解,唯一可能让初学者产生疑问的就是有关于sum函数里面的维度参数的使用,于是我们做一个小实验来看看这个参数它具体有什么作用

小实验/维度参数

我们可以先用torch.tensor创造一个tensor数据类型的二维数组

X = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])

为了我们后面好观察,我们在这里直观一点,将矩阵的形式用公式打出来:
1 2 3 4 5 6 (1) \begin{matrix} 1 & 2 & 3 \\ 4 & 5 & 6 \end{matrix} \tag1 142536(1)
首先,我们先不对维度参数做出赋值,看看sum会产生怎么样的结果

X.sum()
>>> tensor(21.)

从这个例子我们可以看出,如果不对维度参数赋值的话,sum就是把X里面的所有元素进行累加,也就是 1 + 2 + 3 + 4 + 5 + 6 = 21 1+2+3+4+5+6=21 1+2+3+4+5+6=21

然后我们再举一个例子,此时我们不赋值keepdim参数的值,我们只选择0和1参数

X.sum(0),X.sum(1)
>>>(tensor([5., 7., 9.]), tensor([ 6., 15.]))

我们先来看看X.sum(0)代表的是什么,根据矩阵(1)我们可以很直观的看到,[5,7,9]代表的是将原来的矩阵的每一列进行求和最后组成一个行向量
同理可得X.sum(1)就是把原来矩阵的每一行进行求和,最后组成一个行向量

最后我们添加上keepdim参数

X.sum(0, keepdim=True), X.sum(1, keepdim=True)
>>>(tensor([[5., 7., 9.]]),tensor([[ 6.],[15.]]))

我们不难看出,添加了keepdim参数之后,结果里面元素的值并没有发生变化,发生变化的是矩阵的维度,所以我们可以知道keepdim是用来保存原矩阵加和的维度的
也就是说,假如你使用0进行将每一列进行求和,那么你得到的就会是一个行向量,因为你相当于对着每一行进行叠加,这里和pandas的求和方法是一样的

定义模型

def net(X):return softmax(torch.matmul(X.reshape((-1, W.shape[0])), W) + b)

这个代码的大致也很好理解,就是将公式(1.1)进行了计算机化表现,对于初学者而言,可能比较难理解的是有关于reshape的应用,下面我们还是使用几个小实验来进行探索

小实验/reshape

reshape操作是只改变行和列的维度,但是不会改变矩阵中总元素的个数,也就是说如果你想让3×5的矩阵变成4×4的矩阵,这样是做不到的
我们可以借用(1)矩阵来讲解,这里为了大家好观察,我再将矩阵(1)拿过来
1 2 3 4 5 6 (1) \begin{matrix} 1 & 2 & 3 \\ 4 & 5 & 6 \end{matrix} \tag1 142536(1)
我们可以先来试试把这个矩阵给展成一个行向量,需要注意的是reshape里面的参数第一个是你想要改变后的行数,第二个是列数

X.reshape(1,6)
>>>tensor([[1., 2., 3., 4., 5., 6.]])

然后我们可以试试把它改成3×2的矩阵

X.reshape(3,2)
>>>tensor([[1., 2.],[3., 4.],[5., 6.]])

接下来我们可以看看在代码里面的-1是什么意思

X.reshape(1,-1)
>>>tensor([[1., 2., 3., 4., 5., 6.]])

从这里我们可以看出来,如果reshape里面的参数是-1,这个函数就会自己计算合适的维度,让元素的总个数不产生变化

最后我们可以验证一下reshape操作并不可以改变元素的总个数

X.reshape(1,8)
>>>RuntimeError: shape '[1, 8]' is invalid for input of size 6

可以看到发生了报错

定义损失函数

与线性回归模型不同,线性回归使用的损失函数是MSE,但是在softmax分类里面我们采用的是交叉熵损失函数
李沐老师课上讲的函数表达式非常的简洁,但是并不是特别好理解,所以我么采用吴恩达老师讲述的交叉熵损失函数的公式,实际上他们是一个式子
L ( y , p ) = − ( y l o g ( p ) + ( 1 − y ) l o g ( 1 − p ) ) L(y,p) = -(ylog(p) + (1-y)log(1-p)) L(y,p)=(ylog(p)+(1y)log(1p))
李沐老师在代码实现这个函数的时候采用了一个非常巧妙的想法,下面我们来仔细的研究一下

def cross_entropy(y_hat, y):return - torch.log(y_hat[range(len(y_hat)), y])cross_entropy(y_hat, y)

我相信大家对y_hat[range(len(y_hat))的疑问很大

小实验/索引匹配

我们可以先自己定义一个y和y_hat来看看上面这个代码的表达是什么意思

y = torch.tensor([0, 2])
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
y_hat[[0, 1], y]
>>> tensor([0.1000, 0.5000])

这个是什么意思呢?
我们可以来看最后面的y_hat[[0, 1], y],我们可以把这个代码理解为把前面的**[0,1]以及y都看作是一个迭代器**,并且这个迭代器满足索引匹配,也就是这个函数等价于y_hat[0,0]与y_hat[1,2]

定义优化器

lr = 0.1def updater(batch_size):return d2l.sgd([W, b], lr, batch_size)

在这里我们还是采用随机梯度下降的方式来优化模型

训练模型

训练模型直接调用李沐老师写好的函数即可

num_epochs = 10
train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, updater)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/5080.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

深度学习基础之《TensorFlow框架(16)—神经网络案例》

一、mnist手写数字识别 1、数据集介绍 mnist数据集是一个经典的数据集,其中包括70000个样本,包括60000个训练样本和10000个测试样本 2、下载地址:http://yann.lecun.com/exdb/mnist/ 3、文件说明 train-images-idx3-ubyte.gz: training s…

Hbase学习笔记

Hbase是什么 HBase是一个高可靠、高性能、面向列、可伸缩的分布式存储系统。它利用Hadoop HDFS作为其文件存储系统,并提供实时的读写的数据库系统。HBase的设计思想来源于Google的BigTable论文,是Apache的Hadoop项目的子项目。它适合于存储大表数据,并可以达到实时级别。HB…

【Redis 开发】Lua语言

Lua Lua语法 Lua语法 Lua是一种小巧的脚本语言,底层用C语言实现,为了嵌入式应用程序中 官网:https://www.lua.org/ 创建lua文件 touch hello.lua 运行lua文件 lua hello.lua 输出语句 print("Hello World!")数据类型 可以通过t…

一篇易懂的SPI通讯指南

SPI概念 SPI(Serial Peripheral interface, 串行外设接口)是微处理控制单元(MCU)和外围IC(如传感器、ADC、DAC、驱动芯片和外部存储设备等)之间进行通信的同步串行端口,其通信速率一般可以从几千bps到几百Mbps甚至更高…

QT httpServer多线程后台服务器的例子实现

1.需求 1.1 用户需要其他平台(web端)调用Qt平台的接口,获取想要的数据并实时显示在网页里,比如实时的温湿度,用户数据等 1.2 用户需要在其他平台(web端)调用Qt平台的接口,下发数据…

kettle从入门到精通 第五十五课 ETL之kettle Excel输入

想真正学习或者提升自己的ETL领域知识的朋友欢迎进群,一起学习,共同进步。 1、 Excel输入,Microsoft Excel输入步骤的作用是从Microsoft Excel中读取数据,如下图所示: 1)Excel输入步骤从文件D:\data\测试数…

Linux实现简单进度条(附原理解释和动图效果)

1&#xff0c;行缓冲区 先看下面的代码和运行结果&#xff0c; #include<stdio.h> #include<unistd.h> int main() {printf("你好\n");sleep(3);return 0; }只是一个简单的打印“你好”然后休眠三秒&#xff0c;最后程序结束 再看下面的代码和运行结果…

Docker-Compose单机多容器应用编排与管理

前言 Docker Compose 作为 Docker 生态系统中的一个重要组件&#xff0c;为开发人员提供了一种简单而强大的方式来定义和运行多个容器化应用。本文将介绍 Docker Compose 的使用背景、优劣势以及利用 Docker Compose 简化应用程序的部署和管理。 目录 一、Docker Compose 简…

数据库学习之常见的一些SQL命令

查看当前DBMS下所有数据库 show databases; 切换到某一个数据库 use 数据库名; 开启root的远程登录 update mysql.user set host"%" where user"root"; 刷新权限列表 flush privileges; 创建数据库 create database 数据库名称 删除数据库 drop…

CUDA架构介绍与设计模式解析

文章目录 **CUDA**架构介绍与设计模式解析**1** CUDA 介绍CUDA发展历程CUDA主要特性CUDA系统架构CUDA应用场景编程语言支持CUDA计算过程线程层次存储层次 **2** CUDA 系统架构分层架构并行计算模式生产-消费者模式工作池模式异步编程模式 **3** CUDA 中的设计模式工厂模式策略模…

django设计模式理解FBV和CBV

在 Web 开发中&#xff0c;FBV&#xff08;Function-Based Views&#xff09;和 CBV&#xff08;Class-Based Views&#xff09;是两种常见的视图设计模式&#xff0c;用于处理 HTTP 请求并生成相应的响应。下面是它们的简要解释&#xff1a; Function-Based Views (FBV) 在 …

闲话 Asp.Net Core 数据校验(三)EF Core 集成 FluentValidation 校验数据例子

前言 一个在实际应用中 EF Core 集成 FluentValidation 进行数据校验的例子。 Step By Step 步骤 创建一个 Asp.Net Core WebApi 项目 引用以下 Nuget 包 FluentValidation.AspNetCore Microsoft.AspNetCore.Identity.EntityFrameworkCore Microsoft.EntityFrameworkCore.Re…

SQL事前巡检插件

背景: 事故频发 •在工作过程中每年都会看到SQL问题引发的线上问题&#xff0c;一条有问题的SQL足以拖垮整个数据库 不易发觉 •对于SQL性能问题测试在预发环境不易发现&#xff08;数据量小&#xff09; •SAAS系统隔离字段在SQL条件中遗漏&#xff0c;造成越权风险 •业…

密链:openEuler20.03已安装的软件包列表

文章目录 openEuler20.03已安装的软件包列表 openEuler20.03已安装的软件包列表 rpm -qa[rootlocalhost tmp]# rpm -qa librelp-1.2.16-3.oe2003sp4.aarch64 p11-kit-trust-0.23.20-5.oe2003sp4.aarch64 python-setuptools-44.1.1-2.oe2003sp4.noarch luksmeta-9-5.oe2003sp4.…

【算法刷题 | 贪心算法08】4.29(划分字母区间、合并区间)

文章目录 14.划分字母区间14.1题目14.2解法&#xff1a;贪心14.2.1贪心思路14.2.2代码实现 15.合并区间15.1题目15.2解法&#xff1a;贪心15.2.1贪心思路15.2.2代码实现 14.划分字母区间 14.1题目 给你一个字符串 s 。我们要把这个字符串划分为尽可能多的片段&#xff0c;同一…

WebSocket 全面解析

&#x1f31f; 引言 WebSocket&#xff0c;一个让实时通信变得轻而易举的神器&#xff0c;它打破了传统HTTP协议的限制&#xff0c;实现了浏览器与服务器间的全双工通信。想象一下&#xff0c;即时消息、在线游戏、实时股票报价…这一切都离不开WebSocket的魔力&#x1f4ab;。…

LT6911UXE HDMI 2.0 至双端口 MIPI DSI/CSI,带音频 龙迅方案

1. 描述LT6911UXE 是一款高性能 HDMI2.0 至 MIPI DSI/CSI 转换器&#xff0c;适用于 VR、智能手机和显示应用。HDMI2.0 输入支持高达 6Gbps 的数据速率&#xff0c;可为4k60Hz视频提供足够的带宽。此外&#xff0c;数据解密还支持 HDCP2.3。对于 MIPI DSI / CSI 输出&#xff0…

零基础HTML教程(30)--迈入HTML5新时代

文章目录 1. 从H4时代到H5时代2. 属性值可以不用引号3. 标签使用大小写均可4. 部分属性值可以省略5. 浏览器支持情况6. 小结 1. 从H4时代到H5时代 之前讲的29篇HTML教程&#xff0c;内容基本都是H4时代就有的。 随着时代的发展&#xff0c;H4多少有点不够用&#xff0c;所以H…

历届试题 买不到的数目

历届试题 买不到的数目 时间限制&#xff1a;1.0s 内存限制&#xff1a;256.0MB 问题描述 小明开了一家糖果店。他别出心裁&#xff1a;把水果糖包成4颗一包和7颗一包的两种。糖果不能拆包卖。 小朋友来买糖的时候&#xff0c;他就用这两种包装来组合。当然有些糖果…

pyQt5 和 Qt Designer 实现登录注册案例

Qt Designer 设计页面: 通过 PyQt5 手写 1. 先引入用到的库 from PyQt5.QtWidgets import * import sys 2. 创建应用,窗口, 设置窗口 # 创建应用 app QApplication(sys.argv) # 创建窗口 w QWidget()# 设置窗口标题 w.setWindowTitle("注册登录")# 展示 w.sho…