Pytorch:torch.nn.utils.clip_grad_norm_梯度截断_解读

torch.nn.utils.clip_grad_norm_函数主要作用:

  神经网络深度逐渐增加,网络参数量增多的时候,容易引起梯度消失和梯度爆炸。对于梯度爆炸问题,解决方法之一便是进行梯度剪裁torch.nn.utils.clip_grad_norm_(),即设置一个梯度大小的上限

注:旧版为torch.nn.utils.clip_grad_norm()

函数参数:

官网链接:https://pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_norm_.html

torch.nn.utils.clip_grad_norm_(parameters, max_norm, norm_type=2.0, error_if_nonfinite=False, foreach=None)

“Clips gradient norm of an iterable of parameters. The norm is computed over all gradients together, as if they were concatenated into a single vector. Gradients are modified in-place.”

“对一组可迭代(网络)参数的梯度范数进行裁剪。效果如同将所有参数连接成单个向量来计算范数。梯度原位修改。”

Parameters

  • parameters (Iterable[Tensor] or Tensor) – 实施梯度裁剪的可迭代网络参数
    an iterable of Tensors or a single Tensor that will have gradients normalized(一个由张量或单个张量组成的可迭代对象(模型参数),将梯度归一化)

  • max_norm (float) – 该组网络参数梯度的范数上限
    max norm of the gradients(梯度的最大值)

  • norm_type (float) –范数类型
    type of the used p-norm. Can be ‘inf’ for infinity norm.(所使用的范数类型。默认为L2范数,可以是无穷大范数(‘inf’))

  • error_if_nonfinite (bool)
    if True, an error is thrown if the total norm of the gradients from parameters is nan, inf, or -inf. Default: False (will switch to True in the future)

  • foreach (bool)
    use the faster foreach-based implementation. If None, use the foreach implementation for CUDA and CPU native tensors and silently fall back to the slow implementation for other device types. Default: None

源码解读:

参考:https://blog.csdn.net/Mikeyboi/article/details/119522689
(建议大家看看源码,更好理解函数意义,有注释)

def clip_grad_norm_(parameters, max_norm, norm_type=2):# 处理传入的三个参数。# 首先将parameters中的非空网络参数存入一个列表,# 然后将max_norm和norm_type类型强制为浮点数。if isinstance(parameters, torch.Tensor):parameters = [parameters]parameters = list(filter(lambda p: p.grad is not None, parameters))max_norm = float(max_norm)norm_type = float(norm_type)#对无穷范数进行了单独计算,即取所有网络参数梯度范数中的最大值,定义为total_normif norm_type == inf:total_norm = max(p.grad.data.abs().max() for p in parameters)# 对于其他范数,计算所有网络参数梯度范数之和,再归一化,# 即等价于把所有网络参数放入一个向量,再对向量计算范数。将结果定义为total_normelse:total_norm = 0for p in parameters:param_norm = p.grad.data.norm(norm_type)total_norm += param_norm.item() ** norm_type # norm_type=2 求平方(二范数)total_norm = total_norm ** (1. / norm_type) # norm_type=2 等价于 开根号# 最后定义了一个“裁剪系数”变量clip_coef,为传入参数max_norm和total_norm的比值(+1e-6防止分母为0的情况)。# 如果max_norm > total_norm,即没有溢出预设上限,则不对梯度进行修改。# 反之则以clip_coef为系数对全部梯度进行惩罚,使最后的全部梯度范数归一化至max_norm的值。# 注意该方法返回了一个 total_norm,实际应用时可以通过该方法得到网络参数梯度的范数,以便确定合理的max_norm值。clip_coef = max_norm / (total_norm + 1e-6)if clip_coef < 1:for p in parameters:p.grad.data.mul_(clip_coef)return total_norm

使用方法及分析:

应用逻辑为:

  1. 先计算梯度;
  2. 裁剪梯度(在函数内部会判断是否需要裁剪,具体看源码解读);
  3. 最后更新网络参数。

因此 torch.nn.utils.clip_grad_norm_() 的使用应该在loss.backward() 之后,optimizer.step() 之前,

在U-Net中如下:

optimizer.zero_grad(set_to_none=True)
grad_scaler.scale(loss).backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)
grad_scaler.step(optimizer)
grad_scaler.update()

参考:https://blog.csdn.net/zhaohongfei_358/article/details/122820992

注意:

  • 从上面文章可以看到,clip_grad_norm 最后就是对所有的梯度乘以一个 clip_coefp.grad.data.mul_(clip_coef)),而且乘的前提是clip_coef一定是小于1的,所以,clip_grad_norm 只解决梯度爆炸问题,不解决梯度消失问题
  • clip_coef的定义**clip_coef = max_norm / (total_norm + 1e-6)** 可以知道:max_norm越大,对于梯度爆炸的解决越柔和,max_norm越小,对梯度爆炸的解决越狠

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/238024.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

CMD中文名称修改

会引发的问题 1. jupyter notebook运行出现Bad file descriptor (bundled\zeromq\src\epoll.cpp:100) 2. 用Anaconda或pycharm运行jupyter notebook时候&#xff0c;创建ipynb文件没一会儿就开始报错,且没法运行代码 3. 使用opencv时报错&#xff0c;opencv不支持中文路径&…

Python 爬虫之下载视频(二)

爬取某Y的视频链接和标题 文章目录 爬取某Y的视频链接和标题前言一、基本思路二、程序解析阶段三、程序处理阶段总结 前言 这篇内容就简单给大家写个如何从网页上爬取某B主 主页 页面上所有的视频链接和视频标题。 这篇是基础好好看&#xff0c;下篇会根据这篇的结果做一个批…

如何开发专属花店展示平台小程序?

如今&#xff0c;微信小程序已经成为了花店行业拓展客户资源的重要工具。通过开发一个专属花店小程序&#xff0c;你可以为自己的花店带来更多的曝光和客户资源。那么&#xff0c;如何开发一个专属花店小程序呢&#xff1f;接下来&#xff0c;我们将一步步为你详细讲解。 首先&…

产能过剩的今天,企业的方向在哪里?

随着经济的发展和技术的进步&#xff0c;许多行业都面临着产能过剩的问题。在产能过剩的背景下&#xff0c;企业如何找到新的发展方向&#xff0c;成为了一个亟待解决的问题。本文将探讨产能过剩时代下&#xff0c;企业应该如何寻找新的发展之路。 接下来我们就来看看当今的产…

共建还是对抗?BTC 铭文风波中开发者、矿工与社区的平衡艺术

近期&#xff0c;比特币铭文正加速进入一场争议与危机的漩涡。12 月 6 日&#xff0c;比特币核心开发人员 Luke Dashjr 在 X 表示&#xff0c;铭文&#xff08;Inscriptions&#xff09;正在利用比特币核心客户端 Bitcoin Core 的一个漏洞向区块链发送垃圾信息&#xff0c;Bitc…

MUX VLAN配置

MUX VLAN简介 产生背景 MUX VLAN&#xff08;Multiplex VLAN&#xff09;提供了一种通过VLAN进行网络资源控制的机制。 例如&#xff0c;在企业网络中&#xff0c;企业员工和企业客户可以访问企业的服务器。对于企业来说&#xff0c;希望企业内部员工之间可以互相交流&#…

使用工具类Exectors创建线程池

大型并发项目 不能使用Executors 通过ThreadPoolExector的方式 核心线程配置方式: 计算密集型的任务 核心线程数量 CPU的核数 1 IO密集型的任务 核心线程数量 CPU的核数*2 演示: Callable import java.util.concurrent.Callable;public class MyCallable implements Callab…

R语言贝叶斯网络模型、INLA下的贝叶斯回归、R语言现代贝叶斯统计学方法、R语言混合效应(多水平/层次/嵌套)模型

目录 ㈠ 基于R语言的贝叶斯网络模型的实践技术应用 ㈡ R语言贝叶斯方法在生态环境领域中的高阶技术应用 ㈢ 基于R语言贝叶斯进阶:INLA下的贝叶斯回归、生存分析、随机游走、广义可加模型、极端数据的贝叶斯分析 ㈣ 基于R语言的现代贝叶斯统计学方法&#xff08;贝叶斯参数估…

提升Elasticsearch性能的一些经验

分片查询缓存(Shard Request Cache) ES 层面的缓存实现,封装在 IndicesRequestCache 类中。缓存的 Key 是整个客户端请求,缓存内容为单个分片的查询结果。主要作用是对聚合的缓存,查询结果中被缓存的内容主要包括:Aggregations(聚合结果)、Hits.total、以及 Suggestion…

基于ssm房屋租赁平台的设计与开发论文

摘 要 目前对于在外的人员来说租赁房屋是最基本的问题。对于房屋的租赁可以选择直接找房东、找专业的房屋租赁公司和自己在网上找房屋。自己找房东的问题在于需要时间&#xff0c;而且对于需要提前租赁房屋的需要多次跑到小区&#xff0c;找中介租赁房屋的问题在于费用问题&am…

数据结构---算法的时间复杂度

文章目录 前言计算机重要存储数据结构与算法数据结构概念算法 数据库概念 算法的复杂度时间复杂度概念为什么有时间复杂度大O渐进表示法时间复杂度实例实例1&#xff1a;时间复杂度&#xff1a;O&#xff08;N&#xff09;实例2&#xff1a;这里输入参数是不确定的所以 时间复杂…

养老院自助饮水机(字符设备驱动)

目录 1、项目背景 2、驱动程序 2.1 三层架构 2.2 驱动三要素 2.3 字符设备驱动 2.3.1 驱动模块 2.3.2 应用层 3、设计实现 3.1 项目设计 3.2 项目实现 3.2.1 驱动模块代码 3.2.2 用户层代码 4、功能特性 5、技术分析 6. 总结与未来展望 1、项目背景 养老院的老人…

【算法题】3. 无重复字符的最长子串

​题目 给定一个字符串 s &#xff0c;请你找出其中不含有重复字符的 最长子串 的长度。 例子1&#xff1a; 输入: s "abcabcbb" 输出: 3 解释: 因为无重复字符的最长子串是 "abc"&#xff0c;所以其长度为 3。 例子2&#xff1a; 输入: s "b…

研究生课程 |《数值分析》复习

搭配往年真题册食用最佳。

Java_集合进阶Map集合

一、Map集合 1.1 Map概述体系 各位同学&#xff0c;前面我们已经把单列集合学习完了&#xff0c;接下来我们要学习的是双列集合。首先我们还是先认识一下什么是双列集合。 所谓双列集合&#xff0c;就是说集合中的元素是一对一对的。Map集合中的每一个元素是以keyvalue的形式…

mask rcnn训练基于labelme生成的数据集

1.下载mask rcnn源码 此处使用的mask rcnn源码来自于B站博主霹雳吧啦Wz 2.安装labelme sudo apt install python3-pyqt5 pip install labelme如果运行出现QT的错误&#xff0c;可能是与我一样遇到自己装了C版本的QT 解决&#xff1a;运行命令 unset LD_LIBRARY_PATH2.使用lab…

redis主从复制(在虚拟机centos的docker下)

1.安装docker Docker安装(CentOS)简单使用-CSDN博客 2.编辑3个redis配置 cd /etc mkdir redis-ms cd redis-ms/ vim redis6379.conf vim redis6380.conf vim redis6381.conf# master #端口号 port 6379#设置客户端连接后进行任何其他指定前需要使用的密码 requirepass 12345…

C++基础-拷贝构造函数详解

目录 一、概述 二、拷贝构造函数调用时机 三、构造函数调用规则 四、浅拷贝与深拷贝

leetcode454. 四数相加 II

题目描述 给你四个整数数组 nums1、nums2、nums3 和 nums4 &#xff0c;数组长度都是 n &#xff0c;请你计算有多少个元组 (i, j, k, l) 能满足&#xff1a; 0 < i, j, k, l < n nums1[i] nums2[j] nums3[k] nums4[l] 0示例 1&#xff1a; 输入&#xff1a;nums1…

springMVC-处理json和HttpMessageConverter<T>

细节说明&#xff1a;目标方法正常返回JSON需要的数据&#xff0c;可以是一个对象&#xff0c;也可以是一个集合&#xff0c;这里我们返回的是一个Dog对象>转成Json数据格式 示例案例&#xff1a; 在springmve中&#xff0c;如果我们返回一个集合List等&#xff0c;或者返回…