3.Softmax回归

回归和分类

回归估计一个连续值

分类预测一个离散类别

Softmax回归实际是一个分类问题

在这里插入图片描述

从回归到多类分类

对类别进行一位有效编码

y = [ y 1 , y 2 , ⋯ , y n ] T y=[y_1,y_2,\cdots,y_n]^T y=[y1,y2,,yn]T,如果是第i类,则值为1,否则为0

使用均方损失训练,最大值预测为(即softmax函数)
y ^ = a r g m a x i o i \hat y = argmax_i\ o_i y^=argmaxi oi
需要更置信的识别正确类(大余量):

o y − o i ≥ Δ ( y , i ) o_y -o_i\ge \Delta(y,i) oyoiΔ(y,i)

校验比例

输出匹配概率(非负,和为1)
y ^ = s o f t m a x ( o ) y ^ i = e x p ( o i ) ∑ k e x p ( o k ) \hat y = softmax(o)\\ \hat y_i =\frac{exp(o_i)}{\sum_k exp(o_k)} y^=softmax(o)y^i=kexp(ok)exp(oi)
概率 y y y y ^ \hat y y^的区别作为损失

交叉熵损失

交叉熵用来衡量两个概率的区别 H ( p , q ) = ∑ i − p i l o g ( q i ) H(p,q)=\sum_i - p_ilog(q_i) H(p,q)=ipilog(qi)

将它作为损失函数:
l ( y , y ^ ) = − ∑ i y i l o g y ^ i = − l o g y ^ y (假设是第 y 类) l(y,\hat y)=-\sum_i y_ilog\hat y_i = -log \hat y_y (假设是第y类) l(y,y^)=iyilogy^i=logy^y(假设是第y类)
​ 关心正确类的预测值

其梯度是真实概率和预测概率的区别
∂ o i l ( y , y ^ ) = s o f t m a x ( o ) i − y i \partial_{o_i}l(y,\hat y) =softmax(o)_i -y_i oil(y,y^)=softmax(o)iyi

损失函数

均方损失(L2 Loss)


l ( y , y ′ ) = 1 2 ( y − y ′ ) 2 l(y,y')=\frac 12 (y-y')^2 l(y,y)=21(yy)2
​ 在梯度下降时,预测值与真实值相差较远时,梯度会较大,但在离原点比较远时,可能并不希望有较大的梯度,这种情况下可以使用L1 Loss。

绝对值损失(L1 Loss)

l ( y , y ′ ) = ∣ y − y ′ ∣ l(y,y')=|y-y'| l(y,y)=yy

​ 好处就是,无论离原点多远,梯度下降时的导数都是正负1,但在比较接近时,可能就出现振荡了

Huber’s Robust Loss

​ 结合两种的好处
KaTeX parse error: Unknown column alignment: * at position 32: … \begin{array}{*̲*lr**} |y-y'|-\…

读取多类分类的数据集

图像分类数据集

​ 使用Fashion-MNIST数据集

import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l# 看一下图片的形状def get_fashion_mnist_labels(labels):"""返回Fashion-MNIST数据集的文本标签"""text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):"画图"figsize = (num_cols * scale, num_rows * scale)fig, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)axes = axes.flatten()for i, (ax, img) in enumerate(zip(axes, imgs)):if torch.is_tensor(img):# 是图片张量ax.imshow(img.numpy())else:# PIL图片ax.imshow(img)ax.axes.get_xaxis().set_visible(False)ax.axes.get_yaxis().set_visible(False)if titles:ax.set_title(titles[i])d2l.plt.show()  # 加上show图片才会显示return axesdef get_dataloader_workers():'''使用4个进程来读取数据'''return 4def load_data_fashion_mnist(batch_size, resize=None):  #resize可以改变图片的大小"""下载Fashion-MNIST数据集,然后将其加载到内存中"""trans = [transforms.ToTensor()]# 将图片转换成tensor
# 将图片下载,train表示是训练数集,transform表示是tensor而不是图片,download表示从网上下载if resize:trans.insert(0, transforms.Resize(resize))trans = transforms.Compose(trans)# 将图片下载,train表示是训练数集,transform表示是tensor而不是图片,download表示从网上下载mnist_train = torchvision.datasets.FashionMNIST(root="./data", train=True, transform=trans, download=True)# 训练数据集的下载,则train是Falsemnist_test = torchvision.datasets.FashionMNIST(root="./data", train=False, transform=trans, download=True)print(len(mnist_train))print(len(mnist_test))print(mnist_train[0][0].shape)  # 黑白图片,所以channel为1,train[0]表示取第一个元素,第二个[0]表示是取图片,[1]表示取标签return (data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers()),data.DataLoader(mnist_test, batch_size, shuffle=False,num_workers=get_dataloader_workers()))d2l.use_svg_display()  # 使用svg来显示图片
# 通过ToTenseor实例将图像数据从PIL类型变换成32位浮点数格式
# 并除以255使得所有像素的值均在0到1之间# 将数据集放进dataloader里面,指定一个batch_size,我们就可以得到一个批次的数据
X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
# show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y))batch_size = 256train_iter = data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=get_dataloader_workers())
timer = d2l.Timer()
for X, y in train_iter:continueprint(f'{timer.stop():.2f} seconds')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/45313.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

用户对生活的需求,是三翼鸟创新的起点

这两天又长知识了,学到了一个网络新梗:City不City。 它源自于一种新的打卡方式,用于表达对某个城市的态度或感受。比如你跟朋友在城市游荡时,就可以随口问句City不City啊?通常被释义为“洋不洋气”“ 时髦不时髦”。 …

cpp的cbp

.cbp 文件是 Code::Blocks 的项目文件。Code::Blocks 是一个开源的跨平台集成开发环境(IDE),主要用于 C、C 以及 Fortran 编程。.cbp 文件包含有关项目的所有配置信息,包括文件路径、编译选项、链接器设置等。 以下是 .cbp 文件的…

部署YUM仓库及NFS共享功能

目录 一、YUM仓库服务 1、YUM仓库概述 2、准备安装源 2.1、软件仓库的提供方式 2.2、 RPM软件包的来源 3、YUM主配置文件 4、软件卸载 5、YUM源的提供方式 5.1、配置本地YUM源仓库 5.2、配置ftp源 5.2.1、服务端配置 5.2.2、客户端配置 二、NFS共享存储 1、NFS基…

Git 删除包含敏感数据的历史记录及敏感文件

环境 Windows 10 Git 2.41.0 首先备份你需要删除的文件(如果还需要的话),因为命令会将本地也删除将项目中修改的内容撤回或直接提交到仓库中(有修改内容无法提交) 会提示Cannot rewrite branches: You have unstaged …

免费流程图工具 Draw.io Integration安装使用

Draw.io Integration 是 VS Code 上的一个插件,允许用户在 VS Code 中直接创建、编辑和查看 Draw.io 图表,如流程图、UML 图等。以下是 Draw.io Integration 插件在 VS Code 中的安装步骤: 安装步骤 确保 VS Code 已安装: 如果你…

YOLOv10训练自己的数据集(交通标志检测)

YOLOv10训练自己的数据集(交通标志检测) 前言相关介绍前提条件实验环境安装环境项目地址LinuxWindows 使用YOLOv10训练自己的数据集进行交通标志检测准备数据进行训练进行预测进行验证 参考文献 前言 由于本人水平有限,难免出现错漏&#xff…

每日一道算法题 204. 计数质数

题目 204. 计数质数 - 力扣(LeetCode) Python class Solution:def countPrimes(self, n: int) -> int:"""质数又称为素数,是一个大于1的自然数,除了1和它自身外,不能被其他自然数整除的数叫做质数…

【C++题解】1156 - 排除异形基因

问题:1156 - 排除异形基因 类型:数组基础 题目描述: 神舟号飞船在完成宇宙探险任务回到地球后,宇航员张三感觉身体不太舒服,去了医院检查,医生诊断结果:张三体内基因已被改变,原有…

Vscode连接存在私钥的远程服务器

编辑配置文件 # Read more about SSH config files: https://linux.die.net/man/5/ssh_configHost 172.17.x.xxxHostName 172.17.x.xxxUser xxxIdentityFile C:\Users\xxx\.ssh\xxx.pem会出现报错: Permissions 0644 for xxxx are too open. It is required that …

XML Schema 指示器

XML Schema 指示器 1. 引言 XML Schema 是一种用于定义 XML 文档结构和内容的语言。它提供了一种强大的方式来描述 XML 文档中允许的元素、属性和数据类型。XML Schema 指示器是在 XML Schema 定义中使用的一些特殊元素和属性,它们用于指示 XML 处理器如何解析和验证 XML 文…

vue-router路由路径

在配置 vue-router 路由时,path: ‘search’ 和 path: ‘/search’ 有不同的行为: 1.path: ‘search’: 这是一个相对路径。相对路径意味着这个路径是相对于父路径的。如果父路径是 /emergency,那么这个路由的完整路径是 /emergency/search…

QT 报错C2872: “byte“: 不明确的符号

这个错误提示是因为 byte 这个符号不明确,这种情况是由于代码中同时包含了多个同名符号的定义,编译器无法区分,从而导致错误。在这个问题中,可能是由于使用了 Winsock2.h 头文件中定义的 byte 宏与其他地方定义的 byte 符号重名&a…

Android Bitmap

在Android开发中,位图(Bitmap)是一个非常重要的图形处理对象,它用于在内存中存储图像数据。以下是关于Android中位图使用的一些关键点和方法: 一、获取位图 从资源文件中获取: 使用BitmapFactory类&#…

头歌资源库(24)插入加号

一、 问题描述 二、算法思想 可以使用动态规划来解决这个问题。 首先将数字串拆分为多个数字,用一个数组nums来存储每个数字。例如,数字串79846会被拆分为数组[7, 9, 8, 4, 6]。 然后定义一个二维数组dp,其中dp[i][j]表示在前i个数字中插入…

Java异常体系、UncaughtExceptionHandler、Spring MVC统一异常处理、Spring Boot统一异常处理

概述 所有异常都是继承自java.lang.Throwable类,Throwable有两个直接子类,Error和Exception。 Error用来表示程序底层或硬件有关的错误,这种错误和程序本身无关,如常见的NoClassDefFoundError。这种异常和程序本身无关&#xff0…

Java网络模型全扫盲

概述 讲述ava层面NIO基础知识,用作基础回顾所用 1. NIO概述 ​ 在Java中,NIO(Non-blocking I/O 或 New I/O)是Java SE 1.4及后续版本中引入的一套新的输入/输出操作API。 ​ 它与传统的IO模型相比,提供了更高的效率和…

【算法】二叉树-迭代法实现前后中序遍历

递归的实现就是:每一次递归调用都会把函数的局部变量,参数值和返回地址等压入调用栈中,然后递归返回的时候,从栈顶弹出上一次递归的各项参数,这就是递归为什么可以返回上一层位置的原因 可以用栈实现二叉树的前中后序遍历 1. 前序…

FastAPI 学习之路(四十四)WebSockets

我们之前的分析都是基于http的请求,那么如果是websockets可以支持吗,答案是可以的,我们来看下是如何实现的。 from fastapi import WebSocket, FastAPI from fastapi.responses import HTMLResponseapp FastAPI()html """&…

k8s NetworkPolicy

Namespace 隔离 默认情况下,所有 Pod 之间是全通的。每个 Namespace 可以配置独立的网络策略,来 隔离 Pod 之间的流量。 v1.7 版本通过创建匹配所有 Pod 的 Network Policy 来作为默认的网络策略 默认拒绝所有 Pod 之间 Ingress 通信 apiVersion: …

【趣味数学】求阴影部分面积

题 解法1: 中位线法 既然是中点,就可以用起来,横着不行,竖着来,扩展做辅助线 E是中点S(AED) 1/4 S(ABCD) 6 做图中辅助延长线,因为E中点,所以S(MEB)S(AED) 6 同理E也是…