Softmax回归

一、Softmax回归关键思想

1、回归问题和分类问题的区别

       Softmax回归虽然叫“回归”,但是它本质是一个分类问题。回归是估计一个连续值,而分类是预测一个离散类别。

2、Softmax回归模型

       Softmax回归跟线性回归一样将输入特征与权重做线性叠加。与线性回归的一个主要不同在于,Softmax回归的输出值个数等于标签里的类别数。比如一共有4种特征和3种输出动物类别(猫、狗、猪),则权重包含12个标量(带下标的$w$),偏差包含3个标量(带下标的$b$),且对每个输入计算$ O_1,O_2,O_3 $这三个输出:

$ \begin{aligned} o_1 &= x_1 w_{11} + x_2 w_{12} + x_3 w_{13} + x_4 w_{14} + b_1,\\ o_2 &= x_1 w_{21} + x_2 w_{22} + x_3 w_{23} + x_4 w_{24} + b_2,\\ o_3 &= x_1 w_{31} + x_2 w_{32} + x_3 w_{33} + x_4 w_{34} + b_3. \end{aligned} $

最后,再对这些输出值进行Softmax函数运算

       softmax回归同线性回归一样,也是一个单层神经网络。由于每个输出$ O_1,O_2,O_3 $的计算都要依赖于所有的输入$ X_1,X_2,X_3,X_4 $,所以softmax回归的输出层也是一个全连接层。

3、Softmax函数

       Softmax用于多分类过程中,它将多个神经元的输出(比如$ O_1,O_2,O_3 $)映射到(0,1)区间内,可以看成概率来理解,从而来进行多分类!它通过下式将输出值变换成值为正且和为1的概率分布:

$\widehat{y_1},\widehat{y_2},\widehat{y_3} = \mathrm{softmax}(o_1,o_2,o_3)$

其中:

$ \widehat{y}_j=\frac{\exp \left( o_1 \right)}{\sum\limits_{i=1}^3{\exp \left( o_i \right)}} $, $ \widehat{y}_j=\frac{\exp \left( o_2 \right)}{\sum\limits_{i=1}^3{\exp \left( o_i \right)}} $, $ \widehat{y}_j=\frac{\exp \left( o_3 \right)}{\sum\limits_{i=1}^3{\exp \left( o_i \right)}} $

       容易看出 $ \widehat{y_1}+\widehat{y_2}+\widehat{y_3}=1 $ 且 $ \widehat{y_1}+\widehat{y_2}+\widehat{y_3}=1 $,因此 $ \widehat{y_1},\widehat{y_2},\widehat{y_3} $ 是一个合法的概率分布。此外,我们注意到:

$ arg\max\text{\ }o_i=arg\max\text{\ }\widehat{y_i} $

 因此softmax运算不改变预测类别输出。

       下图可以更好的理解Softmax函数,其实就是取自然常数e的指数相加后算比例,由于自然常数的指数($ e^x $)在$ \left( -\infty ,+\infty \right) $单调递增,因此softmax运算不改变预测类别输出。

4、交叉熵损失函数

       假设我们希望根据图片动物的轮廓、颜色等特征,来预测动物的类别,有三种可预测类别:猫、狗、猪。假设我们当前有两个模型(参数不同),这两个模型都是通过sigmoid/softmax的方式得到对于每个预测结果的概率值:

模型1:

模型1
预测真实是否正确
0.30.30.4001正确
0.30.40.3010正确
0.10.20.7100错误

       模型评价:模型1对于样本1和样本2以非常微弱的优势判断正确,对于样本3的判断则彻底错误。

模型2:

模型2
预测真实是否正确
0.10.20.7001正确
0.10.70.2010正确
0.30.40.3100错误

       模型评价:模型2对于样本1和样本2判断非常准确,对于样本3判断错误,但是相对来说没有错得太离谱。

       好了,有了模型之后,我们需要通过定义损失函数来判断模型在样本上的表现了,那么我们可以定义哪些损失函数呢?我们可以先尝试使用以下几种损失函数,然后讨论哪种效果更好。

(1)Classification Error(分类错误率)

       最为直接的损失函数定义为:

$ classification\ error=\frac{count\ of\ error\ items}{count\ of\ all\ items} $

模型1:$ classification\ error=\frac{1}{3} $

模型2:$ classification\ error=\frac{2}{3} $

       我们知道,模型1模型2虽然都是预测错了1个,但是相对来说模型2表现得更好,损失函数值照理来说应该更小,但是,很遗憾的是,classification error 并不能判断出来,所以这种损失函数虽然好理解,但表现不太好。

(2)Mean Squared Error(均方误差MSE)

       均方误差损失也是一种比较常见的损失函数,其定义为:

$ MSE=\frac{1}{n}\sum_i^n{\left( \widehat{y_i}-y_i \right) ^2} $

模型1:

对所有样本的loss求平均:

模型2:

对所有样本的loss求平均:

       我们发现,MSE能够判断出来模型2优于模型1,那为什么不采样这种损失函数呢?主要原因是在分类问题中,使用sigmoid/softmx得到概率,配合MSE损失函数时,采用梯度下降法进行学习时,会出现模型一开始训练时,学习速率非常慢的情况(损失函数 | Mean-Squared Loss - 知乎)。

       有了上面的直观分析,我们可以清楚的看到,对于分类问题的损失函数来说,分类错误率和均方误差损失都不是很好的损失函数,下面我们来看一下交叉熵损失函数的表现情况。

(3)Cross Entropy Loss Function(交叉熵损失函数)

其中:

$M$:类别的数量

$ y_{ic} $:符号函数(0或1),如果样本 i 的真实类别等于 c 取 1,否则取 0

$ p_{ic} $:观测样本 i 属于类别 c 的预测概率

$N$:样本的数量

现在我们利用这个表达式计算上面例子中的损失函数值:

模型1

对所有样本的loss求平均:

模型2:

对所有样本的loss求平均:

       可以发现,交叉熵损失函数可以捕捉到模型1和模型2预测效果的差异,因此对于Softmax回归问题我们常用交叉熵损失函数。

      下面两图可以很清晰的反应整个Softmax回归算法的流程:

二、图像分类数据集

       MNIST数据集是图像分类中广泛使用的数据集之一,但作为基准数据集过于简单。我们将使用类似但更复杂的Fashion-MNIST数据集。

       在这里我们定义一些函数用于数据的读取与显示,这些函数已经在Python包d2l中定义好了,但为了便于大家理解,这里没有直接调用d2l中的函数。

1、读取数据集

       我们可以通过框架中的内置函数将Fashion-MNIST数据集下载并读取到内存中。

# 通过ToTensor实例将图像数据从PIL类型变换成32位浮点数格式,
# 并除以255使得所有像素的数值均在0~1之间
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)

       Fashion-MNIST由10个类别的图像组成,每个类别由训练数据集(train dataset)中的6000张图像和测试数据集(test dataset)中的1000张图像组成。因此,训练集和测试集分别包含60000和10000张图像。测试数据集不会用于训练,只用于评估模型性能。

print(len(mnist_train), len(mnist_test))
60000 10000

       每个输入图像的高度和宽度均为28像素。数据集由灰度图像组成,其通道数为1。为了简洁起见,本书将高度$h$像素、宽度$w$像素图像的形状记为$h \times w$($h$,$w$)。接下来我们可以打印一下mnist_train的类型和mnist_train的第一个元素。

print(type(mnist_train))
print(type(mnist_train[0]))
print(mnist_train[0])
print(mnist_train[0][0].shape)

       可以看出mnist_train的类型为<class 'torchvision.datasets.mnist.FashionMNIST'>。mnist_train的第一个元素的类型是<class 'tuple'>,是一个元组,元组第一个元素是转化为tensor后的灰度值,第二个元素是图像所属类别index,这里是9。因为是灰度图,因此channel数量为1,图片长和宽都是28,因此形状是(1,28,28)。

       Fashion-MNIST中包含的10个类别,分别为t-shirt(T恤)、trouser(裤子)、pullover(套衫)、dress(连衣裙)、coat(外套)、sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包)和ankle boot(短靴)

       以下函数用于在数字标签索引及其文本名称之间进行转换。

def get_fashion_mnist_labels(labels):   # labels:mnist_train和mnist_test里面图像的类别index(数字)"""返回Fashion-MNIST数据集的文本标签"""text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]    # 根据index返回文本标签列表('t-shirt', 'trouser'...)

       我们现在可以创建一个函数来可视化这些样本。

def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):  #@save"""绘制图像列表""""""imgs: tensor向量num_rows: 画图时的行数num_cols: 画图时的列数titles: 每张图片的标题scales: 因为要将num_rows*num_cols张图片画到一张图上,并且还要添加一些文字,因此需要对大图进行一定的缩放才能保证每张小图之间的间隙"""figsize = (num_cols * scale, num_rows * scale)# figsize = (num_cols, num_rows)_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)axes = axes.flatten()for i, (ax, img) in enumerate(zip(axes, imgs)):if torch.is_tensor(img):# 图片张量ax.imshow(img.numpy())else:# PIL图片ax.imshow(img)ax.axes.get_xaxis().set_visible(False)ax.axes.get_yaxis().set_visible(False)if titles:ax.set_title(titles[i])return axes

       以下是训练数据集中前18个样本的图像及其相应的标签。

X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y))

2、读取小批量数据

       为了使我们在读取训练集和测试集时更容易,我们使用内置的数据迭代器,而不是从零开始创建。在每次迭代中,数据加载器每次都会读取一小批量数据,大小为`batch_size`。通过内置数据迭代器,我们可以随机打乱所有样本,从而无偏见地读取小批量。

batch_size = 256def get_dataloader_workers():  #@save"""使用4个进程来读取数据"""return 4train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers())

3、整合所有组件

       现在我们定义`load_data_fashion_mnist`函数,用于获取和读取Fashion-MNIST数据集。这个函数返回训练集和验证集的数据迭代器。此外,这个函数还接受一个可选参数`resize`,用来将图像大小调整为另一种形状。

def load_data_fashion_mnist(batch_size, resize=None):"""下载Fashion-MNIST数据集,然后将其加载到内存中"""trans = [transforms.ToTensor()]    # 此时的trans是一个列表if resize:trans.insert(0, transforms.Resize(resize))    # 如果提供了resize参数,则在转换链中插入Resize操作trans = transforms.Compose(trans)    # 将一系列的图像转换操作组合成一个转换链。# trans是一个由多个图像转换操作组成的列表。它按照列表中的顺序依次应用这些转换操作。# 这样可以将多个转换操作组合在一起,以便在加载数据时一次性应用它们。mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)return (data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers()),data.DataLoader(mnist_test, batch_size, shuffle=False,num_workers=get_dataloader_workers()))

       下面,我们通过指定`resize`参数来测试`load_data_fashion_mnist`函数的图像大小调整功能。

train_iter, test_iter = load_data_fashion_mnist(32, resize=64)
for X, y in train_iter:print(X.shape, X.dtype, y.shape, y.dtype)break
torch.Size([32, 1, 64, 64]) torch.float32 torch.Size([32]) torch.int64

三、softmax回归的从零开始实现

...

参考文献

[1]  损失函数|交叉熵损失函数

[2]  深度学习模型系列一——多分类模型——Softmax 回归-CSDN博客

[3]  Softmax 回归_哔哩哔哩_bilibili

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/214479.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux安装Nginx并部署Vue项目

今天部署了一个Vue项目到阿里云的云服务器上&#xff0c;现记录该过程。 1. 修改Vue项目配置 我们去项目中发送axios请求的文件里更改一下后端的接口路由&#xff1a; 2. 执行命令打包 npm run build ### 或者 yarn build 打包成功之后&#xff0c;我们会看到一个dist包&a…

[MySQL]SQL优化之索引的使用规则

&#x1f308;键盘敲烂&#xff0c;年薪30万&#x1f308; 目录 一、索引失效 &#x1f4d5;最左前缀法则 &#x1f4d5;范围查询> &#x1f4d5;索引列运算&#xff0c;索引失效 &#x1f4d5;前模糊匹配 &#x1f4d5;or连接的条件 &#x1f4d5;字符串类型不加 …

110. 平衡二叉树(Java)

给定一个二叉树&#xff0c;判断它是否是高度平衡的二叉树。 本题中&#xff0c;一棵高度平衡二叉树定义为&#xff1a; 一个二叉树每个节点 的左右两个子树的高度差的绝对值不超过 1 。 示例 1&#xff1a; 输入&#xff1a;root [3,9,20,null,null,15,7] 输出&#xff1a;t…

如何通过SPI控制Peregrine的数控衰减器

概要 Peregrine的数控衰减器PE4312是6位射频数字步进衰减器(DSA,Digital Step Attenuator)工作频率覆盖1MHz~4GHz,插入损耗2dB左右,衰减步进0.5dB,最大衰减量为31.5dB,高达59dBm的IIP3提供了良好的动态性能,切换时间0.5微秒,供电电源2.3V~5.5V,逻辑控制兼容1.8V,20…

​如何使用https://www.krea.ai/来实现文生图,图生图,

网址&#xff1a;https://www.krea.ai/apps/image/realtime Krea.ai 是一个强大的人工智能艺术生成器&#xff0c;可用于创建各种创意内容。它可以用来生成文本描述的图像、将图像转换为其他图像&#xff0c;甚至写博客文章。 文本描述生成图像 要使用 Krea.ai 生成文本描述…

【conda】利用Conda创建虚拟环境,Pytorch各版本安装教程(Ubuntu)

TOC conda 系列&#xff1a; 1. conda指令教程 2. 利用Conda创建虚拟环境&#xff0c;安装Pytorch各版本教程(Ubuntu) 1. 利用Conda创建虚拟环境 nolonolo:~/sun/SplaTAM$ conda create -n splatam python3.10查看结果&#xff1a; (splatam) nolonolo:~/sun/SplaTAM$ cond…

Linux系统编程(一):基本概念

参考引用 Unix和Linux操作系统有什么区别&#xff1f;一文带你彻底搞懂posix Linux系统编程&#xff08;文章链接汇总&#xff09; 1. Unix 和 Linux 1.1 Unix Unix 操作系统诞生于 1969 年&#xff0c;贝尔实验室发布了一个用 C 语言编写的名为「Unix」的操作系统&#xff0…

【基于LSTM的电商评论情感分析:Flask与Sklearn的完美结合】

基于LSTM的电商评论情感分析&#xff1a;Flask与Sklearn的完美结合 引言数据集与爬取数据处理与可视化情感分析模型构建Flask应用搭建词云展示创新点结论 引言 在当今数字化时代&#xff0c;电商平台上涌现出大量的用户评论数据。了解和分析这些评论对于企业改进产品、服务以及…

以太坊:前世今生与未来

一、引言 以太坊&#xff0c;这个在区块链领域大放异彩的名字&#xff0c;似乎已经成为了去中心化应用&#xff08;DApps&#xff09;的代名词。从初期的萌芽到如今的繁荣发展&#xff0c;以太坊经历了一段曲折而精彩的旅程。让我们一起回顾一下以太坊的前世今生&#xff0c;以…

树实验代码

哈夫曼树 #include <stdio.h> #include <stdlib.h> #define MAXLEN 100typedef struct {int weight;int lchild, rchild, parent; } HTNode;typedef HTNode HT[MAXLEN]; int n;void CreatHFMT(HT T); void InitHFMT(HT T); void InputWeight(HT T); void SelectMi…

【Qt开发流程】之UI风格、预览及QPalette使用

概述 一个优秀的应用程序不仅要有实用的功能&#xff0c;还要有一个漂亮美腻的外观&#xff0c;这样才能使应用程序更加友善、操作性良好&#xff0c;更加符合人体工程学。作为一个跨平台的UI开发框架&#xff0c;Qt提供了强大而且灵活的界面外观设计机制&#xff0c;能够帮助…

利用Rclone将阿里云对象存储迁移至雨云对象存储的教程,对象存储数据迁移教程

使用Rclone将阿里云对象存储(OSS)的文件全部迁移至雨云对象存储(ROS)的教程&#xff0c;其他的对象存储也可以参照本教程。 Rclone简介 Rclone 是一个用于和同步云平台同步文件和目录命令行工具。采用 Go 语言开发。 它允许在文件系统和云存储服务之间或在多个云存储服务之间…

STM32-EXTI外部中断

目录 一、中断系统 二、STM32中断 三、NVIC&#xff08;嵌套中断向量控制器&#xff09;基本结构 四、NVIC优先级分组 五、EXTI外部中断 5.1 外部中断基本知识 5.2 外部中断&#xff08;EXTI&#xff09;基本结构 ​编辑 5.2.1开发步骤&#xff1a; 5.3 AFIO复用IO口…

ADAudit Plus:强大的网络安全卫士

随着数字化时代的不断发展&#xff0c;企业面临着越来越复杂和多样化的网络安全威胁。在这个信息爆炸的时代&#xff0c;保护组织的敏感信息和确保网络安全已经成为企业发展不可或缺的一环。为了更好地管理和监控网络安全&#xff0c;ADAudit Plus应运而生&#xff0c;成为网络…

Redis分布式锁有什么缺陷?

Redis分布式锁有什么缺陷&#xff1f; Redis 分布式锁不能解决超时的问题&#xff0c;分布式锁有一个超时时间&#xff0c;程序的执行如果超出了锁的超时时间就会出现问题。 1.Redis容易产生的几个问题&#xff1a; 2.锁未被释放 3.B锁被A锁释放了 4.数据库事务超时 5.锁过期了…

centos 7 卸载图形化界面步骤记录

centos7 服务器操作系统&#xff0c;挺小一配置&#xff0c;装了图形化界面&#xff0c;现在运行程序的时候跑不动了&#xff0c;我想这图形界面也没啥用&#xff0c;卸载了算了&#xff01; 卸载步骤 yum grouplist 查询已经安装的组件 可以看到 图形化界面 等是以分组存在的…

TCP数据粘包的处理

TCP数据粘包的处理 背锅侠TCP解决方案2.1 发送端2.2 接收端 背锅侠TCP 在前面介绍套接字通信的时候说到了TCP是传输层协议&#xff0c;它是一个面向连接的、安全的、流式传输协议。因为数据的传输是基于流的所以发送端和接收端每次处理的数据的量&#xff0c;处理数据的频率可…

Qt练习题

1.使用手动连接&#xff0c;将登录框中的取消按钮使用qt4版本的连接到自定义的槽函数中&#xff0c;在自定义的槽函数中调用关闭函数 将登录按钮使用qt5版本的连接到自定义的槽函数中&#xff0c;在槽函数中判断ui界面上输入的账号是否为"admin"&#xff0c;密码是否…

【Angular开发】Angular 16发布:发现前7大功能

Angular 于2023年5月3日发布了主要版本升级版Angular 16。作为一名Angular开发人员&#xff0c;我发现这次升级很有趣&#xff0c;因为与以前的版本相比有一些显著的改进。 因此&#xff0c;在本文中&#xff0c;我将讨论Angular 16的前7个特性&#xff0c;以便您更好地理解。…

机器学习基础介绍

百度百科&#xff1a; 机器学习是一门多领域交叉学科&#xff0c;涉及概率论、统计学、逼近论、凸分析、算法复杂度理论等多门学科。专门研究计算机怎样模拟或实现人类的学习行为&#xff0c;以获取新的知识或技能&#xff0c;重新组织已有的知识结构使之不断改善自身的性能。 …