pytorch学习——正则化技术——权重衰减

一、概念介绍

 

        权重衰减(Weight Decay)是一种常用的正则化技术,它通过在损失函数中添加一个惩罚项来限制模型的复杂度,从而防止过拟合。 

        在训练参数化机器学习模型时, 权重衰减(weight decay)是最广泛使用的正则化的技术之一, 它通常也被称为L2正则化。

1.1理解:

权重衰减(weight_decay)本质上是一个L2正则化系数

那什么是参数的正则化?从我的理解上,就是让参数限定在一定范围,目的是为了不让模型对训练集过拟合。

注:应对过拟合最好的方法还是扩大有效样本(但成本过高)

1.2如何控制模型容量?

1.将模型变得比较小,减少里面参数的数量

2.缩小参数的取值范围

注:权重衰退就是通过限制参数的取值来实现

1.3硬性限制

即使得w的每个项的平方都小于θ这个值,最强情况下就是θ等于0,即所有w都等于0

1.4柔性限制

 即损失函数后面加了一个非负项,为了使损失函数最小化,就得使得后面项足够小——起到限制w的作用,相比于硬性限制,柔性限制并没有将w的值限制在一个固定范围内。

1.5图解对最优解的影响

 

 上式为不加限制条件的最优解,即图中的绿色中心点,但该点会使得||w||^2这一项较大,其和并不是最优解。

而加上限制的最优点即为图中两曲线的交叉点

1.6更新参数法则

 

 1.7总结

   ~权重衰减是通过L2正则项使得模型参数不会过大,从而控制复杂度

   ~正则项权重是控制模型复杂度的超参数

二、示例演示

2.1模型构造

生成公式如下:

# 导入需要的库
import torch
from torch import nn
from d2l import torch as d2l# 定义训练和测试数据集的大小,输入特征的维度和批次大小
n_train, n_test, num_inputs, batch_size = 20, 100, 200, 5# 定义真实的权重true_w和偏差true_b,并将其初始化为0.01和0.05
true_w, true_b = torch.ones((num_inputs, 1)) * 0.01, 0.05# 使用d2l.synthetic_data函数生成训练数据train_data和测试数据test_data
# 生成的数据是通过真实的权重和偏差加上一些噪声生成的
train_data = d2l.synthetic_data(true_w, true_b, n_train)
test_data = d2l.synthetic_data(true_w, true_b, n_test)# 使用d2l.load_array函数将训练数据train_data和测试数据test_data
# 转换为数据迭代器train_iter和test_iter
train_iter = d2l.load_array(train_data, batch_size)
test_iter = d2l.load_array(test_data, batch_size, is_train=False)

2.2初始化模型参数

def init_params():w = torch.normal(0, 1, size=(num_inputs, 1), requires_grad=True)b = torch.zeros(1, requires_grad=True)return [w, b]
# 初始化模型参数w和b
# w的形状为(num_inputs, 1),从正态分布中随机生成
# b初始化为0
# 参数需要计算梯度,requires_grad参数被设置为True
# 返回一个包含w和b的列表

2.3定义L2范数

def l2_penalty(w):return torch.sum(w.pow(2)) / 2

2.4定义训练代码实现

        下面的代码将模型拟合训练数据集,并在测试数据集上进行评估。

函数的具体实现如下:

  1. 首先通过init_params()函数初始化模型参数w和b。

  2. 定义net函数为线性回归模型,loss为平方损失函数。

  3. 设置训练的轮数num_epochs和学习率lr,同时创建一个可视化工具animator,用于可视化训练过程中的损失值。

  4. 在每个epoch中,遍历训练数据集train_iter,对每个小批量数据(X, y)进行如下操作:

    • 计算模型的输出net(X),并计算损失函数loss(net(X), y)。

    • 加上L2范数惩罚项lambd * l2_penalty(w),其中l2_penalty(w)为权重w的L2范数。

    • 对损失函数进行反向传播,并使用SGD来更新模型参数w和b。

  5. 每5个epoch,计算训练集和测试集上的损失值,并使用animator将损失值可视化。

  6. 训练结束后,输出模型参数w的L2范数。

# 带有L2正则化的线性回归训练过程
# lambd表示L2正则化的强度# 初始化模型参数w和b
w, b = init_params()# 定义线性回归模型net和平方损失函数loss
net, loss = lambda X: d2l.linreg(X, w, b), d2l.squared_loss# 设置训练的轮数num_epochs和学习率lr
# 创建一个可视化工具animator,用于可视化训练过程中的损失值
num_epochs, lr = 100, 0.003
animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',xlim=[5, num_epochs], legend=['train', 'test'])# 在每个epoch中,遍历训练数据集train_iter,对每个小批量数据(X, y)进行如下操作:
for epoch in range(num_epochs):for X, y in train_iter:# 计算模型的输出net(X),并计算损失函数loss(net(X), y)# 加上L2范数惩罚项lambd * l2_penalty(w),其中l2_penalty(w)为权重w的L2范数# 对损失函数进行反向传播,并使用SGD来更新模型参数w和bl = loss(net(X), y) + lambd * l2_penalty(w)l.sum().backward()d2l.sgd([w, b], lr, batch_size)# 每5个epoch,计算训练集和测试集上的损失值,并使用animator将损失值可视化if (epoch + 1) % 5 == 0:animator.add(epoch + 1, (d2l.evaluate_loss(net, train_iter, loss),d2l.evaluate_loss(net, test_iter, loss)))# 训练结束后,输出模型参数w的L2范数
print('w的L2范数是:', torch.norm(w).item())

 2.5训练结果展示

        在这段代码中,lambd是一个超参数,表示L2正则化的强度。在每个小批量数据的损失函数中,会加上L2范数惩罚项,以控制模型的复杂度和防止过拟合。L2正则化的强度由超参数lambd控制,lambd越大,模型的复杂度就越小,对训练数据的拟合程度就越差,但是可以更好地控制过拟合。反之,lambd越小,模型的复杂度就越大,对训练数据的拟合程度就越好,但是可能会过拟合。在模型训练过程中,我们通常会使用交叉验证等技术来选择最优的超参数lambd。

2.5.1忽略正则化直接训练

        其中用lambd = 0禁用权重衰减后运行这个代码。 注意,虽然训练误差有了减少,但测试误差没有减少, 这意味着出现了严重的过拟合。

 2.5.2使用权重衰减

        下面,我们使用权重衰减来运行代码。 注意,在这里训练误差增大,但测试误差减小。 得到预期效果。

 三.简洁实现代码

# 导入需要的库
import torch
from torch import nn
from d2l import torch as d2ldef train_concise(wd):# 定义训练和测试数据集的大小,输入特征的维度和批次大小n_train, n_test, num_inputs, batch_size = 20, 100, 200, 5# 使用nn.Sequential定义了一个单层全连接神经网络net# 并将其参数使用param.data.normal_()方法初始化为随机值net = nn.Sequential(nn.Linear(num_inputs, 1))for param in net.parameters():param.data.normal_()# 使用nn.MSELoss定义平方损失函数loss# 该损失函数的reduction参数设置为'none',表示不对损失值进行降维loss = nn.MSELoss(reduction='none')# 设置训练的轮数num_epochs和学习率lr# 使用torch.optim.SGD定义一个优化器trainer,该优化器的参数包括网络的权重和偏差,以及权重衰减系数wdnum_epochs, lr = 100, 0.003trainer = torch.optim.SGD([{"params":net[0].weight,'weight_decay': wd},{"params":net[0].bias}], lr=lr)# 创建一个可视化工具animator,用于可视化训练过程中的损失值animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',xlim=[5, num_epochs], legend=['train', 'test'])# 在每个epoch中,遍历训练数据集train_iter,对每个小批量数据(X, y)进行如下操作:for epoch in range(num_epochs):for X, y in train_iter:# 将优化器trainer的梯度清零# 计算模型的输出net(X),并计算损失函数loss(net(X), y)# 对损失函数进行反向传播,并使用优化器trainer来更新模型参数trainer.zero_grad()l = loss(net(X), y)l.mean().backward()trainer.step()# 每5个epoch,计算训练集和测试集上的损失值,并使用animator将损失值可视化。if (epoch + 1) % 5 == 0:animator.add(epoch + 1,(d2l.evaluate_loss(net, train_iter, loss),d2l.evaluate_loss(net, test_iter, loss)))print('w的L2范数:', net[0].weight.norm().item())
train_concise(0)    #lambd设置为0

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/15864.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

消防应急照明设置要求在炼钢车间电气室的应用

摘 要:文章以GB51309—2018《消防应急照明和疏散指示系统技术标准》为设计依据,结合某炼钢车间转炉项目的设计过程,在炼钢车间电气室的疏散照明和备用照明的设计思路、原则和方法等方面进行阐述。通过选择合理的消防应急疏散照明控制系统及灯具供配电方案…

国内办公协作系统评测:5 款软件推荐

办公协作系统是现代信息化办公的必备工具之一,对于企业来说,选择一款好用的办公协作系统非常重要。然而,在众多的办公协作系统中,哪个好用是一个让人头痛的问题。总体而言,国内的办公协作系统已经相当成熟和完善&#…

golang interface类型的nil

golang中interface变量,底层两个对象来存,一个是type、一个是value,只有type、value都为nil时,interface变量才是nil package mainimport ("fmt""reflect" )type People interface {Show() }type Student str…

ALLEGRO之Logic

本文主要讲述ALLEGRO的Logic菜单。 (1)Net Logic:暂不清楚; (2)Net Schedule:暂不清楚; (3)AssignDifferential Pair:暂不清楚; &a…

AIGC(Artificial Intelligence Generated Content)和 Web3对比,未来发展

一、AIGC(Artificial Intelligence Generated Content)行业 历史背景 AIGC(Artificial Intelligence Generated Content)是指利用人工智能技术生成的内容。随着人工智能技术的不断发展,AIGC 行业逐渐兴起。早期的 AIG…

时间计算:时间戳加减指定的分钟数__Niyyy_记

在开发中多多少少会遇到时间的计算,以下只是一个简单的例子。 将时间戳加减指定的分钟数,并将结果转换为年月日时分秒格式: function addMinutes(timestamp, minutes) {var date new Date(timestamp);date.setTime(date.getTime() minute…

layui框架学习(35:数据表格_列参数设置)

Layui中的table数据表格模块支持对表格及列进行基础参数设置以提高数据的可视化及可操作性,本文学习并记录与列相关的主要基础参数的用法及效果。   基础参数field设置待显示到列中的数据的字段名,主要针对数据表格url属性中返回的数据集合或data属性设…

Linux-MySQL安装

配置: 1.VMware workstation pro 2.MySQL 3.centOS 5.7版本 1.配置yum仓库 #更新密钥 rpm --import https://repo.mysql.com/RPM-GPG-KEY-mysql-2022#安装 rpm --Uvh https://repo.mysql.com//mysql57-community-release-el7-7.noarch.rpm 2.安装mysql yum …

《向量数据库指南》——Milvus Cloud 2.3 和 2.4 版本的重要变化

Milvus Cloud2.3 和 2.4 版本的重要变化。 首先是 Milvus Cloud2.3 将支持 Json 数据类型,在此基础上亦会支持 Schemaless。此前,用户在使用 Milvus Cloud的过程中会先定一个静态 Schema,此时,如果在实际业务层面如果多了几个 feature 或者 Metadata,就意味着数据需要重新…

echarts柱状图横坐标文字过长的解决办法

背景:echarts图中横坐标显示的文字过长,导致字都堆积在一块如下图所示 解决办法 一:可以尝试修改‘axisLabel’的‘rotate’和‘interval’参数,‘rotate’参数可以设置标签的旋转角度,可以避免标签之间的重叠&#x…

uni-app之微信小程序实现‘下载+保存至本地+预览’功能

目录 一、H5如何实现下载功能 二、微信小程序实现下载资源功能方面与H5有很大的不同 三、 微信小程序实现文件(doc,pdf等格式,非图片)下载(下载->保存->预览)功能 四、图片预览、保存、转发、收藏&#xff1…

windowSoftInputMode设置stateHidden,DIALOG dismiss后,键盘再次显示

windowSoftInputMode设置stateHidden,DIALOG dismiss后,键盘再次显示 解决1 把windowSoftInputMode中的stateHidden属性去掉 android:windowSoftInputMode"adjustPan"解决2 在//在super.dismiss();前添加键盘隐藏方法,避免wind…

Stephen Wolfram:神经网络

Neural Nets 神经网络 OK, so how do our typical models for tasks like image recognition actually work? The most popular—and successful—current approach uses neural nets. Invented—in a form remarkably close to their use today—in the 1940s, neural nets …

SpringBoot 整合 MongoDB 连接 阿里云MongoDB

注:spring-boot-starter-data-mongodb 2.7.5;jdk 1.8 阿里云MongoDB是副本集实例的 在网上查找了一番,大多数都是教连接本地mongodb或者linux上的mongodb 阿里云上有java版连接教程,但它不是SpringBoot方法配置的,是手…

linux安装oracle

oracle安装 基于linux系统安装 Linux安装oracle12C Centos7.6 内存8GB 硬盘:50GB 可视化图形界面 yum groupinstall "GNOME Desktop" -y 可视化后续安装命令 1、软件环境包安装 yum -y install binutils compat-libcap1 compat-libstdc-33 gcc-c glib…

django自定义app,创建子应用

1.工程里创建apps包 ; 2.创建子应用,pycharm terminal 运行:python ./nanage.py startapp app名称; 3.子应用移动到apps包里; 4.settings.py里设置INSTALLED_APPS如“apps.users”,该名字跟子应用apps.py文…

红宝石阅读笔记

第七章 迭代器与生成器 7.3 生成器 乍一看理解,仔细想没理解,然后自己让n2,还原nTimes,等价于 function* nTimes() {if (true) {yield* (function* A() {if (true) {yield* (function* B() { })();yield 0;}})();yield 1;} } 最…

NO1.使用命令行创建Maven工程

①在工作空间目录下打开命令窗口 ②使用命令行生成Maven工程 mvn archetype:generate 运行 MVN 原型:生成命令,下面根据提示操作 选择一个数字或应用过滤器(格式:[groupId:]artifactId,区分大小写包含)&a…

Jquery笔记

DOM对象通过jquery获取 所有的代码都是基于引入jquery.js文件 var mydiv $(#div);//直接获取到DOM对象元素id var mydiv$(.div);//通过class获取DOM对象,如果有同名class只会获取第一个 var mysapn$(span);//通过元素的标签名获取DOM对象 var divarr$(…

Spring源码:Spring运行环境Environment

Spring运行环境 Spring在创建容器时,会创建Environment环境对象,用于保存spring应用程序的运行环境相关的信息。在创建环境时,需要创建属性源属性解析器,会解析属性值中的占位符,并进行替换。 创建环境时&#xff0c…