深度学习——权重衰减(weight_decay)

深度学习——权重衰减(weight_decay)

文章目录

  • 前言
  • 一、权重衰减
    • 1.1. 范数与权重衰减
    • 1.2. 高维线性回归
    • 1.3. 从零开始实现
      • 1.3.1.初始化模型参数
      • 1.3.2. 定义L₂范数惩罚
      • 1.3.3. 定义训练代码实现
      • 1.3.4. 不管正则化直接训练
      • 1.3.5. 使用权重衰减
    • 1.4. 简洁实现
  • 总结


前言

上一章描述了过拟合的问题,本章我们将介绍一些正则化模型的技术。如权重衰减

参考书:
《动手学深度学习》


一、权重衰减

1.1. 范数与权重衰减

在训练参数化机器学习模型时,权重衰减(weight decay)是最广泛使用的正则化的技术之一, 它通常也被称为 L 2 L_2 L2正则化。这项技术通过函数与零的距离来衡量函数的复杂度,

因为在所有函数 f f f中,函数 f = 0 f = 0 f=0(所有输入都得到值 0 0 0),在某种意义上是最简单的。

但是我们应该如何精确地测量一个函数和零之间的距离呢?

一种简单的方法是通过线性函数
f ( x ) = w ⊤ x f(\mathbf{x}) = \mathbf{w}^\top \mathbf{x} f(x)=wx 中的权重向量的某个范数来度量其复杂性,
例如 ∥ w ∥ 2 \| \mathbf{w} \|^2 w2

要保证权重向量比较小,最常用方法是将其范数作为惩罚项加到最小化损失的问题中
即将原来的训练目标最小化训练标签上的预测损失,调整为最小化预测损失和惩罚项之和

现在,如果我们的权重向量增长的太大,我们的学习算法可能会更集中于最小化权重范数 ∥ w ∥ 2 \| \mathbf{w} \|^2 w2。这正是我们想要的。

我们的损失由下式给出:

L ( w , b ) = 1 n ∑ i = 1 n 1 2 ( w ⊤ x ( i ) + b − y ( i ) ) 2 . L(\mathbf{w}, b) = \frac{1}{n}\sum_{i=1}^n \frac{1}{2}\left(\mathbf{w}^\top \mathbf{x}^{(i)} + b - y^{(i)}\right)^2. L(w,b)=n1i=1n21(wx(i)+by(i))2.

为了惩罚权重向量的大小,我们必须以某种方式在损失函数中添加 ∥ w ∥ 2 \| \mathbf{w} \|^2 w2

但是模型应该如何平衡这个新的额外惩罚的损失?
实际上,我们通过正则化常数 λ \lambda λ来描述这种权衡,这是一个非负超参数,我们使用验证数据拟合:

L ( w , b ) + λ 2 ∥ w ∥ 2 , L(\mathbf{w}, b) + \frac{\lambda}{2} \|\mathbf{w}\|^2, L(w,b)+2λw2,

对于 λ > 0 \lambda > 0 λ>0,我们限制 ∥ w ∥ \| \mathbf{w} \| w的大小。

为什么在这里我们使用平方范数而不是标准范数(即欧几里得距离)?

我们这样做是为了便于计算。通过平方 L 2 L_2 L2范数,我们去掉平方根,留下权重向量每个分量的平方和。
这使得惩罚的导数很容易计算:导数的和等于和的导数。

此外,为什么我们首先使用 L 2 L_2 L2范数,而不是 L 1 L_1 L1范数。

L 2 L_2 L2正则化线性模型构成经典的岭回归(ridge regression)算法,
L 1 L_1 L1正则化线性回归是统计学中类似的基本模型,通常被称为套索回归(lasso regression)。

使用 L 2 L_2 L2范数的一个原因是它对权重向量的大分量施加了巨大的惩罚。这使得我们的学习算法偏向于在大量特征上均匀分布权重的模型。在实践中,这可能使它们对单个变量中的观测误差更为稳定

相比之下, L 1 L_1 L1惩罚会导致模型将权重集中在一小部分特征上,
而将其他权重清除为零
。这称为特征选择(feature selection),可能是其他场景下需要的。

L 2 L_2 L2正则化回归的小批量随机梯度下降更新如下式:

w ← ( 1 − η λ ) w − η ∣ B ∣ ∑ i ∈ B x ( i ) ( w ⊤ x ( i ) + b − y ( i ) ) . \begin{aligned} \mathbf{w} & \leftarrow \left(1- \eta\lambda \right) \mathbf{w} - \frac{\eta}{|\mathcal{B}|} \sum_{i \in \mathcal{B}} \mathbf{x}^{(i)} \left(\mathbf{w}^\top \mathbf{x}^{(i)} + b - y^{(i)}\right). \end{aligned} w(1ηλ)wBηiBx(i)(wx(i)+by(i)).

我们根据估计值与观测值之间的差异来更新 w \mathbf{w} w。然而,我们同时也在试图将 w \mathbf{w} w的大小缩小到零。
这就是为什么这种方法有时被称为权重衰减。我们仅考虑惩罚项,优化算法在训练的每一步衰减权重。

与特征选择相比,权重衰减为我们提供了一种连续的机制来调整函数的复杂度。 较小的 λ \lambda λ值对应较少约束的 w \mathbf{w} w,而较大的 λ \lambda λ值对 w \mathbf{w} w的约束更大。

是否对相应的偏置 b 2 b^2 b2进行惩罚在不同的实践中会有所不同,
在神经网络的不同层中也会有所不同。通常,网络输出层的偏置项不会被正则化。

1.2. 高维线性回归

我们通过一个简单的例子来演示权重衰减。

首先,我们像以前一样生成一些数据,生成公式如下:

y = 0.05 + ∑ i = 1 d 0.01 x i + ϵ where  ϵ ∼ N ( 0 , 0.0 1 2 ) . y = 0.05 + \sum_{i = 1}^d 0.01 x_i + \epsilon \text{ where } \epsilon \sim \mathcal{N}(0, 0.01^2). y=0.05+i=1d0.01xi+ϵ where ϵN(0,0.012).

我们选择标签是关于输入的线性函数。标签同时被均值为0,标准差为0.01高斯噪声破坏。
为了使过拟合的效果更加明显,我们可以将问题的维数增加到 d = 200 d = 200 d=200
并使用一个只包含20个样本的小训练集。

import torch
from d2l import torch as d2l
from torch import nnn_train,n_test,num_inputs,batch_size = 20,100,200,5
true_w,true_b = torch.ones((num_inputs,1))*0.01,0.05"""
使用d2l.synthetic_data函数生成了训练数据和测试数据,并使用d2l.load_array函数将数据加载为迭代器。
"""
train_data = d2l.synthetic_data(true_w,true_b,n_train)
train_iter = d2l.load_array(train_data,batch_size)test_data = d2l.synthetic_data(true_w,true_b,n_test)
test_iter = d2l.load_array(test_data,batch_size,is_train= False)
#这里设置is_train=False表示测试数据不用于模型训练,只用于评估模型的性能。

1.3. 从零开始实现

下面我们将从头开始实现权重衰减,只需将 L 2 L_2 L2的平方惩罚添加到原始目标函数中。

1.3.1.初始化模型参数

#初始化模型参数
#我们将定义一个函数来随机初始化模型参数
def init_params():w = torch.normal(0,1,size=(num_inputs,1),requires_grad= True)b = torch.zeros(1,requires_grad=True)return [w,b]

1.3.2. 定义L₂范数惩罚


#定义L2范数惩罚(实现这一惩罚最方便的方法是对所有项求平方后并将它们求和)
def l2_penalty(w):return torch.sum(w.pow(2))/2 #将权重w的平方和除以2,除以2是为了方便计算梯度

1.3.3. 定义训练代码实现

#定义训练代码实现
def train(lambd):w,b  = init_params()net,loss = lambda x: d2l.linreg(x,w,b),d2l.squared_lossnum_epochs,lr = 100,0.003animator = d2l.Animator(xlabel="epochs",ylabel="loss",yscale="log",xlim= [5,num_epochs],legend=["train","test"])for epoch in range(num_epochs):for x,y in train_iter:#增加了L2范数惩罚项#广播机制使l2_penalty(w)成为一个长度为batch_size的向量l = loss(net(x),y) + lambd * l2_penalty(w)l.sum().backward()d2l.sgd([w,b],lr,batch_size)if (epoch+1)%5 ==0:animator.add(epoch+1,(d2l.evaluate_loss(net,train_iter,loss),d2l.evaluate_loss(net,test_iter,loss)))print("w的L2范数是:",torch.norm(w).item())

1.3.4. 不管正则化直接训练

#现在用`lambd = 0`禁用权重衰减后运行这个代码。
#注意,这里训练误差有了减少,但测试误差没有减少,这意味着出现了严重的过拟合。train(lambd= 0)#结果:
w的L2范数是: 13.981727600097656

在这里插入图片描述

1.3.5. 使用权重衰减

#使用权重衰减来运行代码。
#注意,在这里训练误差增大,但测试误差减小。这正是我们期望从正则化中得到的效果。train(lambd= 3)#结果:
w的L2范数是: 0.3319331705570221d2l.plt.show()

在这里插入图片描述

1.4. 简洁实现

深度学习框架为了便于我们使用权重衰减,将权重衰减集成到优化算法中,以便与任何损失函数结合使用。

#在下面的代码中,我们在实例化优化器时直接通过`weight_decay`指定weight decay超参数。
#默认情况下,PyTorch同时衰减权重和偏移。
#这里我们只为权重设置了`weight_decay`,所以偏置参数$b$不会衰减。def train_concise(wd):net = nn.Sequential(nn.Linear(num_inputs,1))for param in net.parameters():param.data.normal_() #使用正态分布随机初始化参数loss = nn.MSELoss(reduction="none") #定义损失函数为均方误差损失num_epochs,lr = 100,0.003#偏置参数没有衰减trainer = torch.optim.SGD([{"params":net[0].weight,"weight_decay":wd},{"params":net[0].bias}],lr = lr)  #net[0].weight表示模型的权重参数,net[0].bias表示模型的偏置参数。weight_decay参数用于设置权重衰减的强度。animator = d2l.Animator(xlabel="epochs",ylabel="loss",yscale="log",xlim=[5,num_epochs],legend=["train","test"])for epoch in range(num_epochs):for x,y in train_iter:trainer.zero_grad() #清零梯度,以防止梯度累积l = loss(net(x),y)l.mean().backward() #计算损失的平均值,并进行反向传播,计算梯度trainer.step() #更新模型的参数,执行一步优化器的更新if (epoch+1)%5 == 0:animator.add(epoch+1,(d2l.evaluate_loss(net,train_iter,loss),d2l.evaluate_loss(net,test_iter,loss)))print("w的L2范数:", net[0].weight.norm().item())  #打印模型权重的L2范数,用于评估模型的复杂度。train_concise(0)
train_concise(3)
d2l.plt.show()#结果:
w的L2范数: 13.411089897155762
w的L2范数: 0.3319282829761505

在这里插入图片描述

在这里插入图片描述


总结

为了有效防止模型的过拟合,降低模型的复杂度,提高泛化能力,本章简单记录了一种常见的正则化技术:权重衰减。简单来说权重衰减是通过在损失函数中添加一个正则化项来实现的。这个正则化项通常是模型参数的L2范数(平方和)或L1范数(绝对值和),通过限制模型参数的大小来防止过拟合。

我独泊兮其未兆,如婴儿之未孩,傫傫(lèi lèi)兮,若无所归。

–2023-10-2 进阶篇

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/96600.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

vue 项目打包性能分析插件 webpack-bundle-analyzer

webpack-bundle-analyzer 是 webpack 的插件,需要配合 webpack 和 webpack-cli 一起使用。这个插件可以读取输出文件夹(通常是 dist)中的 stats.json 文件,把该文件可视化展现,生成代码分析报告,可以直观地…

Leetcode901-股票价格跨度

一、前言 本题基于leetcode901股票价格趋势这道题,说一下通过java解决的一些方法。并且解释一下笔者写这道题之前的想法和一些自己遇到的错误。需要注意的是,该题最多调用 next 方法 10^4 次,一般出现该提示说明需要注意时间复杂度。 二、解决思路 ①…

神经网络中的知识蒸馏

多分类交叉熵损失函数:每个样本的标签已经给出,模型给出在三种动物上的预测概率。将全部样本都被正确预测的概率求得为0.70.50.1,也称为似然概率。优化的目标就是希望似然概率最大化。如果样本很多,概率不断连乘,就会造…

关于丢失msvcp71.dll的5个解决办法,msvcp71.dll丢失原因分析

计算机已经成为我们生活和工作中不可或缺的一部分,在使用计算机的过程中,我们经常遇到各种软件或应用程序崩溃的情况。其中,一个常见的错误提示是“MSVCP71.dll丢失”。这个错误通常出现在运行使用Visual C Redistributable for Visual Studi…

数据结构——多重链表的实现

//多重列表的实现 #include<stdio.h> #include<stdlib.h> struct lnode {int row,col,value; }; //没有用到down指针 //没有用到tag和next指针 typedef struct node {int tag;//区分头结点(0)和非零元素结点(1)struct node* right;struct node* down;//共用体与结…

Django基础讲解-路由控制器和视图(Django-02)

一 路由控制器 参考链接&#xff1a; Django源码阅读&#xff1a;路由&#xff08;二&#xff09; - 知乎 Route路由, 是一种映射关系&#xff01;路由是把客户端请求的 url路径与视图进行绑定 映射的一种关系。 这个/timer通过路由控制器最终匹配到myapp.views中的视图函数 …

抄写Linux源码(Day14:从 MBR 到 C main 函数 (3:研究 head.s) )

回忆我们需要做的事情&#xff1a; 为了支持 shell 程序的执行&#xff0c;我们需要提供&#xff1a; 1.缺页中断(不理解为什么要这个东西&#xff0c;只是闪客说需要&#xff0c;后边再说) 2.硬盘驱动、文件系统 (shell程序一开始是存放在磁盘里的&#xff0c;所以需要这两个东…

山西电力市场日前价格预测【2023-10-08】

日前价格预测 预测说明&#xff1a; 如上图所示&#xff0c;预测明日&#xff08;2023-10-08&#xff09;山西电力市场全天平均日前电价为258.40元/MWh。其中&#xff0c;最高日前电价为496.19元/MWh&#xff0c;预计出现在18:45。最低日前电价为0.00元/MWh&#xff0c;预计出…

【FreeRTOS】内存管理简单介绍

有没有想过什么移植FreeRTOS时&#xff0c;为什么有多种的内存文件&#xff0c;我们工程只使用Heap_4&#xff0c;其他的有什么用&#xff1f;每个的区别是什么&#xff1f;FreeRTOS是一种流行的实时操作系统&#xff0c;广泛应用于嵌入式系统开发中。在嵌入式系统中&#xff0…

(三)行为模式:8、状态模式(State Pattern)(C++示例)

目录 1、状态模式&#xff08;State Pattern&#xff09;含义 2、状态模式的UML图学习 3、状态模式的应用场景 4、状态模式的优缺点 &#xff08;1&#xff09;优点 &#xff08;2&#xff09;缺点 5、C实现状态模式的实例 1、状态模式&#xff08;State Pattern&#x…

光伏发电预测(LSTM、CNN_LSTM和XGBoost回归模型,Python代码)

运行效果&#xff1a;光伏发电预测&#xff08;LSTM、CNN_LSTM和XGBoost回归模型&#xff0c;Python代码&#xff09;_哔哩哔哩_bilibili 运行环境库的版本 光伏太阳能电池通过互连形成光伏模块&#xff0c;以捕捉太阳光并将太阳能转化为电能。因此&#xff0c;当光伏模块暴露…

深入探究 C++ 编程中的资源泄漏问题

目录 1、GDI对象泄漏 1.1、何为GDI资源泄漏&#xff1f; 1.2、使用GDIView工具排查GDI对象泄漏 1.3、有时可能需要结合其他方法去排查 1.4、如何保证没有GDI对象泄漏&#xff1f; 2、进程句柄泄漏 2.1、何为进程句柄泄漏&#xff1f; 2.2、创建线程时的线程句柄泄漏 …

成功解决@Async注解不生效的问题,异步任务处理问题

首先&#xff0c;有这样一个异步监听方法 然后配置好了异步线程池 package com.fdw.study.config;import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.ComponentScan; import org.springframework.context.annotation.Conf…

国产1.8V低电压输入,可用于驱动步进电机;H 桥驱动电路单元可以直接驱动IR-CUT

D6212是专为安防摄像头系统设计的驱动电路&#xff0c;电路由八路达林顿管 阵列和 H 桥驱动电路两个单元组成。八路达林顿管阵列均带有续流二极 管&#xff0c;可用于驱动步进电机&#xff1b;H 桥驱动电路单元可以直接驱动IR-CUT。单个 达林顿管在输入电压低至 1.8V 状态下支持…

口袋参谋:如何提升宝贝流量?这三种方法超实用!

​你的店铺能不能出爆款&#xff1f;提升单品流量是关键。 对于新手卖家来说&#xff0c;是缺乏运营技巧和运营经验的&#xff0c;运营技巧主要体现在标题写作、各种图片和视频制作等。 由于新手买家没有经验&#xff0c;习惯于直接使用数据包上传&#xff0c;导致宝贝没有展…

Java卷上天,可以转行干什么?

小刚是某名企里的一位有5年经验的高级Java开发工程师&#xff0c;每天沉重的的工作让他疲惫不堪&#xff0c;让他萌生出想换工作的心理&#xff0c;但是转行其他工作他又不清楚该找什么样的工作 因为JAVA 这几年的更新实在是太太太……快了&#xff0c;JAVA 8 都还没用多久&am…

cpp primer plus笔记01-注意事项

cpp尽量以int main()写函数头而不是以main()或者int main(void)或者void main()写。 cpp尽量上图用第4行的注释而不是用第5行注释。 尽量不要引用命名空间比如:using namespace std; 函数体内引用的命名空间会随着函数生命周期结束而失效&#xff0c;放置在全局引用的命名空…

PX4仿真添加world模型文件,并使用yolov8进行跟踪

前言 目的:我们是为了在无人机仿真中使用一个汽车模型,然后让仿真的无人机能够识别到这个汽车模型。所以我们需要在无人机仿真的环境中添加汽车模型。 无人机仿真中我们默认使用的empty.world文件,所以只需要将我们需要的模型添加到一起写进这个empty.world文件中去就可以…

jmeter添加断言(详细图解)

先创建一个线程组&#xff0c;再创建一个http请求。 为了方便观察&#xff0c;我们添加两个监听器&#xff0c;察看结果树和断言结果。 添加断言&#xff1a;响应断言&#xff0c;响应断言也是比较常用的一个断言 设置响应断言&#xff1a;正常情况下响应代码是200。选择响应代…

Multisim:JFET混频器设计(含完整程序)

目录 前言实验内容一、先看作业题目要求二、作业正文IntroductionPre-lab work3.13.2 Experiment Work4.1(2)circuit setup4.1(3)add 12V DC4.1(4)set input x1 and x24.1(5)4.1(6)4.1(7)4.2(1)(2)4.2(3)4.2(4)4.3(1)(2)4.3(3) Conclusion 三、资源包内容 前言 花了好大心血完成…