【PyTorch】多项式回归

文章目录

  • 1. 模型与代码实现
    • 1.1. 模型
    • 1.2. 代码实现
      • 1.2.1. 完整代码
      • 1.2.2. 输出结果
  • 2. Q&A
    • 2.1. 欠拟合与过拟合

1. 模型与代码实现

1.1. 模型

  • 将多项式特征值预处理为线性模型的特征值。即
    y = w 0 + w 1 x + w 2 x 2 + ⋯ + w n x n y = w_0+w_1x+w_2x^2+\dots+w_nx^n y=w0+w1x+w2x2++wnxn变换为 y = w 0 + w 1 z 1 + w 2 z 2 + ⋯ + w n z n y=w_0+w_1z_1+w_2z_2+\dots+w_nz_n y=w0+w1z1+w2z2++wnzn
  • 为了避免指数值过大,可以将 x i x^i xi调整为 x i i ! \frac{x^i}{i!} i!xi,即 y = w 0 + w 1 x 1 ! + w 2 x 2 2 ! + ⋯ + w n x n n ! y = w_0+w_1\frac{x}{1!}+w_2\frac{x^2}{2!}+\dots+w_n\frac{x^n}{n!} y=w0+w11!x+w22!x2++wnn!xn

1.2. 代码实现

1.2.1. 完整代码

import os
import numpy as np
import math, torch
from d2l import torch as d2l
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from tensorboardX import SummaryWriter
from rich.progress import trackdef evaluate_loss(dataloader):"""评估给定数据集上模型的损失"""metric.reset()with torch.no_grad():for X, y in dataloader:X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True)loss = criterion(net(X), y)metric.add(loss.sum(), loss.numel())return metric[0] / metric[1]def load_dataset(data_arrays):"""加载数据集"""dataset = TensorDataset(*data_arrays)return DataLoader(dataset, batch_size, shuffle=True, pin_memory=True,num_workers=num_workers, prefetch_factor=prefetch_factor)if __name__ == '__main__':# 全局参数设置learning_rate = 0.01device = torch.device("cuda" if torch.cuda.is_available() else "cpu")num_epochs = 400batch_size = 10num_workers = 0prefetch_factor = 2max_degree = 20             # 多项式最高阶数model_degree = 1           # 多项式模型阶数n_train, n_test = 100, 100  # 训练集和测试集大小true_w = np.zeros(max_degree+1)true_w[0:4] = np.array([5, 1.2, -3.4, 5.6])# 创建记录器def get_logdir():root = 'runs'if not os.path.exists(root):os.mkdir(root)order = len(os.listdir(root)) + 1return f'runs/exp{order}'writer = SummaryWriter(get_logdir())# 生成数据集features = np.random.normal(size=(n_train + n_test, 1))np.random.shuffle(features)poly_features = np.power(features, np.arange(max_degree+1).reshape(1, -1))for i in range(max_degree+1):poly_features[:, i] /= math.gamma(i + 1)  # gamma(n)=(n-1)!labels = np.dot(poly_features, true_w)labels += np.random.normal(scale=0.1, size=labels.shape)    # 加高斯噪声服从N(0, 0.01)poly_features, labels = [torch.as_tensor(x, dtype=torch.float32) for x in [poly_features, labels]]# 创建模型net = nn.Sequential(nn.Linear(model_degree+1, 1, bias=False)).to(device, non_blocking=True)def init_weights(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, mean=0, std=0.01)net.apply(init_weights)criterion = nn.MSELoss(reduction='none')optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate)# 加载数据集features_train, labels_train = poly_features[:n_train, :model_degree+1], labels[:n_train].reshape(-1, 1)features_test, labels_test = poly_features[n_train:, :model_degree+1], labels[n_train:].reshape(-1, 1)dataloader_train = load_dataset((features_train, labels_train))dataloader_test = load_dataset((features_test, labels_test))# 训练循环metric = d2l.Accumulator(2)  # 损失的总和, 样本数量for epoch in track(range(num_epochs)):for X, y in dataloader_train:X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True)loss = criterion(net(X), y)optimizer.zero_grad()loss.mean().backward()optimizer.step()writer.add_scalars(f"{model_degree}-degree", {"train_loss": evaluate_loss(dataloader_train),"test_loss": evaluate_loss(dataloader_test),}, epoch)print("weights =", net[0].weight.data.cpu().numpy())writer.close()

1.2.2. 输出结果

  • 采用1阶多项式(线性模型)拟合:
    1degree

  • 采用3阶多项式拟合
    3degree

  • 采用20阶多项式拟合
    20degree

2. Q&A

2.1. 欠拟合与过拟合

数据集是按照3阶多项式生成的。使用1阶多项式去拟合,发现最后损失始终降不下去,这种情况称为欠拟合,说明模型复杂度不够;使用20阶多项式去拟合,发现测试损失最后还增长了,训练和测试损失总体也比3阶多项式模型的值高,这种情况称为过拟合,说明模型太复杂了,训练过程受到了噪声的影响。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/202379.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

开关电源超强总结

什么是Power Supply? 开关电源的元件构成 三种基本的非隔离开关电源 三种基本的隔离开关电源 反激变换器(Flyback)工作原理 (电流连续模式) 反激变换器(Flyback)工作原理 (电流断续模式&#x…

js中批量修改对象属性

首先,有这个对象 let a {id: 1,name: 张三,age: 18,sex: 0 } 需求:同时修改name,id,并添加一个新属性c 常规写法: a.id 2; a.name 李四; a.c 1; 方法1:使用Object.assign() Object.assign()常用来拷贝合并对象,相同属性…

信息化系列——企业信息化建设(3)

期待已久的对策,马上”出炉“,第一次看的朋友,建议现在主页看看(1)和(2),那咱们就废话少说了,开始今天的正题。 企业信息化建设对策 1、增强企业信息化意识 企业管理者…

【Python】Python读Excel文件生成xml文件

目录 ​前言 正文 1.Python基础学习 2.Python读取Excel表格 2.1安装xlrd模块 2.2使用介绍 2.2.1常用单元格中的数据类型 2.2.2 导入模块 2.2.3打开Excel文件读取数据 2.2.4常用函数 2.2.5代码测试 2.2.6 Python操作Excel官方网址 3.Python创建xml文件 3.1 xml语法…

简单选择排序显示第K趟

感悟&#xff1a;一定要小心细节&#xff0c;循环中注意要是否需要重新赋值 #include <stdio.h> int main() { int c 0; int b 0; int n 0; int k 0; int i 0; int j 0; int max 0; int z 0; int i1 0; int temp 0; …

PACS源码,医学影像传输系统源码,全院级应用,支持放射、超声、内窥镜、病理等影像科室,且具备多种图像处理及三维重建功能

​三维智能PACS系统源码&#xff0c;医学影像采集传输系统源码 PACS系统以大型关系型数据库作为数据和图像的存储管理工具&#xff0c;以医疗影像的采集、传输、存储和诊断为核心&#xff0c;集影像采集传输与存储管理、影像诊断查询与报告管理、综合信息管理等综合应用于一体的…

GUAVA 工具类

Guava是一个Google的开源Java库&#xff0c;常用的工具&#xff1a; 集合工具类&#xff0c;包括Lists&#xff08;创建&#xff1a;newArrayList、newLinkedList等&#xff09;、Sets&#xff08;创建&#xff1a;newHashSet、newLinkedHashSet等&#xff09;和Maps&#xff…

破阵子(三分+凸包旋转卡壳)

Description 平面上有n个点&#xff0c;每个点有各自的速度向量&#xff0c;现在给出0时刻&#xff0c;在同一时刻&#xff0c;平面点的最远距离叫做special dis 他们每个点的位置和每个点的速度向量&#xff0c;现在求在哪个时刻的时候&#xff0c;他们的special dis 最小&am…

postgres(pg)数据库使用建表语句创建数据表

一般创建数据表有两种方式&#xff0c;一种是使用建表语句&#xff0c;二是使用图形化工具建表&#xff08;如&#xff1a;pgadmin4、Navicat、DataGrip、dbeaver等等之类的工具&#xff09;&#xff1b; 1、使用建表语句创建数据表&#xff1a; -- 建立学生测试表语句如下&am…

Java问题和解决方案汇总

将其他类型转换成数值类型的解决方案 例&#xff1a;Integer转成Double类型 Double.parseDouble(a.toString()); 嵌套Map中&#xff0c;拿到里层Map的value(例&#xff1a;Map.get("xxx").get("xxx")) 主要的目的是为了得到第二个get&#xff0c;只要将第一…

接口测试:轻松掌握基础知识,快速提升测试技能!

1.client端和server端 开始接口测试之前&#xff0c;首先搞清楚client端与server端是什么&#xff0c;区别。 web前端&#xff0c;顾名思义&#xff0c;指用户可以直观操作和看到的界面&#xff0c;包括web页面的结构&#xff0c;web的外观视觉表现及web层面的交互实…

顶级设计师力荐的界面设计软件,设计新选择

即时设计 作为专业的在线协作UI设计软件&#xff0c;即时设计可以实现视觉效果、交互效果、体验效果一站成型&#xff0c;为你的目标用户创造流畅体验。 轻松绘制原型&#xff1a;借助社区设计资源和原型模板的即时设计&#xff0c;开始敏捷高效的工作。与产品经理分解用户需…

E. Good Triples

首先 如果产生进位的话是一定不对的&#xff0c;因为进位会给一个1&#xff0c;但是损失了10 然后可以按位直接考虑&#xff0c;转换成一个隔板法组合数问题 // Problem: E. Good Triples // Contest: Codeforces - Codeforces Round 913 (Div. 3) // URL: https://codeforces…

xShell快捷键

Xshell 是一个强大的终端仿真器&#xff0c;它支持多种Linux发行版的远程连接。Xshell提供了一系列的快捷键&#xff0c;以提高用户的操作效率。以下是一些Xshell中常用的快捷键&#xff1a; 新建会话窗口&#xff1a; Ctrl N 或 Ctrl Shift N 在现有会话中打开新标签&…

反序列化漏洞详解(二)

目录 pop链前置知识&#xff0c;魔术方法触发规则 pop构造链解释&#xff08;开始烧脑了&#xff09; 字符串逃逸基础 字符减少 字符串逃逸基础 字符增加 实例获取flag 字符串增多逃逸 字符串减少逃逸 延续反序列化漏洞(一)的内容 pop链前置知识&#xff0c;魔术方法触…

python 深度图转换为点云

一、概念 深度图是点云由3D点投影到2D平面的逆过程,其中每个像素值代表的是物体到相机xy平面的距离。深度图可以提供场景中某一点距离摄像机的远近信息。 二、python代码 import numpy as np import open3d as o3d import os# Depth Intrinsic Parameters fx_d = 7.8128789…

软件测试之python+requests接口自动化测试框架实例教程

前段时间由于公司测试方向的转型&#xff0c;由原来的web页面功能测试转变成接口测试&#xff0c;之前大多都是手工进行&#xff0c;利用postman和jmeter进行的接口测试&#xff0c;后来&#xff0c;组内有人讲原先web自动化的测试框架移驾成接口的自动化框架&#xff0c;使用的…

HTTPS安全防窃听、防冒充、防篡改三大机制原理

前言 本文内容主要对以下两篇文章内容整理过滤&#xff0c;用最直观的角度了解到HTTPS的保护机制&#xff0c;当然啦&#xff0c;如果想要深入了解HTTPS&#xff0c;本文是远远不够的&#xff0c;可以针对以下第一个链接中的文章拓展板块进行学习&#xff0c;希望大家通过本文能…

WAT、CP、FT的概念及周边名词解释

CP是把坏的Die挑出来&#xff0c;可以减少封装和测试的成本。可以更直接的知道Wafer 的良率。 FT是把坏的chip挑出来&#xff1b;检验封装的良率。 现在对于一般的wafer工艺&#xff0c;很多公司多把CP给省了&#xff0c;减少成本。 CP对整片Wafer的每个Die来测试&#xff0…

光伏系统方案设计的注意点

随着太阳能技术的日益发展&#xff0c;光伏系统已经成为一种重要的可再生能源解决方案。然而&#xff0c;设计一个高效、可靠的光伏系统需要考虑到许多因素。本文将探讨光伏系统方案设计的注意点&#xff0c;包括系统规模、地理位置、组件选择、系统布局和运维策略。 系统规模 …