权重衰减(Weight Decay)

       在深度学习中,权重衰减(Weight Decay)是一种常用的正则化技术,旨在减少模型的过拟合现象。权重衰减通过向损失函数添加一个正则化项,以惩罚模型中较大的权重值。

一、权重衰减

       在深度学习中,模型的训练过程通常使用梯度下降法(或其变种)来最小化损失函数。梯度下降法的目标是找到损失函数的局部最小值,使得模型的预测能力最好。然而,当模型的参数(即权重)过多或过大时,容易导致过拟合问题,即模型在训练集上表现很好,但在测试集上表现较差。

       权重衰减通过在损失函数中引入正则化项来解决过拟合问题。正则化项通常使用L1范数或L2范数来度量模型的复杂度。L2范数正则化(也称为权重衰减)是指将模型的权重的平方和添加到损失函数中,乘以一个较小的正则化参数$ \lambda $这个额外的项迫使模型学习到较小的权重值,从而减少模型的复杂度。

       具体而言,对于一个深度学习模型的损失函数$L(w, b)$,其中$w,b$表示模型的参数(权重和偏置),权重衰减可以通过以下方式实现:

$ L'\left( w,b \right) =L\left( w,b \right) +\lambda \cdot \lVert w \rVert ^2 $

       其中,$ L'\left( w,b \right) $是添加了权重衰减的损失函数,$ \lVert w \rVert ^2 $表示参数的L2范数的平方和,$ \lambda $是正则化参数,用于控制正则化项的重要性。

       在训练过程中,梯度下降法将同时更新损失函数和权重。当计算梯度时,权重衰衰减的正则化项将被添加到梯度中,从而导致权重更新的幅度减小。这使得模型的权重趋向于减小,避免过拟合现象。

       需要注意的是,正则化参数$ \lambda $的选择对模型的性能有重要影响。较小的$ \lambda $值会导致较强的正则化效果,可能会使模型欠拟合。而较大的$ \lambda $值可能会减少正则化效果,使模型过拟合。因此,选择合适的正则化参数是权衡模型复杂度和泛化能力的关键。

       偏置(biases)在神经网络中起到平移激活函数的作用,通常不会像权重那样导致过度拟合。偏置的主要作用是调整激活函数的位置,使其更好地对应所需的输出。由于偏置的影响较小,因此将权重衰减应用于偏置通常不是常见的做法。

二、权重衰减数学解释

       L2范数正则化在解决过拟合问题方面具有一定的效果,这是因为它在损失函数中引入了权重的平方和作为正则化项。下面我将解释一下L2范数正则化的数学原理。

       在深度学习中,我们的目标是最小化损失函数,该函数包括两部分:经验误差和正则化项。对于L2范数正则化,我们将正则化项定义为权重的平方和的乘以一个正则化参数$ \lambda $

       针对损失函数$ L'\left( w,b \right)$,我们使用梯度下降法来最小化这个损失函数。在梯度下降的每一步中,我们计算损失函数的梯度,然后更新权重。对于L2范数正则化,梯度的计算中包含了正则化项的贡献。

       具体来说,我们计算损失函数对权重w的梯度,记为$ \nabla L\left( w,b \right) $。那么加入L2范数正则化后的梯度可以写为:

$ \nabla L'\left( w,b \right) =\nabla L\left( w,b \right) +2\lambda w $

       这里,$ 2\lambda w $是正则化项的梯度贡献,其中$ 2\lambda $是正则化参数$ \lambda $的倍数,$w$是权重的梯度。

       当我们使用梯度下降法更新权重时,梯度的负方向指示了损失函数下降的方向。由于L2范数正则化项的存在,权重的梯度会受到惩罚,从而导致权重的更新幅度减小。

       这种减小权重更新幅度的效果使得模型倾向于学习到较小的权重值,从而降低了模型的复杂度。通过减小权重的幅度,L2范数正则化可以有效地控制模型的过拟合,提高模型的泛化能力。

       总结起来,L2范数正则化通过引入权重的平方和作为正则化项,在梯度计算和权重更新中对权重进行惩罚,从而减小了模型的复杂度,防止过拟合现象的发生。

也可以参考李沐老师的课件:

三、代码从零开始实现

import torch
from torch import nn
from d2l import torch as d2l

1、生成数据

       首先,我们像以前一样生成一些数据,生成公式如下:

$y = 0.05 + \sum_{i = 1}^d 0.01 x_i + \epsilon \text{ where } \epsilon \sim \mathcal{N}(0, 0.01^2).$

       我们选择标签是关于输入的线性函数。标签同时被均值为0,标准差为0.01高斯噪声破坏。为了使过拟合的效果更加明显,我们可以将问题的维数增加到$d = 200$(w的长度为200),并使用一个只包含20个样本的小训练集。

n_train, n_test, num_inputs, batch_size = 20, 100, 200, 5   # 训练集长度为20、验证机长度为100、权重参数有200个、批量大小为5
true_w, true_b = torch.ones((num_inputs, 1)) * 0.01, 0.05   # 真实的权重和偏置
train_data = d2l.synthetic_data(true_w, true_b, n_train)
train_iter = d2l.load_array(train_data, batch_size)
test_data = d2l.synthetic_data(true_w, true_b, n_test)
test_iter = d2l.load_array(test_data, batch_size, is_train=False)

2、初始化模型参数

       我们将定义一个函数来随机初始化模型参数。

def init_params():w = torch.normal(0, 1, size=(num_inputs, 1), requires_grad=True)b = torch.zeros(1, requires_grad=True)return [w, b]

3、定义L2范数惩罚

       实现这一惩罚最方便的方法是对所有项求平方后并将它们求和。

def l2_penalty(w):return torch.sum(w.pow(2)) / 2

4、定义训练代码实现

       下面的代码将模型拟合训练数据集,并在测试数据集上进行评估。和之前线性回归一样,线性网络和平方损失没有变化,所以我们通过`d2l.linreg`和`d2l.squared_loss`导入它们。唯一的变化是损失现在包括了惩罚项。

def train(lambd):w, b = init_params()net, loss = lambda X: d2l.linreg(X, w, b), d2l.squared_lossnum_epochs, lr = 100, 0.003animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',xlim=[5, num_epochs], legend=['train', 'test'])for epoch in range(num_epochs):for X, y in train_iter:# 增加了L2范数惩罚项,# 广播机制使l2_penalty(w)成为一个长度为batch_size的向量l = loss(net(X), y) + lambd * l2_penalty(w)l.sum().backward()d2l.sgd([w, b], lr, batch_size)if (epoch + 1) % 5 == 0:animator.add(epoch + 1, (d2l.evaluate_loss(net, train_iter, loss),d2l.evaluate_loss(net, test_iter, loss)))print('w的L2范数是:', torch.norm(w).item())

5、忽略正则化直接训练

       我们现在用`lambd = 0`禁用权重衰减后运行这个代码。注意,这里训练误差有了减少,但测试误差没有减少,这意味着出现了严重的过拟合。

train(lambd=0)
w的L2范数是: 12.963241577148438

 

6、使用权重衰减

       下面,我们使用权重衰减来运行代码。注意,在这里训练误差增大,但测试误差减小。这正是我们期望从正则化中得到的效果。

train(lambd=3)
w的L2范数是: 0.3556520938873291

 

四、简洁实现

       由于权重衰减在神经网络优化中很常用,深度学习框架为了便于我们使用权重衰减,将权重衰减集成到优化算法中,以便与任何损失函数结合使用。此外,这种集成还有计算上的好处,允许在不增加任何额外的计算开销的情况下向算法中添加权重衰减。由于更新的权重衰减部分仅依赖于每个参数的当前值,因此优化器必须至少接触每个参数一次。

1、定义训练代码实现

       在下面的代码中,我们在实例化优化器时直接通过`weight_decay`指定weight decay超参数。默认情况下,PyTorch同时衰减权重和偏移。这里我们只为权重设置了`weight_decay`,所以偏置参数$b$不会衰减。

def train_concise(wd):net = nn.Sequential(nn.Linear(num_inputs, 1))for param in net.parameters():param.data.normal_()loss = nn.MSELoss(reduction='none')num_epochs, lr = 100, 0.003# 偏置参数没有衰减trainer = torch.optim.SGD([{"params":net[0].weight,'weight_decay': wd}, {"params":net[0].bias}],lr=lr)animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',xlim=[5, num_epochs], legend=['train', 'test'])for epoch in range(num_epochs):for X, y in train_iter:trainer.zero_grad()l = loss(net(X), y)l.mean().backward()trainer.step()if (epoch + 1) % 5 == 0:animator.add(epoch + 1,(d2l.evaluate_loss(net, train_iter, loss),d2l.evaluate_loss(net, test_iter, loss)))print('w的L2范数:', net[0].weight.norm().item())

2、忽略正则化直接训练

train_concise(0)
w的L2范数: 13.727912902832031

3、使用权重衰减

train_concise(3)
w的L2范数: 0.3890590965747833

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/227023.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SQL基础:操作环境搭建

在上一节中,我们简单讲述了数据库和SQL的基本概念。 本节我们讲述一下环境搭建,为下一节讲表的基本操作做下铺垫。 环境搭建 具体到操作,我们就要准备一些环境了。如果不进行练习,我们学习的知识将很快被遗忘。 MySQL安装&…

【MySQL内置函数】

目录: 前言一、日期函数获取日期获取时间获取时间戳在日期上增加时间在日期上减去时间计算两个日期相差多少天当前时间案例:留言板 二、字符串函数查看字符串字符集字符串连接查找字符串大小写转换子串提取字符串长度字符串替换字符串比较消除左右空格案…

【话题】低代码123

目录 一、什么是低代码 二、低代码的优缺点 三、你认为低代码会替代传统编程吗? 四、有哪些低代码工具和框架 4.1 国外的平台 4.2 国内的平台 五、未来的软件研发 低代码,听着就过瘾的一个词。而且不是无代码,这说明,低代码…

计算机组成原理-函数调用的汇编表示(call和ret指令 访问栈帧 切换栈帧 传递参数和返回值)

文章目录 call指令和ret指令高级语言的函数调用x86汇编语言的函数调用call ret指令小结其他问题 如何访问栈帧函数调用栈在内存中的位置标记栈帧范围:EBP ESP寄存器访问栈帧数据:push pop指令访问栈帧数据:mov指令小结 如何切换栈帧函数返回时…

Spring入门

学习的最大理由是想摆脱平庸,早一天就多一份人生的精彩;迟一天就多一天平庸的困扰。各位小伙伴,如果您: 想系统/深入学习某技术知识点… 一个人摸索学习很难坚持,想组团高效学习… 想写博客但无从下手,急需…

Toyota Programming Contest 2023#8(AtCoder Beginner Contest 333)

A - Three Threes 题目大意:给你一个整数n,将这个数n输出n次。 呃呃 B - Pentagon 题目大意:给你一个正五边形ABCDE,给你任意两条边,判断是否相等 主要问题要判断一下内边:AD,AC,…

基于ssm图书商城网站的设计和开发论文

摘 要 现代经济快节奏发展以及不断完善升级的信息化技术,让传统数据信息的管理升级为软件存储,归纳,集中处理数据信息的管理方式。本图书商城网站就是在这样的大环境下诞生,其可以帮助管理者在短时间内处理完毕庞大的数据信息&am…

Win11极速安装Tensorflow-gpu+CUDA+cudnn

文章目录 0.pip/conda换默认源1.Anacondapython虚拟环境2.安装CUDA以及cudnn 0.pip/conda换默认源 为了高效下载,建议先把默认源换了,很简单这里不再赘述。(我用梯子,所以没换源😋) 1.Anacondapython虚拟…

最棒的 7 款精选我的世界光影水反效果包

光影支持基础Mod下载 版本:1.12✔1.11.2✔1.10.2✔1.9.4✔1.8✔1.7.10✔ 下载和安装: GLSL Shaders Mod 或者 OptiFine 我的世界光影支持Mod可以帮助玩家提示游戏画面,加强你的游戏录制视频效果和实时游戏体验。有数量众多的光影效果包提供…

一文了解Tomcat

文章目录 1、Tomcat介绍2、Tomcat使用配置2.1、Tomcat下载启动2.2、Tomcat启动乱码2.3、Tomcat端口号修改 3、Tomcat项目部署4、IDEA中使用Tomcat方式 1、Tomcat介绍 什么是Tomcat ​ Tomcat是Apache软件基金会一个核心项目,是一个开源免费的轻量级web服务器&#x…

《Linux C编程实战》笔记:一些系统调用

目录 dup和dup2函数 fcntl函数 示例程序1 示例程序2 ioctl函数 dup和dup2函数 #include <unistd.h> int dup(int oldfd); int dup2(int oldfd, int newfd): dup 函数复制 oldfd 参数所指向的文件描述符。 参数&#xff1a; oldfd&#xff1a;要复制的文件描述符的…

[笔记] wsl 下使用 qemu/grub 模拟系统启动(单分区)

背景 最近在学习操作系统&#xff0c;需要从零开始搭建系统&#xff0c;由于教程中给的虚拟机搭建的方式感觉还是过于重量级&#xff0c;因此研究了一下通过 qemu 模拟器&#xff0c;配合 grub 完成启动系统的搭建。 qemu 介绍 qemu 是一款十分优秀的系统模拟器&#xff0c;…

@PostMapping接收String类型的参数

接口这样定义&#xff1a; PostMapping("/aaa") public void getById(String param)参数这样测试&#xff1a;

C++特殊类和类型转换剖析

目录 一、特殊类 1.1拒绝被拷贝的类 1.2 限制在堆上创建类 1.3 限制在栈上创建的类 1.4 不能被继承的类 二、类型转换 2.1 static_cast 2.2 reinterpret_cast 2.3 const_cast 2.4 dynamic_cast 一、特殊类 什么是特殊类&#xff1f;在普通类的设计基础上&#xff0c…

基于Java+vue的音乐网站设计与实现(源码+文档+数据库)

摘 要 在此基础上&#xff0c;提出了一种基于javavue的在线音乐排行榜系统的设计与实现方法。本系统分为两个大的功能&#xff0c;即&#xff1a;前端显示、后端管理。而在前台&#xff0c;则是播放不同的歌曲&#xff0c;让人可以在上面观看不同的歌曲&#xff0c;也可以观看…

CSS学习

CSS学习 1. 什么是css?2.css引入方式2.1 内嵌式2.2 外联式2.3 行内式2.4 引入方式特点 3. 基础选择器3.1 标签选择器3.2 类选择器3.3 id选择器3.4 通配符选择器 4. 文字基本样式4.1 字体样式4.1.1 字体大小4.1.2 字体粗细4.1.3 倾斜4.1.4 字体4.1.5 字体font相关属性连写 4.2 …

地图自定义省市区合并展示数据整合

需求一&#xff1a;将省级地图下的两个市合并成一个区域&#xff0c;中间的分割线隐藏。 1、访问下方地址&#xff0c;搜索并下载省级地图json文件。 地址&#xff1a;https://datav.aliyun.com/portal/school/atlas/area_selector 2、切换到边界生成器&#xff0c;上传刚刚下…

论文降重同义词替换的实践经验与改进建议 快码论文

大家好&#xff0c;今天来聊聊论文降重同义词替换的实践经验与改进建议&#xff0c;希望能给大家提供一点参考。 以下是针对论文重复率高的情况&#xff0c;提供一些修改建议和技巧&#xff0c;可以借助此类工具&#xff1a; 标题&#xff1a;论文降重同义词替换的实践经验与改…

Datawhale 12月组队学习 leetcode基础 day3 递归

这是一个新的专栏&#xff0c;主要是一些算法的基础&#xff0c;对想要刷leedcode的同学会有一定的帮助&#xff0c;如果在算法学习中遇到了问题&#xff0c;也可以直接评论或者私信博主&#xff0c;一定倾囊相助 进入正题&#xff0c;今天咱们要说的是递归&#xff0c;递归是是…

Qt中槽函数在那个线程执行的探索和思考

信号和槽是Qt的核心机制之一&#xff0c;通过该机制大大简化了开发者的开发难度。信号和槽属于观察者模式&#xff08;本质上是回调函数的应用&#xff09;。是函数就需要考虑其是在那个线程中执行&#xff0c;本文讨论的就是槽函数在那个线程中执行的问题。 目录 1. connect…