【pytorch20】多分类问题

网络结构以及示例

在这里插入图片描述
该网络的输出不是一层或两层的,而是一个十层的代表有十分类
在这里插入图片描述

新建三个线性层,每个线性层都有w和b的tensor

首先输入维度是784,第一个维度是ch_out,第二个维度才是ch_in(由于后面要转置),没有经过softmax函数和sigmoid,即logits

上图已经完成了网络的参数的定义和网络的前向传播过程

在这里插入图片描述
nn.CrossEntropyLoss()F.cross_entropy()是一样的功能,都包含softmax和log和F.nll_loss

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transformsbatch_size = 200
learning_rate = 0.01
epochs = 10train_loader = torch.utils.data.DataLoader(datasets.MNIST('data', train=True, download=True,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(datasets.MNIST('data', train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=True)w1, b1 = torch.randn(200, 784, requires_grad=True), \torch.zeros(200, requires_grad=True)
w2, b2 = torch.randn(200, 200, requires_grad=True), \torch.zeros(200, requires_grad=True)
w3, b3 = torch.randn(10, 200, requires_grad=True), \torch.zeros(10, requires_grad=True)def forward(x):x = x @ w1.t() + b1x = F.relu(x)x = x @ w2.t() + b2x = F.relu(x)x = x @ w3.t() + b3x = F.relu(x)return x# train
optimizer = optim.SGD([w1, b1, w2, b2, w3, b3], lr=learning_rate)
criten = nn.CrossEntropyLoss()for epoch in range(epochs):for batch_idx, (data, target) in enumerate(train_loader):data = data.view(-1, 28 * 28)logits = forward(data)loss = criten(logits, target)optimizer.zero_grad()loss.backward()optimizer.step()if batch_idx % 100 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))test_loss = 0correct = 0for data, target in test_loader:data = data.view(-1, 28 * 28)logits = forward(data)test_loss += criten(logits, target).item()# 每一行的最大值对应的索引pred = logits.data.max(1)[1]correct += pred.eq(target.data).sum()test_loss /= len(test_loader.dataset)print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))

在这里插入图片描述
创建的loss一直不变,为什么会出现这个问题,这个网络的结构层次并不深,数据集也比较简单,但这里出现了梯度弥散的情况,因为loss长时间得不到更新,因为梯度信息几乎接近于0

为什么会出现梯度为0?
影响训练的因素,除了有loss,学习率过大,还有初始化的问题,把初始化代码加上

torch.nn.init.kaiming_normal_(w1)
torch.nn.init.kaiming_normal_(w2)
torch.nn.init.kaiming_normal_(w3)

为什么b不初始化,因为已经初始化为0了

但是w也初始化,只是它们使用的是高斯分布进行初始化,即使是用高斯分布初始化后,结果也不满意,所以用了何凯明的初始化
在这里插入图片描述
可以看出loss直接到0.4了,准确率也达到了80%,而且这里还没运行完,运行完效果会更好

可以看出对于分类问题,初始化参数非常关键

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/44408.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【利用GroundingDINO裁剪分类任务的数据集】及文本提示检测图像任意目标(Grounding DINO) 的使用

文章目录 背景1.Grounding DINO安装2.裁剪指定目标的脚本 背景 在处理公开数据集ImageNet-21k的时候发现里面有很多的数据有问题,比如,数据目标有很多背景,且部分类别有其他种类的图片。针对数据目标有很多背景,公开数据集ImageNe…

【数据库】Redis主从复制、哨兵模式、集群

目录 一、Redis的主从复制 1.1 主从复制的架构 1.2 主从复制的作用 1.3 注意事项 1.4 主从复制用到的命令 1.5 主从复制流程 1.6 主从复制实现 1.7 结束主从复制 1.8 主从复制优化配置 二、哨兵模式 2.1 哨兵模式原理 2.2 哨兵的三个定时任务 2.3 哨兵的结构 2.4 哨…

ArkUI开发学习随机——B站视频简介页面,美团购买界面

案例一:B站视频简介页面 代码: build() {Column(){Column(){Stack(){Image($r("app.media.genimpact")).width(200).height(125).borderRadius({topLeft:5,topRight:5})Row(){Image($r("app.media.bz_play")).height(24).fillColor…

【人工智能】Transformers之Pipeline(概述):30w+大模型极简应用

​​​​​​​ 目录 一、引言 二、pipeline库 2.1 概述 2.2 使用task实例化pipeline对象 2.2.1 基于task实例化“自动语音识别” 2.2.2 task列表 2.2.3 task默认模型 2.3 使用model实例化pipeline对象 2.3.1 基于model实例化“自动语音识别” 2.3.2 查看model与task…

IEC62056标准体系简介-4.IEC62056-53 COSEM应用层

为在通信介质中传输COSEM对象模型,IEC62056参照OSI参考模型,制定了简化的三层通信模型,包括应用层、数据链路层(或中间协议层)和物理层,如图6所示。COSEM应用层完成对COSEM对象的属性和方法的访问&#xff…

01MFC建立单个文件类型——画线

文章目录 选择模式初始化文件作用解析各初始化文件解析类导向创建鼠标按键按下抬起操作函数添加一个变量记录起始位置注意事项代码实现效果图虚实/颜色线选择模式 初始化文件作用解析 运行: 各初始化文件解析 MFC(Microsoft Foundation Classes)是一个C++类库,用于在Win…

防御课综合实验

实验拓扑: 实验要求: 1、DMZ区内的服务器,办公区仅能在办公时间内(9点到18点)可以访问,生产区的设备全天可以访问 2、生产区不允许访问互联网,办公区和游客区允许访问互联网 3、办公区设备10…

二叉平衡树(左单旋,右单旋,左右双旋、右左双旋)

一、AVL树(二叉平衡树:高度平衡的二叉搜索树) 0、二叉平衡树 左右子树高度差不超过1的二叉搜索树。 public class AVLTree{static class AVLTreeNode {public TreeNode left null; // 节点的左孩子public TreeNode right null; // 节点的…

基于Transformer的端到端的目标检测 | 读论文

本文正在参加 人工智能创作者扶持计划 提及到计算机视觉的目标检测,我们一般会最先想到卷积神经网络(CNN),因为这算是目标检测领域的开山之作了,在很长的一段时间里人们都折服于卷积神经网络在图像处理领域的优势&…

论文 | REACT: SYNERGIZING REASONING AND ACTING INLANGUAGE MODELS

本文首先认为,到目前为止,LLM 在语言理解方面令人印象深刻,它们已被用来生成 CoT(思想链)来解决一些问题,它们也被用于执行和计划生成。 尽管这两者是分开研究的,但本文旨在以交错的方式将推理…

JSP入门基础

JSP入门基础 软件开发环境这门课程的复习资料 Web开发技术概述 URL的组成部分 协议、主机DNS名或IP地址和文件名 Tomcat服务器 Tomcat服务器的默认端口号是8080 概念 软件开发环境是围绕着软件开发的一定目标而组织在一起的一组相关软件工具的有机集合 JSP和HTML的区别…

SPE连接器技术革新汽车制造业

概述 新的SPE标准在汽车制造业中的应用正日益受到重视,它不仅推动了汽车通信技术的革新,还对汽车性能测试方法产生了深远影响。本文将详细探讨SPE标准在汽车制造业中的应用案例分析,以及它对供应链的挑战与机遇。 SPE标准在汽车制造业中的应…

[leetcode]subarray-product-less-than-k 乘积小于K的子数组

. - 力扣&#xff08;LeetCode&#xff09; class Solution { public:int numSubarrayProductLessThanK(vector<int>& nums, int k) {if (k 0) {return 0;}int n nums.size();vector<double> logPrefix(n 1);for (int i 0; i < n; i) {logPrefix[i 1] …

揭秘!chatGPT核心技术应用

2022年11月30日&#xff0c;可能将成为一个改变人类历史的日子——美国人工智能开发机构OpenAI推出了聊天机器人ChatGPT-3.5&#xff0c;将人工智能的发展推向了一个新的高度。2023年11月7日&#xff0c;OpenAI首届开发者大会被称为“科技界的春晚”&#xff0c;吸引了全球广大…

prometheus回顾(2)--如何使用Grafana对接Prometheus数据源的详细过程,清晰易懂。

文章目录 Grafana简介什么是GrafanaGrafana 能做什么&#xff1f;什么时候我们会用到Grafana?Prometheus有图形化展示&#xff0c;为什么我们还要用Grafana? 环境操作步骤一、Grafana安装二、Grafana数据源Prometheus添加三、Grafana添加数据仪表盘补充、如何查找仪表盘 Graf…

在Linux下直接修改磁盘镜像文件的内容

背景 嵌入式Linux系统通常在调试稳定后&#xff0c;会对磁盘&#xff08;SSD、NVME、SD卡、TF卡&#xff09;做个镜像&#xff0c;通常是.img后缀的文件&#xff0c;以后组装新设备时&#xff0c;就将镜像文件烧录到新磁盘即可&#xff0c;非常简单。 这种方法有个不便之处&a…

Oracle学习笔记

Oracle 一、简介&#xff1a; 特点&#xff1a; 多用户、大事务量的事务处理 数据安全性和完整性控制 支持分布式数据处理 可以移植性 Oracle 19c 安装 登录甲骨文&#xff0c;安装Oracle 解压压缩包 安装 完毕 此处账户&#xff1a;qfedu 密码&#xff1a;wang8218.…

染色法判定二分图

什么是二分图&#xff1f; 二分图&#xff0c;也称作二部图&#xff0c;是图论中的一种特殊模型。在一个无向图G(V,E) 中&#xff0c;如果顶点集合 V 可以被分割成两个互不相交的子集 A 和 B&#xff0c;并且图中的每条边 (i,j) 关联的两个顶点 i 和 j 分别属于这两个不同的顶…

LeetCode(2)合并链表、环形链表的约瑟夫问题、链表分割

一、合并链表 . - 力扣&#xff08;LeetCode&#xff09; 题目描述&#xff1a; /*** Definition for singly-linked list.* struct ListNode {* int val;* struct ListNode *next;* };*/ typedef struct ListNode ListNode; struct ListNode* mergeTwoLists(struct …

C++入门基础篇(下)

目录 6.引用 6.1 引用的特性 6.2 const引用 7.指针和引用的关系 8.内联函数 9.nullptr 6.引用 引⽤不是新定义⼀个变量&#xff0c;⽽是给已存在变量取了⼀个别名&#xff0c;编译器不会为引⽤变量开辟内存空间&#xff0c; 它和它引⽤的变量共⽤同⼀块内存空间。比如&a…