深度学习(五)softmax 回归之:分类算法介绍,如何加载 Fashion-MINIST 数据集

Softmax 回归

基本原理

回归和分类,是两种深度学习常用方法。回归是对连续的预测(比如我预测根据过去开奖列表下次双色球号),分类是预测离散的类别(手写语音识别,图片识别)。

1699720169075

现在我们已经对回归的处理有一定的理解了,如何过渡到分类呢?

假设我们有 n 类,首先我们要编码这些类让他们变成数据。所有类变成一个列向量。

y = [ y 1 , y 2 , . . . y n ] T y=[y_1,y_2,...y_n]^T y=[y1,y2,...yn]T

有一个数据属于第 i 类,那么他的列向量就是:

y = [ 0 , 0 , . . . , 1 , . . . , 0 , 0 ] T y=[0,0,...,1,...,0,0]^T y=[0,0,...,1,...,0,0]T

也就是只有他所在的那个类的元素=1.

可以用均方损失训练,通过概率判断最终选用哪一个。

Softmax 回归就是一种分类方式(回归问题在多分类上的推广)。首先确定输入特征数和输出类别数。比如上图中我们有4个特征和3个可能的类别,那么计算各自概率的公式包括3个线性回归:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

可以看出 Softmax 是全连接的单层神经网络。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

我们让所有输出结果归一化后,从中选择出最大可能的,置信度最高的分类结果。

image-20231112100423488

采用 e 的指数可以让值全变为非负。

用真实的概率向量-我们预测得到的概率向量就是损失。真实值就是只有一个1的列向量。

交叉熵损失:

image-20231112101259670

可见**分类问题,我们不关心对非正确的预测值,只关心正确预测值是否足够大。**因为正确值是只有一个元素为1的列向量。

常用的损失函数

L2 Loss:均方损失。

image-20231112101555142

L1 Loss:绝对值损失。

image-20231112101829868

L2 梯度是一条倾斜直线,对于梯度下降算法等更为合适;L1 是一个跳变,梯度要么 -1 要么 1. 如图是 L1 L2 的梯度。

image-20231112102551104

我们可以结合两者,得到一个新的损失函数(鲁棒损失 Huber Robust):

KaTeX parse error: {equation} can be used only in display mode.

image-20231112102721527

图像分类数据集

MINIST 是一个常用图像分类数据集,但是过于简单。后来的 upgrade 版叫 Fashion-MINIST(服装分类).

首先,我们研究研究怎么加载训练数据集,以便后面测试算法用。

# 导包
%matplotlib inline
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2ld2l.use_svg_display()d2l.use_svg_display()# 下载数据集并读取到内存
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)		# 训练数据集
mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)	# 测试数据集用于评估性能# 定义函数用于返回对应索引的标签
def get_fashion_mnist_labels(labels):  #@save"""返回Fashion-MNIST数据集的文本标签"""text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]# 图像可视化,让结果看着更直观,比如下面那个绿色图的样子
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):  #@save"""绘制图像列表"""figsize = (num_cols * scale, num_rows * scale)_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)axes = axes.flatten()for i, (ax, img) in enumerate(zip(axes, imgs)):if torch.is_tensor(img):# 图片张量ax.imshow(img.numpy())else:# PIL图片ax.imshow(img)ax.axes.get_xaxis().set_visible(False)ax.axes.get_yaxis().set_visible(False)if titles:ax.set_title(titles[i])return axes# 我们先读一点数据集看看啥样的
X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));

1699980345931

# 通过内置数据加载器读取一批量数据,自动随机打乱读取,不需要我们自己定义
batch_size = 256def get_dataloader_workers():  #@save"""使用4个进程来读取数据"""return 4train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers())

测量以上用时基本2-3s。

总结整合以上数据读取过程,代码如下:

def load_data_fashion_mnist(batch_size, resize=None):  #@save"""下载Fashion-MNIST数据集,然后将其加载到内存中"""trans = [transforms.ToTensor()]if resize:trans.insert(0, transforms.Resize(resize))trans = transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)return (data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers()),data.DataLoader(mnist_test, batch_size, shuffle=False,num_workers=get_dataloader_workers()))

加载图像还可以调整其大小。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/146831.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java多线程下使用TransactionTemplate控制事务

简介 本文展示了在Java的多线程环境下使用Spring的TransactionTemplate控制事务的提交与回滚,当任何一个子线程出现异常时,所有子线程都将回滚 环境 JDK:1.8.0_211 SpringBoot:2.5.10 说明 本文通过同时写入用户(User)和用户详细…

zookeperkafka学习

1、why kafka 优点 缺点kafka 吞吐量高,对批处理和异步处理做了大量的设计,因此Kafka可以得到非常高的性能。延迟也会高,不适合电商场景。RabbitMQ 如果有大量消息堆积在队列中,性能会急剧下降每秒处理几万到几十万的消息。如果…

记录一些涉及到界的题

文章目录 coppersmith的一些相关知识题1 [N1CTF 2023] e2Wrmup题2 [ACTF 2023] midRSA题3 [qsnctf 2023]浅记一下 coppersmith的一些相关知识 上界 X c e i l ( 1 2 ∗ N β 2 d − ϵ ) X ceil(\frac{1}{2} * N^{\frac{\beta^2}{d} - \epsilon}) Xceil(21​∗Ndβ2​−ϵ) …

【Proteus仿真】【STM32单片机】防火防盗GSM智能家居设计

文章目录 一、功能简介二、软件设计三、实验现象联系作者 一、功能简介 本项目使用Proteus8仿真STM32单片机控制器,使用声光报警模块、LCD1602显示模块、DS18B20温度、烟雾传感器模块、按键模块、PCF8591 ADC模块、红外检测模块等。 主要功能: 系统运行…

Mac电脑好用的窗口管理软件 Magnet 中文for mac

Magnet是一款用于Mac操作系统的窗口管理工具,它可以帮助您快速和方便地组织和管理应用程序窗口,以提高您的工作效率和多任务处理能力。 以下是Magnet的一些主要功能和特点: 窗口自动调整:Magnet允许您通过简单的拖放操作或使用快…

微信小程序 限制字数文本域框组件封装

微信小程序 限制字数文本域框 介绍&#xff1a;展示类组件 导入 在app.json或index.json中引入组件 "usingComponents": {"text-field":"/pages/components/text-field/index"}代码使用 <text-field maxlength"500" bindtabsIt…

IDEA创建文件添加作者及时间信息

前言 当使用IDEA进行软件开发时&#xff0c;经常需要在代码文件中添加作者和时间信息&#xff0c;以便更好地维护和管理代码。 但是如果每次都手动编辑 以及修改那就有点浪费时间了。 实践 其实我们可以将注释日期 作者 配置到 模板中 同时配置上动态获取内容 例如时间 这样…

PaddleClas学习2——使用PPLCNet模型对车辆朝向进行识别(python)

使用PPLCNet模型对车辆朝向进行识别 1. 配置PaddlePaddle,PaddleClas环境2. 准备数据2.1 标注数据格式2.2 标注数据3. 模型训练3.1 修改配置文件3.2 训练、评估4 模型预测1. 配置PaddlePaddle,PaddleClas环境 安装:请先参考文档 环境准备 配置 PaddleClas 运行环境。 2. 准…

美国服务器:全面剖析其主要优点与潜在缺点

​  服务器是网站搭建的灵魂。信息化的今天&#xff0c;我们仍需要它来为网站和应用程序提供稳定的运行环境。而美国作为全球信息技术靠前的国家之一&#xff0c;其服务器市场备受关注。那么&#xff0c;美国服务器究竟有哪些主要优点和潜在缺点呢? 优点 数据中心基础设施&a…

Vue 路由缓存 防止路由切换数据丢失

在切换路由的时候&#xff0c;如果写好了一丢数据在去切换路由在回到写好的数据的路由去将会丢失&#xff0c;这时可以使用路由缓存技术进行保存&#xff0c;这样两个界面来回换数据也不会丢失 在 < router-view >展示的内容都不会被销毁&#xff0c;路由来回切换数据也…

Linux上编译和安装SOFA23.06

前言 你可以直接使用编译安装好的SOFA版本Installing from all-included binaries (v23.06.00)&#xff1a; 如果你想自己编译&#xff0c;可以看我下面写的内容&#xff0c;不过绝大多数是从官网来的&#xff0c;如果和官网有出入&#xff0c;建议还是以官网为准。 在Linux下…

PVE Win平台虚拟机下如何安装恢复自定义备份Win系统镜像ISO文件(已成功实现)

环境: Virtual Environment 7.3-3 Win s2019 UltraISO9.7 USM6.0 NTLite_v2.1.1.7917 问题描述: PVE Win平台虚拟机下如何安装恢复自定义备份Win系统镜像ISO文件 本次目标 主要是对虚拟机里面Win系统备份做成可安装ISO文件恢复至别的虚拟机或者实体机上 解决方案: …

【数据库】数据库连接池导致系统吞吐量上不去-复盘

在实际的开发中&#xff0c;我们会使用数据库连接池&#xff0c;但是如果不能很好的理解其中的含义&#xff0c;那么就可以出现生产事故。 HikariPool-1 - Connection is not available, request timed out after 30001ms.当系统的调用量上去&#xff0c;就出现大量这样的连接…

NET8 ORM 使用AOT SqlSugar

.NET AOT8 基本上能够免强使用了, SqlSugar ORM也支持了CRUD 能在AOT下运行了 Nuget安装 SqlSugarCore 具体代码 StaticConfig.EnableAot true;//启用AOT 程序启动执行一次就好了//用SqlSugarClient每次都new,不要用单例模式 var db new SqlSugarClient(new ConnectionC…

Os-ByteSec

Os-ByteSec 一、主机发现和端口扫描 主机发现&#xff0c;靶机地址192.168.80.144 端口扫描&#xff0c;开放了80、139、445、2525端口 二、信息收集 访问80端口 路径扫描 dirsearch -u "http://192.168.80.144/" -e *访问扫描出来的路径&#xff0c;没有发现…

Idea安装完成配置

目录&#xff1a; 环境配置Java配置Maven配置Git配置 基础设置编码级设置File Header自动生成序列化编号配置 插件安装MyBtisPlusRestfulTooklkit-fix 环境配置 Java配置 Idea右上方&#xff0c;找到Project Settings. 有些版本直接有&#xff0c;有些是在设置下的二级菜单下…

LLM实现RPA

“PROAGENT: 从机器人流程自动化到代理流程自动化”这篇论文有几个创新点是比较有意思的&#xff1a;1.通过描述方式生成执行链&#xff0c;执行链通过代码方式生成保证执行链的稳健、可约束2.对执行过程抽取出数据结构&#xff0c;数据结构也通过代码生成方式来约束3.整个过程…

Web之CSS笔记

Web之HTML、CSS、JS 二、CSS&#xff08;Cascading Style Sheets层叠样式表&#xff09;CSS与HTML的结合方式CSS选择器CSS基本属性CSS伪类DIVCSS轮廓CSS边框盒子模型CSS定位 Web之HTML笔记 二、CSS&#xff08;Cascading Style Sheets层叠样式表&#xff09; Css是种格式化网…

【论文阅读】基于隐蔽带宽的汽车控制网络鲁棒认证(一)

文章目录 Abstract第一章 引言1.1 问题陈述1.2 研究假设1.3 贡献1.4 大纲 第二章 背景和相关工作2.1 CAN安全威胁2.1.1 CAN协议设计2.1.2 CAN网络攻击2.1.3 CAN应用攻击 2.2 可信执行2.2.1 软件认证2.2.2 消息身份认证2.2.3 可信执行环境2.2.4 Sancus2.2.5 VulCAN 2.3 侧信道攻…

竞赛 题目:基于深度学习卷积神经网络的花卉识别 - 深度学习 机器视觉

文章目录 0 前言1 项目背景2 花卉识别的基本原理3 算法实现3.1 预处理3.2 特征提取和选择3.3 分类器设计和决策3.4 卷积神经网络基本原理 4 算法实现4.1 花卉图像数据4.2 模块组成 5 项目执行结果6 最后 0 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 基…