【深度学习笔记】09 权重衰减

09 权重衰减

    • 范数和权重衰减
    • 利用高维线性回归实现权重衰减
    • 权重衰减的简洁实现

范数和权重衰减

在训练参数化机器学习模型时,权重衰减(decay weight)是最广泛应用的正则化技术之一,它通常也被称为 L 2 L_2 L2正则化。这项技术通过函数与零的距离来衡量函数的复杂度,
因为在所有函数 f f f中,函数 f = 0 f = 0 f=0(所有输入都得到值 0 0 0
在某种意义上是最简单的。

一种简单的方法是通过线性函数
f ( x ) = w ⊤ x f(\mathbf{x}) = \mathbf{w}^\top \mathbf{x} f(x)=wx
中的权重向量的某个范数来度量其复杂性,
例如 ∥ w ∥ 2 \| \mathbf{w} \|^2 w2
要保证权重向量比较小,
最常用方法是将其范数作为惩罚项加到最小化损失的问题中。
将原来的训练目标最小化训练标签上的预测损失,
调整为最小化预测损失和惩罚项之和。

损失由下式给出:

L ( w , b ) = 1 n ∑ i = 1 n 1 2 ( w ⊤ x ( i ) + b − y ( i ) ) 2 . L(\mathbf{w}, b) = \frac{1}{n}\sum_{i=1}^n \frac{1}{2}\left(\mathbf{w}^\top \mathbf{x}^{(i)} + b - y^{(i)}\right)^2. L(w,b)=n1i=1n21(wx(i)+by(i))2.

x ( i ) \mathbf{x}^{(i)} x(i)是样本 i i i的特征,
y ( i ) y^{(i)} y(i)是样本 i i i的标签,
( w , b ) (\mathbf{w}, b) (w,b)是权重和偏置参数。

为了惩罚权重向量的大小,
必须以某种方式在损失函数中添加 ∥ w ∥ 2 \| \mathbf{w} \|^2 w2
我们通过正则化常数 λ \lambda λ来描述这种权衡,
这是一个非负超参数,我们使用验证数据拟合:

L ( w , b ) + λ 2 ∥ w ∥ 2 , L(\mathbf{w}, b) + \frac{\lambda}{2} \|\mathbf{w}\|^2, L(w,b)+2λw2,

对于 λ = 0 \lambda = 0 λ=0,我们恢复了原来的损失函数。
对于 λ > 0 \lambda > 0 λ>0,我们限制 ∥ w ∥ \| \mathbf{w} \| w的大小。
这里我们仍然除以 2 2 2:当我们取一个二次函数的导数时,
2 2 2 1 / 2 1/2 1/2会抵消。

通过平方 L 2 L_2 L2范数,我们去掉平方根,留下权重向量每个分量的平方和。
这使得惩罚的导数很容易计算:导数的和等于和的导数。

L 2 L_2 L2正则化回归的小批量随机梯度下降更新如下式:

w ← ( 1 − η λ ) w − η ∣ B ∣ ∑ i ∈ B x ( i ) ( w ⊤ x ( i ) + b − y ( i ) ) . \begin{aligned} \mathbf{w} & \leftarrow \left(1- \eta\lambda \right) \mathbf{w} - \frac{\eta}{|\mathcal{B}|} \sum_{i \in \mathcal{B}} \mathbf{x}^{(i)} \left(\mathbf{w}^\top \mathbf{x}^{(i)} + b - y^{(i)}\right). \end{aligned} w(1ηλ)wBηiBx(i)(wx(i)+by(i)).

我们根据估计值与观测值之间的差异来更新 w \mathbf{w} w
然而,我们同时也在试图将 w \mathbf{w} w的大小缩小到零。
这就是为什么这种方法有时被称为权重衰减
我们仅考虑惩罚项,优化算法在训练的每一步衰减权重。
与特征选择相比,权重衰减为我们提供了一种连续的机制来调整函数的复杂度。
较小的 λ \lambda λ值对应较少约束的 w \mathbf{w} w
而较大的 λ \lambda λ值对 w \mathbf{w} w的约束更大。

是否对相应的偏置 b 2 b^2 b2进行惩罚在不同的实践中会有所不同,
在神经网络的不同层中也会有所不同。
通常,网络输出层的偏置项不会被正则化。

利用高维线性回归实现权重衰减

%matplotlib inline
import torch
from torch import nn
from d2l import torch as d2l

首先生成数据,生成公式如下:

y = 0.05 + ∑ i = 1 d 0.01 x i + ϵ where  ϵ ∼ N ( 0 , 0.0 1 2 ) . y = 0.05 + \sum_{i = 1}^d 0.01 x_i + \epsilon \text{ where } \epsilon \sim \mathcal{N}(0, 0.01^2). y=0.05+i=1d0.01xi+ϵ where ϵN(0,0.012).

选择标签是关于输入的线性函数。
标签同时被均值为0,标准差为0.01高斯噪声破坏。
为了使过拟合的效果更加明显,我们可以将问题的维数增加到 d = 200 d = 200 d=200
并使用一个只包含20个样本的小训练集。

n_train, n_test, num_inputs, batch_size = 20, 100, 200, 5
true_w, true_b = torch.ones((num_inputs, 1)) * 0.01, 0.05
train_data = d2l.synthetic_data(true_w, true_b, n_train)
train_iter = d2l.load_array(train_data, batch_size)
test_data = d2l.synthetic_data(true_w, true_b, n_test)
test_iter = d2l.load_array(test_data, batch_size, is_train=False)

初始化模型参数

定义一个函数来随机初始化模型参数

def init_params():w = torch.normal(0, 1, size = (num_inputs, 1), requires_grad = True)b = torch.zeros(1, requires_grad = True)return [w, b]

定义 L 2 L_2 L2范数惩罚

def l2_penalty(w):return torch.sum(w.pow(2)) / 2

定义训练代码实现

下面的代码将模型拟合训练数据集,并在测试数据集上进行评估。

def train(lambd):w, b = init_params()net, loss = lambda X: d2l.linreg(X, w, b), d2l.squared_lossnum_epochs, lr = 100, 0.003animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',xlim=[5, num_epochs], legend=['train', 'test'])for epoch in range(num_epochs):for X, y in train_iter:# 增加了L2范数惩罚项,# 广播机制使l2_penalty(w)成为一个长度为batch_size的向量l = loss(net(X), y) + lambd * l2_penalty(w)l.sum().backward()d2l.sgd([w, b], lr, batch_size)if (epoch + 1) % 5 == 0:animator.add(epoch + 1, (d2l.evaluate_loss(net, train_iter, loss),d2l.evaluate_loss(net, test_iter, loss)))print('w的L2范数是:', torch.norm(w).item())

忽略正则化直接训练

用lamdb=0禁用权重衰减后运行代码。此时训练误差有所减少,但测试误差没有减少,这意味着出现了严重的过拟合。

train(lambd = 0)
w的L2范数是: 14.971677780151367

在这里插入图片描述

使用权重衰减

使用权重衰减来运行代码。此时训练误差增大,但测试误差减小。这正是我们期望从正则化中得到的效果。

train(lambd = 3)
w的L2范数是: 0.34405317902565

在这里插入图片描述

权重衰减的简洁实现

在实例化优化器时直接通过weight_decay指定weight decay超参数。默认情况下,PyTorch同时衰减权重和便宜。这里只为权重设置了weight_decay,所以偏置参数 b b b不会衰减。

def train_concise(wd):net = nn.Sequential(nn.Linear(num_inputs, 1))for param in net.parameters():param.data.normal_()loss = nn.MSELoss(reduction='none')num_epochs, lr = 100, 0.003# 偏置参数没有衰减trainer = torch.optim.SGD([{"params":net[0].weight,'weight_decay': wd},{"params":net[0].bias}], lr=lr)animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',xlim=[5, num_epochs], legend=['train', 'test'])for epoch in range(num_epochs):for X, y in train_iter:trainer.zero_grad()l = loss(net(X), y)l.mean().backward()trainer.step()if (epoch + 1) % 5 == 0:animator.add(epoch + 1,(d2l.evaluate_loss(net, train_iter, loss),d2l.evaluate_loss(net, test_iter, loss)))print('w的L2范数:', net[0].weight.norm().item())
train_concise(0)
w的L2范数: 13.416662216186523

在这里插入图片描述

train_concise(3)
w的L2范数: 0.39273694157600403

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/198774.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

python中获取函数签名(参数)

说明 有些时候我们不清楚python的函数的具体签名的时候,调用可能会报错,这里就是介绍一种简单的方法来获取函数的签名参数。 语法 获取函数的参数签名可以使用 组件.函数.__signature__示例: 我想要获取streamlit.text_input函数的参数签…

golang开发之个微机器人开发

请求URL: http://域名地址/sendFile 请求方式: POST 请求头Headers: Content-Type:application/jsonAuthorization:login接口返回 参数: 参数名必选类型说明wId是string登录实例标识wcId是string接收…

使用Redis实现购物车后端处理

本文中心思想:实现购物车的后端处理逻辑。 本文将教会你掌握:1.存储商品信息,2.存储购物车信息,3.获取购物车信息。 存储商品信息 商品包含多个属性,例如:名字&#x…

it资产管理系统

it资产管理系统这个词组初听有些陌生,再听却别有一种科技感。 先来看下it资产管理系统的定义: 它是一种针对企业IT资产进行全面管理和监控的工具,它可以帮助企业实现对IT资源的有效利用和合理配置,提高企业的运营效率和市场竞争力…

查看php进程占用内存

要查看每个PHP中Swoole进程占用的内存,您可以使用Linux的一些工具和命令来实现。 使用ps命令列出正在运行的进程,可以通过进程名或进程ID(PID)过滤结果: ps aux | grep php这将显示与PHP相关的进程列表。 通过top命令…

Python 3 读写 json 文件

使用函数 json.dumps()、loads() 、json.dump()、json.load() 。 1 json.dumps() 将python对象编码成Json字符串 import jsondata {name:Alice,age:25,gender:F,city:北京 }# 1 json.dumps() 将python对象编码成Json字符串 print(type(data)) json_str json.dumps(data…

mysql安装环境

安装mysql https://www.mysql.com/downloads/ 如果操作系统版本比较低,还需要安装NET Framework4.5.2 搭建环境变量 MySQL可视化界面: 破解navicat:

spark不同结构Dataset合并

1.先将hdfs(或本地)存储的csv文件加载为Dataset 先在本地C盘准备两个csv文件 test.csv client_id,behives,del,normal_status,cust_type,no_trd_days 7056,zl,1,hy,个人,2 7057,cf,1,hy,个人,12 7058,hs,2,hy,个人,1200 212121,0,sj,hy,个人,1100 212122,1,yx,hy,个人,100 212…

如何入驻抖音本地生活服务商,门槛太高怎么办?

随着抖音本地生活服务市场的逐渐成熟,越来越多平台开始涉及本地生活服务领域,而本地生活服务商成了一个香窝窝,为了保护用户权益和平台生态,对入驻入驻抖音本地生活服务商的条件及审核也越来越严格,这让很多想成为抖音…

ES常用操作语句

ES常用操作语句 注:本文中的操作语句基于ES5.5和7.7的版本,版本不同操作语句上可能有细微差别,如5.5版本有索引类型,7.7版本已废弃,查询不应该带索引类型 新增 # 添加字段,并设置字段类型 PUT /索引/_map…

LeetCode56. 合并区间

&#x1f517;:【贪心算法&#xff0c;合并区间有细节&#xff01;LeetCode&#xff1a;56.合并区间-哔哩哔哩】 class Solution { public:vector<vector<int>> merge(vector<vector<int>>& intervals) {if(intervals.size()0){return intervals;…

【页面】表格展示

展示 Dom <template><div class"srch-result-container"><!--左侧--><div class"left"><div v-for"(item,index) in muneList" :key"index" :class"(muneIndexitem.mm)?active:"click"pa…

GORM gorm.DB 对象剖析

文章目录 1.GORM 简介2.gorm.DB 简介2.1 定义2.2 初始化2.3 查询方法2.4 事务支持2.5 模型关联2.6 钩子&#xff08;Hooks&#xff09;2.7 自定义数据类型 3.为什么不同请求可以共用一个 gorm.DB 对象&#xff1f;4.链式调用与方法5.小结参考文献 1.GORM 简介 GORM 是一个流行…

基于c 实现 FIFO

功能&#xff1a; 1、读和写长度不限制 2、数据操作 和 指针操作分开&#xff08;如先操作数据&#xff0c;再操作指针&#xff09; 适用场景&#xff1a; 单向通信模式&#xff0c;一方写、一方读&#xff0c;可用于任务间单向通信&#xff08;无需锁&#xff09; 如&…

7-HDFS的文件管理

单选题 题目1&#xff1a;下列哪个属性是hdfs-site.xml中的配置&#xff1f; 选项: A fs.defaultFS B dfs.replication C mapreduce.framework.name D yarn.resourcemanager.address 答案&#xff1a;B ------------------------------ 题目2&#xff1a;HDFS默认备份数量&…

fastboot常用命令

fastboot常用命令 显示fastboot设备&#xff1a;fastboot devices 获取手机相关信息&#xff1a;fastboot getvar all 重启手机&#xff1a;fastboot reboot 重启到bootloader&#xff1a;fastboot reboot-bootloader 擦除分区&#xff1a;fastboot erase (分区名) 例&…

代码随想录算法训练营第四十三天 _ 动态规划_1049.最后一块石头的重量II、494.目标和、474.一和零。

学习目标&#xff1a; 动态规划五部曲&#xff1a; ① 确定dp[i]的含义 ② 求递推公式 ③ dp数组如何初始化 ④ 确定遍历顺序 ⑤ 打印递归数组 ---- 调试 引用自代码随想录&#xff01; 60天训练营打卡计划&#xff01; 学习内容&#xff1a; 1049.最后一块石头的重量II 该题…

360公司-2019校招笔试-Windows开发工程师客观题合集解析

360公司-2019校招笔试-Windows开发工程师客观题合集 API无法实现进程间数据的相互传递是PostMessage2.以下代码执行后,it的数据为(异常) std::list<int> temp; std::list<int>::iterator it = temp.begin(); it = --it; 3.API在失败时的返回值跟其他不一样是 …

微信小程序自定义tabBar简易实现

文章目录 1.app.json设置custom为true开启自定义2.根目录创建自定义的tab文件3.app.js全局封装一个设置tabbar选中的方法4.在onshow中使用选中方法最终效果预览 1.app.json设置custom为true开启自定义 2.根目录创建自定义的tab文件 index.wxml <view class"tab-bar&quo…

自动升降压稳压电源模块输入3v~24V输出3.3/4.2/5/9/12V芯片

自动升降压稳压电源模块是一种高效、高稳定性的电源解决方案&#xff0c;广泛应用于各种需要稳定电压输出的场合。该模块采用宽电压低功耗方案&#xff0c;能够自动升降压&#xff0c;适应不同的输入电压范围&#xff0c;同时具有关断功能&#xff0c;确保设备的安全运行。 该电…