深度学习 —— 个人学习笔记6(权重衰减)

声明

  本文章为个人学习使用,版面观感若有不适请谅解,文中知识仅代表个人观点,若出现错误,欢迎各位批评指正。

十三、权重衰减

  使用以下公式为例做演示:

y = 0.05 + ∑ i = 1 d 0.01 x i + ε w h e r e ε ~ N ( 0 , 0.0 1 2 ) y = 0.05 + \sum_{i=1}^{d} 0.01x_i + \varepsilon \quad where \quad \varepsilon \; ~ \; N ( 0 , 0.01^2 ) y=0.05+i=1d0.01xi+εwhereεN(0,0.012)

  • 权重衰减的实现
import torch
from torch import nn
from d2l import torch as d2l
from IPython import displaydef synthetic_data(w, b, num_examples):"""生成 y = Xw + b + 噪声。"""X = torch.normal(0, 1, (num_examples, len(w))).cuda()                    # 均值为 0,方差为 1,有 num_examples 个样本,列数为 w 长度y = torch.matmul(X, w).cuda() + b                                        # y = Xw + by += torch.normal(0, 0.01, y.shape).cuda()                               # 随机噪音return X, y.reshape((-1, 1))                                             # x,y作为列向量返回class Animator:                                                                   # 定义一个在动画中绘制数据的实用程序类 Animator"""在动画中绘制数据"""def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,ylim=None, xscale='linear', yscale='linear',fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1,figsize=(3.5, 2.5)):# 增量地绘制多条线if legend is None:legend = []d2l.use_svg_display()self.fig, self.axes = d2l.plt.subplots(nrows, ncols, figsize=figsize)if nrows * ncols == 1:self.axes = [self.axes, ]# 使用lambda函数捕获参数self.config_axes = lambda: d2l.set_axes(self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)self.X, self.Y, self.fmts = None, None, fmtsdef add(self, x, y):# Add multiple data points into the figureif not hasattr(y, "__len__"):y = [y]n = len(y)if not hasattr(x, "__len__"):x = [x] * nif not self.X:self.X = [[] for _ in range(n)]if not self.Y:self.Y = [[] for _ in range(n)]for i, (a, b) in enumerate(zip(x, y)):if a is not None and b is not None:self.X[i].append(a)self.Y[i].append(b)self.axes[0].cla()for x, y, fmt in zip(self.X, self.Y, self.fmts):self.axes[0].plot(x, y, fmt)self.config_axes()display.display(self.fig)# 通过以下两行代码实现了在PyCharm中显示动图d2l.plt.draw()d2l.plt.pause(interval=0.001)display.clear_output(wait=True)d2l.plt.rcParams['font.sans-serif'] = ['Microsoft YaHei']n_train, n_test, num_inputs, batch_size = 20, 100, 200, 5
true_w, true_b = torch.ones((num_inputs, 1)).cuda() * 0.01, 0.05
train_data = synthetic_data(true_w, true_b, n_train)
train_iter = d2l.load_array(train_data, batch_size)
test_data = synthetic_data(true_w, true_b, n_test)
test_iter = d2l.load_array(test_data, batch_size, is_train=False)##############    权重衰减的实现    #############
def init_params():""" 初始化参数 """w = torch.normal(0, 1, size=(num_inputs, 1)).cuda()b = torch.zeros(1).cuda()w.requires_grad_(True)b.requires_grad_(True)return [w, b]def l2_penalty(w):""" 定义 L2 范数惩罚 """return (torch.sum(w.pow(2)) / 2).cuda()def train(lambd):flag_button = "使用"w, b = init_params()net, loss = lambda X: d2l.linreg(X, w, b), d2l.squared_lossnum_epochs, lr = 150, 0.005animator = Animator(xlabel='epochs', ylabel='loss', yscale='log',xlim=[5, num_epochs], legend=['train', 'test'])for epoch in range(num_epochs):for X, y in train_iter:# 增加了 L2 范数惩罚项,、# 广播机制使 l2_penalty(w) 成为一个长度为 batch_size 的向量l = loss(net(X), y) + lambd * l2_penalty(w)l.sum().backward()d2l.sgd([w, b], lr, batch_size)if (epoch + 1) % 5 == 0:animator.add(epoch + 1, (d2l.evaluate_loss(net, train_iter, loss),d2l.evaluate_loss(net, test_iter, loss)))# print('w的L2范数是:', torch.norm(w).item())if lambd == 0:flag_button = "禁用"d2l.plt.title(f"{flag_button}权重衰减 (lambda = {lambd})\nw 的 L2 范数是:{torch.norm(w).item()}")d2l.plt.show()train(lambd=0)train(lambd=15)


  • 权重衰减的简洁实现
import torch
from torch import nn
from d2l import torch as d2l
from IPython import displaydef synthetic_data(w, b, num_examples):"""生成 y = Xw + b + 噪声。"""X = torch.normal(0, 1, (num_examples, len(w))).cuda()                    # 均值为 0,方差为 1,有 num_examples 个样本,列数为 w 长度y = torch.matmul(X, w).cuda() + b                                        # y = Xw + by += torch.normal(0, 0.01, y.shape).cuda()                               # 随机噪音return X, y.reshape((-1, 1))                                             # x,y作为列向量返回class Animator:                                                                   # 定义一个在动画中绘制数据的实用程序类 Animator"""在动画中绘制数据"""def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,ylim=None, xscale='linear', yscale='linear',fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1,figsize=(3.5, 2.5)):# 增量地绘制多条线if legend is None:legend = []d2l.use_svg_display()self.fig, self.axes = d2l.plt.subplots(nrows, ncols, figsize=figsize)if nrows * ncols == 1:self.axes = [self.axes, ]# 使用lambda函数捕获参数self.config_axes = lambda: d2l.set_axes(self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)self.X, self.Y, self.fmts = None, None, fmtsdef add(self, x, y):# Add multiple data points into the figureif not hasattr(y, "__len__"):y = [y]n = len(y)if not hasattr(x, "__len__"):x = [x] * nif not self.X:self.X = [[] for _ in range(n)]if not self.Y:self.Y = [[] for _ in range(n)]for i, (a, b) in enumerate(zip(x, y)):if a is not None and b is not None:self.X[i].append(a)self.Y[i].append(b)self.axes[0].cla()for x, y, fmt in zip(self.X, self.Y, self.fmts):self.axes[0].plot(x, y, fmt)self.config_axes()display.display(self.fig)# 通过以下两行代码实现了在PyCharm中显示动图d2l.plt.draw()d2l.plt.pause(interval=0.001)display.clear_output(wait=True)d2l.plt.rcParams['font.sans-serif'] = ['Microsoft YaHei']n_train, n_test, num_inputs, batch_size = 20, 100, 200, 5
true_w, true_b = torch.ones((num_inputs, 1)).cuda() * 0.01, 0.05
train_data = synthetic_data(true_w, true_b, n_train)
train_iter = d2l.load_array(train_data, batch_size)
test_data = synthetic_data(true_w, true_b, n_test)
test_iter = d2l.load_array(test_data, batch_size, is_train=False)##############    权重衰减的简洁实现    #############def train_concise(wd):flag_button = "使用"net = nn.Sequential(nn.Linear(num_inputs, 1)).cuda()for param in net.parameters():param.data.normal_().cuda()loss = nn.MSELoss(reduction='none').cuda()num_epochs, lr = 150, 0.005# 偏置参数没有衰减trainer = torch.optim.SGD([{"params":net[0].weight,'weight_decay': wd},{"params":net[0].bias}], lr=lr)animator = Animator(xlabel='epochs', ylabel='loss', yscale='log',xlim=[5, num_epochs], legend=['train', 'test'])for epoch in range(num_epochs):for X, y in train_iter:trainer.zero_grad()l = loss(net(X), y)l.mean().backward()trainer.step()if (epoch + 1) % 5 == 0:animator.add(epoch + 1,(d2l.evaluate_loss(net, train_iter, loss),d2l.evaluate_loss(net, test_iter, loss)))# print('w的L2范数:', net[0].weight.norm().item())if wd == 0:flag_button = "禁用"d2l.plt.title(f"{flag_button}权重衰减 (lambda = {wd})\nw 的 L2 范数是:{net[0].weight.norm().item()}")d2l.plt.show()train_concise(0)train_concise(-2)  



  文中部分知识参考:B 站 —— 跟李沐学AI;百度百科

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/49244.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

从PyTorch官方的一篇教程说开去(4 - Q-table来源及解决问题实例)

偷个懒,代码来自比很久之前看的书,当时还在用gym,我做了微调以升级到gymnasium当前版本,确保可以正常演示。如果小伙伴或者原作者看到了麻烦提一下,我好备注一下出处。 您的进步和反馈是我最大的动力,小伙…

语音识别 语音识别项目相关笔记内容

语音识别 语音识别项目相关笔记内容 语音识别应用范畴语音识别框架语音基本操作使用scipy.io.wavfile读取wav音频文件获取采样率、长度、通道数使用numpy读取pcm格式音频文件读取wav音频文件,并绘制图像读取双声道的wav音频文件,分别绘制不同声道的波形图读取一个采样率为16k…

【Docker】Docker Desktop - WSL update failed

问题描述 Windows上安装完成docker desktop之后,第一次启动失败,提示:WSL update failed 解决方案 打开Windows PowerShell 手动执行: wsl --set-default-version 2 wsl --update

使用 vue-element-plus-admin 框架遇到的问题记录

项目打包遇到的问题: 打包语句:pnpm run build:pro 报错信息: Error: [vite]: Rollup failed to resolve import "E:/workplace_gitee/xxx/node_modules/.pnpm/element-plus2.5.5_vue3.4.15/node_modules/element-plus/es/components…

【精品资料】数据安全治理解决方案(27页PPT)

引言:数据安全治理解决方案是一个综合性的体系,旨在通过策略、技术、流程和人力的有机结合,全面提升组织的数据安全防护能力,保障数据资产的安全与合规。 方案介绍:数据安全治理解决方案是组织为确保其数据资产的安全性…

Spark内核的设计原理

导读: 本期是DataFun深入浅出Apache Spark第一期的分享,主讲老师耿嘉安开场介绍了自己的从业经历,当前就职的数新网络与Spark相关的两款产品赛博数智引擎CyberEngine和赛博数据智能平台CyberData。 本次分享题目为《Spark内核的设计原理》&…

智能化一体闸门:助力行业发展

随着科技的飞速发展,智能化技术已经渗透到各个行业和领域,其中水利行业也不例外。智能化一体闸门以其高效、智能、便捷的特点,正助力着行业发展。 一、智能化一体闸门的定义与特点 智能化一体闸门,是集成了先进传感技术、自动控制…

Transformer之Swin-Transformer结构解读

写在最前面之如何只用nn.Linear实现nn.Conv2d的功能 很多人说,Swin-Transformer就是另一种Convolution,但是解释得真就是一坨shit,这里我郑重解释一下,这是为什么? 首先,Convolution是什么? Co…

什么是离线语音识别芯片?与在线语音识别的区别

离线语音识别芯片是一种不需要联网和其他外部设备支持,‌上电即可使用的语音识别系统。‌它的应用场合相对单一,‌主要适用于智能家电、‌语音遥控器、‌智能玩具等,‌以及车载声控和一部分智能家居。‌离线语音识别芯片的特点包括小词汇量、…

Python文件写入读取,文件复制以及一维,二维,多维数据存储

基础解释 在 Python 中,文件操作的模式除了 w (只写)、 a (追加写)、 r (只读)外,还有以下几种常见模式:- r :可读可写。该文件必须已存在,写操…

分类损失函数 (一) torch.nn.CrossEntropyLoss()

1、交叉熵 是一种用于衡量两个概率分布之间的距离或相似性的度量方法。机器学习中,交叉熵常用于损失函数,用于评估模型的预测结果和实际标签的差异。公式: y:真是标签的概率分布,y:模型预测的概率分布 …

数据库中的内、外、左、右连接

常用的数据库连表形式: 内连接 :inner join 外连接 :outer join 左外连接 :left outer join 左连接 :left join 右外连接 right outer join 右连接: right join 全连接 full join 、union 一、内连接 内…

企业私有云的部署都有哪些方式?

如今常见的企业私有云的部署方式有自建私有云、托管私有云、虚拟私有云、混合云、容器化私有云、本地数据中心部署等。如今,企业私有云的部署呈多样化趋势,以用来满足各个企业的具体需求。以下是RAK部落小编为大家汇总的企业私有云常见的部署方式&#x…

LeetCode 58.最后一个单词的长度 C++

LeetCode 58.最后一个单词的长度 C 思路🤔: 先解决当最后字符为空格的情况,如果最后字符为空格下标就往后移动,直到不为空格才停止,然后用rfind查询空格找到的就是最后一个单词的起始位置,最后相减就是单词…

C++ 正则库与HTTP请求

正则表达式的概念和语法 用于描述和匹配字符串的工具,通过特定的语法规则,灵活的定义复杂字符串匹配条件 常用语法总结 基本字符匹配 a:匹配字符aabc:匹配字符串abc 元字符(特殊含义的字符) .:匹…

1Panel面板配置java运行环境及网站的详细操作教程

本篇文章主要讲解,通过1Panel面板实现java运行环境,部署网站并加载的详细教程。 日期:2024年7月21日 作者:任聪聪 独立博客:https://rccblogs.com/501.html 一、实际效果 二、详细操作 步骤一、给我的项目进行打包&am…

在jsPsych中使用Vue

jspsych 介绍 jsPsych是一个非常好用的心理学实验插件,可以用来构建心理学实验。具体的就不多介绍了,大家可以去看官网:https://www.jspsych.org/latest/ 但是大家在使用时就会发现,这个插件只能使用js绘制界面,或者…

STM32自己从零开始实操10:PCB全过程

一、PCB总体分布 分布主要参考有: 方便供电布线。方便布信号线。方便接口。人体工学。 以下只能让大家看到各个模块大致分布在板子的哪一块,只能说每个人画都有自己的理由,我的理由如下。 还有很多没有表达出来的东西,我也不知…

PingCAP 王琦智:下一代 RAG,tidb.ai 使用知识图谱增强 RAG 能力

导读 随着 ChatGPT 的流行,LLMs(大语言模型)再次进入人们的视野。然而,在处理特定领域查询时,大模型生成的内容往往存在信息滞后和准确性不足的问题。如何让 RAG 和向量搜索技术在实际应用中更好地满足企业需求&#…

昇思25天学习打卡营第14天|计算机视觉

昇思25天学习打卡营第14天 文章目录 昇思25天学习打卡营第14天FCN图像语义分割语义分割模型简介网络特点数据处理数据预处理数据加载训练集可视化 网络构建网络流程 训练准备导入VGG-16部分预训练权重损失函数自定义评价指标 Metrics 模型训练模型评估模型推理总结引用 打卡记录…