【PyTorch】权重衰减

文章目录

  • 1. 理论介绍
  • 2. 实例解析
    • 2.1. 实例描述
    • 2.2. 代码实现

1. 理论介绍

  • 通过对模型过拟合的思考,人们希望能通过某种工具调整模型复杂度,使其达到一个合适的平衡位置。
  • 权重衰减(又称 L 2 L_2 L2正则化)通过为损失函数添加惩罚项,用来惩罚权重的 L 2 L_2 L2范数,从而限制模型参数值,促使模型参数更加稀疏或更加集中,进而调整模型的复杂度,即: L ( w , b ) + λ 2 ∥ w ∥ 2 L(\mathbf{w}, b) + \frac{\lambda}{2} \|\mathbf{w}\|^2 L(w,b)+2λw2其中 λ \lambda λ权重衰减的超参数
  • L p L_p Lp范数: ∥ x ∥ p = ( ∑ i = 1 n ∣ x i ∣ p ) 1 / p \|\mathbf{x}\|_p = \left(\sum_{i=1}^n \left|x_i \right|^p \right)^{1/p} xp=(i=1nxip)1/p
    p = 1 p=1 p=1时称为 L 1 L_1 L1范数;当 p = 2 p=2 p=2时称为 L 2 L_2 L2范数。
    惩罚 L 1 L_1 L1范数会导致模型将权重集中在一小部分特征上, 而将其他权重清除为零, 这称为特征选择;惩罚 L 2 L_2 L2范数会导致模型在大量特征上均匀分布权重,使得模型对单个变量的观测误差更为稳定。
  • 通常不建议对偏置进行正则化,因为偏置的取值并不像权值那样会随着训练过程而变化,因此对偏置进行正则化对于控制模型的复杂度影响较小;另外,对偏置进行正则化可能会导致对数据中的偏移进行过度拟合,而减弱了模型对其他特征的学习。

2. 实例解析

2.1. 实例描述

使用以下公式生成包含20个样本的小训练集和100个样本的测试集,并用线性网络进行拟合: y = 0.05 + ∑ i = 1 200 0.01 x i + ϵ where  ϵ ∼ N ( 0 , 0.0 1 2 ) . y = 0.05 + \sum_{i = 1}^{200} 0.01 x_i + \epsilon \text{ where } \epsilon \sim \mathcal{N}(0, 0.01^2). y=0.05+i=12000.01xi+ϵ where ϵN(0,0.012).

2.2. 代码实现

  • 主要代码
optimizer = optim.SGD([{"params": net.weight,"weight_decay": weight_decay},{"params": net.bias}], lr=lr)
  • 完整代码
import os
import torch
from torch import nn, optim
from torch.utils.data import TensorDataset, DataLoader
from tensorboardX import SummaryWriter
from rich.progress import trackdef data_generator(w, b, num):"""为线性模型生成数据"""X = torch.randn(num, len(w))y = torch.sum(X @ w, dim=1) + by += torch.normal(0, 0.01, y.shape)return X, y.reshape(-1, 1)def load_dataset(*tensors):"""加载数据集"""dataset = TensorDataset(*tensors)return DataLoader(dataset, batch_size, shuffle=True)def evaluate_loss(dataloader, net, criterion):"""评估模型在指定数据集上的损失"""num_examples = 0loss_sum = 0.0with torch.no_grad():for X, y in dataloader:X, y = X.cuda(), y.cuda()loss = criterion(net(X), y)num_examples += y.shape[0]loss_sum += loss.sum()return loss_sum / num_examplesif __name__ == '__main__':# 全局参数设置lr = 0.003num_epochs = 100batch_size = 5# 创建记录器def log_dir():root = "runs"if not os.path.exists(root):os.mkdir(root)order = len(os.listdir(root)) + 1return f'{root}/exp{order}'writer = SummaryWriter(log_dir=log_dir())# 合成数据集num_inputs = 200n_train, n_test = 20, 100true_w, true_b = torch.ones((num_inputs, 1)) * 0.01, 0.05X, y = data_generator(true_w, true_b, n_train + n_test)# 加载数据集dataloader_train = load_dataset(X[:n_train], y[:n_train])dataloader_test = load_dataset(X[n_train:], y[n_train:])def loop(weight_decay):# 定义模型net = nn.Linear(num_inputs, 1).cuda()nn.init.normal_(net.weight)nn.init.constant_(net.bias, 0)criterion = nn.MSELoss(reduction='none')optimizer = optim.SGD([{"params": net.weight,"weight_decay": weight_decay},{"params": net.bias}], lr=lr)# 训练循环for epoch in track(range(num_epochs), description=f'wd={weight_decay}'):for X, y in dataloader_train:X, y = X.cuda(), y.cuda()loss = criterion(net(X), y)optimizer.zero_grad()loss.mean().backward()optimizer.step()writer.add_scalars(f'wd={weight_decay}', {'train_loss': evaluate_loss(dataloader_train, net, criterion),'test_loss': evaluate_loss(dataloader_test, net, criterion),}, epoch)for weight_decay in [0, 3]:loop(weight_decay)writer.close()
  • 输出结果
    • weight_decay = 0
      0
    • weight_decay = 3
      3

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/203184.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

uni-app 微信小程序之加载行政区图

文章目录 1. 实现效果2. 实现步骤 1. 实现效果 2. 实现步骤 使用三方组件 ucharts echarts 高性能跨全端图表组件页面导入引入的三方组件 组件demo代码 <template><view class"qiun-columns"><view class"cu-bar bg-white margin-top-xs"…

文献阅读:基于改进ConvNext的玉米叶片病害分类

文献阅读&#xff1a;基于改进ConvNext的玉米叶片病害分类 CBAM注意力机制模块&#xff1a; 1&#xff1a;通道注意力模块&#xff0c;对输入进来的特征层分别进行全局平均池化&#xff08;AvgPool&#xff09;和全局最大池化&#xff08;MaxPool&#xff09;&#xff08;两个…

可视化监控云平台/智能监控平台EasyCVR国标设备开启音频没有声音是什么原因?

视频云存储/安防监控EasyCVR视频汇聚平台基于云边端智能协同&#xff0c;支持海量视频的轻量化接入与汇聚、转码与处理、全网智能分发、视频集中存储等。GB28181视频平台EasyCVR拓展性强&#xff0c;视频能力丰富&#xff0c;具体可实现视频监控直播、视频轮播、视频录像、云存…

ROS-ROS通信机制-参数服务器

文章目录 一、基础理论知识二、C实现三、Python实现 一、基础理论知识 参数服务器在ROS中主要用于实现不同节点之间的数据共享。参数服务器相当于是独立于所有节点的一个公共容器&#xff0c;可以将数据存储在该容器中&#xff0c;被不同的节点调用&#xff0c;当然不同的节点…

多人聊天UDP

服务端 package 多人聊天;import java.io.BufferedReader; import java.io.InputStream; import java.io.InputStreamReader; import java.io.OutputStream; import java.io.PrintStream; import java.net.ServerSocket; import java.net.Socket; import java.util.ArrayList;…

零基础小白怎么准备蓝桥杯-蓝桥杯竞赛经验分享

零基础小白怎么准备蓝桥杯-蓝桥杯竞赛经验分享 前言竞赛简介竞赛目的如何备战1.基础学习2.实战训练&#xff08;非常重要&#xff09; 资料分享 前言 博主在蓝桥杯中获得过十四届Java B 组的省一国二&#xff0c;本文为大家介绍一下蓝桥杯并分享一下自己的参赛经验。 竞赛简介…

封装校验-----Vue3+ts项目

登录校验页面 <script setup lang"ts"> import { ref } from vue import { mobileRules, passwordRules } from /utils/rules const mobile ref() const password ref() </script><!-- 表单 --><van-form autocomplete"off">&l…

数位统计DP

数位DP 数位&#xff08;digit&#xff09;指的是一个数字中的每一位。例如&#xff0c;对于整数1234来说&#xff0c;它有四个数位&#xff0c;分别是1、2、3和4。在数位统计 DP 中&#xff0c;我们通常将数字拆解成各个数位&#xff0c;并使用这些数位进行状态定义和转移。通…

电脑设置代理IP,上网怎么使用代理

在我们的日常生活中&#xff0c;代理IP的使用越来越常见。当我们需要隐藏自己的真实IP地址时&#xff0c;代理IP就成为了我们的不二选择。那么&#xff0c;如何设置代理IP来固定上网地址呢&#xff1f;本文将详细介绍代理IP的设置方法以及如何使用代理IP固定上网地址。 一、代…

Failed to connect to github.com port 443 after 21055 ms: Timed out

目前自己使用了梯*子还是会报这样的错误&#xff0c;连接不到的github。 查了一下原因&#xff1a; 是因为这个请求没有走代理。 解决方案&#xff1a; 设置 -> 网络和Internet -> 代理 -> 编辑 记住这个IP和端口 使用以下命令&#xff1a; git config --global h…

Kafka集群调优

一、前言 我们需要对4个规格的kafka能力进行探底&#xff0c;即其可以承载的最大吞吐&#xff1b;4个规格对应的单节点的配置如下&#xff1a; 标准版&#xff1a; 2C4G铂金版&#xff1a; 4C8G专业版&#xff1a; 8C16G企业版&#xff1a; 16C32G 另外&#xff0c;一般来讲…

git 分支的创建与删除

一 创建本地分支 git checkout -b codetwo //创建本地分支 codetwo git branch newcode //创建本地分支newcode创建的分支如下图&#xff1a; 用checkout的方式创建&#xff0c;只是创建的同时还切换到了这个本地分支 二 创建远程分支 git branch newcode //创…

ALPHA开发板烧录工具MfgTool烧写原理

一. 简介 MfgTool 工具是 NXP 提供的专门用于给 I.MX 系列 CPU 烧写系统的软件&#xff0c;可以在 NXP 官网下载到。运行在windows下。可以烧写uboot.imx、zImage、dtb&#xff0c;rootfs。通过 USB口进行烧写。 上一篇文章简单了解了 烧录工具MfgTool &#xff08;针对ALPH…

TrustZone之数据、指令和统一缓存(unified caches)

在Arm架构中,data caches是物理标记(physically tagged)的。物理地址包括该行来自哪个地址空间,如下所示: 对于NP:0x800000的缓存查找永远不会命中使用SP:0x800000标记的缓存行。这是因为NP:0x800000和SP:0x800000是不同的地址。 这也影响缓存维护操作。考虑前面图表中的示…

人工智能学习8(集成学习之xgboost)

编译工具&#xff1a;PyCharm 文章目录 编译工具&#xff1a;PyCharm 集成学习XGBoost(Extreme Gradient Boosting)极端梯度提升树1.最优模型的构建方法XGBoost目标函数案例1&#xff1a;泰坦尼克号案例2&#xff1a;对奥拓集团差评进行正确分类。数据准备&#xff1a;1.第一种…

基于深度学习yolov5行人社交安全距离监测系统

欢迎大家点赞、收藏、关注、评论啦 &#xff0c;由于篇幅有限&#xff0c;只展示了部分核心代码。 文章目录 一项目简介系统工作原理主要组成部分技术实现优势和特点应用场景和前景 二、功能三、系统四. 总结 一项目简介 基于深度学习 YOLOv5 的行人社交安全距离监测系统是一种…

数据仓库与数据挖掘复习资料

一、题型与考点[第一种] 1、解释基本概念(中英互译解释简单的含义)&#xff1b; 2、简答题(每个10分有两个一定要记住)&#xff1a; ① 考时间序列Time series(第六章)的基本概念含义解释作用&#xff08;序列模式挖掘的作用&#xff09;&#xff1b; ② 考聚类(第五章)重点考…

自动化定时发送天气提醒邮件

&#x1f388; 博主&#xff1a;一只程序猿子 &#x1f388; 博客主页&#xff1a;一只程序猿子 博客主页 &#x1f388; 个人介绍&#xff1a;爱好(bushi)编程&#xff01; &#x1f388; 创作不易&#xff1a;如喜欢麻烦您点个&#x1f44d;或者点个⭐&#xff01; &#x1f…

配置端口安全示例

组网需求 如图1所示&#xff0c;用户PC1、PC2、PC3通过接入设备连接公司网络。为了提高用户接入的安全性&#xff0c;将接入设备Switch的接口使能端口安全功能&#xff0c;并且设置接口学习MAC地址数的上限为接入用户数&#xff0c;这样其他外来人员使用自己带来的PC无法访问公…

华为配置风暴控制示例

组网需求 如下图所示&#xff0c;SwitchA作为二层网络到三层路由器的衔接点&#xff0c;需要防止二层网络转发的广播、未知组播或未知单播报文产生广播风 配置思路 用如下的思路配置风暴控制。 通过在GE0/0/1接口视图下配置风暴控制功能&#xff0c;实现防止二层网络转发的…