李沐深度学习-softmax从零开始

import torch
import torchvision
import numpy as np
import syssys.path.append("路径")
import d2lzh_pytorch as d2l'''
1. 获取和读取数据
2. 初始化参数和模型
3. 定义softmax运算
4. 定义模型
5. 定义损失函数:交叉熵损失函数
6. 定义分类准确率
7. 训练模型
8. 预测
''''''
----------------- !!分类计算中,每个样本都要进行标签中类别数个预测,来判断该样本属于那种分类的概率大!!!!
''''''
-----------------------------------------------------------获取和读取数据
'''
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)'''
----------------------------------------------------------初始化模型参数
'''
num_inputs = 784  # 一个图像是28x28大小,输入特征个数就是784个  w分10类,一类也需要784个值,b只需要10个分类偏置
num_outputs = 10  # 分类个数
w = torch.tensor(np.random.normal(0, 0.01, size=(num_inputs, num_outputs)), dtype=torch.float, )  # 归一化生成w:784x10
b = torch.zeros(num_outputs, dtype=torch.float)
w.requires_grad_(requires_grad=True)
b.requires_grad_(requires_grad=True)'''
----------------------------------------------------------定义softmax运算
'''
# 如何对多维Tensor按维度进行操作
# 在以下的操作中和对其中同一列(dim=0)或同一行(dim=1)的元素进行求和,并在结果中保留行和列这两个维度(keepdim=True)
X = torch.tensor([[1, 2, 3], [4, 5, 6]])
print(X.sum(dim=0, keepdim=True))
print(X.sum(dim=1, keepdim=True))def softmax(X):  # 这里X不是样本,而是经过线性运算之后的结果,行数代表样本数,列数代表种类数/输出个数X_exp = X.exp()  # 先对X中的每个元素幂指数化partition = X_exp.sum(dim=1, keepdim=True)  # 每一行进行幂指数求和,得到分母return X_exp / partition  # 这里运用了广播机制X = torch.rand(2, 5)
result = softmax(X)
print(result)  # 经过softmax处理后,得到了预测输出在每个类别上的预测概率分布'''
-------------------------------------------------------------------------定义模型
'''def net(X):return softmax(torch.mm(X.view(-1, num_inputs), w) + b)  # 这里是进行线性矢量计算,可以是单个样本也可以是批量样本计算'''
------------------------------------------------------------------------定义损失函数:交叉熵损失函数
'''
# 为了得到标签的预测概率,可以使用gather函数,下面y_hat是2个样本在3个类别的预测概率,y是这2个样本的标签类别
# 通过使用gather函数,可以得到2个样本的标签的预测概率。 在代码中,标签类别的离散值是从0开始逐一递增的
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
y = torch.LongTensor([0, 2])  # 这是标签的类别
y_hat.gather(1, y.view(-1, 1))  # 应该根据y中标签类别确定y_hat中标签类别的对应位置然后取出这个类别的概率,这里应该取出# y_hat中的 (0,0),(1,2)位置处的值,因为这里是和y中类别0,2所对应的概率位置def cross_entroy(y_hat, y):return -torch.log(y_hat.gather(1, y.view(-1, 1)))  # 交叉熵只关心对正确类别的预测概率,通过使用gather函数,可以得到2个样本的标签的预测概率。tensor_0 = torch.arange(3, 12).view(3, 3)
index = torch.tensor([[2, 1, 0]]).t()
output = tensor_0.gather(1, index)
print(output)
'''-------------------------------------------------------------------计算分类准确率
'''# y_hat是一个预测概率分布,把分布中预测概率最大的作为输出类别,如果与真实类别y一直,则表示预测准确
# 分类预测率=正确预测个数与总预测数量之比def accuracy(y_hat, y):return (y_hat.argmax(dim=1) == y).float().mean().item(), (y_hat.argmax(dim=1) == y).float()# y_hat.argmax(dim=1)返回矩阵y_hat每行中最大的元素索引,且返回结果与变量y形状相同# 上述判断式是一个类型为ByteTensor的Tensor,使用float()将其转换为值为0(相等为假)或1(相等为真)的浮点型Tensorprint(accuracy(y_hat, y))  # 50%准确率,第一个预测错误,第二个预测正确
print(d2l.evaluate_accuracy(test_iter, net))  # 预测了mnist的test数据集softmax运算精度'''
-------------------------------------------------------------------------------------训练模型
'''
num_epochs, lr = 5, 0.1
result = d2l.train_ch3(net, train_iter, test_iter, cross_entroy, num_epochs, batch_size, [w, b], lr)'''
-------------------------------------------------------------------------------------预测
'''
X, y = next(iter(test_iter))  # test_iter返回的是一个迭代器对象,需要使用next()函数进行调用
true_labels = d2l.get_fashion_mnist_labels(y.numpy())  # 获取了test数据集中的真实标签并进行转义
pred_labels = d2l.get_fashion_mnist_labels(net(X).argmax(dim=1).numpy())  # 将预测函数net返回的y_hat取其每行最大值的下标索引
titles = [true + '\n' + pred for true, pred in zip(true_labels, pred_labels)]  # 列表中使用for循环# 一次next()只访问了一个批量元组,X就是一个列表
d2l.show_fashion_mnist(X[0:9], titles[0:9])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/633656.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

为什么要选择“零代码”开发的智慧能源管理平台?

全球低代码市场发展较早,集中度逐渐凸显,零代码市场尙未形成市场格局,很多企业出现“业务部门不懂技术,技术部门不懂业务”的现象往往会制约软件的开发进度,如何快速搭建软件系统应用,助力业务增长与效率提…

京东云开发者DDD妙文欣赏(1)

DDD领域驱动设计批评文集 做强化自测题获得“软件方法建模师”称号 《软件方法》各章合集 京东云开发者原文链接:DDD落地实践-架构师眼中的餐厅>>,以下简称《餐厅》。 我截图时,阅读量有6044,在同类文章中已经算是热文了…

初始Spring(适合新手)

一、Spring核心概念(IOC) 控制反转IOC:Inversion of control 控制对象产生的权利反转到spring ioc 依赖注入DI:Dependency injection 依赖spring ioc注入对象 最少jar包: spring-beans-.jar spring-context-.jar spring-core-.jar spring-ex…

山西电力市场日前价格预测【2024-01-20】

日前价格预测 预测说明: 如上图所示,预测明日(2024-01-20)山西电力市场全天平均日前电价为304.16元/MWh。其中,最高日前电价为486.22元/MWh,预计出现在18:15。最低日前电价为87.43元/MWh,预计出…

Qt 状态机框架:The State Machine Framework (一)

传送门: Qt 状态机框架:The State Machine Framework (一) Qt 状态机框架:The State Machine Framework (二) 一、什么是状态机框架 状态机框架提供了用于创建和执行状态图/表[1]的类。这些概念和表示法基于Harel的Statecharts:一种复杂系统的可视化形式,也是UML状态图的基…

adb 配对+无线连接

配对 打开手机开发者选项-无线调试-使用配对码配对设备 出现ip端口和配对码后,电脑输入命令: adb pair ip:端口 eg:adb pair 192.168.137.244:39683 提示输入配对码:就按照手机上的输入。 此时配对成功 连接 再使用命令adb connect ip:port…

Java工具类:将xml转为Json

目录 一、场景二、工具类三、测试类四、测试结果 一、场景 在对接第三方接口时,由于接口返回的并不是常见的Json,而是XML,所以需要将XML转为Json,方便后续处理 二、工具类 package com.xxx.util;import org.apache.commons.lang…

力扣 | 15. 三数之和

暴力解法import java.util.*;public class _15_ThreeSum1 {public List<List<Integer>> threeSum(int[] nums) {if (nums null || nums.length < 3)return new ArrayList<>();Set<List<Integer>> res new HashSet<>();Arrays.sort(nu…

Linux的常用命令

查看命令的帮助 命令名 --help 切换目录命令cd cd app 切换到app目录 cd .. 切换到上一层目录 cd / 切换到系统根目录 cd ~ 切换到用户主目录 cd - 切换到上一个所在目录 使用tab键来补全文件路径 列出文件列表&#xff1a;ls ll ls(list)是一个非常有用的命令&…

网页内容包含敏感字该怎么办?

嗨&#xff0c;大家好&#xff01;今天咱们来聊聊一个非常重要的话题——网页内容包含敏感字的危害。这可不是小事&#xff0c;影响可大了&#xff01; 首先&#xff0c;得搞明白什么是敏感字。这指的是那些可能引起不适或冒犯的词汇&#xff0c;可能涉及到政治、宗教、性别等方…

排序:计数排序

目录 思想&#xff1a; 操作步骤&#xff1a; 思路&#xff1a; 注意事项&#xff1a; 优缺点&#xff1a; 代码解析&#xff1a; 完整代码展示&#xff1a; 思想&#xff1a; 计数排序又称为鸽巢原理&#xff0c;是对哈希直接定址法的变形应用。 操作步骤&#xff…

基于 Hologres+Flink 的曹操出行实时数仓建设

本文整理自曹操出行实时计算负责人林震基于 HologresFlink 的曹操出行实时数仓建设的分享&#xff0c;内容主要分为以下六部分&#xff1a; 曹操出行业务背景介绍曹操出行业务痛点分析HologresFlink 构建企业级实时数仓曹操出行实时数仓实践曹操出行业务成果分析未来展望 一、曹…

AI新势力|将创业当作修行的BookGPT

近期&#xff0c;科技慢半拍联合AIGC开放社区采访了AI创业产品BootGPT的创始人陆再谋。陆总分享了他的创业之旅&#xff0c;从贵州到北京&#xff0c;再回到贵州的整段创业经历&#xff0c;从最初的困难到逐渐取得的成果&#xff0c;打造出了BookGPT这款创业产品。 在本次访谈中…

c++学习笔记-STL案例-机房预约系统4-管理员模块

前言 衔接上一篇“c学习笔记-STL案例-机房预约系统3-登录模块”&#xff0c;本文主要设计管理员模块&#xff0c;从管理员登录和注销、添加账号、显示账号、查看机房、清空预约五个功能进行分析和实现。 目录 7 管理员模块 7.1 管理员登录和注销 7.1.1 构造函数 ​编辑7.1.2…

加速电压对扫描电子显微镜成像的影响

扫描电子显微镜&#xff08;SEM&#xff09;是一种利用聚焦电子束扫描样品表面&#xff0c;通过激发和收集二次电子、特征X射线等信号&#xff0c;获得样品表面形貌和成分信息的分析仪器。在SEM成像过程中&#xff0c;加速电压是一个关键参数&#xff0c;对成像效果具有重要影响…

【概述版】悲剧先于解析:在大型语言模型的新时代,历史重演了

这篇论文探讨了大型语言模型&#xff08;LLM&#xff09;的成功对自然语言处理&#xff08;NLP&#xff09;领域的影响&#xff0c;并提出了在这一新时代中继续做出有意义贡献的方向。作者回顾了2005年机器翻译中大型语法模型的第一个时代&#xff0c;并从中汲取教训和经验。他…

运动型蓝牙耳机推荐哪款?2024运动耳机排行榜最新

​运动耳机在运动爱好者的装备清单中占有重要地位&#xff0c;要求舒适佩戴、卓越音质和环境适应性。市面上的运动耳机琳琅满目&#xff0c;选择合适的可能令人犹豫。那么都有哪些运动耳机值得入手呢&#xff1f;今天来跟大家聊聊运动耳机推荐哪款。 1.南卡开放式耳机&#xff…

数据结构:链式栈

stack.h /* * 文件名称&#xff1a;stack.h * 创 建 者&#xff1a;cxy * 创建日期&#xff1a;2024年01月18日 * 描 述&#xff1a; */ #ifndef _STACK_H #define _STACK_H#include <stdio.h> #include <stdlib.h>typedef struct stack{int data…

环形链表问题2(返回链表开始入环的第一个节点)

环形链表问题2&#xff08;返回链表开始入环的第一个节点&#xff09; 力扣&#xff08;LeetCode&#xff09;官网 - 全球极客挚爱的技术成长平台备战技术面试&#xff1f;力扣提供海量技术面试资源&#xff0c;帮助你高效提升编程技能&#xff0c;轻松拿下世界 IT 名企 Dream…

妇幼保健院污水处理需要哪些工艺设备

妇幼保健院作为医疗机构&#xff0c;在日常运营中会产生大量的污水&#xff0c;因此污水处理是一个非常重要的环节。为了保证污水得到有效处理&#xff0c;并达到相关的排放标准&#xff0c;妇幼保健院污水处理工艺设备是必不可少的。 首先&#xff0c;妇幼保健院污水处理需要一…