pytorch-解决过拟合之regularization

目录

  • 1.解决过拟合的方法
  • 2. regularization
  • 2. regularization分类
  • 3. pytorch L2 regularization
  • 4. 自实现L1 regularization
  • 5. 完整代码

1.解决过拟合的方法

  • 更多的数据
  • 降低模型复杂度
    regularization
  • Dropout
  • 数据处理
  • 早停止

2. regularization

以二分类的cross entropy为例,就是在其公式后增加一项参数一范数累加和,并乘以一个超参数用来权衡参数配比。
模型优化是要使得前部分loss尽量小,那么同时也要后半部分范数接近于0,但是为了保持模型的表达能力还要保留比如 β 0 β_{0} β0+ β 1 β_{1} β1x+ β 2 β_{2} β2 x 2 x^2 x2,那么可能使得 β 0 β_{0} β0 β 2 β_{2} β2 β 3 β_{3} β3 = 0.01 而 β 4 β_{4} β4- β n β_{n} βn很小很小,比如0.0001,这样就使得比如f(x)= x 7 x^7 x7,退化为 β 0 β_{0} β0+ β 1 β_{1} β1x+ β 2 β_{2} β2 x 2 x^2 x2,这样即保证了模型的表达能力,也降低了模型的复杂度。从而防止过拟合。
在这里插入图片描述
下图是未增加regularization和增加了regularization的区别展示图
可以看出未增加regularization的时候,模型可以将噪点也拟合进去了,因此图形很不平滑,发生了过拟合。而增加regularization之后,图形变得很平滑。
在这里插入图片描述

2. regularization分类

regularization有两类分别是L1和L2,L1增加的是参数的一范数,L2增加的二范数
在这里插入图片描述
最常用的是L2regularization

3. pytorch L2 regularization

pytorch中L2 regularization叫weight_decay
在这里插入图片描述

4. 自实现L1 regularization

在这里插入图片描述

5. 完整代码

import  torch
import  torch.nn as nn
import  torch.nn.functional as F
import  torch.optim as optim
from    torchvision import datasets, transformsfrom visdom import Visdombatch_size=200
learning_rate=0.01
epochs=10train_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=True, download=True,transform=transforms.Compose([transforms.ToTensor(),# transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=False, transform=transforms.Compose([transforms.ToTensor(),# transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=True)class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.model = nn.Sequential(nn.Linear(784, 200),nn.LeakyReLU(inplace=True),nn.Linear(200, 200),nn.LeakyReLU(inplace=True),nn.Linear(200, 10),nn.LeakyReLU(inplace=True),)def forward(self, x):x = self.model(x)return xdevice = torch.device('cuda:0')
net = MLP().to(device)
optimizer = optim.SGD(net.parameters(), lr=learning_rate, weight_decay=0.01)
criteon = nn.CrossEntropyLoss().to(device)viz = Visdom()viz.line([0.], [0.], win='train_loss', opts=dict(title='train loss'))
viz.line([[0.0, 0.0]], [0.], win='test', opts=dict(title='test loss&acc.',legend=['loss', 'acc.']))
global_step = 0for epoch in range(epochs):for batch_idx, (data, target) in enumerate(train_loader):data = data.view(-1, 28*28)data, target = data.to(device), target.cuda()logits = net(data)loss = criteon(logits, target)optimizer.zero_grad()loss.backward()# print(w1.grad.norm(), w2.grad.norm())optimizer.step()global_step += 1viz.line([loss.item()], [global_step], win='train_loss', update='append')if batch_idx % 100 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))test_loss = 0correct = 0for data, target in test_loader:data = data.view(-1, 28 * 28)data, target = data.to(device), target.cuda()logits = net(data)test_loss += criteon(logits, target).item()pred = logits.argmax(dim=1)correct += pred.eq(target).float().sum().item()viz.line([[test_loss, correct / len(test_loader.dataset)]],[global_step], win='test', update='append')viz.images(data.view(-1, 1, 28, 28), win='x')viz.text(str(pred.detach().cpu().numpy()), win='pred',opts=dict(title='pred'))test_loss /= len(test_loader.dataset)print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/2582.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

上海亚商投顾:沪指缩量调整 有色、煤炭等周期股集体大跌

上海亚商投顾前言:无惧大盘涨跌,解密龙虎榜资金,跟踪一线游资和机构资金动向,识别短期热点和强势个股。 一.市场情绪 沪指昨日缩量调整,午后一度跌近1%,黄白二线走势分化,微盘股指数涨超3%。军…

GLM4——Function calling(函数调用)

作用: Function Calling可以根据用户的输入自行判断何时需要调用哪些函数,并且可以根据目标函数的描述生成符合要求的请求参数。开发人员可以使用函数调用能力,通过GPT实现: 在进行自然语言交流时,通过调用外部工具回答问题&…

[图解敏捷口号]普天之下皆我妈-01-新手一次走两步

0 00:00:00,830 --> 00:00:03,750 今天我们来看一句敏捷口号 1 00:00:04,030 --> 00:00:05,660 后面我们会 2 00:00:06,300 --> 00:00:09,570 列一些比较幼稚的口号 3 00:00:09,970 --> 00:00:11,145 一句一句 4 00:00:11,145 --> 00:00:12,790 我们来剖析一…

超级牛逼 专业的 js 汉字拼音转换库

pinyin-pro 是一个专业的 js 汉字拼音转换库,功能丰富、准确率高、性能优异。 🎨 特色功能 支持拼音/声母/韵母/首字母/音调/全部信息支持人名姓氏模式支持文本和拼音匹配支持自定义拼音支持获取带拼音汉字的 HTML 字符串支持获取汉字的所有拼音支持拼音…

SpringBoot 启动控制台 --banner.txt实现打印炫酷控制台图案

文章目录 目录 文章目录 安装流程 小结 概要安装流程技术细节小结 概要 分析源代码,banner.txt实现打印控制台 控制台图案生成网址:Ascii艺术字实现个性化Spring Boot启动banner图案,轻松修改更换banner.txt文件内容,收集了丰富…

SSL证书安装失败怎么办?

在互联网时代,SSL(Secure Sockets Layer)证书已成为保障网站数据传输安全、提升用户信任度的重要工具。然而,在实际操作过程中,SSL证书的安装并非总能一帆风顺,有时会遇到各种导致安装失败的问题。本文将详…

munge服务启动异常问题记录

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、问题一:cannot canonicalize "/var/run/munge"二、问题二:Failed to create "/var/run/munge/munge.socket.2.lock": Perm…

微前端集成模式:独立部署与共享依赖

微前端是一种将复杂的前端应用程序拆分为多个独立的、可独立部署的子应用的架构模式。在微前端中,每个子应用都可以独立开发、测试和部署,而不会影响其他子应用的运行。 在微前端中,有两种常见的集成模式:独立部署和共享依赖。 独…

美硕科技授权世强硬创代理,继电器具备控制功率小、电磁干扰小特点

受工业自动化、智能制造、物联网以及可再生能源等领域发展的推动,全球继电器市场在过去几年中持续增长,预计未来几年将继续保持这一趋势。 为满足日益增长的市场需求,世强先进(深圳)科技股份有限公司(下称…

jvm中的引用类型

Java中的引用类型 1.强引用 一个对象A被局部变量、静态变量引用了就产生了强引用。因为局部变量、静态变量都是被GC Root对象关联上的,所以被引用的对象A,就在GC Root的引用链上了。只要这一层关系存在,对象A就不会被垃圾回收器回收。所以只要…

Linux shell编程学习笔记47:lsof命令

0 前言 今天国产电脑提示磁盘空间已耗尽,使用用df命令检查文件系统情况,发现/dev/sda2已使用100%。 Linux shell编程学习笔记39:df命令https://blog.csdn.net/Purpleendurer/article/details/135577571于是开始清理磁盘空间。 第一步是查看…

第二篇、SD真人视频转卡通动画 学习笔记

接着第一篇 2K转4K 生成玩卡通视频后,如何转换成更高分辨率的视频 1、将第一篇生成的工作目录下的output目录改成output-old,新建一个output目录 2、进入0,1子目录,把EbSynth生成的Outputxxx都删掉,frames和keys下…

IP5306 2.1A充电2.4 A放电电高集成度移动电源SOC IC,为移动电源提供完美电源解决方案

IP5306是一款集成升压转换器、锂电池充电管 理、电池电量指示的多功能电源管理 SOC,为移动 电源提供完整的电源解决方案。 IP5306的高集成度与丰富功能,使其在应用时 仅需极少的外围器件,并有效减小整体方案的尺寸, 降低 BOM 成本…

Unity射击游戏开发教程:(5)使用 GetComponent 在 Unity 中进行脚本通信

我认为脚本通信是刚开始使用 Unity 时较难掌握的概念之一,我将继续讨论这个概念。在本文中,我将介绍如何在游戏对象发生碰撞时使用 GetComponent 来访问另一个脚本。 在这个游戏场景中,我有两个游戏对象,它们都有自己的脚本,需要进行通信。我们有玩家脚本和敌人脚本。Enem…

MP:There is no getter for property named ‘null‘ in ‘class XXX‘异常

在使用主键进行更新或者删除的时候,报下面错误 There is no getter for property named ‘null’ in class 。。。 代码如下 Builder Data public class Course implements Serializable {private static final long serialVersionUID -16929324809307129L;privat…

CC++的内存管理

C&C的内存管理 栈:即用即销毁 堆:有需求再申请空间,手动销毁 注意:const 修饰可以使变量有常性,但是变量存储的域与没有const修饰是相同的。 即: 在 main函数中, const int a 0; int b…

1个月,从估值3.5亿美元到卷款3000万,ZKasino做了什么?

项目rug跑路,对于加密圈的人而言,并不少见。 但rug得这么理直气壮,甚至在圈内掀起了一波对投资机构和KOL的口诛笔伐的项目,ZKasino,也算是头几个。 短短一个月时间,从估值3.5亿美元、众人吹捧的明星级项目&…

深入了解Redis内存淘汰策略中的LRU算法应用

LRU算法简析 LRU(Least Recently Used,最近最少使用)算法是一种常见的内存淘汰策略,它根据数据的访问时间来决定哪些数据会被淘汰。LRU算法的核心思想是:最久未被访问的数据,被认为是最不常用的数据&#…

基于Tensorflow完成mnist数据集的数字手写体识别

基于Tensorflow完成mnist数据集的数字手写体识别 关于知识背景CNNFCNN 关于数据集新的改变 关于知识背景 CNN 卷积神经网络(Convolutional Neural Networks,简称CNN)是一种具有局部连接、权值共享等特点的深层前馈神经网络(Feed…

【大数据】LSM树,专为海量数据读写而生的数据结构

目录 1.什么是LSM树? 2.LSM树的落地实现 1.什么是LSM树? LSM树(Log-Structured Merge Tree)是一种专门针对大量写操作做了优化的数据存储结构,尤其适用于现代大规模数据处理系统,如NoSQL数据库&#xff…