动手学深度学习——softmax分类

1. 分类问题

回归与分类的区别:

  • 回归可以用于预测多少的问题, 比如"预测房屋被售出价格",它是个单值输出。
  • softmax可以用来预测分类问题,例如"某个图片中是猫、鸡还是狗?",这是一个多值输出,输出个数等于类别个数,输出的第i个值表示预测为第i类别的概率。

两者的区别在于是问多少还是问哪一个?

分类可以用来描述下面两个问题:

  1. 样本属于哪个类别
  2. 样本属于每个类别的概率

比较经典的分类问题有:

  1. MNIST数据集,手写数字识别,有0-9十个类别。
  2. ImageNet数据集,从一百万张图片中识别自然物体,有1000个类别。
  3. kaggle上的恶意软件类别识别。
  4. 区分淘宝商品的评论是正面还是负面评论。

2. 分类编码

由于自然语言表示的类别不方便运算,所以为了计算的需要,有必要对类别进行编码。

对于分类问题,最常用的编码方式为一位有效编码,也称为独热编码(one-hot encoding)。它可以表示为一个向量,长度等于类别数量,向量中只有一个特征为1,其它特征均为0。

这里我们以一个图像分类问题为例来讨论, 假设要预测一张图片是猫、鸡还是狗,那么我们对这三种类别进行一位有效编码的形式如下:

  • (1,0, 0)对应于“猫”
  • (0,1,0)对应于“鸡”
  • (0,0,1)对应于“狗”
  1. 正确类别对应的分量设置为1,其它所有分量均为0.
  2. 类别数量等于分量数量(这里的分量是指向量在具体一个维度上的值)

分类问题对模型的要求:正确类的置信度要远远大于非正确类的置信度,即Oy > Oi。

相比具体每个类别的预测值大小,我们更关心正确类别的预测值是否远大于其它非正确类别的预测值,只有这样,才能表明模型能真正区分出正确类别。

3. 网络架构

与线性回归一样,softmax回归也是一个单层神经网络。

接着上面的例子,假设每次输入是一个2*2的灰度图像,我们可以用一个标量表示每个像素值,每个图像对应四个特征[x1,x2,x3,x4]。

我们可以定义输出向量y=[o1,o2,o3], 其中o1、o2、o3分别表示输入i是猫、鸡、狗的预测值大小。

由于我们有4个特征和3个可能的输出类别, 我们将需要12个标量来表示权重w, 3个标量来表示偏置b。则每个类别的计算可以表示为:
在这里插入图片描述

由于计算每个输出o1、o2和o3取决于 所有输入x1、x2、x3和x4, 所以softmax回归的输出层也是全连接层。
在这里插入图片描述
如同线性回归一样,可以将计算公式简洁表示,o = Wx + b。这是将所有权重放到一个W矩阵中。 对于给定数据样本的特征x, 我们的输出y是由权重W与输入特征x进行矩阵-向量乘法再加上偏置b得到。

4. 输出概率化

对于分类问题,我们希望模型的输出yj可以视为它属于类别j的概率,然后只需要选择具有最大输出值的类别argmax(xj,yj) 作为我们的预测即可, 这样能同时方便人脑理解和算术运算。

例如,如果为猫、鸡和狗的概率分别为0.1、0.8和0.1, 因为0.8概率最大,所以我们预测的类别是2,在我们的例子中代表“鸡”。

这里之所以要进行标准化概率计算,而不直接将预测o作为输出,其原因在于将线性层的输出o视作概率会存在一些问题:

  1. 线性层输出没有限制输出数字的总和为1,不符合概率分布。
  2. 根据输入的不同,线性层的输出是可以为负值的,会影响我们的计算。

要将输出视为概率,我们必须保证以下两点:

  1. 在任何类别上的输出都是非负
  2. 所有类别的预测值总和为1。

而softmax函数则正好能够将未规范化的预测变换为非负数并且总和为1,同时让模型保持可导的性质。它的作法为:

  • 对每个未规范化的预测求幂(指数),这样可以确保输出非负
  • 让每个求幂后的结果除以它们的总和,就能确保最终输出的概率值总和为1
    在这里插入图片描述通过对输出向量o进行softmax运算后,预测值就是一个概率分布。

而真实的值经过独热编码后也符合这个特征,因为它也符合概率的特性:

  1. 非负数:只有0和1两种值;
  2. 和为1:只有一个值为1,其它均为0;

这样就得到两个概率:预测值概率和真实值概率。接下来,就可以比较两个概率来作为损失。

5. 损失计算

交叉熵损失:用来衡量两个概率分布之间的差异。

对于分类问题,我们不关心非正确类别的预测值,只关心对正确类的预测值置信度有多大。

假设模型对每个类别的预测概率分别是0.7、0.2和0.1,实际该样本属于第一个类别。交叉熵损失会根据模型对第一个类别的预测概率和实际概率来计算一个损失值。用数学表示如下:

H(p, q) = -Σ p(x) * log(q(x))
  • p(x)表示实际的概率分布,q(x)表示模型预测的概率分布。
  • 前面加负号的目的是为了保证交叉熵为正值。log(q(x))的值通常是小于0的(小于1时,对数为负数),p(x)是一个概率值,介于0和1之间。
  • 交叉熵越小,表示两个概率分布越接近,模型的预测效果越好。

可以把交叉熵H(P,Q)想象为“主观概率为Q
的观察者在看到根据概率P生成的数据时的意外程度”。 当P=Q时,这种意外程度降到最低。

训练的目的:最小化交叉熵来优化模型的参数,使得模型的预测结果更接近于实际标签。

由于真实值p(x)是一个独热编码向量,只有一项为1,其它项均为0,所以这里的交叉熵又可以简写成:
在这里插入图片描述
所以,对于分类问题来说,我们不关心非正确类别的预测值,只关心正确类别的预测值有多大。

而梯度则是预测概率与真实概率之间的差异,损失函数对输出o求导为:
在这里插入图片描述

softmax回归模型训练的目标:给出任何样本特征,我们可以预测每个输出类别的概率。 通常我们使用预测概率最高的类别作为输出类别。 如果预测与实际类别(标签)一致,则预测是正确的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/5373.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

深入理解正则表达式:从入门到精通

title: 深入理解正则表达式:从入门到精通 date: 2024/4/30 18:37:21 updated: 2024/4/30 18:37:21 tags: 正则Python文本分析日志挖掘数据清洗模式匹配工具推荐 第一章:正则表达式入门 介绍正则表达式的基本概念和语法 正则表达式是一种用于描述字符串…

Bert基础(二十一)--Bert实战:文本摘要

一、介绍 1.1 文本摘要简介 文本摘要(Text Summarization),作为自然语言处理(NLP)领域的一个分支,其核心目标是从长篇文档中提取关键信息,并生成简短的摘要,以提供对原始内容的高度…

Go语言map

map 概念 在Go语言中,map 是一种内建的数据结构,它提供了一种关联式的存储机制,允许你以键值对的形式存储数据。每个键都是唯一的,并且与一个值相关联。你可以通过键来查找、添加、更新和删除值,这类似于其他编程语言…

MAKEFILE 从易到难

相信一个简单的makefile, 只要用过C语言的都能写出来。 但是如果工程中包含了几十个文件夹, 上万个文件, 那用一般的方式就搞不定了。 在用dpdk 的时候, 会经常修改makefile要适配我们的工程。 最开始也是用dpdk中自带的makefil…

wpf 树形结构

Simplifying the WPF TreeView by Using the ViewModel Pattern - CodeProject 【原创】WPF TreeView带连接线样式的优化(WinFrom风格) - iDream2016 - 博客园 (cnblogs.com)

Android 音视频播放器 Demo(二)—— 音频解码与音视频同步

音视频编解码系列目录: Android 音视频基础知识 Android 音视频播放器 Demo(一)—— 视频解码与渲染 Android 音视频播放器 Demo(二)—— 音频解码与音视频同步 RTMP 直播推流 Demo(一)—— 项目…

selenium截屏代码

六、截屏应用场景:失败截图,让错误看的更直观方法: driver.get_screenshot_as_file(imgepath)参数:imagepath:为图片要保存的目录地址及文件名称如: 当前目录 ./test.png上一级目录 ../test.png扩展&#x…

Qt+Ubuntu20.04:打包qt

打包程序 参考 qt项目在Linux平台上面发布成可执行程序.run_qt.run不是虚拟机的配置文件-CSDN博客 Linux下Qt程序的打包发布(1)-不使用第三方工具 - 知乎 (zhihu.com) 过程 1、Release编译 先将你的程序在release下编译通过,保证下面打包的程序是你最新的。 2…

C#调用skiasharp操作并绘制图片

之前学习ViewFaceCore时采用Panel控件和GDI将图片及识别出的人脸方框和关键点绘制出来,本文将其修改为基于SKControl和SKCanvas实现相同的显示效果并支持保存为本地图片。   新建Winform项目,在Nuget包管理器中搜索并安装一下SkiaSharp和ViewFaceCore…

【AI工具合集】图片、文本、音视频工具与A I岗位面试资料

1、AI 工具集合 全球最新热门 Al 工具, AI 工具整合包,可以下载并在 Windows 系统私有化本地化运行,包括图片、文本、视频、音频等工具资源,按照功能、业务和行业来分类。 1.1 AI 图片工具 MoneyPrinter:一键生成短…

HTTP 多个版本

了解一下各个版本的HTTP。 上个世纪90年代初期,蒂姆伯纳斯-李(Tim Berners-Lee)及其 CERN的团队共同努力,制定了互联网的基础,定义了互联网的四个构建模块: 超文本文档格式(HTML) …

Linux基础——Linux开发工具(上)_vim

前言:在了解完Linux基本指令和Linux权限后,我们有了足够了能力来学习后面的内容,但是在真正进入Linux之前,我们还得要学会使用Linux中的几个开发工具。而我们主要介绍的是以下几个: yum, vim, gcc / g, gdb, make / ma…

【初识Redis】

初识Redis Redis(Remote Dictionary Server)是一个开源的内存数据库,它提供了一个高性能的键值存储系统,并且支持多种数据结构,包括字符串、哈希、列表、集合和有序集合等。Redis的特点包括: 内存存储&…

bottom-up-attention.pytorch

环境 torch1.5cu 101cp38 on 2080ti # clone the repository inclduing Detectron2(be792b9) $ git clone --recursive https://github.com/MILVLG/bottom-up-attention.pytorch$ cd detectron2 $ pip install -e . $ cd .. detectron2直接克隆有问题,需要把det…

C语言实验-数组、字符串以及指针

一&#xff1a; 求一个NN矩阵主、次对角线上所有元素之和。矩阵输入、矩阵输出、矩阵对角线求和分别用三个子函数实现。&#xff08;N的值由用户从键盘输入&#xff09; #define _CRT_SECURE_NO_WARNINGS #include<stdio.h> #include<stdlib.h>void print(int(*arr…

有哪些好用的局域网电脑监控系统软件?

企业员工不好管理&#xff1f;&#xff1f;&#xff1f; 局域网已成为企业日常运营不可或缺的一部分。 然而&#xff0c;随着网络技术的普及&#xff0c;员工在局域网中的不当行为也日益增多&#xff0c;如滥用网络资源、泄露敏感信息、消极怠工等。 为了解决这些问题&#x…

植物大战僵尸杂交版

1.感谢作者潜艇伟伟迷 2.大小大概110M&#xff0c;下载链接在下方 链接&#xff1a;https://pan.baidu.com/s/1Ew6iTg0_d_Ut8N9_18KGLw 提取码&#xff1a;yspa 3.祝大家玩的开心

嵌入式学习——C语言基础——day13

1. 结构体类型的定义 struct 类型名 { 数据类型1 成员变量1; 数据类型2 成员变量2; 数据类型3 成员变量3; ... }; 定义结构体中可以使用的数据类型有 1.基本数据类型&#xff1a;int long short char doub…

C++-10

1.C一个程序&#xff0c;实现两个类&#xff0c;分别存放输入的字符串中的数字和字母&#xff0c;并按各自的顺序排列&#xff0c; 类中实现-一个dump函数&#xff0c;调C用后输出类中当前存放的字符串结果。 例如&#xff0c;输入1u4y2a3d,输出:存放字母的类&#xff0c;输出a…

Mybatis-plus对单表操作的封装

MyBatis-Plus单表操作详解及拓展 MyBatis-Plus是一个基于MyBatis的增强工具&#xff0c;它提供了丰富的CRUD操作和分页查询等功能&#xff0c;极大地简化了开发人员的数据库操作。本文将详细介绍MyBatis-Plus官方已经写好的单表操作&#xff0c;并提供一些拓展内容。 1. 引言…