【PyTorch】多项式回归

文章目录

  • 1. 模型与代码实现
    • 1.1. 模型
    • 1.2. 代码实现
      • 1.2.1. 完整代码
      • 1.2.2. 输出结果
  • 2. Q&A
    • 2.1. 欠拟合与过拟合

1. 模型与代码实现

1.1. 模型

  • 将多项式特征值预处理为线性模型的特征值。即
    y = w 0 + w 1 x + w 2 x 2 + ⋯ + w n x n y = w_0+w_1x+w_2x^2+\dots+w_nx^n y=w0+w1x+w2x2++wnxn变换为 y = w 0 + w 1 z 1 + w 2 z 2 + ⋯ + w n z n y=w_0+w_1z_1+w_2z_2+\dots+w_nz_n y=w0+w1z1+w2z2++wnzn
  • 为了避免指数值过大,可以将 x i x^i xi调整为 x i i ! \frac{x^i}{i!} i!xi,即 y = w 0 + w 1 x 1 ! + w 2 x 2 2 ! + ⋯ + w n x n n ! y = w_0+w_1\frac{x}{1!}+w_2\frac{x^2}{2!}+\dots+w_n\frac{x^n}{n!} y=w0+w11!x+w22!x2++wnn!xn

1.2. 代码实现

1.2.1. 完整代码

import os
import numpy as np
import math, torch
from d2l import torch as d2l
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from tensorboardX import SummaryWriter
from rich.progress import trackdef evaluate_loss(dataloader):"""评估给定数据集上模型的损失"""metric.reset()with torch.no_grad():for X, y in dataloader:X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True)loss = criterion(net(X), y)metric.add(loss.sum(), loss.numel())return metric[0] / metric[1]def load_dataset(data_arrays):"""加载数据集"""dataset = TensorDataset(*data_arrays)return DataLoader(dataset, batch_size, shuffle=True, pin_memory=True,num_workers=num_workers, prefetch_factor=prefetch_factor)if __name__ == '__main__':# 全局参数设置learning_rate = 0.01device = torch.device("cuda" if torch.cuda.is_available() else "cpu")num_epochs = 400batch_size = 10num_workers = 0prefetch_factor = 2max_degree = 20             # 多项式最高阶数model_degree = 1           # 多项式模型阶数n_train, n_test = 100, 100  # 训练集和测试集大小true_w = np.zeros(max_degree+1)true_w[0:4] = np.array([5, 1.2, -3.4, 5.6])# 创建记录器def get_logdir():root = 'runs'if not os.path.exists(root):os.mkdir(root)order = len(os.listdir(root)) + 1return f'runs/exp{order}'writer = SummaryWriter(get_logdir())# 生成数据集features = np.random.normal(size=(n_train + n_test, 1))np.random.shuffle(features)poly_features = np.power(features, np.arange(max_degree+1).reshape(1, -1))for i in range(max_degree+1):poly_features[:, i] /= math.gamma(i + 1)  # gamma(n)=(n-1)!labels = np.dot(poly_features, true_w)labels += np.random.normal(scale=0.1, size=labels.shape)    # 加高斯噪声服从N(0, 0.01)poly_features, labels = [torch.as_tensor(x, dtype=torch.float32) for x in [poly_features, labels]]# 创建模型net = nn.Sequential(nn.Linear(model_degree+1, 1, bias=False)).to(device, non_blocking=True)def init_weights(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, mean=0, std=0.01)net.apply(init_weights)criterion = nn.MSELoss(reduction='none')optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate)# 加载数据集features_train, labels_train = poly_features[:n_train, :model_degree+1], labels[:n_train].reshape(-1, 1)features_test, labels_test = poly_features[n_train:, :model_degree+1], labels[n_train:].reshape(-1, 1)dataloader_train = load_dataset((features_train, labels_train))dataloader_test = load_dataset((features_test, labels_test))# 训练循环metric = d2l.Accumulator(2)  # 损失的总和, 样本数量for epoch in track(range(num_epochs)):for X, y in dataloader_train:X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True)loss = criterion(net(X), y)optimizer.zero_grad()loss.mean().backward()optimizer.step()writer.add_scalars(f"{model_degree}-degree", {"train_loss": evaluate_loss(dataloader_train),"test_loss": evaluate_loss(dataloader_test),}, epoch)print("weights =", net[0].weight.data.cpu().numpy())writer.close()

1.2.2. 输出结果

  • 采用1阶多项式(线性模型)拟合:
    1degree

  • 采用3阶多项式拟合
    3degree

  • 采用20阶多项式拟合
    20degree

2. Q&A

2.1. 欠拟合与过拟合

数据集是按照3阶多项式生成的。使用1阶多项式去拟合,发现最后损失始终降不下去,这种情况称为欠拟合,说明模型复杂度不够;使用20阶多项式去拟合,发现测试损失最后还增长了,训练和测试损失总体也比3阶多项式模型的值高,这种情况称为过拟合,说明模型太复杂了,训练过程受到了噪声的影响。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/202379.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

开关电源超强总结

什么是Power Supply? 开关电源的元件构成 三种基本的非隔离开关电源 三种基本的隔离开关电源 反激变换器(Flyback)工作原理 (电流连续模式) 反激变换器(Flyback)工作原理 (电流断续模式&#x…

信息化系列——企业信息化建设(3)

期待已久的对策,马上”出炉“,第一次看的朋友,建议现在主页看看(1)和(2),那咱们就废话少说了,开始今天的正题。 企业信息化建设对策 1、增强企业信息化意识 企业管理者…

【Python】Python读Excel文件生成xml文件

目录 ​前言 正文 1.Python基础学习 2.Python读取Excel表格 2.1安装xlrd模块 2.2使用介绍 2.2.1常用单元格中的数据类型 2.2.2 导入模块 2.2.3打开Excel文件读取数据 2.2.4常用函数 2.2.5代码测试 2.2.6 Python操作Excel官方网址 3.Python创建xml文件 3.1 xml语法…

PACS源码,医学影像传输系统源码,全院级应用,支持放射、超声、内窥镜、病理等影像科室,且具备多种图像处理及三维重建功能

​三维智能PACS系统源码,医学影像采集传输系统源码 PACS系统以大型关系型数据库作为数据和图像的存储管理工具,以医疗影像的采集、传输、存储和诊断为核心,集影像采集传输与存储管理、影像诊断查询与报告管理、综合信息管理等综合应用于一体的…

接口测试:轻松掌握基础知识,快速提升测试技能!

1.client端和server端 开始接口测试之前,首先搞清楚client端与server端是什么,区别。 web前端,顾名思义,指用户可以直观操作和看到的界面,包括web页面的结构,web的外观视觉表现及web层面的交互实…

顶级设计师力荐的界面设计软件,设计新选择

即时设计 作为专业的在线协作UI设计软件,即时设计可以实现视觉效果、交互效果、体验效果一站成型,为你的目标用户创造流畅体验。 轻松绘制原型:借助社区设计资源和原型模板的即时设计,开始敏捷高效的工作。与产品经理分解用户需…

E. Good Triples

首先 如果产生进位的话是一定不对的,因为进位会给一个1,但是损失了10 然后可以按位直接考虑,转换成一个隔板法组合数问题 // Problem: E. Good Triples // Contest: Codeforces - Codeforces Round 913 (Div. 3) // URL: https://codeforces…

反序列化漏洞详解(二)

目录 pop链前置知识,魔术方法触发规则 pop构造链解释(开始烧脑了) 字符串逃逸基础 字符减少 字符串逃逸基础 字符增加 实例获取flag 字符串增多逃逸 字符串减少逃逸 延续反序列化漏洞(一)的内容 pop链前置知识,魔术方法触…

软件测试之python+requests接口自动化测试框架实例教程

前段时间由于公司测试方向的转型,由原来的web页面功能测试转变成接口测试,之前大多都是手工进行,利用postman和jmeter进行的接口测试,后来,组内有人讲原先web自动化的测试框架移驾成接口的自动化框架,使用的…

HTTPS安全防窃听、防冒充、防篡改三大机制原理

前言 本文内容主要对以下两篇文章内容整理过滤,用最直观的角度了解到HTTPS的保护机制,当然啦,如果想要深入了解HTTPS,本文是远远不够的,可以针对以下第一个链接中的文章拓展板块进行学习,希望大家通过本文能…

WAT、CP、FT的概念及周边名词解释

CP是把坏的Die挑出来,可以减少封装和测试的成本。可以更直接的知道Wafer 的良率。 FT是把坏的chip挑出来;检验封装的良率。 现在对于一般的wafer工艺,很多公司多把CP给省了,减少成本。 CP对整片Wafer的每个Die来测试&#xff0…

光伏系统方案设计的注意点

随着太阳能技术的日益发展,光伏系统已经成为一种重要的可再生能源解决方案。然而,设计一个高效、可靠的光伏系统需要考虑到许多因素。本文将探讨光伏系统方案设计的注意点,包括系统规模、地理位置、组件选择、系统布局和运维策略。 系统规模 …

onnx检测推理

起因:当我想把检测的onnx模型转换到特定的设备可以使用的模型时,报错do not support dimension size > 4,onnx中有些数据的维度是五维,如图。本文使用的是edgeyolo,它使用的是yolox的head,最后的输出加上…

gmid方法设计五管OTA二级远放

首先给出第一级是OTA,第二级是CS的二级运放电路图: gmid的设计方法可以根据GBW、Av、CL来进行电路设计,因此在设计电路之前需要以上的参数要求。 1、为了满足电路的相位裕度至少60,需要对GBW、主极点、零点进行分析。 首先给出其…

应用程序无法找到xinput1_3.dll怎么办,xinput1_3.dll 丢失的解决方法

当电脑系统或特定应用程序无法找到或访问到 xinput1_3.dll 文件时,便会导致错误消息的出现,例如“找不到 xinput1_3.dll”、“xinput1_3.dll 丢失”等。这篇文章将大家讨论关于 xinput1_3.dll 文件的内容、xinput1_3.dll丢失问题的解决方法,以…

查收查引(通过文献检索开具论文收录或引用的检索证明)

开具论文收录证明的 专业术语为 查收查引,是高校图书馆、情报机构或信息服务机构提供的一项有偿服务。 因检索需要一定的时间,提交委托时请预留足够的检索时间。 一般需要提供:论文题目、作者、期刊名称、发表年代、卷期、页码。 目录 一、查…

MIT6.5840-2023-Lab1: MapReduce

前置知识 MapReduce:Master 将一个 Map 任务或 Reduce 任务分配给一个空闲的 worker。 Map阶段:被分配了 map 任务的 worker 程序读取相关的输入数据片段,生成并输出中间 k/v 对,并缓存在内存中。 Reduce阶段:所有 ma…

QT 中基于 TCP 的网络通信 (备查)

基础 基于 TCP 的套接字通信需要用到两个类: 1)QTcpServer:服务器类,用于监听客户端连接以及和客户端建立连接。 2)QTcpSocket:通信的套接字类,客户端、服务器端都需要使用。 这两个套接字通信类…

《当代家庭教育》期刊论文投稿发表简介

《当代家庭教育》杂志是家庭的参谋和助手,社会的桥梁和纽带,人生的伴侣和知音,事业的良师益友。 国家新闻出版总署批准的正规省级教育类G4期刊,知网、维普期刊网收录。安排基础教育相关稿件,适用于评职称时的论文发表…

今天面试招了个25K的测试员,从腾讯出来的果然都有两把刷子···

公司前段时间缺人,也面了不少测试,前面一开始瞄准的就是中级的水准,也没指望来大牛,提供的薪资在15-25k,面试的人很多,但平均水平很让人失望。看简历很多都是4年工作经验,但面试中,不…