李沐深度学习记录4:12.权重衰减/L2正则化

权重衰减从零开始实现

#高维线性回归
%matplotlib inline
import torch
from torch import nn
from d2l import torch as d2l#整个流程是,1.生成标准数据集,包括训练数据和测试数据
#          2.定义线性模型训练
#           模型初始化(函数)、包含惩罚项的损失(函数)
#           定义epochs进行训练,每训练5轮评估一次模型在训练集和测试集的损失,画图显示
#           训练结束后分别查看并比较是否添加范数惩罚项损失对应的训练结果w的L2范数
#生成数据集
n_train, n_test, num_inputs, batch_size = 20, 100, 200, 5  #训练数据样本数20,测试样本数100,数据维度200,批量大小5
true_w, true_b = torch.ones((num_inputs, 1)) * 0.01, 0.05  #生成w矩阵(200,1),w值0.01,偏置b为0.05
train_data = d2l.synthetic_data(true_w, true_b, n_train) #生成训练数据集X(20,200),y(20,1),y=Xw+b+噪声,train_data接收返回的X,y
train_iter = d2l.load_array(train_data, batch_size)  #传入数据集和批量大小,构造训练数据迭代器
test_data = d2l.synthetic_data(true_w, true_b, n_test) #生成测试数据集
test_iter = d2l.load_array(test_data, batch_size, is_train=False)  #构造测试数据迭代器#初始化模型参数
def init_params():w = torch.normal(0, 1, size=(num_inputs, 1), requires_grad=True)b = torch.zeros(1, requires_grad=True)return [w, b]#定义L2范数惩罚项
def l2_penalty(w):return torch.sum(w.pow(2)) / 2  #L2范数公式需要开平方根,但这里L2范数惩罚项是L2范数的平方,所以不需要开平方根了#训练代码
def train(lambd):  #输入λ超参数w, b = init_params()  #初始化模型参数net, loss = lambda X: d2l.linreg(X, w, b), d2l.squared_loss  #net线性模型torch.matmul(X, w) + b;loss是均方误差num_epochs, lr = 100, 0.003animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',xlim=[5, num_epochs], legend=['train', 'test'])for epoch in range(num_epochs):  #进行多次迭代训练for X, y in train_iter:  #每个epoch,取训练数据# 增加了L2范数惩罚项,# 广播机制使l2_penalty(w)成为一个长度为batch_size的向量l = loss(net(X), y) + lambd * l2_penalty(w)  #loss计算加上了λ×范数惩罚项l.sum().backward()  #这里计算损失和,下面参数更新时会对梯度求平均再更新参数d2l.sgd([w, b], lr, batch_size)  #进行参数更新操作if (epoch + 1) % 5 == 0:  #每5次epoch训练,评估一次模型的训练损失和测试损失animator.add(epoch + 1, (d2l.evaluate_loss(net, train_iter, loss),d2l.evaluate_loss(net, test_iter, loss)))print('w的L2范数是:', torch.norm(w).item())  #训练结束后,计算w的L2范数(没有平方)
#λ为0,无正则化项,训练
train(lambd=0)
d2l.plt.show()

在这里插入图片描述

#λ为10,有正则化项,训练
train(lambd=5)
d2l.plt.show()

在这里插入图片描述

权重衰减的简洁实现

#权重衰减的简洁实现
def train_concise(wd):net = nn.Sequential(nn.Linear(num_inputs, 1))   #定义模型for param in net.parameters():   #初始化参数param.data.normal_()loss = nn.MSELoss(reduction='none')  #计算loss,这里不包含正则项num_epochs, lr = 100, 0.003# 偏置参数没有衰减#在参数优化部分,计算梯度时加入了权重衰减#所以是计算loss时没计算正则项,只是在计算梯度时加入了权重衰减吗?trainer = torch.optim.SGD([{"params":net[0].weight,'weight_decay': wd},{"params":net[0].bias}], lr=lr)animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',xlim=[5, num_epochs], legend=['train', 'test'])for epoch in range(num_epochs):   #训练100轮for X, y in train_iter:  #对于每轮,取数据训练trainer.zero_grad()   #梯度清零l = loss(net(X), y)  #计算lossl.mean().backward() #反向传播trainer.step()  #更新梯度if (epoch + 1) % 5 == 0:   #每5轮评估一次模型在测试集和训练集的损失animator.add(epoch + 1,(d2l.evaluate_loss(net, train_iter, loss),d2l.evaluate_loss(net, test_iter, loss)))print('w的L2范数:', net[0].weight.norm().item())
#没有进行权重衰减
train_concise(0)

在这里插入图片描述

#进行权重衰减
train_concise(5)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/100152.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

springboot 捕获特点异常信息并处理

前端获取效果图 springboot 捕获特点异常信息并处理 import com.one.utils.JSONResult; //JSONResult定义处理结果对象 import org.springframework.web.bind.annotation.ExceptionHandler

35.树与二叉树练习(1)(王道第5章综合练习)

【所用的树,队列,栈的基本操作详见上一节代码】 试题1(王道5.3.3节第3题): 编写后序遍历二叉树的非递归算法。 参考:34.二叉链树的C语言实现_北京地铁1号线的博客-CSDN博客https://blog.csdn.net/qq_547…

3D 生成重建005-NeRF席卷3D的表达形式

3D生成重建005-NeRF席卷3D的表达形式 文章目录 0 论文工作1 论文方法1.1 体渲染1.2 离散积分1.3位置编码1.4分层采样1.5 影响 2 效果 0 论文工作 NeRF(神经辐射场技术)最早2020年提出用于新视图合成任务,并在这个领域取得了优秀的效果。如下图所示,受到…

Springcloud笔记(2)-Eureka服务注册

Eureka服务注册 服务注册,发现。 在Spring Cloud框架中,Eureka的核心作用是服务的注册和发现,并实现服务治理。 Eureka包含两个组件:Eureka Server和Eureka Client。 Eureka Server提供服务注册服务,各个节点启动后…

mysql面试题31:一条SQL语句在MySQL中如何执行的

该文章专注于面试,面试只要回答关键点即可,不需要对框架有非常深入的回答,如果你想应付面试,是足够了,抓住关键点 面试官:一条SQL语句在MySQL中如何执行的 以下是一条SQL语句在MySQL中的详细执行步骤: 语法分析:MySQL首先对SQL语句进行语法分析,检查SQL语句是否符合…

ARM-流水灯

.text .global _start _start: 1、设置GPIOE寄存器的时钟使能 RCC_MP_AHB$ENSETR[4]->1 0x50000a28LDR R0,0X50000A28 LDR R1,[R0] 从R0起始地址的4字节数据取出放在R1 ORR R1,R1,#(0X3<<4) 第4位设置为1 STR R1,[R0] 写回2、设置PE10、PE8、PF10管脚为输出模式 …

elasticSearch7.9数据占用磁盘存储空间情况

最近&#xff0c;在VMware Workstation虚拟机上安装了es7.9&#xff0c;单节点的es&#xff0c;不是集群&#xff0c;然后建了一个索引&#xff08;包含3个分片和一个副本&#xff09;&#xff0c;插入了500万条数据&#xff0c;占据磁盘空间17G。如下图&#xff1a; 索引的字…

在两个有序数组中找整体第k小的数

一、题目 给定两个已经排序的数组&#xff08;假设按照升序排列&#xff09;&#xff0c;然后找出第K小的数。比如数组A {1&#xff0c; 8&#xff0c; 10&#xff0c; 20}&#xff0c; B {5&#xff0c; 9&#xff0c; 22&#xff0c; 110}&#xff0c; 第 3 小的数是 8.。…

基于Springboot实现点餐平台网站管理系统项目【项目源码+论文说明】分享

基于Springboot实现点餐平台网站管理系统演示 摘要 随着现在网络的快速发展&#xff0c;网上管理系统也逐渐快速发展起来&#xff0c;网上管理模式很快融入到了许多商家的之中&#xff0c;随之就产生了“点餐平台网站”&#xff0c;这样就让点餐平台网站更加方便简单。 对于本…

解决远程git服务器路径改变导致本地无法push的问题

解决远程git服务器路径改变导致本地无法push的问题 &#xff08;1&#xff09;第一步&#xff1a;查看git配置 git config -l&#xff08;2&#xff09;第二步&#xff1a;删除远程git地址 git remote remove origin&#xff08;3&#xff09;第三步&#xff1a;再次查看git配…

ELK 处理 SpringCloud 日志

在排查线上异常的过程中&#xff0c;查询日志总是必不可缺的一部分。现今大多采用的微服务架构&#xff0c;日志被分散在不同的机器上&#xff0c;使得日志的查询变得异常困难。工欲善其事&#xff0c;必先利其器。如果此时有一个统一的实时日志分析平台&#xff0c;那可谓是雪…

Mall脚手架总结(三) —— MongoDB存储浏览数据

前言 通过Elasticsearch整合章节的学习&#xff0c;我们了解SpringData框架以及相应的衍生查询的方式操作数据读写的语法。MongoDB的相关操作也同样是借助Spring Data框架&#xff0c;因此这篇文章的内容比较简单&#xff0c;重点还是弄清楚MongoDB的使用场景以及如何通过Sprin…

【计算机网络】UDP协议编写群聊天室----附代码

UDP构建服务器 x 预备知识 认识UDP协议 此处我们也是对UDP(User Datagram Protocol 用户数据报协议)有一个直观的认识; 后面再详细讨论. 传输层协议无连接不可靠传输面向数据报 网络字节序 我们已经知道,内存中的多字节数据相对于内存地址有大端和小端之分, 磁盘文件中的…

扬尘监测:智能化解决方案让生活更美好

随着工业化和城市化的快速发展&#xff0c;扬尘污染问题越来越受到人们的关注。扬尘不仅影响城市环境&#xff0c;还会对人们的健康造成威胁。为了解决这一问题&#xff0c;扬尘监测成为了一个重要的手段。本文将介绍扬尘监测的现状、重要性以及智能化解决方案&#xff0c;帮助…

Python库学习(九):Numpy[续篇三]:数组运算

NumPy是用于数值计算的强大工具&#xff0c;提供了许多数组运算和数学函数&#xff0c;允许你执行各种操作&#xff0c;包括基本运算、统计计算、线性代数、元素级操作等 1.基本运算 1.1 四则运算 NumPy数组支持基本的四则运算&#xff08;加法、减法、乘法和除法&#xff09;…

Flink学习笔记(一):Flink重要概念和原理

文章目录 1、Flink 介绍2、Flink 概述3、Flink 组件介绍3.1、Deploy 物理部署层3.2、Runtime 核心层3.3、API&Libraries 层3.4、扩展库 4、Flink 四大基石4.1、Checkpoint4.2、State4.3、Time4.4、Window 5、Flink 的应用场景5.1、Event-driven Applications【事件驱动】5.…

初识 C语言文件操作

目录 前言&#xff1a; 为什么我们要使用文件&#xff1f; 什么是文件&#xff1f; 程序文件&#xff1a; 数据文件&#xff1a; 文件名&#xff1a; 文件的打开和关闭 文件指针&#xff1a; 流程&#xff1a; 文件路径&#xff1a; 文件的顺序读写&#xff1a; …

【Java 进阶篇】CSS语法格式详解

在前端开发中&#xff0c;CSS&#xff08;层叠样式表&#xff09;用于控制网页的样式和布局。了解CSS的语法格式是学习如何设计和美化网页的关键。本文将深入解释CSS的语法格式&#xff0c;包括选择器、属性和值等基本概念&#xff0c;同时提供示例代码以帮助初学者更好地理解。…

首批成员单位 | 聚铭网络受邀加入中国人工智能产业发展联盟数据委员会

近日&#xff0c;中国人工智能产业发展联盟(简称AIIA&#xff09;成立“数据委员会”&#xff0c;**聚铭网络受邀加入&#xff0c;成为首批成员单位&#xff0c;**与其他成员单位协同推动人工智能产业发展。 中国人工智能产业发展联盟是在国家发展和改革委员会、科学技术部、工…

Python Opencv实践 - 车辆识别(1)读取视频,移除背景,做预处理

示例中的图像的腐蚀、膨胀和闭运算等需要根据具体视频进行实验得到最佳效果。代码仅供参考。 import cv2 as cv import numpy as np#读取视频文件 video cv.VideoCapture("../../SampleVideos/Traffic.mp4") FPS 10 DELAY int(1000 / FPS) kernel cv.getStructu…