模型选择实战

我们现在可以通过多项式拟合来探索这些概念。

import math
import numpy as np
import torch
from torch import nn
from d2l import torch as d2l

生成数据集

给定x,我们将使用以下三阶多项式来生成训练和测试数据的标签:

max_degree = 20  # 多项式的最大阶数
n_train, n_test = 100, 100  # 训练和测试数据集大小
true_w = np.zeros(max_degree)  # 分配大量的空间
true_w[0:4] = np.array([5, 1.2, -3.4, 5.6])features = np.random.normal(size=(n_train + n_test, 1))
np.random.shuffle(features)
poly_features = np.power(features, np.arange(max_degree).reshape(1, -1))
for i in range(max_degree):poly_features[:, i] /= math.gamma(i + 1)  # gamma(n)=(n-1)!
# labels的维度:(n_train+n_test,)
labels = np.dot(poly_features, true_w)
labels += np.random.normal(scale=0.1, size=labels.shape)

对模型进行训练和测试

首先让我们实现一个函数来评估模型在给定数据集上的损失。

def evaluate_loss(net, data_iter, loss):  #@save"""评估给定数据集上模型的损失"""metric = d2l.Accumulator(2)  # 损失的总和,样本数量for X, y in data_iter:out = net(X)y = y.reshape(out.shape)l = loss(out, y)metric.add(l.sum(), l.numel())return metric[0] / metric[1]

现在定义训练函数。

def train(train_features, test_features, train_labels, test_labels,num_epochs=400):loss = nn.MSELoss(reduction='none')input_shape = train_features.shape[-1]# 不设置偏置,因为我们已经在多项式中实现了它net = nn.Sequential(nn.Linear(input_shape, 1, bias=False))batch_size = min(10, train_labels.shape[0])train_iter = d2l.load_array((train_features, train_labels.reshape(-1,1)),batch_size)test_iter = d2l.load_array((test_features, test_labels.reshape(-1,1)),batch_size, is_train=False)trainer = torch.optim.SGD(net.parameters(), lr=0.01)animator = d2l.Animator(xlabel='epoch', ylabel='loss', yscale='log',xlim=[1, num_epochs], ylim=[1e-3, 1e2],legend=['train', 'test'])for epoch in range(num_epochs):d2l.train_epoch_ch3(net, train_iter, loss, trainer)if epoch == 0 or (epoch + 1) % 20 == 0:animator.add(epoch + 1, (evaluate_loss(net, train_iter, loss),evaluate_loss(net, test_iter, loss)))print('weight:', net[0].weight.data.numpy())

三阶多项式函数拟合(正常)

我们将首先使用三阶多项式函数,它与数据生成函数的阶数相同。 结果表明,该模型能有效降低训练损失和测试损失。 学习到的模型参数也接近真实值w=[5,1.2,−3.4,5.6]。

# 从多项式特征中选择前4个维度,即1,x,x^2/2!,x^3/3!
train(poly_features[:n_train, :4], poly_features[n_train:, :4],labels[:n_train], labels[n_train:])

线性函数拟合(欠拟合)

让我们再看看线性函数拟合,减少该模型的训练损失相对困难。 在最后一个迭代周期完成后,训练损失仍然很高。 当用来拟合非线性模式(如这里的三阶多项式函数)时,线性模型容易欠拟合。

# 从多项式特征中选择前2个维度,即1和x
train(poly_features[:n_train, :2], poly_features[n_train:, :2],labels[:n_train], labels[n_train:])

高阶多项式函数拟合(过拟合)

现在,让我们尝试使用一个阶数过高的多项式来训练模型。 在这种情况下,没有足够的数据用于学到高阶系数应该具有接近于零的值。 因此,这个过于复杂的模型会轻易受到训练数据中噪声的影响。 虽然训练损失可以有效地降低,但测试损失仍然很高。 结果表明,复杂模型对数据造成了过拟合。

# 从多项式特征中选取所有维度
train(poly_features[:n_train, :], poly_features[n_train:, :],labels[:n_train], labels[n_train:], num_epochs=1500)

总结:

  • 欠拟合是指模型无法继续减少训练误差。过拟合是指训练误差远小于验证误差。
  • 由于不能基于训练误差来估计泛化误差,因此简单地最小化训练误差并不一定意味着泛化误差的减小。机器学习模型需要注意防止过拟合,即防止泛化误差过大。
  • 验证集可以用于模型选择,但不能过于随意地使用它。
  • 我们应该选择一个复杂度适当的模型,避免使用数量不足的训练样本。

借鉴:

4.4. 模型选择、欠拟合和过拟合 — 动手学深度学习 2.0.0 documentation (d2l.ai)

手学深度学习 2.0.0 documentation (d2l.ai)](https://zh-v2.d2l.ai/chapter_multilayer-perceptrons/underfit-overfit.html#id11)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/645936.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

如何系统的自学Python

1、官方文档 Python 的官方文档是最权威和详尽的学习资源。在官方文档中,你可以找到 Python 的语法规则、内置函数和模块、标准库等信息。如果你想深入学习 Python,官方文档是必不可少的参考资料。 Python 的官方文档分为两个版本,分别是 P…

前端面试题-(浏览器内核,CSS选择器优先级,盒子模型,CSS硬件加速,CSS扩展)

前端面试题-(浏览器内核,CSS选择器优先级,盒子模型,CSS硬件加速,CSS扩展) 常见的浏览器内核CSS选择器优先级盒子模型CSS硬件加速CSS扩展 常见的浏览器内核 内核描述Trident(IE内核)主要用在window系统中的IE浏览器中&…

BTC交易模式 - UXTO - 工具整理

UXTO 相关工具分析 https://mempool.space/signet/ 测试网浏览器https://bitcoin.org/zh_CN/choose-your-wallet BTC钱包 正文链接:BTC交易模式 - UXTO

分布式锁实现(mysql,以及redis)以及分布式的概念(续)redsync包使用

道生一,一生二,二生三,三生万物 这张尽量结合上一章进行使用:上一章 这章主要是讲如何通过redis实现分布式锁的 redis实现 这里我用redis去实现: 技术:golang,redis,数据结构 …

使用Python的pygame库实现自动追踪目标的Snake游戏

和上一期不同的目标追踪入门不同的是,这期是自动追踪科学游戏,话不多说,321上链接 一、项目背景 Snake游戏是一款经典的游戏,玩家需要控制一条蛇在屏幕上移动,吃掉食物并避免撞到自己的身体或墙壁。传统的Snake游戏通常…

校园跑腿小程序源码系统+代取快递+食堂超市代买+跑腿 带完整的安装代码包以及搭建教程

随着移动互联网的普及,人们越来越依赖于手机应用来解决日常生活中的各种问题。特别是在校园内,由于快递点距离宿舍较远、食堂排队人数过多等情况,学生对于便捷、高效的服务需求愈发强烈。在此背景下,校园跑腿小程序源码系统应运而…

JAVA 学习 面试(九)Lambda表达式与泛型

Lambda表达式 // 使用 Lambda 表达式计算两个数的和 MathOperation addition (a, b) -> a b; // 调用 Lambda 表达式 int result addition.operation(5, 3); // MathOperation 是一个函数式接口,它包含一个抽象方法 operation,Lambda 表达式 (a, …

this.$copyText;vue-clipboard2作用;vue-clipboard2剪切板

1.安装 npm install --save vue-clipboard2 2.在main.js中引用 import Vue from vue import VueClipBoard from vue-clipboard2 Vue.use(VueClipBoard) 3.代码中使用 <button click"Copy">复制</button> Copy() { this.$copyText(this.value).then…

蓝桥杯备赛 week 3 —— 高精度(C/C++,零基础,配图)

目录 &#x1f308;前言&#xff1a; &#x1f4c1; 高精度的概念 &#x1f4c1; 高精度加法和其模板 &#x1f4c1; 高精度减法和其模板 &#x1f4c1; 高精度乘法和其模板 &#x1f4c1; 高精度除法和其模板 &#x1f4c1; 总结 &#x1f308;前言&#xff1a; 这篇文…

css Media媒体查询常用属性

使用@media规则声明媒体查询,主要用于控制在不同的设备上显示不同的效果 媒体类型: screen 适用于电脑屏幕、平板电脑、智能手机等 print 适用于打印预览 特性 width 可视区域的宽度 orientation 视窗的旋转方向(横屏landscape,默认竖屏模式)。 运算符: and 并且 , 或…

Linux/Academy

Enumeration nmap 首先扫描目标端口对外开放情况 nmap -p- 10.10.10.215 -T4 发现对外开放了22,80,33060三个端口&#xff0c;端口详细信息如下 结果显示80端口运行着http&#xff0c;且给出了域名academy.htb&#xff0c;现将ip与域名写到/et/hosts中&#xff0c;然后从ht…

Mysql 文件导入与导出

i/o 一、导出(mysqldump)<一>、导出sql文件<二>、导出csv文件 二、导入(load)三、常见报错The Mysql server is running with the --secure-file-priv option so it cannot execute this statement 一、导出(mysqldump) <一>、导出sql文件 1、整库 mysqld…

【12.PWM输出】蓝桥杯嵌入式一周拿奖速成系列

系列文章目录 蓝桥杯嵌入式系列文章目录(更多此系列文章可见) PWM输出 系列文章目录一、STM32CUBEMX配置二、项目代码1.main.c --> PWMOutputProcess 总结 一、STM32CUBEMX配置 STM32CUBEMX PA6 ->TIM16_CH1; PA7-> TIM17_CH1 预分频设置为79,自动重装载设置999PWM输…

PyQtGraph 之PlotCurveItem 详解

PyQtGraph 之PlotCurveItem 详解 PlotCurveItem 是 PyQtGraph 中用于显示曲线的图形项。以下是 PlotCurveItem 的主要参数和属性&#xff1a; 创建 PlotCurveItem 对象 import pyqtgraph as pg# 创建一个 PlotCurveItem curve pg.PlotCurveItem()常用的参数和属性 setData(…

资源管理核心考点梳理

个人总结&#xff0c;仅供参考&#xff0c;欢迎加好友一起讨论 PMP - 资源管理核心考点梳理 资源管理包括人力资源和实物资源管理。学习的重点是人力资源的管理&#xff0c;这一章是考试的重点章节&#xff0c;在新考纲中&#xff0c;“人”这一模块在题目种的比例是42%。 01 …

在uvm中,以svi结尾和sv结尾文件的区别

在UVM&#xff08;Universal Verification Methodology&#xff09;中&#xff0c;.sv和.svi文件扩展名通常是SystemVerilog文件的标准扩展名。它们都用来标识SystemVerilog源代码文件。然而&#xff0c;不同项目或团队可能会采用不同的命名约定来区分不同类型的SystemVerilog文…

14.块参照的旋转(BlockReference)

愿你出走半生,归来仍是少年&#xff01; 环境&#xff1a;.NET FrameWork4.5、ObjectArx 2016 64bit、Entity Framework 6. 在排水管网数据的编图时&#xff0c;时常会遇见针对雨水箅等进行旋转。由于数据存储在数据库内&#xff0c;通过CAD自带的旋转功能只能变更图面而无法…

YOLOv8改进 | Conv篇 | 利用轻量化PartialConv提出一种全新的结构CSPPC (参数量下降约100W)

一、本文介绍 本文给大家带来的改进机制是由我独家研制的,我结合了DualConv的思想并根据PartialConv提出了一种全新的结构CSPPC用来替换网络中的C2f,将其替换我们网络中的C2f参数量后直接下降约百万,计算量GFLOPs降低至6.0GFLOPs同时,其中的PartialConv作为一种具有高速推…

SVG 矩形 – SVG Rectangle (3)

简介 rect 元素用于创建 SVG 矩形和矩形图形的变体。有六个属性决定矩形在屏幕上的形状和位置 x, y – 矩形左上角的 x, y 坐标width、height – 矩形的宽度和高度rx、ry – 矩形角的 x 和 y 半径 如果没有设置 x 和 y 属性&#xff0c;则矩形的左上角放置在点 (0,0) 处。 如…

Python 中的多进程(01/2):简介

一、说明 本文简要而简明地介绍了 Python 编程语言中的多处理&#xff08;多进程&#xff09;。解释多处理的基本信息&#xff0c;如什么是多处理&#xff1f;为什么用多处理&#xff1f;在python中怎么办等。 二、什么是多处理&#xff1f; 多处理是指系统同时支持多个处理器的…