【机器学习】036_权重衰退

一、范数

· 定义:向量的范数表示一个向量有多大(分量的大小)

L1范数:

        · 即向量元素绝对值之和,用符号 ‖ v ‖ 1 表示。

        · 公式:\left \| x \right \|_1 = \sum_{n}^{i=1}|x_i|

L2范数:

        · 即向量的模,向量各元素绝对值的平方之和再开根号,用符号 ‖ v ‖ 2 表示。

        · 公式:\left \| x \right \|_2=\sqrt{\sum_{n}^{i=1}x_i^2}

Lp范数:

        · 即向量范数的一般形式,各元素绝对值的p次幂之和再开p次根号,用符号 ‖ v ‖ p 表示。

        · 公式:\left \| x \right \|_p = (\sqrt[p]{\sum_{n}^{i=1}|x|^p})

二、权重衰减(L2正则化)

模型(函数)复杂度的度量:

· 一般通过线性函数 f(x) = w^Tx 中的权重向量的某个范数(如 \left \| w \right \|^2)来度量其复杂度

要想避免模型的过拟合,就要控制模型容量,使模型的权重向量尽可能小

· 通过限制参数值的选择范围来控制模型容量

衰减方法:

借助损失函数,将权重范数作为惩罚项添加到最小化损失中;使得损失函数的作用变为“最小化预测损失和惩罚项之和”。

损失函数公式如下:

J(w,b)=L(w,b)+\frac{\lambda }{2}\left \| w \right \|^2

· 其中,L(w,b) 是模型原本的损失函数,\frac{\lambda }{2}\left \| w \right \|^2 是新添加的惩罚项。

· 正则化常数 \lambda 用来描绘这种权衡,其为一个非负超参数。

· \lambda 的值越大,表示对 w 的约束较大;反之 \lambda 的值越小,表示对 w 的约束较小。

※为何选用平方范数而不是标准范数:

        · 便于计算。平方范数可以去掉平方根使得导数更容易计算,利于反向传播过程。

        · 使用L2范数是因为它会对权重向量的大分量施加巨大的惩罚,使各权重均匀分布。

        · L1范数惩罚会导致权重集中在某一小部分特征上,其它权重被清除为0(特征选择)。

使用该损失函数,就可以使梯度下降的优化算法在训练的每一步都衰减权重,避免过拟合发生。

如上图所示,现在模型的损失函数同时受两项影响,一是误差项,二是惩罚项。

        现在在等高线图上,梯度下降最终收敛的位置不再是某一个项所造成的最低点,因为在这时,可能误差项达到最小了,但是惩罚项很大,使得惩罚项拉着损失函数再向另一个方向移动。

        只有当达到了两个项共同作用下的一个平衡点时,损失函数才具有最小值,这个时候的模型往往复杂度也降低了,虽然有可能造成训练损失增大,但是测试损失会减小。

三、代码实现权重衰减

从零实现代码如下:

import matplotlib
import torch
from torch import nn
from d2l import torch as d2l# 训练数据集、测试数据集、输入值、训练批次
n_train, n_test, num_inputs, batch_size = 20, 100, 200, 5
# 初始化w和b的真实值
true_w, true_b = torch.ones((num_inputs, 1)) * 0.01, 0.05
# 拿到训练数据
train_data = d2l.synthetic_data(true_w, true_b, n_train)
train_iter = d2l.load_array(train_data, batch_size)
test_data = d2l.synthetic_data(true_w, true_b, n_test)
test_iter = d2l.load_array(test_data, batch_size, is_train=False)# 初始化模型参数w和b
def init_params():w = torch.normal(0, 1, size=(num_inputs, 1), requires_grad=True)b = torch.zeros(1, requires_grad=True)return [w, b]
# 定义L2范数惩罚项
def l2_penalty(w):return torch.sum(w.pow(2)) / 2
# 实现训练代码,读入参数为兰姆达(正则化参数)
def train(lambd):w, b = init_params()net, loss = lambda X: d2l.linreg(X, w, b), d2l.squared_lossnum_epochs, lr = 100, 0.003animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',xlim=[5, num_epochs], legend=['train', 'test'])for epoch in range(num_epochs):for X, y in train_iter:# 增加了L2范数惩罚项,# 广播机制使l2_penalty(w)成为一个长度为batch_size的向量l = loss(net(X), y) + lambd * l2_penalty(w)l.sum().backward()d2l.sgd([w, b], lr, batch_size)if (epoch + 1) % 5 == 0:animator.add(epoch + 1, (d2l.evaluate_loss(net, train_iter, loss),d2l.evaluate_loss(net, test_iter, loss)))print('w的L2范数是:', torch.norm(w).item())
# 使用权重进行训练
train(lambd=3)

简洁实现代码如下:

import torch
from torch import nn
from d2l import torch as d2l# 训练数据集、测试数据集、输入值、训练批次
n_train, n_test, num_inputs, batch_size = 20, 100, 200, 5
# 初始化w和b的真实值
true_w, true_b = torch.ones((num_inputs, 1)) * 0.01, 0.05
# 拿到训练数据
train_data = d2l.synthetic_data(true_w, true_b, n_train)
train_iter = d2l.load_array(train_data, batch_size)
test_data = d2l.synthetic_data(true_w, true_b, n_test)
test_iter = d2l.load_array(test_data, batch_size, is_train=False)def train_concise(wd):net = nn.Sequential(nn.Linear(num_inputs, 1))for param in net.parameters():param.data.normal_()loss = nn.MSELoss(reduction='none')num_epochs, lr = 100, 0.003# 偏置参数没有衰减trainer = torch.optim.SGD([{"params":net[0].weight,'weight_decay': wd},{"params":net[0].bias}], lr=lr)animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',xlim=[5, num_epochs], legend=['train', 'test'])for epoch in range(num_epochs):for X, y in train_iter:trainer.zero_grad()l = loss(net(X), y)l.mean().backward()trainer.step()if (epoch + 1) % 5 == 0:animator.add(epoch + 1,(d2l.evaluate_loss(net, train_iter, loss),d2l.evaluate_loss(net, test_iter, loss)))print('w的L2范数:', net[0].weight.norm().item())train_concise(3)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/152595.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

适合您的智能手机的 7 款优秀手机数据恢复软件分享

如今,我们做什么都用手机;从拍照到录音,甚至作为 MP3 播放器,我们已经对手机变得非常依恋。这导致我们在手机上留下了很多珍贵的回忆。 不幸的是,我们有可能会丢失手机上的部分甚至全部数据。幸运的是,这不…

1. hadoop环境准备

环境准备 准备三台虚拟机&#xff0c;配置最好是 2C 4G 以上 本文准备三台机器的内网ip分别为 172.17.0.10 172.17.0.11 172.17.0.12本机配置/etc/hosts cat >> /etc/hosts<<EOF 172.17.0.10 hadoop01 172.17.0.11 hadoop02 172.17.0.12 hadoop03 EOF本机设置与…

队列的实现和OJ练习

目录 概念 队列的实现 利用结构体存放队列结构 为什么单链表不使用这种方法&#xff1f; 初始化队列 小提示&#xff1a; 队尾入队列 队头出队列 获取队头元素 获取队尾元素 获取队列中有效元素个数 检测队列是否为空 销毁队列 最终代码 循环队列 队列的OJ题 …

MobaXterm如何连接CentOS7的Linux虚拟机?Redis可视化客户端工具如何连接Linux版Redis?

一、打开Lunix虚拟机,进入虚拟机中,在终端中输入ifconfig,得到以下信息&#xff0c;红框中为ip地址 二、打开MobaXterm&#xff0c;点击session 选择SSH&#xff0c;在Remote host中输入linux得到的IP地址&#xff0c;Specify username中可起一个任意的连接名称。 输入密码 四、…

【洛谷 P3743】kotori的设备 题解(二分答案+递归)

kotori的设备 题目背景 kotori 有 n n n 个可同时使用的设备。 题目描述 第 i i i 个设备每秒消耗 a i a_i ai​ 个单位能量。能量的使用是连续的&#xff0c;也就是说能量不是某时刻突然消耗的&#xff0c;而是匀速消耗。也就是说&#xff0c;对于任意实数&#xff0c;…

60 权限提升-MYMSORA等SQL数据库提权

目录 数据库应用提权在权限提升中的意义WEB或本地环境如何探针数据库应用数据库提权权限用户密码收集等方法目前数据库提权对应的技术及方法等 演示案例Mysql数据库提权演示-脚本&MSF1.UDF提权知识点: (基于MYSQL调用命令执行函数&#xff09;读取数据库存储或备份文件 (了…

GaussDB新特性Ustore存储引擎介绍

1、 Ustore和Astore存储引擎介绍 Ustore存储引擎&#xff0c;又名In-place Update存储引擎&#xff08;原地更新&#xff09;&#xff0c;是openGauss 内核新增的一种存储模式。此前的版本使用的行存储引擎是Append Update&#xff08;追加更新&#xff09;模式。相比于Append…

在网络攻击之前、期间和之后应采取的步骤

在当今复杂的威胁形势下&#xff0c;网络攻击是不可避免的。 恶意行为者变得越来越复杂&#xff0c;出于经济动机的攻击变得越来越普遍&#xff0c;并且每天都会发现新的恶意软件系列。 这使得对于各种规模和跨行业的组织来说&#xff0c;制定适当的攻击计划变得更加重要。 …

【Linux】进程间通信 -- 管道

对于进程间通信的理解 首先&#xff0c;进程间通信的本质是&#xff0c;让不同的进程看到同一份资源&#xff08;这份资源不能隶属于任何一个进程&#xff0c;即应该是共享的&#xff09;。而进程间通信的目的是为了实现多进程之间的协同。 但由于进程运行具有独立性&#xff…

密码加密解密之路

1.背景 做数据采集&#xff0c;客户需要把他们那边的数据库连接信息存到我们系统里&#xff0c;那我们系统就要尽可能的保证这部分数据安全&#xff0c;不被盗。 2.我的思路 1.需要加密的地方有两处&#xff0c;一个是新增的时候前端传给后端的时候&#xff0c;一个是存到数…

异步爬取+多线程+redis构建一个运转丝滑且免费http-ip代理池 (三)

内容提要: 如果说,爬取网页数据的时候,我们使用了异步,那么将数据放入redis里面,其实也需要进行异步;当然,如果使用多线程或者redis线程池技术也是可以的,但那会造成冗余; 因此,在测试完多线程redis搭配异步爬虫的时候,我发现效率直接在redis这里被无限拉低下来! 因此: 最终的r…

从0开始学习JavaScript--JavaScript中的集合类

JavaScript中的集合类是处理数据的关键&#xff0c;涵盖了数组、Set、Map等多种数据结构。本文将深入研究这些集合类的创建、操作&#xff0c;以及实际应用场景&#xff0c;并通过丰富的示例代码&#xff0c;帮助大家更全面地了解和应用这些概念。 数组&#xff08;Array&…

SystemVerilog学习 (11)——覆盖率

目录 一、概述 二、覆盖率的种类 1、概述 2、分类 三、代码覆盖率 四、功能覆盖率 五、从功能描述到覆盖率 一、概述 “验证如果没有量化&#xff0c;那么就意味着没有尽头。” 伴随着复杂SoC系统的验证难度系数成倍增加&#xff0c;无论是定向测试还是随机测试&#xff…

安全框架springSecurity+Jwt+Vue-1(vue环境搭建、动态路由、动态标签页)

一、安装vue环境&#xff0c;并新建Vue项目 ①&#xff1a;安装node.js 官网(https://nodejs.org/zh-cn/) 2.安装完成之后检查下版本信息&#xff1a; ②&#xff1a;创建vue项目 1.接下来&#xff0c;我们安装vue的环境 # 安装淘宝npm npm install -g cnpm --registryhttps:/…

软件测试/测试开发/人工智能丨基于Spark的分布式造数工具:加速大规模测试数据构建

随着软件开发规模的扩大&#xff0c;测试数据的构建变得越来越复杂&#xff0c;传统的造数方法难以应对大规模数据需求。本文将介绍如何使用Apache Spark构建分布式造数工具&#xff0c;以提升测试数据构建的效率和规模。 为什么选择Spark&#xff1f; 分布式计算&#xff1a;…

easyExcel注解详情

前言11个注解字段注解 类注解基础综合示例补充颜色总结 11个注解 ExcelProperty ColumnWith 列宽 ContentFontStyle 文本字体样式 ContentLoopMerge 文本合并 ContentRowHeight 文本行高度 ContentStyle 文本样式 HeadFontStyle 标题字体样式 HeadRowHeight 标题高度 HeadStyle…

Python将原始数据集和标注文件进行数据增强(随机仿射变换),并生成随机仿射变换的数据集和标注文件

Python将原始数据集和标注文件进行数据增强&#xff08;随机仿射变换&#xff09;&#xff0c;并生成随机仿射变换的数据集和标注文件 前言前提条件相关介绍实验环境生成随机仿射变换的数据集和标注文件代码实现输出结果 前言 由于本人水平有限&#xff0c;难免出现错漏&#x…

OpenCV快速入门:图像滤波与边缘检测

文章目录 前言一、噪声种类与生成1.1 椒盐噪声1.2 高斯噪声1.3 彩色噪声 二、卷积操作2.1 卷积基本原理2.2 卷积操作代码实现 三、线性滤波3.1 均值滤波均值滤波原理均值滤波公式均值滤波代码实现 3.2 方框滤波方框滤波原理方框滤波公式方框滤波代码实现 3.3 高斯滤波高斯滤波原…

redis非关系型数据库(缓存型数据库)——中间件

【重点】redis为什么这么快&#xff1f;&#xff08;应届&#xff09; ①redis是纯内存结构&#xff0c;避免磁盘I/O的耗时 ②redis核心模块是一个单进程&#xff0c;减少线程切换和回收线程资源时间 ③redis采用的是I/O的多路复用机制&#xff08;每一个执行线路可以同时完…

npm install 下载不下来依赖解决方案

背景 最近在构建 前端自动化部署 的方案中发现了一个问题&#xff0c;就是我在npm install的时候&#xff0c;有时候成功&#xff0c;有时候不成功&#xff0c;而且什么代码也没发生更改&#xff0c;报错也就是那么几个错&#xff0c;所以在此也整理了一下遇到这种情况&#xf…