模型选择实战

我们现在可以通过多项式拟合来探索这些概念。

import math
import numpy as np
import torch
from torch import nn
from d2l import torch as d2l

生成数据集

给定x,我们将使用以下三阶多项式来生成训练和测试数据的标签:

max_degree = 20  # 多项式的最大阶数
n_train, n_test = 100, 100  # 训练和测试数据集大小
true_w = np.zeros(max_degree)  # 分配大量的空间
true_w[0:4] = np.array([5, 1.2, -3.4, 5.6])features = np.random.normal(size=(n_train + n_test, 1))
np.random.shuffle(features)
poly_features = np.power(features, np.arange(max_degree).reshape(1, -1))
for i in range(max_degree):poly_features[:, i] /= math.gamma(i + 1)  # gamma(n)=(n-1)!
# labels的维度:(n_train+n_test,)
labels = np.dot(poly_features, true_w)
labels += np.random.normal(scale=0.1, size=labels.shape)

对模型进行训练和测试

首先让我们实现一个函数来评估模型在给定数据集上的损失。

def evaluate_loss(net, data_iter, loss):  #@save"""评估给定数据集上模型的损失"""metric = d2l.Accumulator(2)  # 损失的总和,样本数量for X, y in data_iter:out = net(X)y = y.reshape(out.shape)l = loss(out, y)metric.add(l.sum(), l.numel())return metric[0] / metric[1]

现在定义训练函数。

def train(train_features, test_features, train_labels, test_labels,num_epochs=400):loss = nn.MSELoss(reduction='none')input_shape = train_features.shape[-1]# 不设置偏置,因为我们已经在多项式中实现了它net = nn.Sequential(nn.Linear(input_shape, 1, bias=False))batch_size = min(10, train_labels.shape[0])train_iter = d2l.load_array((train_features, train_labels.reshape(-1,1)),batch_size)test_iter = d2l.load_array((test_features, test_labels.reshape(-1,1)),batch_size, is_train=False)trainer = torch.optim.SGD(net.parameters(), lr=0.01)animator = d2l.Animator(xlabel='epoch', ylabel='loss', yscale='log',xlim=[1, num_epochs], ylim=[1e-3, 1e2],legend=['train', 'test'])for epoch in range(num_epochs):d2l.train_epoch_ch3(net, train_iter, loss, trainer)if epoch == 0 or (epoch + 1) % 20 == 0:animator.add(epoch + 1, (evaluate_loss(net, train_iter, loss),evaluate_loss(net, test_iter, loss)))print('weight:', net[0].weight.data.numpy())

三阶多项式函数拟合(正常)

我们将首先使用三阶多项式函数,它与数据生成函数的阶数相同。 结果表明,该模型能有效降低训练损失和测试损失。 学习到的模型参数也接近真实值w=[5,1.2,−3.4,5.6]。

# 从多项式特征中选择前4个维度,即1,x,x^2/2!,x^3/3!
train(poly_features[:n_train, :4], poly_features[n_train:, :4],labels[:n_train], labels[n_train:])

线性函数拟合(欠拟合)

让我们再看看线性函数拟合,减少该模型的训练损失相对困难。 在最后一个迭代周期完成后,训练损失仍然很高。 当用来拟合非线性模式(如这里的三阶多项式函数)时,线性模型容易欠拟合。

# 从多项式特征中选择前2个维度,即1和x
train(poly_features[:n_train, :2], poly_features[n_train:, :2],labels[:n_train], labels[n_train:])

高阶多项式函数拟合(过拟合)

现在,让我们尝试使用一个阶数过高的多项式来训练模型。 在这种情况下,没有足够的数据用于学到高阶系数应该具有接近于零的值。 因此,这个过于复杂的模型会轻易受到训练数据中噪声的影响。 虽然训练损失可以有效地降低,但测试损失仍然很高。 结果表明,复杂模型对数据造成了过拟合。

# 从多项式特征中选取所有维度
train(poly_features[:n_train, :], poly_features[n_train:, :],labels[:n_train], labels[n_train:], num_epochs=1500)

总结:

  • 欠拟合是指模型无法继续减少训练误差。过拟合是指训练误差远小于验证误差。
  • 由于不能基于训练误差来估计泛化误差,因此简单地最小化训练误差并不一定意味着泛化误差的减小。机器学习模型需要注意防止过拟合,即防止泛化误差过大。
  • 验证集可以用于模型选择,但不能过于随意地使用它。
  • 我们应该选择一个复杂度适当的模型,避免使用数量不足的训练样本。

借鉴:

4.4. 模型选择、欠拟合和过拟合 — 动手学深度学习 2.0.0 documentation (d2l.ai)

手学深度学习 2.0.0 documentation (d2l.ai)](https://zh-v2.d2l.ai/chapter_multilayer-perceptrons/underfit-overfit.html#id11)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/645936.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

前端面试题-(浏览器内核,CSS选择器优先级,盒子模型,CSS硬件加速,CSS扩展)

前端面试题-(浏览器内核,CSS选择器优先级,盒子模型,CSS硬件加速,CSS扩展) 常见的浏览器内核CSS选择器优先级盒子模型CSS硬件加速CSS扩展 常见的浏览器内核 内核描述Trident(IE内核)主要用在window系统中的IE浏览器中&…

分布式锁实现(mysql,以及redis)以及分布式的概念(续)redsync包使用

道生一,一生二,二生三,三生万物 这张尽量结合上一章进行使用:上一章 这章主要是讲如何通过redis实现分布式锁的 redis实现 这里我用redis去实现: 技术:golang,redis,数据结构 …

使用Python的pygame库实现自动追踪目标的Snake游戏

和上一期不同的目标追踪入门不同的是,这期是自动追踪科学游戏,话不多说,321上链接 一、项目背景 Snake游戏是一款经典的游戏,玩家需要控制一条蛇在屏幕上移动,吃掉食物并避免撞到自己的身体或墙壁。传统的Snake游戏通常…

校园跑腿小程序源码系统+代取快递+食堂超市代买+跑腿 带完整的安装代码包以及搭建教程

随着移动互联网的普及,人们越来越依赖于手机应用来解决日常生活中的各种问题。特别是在校园内,由于快递点距离宿舍较远、食堂排队人数过多等情况,学生对于便捷、高效的服务需求愈发强烈。在此背景下,校园跑腿小程序源码系统应运而…

蓝桥杯备赛 week 3 —— 高精度(C/C++,零基础,配图)

目录 🌈前言: 📁 高精度的概念 📁 高精度加法和其模板 📁 高精度减法和其模板 📁 高精度乘法和其模板 📁 高精度除法和其模板 📁 总结 🌈前言: 这篇文…

Linux/Academy

Enumeration nmap 首先扫描目标端口对外开放情况 nmap -p- 10.10.10.215 -T4 发现对外开放了22,80,33060三个端口,端口详细信息如下 结果显示80端口运行着http,且给出了域名academy.htb,现将ip与域名写到/et/hosts中,然后从ht…

【12.PWM输出】蓝桥杯嵌入式一周拿奖速成系列

系列文章目录 蓝桥杯嵌入式系列文章目录(更多此系列文章可见) PWM输出 系列文章目录一、STM32CUBEMX配置二、项目代码1.main.c --> PWMOutputProcess 总结 一、STM32CUBEMX配置 STM32CUBEMX PA6 ->TIM16_CH1; PA7-> TIM17_CH1 预分频设置为79,自动重装载设置999PWM输…

PyQtGraph 之PlotCurveItem 详解

PyQtGraph 之PlotCurveItem 详解 PlotCurveItem 是 PyQtGraph 中用于显示曲线的图形项。以下是 PlotCurveItem 的主要参数和属性: 创建 PlotCurveItem 对象 import pyqtgraph as pg# 创建一个 PlotCurveItem curve pg.PlotCurveItem()常用的参数和属性 setData(…

资源管理核心考点梳理

个人总结,仅供参考,欢迎加好友一起讨论 PMP - 资源管理核心考点梳理 资源管理包括人力资源和实物资源管理。学习的重点是人力资源的管理,这一章是考试的重点章节,在新考纲中,“人”这一模块在题目种的比例是42%。 01 …

14.块参照的旋转(BlockReference)

愿你出走半生,归来仍是少年! 环境:.NET FrameWork4.5、ObjectArx 2016 64bit、Entity Framework 6. 在排水管网数据的编图时,时常会遇见针对雨水箅等进行旋转。由于数据存储在数据库内,通过CAD自带的旋转功能只能变更图面而无法…

SVG 矩形 – SVG Rectangle (3)

简介 rect 元素用于创建 SVG 矩形和矩形图形的变体。有六个属性决定矩形在屏幕上的形状和位置 x, y – 矩形左上角的 x, y 坐标width、height – 矩形的宽度和高度rx、ry – 矩形角的 x 和 y 半径 如果没有设置 x 和 y 属性,则矩形的左上角放置在点 (0,0) 处。 如…

Python 中的多进程(01/2):简介

一、说明 本文简要而简明地介绍了 Python 编程语言中的多处理(多进程)。解释多处理的基本信息,如什么是多处理?为什么用多处理?在python中怎么办等。 二、什么是多处理? 多处理是指系统同时支持多个处理器的…

C语言第八弹---一维数组

✨个人主页: 熬夜学编程的小林 💗系列专栏: 【C语言详解】 【数据结构详解】 一维数组 1、数组的概念 2、⼀维数组的创建和初始化 2.1、数组创建 2.2、数组的初始化 2.3、数组的类型 3、⼀维数组的使用 3.1、数组下标 3.2、数组元素…

Vscode配置python代码开发

文章目录 1. 配置python运行环境2. 常用插件说明3. Vscode配置文件说明3.1 setting.json配置说明3.2 launch.json配置说明 4. 远程开发5. 其他配置 1. 配置python运行环境 安装python插件:点击VSCode左侧边栏中的扩展图标(或按 CtrlShiftX)&a…

从方法论到最佳实践,深度解析企业云原生 DevSecOps 体系构建

作者:匡大虎 引言 安全一直是企业上云关注的核心问题。随着云原生对云计算基础设施和企业应用架构的重定义,传统的企业安全防护架构已经不能够满足新时期下的安全防护要求。为此企业安全人员需要针对云原生时代的安全挑战重新进行系统性的威胁分析并构…

深度视觉目标跟踪进展综述-论文笔记

中科大学报上的一篇综述,总结得很详细,整理了相关笔记。 1 引言 目标跟踪旨在基于初始帧中指定的感兴趣目标( 一般用矩形框表示) ,在后续帧中对该目标进行持续的定位。 基于深度学习的跟踪算法,采用的框架包括相关滤波器、分类…

Rust 通用代码生成器莲花发布红莲尝鲜版二十视频,支持 Nodejs 21,18 和 14

Rust 通用代码生成器莲花发布红莲尝鲜版二十视频,支持 Nodejs 21,18 和 14 Rust 通用代码生成器莲花发布红莲尝鲜版二十视频。此版本开始支持 Nodejs21,18 加上原来支持的 Nodejs 14。现在莲花支持三种 Nodejs 环境。适应性大大增强,也给您的使用带来了…

IDEA配置Maven教程

1.Maven下载 首先我们进入maven官方网站Maven – Welcome to Apache Maven,进入网页后,点击Download去下载 下载免安装版,解压即可,解压至磁盘任意目录,尽量不要取中文名如下图: 2.配置Maven环境变量 复制Maven所在的…

cms中getshell的各种姿势

cms中getshell的各种姿势 wordpress----getshell 这里wordpress后台,外观,主题,编辑,修改其中的404模版,保存后就可拿到shell 直接访问,就可以成功连接 另外,在主题中,可以上传 …

[蓝桥杯]真题讲解:景区导游(DFS遍历、图的存储、树上前缀和与LCA)

蓝桥杯真题讲解&#xff1a; 一、视频讲解二、暴力代码三、正解代码 一、视频讲解 视频讲解 二、暴力代码 //暴力代码&#xff1a;DFS #include<bits/stdc.h> #define endl \n #define deb(x) cout << #x << " " << x << \n; #de…