从0书写一个softmax分类 李沐pytorch实战

 输出维度

在softmax 分类中 我们输出与类别一样多。 数据集有10个类别,所以网络输出维度为10。

 初始化权重和偏置

torch.norma 生成一个均值为 0,标准差为0.01,一个形状为size=(num_inputs, num_outputs)的张量

偏置生成一个num_outputs =10 的一维张量,并用0初始化 

W = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True)###
b = torch.zeros(num_outputs, requires_grad=True)

requires_grad=True,PyTorch 会在后向传播过程中自动计算该张量的梯度,这对于优化模型参数非常重要。

sum 运算符工作机制:

X = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
X.sum(0, keepdim=True), X.sum(1, keepdim=True)

sum = 0 张量按列求和 sum = 1 张量按行求和 

定义softmax函数 

def softmax(X):X_exp = torch.exp(X)partition = X_exp.sum(1, keepdim=True)return X_exp / partition  # 这里应用了广播机制

测试

X = torch.normal(0, 1, (2, 5))
X_prob = softmax(X)
print(X_prob)
print(X_prob.sum(1))

定义sofrmax模型:

softmax 回归模型 定义

def net(X):return softmax(torch.matmul(X.reshape((-1, W.shape[0])), W) + b)

 W.shape[0]表示权重的第一维大小

reshape 函数会根据原始张量 X 的元素总数和你提供的其他维度来计算出 -1 代表的维度

定义交叉熵损失函数:

回顾

y = torch.tensor([0, 2])
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
print(y_hat[[0, 1], y])

y 张量是两个真实类别,第0类和第二类

y_hat 是对两个类别 在三种类别上的预测,真实的第0类预测结果为0.1,第2类预测结果为0.5

输出,取出y_hat的指定索引,

  • 对于第一个样本(索引 0),取 y_hat[0, 0],即 0.1
  • 对于第二个样本(索引 1),取 y_hat[1, 2],即 0.5
def cross_entropy(y_hat, y):return - torch.log(y_hat[range(len(y_hat)), y])print(f'交叉熵损失为{cross_entropy(y_hat, y)}')

交叉熵损失为tensor([2.3026, 0.6931])

定义分类精度:

计算出 正确预测数量与总预测数量之比

def accuracy(y_hat, y):  #@save"""计算预测正确的数量"""if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:y_hat = y_hat.argmax(axis=1)cmp = y_hat.type(y.dtype) == yreturn float(cmp.type(y.dtype).sum())
print(accuracy(y_hat, y) / len(y))
  • y_hat 的形状为 (2, 3),这意味着:
    • 第一维的大小是 2,表示有 2 个样本(行)。
    • 第二维的大小是 3,表示每个样本有 3 个类别的预测概率(列)。

因此,y_hat.shape[1] 返回的是第二维的大小,也就是 3。这个值表示每个样本的类别数。在这个例子中,y_hat 中的每一行包含了对应样本对 3 个类别的预测概率。

print :y_hat.argmax(axis=1) 预测值张量为[2,2],与y = torch.tensor([0, 2])做对比

将布尔张量转换为整型张量

将[False,True],转换为0,1形式

(cmp.type(y.dtype).sum())

报错RuntimeError: DataLoader worker (pid(s) 12452, 3084, 29000, 29444) exited unexpectedly解决方法:

train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
test_iter.num_workers = 0
train_iter.num_workers = 0

再训练迭代器和测试迭代器后加入

定义评估模型准确率函数:

def evaluate_accuracy(net, data_iter):  #@save"""计算在指定数据集上模型的精度"""if isinstance(net, torch.nn.Module):net.eval()  # 将模型设置为评估模式metric = Accumulator(2)  # 正确预测数、预测总数with torch.no_grad():for X, y in data_iter:metric.add(accuracy(net(X), y), y.numel())return metric[0] / metric[1]

 if isinstance(net, torch.nn.Module)是一个Python内置函数,用于检查对象net是否是torch.nn.Module类的实例。

  • 创建一个Accumulator对象,用于累加正确预测的数量和预测的总数量。Accumulator类会存储两个值。
  • 调用模型net对输入X进行预测,得到预测结果,然后使用accuracy函数计算预测的准确数量。y.numel()返回标签y中的元素总数(即样本数量),这两个值一起传递给metric.add()进行累加。

代码分析:

    def add(self, *args):self.data = [a + float(b) for a, b in zip(self.data, args)]
  • zip函数将self.data(当前存储的累积值)和args(输入的多个参数)配对。假设self.data是一个包含n个元素的列表,而args也是一个包含n个元素的可变参数列表。
  • 例如,如果self.data = [3.0, 5.0],而args = (2, 1)zip会生成[(3.0, 2), (5.0, 1)]的迭代器。
  • [a + float(b) for a, b in zip(self.data, args)]是一个列表推导式,用于遍历zip生成的配对。在这个过程中:
    • aself.data中的当前元素。
    • bargs中的当前元素。
  • 对每一对(a, b),该表达式计算a + float(b),将b转换为浮点数并与a相加。

 预测:

def predict_ch3(net, test_iter, n=9):  #@save"""预测标签(定义见第3章)"""for X, y in test_iter:breaktrues = d2l.get_fashion_mnist_labels(y)preds = d2l.get_fashion_mnist_labels(net(X).argmax(axis=1))titles = [true +'\n' + pred for true, pred in zip(trues, preds)]d2l.show_images(X[0:n].reshape((n, 28, 28)), 1, n, titles=titles[0:n])
predict_ch3(net, test_iter)
d2l.plt.show()

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/879313.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Kubernetes从零到精通(10-服务Service)

Service简介 Deployment这种工作负载能管理我们应用Pod的副本数,并实现动态的创建和销毁,所以Pod本身是临时资源(IP随时可能变化)。现在如果某组Pod A需要访问另一组Pod B,A就需要在应用的配置参数里动态跟踪并更改B的…

【数学建模】相关系数

第一部分:相关系数简介 总体与样本: 总体:指研究对象的全体,比如全国人口普查数据。样本:从总体中抽取的一部分个体,如通过问卷调查收集的学生数据。 皮尔逊相关系数: 总体皮尔逊相关系数&…

Linux 8250串口控制器

1 8250串口类型的识别 Intel HW都使用DesignWare 8250: drivers/mfd/intel-lpss-pci.c drivers/tty/serial/8250/8250_dw.c IIR寄存器的高2位bit7、bit6用来识别8250串口的类型: 0 - 8250,无FIFO 0 - 并且存在SCR(Scratch registe…

安科瑞Acrel-1000DP分布式光伏监控系统平台的设计与应用-安科瑞 蒋静

针对用户新能源接入后存在安全隐患、缺少有效监控、发电效率无法保证、收益计算困难、运行维护效率低等通点,提出的Acrel-1000DP分布式光伏监控系统平台,对整个用户电站全面监控,为用户实现降低能源使用成本、减轻变压器负载、余电上网&#…

如何构建大数据治理平台,助力企业数据决策

建设背景 (1)什么是数据资产 资产由企业及组织拥有和控制,能够提供增值服务、带来经济利益的重要资源。 资产不但需要管理, 更需要运营。 (2)数据资产运营中的问题 数据资产运营中存在的问题主要包括以下…

CANopen协议的理解

本文的重点是对CANopen协议的理解,不是编程实现 参考链接 canopen快速入门 1cia301协议介绍_哔哩哔哩_bilibili CANopen是什么? CANopen通讯基础(上)_哔哩哔哩_bilibili CANopen概述 图1. CAN报文标准帧的格式 CAN的报文可简单…

docker-compose 部署 flink

下载 flink 镜像 [rootlocalhost ~]# docker pull flink Using default tag: latest latest: Pulling from library/flink 762bedf4b1b7: Pull complete 95f9bd9906fa: Pull complete a880dee0d8e9: Pull complete 8c5deab9cbd6: Pull complete 56c142282fae: Pull comple…

Redis搭建集群

功能概述 Redis Cluster是Redis的自带的官方分布式解决方案,提供数据分片、高可用功能,在3.0版本正式推出。 使用Redis Cluster能解决负载均衡的问题,内部采用哈希分片规则: 基础架构图如下所示: 图中最大的虚线部分…

路由器WAN口和LAN口有什么不一样?

“ 路由器WAN口和LAN口的区别,WAN是广域网端口,LAN是本地网端口。WAN主要用于连接外部网络,而LAN用来连接家庭内部网络,两者主要会在标识上面有区别。以往大部分路由器的WAN只有一个,LAN口则有四个或以上,近…

《深度学习》—— 神经网络基本结构

前言 深度学习是一种基于神经网络的机器学习算法,其核心在于构建由多层神经元组成的人工神经网络,这些层次能够捕捉数据中的复杂结构和抽象特征。神经网络通过调整连接各层的权重,从大量数据中自动学习并提取特征,进而实现预测或…

Banana Pi BPI-SM9 AI 计算模组采用算能科技BM1688芯片方案设计

产品概述 香蕉派 Banana Pi BPI-SM9 16-ENC-A3 深度学习计算模组搭载算能科技高集成度处理器 BM1688,功耗低、算力强、接口丰富、兼容性好。支持INT4/INT8/FP16/BF16/FP32混合精度计算,可支持 16 路高清视频实时分析,灵活应对图像、语音、自…

Python面试宝典第48题:找丑数

题目 我们把只包含质因子2、3和5的数称作丑数(Ugly Number)。比如:6、8都是丑数,但14不是,因为它包含质因子7。习惯上,我们把1当做是第一个丑数。求按从小到大的顺序的第n个丑数。 示例 1: 输入…

基于MinerU的PDF解析API

基于MinerU的PDF解析API - MinerU的GPU镜像构建 - 基于FastAPI的PDF解析接口支持一键启动,已经打包到镜像中,自带模型权重,支持GPU推理加速,GPU速度相比CPU每页解析要快几十倍不等 主要功能 删除页眉、页脚、脚注、页码等元素&…

uniapp使用高德地图设置marker标记点,后续根据接口数据改变某个marker标记点,动态更新

最近写的一个功能属实把我难倒了,刚开始我请求一次数据获取所有标记点,然后设置到地图上,然后后面根据socket传来的数据对这些标记点实时更新,改变标记点的图片或者文字, 1:第一个想法是直接全量替换,事实证明这样不行,会很卡顿,有明显闪烁感,如果标记点比较少,就十几个可以用…

【网络安全】-rce漏洞-pikachu

rce漏洞包含命令执行漏洞与代码执行漏洞 文章目录 前言 什么是rce漏洞? 1.rce漏洞产生原因: 2.rce的分类: 命令执行漏洞: 命令拼接符: 常用函数: 代码执行漏洞: 常用函数: 分类&…

信号与线性系统综合实验

文章目录 一、实验目的二、实验内容及其结果分析(一)基础部分(二)拓展部分(三)应用设计部分 三、心得体会 一、实验目的 1、掌握连续时间信号与系统的时域、频域综合分析方法;   2、掌握运用M…

SAP B1 单据页面自定义 - 用户界面编辑字段

背景 接《SAP B1 基础实操 - 用户定义字段 (UDF)》,在设置完自定义字段后,如下图,通过打开【用户定义字段】可打开表单右侧的自定义字段页。然而再开打一页附加页面操作繁复,若是客户常用的定义字段,也可以把这些用户…

JMM 指令重排 volatile happens-before

在单线程程序中,操作系统会通过编译器优化重排序、指令级并行重排序、内存系统重排序三个步骤对源代码进行指令重排,提高代码执行的性能。 但是在多线程情况下,操作系统“盲目” 地进行指令重排可能会导致我们不想看到的问题,如经…

2024第三届大学生算法大赛 真题训练2 解题报告 | 珂学家 | FFT/NTT板子

前言 题解 D是FFT板子题,这么来看,其实处于ACM入门题,哭了T_T. D. 行走之谜 思路: FFT 如果你知道多项式乘法,继而知道FFT,那题纯粹就是板子题,可惜当时比赛的时候,无人AC。 这题来简单抽象…

物联网之PWM呼吸灯、脉冲、LEDC

MENU 前言原理硬件电路设计软件程序设计analogWrite()函数实现呼吸灯效果LEDC输出PWM信号 前言 学习制作呼吸灯,通过LED灯的亮度变化来验证PWM不同电压的输出。呼吸灯是指灯光在单片机的控制之下完成由亮到暗的逐渐变化,感觉好像是人在呼吸。 原理 脉冲宽…