动手学深度学习11 权重衰退

动手学深度学习11 权重衰退

  • 1. 权重衰退
  • 2. 代码实现
  • 3. QA

视频: https://www.bilibili.com/video/BV1UK4y1o7dy/?spm_id_from=autoNext&vd_source=eb04c9a33e87ceba9c9a2e5f09752ef8
电子书: ttps://zh-v2.d2l.ai/chapter_multilayer-perceptrons/weight-decay.html
课件: https://courses.d2l.ai/zh-v2/assets/pdfs/part-0_16.pdf

1. 权重衰退

在这里插入图片描述
在这里插入图片描述
为了便于讨论,我们假设训练的模型中只有w1和w2两个参数
但是我们觉得100这个数还是太大了,怎么办?我们在损失函数中添加一项,这一项是w1的平方+w2的平方
罚的项是L2正则项,因为离原点越近,正则项越小
绿线是损失函数的取值, 黄线是惩罚项的取值, 两者都是圈越大取值越大
两个圆锥的交点可能是最优点!!!!!!!
这里的平衡点可以看看kkt条件来理解,但其实增广拉格朗日方程本身和subject to ||w||^2<theta的模型是等价的,在这里用subject to 的模型来理解可能更简单
其实就是新的损失函数由两项组成,此时求导后,梯度有两项了,一项将w向绿线中心拉,一项将w向原点拉进,最后将在w*点达到一个平衡
在这里插入图片描述
在这里插入图片描述
每次都在权重更新之前对w做了一次放小(ηλ<1), 所以叫做权重衰退。
lambda是控制模型参数的超参数
在这里插入图片描述

2. 代码实现

%matplotlib inline
import torch
from torch import nn
from d2l import torch as d2ln_train, n_test, num_inputs, batch_size = 20, 100, 200, 5
true_w, true_b = torch.ones((num_inputs, 1)) * 0.01, 0.05
train_data = d2l.synthetic_data(true_w, true_b, n_train)
train_iter = d2l.load_array(train_data, batch_size)
test_data = d2l.synthetic_data(true_w, true_b, n_test)
test_iter = d2l.load_array(test_data, batch_size, is_train=False)
print(train_data[0], train_data[1])def init_params():w = torch.normal(0, 1, size=(num_inputs, 1), requires_grad=True)b = torch.zeros(1, requires_grad=True)return [w, b]def l2_penalty(w):return torch.sum(w.pow(2)) / 2  # w.pow(2) 幂函数  w的2次幂def train(lambd):w, b = init_params()net, loss = lambda X: d2l.linreg(X, w, b), d2l.squared_lossnum_epochs, lr = 100, 0.03animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log', xlim=[5, num_epochs], legend=['train', 'test'])for epoch in range(num_epochs):for X, y in train_iter:# 增加了L2范数惩罚项# 广播机制使L2_penalty(w)成为一个长度为batch_size的向量l = loss(net(X), y) + lambd * l2_penalty(w)l.sum().backward()d2l.sgd([w,b], lr, batch_size)if (epoch+1)%5 == 0:animator.add(epoch+1, (d2l.evaluate_loss(net, train_iter, loss), d2l.evaluate_loss(net, test_iter, loss)))print("w的L2范数是:", torch.norm(w).item())train(lambd=0)
train(lambd=3)
w的L2范数是: 12.036040306091309  # lambda=0
w的L2范数是: 0.03226672485470772 # lambda=3

lambd = 0
在这里插入图片描述
lambd = 3
在这里插入图片描述

3. QA

  1. pytorch是否支持复数神经网络,nn输入输出权重激活函数都是复数,loss则是一个复数到实数的映射。
    不支持。复数是把一个数做到二维,加一个第二维实现效果,不一定要用复数。

  2. 为什么参数不过大模型复杂度就低呢?
    限制整个模型在优化的时候,是在很小的范围取参数。只能在在一些比较平滑的模型曲线上取参数,这样就学不出一个很复杂的模型。

  3. 如果是L1范数的话如何更新权重?
    把上面的l2_penalty()函数内容,换成torch.abs(w)尝试
    w的L1范数是: 0.7089653015136719

    return torch.sum(torch.abs(w))
    ......
    train(lambd=3)
    

    在这里插入图片描述

  4. 实践中权重衰减的值一般设置为多少好?有时感觉权重衰减的效果并不好
    一般取1e-2, 1e-3, 1e-4 (0.01, 0.001, 0.0001)。 权重衰退有一点点,但是不要太指望这个方法。如果模型很复杂,权重衰退没有很好的效果。可以试下 1e-3,效果不好换别的方法。

  5. 损失函数正则项中的2为什么使用上标而不是下标?之前介绍L2范数使用的是下标,是相同的概念?不太理解不同的数学记法。
    上标是平方的意思L2的平方项,L2范数是在下标有2的,但是L2是默认的范数,所以一般都是省略的。

  6. 为什么要把w往小的啦?如果最优解的w就是比较大的数,那权重衰减是不是会有反作用?
    数据是有噪音的,可能学到的不是真正的最优解,lambda过大过小都会和最优解离得比较远,所以选择的lambda值要合适。

  7. L2 norm理解是让w的值变得更平均,没有突出的值为什么这样调整可以使得拟合更好呢?
    不是让w更平均,而是让值更小一点。当没有lambda学到的范数很大的情况下,可以用lambda往回拉。但模型没有overfitting的时候,往回拉这种操作是没用的。

  8. weight_decay的值一般怎么选择?有哪些实践经验?
    试试1e-3这些值,没效果换方法。

  9. 实际应用中,lambda作为超参数是一次次在训练后调整优化了吗?调整到什么时候达到满意的效果?有什么方法论或最佳实践吗?
    不知道什么时候最优。看看训练集和测试集曲线的差距,调下参数。
    10 在解释数据噪音的时候,说如果噪音越大,w就比较大,是经验所得还是可以证明?
    可以证明。噪音越大,学的w就越大,可以尝试一下。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/825922.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【MySQL 数据宝典】【内存结构】- 004 自适应哈希索引

自适应哈希索引 https://developer.aliyun.com/article/1230086 什么是自适应哈希索引&#xff1f; 自适应哈希索引是MySQL InnoDB存储引擎中的一种索引结构&#xff0c;用于加速查询。它根据查询模式和数据分布动态地调整自身的大小&#xff0c;以提高性能。 上图就是通过…

Redis中的订阅发布和事务(一)

订阅发布 PUBSUB NUMSUB PUBSUB NUMSUB [channel-1 channel-2… channel-n]子命令接受任意多个频道作为输入参数&#xff0c;并返回这些频道的订阅者数量。 这个子命令是通过pubsub_channels字典中找到频道对应的订阅者链表&#xff0c;然后返回订阅者链表的长度来实现的(订阅…

Nuclei 减少漏报的使用小技巧

在最近工作的渗透测试项目中发现Nuclei存在一个问题&#xff0c;就是相同的网站连续扫描多次会出现漏报的情况&#xff0c;此前没有注意过这个情况&#xff0c;所以写篇文章记录一下。 在此之前我的常用命令都是一把梭&#xff0c;有就有没有就继续其他测试 $ nuclei -u htt…

代码随想录算法训练营第四十五天| 70. 爬楼梯 (进阶),322. 零钱兑换 ,279.完全平方数

题目与题解 70. 爬楼梯 &#xff08;进阶&#xff09; 题目链接&#xff1a;70. 爬楼梯 &#xff08;进阶&#xff09; 代码随想录题解&#xff1a;70. 爬楼梯 &#xff08;进阶&#xff09; 解题思路&#xff1a; 这道题要求每次可以爬1-m层的楼梯&#xff0c;最终爬到n&…

微服务架构中的业务数据可视化设计

目录 1.概要设计 1.1明确可视化目标 1.2数据整合与标准化 1.3选择合适的数据可视化工具 1.4设计可视化界面 1.5 实时更新与优化 2.技术实现 2.1数据采集与整合 2.2数据处理与转换 2.3数据存储 2.4 数据可视化 2.5 实时数据更新 2.6 安全性与权限控制 2.7 监控与日…

【ZZULIOJ】1072: 青蛙爬井(Java)

目录 题目描述 输入 输出 样例输入 Copy 样例输出 Copy 提示 code 题目描述 有一口深度为high米的水井&#xff0c;井底有一只青蛙&#xff0c;它每天白天能够沿井壁向上爬up米&#xff0c;夜里则顺井壁向下滑down米&#xff0c;若青蛙从某个早晨开始向外爬&#xff0c…

患者关系管理系统功能详解

脉购健康管理系统&#xff08;软件&#xff09;包含&#xff1a;客户开卡、健康档案、问卷调查、问诊表、自动设置标签、自动随访、健康干预、健康调养、历年指标趋势分析、疾病风险评估、饮食/运动/心理健康建议、同步检查报告数据、随访记录、随访电话录音、健康阶段总结、打…

Java - 阿里巴巴命名规范

文章目录 前言一、编程规约(一) 命名风格(二) 常量定义(三) 代码格式(四) OOP 规约(五) 日期时间(六) 集合处理(七) 并发处理(八) 控制语句(九) 注释规约(十) 前后端规约(十一) 其他 二、异常日志(一) 错误码(二) 异常处理(三) 日志规约 三、单元测试四、安全规约五、MySQL 数据…

2024面试软件测试,常见的面试题(上)

一、综合素质 1、自我介绍 面试官您好&#xff0c;我叫XXX&#xff0c;一直从事车载软件测试&#xff0c;负责最多的是中控方面。 以下是我的一些优势&#xff1a; 车载的测试流程我是熟练掌握的&#xff0c;且能够独立编写测试用例。 平时BUG提交会使用到Jira&#xff0c;类似…

postgis源码编译安装-实操成功

依赖环境安装 sqlite3安装 https://www.sqlite.org/2024/sqlite-autoconf-3450200.tar.gz tar xvf sqlite-autoconf-3450200.tar cd sqlite-autoconf-3450200 mkdir -p /home/postgres/app/postgis/sqlite3 ./configure --prefix=/home/postgres/app/postgis/sqlite3 ma…

电缆检测仪的正确使用方法有哪些步骤?

电缆检测仪的正确使用方法是&#xff1a;首先&#xff0c;确保检测仪电源充足&#xff0c;设备完好无损&#xff1b;其次&#xff0c;根据电缆类型和故障类型选择合适的测试模式和参数&#xff1b;接着&#xff0c;将检测仪与电缆正确连接&#xff0c;确保接触良好&#xff1b;…

深入挖掘C语言 ----动态内存分配

开篇备忘录: "自给自足的光, 永远都不会暗" 目录 1. malloc和free1.1 malloc1.2 free 2. calloc和realloc2.1 calloc2.2 realloc 3. 总结C/C中程序内存区域划分 正文开始 1. malloc和free 1.1 malloc C语言提供了一个动态开辟内存的函数; void* malloc (size_t s…

python处理IP对应城市省份

python处理IP对应城市省份 IP地理地址库geoip2用法 数据包下载 数据包下载地址&#xff08;需要注册&#xff09; https://www.maxmind.com/en/accounts/258630/geoip/downloads 考虑到注册麻烦&#xff0c;可以到下面这个github的链接去直接下载 https://github.com/Hackl0…

AItoolchain相关技术学习

AItoolchain主要模块包括&#xff1a; 模型转换&#xff1a;将深度学习模型转换为特定硬件平台可以识别和执行的格式。嵌入式运行环境&#xff1a;提供异构模型的运行库支持&#xff0c;确保模型在目标设备上的运行效率。性能验证&#xff1a;包括静态和动态性能评估&#xff…

2024-9.python文件操作

文件操作 引言 到目前为止&#xff0c;我们做的一切操作&#xff0c;都是在内存里进行的&#xff0c;这样会有什么问题吗&#xff1f;如果一旦断电或发生意外关机了&#xff0c;那么你辛勤的工作成果将瞬间消失。是不是感觉事还挺大的呢&#xff1f;现在你是否感觉你的编程技…

【Java EE】依赖注入DI详解

文章目录 &#x1f334;什么是依赖注入&#x1f340;依赖注入的三种方法&#x1f338;属性注入(Field Injection)&#x1f338;构造方法注入&#x1f338;Setter注入&#x1f338;三种注入优缺点分析 &#x1f333;Autowired存在的问题&#x1f332;解决Autowired对应多个对象问…

动态库静态库linux

动态库静态库 静态库 静态库必须包含在可执行文件里&#xff0c;整个都要包含 缺点&#xff1a;消耗系统大&#xff0c;每个使用静态库的程序都要复制静态库&#xff08;浪费内存&#xff09; 影响使用场景&#xff1a; 在静态库内存小的时候&#xff0c;可以用来提升速度 制…

redis在Windows下设置静默启动

redis在Windows下设置静默启动 下载windows版redis,解压cmd命令行有窗口启动(这种启动方式&#xff0c;这个界面就不能关闭才会生效 注册成为服务&#xff0c;设置成开机启动或者手动启动(静默启动)清除缓存本地清除&#xff0c;直接打开redis-cli.exe本地远程连接清除缓存 下载…

投影矩阵(Projection Matrix)

在机器学习和数据分析中&#xff0c;投影矩阵是一个非常重要的工具&#xff0c;它主要用于将高维数据降维或者变换到新的坐标系中。这个过程通常被称为线性变换或投影。 过程&#xff1a; 假设我们有一个原始的高维数据集X&#xff0c;其中每一列代表一个特征&#xff0c;每一行…

Scala 03 —— Scala Puzzle 拓展

Scala 03 —— Scala Puzzle 拓展 文章目录 Scala 03 —— Scala Puzzle 拓展一、占位符二、模式匹配的变量和常量模式三、继承 成员声明的位置结果初始化顺序分析BMember 类BConstructor 类 四、缺省初始值与重载五、Scala的集合操作和集合类型保持一致性第一部分代码解释第二…