【PyTorch】模型选择、欠拟合和过拟合

文章目录

  • 1. 理论介绍
  • 2. 实例解析
    • 2.1. 实例描述
    • 2.2. 代码实现
      • 2.2.1. 完整代码
      • 2.2.2. 输出结果

1. 理论介绍

  • 将模型在训练数据上拟合的比在潜在分布中更接近的现象称为过拟合, 用于对抗过拟合的技术称为正则化
  • 训练误差和验证误差都很严重, 但它们之间差距很小。 如果模型不能降低训练误差,这可能意味着模型过于简单(即表达能力不足),无法捕获试图学习的模式。 这种现象被称为欠拟合
  • 训练误差是指模型在训练数据集上计算得到的误差。
  • 泛化误差是指模型应用在同样从原始样本的分布中抽取的无限多数据样本时,模型误差的期望。我们永远不能准确地计算出泛化误差,在实际中,我们只能通过将模型应用于一个独立的测试集来估计泛化误差, 该测试集由随机选取的、未曾在训练集中出现的数据样本构成。
  • 影响模型泛化的因素
    • 可调整参数的数量。当可调整参数的数量(有时称为自由度)很大时,模型往往更容易过拟合。
    • 参数采用的值。当权重的取值范围较大时,模型可能更容易过拟合。
    • 训练样本的数量。即使模型很简单,也很容易过拟合只包含一两个样本的数据集,而过拟合一个有数百万个样本的数据集则需要一个极其灵活的模型。
  • 在机器学习中,我们通常在评估几个候选模型后选择最终的模型。 这个过程叫做模型选择。候选模型可能在本质上不同,也可能是不同的超参数设置下的同一类模型。
  • 为了确定候选模型中的最佳模型,我们通常会使用验证集。验证集与测试集十分相似,唯一的区别是验证集是用于确定最佳模型,测试集是用于评估最终模型的性能
  • K K K折交叉验证:当训练数据稀缺时,将原始训练数据分成 K K K个不重叠的子集。 然后执行 K K K次模型训练和验证,每次在 ( K − 1 ) (K-1) (K1)个子集上进行训练, 并在剩余的一个子集(在该轮中没有用于训练的子集)上进行验证。 最后,通过对 K K K次实验的结果取平均来估计训练和验证误差。
  • 引起过拟合的因素
    • 模型复杂度
      模型复杂度
    • 数据集大小
      • 训练数据集中的样本越少,我们就越有可能(且更严重地)过拟合。
      • 给出更多的数据,拟合更复杂的模型可能是有益的; 如果没有足够的数据,简单的模型可能更有用。

2. 实例解析

2.1. 实例描述

使用以下三阶多项式来生成训练和测试数据 y = 5 + 1.2 x − 3.4 x 2 2 ! + 5.6 x 3 3 ! + ϵ where  ϵ ∼ N ( 0 , 0. 1 2 ) . y = 5 + 1.2x - 3.4\frac{x^2}{2!} + 5.6 \frac{x^3}{3!} + \epsilon \text{ where } \epsilon \sim \mathcal{N}(0, 0.1^2). y=5+1.2x3.42!x2+5.63!x3+ϵ where ϵN(0,0.12).并用1阶(线性模型)、3阶、20阶多项式拟合。

2.2. 代码实现

2.2.1. 完整代码

import os
import numpy as np
import math, torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from tensorboardX import SummaryWriter
from rich.progress import trackdef evaluate_loss(dataloader, net, criterion):"""评估模型在指定数据集上的损失"""num_examples = 0loss_sum = 0.0with torch.no_grad():for X, y in dataloader:X, y = X.cuda(), y.cuda()loss = criterion(net(X), y)num_examples += y.shape[0]loss_sum += loss.sum()return loss_sum / num_examplesdef load_dataset(*tensors):"""加载数据集"""dataset = TensorDataset(*tensors)return DataLoader(dataset, batch_size, shuffle=True)if __name__ == '__main__':# 全局参数设置num_epochs = 400batch_size = 10learning_rate = 0.01# 创建记录器def log_dir():root = "runs"if not os.path.exists(root):os.mkdir(root)order = len(os.listdir(root)) + 1return f'{root}/exp{order}'writer = SummaryWriter(log_dir())# 生成数据集max_degree = 20             # 多项式最高阶数n_train, n_test = 100, 100  # 训练集和测试集大小true_w = np.zeros(max_degree+1)true_w[0:4] = np.array([5, 1.2, -3.4, 5.6])features = np.random.normal(size=(n_train + n_test, 1))np.random.shuffle(features)poly_features = np.power(features, np.arange(max_degree+1).reshape(1, -1))for i in range(max_degree+1):poly_features[:, i] /= math.gamma(i + 1)  # gamma(n)=(n-1)!labels = np.dot(poly_features, true_w)labels += np.random.normal(scale=0.1, size=labels.shape)    # 加高斯噪声服从N(0, 0.01)poly_features, labels = [torch.as_tensor(x, dtype=torch.float32) for x in [poly_features, labels.reshape(-1, 1)]]def loop(model_degree):# 创建模型net = nn.Linear(model_degree+1, 1, bias=False).cuda()nn.init.normal_(net.weight, mean=0, std=0.01)criterion = nn.MSELoss(reduction='none')optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate)# 加载数据集dataloader_train = load_dataset(poly_features[:n_train, :model_degree+1], labels[:n_train])dataloader_test = load_dataset(poly_features[n_train:, :model_degree+1], labels[n_train:])# 训练循环for epoch in track(range(num_epochs), description=f'{model_degree}-degree'):for X, y in dataloader_train:X, y = X.cuda(), y.cuda()loss = criterion(net(X), y)optimizer.zero_grad()loss.mean().backward()optimizer.step()writer.add_scalars(f"{model_degree}-degree", {"train_loss": evaluate_loss(dataloader_train, net, criterion),"test_loss": evaluate_loss(dataloader_test, net, criterion),}, epoch)print(f"{model_degree}-degree: weights =", net.weight.data.cpu().numpy())for model_degree in [1, 3, 20]:loop(model_degree)writer.close()

2.2.2. 输出结果

权重

  • 采用1阶多项式(线性模型)拟合
    1
  • 采用3阶多项式拟合
    3
  • 采用20阶多项式拟合
    20

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/204510.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

计算两个结构的差

平面上有6个点,以6a1的方式运动 1 1 1 1 - - - 1 - - - 1 现在有一个点逃逸,剩下的5个点将如何运动? 2 2 2 3 - - - 3 - - - 3 将6a1的6个点减去1个点,只有两种可能,或者变成5a2&#xff0c…

SpringBoot的配置加载优先级

目录 一、背景分析 二、学习资源 三、具体使用 四、一些小技巧 方式一 方式二 一、背景分析 SpringBoot项目在打包之后&#xff0c;其配置文件就在jar包内&#xff0c;如果没有<配置文件优先级>这个机制&#xff0c;那么项目打成jar包之后&#xff0c;如果启动项目…

视频监控管理平台/智能监测/检测系统EasyCVR智能地铁监控方案,助力地铁高效运营

近日&#xff0c;关于全国44座城市开通地铁&#xff0c;却只有5座城市赚钱的新闻冲上热搜。地铁作为城市交通的重要枢纽&#xff0c;是人们出行必不可少的一种方式&#xff0c;但随着此篇新闻的爆出&#xff0c;大家也逐渐了解到城市运营的不易&#xff0c;那么&#xff0c;如何…

目标检测——SPPNet算法解读

论文&#xff1a;Spatial Pyramid Pooling in Deep Convolutional Networks for Visual Recognition 作者&#xff1a;Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun 链接&#xff1a;https://arxiv.org/abs/1406.4729 目录 1、算法概述2、Deep Networks with Spatia…

适用于 Windows 的 10 个顶级分区管理器软件

您想要对硬盘驱动器或 USB 驱动器进行分区的原因可能有多种。许多用户希望对其外部和内部硬盘驱动器进行分区以有效地管理数据。为了处理分区&#xff0c;Windows为用户提供了一个内置的分区管理工具。 Windows 用户可以通过磁盘管理面板对任何驱动器进行分区。然而&#xff0…

哪种超声波清洗眼镜不伤镜片?超声波什么牌子好?眼镜清洗机推荐

戴眼镜的朋友就能够感同身受&#xff0c;眼镜佩戴一天真的特别容易脏&#xff0c;灰尘或者是脸部出汗出油等&#xff0c;让耳机惨不忍睹&#xff01;对于眼镜清洗的方式也有很多种&#xff0c;那么到底哪种清洗方式才是真正有效果并且不伤眼镜的呢&#xff1f;跟随笔者的脚步来…

安装Kuboard管理K8S集群

目录 第一章.安装Kuboard管理K8S集群 1.安装kuboard 2.绑定K8S集群&#xff0c;完成信息设定 3.内网安装 第二章.kuboard-spray安装K8S 2.1.先拉镜像下来 2.2.之后打开后&#xff0c;先熟悉功能&#xff0c;注意版本 2.3.打开资源包管理&#xff0c;选择符合自己服务器…

Python-Opencv图像处理的小坑

1.背景 最近在做一点图像处理的事情&#xff0c;在做处理时的cv2遇到一些小坑&#xff0c;希望大家遇到的相关的问题可以注意&#xff01;&#xff01; 2. cv2.imwrite保存图像 cv2.imwrite(filename, img, [params]) filename&#xff1a;需要写入的文件名&#xff0c;包括路…

快速排序并不难

快速排序的核心框架是“二叉树的前序遍历对撞型双指针”。我们在《一维数组》一章提到过”双指针思路“&#xff1a;在处理奇偶等情况时会使用两个游标&#xff0c;一个从前向后&#xff0c;一个是从后向前来比较&#xff0c;根据结果来决定继续移动还是停止等待。快速排序的每…

医院信息系统源码,采用JAVA编程,支持跨平台部署应用,满足一级综合医院(专科二级及以下医院500床)的日常业务应用

医院HIS系统源码&#xff0c;HIS系统全套源码&#xff0c;支持电子病历4级&#xff0c;自主版权 his医院信息系统内设门诊/住院医生工作站、门诊/住院护士工作站。各工作站主要功能依据职能要求进行研发。如医生工作站主要功能为编辑电子病历、打印、处理医嘱&#xff1b;护士工…

总结|哪些平台有大模型知识库的Web API服务

截止2023/12/6 笔者个人的调研&#xff0c;有三家有大模型知识库的web api服务&#xff1a; 平台类型文档数量文档上传并解析的结构api情况返回页码文心一言插件版多文档有问答api&#xff0c;文档上传是通过网页进行上传有&#xff0c;而且是具体的chunk id&#xff0c;需要设…

“消费增值:改变你的购物方式,让每一笔消费都变得更有价值“

你是否厌倦了仅仅购买物品或享受服务后便一无所有的消费方式&#xff1f;现在&#xff0c;消费增值的概念将彻底改变你的消费观念&#xff01;通过参与消费增值&#xff0c;你的每一笔消费都将变得更有价值&#xff01; 消费增值是一种全新的消费理念&#xff0c;它让你在购物的…

星闪的三层架构

在数字化转型的浪潮中&#xff0c;物联网技术正成为连接世界的纽带&#xff0c;将各种智能设备融为一个无缝的整体。而在这个大背景下&#xff0c;星闪崭露头角&#xff0c;将成为连接未来的关键枢纽。本文将介绍星闪系统的三层架构&#xff0c;包括基础应用层、基础服务层和星…

面向AI开发的六种最重要的编程语言

在AI开发界&#xff0c;你使用的编程语言很重要。每种语言有其独特的特性。选择合适的语言不是关乎个人偏好的问题&#xff0c;而是影响你如何构建和启动AI系统的关键决定。无论你在AI方面有无经验&#xff0c;选择一种合适的语言来学习至关重要。合适的语言将帮助你创建功能强…

一文带你快速了解Python史上最快Web框架

文章目录 1. 写在前面2. Sanic框架简介2.1 背景2.2 特征与优势 3. Sanic框架实战3.1. 安装Sanic3.2. Demo案例编写 【作者主页】&#xff1a;吴秋霖 【作者介绍】&#xff1a;Python领域优质创作者、阿里云博客专家、华为云享专家。长期致力于Python与爬虫领域研究与开发工作&a…

NVRAM相关

1. Modem NVRAM四个分区 nvdata&#xff1a;手机运行过程中&#xff0c;使用(读写)的NVRAM(除了存在protect_f和protect_s中的NVRAM)都是该分区的nvram文件。存储着普通NVRAM数据、 IMEI、barcode、Calibration数据等。对应的modem path是Z:\NVRAM。NVRAM目录下有CALIBRAT、NVD…

Goby 漏洞发布| Apache OFBiz webtools/control/xmlrpc 远程代码执行漏洞(CVE-2023-49070)

漏洞名称&#xff1a; Apache OFBiz webtools/control/xmlrpc 远程代码执行漏洞&#xff08;CVE-2023-49070&#xff09; English Name&#xff1a;Apache OFBiz webtools/control/xmlrpc Remote Code Execution Vulnerability (CVE-2023-49070) CVSS core: 9.8 影响资产数&…

2023新优化应用:RIME-CNN-LSTM-Attention超前24步多变量回归预测算法

程序平台&#xff1a;适用于MATLAB 2023版及以上版本。 霜冰优化算法是2023年发表于SCI、中科院二区Top期刊《Neurocomputing》上的新优化算法&#xff0c;现如今还未有RIME优化算法应用文献哦。RIME主要对霜冰的形成过程进行模拟&#xff0c;将其巧妙地应用于算法搜索领域。 …

外网的maven项目转移到内网操作的步骤

1、新起一个仓库路径testRep&#xff0c;idea 引用的maven里的setting.xml里仓库配置修改成刚才建的路径&#xff0c;目的把需要的jar全部下载到那个文件夹里 2、项目打压缩包&#xff0c;刚才仓库文件夹打压缩包&#xff0c;并复制到内网电脑 3、内网电脑idea引入项目 4、修改…

【离散数学】——期末刷题题库(等价关系与划分)

&#x1f383;个人专栏&#xff1a; &#x1f42c; 算法设计与分析&#xff1a;算法设计与分析_IT闫的博客-CSDN博客 &#x1f433;Java基础&#xff1a;Java基础_IT闫的博客-CSDN博客 &#x1f40b;c语言&#xff1a;c语言_IT闫的博客-CSDN博客 &#x1f41f;MySQL&#xff1a…