pytorch-解决过拟合之regularization

目录

  • 1.解决过拟合的方法
  • 2. regularization
  • 2. regularization分类
  • 3. pytorch L2 regularization
  • 4. 自实现L1 regularization
  • 5. 完整代码

1.解决过拟合的方法

  • 更多的数据
  • 降低模型复杂度
    regularization
  • Dropout
  • 数据处理
  • 早停止

2. regularization

以二分类的cross entropy为例,就是在其公式后增加一项参数一范数累加和,并乘以一个超参数用来权衡参数配比。
模型优化是要使得前部分loss尽量小,那么同时也要后半部分范数接近于0,但是为了保持模型的表达能力还要保留比如 β 0 β_{0} β0+ β 1 β_{1} β1x+ β 2 β_{2} β2 x 2 x^2 x2,那么可能使得 β 0 β_{0} β0 β 2 β_{2} β2 β 3 β_{3} β3 = 0.01 而 β 4 β_{4} β4- β n β_{n} βn很小很小,比如0.0001,这样就使得比如f(x)= x 7 x^7 x7,退化为 β 0 β_{0} β0+ β 1 β_{1} β1x+ β 2 β_{2} β2 x 2 x^2 x2,这样即保证了模型的表达能力,也降低了模型的复杂度。从而防止过拟合。
在这里插入图片描述
下图是未增加regularization和增加了regularization的区别展示图
可以看出未增加regularization的时候,模型可以将噪点也拟合进去了,因此图形很不平滑,发生了过拟合。而增加regularization之后,图形变得很平滑。
在这里插入图片描述

2. regularization分类

regularization有两类分别是L1和L2,L1增加的是参数的一范数,L2增加的二范数
在这里插入图片描述
最常用的是L2regularization

3. pytorch L2 regularization

pytorch中L2 regularization叫weight_decay
在这里插入图片描述

4. 自实现L1 regularization

在这里插入图片描述

5. 完整代码

import  torch
import  torch.nn as nn
import  torch.nn.functional as F
import  torch.optim as optim
from    torchvision import datasets, transformsfrom visdom import Visdombatch_size=200
learning_rate=0.01
epochs=10train_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=True, download=True,transform=transforms.Compose([transforms.ToTensor(),# transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=False, transform=transforms.Compose([transforms.ToTensor(),# transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=True)class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.model = nn.Sequential(nn.Linear(784, 200),nn.LeakyReLU(inplace=True),nn.Linear(200, 200),nn.LeakyReLU(inplace=True),nn.Linear(200, 10),nn.LeakyReLU(inplace=True),)def forward(self, x):x = self.model(x)return xdevice = torch.device('cuda:0')
net = MLP().to(device)
optimizer = optim.SGD(net.parameters(), lr=learning_rate, weight_decay=0.01)
criteon = nn.CrossEntropyLoss().to(device)viz = Visdom()viz.line([0.], [0.], win='train_loss', opts=dict(title='train loss'))
viz.line([[0.0, 0.0]], [0.], win='test', opts=dict(title='test loss&acc.',legend=['loss', 'acc.']))
global_step = 0for epoch in range(epochs):for batch_idx, (data, target) in enumerate(train_loader):data = data.view(-1, 28*28)data, target = data.to(device), target.cuda()logits = net(data)loss = criteon(logits, target)optimizer.zero_grad()loss.backward()# print(w1.grad.norm(), w2.grad.norm())optimizer.step()global_step += 1viz.line([loss.item()], [global_step], win='train_loss', update='append')if batch_idx % 100 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))test_loss = 0correct = 0for data, target in test_loader:data = data.view(-1, 28 * 28)data, target = data.to(device), target.cuda()logits = net(data)test_loss += criteon(logits, target).item()pred = logits.argmax(dim=1)correct += pred.eq(target).float().sum().item()viz.line([[test_loss, correct / len(test_loader.dataset)]],[global_step], win='test', update='append')viz.images(data.view(-1, 1, 28, 28), win='x')viz.text(str(pred.detach().cpu().numpy()), win='pred',opts=dict(title='pred'))test_loss /= len(test_loader.dataset)print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/2582.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

上海亚商投顾:沪指缩量调整 有色、煤炭等周期股集体大跌

上海亚商投顾前言:无惧大盘涨跌,解密龙虎榜资金,跟踪一线游资和机构资金动向,识别短期热点和强势个股。 一.市场情绪 沪指昨日缩量调整,午后一度跌近1%,黄白二线走势分化,微盘股指数涨超3%。军…

[图解敏捷口号]普天之下皆我妈-01-新手一次走两步

0 00:00:00,830 --> 00:00:03,750 今天我们来看一句敏捷口号 1 00:00:04,030 --> 00:00:05,660 后面我们会 2 00:00:06,300 --> 00:00:09,570 列一些比较幼稚的口号 3 00:00:09,970 --> 00:00:11,145 一句一句 4 00:00:11,145 --> 00:00:12,790 我们来剖析一…

SpringBoot 启动控制台 --banner.txt实现打印炫酷控制台图案

文章目录 目录 文章目录 安装流程 小结 概要安装流程技术细节小结 概要 分析源代码,banner.txt实现打印控制台 控制台图案生成网址:Ascii艺术字实现个性化Spring Boot启动banner图案,轻松修改更换banner.txt文件内容,收集了丰富…

SSL证书安装失败怎么办?

在互联网时代,SSL(Secure Sockets Layer)证书已成为保障网站数据传输安全、提升用户信任度的重要工具。然而,在实际操作过程中,SSL证书的安装并非总能一帆风顺,有时会遇到各种导致安装失败的问题。本文将详…

munge服务启动异常问题记录

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、问题一:cannot canonicalize "/var/run/munge"二、问题二:Failed to create "/var/run/munge/munge.socket.2.lock": Perm…

美硕科技授权世强硬创代理,继电器具备控制功率小、电磁干扰小特点

受工业自动化、智能制造、物联网以及可再生能源等领域发展的推动,全球继电器市场在过去几年中持续增长,预计未来几年将继续保持这一趋势。 为满足日益增长的市场需求,世强先进(深圳)科技股份有限公司(下称…

jvm中的引用类型

Java中的引用类型 1.强引用 一个对象A被局部变量、静态变量引用了就产生了强引用。因为局部变量、静态变量都是被GC Root对象关联上的,所以被引用的对象A,就在GC Root的引用链上了。只要这一层关系存在,对象A就不会被垃圾回收器回收。所以只要…

Linux shell编程学习笔记47:lsof命令

0 前言 今天国产电脑提示磁盘空间已耗尽,使用用df命令检查文件系统情况,发现/dev/sda2已使用100%。 Linux shell编程学习笔记39:df命令https://blog.csdn.net/Purpleendurer/article/details/135577571于是开始清理磁盘空间。 第一步是查看…

第二篇、SD真人视频转卡通动画 学习笔记

接着第一篇 2K转4K 生成玩卡通视频后,如何转换成更高分辨率的视频 1、将第一篇生成的工作目录下的output目录改成output-old,新建一个output目录 2、进入0,1子目录,把EbSynth生成的Outputxxx都删掉,frames和keys下…

IP5306 2.1A充电2.4 A放电电高集成度移动电源SOC IC,为移动电源提供完美电源解决方案

IP5306是一款集成升压转换器、锂电池充电管 理、电池电量指示的多功能电源管理 SOC,为移动 电源提供完整的电源解决方案。 IP5306的高集成度与丰富功能,使其在应用时 仅需极少的外围器件,并有效减小整体方案的尺寸, 降低 BOM 成本…

Unity射击游戏开发教程:(5)使用 GetComponent 在 Unity 中进行脚本通信

我认为脚本通信是刚开始使用 Unity 时较难掌握的概念之一,我将继续讨论这个概念。在本文中,我将介绍如何在游戏对象发生碰撞时使用 GetComponent 来访问另一个脚本。 在这个游戏场景中,我有两个游戏对象,它们都有自己的脚本,需要进行通信。我们有玩家脚本和敌人脚本。Enem…

CC++的内存管理

C&C的内存管理 栈:即用即销毁 堆:有需求再申请空间,手动销毁 注意:const 修饰可以使变量有常性,但是变量存储的域与没有const修饰是相同的。 即: 在 main函数中, const int a 0; int b…

1个月,从估值3.5亿美元到卷款3000万,ZKasino做了什么?

项目rug跑路,对于加密圈的人而言,并不少见。 但rug得这么理直气壮,甚至在圈内掀起了一波对投资机构和KOL的口诛笔伐的项目,ZKasino,也算是头几个。 短短一个月时间,从估值3.5亿美元、众人吹捧的明星级项目&…

深入了解Redis内存淘汰策略中的LRU算法应用

LRU算法简析 LRU(Least Recently Used,最近最少使用)算法是一种常见的内存淘汰策略,它根据数据的访问时间来决定哪些数据会被淘汰。LRU算法的核心思想是:最久未被访问的数据,被认为是最不常用的数据&#…

基于Tensorflow完成mnist数据集的数字手写体识别

基于Tensorflow完成mnist数据集的数字手写体识别 关于知识背景CNNFCNN 关于数据集新的改变 关于知识背景 CNN 卷积神经网络(Convolutional Neural Networks,简称CNN)是一种具有局部连接、权值共享等特点的深层前馈神经网络(Feed…

【大数据】LSM树,专为海量数据读写而生的数据结构

目录 1.什么是LSM树? 2.LSM树的落地实现 1.什么是LSM树? LSM树(Log-Structured Merge Tree)是一种专门针对大量写操作做了优化的数据存储结构,尤其适用于现代大规模数据处理系统,如NoSQL数据库&#xff…

C# winform OpenProtocol中数据中的UI是什么类型?

C# winform OpenProtocol中数据中的UI是什么类型?

vue2项目升级到vue3经历分享

依据vue官方文档,vue2在2023年12月31日终止维护。因此决定将原来的岁月云记账升级到vue3,预计工作量有点大,于是想着把过程记录下来。 原系统使用的技术栈 "dependencies": {"axios": "^0.21.1","babel-…

C++-DAY1

思维导图 有以下定义,说明哪些量可以改变哪些不可以改变? const char *p; const (char *) p; char *const p; const char* const p; char const *p; (char *) const p; char const* const p; const char *p:指针 p 所指向的内容不可改…

【嵌入式】Arduino IDE + ESP32开发环境配置

一 背景说明 最近想捣鼓一下ESP32的集成芯片,比较了一下,选择Arduino IDE并添加ESP32支持库的方式来开发,下面记录一下安装过程以及安装过程中遇到的坑。 二 下载准备 【1】Arduino IDE ESP32支持一键安装包(非常推荐&#xff0…