【pytorch20】多分类问题

网络结构以及示例

在这里插入图片描述
该网络的输出不是一层或两层的,而是一个十层的代表有十分类
在这里插入图片描述

新建三个线性层,每个线性层都有w和b的tensor

首先输入维度是784,第一个维度是ch_out,第二个维度才是ch_in(由于后面要转置),没有经过softmax函数和sigmoid,即logits

上图已经完成了网络的参数的定义和网络的前向传播过程

在这里插入图片描述
nn.CrossEntropyLoss()F.cross_entropy()是一样的功能,都包含softmax和log和F.nll_loss

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transformsbatch_size = 200
learning_rate = 0.01
epochs = 10train_loader = torch.utils.data.DataLoader(datasets.MNIST('data', train=True, download=True,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(datasets.MNIST('data', train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=True)w1, b1 = torch.randn(200, 784, requires_grad=True), \torch.zeros(200, requires_grad=True)
w2, b2 = torch.randn(200, 200, requires_grad=True), \torch.zeros(200, requires_grad=True)
w3, b3 = torch.randn(10, 200, requires_grad=True), \torch.zeros(10, requires_grad=True)def forward(x):x = x @ w1.t() + b1x = F.relu(x)x = x @ w2.t() + b2x = F.relu(x)x = x @ w3.t() + b3x = F.relu(x)return x# train
optimizer = optim.SGD([w1, b1, w2, b2, w3, b3], lr=learning_rate)
criten = nn.CrossEntropyLoss()for epoch in range(epochs):for batch_idx, (data, target) in enumerate(train_loader):data = data.view(-1, 28 * 28)logits = forward(data)loss = criten(logits, target)optimizer.zero_grad()loss.backward()optimizer.step()if batch_idx % 100 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))test_loss = 0correct = 0for data, target in test_loader:data = data.view(-1, 28 * 28)logits = forward(data)test_loss += criten(logits, target).item()# 每一行的最大值对应的索引pred = logits.data.max(1)[1]correct += pred.eq(target.data).sum()test_loss /= len(test_loader.dataset)print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))

在这里插入图片描述
创建的loss一直不变,为什么会出现这个问题,这个网络的结构层次并不深,数据集也比较简单,但这里出现了梯度弥散的情况,因为loss长时间得不到更新,因为梯度信息几乎接近于0

为什么会出现梯度为0?
影响训练的因素,除了有loss,学习率过大,还有初始化的问题,把初始化代码加上

torch.nn.init.kaiming_normal_(w1)
torch.nn.init.kaiming_normal_(w2)
torch.nn.init.kaiming_normal_(w3)

为什么b不初始化,因为已经初始化为0了

但是w也初始化,只是它们使用的是高斯分布进行初始化,即使是用高斯分布初始化后,结果也不满意,所以用了何凯明的初始化
在这里插入图片描述
可以看出loss直接到0.4了,准确率也达到了80%,而且这里还没运行完,运行完效果会更好

可以看出对于分类问题,初始化参数非常关键

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/44408.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于STM32的HAL库的WS2812幻彩灯驱动程序

1、WS2812幻彩灯带有三条接线,5V、GND和控制信号线,利用单片机的SPI引脚可以发出不同频率协议的脉冲即可实现对幻彩灯带的各个灯珠的颜色和亮灭的控制。 2、幻彩灯带驱动代码如下: (1)ws2812.c #include "ws28…

【利用GroundingDINO裁剪分类任务的数据集】及文本提示检测图像任意目标(Grounding DINO) 的使用

文章目录 背景1.Grounding DINO安装2.裁剪指定目标的脚本 背景 在处理公开数据集ImageNet-21k的时候发现里面有很多的数据有问题,比如,数据目标有很多背景,且部分类别有其他种类的图片。针对数据目标有很多背景,公开数据集ImageNe…

【数据库】Redis主从复制、哨兵模式、集群

目录 一、Redis的主从复制 1.1 主从复制的架构 1.2 主从复制的作用 1.3 注意事项 1.4 主从复制用到的命令 1.5 主从复制流程 1.6 主从复制实现 1.7 结束主从复制 1.8 主从复制优化配置 二、哨兵模式 2.1 哨兵模式原理 2.2 哨兵的三个定时任务 2.3 哨兵的结构 2.4 哨…

ArkUI开发学习随机——B站视频简介页面,美团购买界面

案例一:B站视频简介页面 代码: build() {Column(){Column(){Stack(){Image($r("app.media.genimpact")).width(200).height(125).borderRadius({topLeft:5,topRight:5})Row(){Image($r("app.media.bz_play")).height(24).fillColor…

【人工智能】Transformers之Pipeline(概述):30w+大模型极简应用

​​​​​​​ 目录 一、引言 二、pipeline库 2.1 概述 2.2 使用task实例化pipeline对象 2.2.1 基于task实例化“自动语音识别” 2.2.2 task列表 2.2.3 task默认模型 2.3 使用model实例化pipeline对象 2.3.1 基于model实例化“自动语音识别” 2.3.2 查看model与task…

如何通过Java操作Redis?——Jedis!

简介 在redis命令行客户端中操作redis是否可行?可行,但不方便且不是主流的方式。最终还是要通过Java代码来操作~ Redis的底层通信是遵守RESP协议的,一些第三方的库就实现了这些协议,然后封装好API,程序猿通过封装好的…

IEC62056标准体系简介-4.IEC62056-53 COSEM应用层

为在通信介质中传输COSEM对象模型,IEC62056参照OSI参考模型,制定了简化的三层通信模型,包括应用层、数据链路层(或中间协议层)和物理层,如图6所示。COSEM应用层完成对COSEM对象的属性和方法的访问&#xff…

01MFC建立单个文件类型——画线

文章目录 选择模式初始化文件作用解析各初始化文件解析类导向创建鼠标按键按下抬起操作函数添加一个变量记录起始位置注意事项代码实现效果图虚实/颜色线选择模式 初始化文件作用解析 运行: 各初始化文件解析 MFC(Microsoft Foundation Classes)是一个C++类库,用于在Win…

昇思25天学习打卡营第16天|基于MindSpore通过GPT实现情感分类

今天的这个代码几乎没有任何解释,结合之前GPT生成文本摘要的代码。 大概记录一下 import numpy as np # 导入NumPy库def process_dataset(dataset, tokenizer, max_seq_len512, batch_size4, shuffleFalse): # 判断当前设备是否为Ascend,如果是ascen的…

防御课综合实验

实验拓扑: 实验要求: 1、DMZ区内的服务器,办公区仅能在办公时间内(9点到18点)可以访问,生产区的设备全天可以访问 2、生产区不允许访问互联网,办公区和游客区允许访问互联网 3、办公区设备10…

二叉平衡树(左单旋,右单旋,左右双旋、右左双旋)

一、AVL树(二叉平衡树:高度平衡的二叉搜索树) 0、二叉平衡树 左右子树高度差不超过1的二叉搜索树。 public class AVLTree{static class AVLTreeNode {public TreeNode left null; // 节点的左孩子public TreeNode right null; // 节点的…

Java 中的异常处理机制是如何工作的?请解释 try-catch-finally 的基本用法?

Java中的异常处理机制是确保程序稳健性的重要组成部分,它允许程序在遇到错误或异常情况时,能够优雅地处理问题,而不是直接崩溃。 这一机制的核心在于使用try-catch-finally结构,以及通过throw和throws关键字来抛出和声明异常。 …

基于Transformer的端到端的目标检测 | 读论文

本文正在参加 人工智能创作者扶持计划 提及到计算机视觉的目标检测,我们一般会最先想到卷积神经网络(CNN),因为这算是目标检测领域的开山之作了,在很长的一段时间里人们都折服于卷积神经网络在图像处理领域的优势&…

Vue2.0和Vue3.0的区别?

Vue.js 3.0 相较于 Vue.js 2.0 在多个方面进行了改进和优化,主要包括以下几点: 性能提升: Vue 3.0 使用了新的响应式系统,称为“Proxy-based”,相比于 Vue 2.0 的“Object.defineProperty”,更加高效。 编…

【深度学习基础】安装包报错——MAC M3-MAX芯片安装scikit-learn库报错。

目录 一、问题描述二、解决方法 一、问题描述 首先想安装scikit-learn库在mac终端显示顺利安装完成,但是测试的时候报错如下所示: /opt/anaconda3/envs/dtc/bin/python /Users/chenfaquan/PycharmProjects/TimeSeries/data_create.py Traceback (most…

论文 | REACT: SYNERGIZING REASONING AND ACTING INLANGUAGE MODELS

本文首先认为,到目前为止,LLM 在语言理解方面令人印象深刻,它们已被用来生成 CoT(思想链)来解决一些问题,它们也被用于执行和计划生成。 尽管这两者是分开研究的,但本文旨在以交错的方式将推理…

JSP入门基础

JSP入门基础 软件开发环境这门课程的复习资料 Web开发技术概述 URL的组成部分 协议、主机DNS名或IP地址和文件名 Tomcat服务器 Tomcat服务器的默认端口号是8080 概念 软件开发环境是围绕着软件开发的一定目标而组织在一起的一组相关软件工具的有机集合 JSP和HTML的区别…

range()用法

range(n):是Python中的函数,作用是可以生成 [0, n)之间的正数range(a,b) :生成[a,b)之间的正数数字,不包含brange(start, end, step):生成的数值规则--- [start, end) 步长是 step,默认 1 详见:http://t.csdnimg.cn/7…

科研入门笔记

自学参考: 沐神论文精读系列 如何读论文 通常,一篇论文的结构为: title标题abstract摘要introduction介绍method方法experiments实验conclusion结论 一篇论文可以考虑读1~3遍 第一遍 海选:标题、摘要、结论,选读方…

SPE连接器技术革新汽车制造业

概述 新的SPE标准在汽车制造业中的应用正日益受到重视,它不仅推动了汽车通信技术的革新,还对汽车性能测试方法产生了深远影响。本文将详细探讨SPE标准在汽车制造业中的应用案例分析,以及它对供应链的挑战与机遇。 SPE标准在汽车制造业中的应…