【深度学习笔记】08 欠拟合和过拟合

08 欠拟合和过拟合

      • 生成数据集
      • 对模型进行训练和测试
      • 三阶多项式函数拟合(正常)
      • 线性函数拟合(欠拟合)
      • 高阶多项式函数拟合(过拟合)

import math
import numpy as np
import torch
from torch import nn
from d2l import torch as d2l

生成数据集

给定 x x x,我们将使用以下三阶多项式来生成训练和测试数据的标签:

y = 5 + 1.2 x − 3.4 x 2 2 ! + 5.6 x 3 3 ! + ϵ where  ϵ ∼ N ( 0 , 0. 1 2 ) . y = 5 + 1.2x - 3.4\frac{x^2}{2!} + 5.6 \frac{x^3}{3!} + \epsilon \text{ where } \epsilon \sim \mathcal{N}(0, 0.1^2). y=5+1.2x3.42!x2+5.63!x3+ϵ where ϵN(0,0.12).

噪声项 ϵ \epsilon ϵ服从均值为0且标准差为0.1的正态分布。
在优化的过程中,我们通常希望避免非常大的梯度值或损失值。
这就是我们将特征从 x i x^i xi调整为 x i i ! \frac{x^i}{i!} i!xi的原因,
这样可以避免很大的 i i i带来的特别大的指数值。
我们将为训练集和测试集各生成100个样本。

max_degree = 20  # 多项式的最大阶数
n_train, n_test = 100, 100  # 训练和测试数据集的大小
true_w = np.zeros(max_degree)  # 分配大量空间
true_w[0:4] = np.array([5, 1.2, -3.4, 5.6])features = np.random.normal(size = (n_train + n_test, 1))
np.random.shuffle(features)
poly_features = np.power(features, np.arange(max_degree).reshape(1, -1))
for i in range(max_degree):poly_features[:, i] /= math.gamma(i + 1)  # gamma(n)=(n-1)!
# labels的维度:(n_train + n_test)
labels = np.dot(poly_features, true_w)
labels += np.random.normal(scale = 0.1, size = labels.shape)

存储在poly_features中的单项式由gamma函数重新缩放,
其中 Γ ( n ) = ( n − 1 ) ! \Gamma(n)=(n-1)! Γ(n)=(n1)!
从生成的数据集中查看一下前2个样本,
第一个值是与偏置相对应的常量特征。

# NumPy ndarray转换为tensor
true_w, features, poly_features, labels = [torch.tensor(x, dtype=torch.float32) for x in [true_w, features, poly_features, labels]]
features[:2], poly_features[:2, :], labels[:2]
(tensor([[-1.2565],[-2.2676]]),tensor([[ 1.0000e+00, -1.2565e+00,  7.8936e-01, -3.3060e-01,  1.0385e-01,-2.6096e-02,  5.4648e-03, -9.8091e-04,  1.5406e-04, -2.1508e-05,2.7024e-06, -3.0868e-07,  3.2321e-08, -3.1238e-09,  2.8036e-10,-2.3484e-11,  1.8442e-12, -1.3630e-13,  9.5145e-15, -6.2919e-16],[ 1.0000e+00, -2.2676e+00,  2.5709e+00, -1.9433e+00,  1.1016e+00,-4.9960e-01,  1.8881e-01, -6.1164e-02,  1.7337e-02, -4.3681e-03,9.9049e-04, -2.0418e-04,  3.8583e-05, -6.7300e-06,  1.0901e-06,-1.6479e-07,  2.3354e-08, -3.1151e-09,  3.9243e-10, -4.6835e-11]]),tensor([ -1.1486, -17.5782]))

对模型进行训练和测试

首先实现一个函数来评估模型在给定数据集上的损失

def evaluate_loss(net, data_iter, loss):  #@save"""评估给定数据集上模型的损失"""metric = d2l.Accumulator(2)  # 损失的总和,样本数量for X, y in data_iter:out = net(X)y = y.reshape(out.shape)l = loss(out, y)metric.add(l.sum(), l.numel())return metric[0] / metric[1]

定义训练函数

def train(train_features, test_features, train_labels, test_labels,num_epochs=400):loss = nn.MSELoss()input_shape = train_features.shape[-1]# 不设置偏置,因为我们已经在多项式中实现了它net = nn.Sequential(nn.Linear(input_shape, 1, bias=False))batch_size = min(10, train_labels.shape[0])train_iter = d2l.load_array((train_features, train_labels.reshape(-1,1)),batch_size)test_iter = d2l.load_array((test_features, test_labels.reshape(-1,1)),batch_size, is_train=False)trainer = torch.optim.SGD(net.parameters(), lr=0.01)animator = d2l.Animator(xlabel='epoch', ylabel='loss', yscale='log',xlim=[1, num_epochs], ylim=[1e-3, 1e2],legend=['train', 'test'])for epoch in range(num_epochs):d2l.train_epoch_ch3(net, train_iter, loss, trainer)if epoch == 0 or (epoch + 1) % 20 == 0:animator.add(epoch + 1, (evaluate_loss(net, train_iter, loss),evaluate_loss(net, test_iter, loss)))print('weight:', net[0].weight.data.numpy())

三阶多项式函数拟合(正常)

首先使用三阶多项式函数,它与数据生成函数的阶数相同。
结果表明,该模型能有效降低训练损失和测试损失。
学习到的模型参数也接近真实值 w = [ 5 , 1.2 , − 3.4 , 5.6 ] w = [5, 1.2, -3.4, 5.6] w=[5,1.2,3.4,5.6]

# 从多项式特征中选择前4个维度,即1,x,x^2/2!,x^3/3!
train(poly_features[:n_train, :4], poly_features[n_train:, :4],labels[:n_train], labels[n_train:])
weight: [[ 5.016911   1.2140008 -3.4260864  5.581189 ]]

在这里插入图片描述

线性函数拟合(欠拟合)

当用于拟合非线性模式(如这里的三阶多项式函数)时,线性模型容易欠拟合。

# 从多项式特征中选择前2个维度,即1和x
train(poly_features[:n_train, :2], poly_features[n_train:, :2],labels[:n_train], labels[n_train:])
weight: [[3.6582568 3.967709 ]]

在这里插入图片描述

高阶多项式函数拟合(过拟合)

尝试使用一个阶数过高的多项式来训练模型。在这种情况下,没有足够的数据用于学习高阶系数应该具有接近于零的值。因此,这个过于复杂的模型会轻易受到训练数据中噪声的影响。虽然训练损失可以有效地降低,但测试损失仍然很高。结果表明,复杂模型对数据造成了过拟合。

# 从多项式特征中选取所有维度
train(poly_features[:n_train, :], poly_features[n_train:, :],labels[:n_train], labels[n_train:], num_epochs=1500)
weight: [[ 5.0074177   1.2689897  -3.3731365   5.1744385  -0.28380862  1.38800660.28433597  0.23886798  0.10456859 -0.01008706 -0.13444044 -0.00965116-0.09757714 -0.09527045  0.21342376  0.13767722 -0.08204057 -0.106664410.12475184  0.21017507]]

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/200265.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

公有云迁移研究——AWS Translate

大纲 1 什么是Translate2 Aws Translate是怎么运作的3 Aws Translate和Google Translate的区别4 迁移任务4.1 迁移原因 5 Aws Translate的Go demo6 迁移中遇到的问题6.1 账号和权限问题:6.2 小语种 1 什么是Translate Translate是一种文本翻译服务,它使…

xcode opencv

1、导入报错 Undefined symbols: linker command failed with exit code 1 (use -v to see invocation) 直接添加如下图内容即可

<JavaEE> synchronized关键字和锁机制 -- 锁的特点、锁的使用、锁竞争和死锁、死锁的解决方法

目录 一、synchronized 关键字简介 二、synchronized 的特点 -- 互斥 三、synchronized 的特点 -- 可重入 四、synchronized 的使用示例 4.1 修饰代码块 - 锁任意实例 4.2 修饰代码块 - 锁当前实例 4.3 修饰普通方法 - 锁方法所在实例 4.4 修饰代码块 - 锁指定类对象 …

【从零开始学习JVM | 第二篇】字节码文件的组成

前言: 字节码作为JAVA跨平台的主要原因,熟练的掌握JAVA字节码文件的组成可以帮助我们解决项目的各种问题,并且在面试中,关于字节码部分的内容却是一大考点和难点,因此我们在这里穿插讲解一下字节码文件的组成。 目录 …

16、观察者模式(Observer Pattern)

观察者(Observer Pattern) 定义对象间的一种一对多的依赖关系,当一个对象的状态发生改变时,所有依赖于它的对象都得到通知并被自动更新。主要解决一个对象状态改变给其他对象通知的问题,而且要考虑到易用和低耦合&…

你好!哈希表【JAVA】

1.初识🎶🎶🎶 它基本上是由一个数组和一个哈希函数组成的。哈希函数将每个键映射到数组的特定索引位置,这个位置被称为哈希码。当我们需要查找一个键时,哈希函数会计算其哈希码并立即返回结果,因此我们可以…

【OpenGauss源码学习 —— (RowToVec)算子】

VecToRow 算子 概述ExecInitRowToVec 函数ExecRowToVec 函数VectorizeOneTuple 函数 ExecEndRowToVec 函数总结 声明:本文的部分内容参考了他人的文章。在编写过程中,我们尊重他人的知识产权和学术成果,力求遵循合理使用原则,并在…

github首次将文件合到远端分支,发现名字不是master,而是main

暂存区和本地仓库的信息都存储在.git目录中其中 其中,暂存区和本地仓库的信息都存储在.git目录中 在自己的github上实践 1、刚开始,git clone gitgithub.com:lingze8678/my_github.git到本地 2、在克隆后的代码中加入一个pdf文件 3、在git bash中操作…

CentOS增加虚拟内存 (Linux增加内存)

前言 因为囊中羞涩不敢言,所以内存只有2G,项目在运行的时候,占用的内存已经报表,所以有的时候就会出现宕机的情况发生,后面发现可以通过使用增加虚拟内存空间,来增加内存容量。 下面进入正题,讲…

Selenium+Python自动化测试之验证码处理

两种方式: 验证码识别技术 (很难达到100%) 添加Cookie (*****五星推荐) 方式一:验证码识别技术 逻辑方式: 1:打开验证码所在页面,截图。获取验证码元素坐标,剪切出验证码图片&…

【MATLAB】辛几何模态分解分解+FFT+HHT组合算法

有意向获取代码,请转文末观看代码获取方式~也可转原文链接获取~ 1 基本定义 辛几何模态分解(CEEMDAN)是一种处理非线性和非平稳信号的适应性信号分解方法。通过在信号中加入白噪声,并多次进行经验模态分解(EMD&#…

深度学习TensorFlow2基础知识学习前半部分

目录 测试TensorFlow是否支持GPU: 自动求导: 数据预处理 之 统一数组维度 定义变量和常量 训练模型的时候设备变量的设置 生成随机数据 交叉熵损失CE和均方误差函数MSE 全连接Dense层 维度变换reshape 增加或减小维度 数组合并 广播机制&#…

clickhouse的向量化执行

背景 clickhouse快的很大一部分原因来源于数据的向量化执行,本文就来看一下向量化执行和正常标量执行的区别 SIMD的向量化执行 从上图可知,clickhouse通过SIMD指令可以做到一个cpu周期操作两个向量的运算操作,比起普通的cpu指令效率提高了N…

Understanding Computer Hardware

文章目录 I. Input Devices1. Keyboard(1)Layout(2)Key Types(3)Functionality(4)Connectivity(5)Ergonomics(6)Multimedia Keys&…

【计算机组成体系结构】主存储器的基本组成

一、半导体元器件存储二进制0/1的原理 一个存储器逻辑上分为MAR,MDR和存储体,这三块在时序逻辑电路的控制下相互配合工作。 而存储体有多个存储单元构成,每个存储单元又由每个存储元构成。一个存储元可以存放一位的二进制的0/1。 一个存储元…

OWASP安全练习靶场juice shop-更新中

Juice Shop是用Node.js,Express和Angular编写的。这是第一个 完全用 JavaScript 编写的应用程序,列在 OWASP VWA 目录中。 该应用程序包含大量不同的黑客挑战 用户应该利用底层的困难 漏洞。黑客攻击进度在记分板上跟踪。 找到这个记分牌实际上是&#…

想考研到电子类,未来从事芯片设计,目前该怎么准备?

最近看不少天坑学子想考研微电子专业,但却不知道该怎么准备?接下来就带大家一起来具体了解一下~ 首先是目标院校的选择? 目前所设的微电子专业学校里,比较厉害的有北京大学、清华大学、中国科学院大学、复旦大学、上海交通大学、…

ROS2教程08 ROS2的功能包、依赖管理、工作空间配置与编译

ROS2的功能包、依赖管理、工作空间配置与编译 版权信息 Copyright 2023 Herman YeAuromix. All rights reserved.This course and all of its associated content, including but not limited to text, images, videos, and any other materials, are protected by copyrigh…

品牌是如何通过软文推广产品的?媒介盒子为您揭秘

需求是概念的、抽象的,产品是具象的,多维的。软文推广就是通过发现消费者的需求来促使消费者主动购买产品,今天媒介盒子就来和大家聊聊:品牌是如何通过软文推广产品的。 一、 差异化内容打出独特点 差异化内容指通过和竞品的分析…

基于Intel Ai Analytics Toolkit 及边缘计算的溶氧预测水产养殖监测方案

基于AI的淡水养殖水质溯源、优化系统方案 前言一、关键需求及方案概述二、方案设计预测机制LSTM 模型基于intel AI 的时序水质分析模型与分类模型优化 三、实战分析1、方案简述2、数据分析预处理特征类型处理特征分布分析 3、特征构造4、特征选择过滤法重要性排序 5.构建LSTM模…