【深度学习笔记】09 权重衰减

09 权重衰减

    • 范数和权重衰减
    • 利用高维线性回归实现权重衰减
    • 权重衰减的简洁实现

范数和权重衰减

在训练参数化机器学习模型时,权重衰减(decay weight)是最广泛应用的正则化技术之一,它通常也被称为 L 2 L_2 L2正则化。这项技术通过函数与零的距离来衡量函数的复杂度,
因为在所有函数 f f f中,函数 f = 0 f = 0 f=0(所有输入都得到值 0 0 0
在某种意义上是最简单的。

一种简单的方法是通过线性函数
f ( x ) = w ⊤ x f(\mathbf{x}) = \mathbf{w}^\top \mathbf{x} f(x)=wx
中的权重向量的某个范数来度量其复杂性,
例如 ∥ w ∥ 2 \| \mathbf{w} \|^2 w2
要保证权重向量比较小,
最常用方法是将其范数作为惩罚项加到最小化损失的问题中。
将原来的训练目标最小化训练标签上的预测损失,
调整为最小化预测损失和惩罚项之和。

损失由下式给出:

L ( w , b ) = 1 n ∑ i = 1 n 1 2 ( w ⊤ x ( i ) + b − y ( i ) ) 2 . L(\mathbf{w}, b) = \frac{1}{n}\sum_{i=1}^n \frac{1}{2}\left(\mathbf{w}^\top \mathbf{x}^{(i)} + b - y^{(i)}\right)^2. L(w,b)=n1i=1n21(wx(i)+by(i))2.

x ( i ) \mathbf{x}^{(i)} x(i)是样本 i i i的特征,
y ( i ) y^{(i)} y(i)是样本 i i i的标签,
( w , b ) (\mathbf{w}, b) (w,b)是权重和偏置参数。

为了惩罚权重向量的大小,
必须以某种方式在损失函数中添加 ∥ w ∥ 2 \| \mathbf{w} \|^2 w2
我们通过正则化常数 λ \lambda λ来描述这种权衡,
这是一个非负超参数,我们使用验证数据拟合:

L ( w , b ) + λ 2 ∥ w ∥ 2 , L(\mathbf{w}, b) + \frac{\lambda}{2} \|\mathbf{w}\|^2, L(w,b)+2λw2,

对于 λ = 0 \lambda = 0 λ=0,我们恢复了原来的损失函数。
对于 λ > 0 \lambda > 0 λ>0,我们限制 ∥ w ∥ \| \mathbf{w} \| w的大小。
这里我们仍然除以 2 2 2:当我们取一个二次函数的导数时,
2 2 2 1 / 2 1/2 1/2会抵消。

通过平方 L 2 L_2 L2范数,我们去掉平方根,留下权重向量每个分量的平方和。
这使得惩罚的导数很容易计算:导数的和等于和的导数。

L 2 L_2 L2正则化回归的小批量随机梯度下降更新如下式:

w ← ( 1 − η λ ) w − η ∣ B ∣ ∑ i ∈ B x ( i ) ( w ⊤ x ( i ) + b − y ( i ) ) . \begin{aligned} \mathbf{w} & \leftarrow \left(1- \eta\lambda \right) \mathbf{w} - \frac{\eta}{|\mathcal{B}|} \sum_{i \in \mathcal{B}} \mathbf{x}^{(i)} \left(\mathbf{w}^\top \mathbf{x}^{(i)} + b - y^{(i)}\right). \end{aligned} w(1ηλ)wBηiBx(i)(wx(i)+by(i)).

我们根据估计值与观测值之间的差异来更新 w \mathbf{w} w
然而,我们同时也在试图将 w \mathbf{w} w的大小缩小到零。
这就是为什么这种方法有时被称为权重衰减
我们仅考虑惩罚项,优化算法在训练的每一步衰减权重。
与特征选择相比,权重衰减为我们提供了一种连续的机制来调整函数的复杂度。
较小的 λ \lambda λ值对应较少约束的 w \mathbf{w} w
而较大的 λ \lambda λ值对 w \mathbf{w} w的约束更大。

是否对相应的偏置 b 2 b^2 b2进行惩罚在不同的实践中会有所不同,
在神经网络的不同层中也会有所不同。
通常,网络输出层的偏置项不会被正则化。

利用高维线性回归实现权重衰减

%matplotlib inline
import torch
from torch import nn
from d2l import torch as d2l

首先生成数据,生成公式如下:

y = 0.05 + ∑ i = 1 d 0.01 x i + ϵ where  ϵ ∼ N ( 0 , 0.0 1 2 ) . y = 0.05 + \sum_{i = 1}^d 0.01 x_i + \epsilon \text{ where } \epsilon \sim \mathcal{N}(0, 0.01^2). y=0.05+i=1d0.01xi+ϵ where ϵN(0,0.012).

选择标签是关于输入的线性函数。
标签同时被均值为0,标准差为0.01高斯噪声破坏。
为了使过拟合的效果更加明显,我们可以将问题的维数增加到 d = 200 d = 200 d=200
并使用一个只包含20个样本的小训练集。

n_train, n_test, num_inputs, batch_size = 20, 100, 200, 5
true_w, true_b = torch.ones((num_inputs, 1)) * 0.01, 0.05
train_data = d2l.synthetic_data(true_w, true_b, n_train)
train_iter = d2l.load_array(train_data, batch_size)
test_data = d2l.synthetic_data(true_w, true_b, n_test)
test_iter = d2l.load_array(test_data, batch_size, is_train=False)

初始化模型参数

定义一个函数来随机初始化模型参数

def init_params():w = torch.normal(0, 1, size = (num_inputs, 1), requires_grad = True)b = torch.zeros(1, requires_grad = True)return [w, b]

定义 L 2 L_2 L2范数惩罚

def l2_penalty(w):return torch.sum(w.pow(2)) / 2

定义训练代码实现

下面的代码将模型拟合训练数据集,并在测试数据集上进行评估。

def train(lambd):w, b = init_params()net, loss = lambda X: d2l.linreg(X, w, b), d2l.squared_lossnum_epochs, lr = 100, 0.003animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',xlim=[5, num_epochs], legend=['train', 'test'])for epoch in range(num_epochs):for X, y in train_iter:# 增加了L2范数惩罚项,# 广播机制使l2_penalty(w)成为一个长度为batch_size的向量l = loss(net(X), y) + lambd * l2_penalty(w)l.sum().backward()d2l.sgd([w, b], lr, batch_size)if (epoch + 1) % 5 == 0:animator.add(epoch + 1, (d2l.evaluate_loss(net, train_iter, loss),d2l.evaluate_loss(net, test_iter, loss)))print('w的L2范数是:', torch.norm(w).item())

忽略正则化直接训练

用lamdb=0禁用权重衰减后运行代码。此时训练误差有所减少,但测试误差没有减少,这意味着出现了严重的过拟合。

train(lambd = 0)
w的L2范数是: 14.971677780151367

在这里插入图片描述

使用权重衰减

使用权重衰减来运行代码。此时训练误差增大,但测试误差减小。这正是我们期望从正则化中得到的效果。

train(lambd = 3)
w的L2范数是: 0.34405317902565

在这里插入图片描述

权重衰减的简洁实现

在实例化优化器时直接通过weight_decay指定weight decay超参数。默认情况下,PyTorch同时衰减权重和便宜。这里只为权重设置了weight_decay,所以偏置参数 b b b不会衰减。

def train_concise(wd):net = nn.Sequential(nn.Linear(num_inputs, 1))for param in net.parameters():param.data.normal_()loss = nn.MSELoss(reduction='none')num_epochs, lr = 100, 0.003# 偏置参数没有衰减trainer = torch.optim.SGD([{"params":net[0].weight,'weight_decay': wd},{"params":net[0].bias}], lr=lr)animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',xlim=[5, num_epochs], legend=['train', 'test'])for epoch in range(num_epochs):for X, y in train_iter:trainer.zero_grad()l = loss(net(X), y)l.mean().backward()trainer.step()if (epoch + 1) % 5 == 0:animator.add(epoch + 1,(d2l.evaluate_loss(net, train_iter, loss),d2l.evaluate_loss(net, test_iter, loss)))print('w的L2范数:', net[0].weight.norm().item())
train_concise(0)
w的L2范数: 13.416662216186523

在这里插入图片描述

train_concise(3)
w的L2范数: 0.39273694157600403

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/198774.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

golang开发之个微机器人开发

请求URL: http://域名地址/sendFile 请求方式: POST 请求头Headers: Content-Type:application/jsonAuthorization:login接口返回 参数: 参数名必选类型说明wId是string登录实例标识wcId是string接收…

it资产管理系统

it资产管理系统这个词组初听有些陌生,再听却别有一种科技感。 先来看下it资产管理系统的定义: 它是一种针对企业IT资产进行全面管理和监控的工具,它可以帮助企业实现对IT资源的有效利用和合理配置,提高企业的运营效率和市场竞争力…

mysql安装环境

安装mysql https://www.mysql.com/downloads/ 如果操作系统版本比较低,还需要安装NET Framework4.5.2 搭建环境变量 MySQL可视化界面: 破解navicat:

如何入驻抖音本地生活服务商,门槛太高怎么办?

随着抖音本地生活服务市场的逐渐成熟,越来越多平台开始涉及本地生活服务领域,而本地生活服务商成了一个香窝窝,为了保护用户权益和平台生态,对入驻入驻抖音本地生活服务商的条件及审核也越来越严格,这让很多想成为抖音…

【页面】表格展示

展示 Dom <template><div class"srch-result-container"><!--左侧--><div class"left"><div v-for"(item,index) in muneList" :key"index" :class"(muneIndexitem.mm)?active:"click"pa…

360公司-2019校招笔试-Windows开发工程师客观题合集解析

360公司-2019校招笔试-Windows开发工程师客观题合集 API无法实现进程间数据的相互传递是PostMessage2.以下代码执行后,it的数据为(异常) std::list<int> temp; std::list<int>::iterator it = temp.begin(); it = --it; 3.API在失败时的返回值跟其他不一样是 …

微信小程序自定义tabBar简易实现

文章目录 1.app.json设置custom为true开启自定义2.根目录创建自定义的tab文件3.app.js全局封装一个设置tabbar选中的方法4.在onshow中使用选中方法最终效果预览 1.app.json设置custom为true开启自定义 2.根目录创建自定义的tab文件 index.wxml <view class"tab-bar&quo…

自动升降压稳压电源模块输入3v~24V输出3.3/4.2/5/9/12V芯片

自动升降压稳压电源模块是一种高效、高稳定性的电源解决方案&#xff0c;广泛应用于各种需要稳定电压输出的场合。该模块采用宽电压低功耗方案&#xff0c;能够自动升降压&#xff0c;适应不同的输入电压范围&#xff0c;同时具有关断功能&#xff0c;确保设备的安全运行。 该电…

想要更高效的文件传输?这些aspera替代方案可以助你一臂之力

随着数字化时代的不断推进&#xff0c;数据传输已成为各行各业、各类企业所必需的核心能力。而在文件传输方面&#xff0c;传统的方式往往面临着诸多问题&#xff0c;例如文件大小限制、传输速度过慢、不稳定、不安全等问题。为此&#xff0c;许多企业开始寻找更可靠、更高效的…

网工学习10-IP地址

一、IP地址概念 IP地址是一个32位的二进制数&#xff0c;它由网络ID和主机ID两部份组成&#xff0c;用来在网络中唯一的标识的一台计算机。网络ID用来标识计算机所处的网段&#xff1b;主机ID用来标识计算机在网段中的位置。IP地址通常用4组3位十进制数表示&#xff0c;中间用…

scipy笔记:scipy.interpolate.interp1d

1 主要使用方法 class scipy.interpolate.interp1d(x, y, kindlinear, axis-1, copyTrue, bounds_errorNone, fill_valuenan, assume_sortedFalse) 2 主要函数 x一维实数值数组&#xff0c;代表插值的自变量y N维实数值数组&#xff0c;其中沿着插值轴的 y 长度必须等于 x 的…

gitlab注册无中国区电话验证问题

众所周知gitlab对中国区不友好&#xff0c;无法直接注册&#xff0c;页面无法选择86的手机号进行验证码发送。 Google上众多的方案是修改dom&#xff0c;而且时间大约是21年以前。 修改dom&#xff0c;对于现在的VUE、React框架来说是没有用的&#xff0c;所以不用尝试。 直接看…

postman参数为D:\\audio\\test.mp3请求报错

报错信息 报错 java.lang.IllegalArgumentException: Invalid character found in the request target [/v1/audio/transcriptions?audioPathD:\\audio\\test.mp3 ]. The valid characters are defined in RFC 7230 and RFC 3986 解决方式 yml文件上放行指定字符 relaxed-pa…

安装获取mongodb

目录 本地安装 获取云上资源 获取Atlas免费数据库 本地连接数据库 在Atlas中连接数据库 本文适合初学者或mongodb感兴趣的同学来准备学习测试环境&#xff0c;或本地临时开发环境。mongodb是一个对用户非常友好的数据库。这种友好&#xff0c;不仅仅体现在灵活的数据结构和…

评论功能实现方案

构建高效且安全的评论功能&#xff1a;实现方案探讨。 1、分析 我们以b站的评论为例&#xff0c;用下图来解释我们评论的分级。 我们可以抽出存储评论的数据表属性 评论id父级id评论作者id被回复用户ID评论帖子ID评论内容创建时间 可以设计如下的数据表 其中pid表示父id。 …

考研失利后,我是如何零基础转行测试开发 ,成功拿下独角兽公司offer?

想当年&#xff0c;从一个什么都不懂的非科班测试小白&#xff0c;考研失利后&#xff0c;转行到K12教育知名互联网公司做测试开发工程师&#xff0c;我用了大概半年的时间。 这个过程中我自己也摸索出了一条学习路线&#xff0c;在这里想给大家分享一下我的学习路线&#xff…

Hadoop学习笔记(HDP)-Part.16 安装HBase

目录 Part.01 关于HDP Part.02 核心组件原理 Part.03 资源规划 Part.04 基础环境配置 Part.05 Yum源配置 Part.06 安装OracleJDK Part.07 安装MySQL Part.08 部署Ambari集群 Part.09 安装OpenLDAP Part.10 创建集群 Part.11 安装Kerberos Part.12 安装HDFS Part.13 安装Ranger …

vue3 vue-router过渡动效 滚动行为 (四)

文章目录 一、过渡动效1.1安装animate.css1.2 利用元信息存储过渡名称1.3 在组件中使用 二、滚动行为2.1 始终滚动到顶部2.2 相对于某个元素的偏移量2.3 保持之前的滚动位置 一、过渡动效 1.1安装animate.css npm install animate.css --save1.2 利用元信息存储过渡名称 {pa…

ROS opencv PCL Ceres-solver之间版本对应关系

ROS1 : neotic Opencv : 4.6.0 Ceres-solver : 2.0.0

ABAP 报表工具栏缺少小计按钮

解决方案&#xff1a; 在sap标准程序 SAPLKKBL 中有多个标准的的状态栏 都有小计按钮 复制过来之后却不显示&#xff0c;调试发现&#xff0c; 在 pf_status_alv里面做了excluding &#xff0c;需要把小计排除 调试RT_EXTAB. 说明程序默认给隐藏了 不显示&#xff0c;删除调…