《动手学深度学习(PyTorch版)》笔记4.5

注:书中对代码的讲解并不详细,本文对很多细节做了详细注释。另外,书上的源代码是在Jupyter Notebook上运行的,较为分散,本文将代码集中起来,并加以完善,全部用vscode在python 3.9.18下测试通过。

Chapter4 Multilayer Perceptron

4.5 Weight Decay

前一节我们描述了过拟合的问题,本节我们将介绍一些正则化模型的技术。我们总是可以通过去收集更多的训练数据来缓解过拟合,但这可能成本很高,耗时颇多,或者完全超出我们的控制,因而在短期内不可能做到。假设我们已经拥有尽可能多的高质量数据,我们便可以将重点放在正则化技术上。

在多项式回归的例子中,我们可以通过调整拟合多项式的阶数来限制模型的容量。实际上,限制特征的数量是缓解过拟合的一种常用技术。然而,简单地丢弃特征对这项工作来说可能过于生硬。我们继续思考多项式回归的例子,考虑高维输入可能发生的情况。多项式对多变量数据的自然扩展称为单项式(monomials),也可以说是变量幂的乘积。单项式的阶数是幂的和。例如, x 1 2 x 2 x_1^2 x_2 x12x2 x 3 x 5 2 x_3 x_5^2 x3x52都是3次单项式。

注意,随着阶数 d d d的增长,带有阶数 d d d的项数迅速增加。 给定 k k k个变量,阶数为 d d d的项的个数为 ( k − 1 + d k − 1 ) {k - 1 + d} \choose {k - 1} (k1k1+d),即 C k − 1 + d k − 1 = ( k − 1 + d ) ! ( d ) ! ( k − 1 ) ! C^{k-1}_{k-1+d} = \frac{(k-1+d)!}{(d)!(k-1)!} Ck1+dk1=(d)!(k1)!(k1+d)!。因此即使是阶数上的微小变化,比如从 2 2 2 3 3 3,也会显著增加我们模型的复杂性。仅仅通过简单的限制特征数量(在多项式回归中体现为限制阶数),可能仍然使模型在过简单和过复杂中徘徊,我们需要一个更细粒度的工具来调整函数的复杂性,使其达到一个合适的平衡位置。

4.5.1 Norm and Weight Decay

权重衰减是最广泛使用的正则化的技术之一在训练参数化机器学习模型时,权重衰减(weight decay)是最广泛使用的正则化的技术之一,它通常也被称为 L 2 L_2 L2正则化。这项技术通过函数与零的距离来衡量函数的复杂度,因为在所有函数 f f f中,函数 f = 0 f = 0 f=0(所有输入都得到值 0 0 0)在某种意义上是最简单的。但是我们应该如何精确地测量一个函数和零之间的距离呢?没有一个准确的答案。
一种简单的方法是通过线性函数 f ( x ) = w ⊤ x f(\mathbf{x}) = \mathbf{w}^\top \mathbf{x} f(x)=wx中的权重向量的某个范数来度量其复杂性,例如 ∥ w ∥ 2 \| \mathbf{w} \|^2 w2。要保证权重向量比较小,最常用方法是将其范数作为惩罚项加到最小化损失的问题中。将原来的训练目标最小化训练标签上的预测损失,调整为最小化预测损失和惩罚项之和。现在,如果我们的权重向量增长的太大,我们的学习算法可能会更集中于最小化权重范数 ∥ w ∥ 2 \| \mathbf{w} \|^2 w2。这正是我们想要的。3.1节的线性回归例子里,损失由下式给出:

L ( w , b ) = 1 n ∑ i = 1 n 1 2 ( w ⊤ x ( i ) + b − y ( i ) ) 2 . L(\mathbf{w}, b) = \frac{1}{n}\sum_{i=1}^n \frac{1}{2}\left(\mathbf{w}^\top \mathbf{x}^{(i)} + b - y^{(i)}\right)^2. L(w,b)=n1i=1n21(wx(i)+by(i))2.

为了惩罚权重向量的大小,我们必须以某种方式在损失函数中添加 ∥ w ∥ 2 \| \mathbf{w} \|^2 w2,但是模型应该如何平衡这个新的额外惩罚的损失?实际上,我们通过正则化常数 λ \lambda λ来描述这种权衡,这是一个非负超参数,我们使用验证数据拟合:

L ( w , b ) + λ 2 ∥ w ∥ 2 , L(\mathbf{w}, b) + \frac{\lambda}{2} \|\mathbf{w}\|^2, L(w,b)+2λw2,

对于 λ = 0 \lambda = 0 λ=0,我们恢复了原来的损失函数。对于 λ > 0 \lambda > 0 λ>0,我们限制 ∥ w ∥ \| \mathbf{w} \| w的大小。这里我们仍然除以 2 2 2,是因为我们取一个二次函数的导数时, 2 2 2 1 / 2 1/2 1/2会抵消。通过平方 L 2 L_2 L2范数,我们去掉平方根,留下权重向量每个分量的平方和,这使得惩罚的导数很容易计算。

L 2 L_2 L2正则化线性模型构成经典的岭回归(ridge regression)算法, L 1 L_1 L1正则化线性回归是统计学中类似的基本模型,通常被称为套索回归(lasso regression)。我们通常使用 L 2 L_2 L2范数的一个原因是它对权重向量的大分量施加了巨大的惩罚。这使得我们的学习算法偏向于在大量特征上均匀分布权重的模型。在实践中,这可能使它们对单个变量中的观测误差更为稳定。相比之下, L 1 L_1 L1惩罚会导致模型将权重集中在一小部分特征上,而将其他权重清除为零。这称为特征选择(feature selection),这可能是其他场景下需要的。

L 2 L_2 L2正则化回归的小批量随机梯度下降更新如下式:

w ← ( 1 − η λ ) w − η ∣ B ∣ ∑ i ∈ B x ( i ) ( w ⊤ x ( i ) + b − y ( i ) ) . \begin{aligned} \mathbf{w} & \leftarrow \left(1- \eta\lambda \right) \mathbf{w} - \frac{\eta}{|\mathcal{B}|} \sum_{i \in \mathcal{B}} \mathbf{x}^{(i)} \left(\mathbf{w}^\top \mathbf{x}^{(i)} + b - y^{(i)}\right). \end{aligned} w(1ηλ)wBηiBx(i)(wx(i)+by(i)).

根据之前章节所讲的,我们根据估计值与观测值之间的差异来更新 w \mathbf{w} w。然而,我们同时也在试图将 w \mathbf{w} w的大小缩小到零。这就是为什么这种方法有时被称为权重衰减。我们仅考虑惩罚项,优化算法在训练的每一步衰减权重。与特征选择相比,权重衰减为我们提供了一种连续的机制来调整函数的复杂度。较小的 λ \lambda λ值对应较少约束的 w \mathbf{w} w,而较大的 λ \lambda λ值对 w \mathbf{w} w的约束更大。是否对相应的偏置 b 2 b^2 b2进行惩罚在不同的实践中会有所不同,在神经网络的不同层中也会有所不同。通常,网络输出层的偏置项不会被正则化。

4.5.2 High Dimensional Linear Regression

我们通过一个简单的例子来演示权重衰减。首先,我们像以前一样生成一些数据,生成公式如下:

y = 0.05 + ∑ i = 1 d 0.01 x i + ϵ where  ϵ ∼ N ( 0 , 0.0 1 2 ) . y = 0.05 + \sum_{i = 1}^d 0.01 x_i + \epsilon \text{ where } \epsilon \sim \mathcal{N}(0, 0.01^2). y=0.05+i=1d0.01xi+ϵ where ϵN(0,0.012).

我们选择标签是关于输入的线性函数。标签同时被均值为0,标准差为0.01高斯噪声破坏。为了使过拟合的效果更加明显,我们可以将问题的维数增加到 d = 200 d = 200 d=200,并使用一个只包含20个样本的小训练集。

import matplotlib.pyplot as plt
import torch
from d2l import torch as d2l
from torch import nnn_train, n_test, num_inputs, batch_size = 20, 100, 200, 5
true_w, true_b = torch.ones((num_inputs, 1)) * 0.01, 0.05
train_data = d2l.synthetic_data(true_w, true_b, n_train)
train_iter = d2l.load_array(train_data, batch_size)
test_data = d2l.synthetic_data(true_w, true_b, n_test)
test_iter = d2l.load_array(test_data, batch_size, is_train=False)#初始化模型参数
def init_params():w = torch.normal(0, 1, size=(num_inputs, 1), requires_grad=True)b = torch.zeros(1, requires_grad=True)return [w, b]#定义L2范数惩罚
def l2_penalty(w):return torch.sum(w.pow(2)) / 2#训练
def train(lambd):w, b = init_params()net, loss = lambda X: d2l.linreg(X, w, b), d2l.squared_lossnum_epochs, lr = 100, 0.003animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',xlim=[5, num_epochs], legend=['train', 'test'])for epoch in range(num_epochs):for X, y in train_iter:# 增加了L2范数惩罚项,# 广播机制使l2_penalty(w)成为一个长度为batch_size的向量l = loss(net(X), y) + lambd * l2_penalty(w)l.sum().backward()d2l.sgd([w, b], lr, batch_size)if (epoch + 1) % 5 == 0:animator.add(epoch + 1, (d2l.evaluate_loss(net, train_iter, loss),d2l.evaluate_loss(net, test_iter, loss)))print('w的L2范数是:', torch.norm(w).item())#忽略正则化直接训练
train(lambd=0)
plt.show()#使用正则化后进行训练
train(lambd=3)
plt.show()#简洁实现
def train_concise(wd):net = nn.Sequential(nn.Linear(num_inputs, 1))for param in net.parameters():param.data.normal_()loss = nn.MSELoss(reduction='none')num_epochs, lr = 100, 0.003# 偏置参数没有衰减trainer = torch.optim.SGD([{"params":net[0].weight,'weight_decay': wd},{"params":net[0].bias}], lr=lr)animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',xlim=[5, num_epochs], legend=['train', 'test'])for epoch in range(num_epochs):for X, y in train_iter:trainer.zero_grad()l = loss(net(X), y)l.mean().backward()trainer.step()if (epoch + 1) % 5 == 0:animator.add(epoch + 1,(d2l.evaluate_loss(net, train_iter, loss),d2l.evaluate_loss(net, test_iter, loss)))print('w的L2范数:', net[0].weight.norm().item())train_concise(0)
train_concise(3)
plt.show()

无正则化效果(过拟合):
在这里插入图片描述

正则化效果:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/649613.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

掌握可视化大屏:提升数据分析和决策能力的关键(下)

🤍 前端开发工程师、技术日更博主、已过CET6 🍨 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 🕠 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 🍚 蓝桥云课签约作者、上架课程《Vue.js 和 E…

(大众金融)SQL server面试题(3)-客户已用额度总和

今天,面试了一家公司,什么也不说先来三道面试题做做,第三题。 那么,我们就开始做题吧,谁叫我们是打工人呢。 题目是这样的: DEALER_INFO经销商授信协议号码经销商名称经销商证件号注册地址员工人数信息维…

three.js 鼠标选中模型弹出标签

效果&#xff1a;请关注抖音 代码&#xff1a; <template><div><el-container><el-main><div class"box-card-left"><div id"threejs" style"border: 1px solid red;position: relative;"></div><…

202410读书笔记|《半小时漫画青春期》——成为自己世界的星星,这才是最要紧的事儿

202410读书笔记|《半小时漫画青春期&#xff1a;心理篇》——成为自己世界的星星&#xff0c;这才是最要紧的事儿 一、一到考试就焦虑&#xff0c;怎么办&#xff1f;二、以前情绪挺淡定&#xff0c;现在咋动不动就爆发&#xff1f;三、追星那么开心&#xff0c;为啥还要我小心…

ajax点击搜索返回所需数据

html 中body设置&#xff08;css设置跟进自身需求&#xff09; <p idsearch_head>学生信息查询表</p> <div id"div_1"> <div class"search_div"> <div class"search_div_item"> …

数据库设计的一些原则

文章目录 数据库设计原则表之间的关系一对一关系&#xff08;了解&#xff09;一对多&#xff08;多对一&#xff09;多对多联合主键和复合主键 数据库设计准则-范式1、函数依赖2、完全函数依赖3、部分函数依赖4、传递函数依赖5、码 第一范式第二范式第三范式第三范式 数据库设…

【原神游戏开发日志3】登录和注册有何区别?

版权声明&#xff1a; ● 本文为“优梦创客”原创文章&#xff0c;您可以自由转载&#xff0c;但必须加入完整的版权声明 ● 文章内容不得删减、修改、演绎 ● 本文视频版本&#xff1a;见文末 ● 相关学习资源&#xff1a;见文末 前言 ● 这是我们原神游戏开发日记的第三期 ●…

TensorFlow2实战-系列教程1:回归问题预测

&#x1f9e1;&#x1f49b;&#x1f49a;TensorFlow2实战-系列教程 总目录 有任何问题欢迎在下面留言 本篇文章的代码运行界面均在Jupyter Notebook中进行 本篇文章配套的代码资源已经上传 1、环境测试 import tensorflow as tf import numpy as np tf.__version__打印结果 ‘…

2024年材料、控制工程与制造技术国际学术会议(ICMCEMT 2024)

2024年材料、控制工程与制造技术国际学术会议(ICMCEMT 2024) 2024 International Conference on Materials, Control Engineering, and Manufacturing Technology (ICMCEMT 2024) 会议简介&#xff1a; 2024年材料、控制工程与制造技术国际学术会议(ICMCEMT 2024)定于2024年在…

分布式因果推断在美团履约平台的探索与实践

美团履约平台技术部在因果推断领域持续的探索和实践中&#xff0c;自研了一系列分布式的工具。本文重点介绍了分布式因果树算法的实现&#xff0c;并系统地阐述如何设计实现一种分布式因果树算法&#xff0c;以及因果效应评估方面qini_curve/qini_score的不足与应对技巧。希望能…

ERROR Failed to get response from https://registry.npm.taobao.org/ 错误的解决

这个问题最近才出现的。可能跟淘宝镜像的证书到期有关。 解决方式一&#xff1a;更新淘宝镜像&#xff08;本人测试无效&#xff0c;但建议尝试&#xff09; 虽然无效&#xff0c;但感觉是有很大关系的。还是设置一下比较好。 淘宝镜像的地址&#xff08;registry.npm.taobao…

【计算机网络】协议,电路交换,分组交换

定义了在两个或多个通信实体之间交换的报文格式和次序,以及报文发送和/或接收一个报文或其他事件所采取的动作.网络边缘: 端系统 (因为处在因特网的边缘) 主机 端系统 客户 client服务器 server今天大部分服务器都属于大型数据中心(data center)接入网(access network) 指将端…

Visual Studio 2022 C++ 生成dll或so文件在windows或linux下用C#调用

背景 开发中我们基本使用windows系统比较快捷&#xff0c;但是部署的时候我们又希望使用linux比较便宜&#xff0c;硬件产商还仅提供了c sdk&#xff01;苦了我们做二次开发的码农。 方案 需要确认一件事&#xff0c;目前c这门语言不是跨平台的 第一个问题【C生成dll在window…

nav02 学习03 机器人传感器

机器人传感器 移动机器人配备了大量传感器&#xff0c;使它们能够看到和感知周围的环境。这些传感器获取的信息可用于构建和维护环境地图、在地图上定位机器人以及查看环境中的障碍物。这些任务对于能够安全有效地在动态环境中导航机器人至关重要。 机器人的传感器类似人的感官…

二极管漏电流对单片机ad采样偏差的影响

1&#xff0c;下图是常规的单片机采集电压电路&#xff0c;被测量电压经过电阻分压&#xff0c;给到mcu采集&#xff0c;反向二极管起到钳位作用&#xff0c;避免高压打坏mcu。 2&#xff0c;该电路存在的问题 二极管存在漏电流&#xff0c;会在100k电阻上产生叠加电压&#x…

qt 坦克大战游戏 GUI绘制

关于本章节中使用的图形绘制类&#xff0c;如QGraphicsView、QGraphicsScene等的详细使用说明请参见我的另一篇文章&#xff1a; 《图形绘制QGraphicsView、QGraphicsScene、QGraphicsItem、Qt GUI-CSDN博客》 本文将模仿坦克大战游戏&#xff0c;目前只绘制出一辆坦克&#…

Oracle RAC 集群的安装(保姆级教程)

文章目录 一、安装前的规划1、系统规划2、网络规划3、存储规划 二、主机配置1、Linux主机安装&#xff08;rac01&rac02&#xff09;2、配置yum源并安装依赖包&#xff08;rac01&rac02&#xff09;3、网络配置&#xff08;rac01&rac02&#xff09;4、存储配置&#…

c语言实现—动态通讯录

一.前言 上次带大家认识了一下顺序表&#xff0c;其实我们可以在顺序表的基础上实现一个通讯录的小项目&#xff0c;通讯录的本质仍然是顺序表&#xff0c;所以如果下面的代码你有问题的话&#xff0c;先去看看我的上篇文章哦~。 通讯录的功能大家应该都知道吧&#xff0c;这次…

chroot: failed to run command ‘/bin/bash’: No such file or directory

1. 问题描述及原因分析 在busybox的环境下&#xff0c;执行 cd rootfs chroot .报错如下&#xff1a; chroot: failed to run command ‘/bin/bash’: No such file or directory根据报错应该rootfs文件系统中缺少/bin/bash&#xff0c;进入查看确实默认是sh&#xff0c;换成…

【微信小程序】浮动按钮拖动功能

在开发过程中无意间想到了这个功能。一番查询之后找到这个功能相关的代码片段。拷贝过来之后各种报错&#xff0c;经过自己的整改以可以使用。 功能图片&#xff1a; 中间的微信按钮可以拖动 wxml&#xff1a;页面代码 <button catchtouchmove"buttonMove" cat…