动手学深度学习(Pytorch版)代码实践 -深度学习基础-10权重衰减

10权重衰减

"""
正则化是处理过拟合的常用方法:在训练集的损失函数中加入惩罚项,以降低学习到的模型的复杂度。
保持模型简单的一个特别的选择是使用L2惩罚的权重衰减。这会导致学习算法更新步骤中的权重衰减。
"""import torch
from torch import nn
from d2l import torch as d2l
import liliPytorch as lpn_train, n_test, num_input, batch_size = 20, 100, 200, 5
true_w, true_b = torch.ones((num_input,1)) * 0.01, 0.05train_data = d2l.synthetic_data(true_w, true_b, n_train)
train_iter = d2l.load_array(train_data, batch_size)test_data = d2l.synthetic_data(true_w, true_b, n_test)
test_iter = d2l.load_array(test_data, batch_size, is_train=False)#初始化模型参数
def init_params():w = torch.normal(0,1,size=(num_input,1), requires_grad=True)b = torch.zeros(1,requires_grad=True)return [w,b]#定义L2范数惩罚
def l2_penalty(w):return torch.sum(w.pow(2)) / 2def l1_penalty(w):return torch.sum(torch.abs(w))# 定义模型
def linreg(X, w, b):"""线性回归模型"""return torch.matmul(X, w) + b# 定义损失函数
def squared_loss(y_hat, y):"""均方损失函数"""return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2# 定义优化函数
def sgd(params, lr, batch_size):"""小批量随机梯度下降"""# 更新参数时不需要计算梯度with torch.no_grad():for param in params:param -= lr * param.grad / batch_size  # 参数更新param.grad.zero_()  # 梯度清零#定义训练代码实现
def train(lambd):w, b = init_params()net, loss = lambda X: linreg(X, w, b), squared_lossnum_epochs, lr = 100, 0.003animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',xlim=[5, num_epochs], legend=['train', 'test'])for epoch in range(num_epochs):for X, y in train_iter:# 增加了L2范数惩罚项,# 广播机制使l2_penalty(w)成为一个长度为batch_size的向量l = loss(net(X), y) + lambd * l2_penalty(w)l.sum().backward()sgd([w, b], lr, batch_size)if (epoch + 1) % 5 == 0:animator.add(epoch + 1, (d2l.evaluate_loss(net, train_iter, loss),d2l.evaluate_loss(net, test_iter, loss)))print('w的L2范数是:', torch.norm(w).item())#忽略正则化直接训练¶
# train(lambd=0)
#w的L2范数是: 14.630496978759766# 使用权重衰减
# train(lambd=3)
# d2l.plt.show() #权重衰减-简洁实现
def train_concise(wd):net = nn.Sequential(nn.Linear(num_input, 1))for param in net.parameters():param.data.normal_()loss = nn.MSELoss(reduction='none')num_epochs, lr = 100, 0.003# 偏置参数没有衰减trainer = torch.optim.SGD([{"params":net[0].weight,'weight_decay': wd},{"params":net[0].bias}], lr=lr)animator = lp.Animator(xlabel='epochs', ylabel='loss', yscale='log',xlim=[5, num_epochs], legend=['train', 'test'])for epoch in range(num_epochs):for X, y in train_iter:trainer.zero_grad()l = loss(net(X), y)l.mean().backward()trainer.step()if (epoch + 1) % 5 == 0:animator.add(epoch + 1,(d2l.evaluate_loss(net, train_iter, loss),d2l.evaluate_loss(net, test_iter, loss)))print('w的L2范数:', net[0].weight.norm().item())train_concise(0)
d2l.plt.show() 
# w的L2范数是: 0.33992505073547363

运行结果:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/33005.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

html--好看的手机充值单页

<!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><title>线上充值-首页</title><meta content"widthdevice-width,initial-scale1.0,maximum-scale1.0,user-scalable0" name"viewport&…

maya模型仓鼠制作

小仓鼠建模&#xff08;6&#xff09;_哔哩哔哩_bilibili 20240623作品---个人评价&#xff1a;第一次做的&#xff0c;虽然有点丑&#xff0c;但是还能看&#xff01;希望后面有些进步

论文阅读--Efficient Hybrid Zoom using Camera Fusion on Mobile Phones

这是谷歌影像团队 2023 年发表在 Siggraph Asia 上的一篇文章&#xff0c;主要介绍的是利用多摄融合的思路进行变焦。 单反相机因为卓越的硬件性能&#xff0c;可以非常方便的实现光学变焦。不过目前的智能手机&#xff0c;受制于物理空间的限制&#xff0c;还不能做到像单反一…

线程封装,互斥

文章目录 线程封装线程互斥加锁、解锁认识接口解决问题理解锁 线程封装 C/C代码混编引起的问题 此处pthread_create函数要求传入参数为void * func(void * )类型,按理来说ThreadRoutine满足,但是 这是在内类完成封装,所以ThreadRoutine函数实际是两个参数,第一个参数Thread* …

【建设方案】大数据湖一体化建设方案(ppt原件)

1、背景&#xff1a;大数据湖的发展背景与建设理念 2、体系&#xff1a;大数据湖体系规划与建设思路 3、生态圈&#xff1a;探索新兴业务入湖建设模式 4、共享&#xff1a;大数据湖统一访问共享规划 5、运营&#xff1a;大数据湖一体化运营管理建设 &#xff08;本方案及更多方…

Kafka~基础原理与架构了解

Kafka是什么 Kafka我们了解一直认为是一个消息队列&#xff0c;但是其设计初&#xff0c;是一个&#xff1a;分布式流式处理平台。流平台具有三个关键功能&#xff1a; 消息队列&#xff1a;发布和订阅消息流&#xff0c;这个功能类似于消息队列&#xff0c;这也是 Kafka 也被…

Comfyui-ChatTTS-OpenVoice 为ComfyUI添加语音合成、语音克隆功能

‍‍ 生成多人播客&#xff1a; Comfyui-ChatTTS是一个开源的GitHub项目&#xff0c;致力于为ComfyUI添加语音合成功能。该项目提供了一系列功能强大的节点和模型&#xff0c;支持用户创建和复用音色&#xff0c;支持多人对话模式的生成&#xff0c;并提供了导出音频字幕文件的…

“Jedis与Redis整合指南:实现高效的Java应用与Redis交互“

目录 #. 概念 1. 导入jedis依赖 2. 写一个类&#xff08;ping通redis&#xff09; 3. String字符串使用 3.1 set&#xff0c;get方法使用&#xff08;设值&#xff0c;取值&#xff09; 3.2 mset&#xff0c;mget方法使用&#xff08;设置多个值&#xff0c;取多个值&…

怎么在vscode里运行一个cpp文件

文章目录 1.需要下载g编译器&#xff0c;或clang&#xff08;快&#xff0c;但是优化效果没有g好&#xff09;2.新建文件夹和cpp文件&#xff08;tasks.json&#xff09;3.怎么在vscode里调试(launch.json)4.怎么设置让中断输出的字符是中文&#xff01;5.飞机大战 1.需要下载g…

iis下asp.netcore后台定时任务会取消

问题 使用BackgroundService或者IHostedService做后台定时任务的时候部署到iis会出现不定时定时任务取消的问题&#xff0c;原因是iis会定时的关闭网站 解决 应用程序池修改为AlwaysRunning 修改web.config <?xml version"1.0" encoding"utf-8"?…

Android studio登录Google账号超时的解决方法

确保自己已经打开了代理&#xff08;科学上网&#xff09;在设置-外观与行为-系统设置-HTTP代理 中打开“自动检测代理设置”&#xff1a; 再次重新尝试登录Google账号&#xff0c;登陆成功&#xff01; 学术会议征稿 想要了解国内主办的覆盖学科最全最广的学术会议&#xff0c…

代码-功能-python-爬取博客网标题作者发布时间

环境&#xff1a; python 3.8 代码&#xff1a; # 爬取博客园内容 # https://www.cnblogs.com/import re from lxml import etree import requests import json import threading from queue import Queue import pymysql import timeclass HeiMa:def __init__(self):# 请…

k8s 部署 ruoyi 前后端分离项目

本文视频版 https://www.bilibili.com/video/BV17ugkePEeN 参考 https://blog.csdn.net/qq_50247813/article/details/136934090 https://gitee.com/nasaa/RuoYi-Vue-cloud https://www.itsgeekhead.com/tuts/kubernetes-129-ubuntu-22-04-3/ https://kubernetes.io/docs/se…

【漏洞复现】畅捷通T+ keyEdit.aspx SQL漏洞

0x01 产品简介 畅捷通 T 是一款灵动&#xff0c;智慧&#xff0c;时尚的基于互联网时代开发的管理软件&#xff0c;主要针对中小型工贸与商贸企业&#xff0c;尤其适合有异地多组织机构(多工厂&#xff0c;多仓库&#xff0c;多办事处&#xff0c;多经销商)的企业&#xff0c;…

用户态协议栈06-TCP三次握手

最近由于准备软件工程师职称考试&#xff0c;然后考完之后不小心生病了&#xff0c;都没写过DPDK的博客了。今天开始在上次架构优化的基础上增加TCP的协议栈流程。 什么是TCP 百度百科&#xff1a;TCP即传输控制协议&#xff08;Transmission Control Protocol&#xff09;是…

LabVIEW程序退出后线程仍在运行问题

LabVIEW程序退出后&#xff0c;线程仍在运行的问题可能源于资源管理不当、未正确终止循环、事件结构未处理、并发编程错误以及外部库调用未结束等方面。本文将从这些角度详细分析&#xff0c;探讨可能的原因和解决方案&#xff0c;并提供预防措施&#xff0c;帮助开发者避免类似…

将知乎专栏文章转换为 Markdown 文件保存到本地

一、参考内容 参考知乎文章代码 | 将知乎专栏文章转换为 Markdown 文件保存到本地&#xff0c;利用代码为GitHub&#xff1a;https://github.com/chenluda/zhihu-download。 二、步骤 1.首先安装包flask、flask-cors、markdownify 2. 运行app.py 3.在浏览器中打开链接&…

已解决javax.management.BadStringOperationException异常的正确解决方法,亲测有效!!!

已解决javax.management.BadStringOperationException异常的正确解决方法&#xff0c;亲测有效&#xff01;&#xff01;&#xff01; 目录 问题分析 出现问题的场景 报错原因 解决思路 解决方法 分析错误日志 检查字符串值合法性 确认字符串格式 优化代码逻辑 增加…

Trimesh介绍及基本使用

Trimesh介绍及基本使用 Trimesh是一个纯Python 工具库&#xff08;支持3.7版本以上&#xff09;&#xff0c;用于加载和使用三角形Mesh网格&#xff0c;支持多种常见的三维数据格式&#xff0c;如二进制/文本格式的STL、Wavefront OBJ、二进制/文本格式的PLY、GLTF/GLB 2.0、3…

Leetcode 2713. 矩阵中严格递增的单元格数(DFS DP)

Leetcode 2713. 矩阵中严格递增的单元格数 DFS 容易想到&#xff0c;枚举每个点作为起点&#xff0c;向同行同列的可跳跃点dfs&#xff0c;维护全局变量记录可达的最远距离 超时&#xff0c;通过样例193 / 566 class Solution {int res 0;public void dfs(int[][] mat, in…