动手学深度学习(三)线性神经网络—softmax回归

分类任务是对离散变量预测,通过比较分类的概率来判断预测的结果。

softmax回归和线性回归一样也是将输入特征与权重做线性叠加,但是softmax回归的输出值个数等于标签中的类别数,这样就可以用于预测分类问题。

分类问题和线性回归的区别:分类任务通常有多个输出,作为不同类别的置信度。

一、softmax回归

1.1 网络架构

为了解决线性模型的分类问题,我们需要和输出一样多的仿射函数,每个输出对应它自己的仿射函数。

与线性回归一样,softmax回归也是一个单层神经网络。

在softmax回归中,输出层的输出值大小就代表其所属类别的置信度大小,置信度最大的那个类别我们将其作为预测。

1.2 softmax运算

首先,分类任务的目标是通过比较每个类别的置信度大小来判断预测的结果。但是,我们不能选择未规范化的最大输出值的 o_i 的类别作为我们的预测,原因有两点:

1. 输出值 o_i的总和不一定为1

2. 输出值 o_i有可能为负数。

这违反了概率论基本公理,很难判断所预测的类别是否真符合真实值。

softmax函数通过如下公式,解决了以上问题

softmax函数确保了输出值的非负,和为1,这一种规范手段。

1.3 交叉熵损失函数

交叉熵损失常用来衡量两个概率之间的差别

根据公式推断, 交叉熵损失函数的偏导数是我们softmax函数分配的概率与实际发生的情况之间的差距,换句话来说,其梯度是真实概率 y 和预测概率 \hat{y} 之间的差距。

二、图像分类数据集

MNIST数据集是图像分类中广泛使用的数据集之一,但作为基准数据集过于简单。我们将使用类似但更复杂的Fashion-MNIST数据集。

2.1 导包

import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2ld2l.use_svg_display()  # 用SVG显示图片

2.2 创建数据集

通过框架中的内置函数将Fashion-MNIST数据集下载并读取到内存中。

# 通过ToTensor实例将图像数据从PIL类型转化成32位的浮点数格式
# 并除以255使得所有像素的数值均在0到1之间
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)

查看Fashion-MNIST训练集和测试集大小,分别包含60000,10000张图片。

print(len(mnist_train), len(mnist_test))

查看图片分辨率,图片分辨率大小为[1, 28, 28]。

print(mnist_train[0][0].shape)

补充

torchvision.datasets 是Torchvision提供的标准数据集。

torchvision.transforms是包含一系列常用图像变换方法的包,可用于图像预处理、数据增强等工作。

torchvision.transforms.ToTensor()把一个取值范围是[0,255]PIL.Image或者shape(H,W,C)numpy.ndarray,转换成形状为[C,H,W],取值范围是[0,1.0]torch.FloadTensor(浮点型的tensor)。

 

2.3 可视化数据集函数

# 可视化数据集函数
def get_fashion_mnist_labels(labels):"""返回Fashion-MNIST数据集的文本标签"""text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):"""绘制图像列表"""figsize = (num_cols * scale, num_rows * scale)_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize) # 创建绘制num_rows*num_cols个子图的位置区域axes = axes.flatten() # 降维成一维数组for i, (ax, img) in enumerate(zip(axes, imgs)):if torch.is_tensor(img):# 图片张量ax.imshow(img.numpy()) # 负责对图像进行处理,并存入内存,并不显示else:# PIL图片ax.imshow(img)ax.axes.get_xaxis().set_visible(False) #不显示y轴ax.axes.get_yaxis().set_visible(False) #不显示x轴if titles:ax.set_title(titles[i])return axes

可视化展示训练集中前18个图片。

X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y))
d2l.plt.show() # 将plt.imshow()处理后的数据显示出来

plt.subplots(num_rows, num_cols, figsize):创建绘制num_rows*num_cols个子图的位置区域,其中子图大小为figsize。

enumerate():获取可迭代对象的每个元素的索引值及该元素值。

zip():用于将可迭代的对象作为参数,将对象中对应的元素打包成一个个元组,然后返回由这些元组组成的列表。

imshow():负责对图像进行处理,并存入内存,并不显示。

plt.show():将plt.imshow()处理后的数据显示出来。

2.4 读取小批量

使用4个进程,以批量大小为256,来读取数据集。

# 读取小批量
batch_size = 256def get_dataloader_workers(): """使用4个进程来读取数据"""return 4train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers())

2.5 整合所有组件

这个函数包含了以上所有工作。

def load_data_fashion_mnist(batch_size, resize=None):"""下载Fashion-MNIST数据集,然后将其加载到内存中"""trans = [transforms.ToTensor()]if resize:trans.insert(0, transforms.Resize(resize)) #修改图片大小trans = transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)return (data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers()),data.DataLoader(mnist_test, batch_size, shuffle=False,num_workers=get_dataloader_workers()))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/31107.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[YAPI]导出API文档

1.登录点击进去,点击项目2.点击接口,点击编辑,划到最下面,开启开放接口3.点击数据管理, 选择你要的数据导出格式,点击公开接口, 导出完别忘记关闭,防止别人导的时候将你开启的 也一并下载下来

opencv 基础54-利用形状场景算法比较轮廓-cv2.createShapeContextDistanceExtractor()

注意:新版本的opencv 4 已经没有这个函数 cv2.createShapeContextDistanceExtractor() 形状场景算法是一种用于比较轮廓或形状的方法。这种算法通常用于计算两个形状之间的相似性或差异性,以及找到最佳的匹配方式。 下面是一种基本的比较轮廓的流程&…

点的复合运动

一、问题所在 对于复合运动中的牵连运动一直很蒙,之前做题的时候都是靠经验,比如圆盘选择圆心做动系原点、连杆选择牵连点做原点等,今天重新整理了一下。 牵连运动的定义是动系相对于定系的运动,这个定义就很模糊。如果是指动系…

调整项目符号/项目编号与文本的距离

百度知道多年前的答案是调整标尺,我的PPT里没有标尺 调节悬挂缩进即可

flutter开发实战-实现左右来回移动的按钮引导动画效果

flutter开发实战-实现左右来回移动的按钮引导动画效果 最近开发过程中需要实现左右来回移动的按钮引导动画效果 一、动画 AnimationController用来控制一个或者多个动画的正向、反向、停止等相关动画操作。在默认情况下AnimationController是按照线性进行动画播放的。Animati…

【Vue】input 事件

input 事件是在用户输入内容时触发的事件。它适用于包含文本输入框&#xff08;例如 <input> 或 <textarea>&#xff09;的元素&#xff0c;以及可编辑的内容区域&#xff08;例如 <div contenteditable>&#xff09;。 当用户在输入框中输入文本、复制粘贴…

Vite 创建 Vue项目之后,eslint 错误提示的处理

使用 npm create vuelatest创建 vue 项目&#xff08;TS&#xff09;之后&#xff0c;出现了一些 eslint 错误提示&#xff0c;显然&#xff0c;不是代码真实的错误&#xff0c;而是提示搞错了。 vuejs/create-vue: &#x1f6e0;️ The recommended way to start a Vite-pow…

勘探开发人工智能技术:机器学习(3)

0 提纲 4.1 logistic回归 4.2 支持向量机(SVM) 4.3 PCA 1 logistic回归 用超平面分割正负样本, 考虑所有样本导致的损失. 1.1 线性分类器 logistic 回归是使用超平面将空间分开, 一边是正样本, 另一边是负样本. 因此, 它是一个线性分类器. 如图所示, 若干样本由两个特征描…

Ubuntu 20.04 中安装docker一键安装脚本

直接上脚本&#xff0c;依次执行如下命令即可 wget http://apollo-pkg-beta.bj.bcebos.com/docker_install.sh bash docker_install.shdocker install docker operation system Ubuntu 18.04 直接上脚本&#xff0c;依次执行如下命令即可 ways1 : wget https://github.com…

FPGA应用学习-----FIFO双口ram解决时钟域+asic样机的时钟选通

60m写入异步ram&#xff0c;再用100M从ram中读出 写地址转换为格雷码后&#xff0c;打两拍和读地址判断是否空产生。相反读地址来判断是否满产生。 分割同步模块 asic时钟的门控时钟&#xff0c;fpga是不推荐采用门控时钟的&#xff0c;有很多方法移除fpga的时钟选通。 如果是a…

Plugin 插件

Plugin 插件 插件是 webpack 的支柱功能。插件目的在于解决 loader 无法实现的其他事。Webpack 提供很多开箱即用的插件。 常用插件 clean-webpack-plugin 自动清理输出目录 html-webpack-plugin 自动生成使用 bundle.js 的 HTML copy-webpack-plugin 拷贝文件到输出目…

天花板级,Python接口自动化测试-接口关联封装调用(实例)

目录&#xff1a;导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结&#xff08;尾部小惊喜&#xff09; 前言 流程相关的接口&a…

docker基本使用方法

docker使用 1. Docker 介绍 Docker 可以让开发者打包他们的应用以及依赖包到一个轻量级、可移植的容器中&#xff0c;然后发布到任何流行的 Linux 机器上&#xff0c;也可以实现虚拟化。Docker 使您能够将应用程序与基础架构分开&#xff0c;从而可以快速交付软件。通过利用 …

Zabbix 6.0 监控其他

文章目录 一、Zabbix 监控 Windows 系统1&#xff09;下载 Windows 客户端 Zabbix agent 22&#xff09;安装客户端&#xff0c;配置3&#xff09;在服务端 Web 页面添加主机&#xff0c;关联模板 二、Zabbix 监控 java 应用1&#xff09;客户端开启 java jmxremote 远程监控功…

2500、删除每行中最大值在IDEA中调试Java

leetcode:2500、删除每行中最大值在IDEA中调试&#xff0c;使用Java实现 题目描述&#xff1a; 给你一个 m x n 大小的矩阵 grid &#xff0c;由若干正整数组成。 执行下述操作&#xff0c;直到 grid 变为空矩阵&#xff1a; 从每一行删除值最大的元素。如果存在多个这样的…

MySQL_多表查询

多表查询 概述&#xff1a;多表查询就是多张表之间的查询。 回顾&#xff1a;SELECT * FROM table_name 多表查询 from 后面就得跟多张表。如&#xff1a;select * from emp,dept 笛卡尔积&#xff1a;笛卡尔积在数学中&#xff0c;表示两个集合&#xff0c;集合 A 和集合 …

Django实现音乐网站 ⑽

使用Python Django框架制作一个音乐网站&#xff0c; 本篇主要是后台对歌曲类型、歌单功能原有功能进行部分功能实现和显示优化。 目录 歌曲类型功能优化 新增编辑 优化输入项标题显示 父类型显示改为下拉菜单 列表显示 父类型显示名称 过滤器增加父类型 歌单表功能优化…

# X11、Xlib、XFree86、Xorg、GTK、Qt、Gnome和KDE之间的关系

X11、Xlib、XFree86、Xorg、GTK、Qt、Gnome和KDE之间的关系 很多人对于他们是啥是傻傻分不清的&#xff0c;我做了个表格供大家参考。 摘抄&#xff1a; X11是X Window System Protocol, Version 11&#xff08;RFC1013&#xff09;&#xff0c;是X server和X client之间的通…

Android多渠道打包+自动签名工具 [原创]

多渠道打包自动签名工具 [原创] github源码&#xff1a;github.com/G452/apk-packer 如果觉得有帮助可以点个小星星支持一下&#xff0c;万分感谢&#xff01; 使用步骤&#xff1a; 1、在apk-packer.exe目录内放入打包需要的配置&#xff1a; 1&#xff09;签名文件.jks2&am…

各种变形链表(循环链表、双向链表、带头结点的链表等)的表示和基本操作的实现

目录 双向链表双链表的插入操作双链表的删除操作 循环链表循环双链表静态链表 双向链表 单链表节节点中只有一个指向其后继的指针&#xff0c;使得单链表只能从头结点一次顺序的向后遍历。要访问某个记得点的前驱结点&#xff08;插入、删除操作时&#xff09;&#xff0c;只能…