李沐深度学习-多项式函数拟合试验

d2lzh_pytorch模块跳转连接

import torch
import numpy as np
import syssys.path.append("路径")
import d2lzh_pytorch as d2l'''
-----------------------------生成人工数据集
样本数n=200
特征数=3
三阶多项式y=1.2x-3.4x^2+5.6x^3+5+ε
'''
n_train, n_test, true_w, true_b = 100, 100, [1.2, -3.4, 5.6], 5
sample_features = torch.randn(n_train + n_test, 1)  # 200x1   单算的一个特征
poly_sample_features = torch.cat((sample_features, torch.pow(sample_features, 2), torch.pow(sample_features, 3)),dim=1)  # 组合成3个特征
# 因为poly_features取列的元素时,没有对列加[]限制,所以取出来的值不保有原来的维度,而是成为了一维张量,所以labels相应的也是一维张量
labels = true_w[0] * poly_sample_features[:, 0] + true_w[1] * poly_sample_features[:, 1] + true_w[2] * poly_sample_features[:, 2] + true_b
labels += torch.tensor(np.random.normal(0, 0.01, (labels.size())), dtype=torch.float)  # 以上是为了得到真实labels# print(poly_sample_features, '\n', poly_sample_features[:2], '\n', poly_sample_features[:, 2],
# '\n', poly_sample_features[:, :2])'''
-----------------------------------------------------定义,训练和测试模型
'''
# 尝试使用不同复杂度的模型来拟合生成的数据集
num_epochs, loss = 100, torch.nn.MSELoss()'''
以下函数思路设计:1. 参数传入训练数据集样本,测试数据集样本,训练标签,测试标签2. 设计网络,网络输入特征计算格式3. 设计数据读取,读取训练数据集4. 循环更新迭代步骤,目的是为了优化w,b5. 循环外使用全批量训练数据集和测试数据集,在已经更新好的w,b的基础上进行损失计算6. 画图
'''def fit_and_plot(train_features, test_features, train_labels, test_labels, label):net = torch.nn.Linear(train_features.shape[-1], 1)# Linear 自动初始化了模型参数batch_size = min(10, train_labels.shape[0])dataset = torch.utils.data.TensorDataset(train_features, train_labels)train_iter = torch.utils.data.DataLoader(dataset, batch_size, shuffle=True)optimizer = torch.optim.SGD(net.parameters(), lr=0.01)train_ls, test_ls = [], []for _ in range(num_epochs):for X, y in train_iter:  # 里面的循环作用只是为了更迭模型参数y_hat = net(X)l = loss(y_hat, y.view(y_hat.size()))  # 计算每一批的平均损失optimizer.zero_grad()l.backward()optimizer.step()train_labels = train_labels.view(-1, 1)test_labels = test_labels.view(-1, 1)  # 做这一步形状改变,是为了接下来的循环外的损失计算,损失计算时要求形状相同train_ls.append(loss(net(train_features), train_labels).item())  # 一个循环周期后记录一次更新后的参数的损失表现test_ls.append(loss(net(test_features), test_labels).item())  # 直接一整个没有分批量就进行了损失计算,使用的是最新的w,bprint(f'final epoch: train loss', train_ls[-1], 'test loss', test_ls[-1])d2l.semilogy(range(1, num_epochs + 1), train_ls, 'epochs', 'loss', label,range(1, num_epochs + 1), test_ls, ['train', 'test'])print('weight:', net.weight.data,'\nbias:', net.bias.data)'''
-----------------------------------------------------------------三阶多项式函数拟合(正常)
'''
fit_and_plot(poly_sample_features[:n_train, :], poly_sample_features[n_train:, :], labels[:n_train], labels[n_train:],'正常')'''
----------------------------------------------------------------线性函数拟合(欠拟合)
'''
# 使用的三阶多项式生成的数据标签,但是训练用的数据集是单特征的数据集,而不是三特征的数据集,这样模型就是线性模型,而不是非线性模型
# labels 还是一维张量形状
fit_and_plot(sample_features[:n_train, :], sample_features[n_train:, :], labels[:n_train], labels[n_train:], '欠拟合')'''
----------------------------------------------------------------训练样本不足(过拟合)
'''
fit_and_plot(poly_sample_features[0:2, :], poly_sample_features[:n_train, :], labels[0:2], labels[:n_train], '过拟合')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/637271.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

免费三款备受推崇的爬虫软件

在信息爆炸的时代,爬虫软件成为了数据采集、信息挖掘的得力工具。为了解决用户对优秀爬虫软件的需求,本文将专心分享三款备受推崇的爬虫软件,其中特别突出推荐147采集软件,为您开启爬虫软件的奇妙世界。 一、爬虫软件的重要性 爬…

使用OpenCV绘制图形

使用OpenCV绘制图形 绘制黄色的线: # 绘制一个黑色的背景画布 canvas np.zeros((300, 300, 3), np.uint8) # 在画布上,绘制一条起点坐标为(150, 50)、终点坐标为(150, 250),黄色的,线条宽度为20的线段 canvas cv2.line(canvas,…

迭代器模式介绍

目录 一、迭代器模式介绍 1.1 迭代器模式定义 1.2 迭代器模式原理 1.2.1 迭代器模式类图 1.2.2 模式角色说明 1.2.3 示例代码 二、迭代模式的应用 2.1 需求说明 2.2 需求实现 2.2.1 抽象迭代类 2.2.2 抽象集合类 2.2.3 主题类 2.2.4 具体迭代类 2.2.5 具体集合类 …

模拟外卖平台商家菜品上架系统

目的:模仿平台商品上架,完成外卖商家系统 需求:1.需要完成商家菜品上架操作;2.需要完成所有菜品信息的浏览; 分析: 步骤: 1.确定操作对象,并创建对象类以及对象操作类;…

tcp/ip协议2实现的插图,数据结构7 (27 - 章)

(166) 166 二七1 TCP的函数 函tcp_drain,tcp_drop (167) (168)

Windows WSL2 占用磁盘空间清理释放

目前工作中时常用到WSL2(Ubuntu20.04),在使用一段时间后会发现WSL2所占用磁盘空间越来越多,体现在WSL2之上安装Linux分发对应的vhdx虚拟磁盘文件体积越来越大,会占用Windows自身空间,即使手动清理了Linux分…

GD32E230C8T6《调试篇》之 (软件) IIC通信(主机接收从机) + GN1650驱动芯片 + 按键 + 4位8段数码管显示 (成功)

GD32E230C8T6《调试篇》之 (软件) IIC通信 GN1650驱动芯片 4位8段数码管显示(成功) IIC是什么IIC简介1)IIC总线物理连接2)IIC时序协议 按键扫描代码1)DIG2短按只一次,长按超过1s 一…

若依微服务框架,富文本加入图片保存时出现JSON parse error: Unexpected character (‘/‘ (code 47)):...

若依微服务框架,富文本加入图片保存时出现JSON parse error: Unexpected character 一、问题二、解决1.修改网关配置2、对数据进行加密解密2.1安装插件2.2vue页面加密使用2.3后台解密存储 一、问题 若依微服务项目在使用富文本框的时候,富文本加入图片进…

【Java程序员面试专栏 专业技能篇】MySQL核心面试指引(一):基础知识考察

关于MySQL部分的核心知识进行一网打尽,包括三部分:基础知识考察、核心机制策略、性能优化策略,通过一篇文章串联面试重点,并且帮助加强日常基础知识的理解,全局思维导图如下所示 本篇Blog为第一部分:基础知识考察,子节点表示追问或同级提问 基本概念 包括一些核心问…

Python中的卷积神经网络(CNN)入门

卷积神经网络(Convolutional Neural Networks, CNN)是一类特别适用于处理图像数据的深度学习模型。在Python中,我们可以使用流行的深度学习库TensorFlow和Keras来创建和训练一个CNN模型。在本文中,我们将介绍如何使用Keras创建一个…

ARMv8-AArch64 的异常处理模型详解之异常类型 Exception types

异常类型详解 Exception types 一, 什么是异常二,同步异常(synchronous exceptions)2.1 无效的指令和陷阱异常(Invalid instructions and trap exceptions)2.2 内存访问产生的异常2.3 产生异常的指令2.4 调…

构建 aarch64 以及 riscv64 交叉编译工具链(裸机)

构建 aarch64 以及 riscv64 交叉编译工具链(裸机) 因为我的需求是构建裸机的程序,所以我选择了裸机相关的交叉工具链 其他工具链也类似,在给出的两个官方链接中提供了所有的交叉工具链,选择合适的工具构建即可 一、…

基于JavaWeb+SSM+Vue智能社区服务小程序系统的设计和实现

基于JavaWebSSMVue智能社区服务小程序系统的设计和实现 滑到文末获取源码Lun文目录前言主要技术系统设计功能截图订阅经典源码专栏Java项目精品实战案例《500套》 源码获取 滑到文末获取源码 Lun文目录 目录 1系统概述 1 1.1 研究背景 1 1.2研究目的 1 1.3系统设计思想 1 2相…

【排序算法】六、快速排序(C/C++)

「前言」文章内容是排序算法之快速排序的讲解。(所有文章已经分类好,放心食用) 「归属专栏」排序算法 「主页链接」个人主页 「笔者」枫叶先生(fy) 目录 快速排序1.1 原理1.2 Hoare版本(单趟)1.3 快速排序完整代码&…

Excel 根据日期按月汇总公式

Excel 根据日期按月汇总公式 数据透视表日期那一列右击,选择“组合”,步长选择“月” 参考 Excel 根据日期按月汇总公式Excel如何按着日期来做每月求和

Linux内存管理:(九)内存规整

文章说明: Linux内核版本:5.0 架构:ARM64 参考资料及图片来源:《奔跑吧Linux内核》 Linux 5.0内核源码注释仓库地址: zhangzihengya/LinuxSourceCode_v5.0_study (github.com) 1. 引言 伙伴系统以页面为单位来管…

leetcode:每日温度---单调栈

题目: 给定一个整数数组 temperatures ,表示每天的温度,返回一个数组 answer ,其中 answer[i] 是指对于第 i 天,下一个更高温度出现在几天后。如果气温在这之后都不会升高,请在该位置用 0 来代替。 示例&…

js数组的截取和合并

在JavaScript中,你可以使用slice()方法来截取数组,使用concat()方法来合并数组。 截取数组 slice()方法返回一个新的数组对象,这个对象是一个由原数组的一部分浅复制而来。它接受两个参数,第一个参数是开始截取的位置&#xff08…

代码随想录day24

回溯算法 回溯的本质是穷举,穷举所有可能,然后选出我们想要的答案,如果想让回溯法高效一些,可以加一些剪枝的操作。 回溯法,一般可以解决如下几种问题: 1、组合问题:N个数里面按一定规则找出k个…

天龙八部资源提取工具(提取+添加+修改+查看+教程)

可以提取,添加,修改,查看天龙八部里面的数据。非常好用。 天龙八部资源提取工具(提取添加修改查看教程) 下载地址: 链接:https://pan.baidu.com/s/1XOMJ1xvsbD-UUQOv3QfHPQ?pwd0kd0 提取码&…