pytorch 中 drop_last与 nn.Parameter

1. drop_last

在使用深度学习,pytorch 的DataLoader 中,

from torch.utils.data import DataLoader# Define your dataset and other necessary configurations
# Create DataLoader
train_loader = DataLoader(dataset, batch_size=batch_size, drop_last=True)

drop_last=True :DataLoader 中的此设置会删除不完整的最后一批(如果它小于指定的批量大小)。这确保了训练期间处理的每个批次包含相同数量的样本。

1.1 drop_last = True

dataset_size = 100
batch_size = 32
train_loader = DataLoader(dataset, batch_size=batch_size, drop_last=True)

使用 drop_last=True ,DataLoader 确保每个批次包含 32 个样本,删除不完整的最终批次。例如,在这种情况下,训练期间将处理 3 个批次(32、32、32),其余 4 个样本将不会用于训练。

适用情况:

当网络模型的初始化中,需要用到batch size 时, 这种情况下, 需要注意的是此时, drop_last = False , 会影响网络模型结构, 由于模型的初始化过程中,使用了batch size 参数, 所有此时应该设置为 True;

1.2 drop_last = False

而当 drop_last = False, 当最后一个批次中, 剩余的样本个数不足 batch 样本数目时, 会保留这剩余的样本,使用剩余的样本进行训练。

当数据不均衡, 并且某一类中样本数量很少时, 此时 drop_last = True 会严重影响到模型的精度,此时应该使用 False;

原因是,本身的某个类别中训练集和测试集的数量就已经小于batch size 时, 此时使用 drop last, 会严重该类别的训练和测试效果。

如下面的情况:

遇到了这样的问题。一共16类,第15 16类的训练集数量是15、15,测试集分别为14、5。其他1-14类训练集分别有50个,测试集均为200左右。

当我在pytorch的dataloader中设置了drop_last=True时,无论怎么训练,使用怎么样的数据增强,第15 16类才测试集上的准确率永远为0.

原因分析:
当dataloader设置了drop_last=True时,在训练时如果数据总量无法整除batch_size,那么这个dataloader就会丢掉最后一个batch,也就是说训练的时候有部分数据是被丢掉的。而我遇到的情况可能是正好把第15 16类的测试数据给丢掉了部分,导致模型很好的学习到这两类的特征。

解决方案:
将drop_last改为False,即可解决该问题。

2. nn.Parameter

在深度学习训练过程中, 通常需要自己创建出一个初始化的张量, 并且希望通过模型训练过程中, 更新该张量。

torch.randn(bt, 3, 256)

而普通的使用torch 随机初始化的方式,如上面的这种方式,
在大多数情况下,随机初始化张量不会使其参数变得可学习。在没有任何相关学习过程或梯度更新的情况下随机初始化的张量在网络训练期间不会适应或改变。

2.1 可学习参数

为了使得创建的张量,在网络训练过程中,可以得到更新。

在 PyTorch 中, nn.Parameter 是一个继承自 torch.Tensor 的类。它允许您向框架指示该张量应被视为模型参数的一部分。当您将其分配为 nn.Module 中的属性时,它在优化过程中变得可训练。

import torch
import torch.nn as nn# Creating a tensor as a learnable parameter
param_tensor = nn.Parameter(torch.randn(1, 3))

param_tensor 将在训练过程中进行优化因为它们被视为模型可学习参数的一部分。

放到 cuda 设备上

 self.cuda_param = nn.Parameter(torch.randn(1, 2).cuda())

2.2 nn.ParameterList

同样, 当想创建一个列表都是可学习的参数时, 使用如下的方式;

self.parameters = nn.ParameterList([nn.Parameter(torch.randn(256)) for _ in range(5)])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/213710.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

vue项目列表跳转详情返回列表页保留搜索条件

需求 列表进入详情后,返回详情的时候保留搜索的条件,第几页进入的返回还在第几页 1.在详情页设置定义一个字段 mounted() {sessionStorage.setItem("msgInfo", true);},2.在获取列表数据的时候在mounted里面判断定义的字段 if (sessionStor…

【EI会议征稿】第二届纯数学、应用数学与计算数学国际学术会议(PACM 2024)

第二届纯数学、应用数学与计算数学国际学术会议(PACM 2024) 2024 2nd International Cnference on Pure, Applied and Computational Mathematics (PACM 2024) 第二届纯数学、应用数学计算数学国际学术会议 (PACM2024) 将于2024年1月19-21日在中国厦门隆…

报错:AttributeError: ‘DataFrame‘ object has no attribute ‘reshape‘

这个错误通常发生在你试图在 Pandas DataFrame 上直接使用 reshape 方法时。reshape 方法通常与 NumPy 数组相关联,而不是 Pandas DataFrame。 如果你正在使用 Pandas DataFrame 并希望重新塑造它,你应该使用 Pandas 的重塑函数,如 pivot、m…

linux常用命令大全50个Linux常用命令

Linux有许多常用的命令,这些命令可以用来管理文件、运行程序、查看系统状态等。以下是一些常用的Linux命令: pwd:显示当前所在的工作目录的全路径名称。cd:用于更改当前工作目录,例如,若要进入Documents目…

UE5 树叶飘落 学习笔记

一个Plane是由两个三角形构成的,所以World Position Offset,只会从中间这条线折叠 所有材质 这里前几篇博客有说这种逻辑,就是做一个对称的渐变数值 这里用粒子的A值来做树叶折叠的程度,当然你也可以用Dynamic Param 这样就可以让…

Android 11.0 长按按键切换SIM卡默认移动数据

Android 11.0 长按按键切换SIM卡默认移动数据 近来收到客户需求想要通过长按按键实现切换SIM卡默认移动数据的功能,该功能主要通过长按按键发送广播来实现,具体修改参照如下: 首先创建广播,具体修改参照如下: /vend…

麒麟KYLINOS上删除多余有线连接

原文链接:麒麟KYLINOS上删除多余网络有线连接 hello,大家好啊,今天我要给大家介绍的是在麒麟KYLINOS操作系统中,如何删除通过Parallels Desktop虚拟机安装时产生的多余有线连接。在使用Parallels Desktop虚拟机安装麒麟桌面操作系…

C/C++ 题目:给定字符串s1和s2,判断s1是否是s2的子序列

判断子序列一个字符串是否是另一个字符串的子序列 解释:字符串的一个子序列是原始字符串删除一些(也可以不删除)字符,不改变剩余字符相对位置形成的新字符串。 如,"ace"是"abcde"的一个子序…

服务器数据恢复—raid5少盘状态下新建raid5如何恢复原raid5数据?

服务器数据恢复环境: 一台服务器上搭建了一组由5块硬盘组建的raid5阵列,服务器上层存放单位重要数据,无备份文件。 服务器故障&分析: 服务器上raid5有一块硬盘掉线,外聘运维人员在没有了解服务器具体情况下&#x…

如何在linux中使用rpm管理软件

本章主要介绍使用rpm对软件包进行管理。 使用rpm查询软件的信息 使用rpm安装及卸载软件 使用rpm对软件进行更新 使用rpm对软件进行验证 rpm 全称是redhat package manager,后来改成rpm package manager,这是根据源 码包编译出来的包。先从光盘中拷贝一…

[算法每日一练]-双指针 (保姆级教程篇 1) #A-B数对 #求和 #元音字母 #最短连续子数组 #无重复字符的最长子串 #最小子串覆盖 #方块桶

目录 A-B数对 解法一:双指针 解法二:STL二分查找 解法三:map 求和 元音字母 最短连续子数组 无重复字符的最长子串 最小子串覆盖 方块桶 双指针特点:双指针绝不回头 A-B数对 解法一:双指针 先把数列排列成…

《C++新经典设计模式》之第8章 外观模式

《C新经典设计模式》之第8章 外观模式 外观模式.cpp 外观模式.cpp #include <iostream> #include <memory> using namespace std;// 中间层角色&#xff0c;隔离接口&#xff0c;两部分模块通过中间层打交道 // 提供简单接口&#xff0c;不与底层直接打交道 // 提…

Grounding DINO、TAG2TEXT、RAM、RAM++论文解读

提示&#xff1a;Grounding DINO、TAG2TEXT、RAM、RAM论文解读 文章目录 前言一、Grounding DINO: Marrying DINO with Grounded Pre-Training for Open-Set Object Detection1、摘要2、背景3、部分文献翻译4、贡献5、模型结构解读a.模型整体结构b.特征增强结构c.解码结构 6、实…

使用Sourcetrail解析C项目

阅读源码的工具很多&#xff0c;今天给大家推荐一款别具一格的源码阅读神器。 它就是 Sourcetrail&#xff0c;一个免费开源、跨平台的可视化源码探索项目 使用

释放深度学习的力量:使用 CUDA 和 Turing GPU 构建 AI

深度学习是一种人工智能的分支,它使用神经网络模拟人类大脑的学习过程,从大量的数据中学习特征和规律。深度学习已经彻底改变了无数领域,从图像和语音识别到自然语言处理和自动驾驶汽车。但是,要充分利用深度学习的强大功能,需要强大的工具,而 NVIDIA 的 Turing GPU 就是…

Faster R-CNN pytorch源码血细胞检测实战(二)数据增强

Faster R-CNN pytorch源码血细胞检测实战&#xff08;二&#xff09;数据增强 文章目录 Faster R-CNN pytorch源码血细胞检测实战&#xff08;二&#xff09;数据增强1. 资源&参考2. 数据增强2.1 代码运行2.2 文件存放 3 数据集划分4. 训练&测试5. 总结 1. 资源&参…

静态SOCKS5的未来发展趋势和新兴应用场景

随着网络技术的不断发展和进步&#xff0c;静态SOCKS5代理也在不断地完善和发展。未来&#xff0c;静态SOCKS5代理将会呈现以下发展趋势和新兴应用场景。 一、发展趋势 安全性更高&#xff1a;随着网络安全问题的日益突出&#xff0c;用户对代理服务器的安全性要求也越来越高…

AcWing 3425:小白鼠排队 ← 北京大学考研机试题

【题目来源】https://www.acwing.com/problem/content/3428/【题目描述】 N 只小白鼠&#xff0c;每只鼠头上戴着一顶有颜色的帽子。 现在称出每只白鼠的重量&#xff0c;要求按照白鼠重量从大到小的顺序输出它们头上帽子的颜色。 帽子的颜色用 red&#xff0c;blue 等字符串来…

c#下载微信跟支付宝交易账单

下载微信交易账单 //账单日期只能下载前一天的string datetime DateTime.Now.AddDays(-1).ToString("yyyy-MM-dd");string body "";string URL "/v3/bill/fundflowbill" "?bill_date" datetime;//生成签名认证var auth BuildAu…

nodejs 异步函数加 await 和不加 await 的区别

在 nodejs 中&#xff0c;异步函数加上 await 和不加 await 的区别在于函数的返回值。 当一个异步函数加上 await 时&#xff0c;它会暂停当前函数的执行&#xff0c;直到异步操作完成并返回结果。这意味着可以直接使用异步操作的结果&#xff0c;而不需要使用 .then() 方法或…