【pytorch练习】使用pytorch神经网络架构拟合余弦曲线

在本篇博客中,我们将通过一个简单的例子,讲解如何使用 PyTorch 实现一个神经网络模型来拟合余弦函数。本文将详细分析每个步骤,从数据准备到模型的训练与评估,帮助大家更好地理解如何使用 PyTorch 进行模型构建和训练。

一、背景

在机器学习中,拟合曲线是一个常见的任务,尤其是在函数预测和回归问题中。今天,我们使用一个简单的神经网络模型来拟合余弦曲线,具体步骤包括:

准备训练数据;
构建神经网络模型;
训练模型;
可视化预测结果与真实数据。
本例通过 PyTorch 实现了整个流程,我们将逐步展开。

二、代码解析

  1. 导入必要的库
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
matplotlib.rcParams['font.sans-serif'] = ['SimHei']
matplotlib.rcParams['axes.unicode_minus'] = False

首先,我们导入了PyTorch相关的库 torch、torch.nn,以及用于数据加载的 DataLoader 和 TensorDataset。为了可视化结果,我们还引入了 matplotlib。

此外,为了避免某些系统环境下的警告信息,我们设置了 os.environ[“KMP_DUPLICATE_LIB_OK”] = “TRUE”,这有助于避免在多线程计算中遇到一些潜在的错误。

  1. 准备拟合数据
# 准备拟合数据
x = np.linspace(-2 * np.pi, 2 * np.pi, 400)  # 生成从 -2π 到 2π 的 400 个点
y = np.cos(x)  # 计算对应的余弦值# 绘制生成的数据的散点图
plt.figure(figsize=(7, 5), dpi=160)
plt.scatter(x, y, color='red', label='生成数据')
plt.title('x 和 cos(x) 数据散点图', fontsize=15)
plt.xlabel('x', fontsize=12)
plt.ylabel('cos(x)', fontsize=12)
plt.legend(fontsize=12)
plt.grid(True)
plt.show()

在这里插入图片描述
使用 numpy.linspace 生成一个包含 400 个点的 x 轴数据,范围从 -2π 到 2π,然后计算对应的 y 值,这里 y = cos(x)。

  1. 接下来,将数据整理成 PyTorch 能够接受的格式
# 将数据做成数据集的模样
X = np.expand_dims(x, axis=1)  # 使 X 变为二维数组
Y = y.reshape(400, -1)  # Y 为一列的数组
dataset = TensorDataset(torch.tensor(X, dtype=torch.float), torch.tensor(Y, dtype=torch.float))
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)

通过 TensorDataset 将 x 和 y 数据捆绑成一个数据集,并使用 DataLoader 来批量加载数据,设置 batch_size=10,并启用数据打乱(shuffle=True)以增加模型训练的随机性。

  1. 构建神经网络
    接下来,我们将构建一个简单的神经网络来拟合这些数据。在这个例子中,我们使用了一个全连接的神经网络,并采用了 ReLU 激活函数。网络的结构如下:

输入层:1 个神经元(因为我们的输入是一个 1D 数值)。
隐藏层 1:10 个神经元,使用 ReLU 激活函数。
隐藏层 2:100 个神经元,使用 ReLU 激活函数。
隐藏层 3:10 个神经元,使用 ReLU 激活函数。
输出层:1 个神经元,输出拟合的结果。

import torch.nn as nnclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.net = nn.Sequential(nn.Linear(in_features=1, out_features=10), nn.ReLU(),nn.Linear(10, 100), nn.ReLU(),nn.Linear(100, 10), nn.ReLU(),nn.Linear(10, 1))def forward(self, input: torch.FloatTensor):return self.net(input)# 创建模型实例
net = Net()
net 
Net((net): Sequential((0): Linear(in_features=1, out_features=10, bias=True)(1): ReLU()(2): Linear(in_features=10, out_features=100, bias=True)(3): ReLU()(4): Linear(in_features=100, out_features=10, bias=True)(5): ReLU()(6): Linear(in_features=10, out_features=1, bias=True))
)

这段代码定义了一个简单的神经网络类 Net,它继承自 nn.Module。通过 nn.Sequential 来堆叠多个层,使得网络的结构更加简洁和易于理解。每一层都紧跟着一个 ReLU 激活函数,用于引入非线性特征。

  1. 训练模型
    接下来,我们开始训练模型。我们选择 Adam 优化器,并使用均方误差(MSE)作为损失函数。在每个 epoch 中,我们都会迭代一次所有的训练数据,通过反向传播更新模型参数。
# 设置优化器和损失函数
optim = torch.optim.Adam(net.parameters(), lr=0.001)
Loss = nn.MSELoss()# 训练模型
for epoch in range(100):loss = Nonefor batch_x, batch_y in dataloader:# 前向传播y_predict = net(batch_x)# 计算损失loss = Loss(y_predict, batch_y)# 清空梯度optim.zero_grad()# 反向传播loss.backward()# 更新参数optim.step()# 每10步打印一次训练日志if (epoch + 1) % 10 == 0:print(f"训练步骤: {epoch+1}, 模型损失: {loss.item()}")
训练步骤: 10, 模型损失: 0.12506699562072754
训练步骤: 20, 模型损失: 0.024437546730041504
训练步骤: 30, 模型损失: 0.08189699053764343
训练步骤: 40, 模型损失: 0.03138166293501854
训练步骤: 50, 模型损失: 0.00651053711771965
训练步骤: 60, 模型损失: 0.0032562180422246456
训练步骤: 70, 模型损失: 0.00018047125195153058
训练步骤: 80, 模型损失: 0.005476313643157482
训练步骤: 90, 模型损失: 0.0014593529049307108
训练步骤: 100, 模型损失: 0.0008746677194721997
  1. 可视化
    训练完成后,我们可以使用训练好的模型来进行预测,并将预测结果与真实数据进行比较。
# 绘制真实数据与预测数据的对比
plt.figure(figsize=(12, 7), dpi=160)
plt.plot(x, y, label="实际值", marker="X")
plt.plot(x, predict.detach().numpy(), label="预测值", marker='o')
plt.xlabel("x", size=15)
plt.ylabel("cos(x)", size=15)
plt.xticks(size=15)
plt.yticks(size=15)
plt.legend(fontsize=15)
plt.show()

在这里插入图片描述
通过绘制图表,我们可以清楚地看到,训练好的神经网络已经很好地拟合了余弦函数,并且与真实数据非常接近。

** 通过本篇教程,我们了解了如何使用 PyTorch 从零开始构建神经网络,并使用该网络拟合一个简单的余弦曲线。我们逐步演示了数据准备、网络构建、模型训练以及预测可视化的过程。希望通过这篇文章,你能够掌握神经网络的基本操作,并能够将其应用于其他任务中。**

如果你有任何问题或建议,欢迎在评论区留言交流!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/66133.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

设计模式中的代理模式

在Java中,代理模式(Proxy Pattern)可以通过静态代理和动态代理两种主要方式实现。 一、静态代理模式 在编译时就已经确定了代理类和被代理类的关系。代理类和目标对象通常实现相同的接口或继承相同父类。缺点是对于每个需要代理的目标对象都…

编程入门(2)-2024年 RAD Studio version 12发布综述

随着2024年即将画上句号,我想借此机会回顾一下我们在这一年中发布的一些Embarcadero产品、行业趋势,并感谢我们尊贵的客户们对我们的产品一如既往的支持。这一年对我们来说充满了激动人心的变化和发展,我们非常高兴能与您一起踏上这段旅程。 …

使用LLM自回归与超级转义词表生成图像:超越传统扩散模型的新范式

引言 在人工智能领域,尤其是自然语言处理(NLP)和计算机视觉(CV),大型语言模型(LLM)的出现带来了前所未有的变革。随着技术的进步,研究人员开始探索如何将LLM应用于更多样…

visual studio 安全模式

一、安全模式: 在 Visual Studio 中,安全模式是一种启动方式,允许你在禁用所有扩展和自定义设置的情况下启动 Visual Studio。这个模式可以帮助排除插件或扩展引起的问题,特别是在 Visual Studio 无法正常启动时。 二、安全模式下…

RocketMQ消费者如何消费消息以及ack

1.前言 此文章是在儒猿课程中的学习笔记,感兴趣的想看原来的课程可以去咨询儒猿课堂 这篇文章紧挨着上一篇博客来进行编写,有些不清楚的可以看下上一篇博客: https://blog.csdn.net/u013127325/article/details/144934073 2.broker是如何…

现代光学基础5

总结自老师的讲义 yt5 开卷考试复习资料:光探测器与光伏技术 目录 光探测器(Photodetector) 工作原理二极管电路连接方式响应度(Responsivity)微弱光检测超导纳米线单光子探测光电二极管噪声 太阳能电池&#xff0…

EasyExcel自定义动态下拉框(附加业务对象转换功能)

全文直接复制粘贴即可,测试无误 一、注解类 1、ExcelSelected.java 设置下拉框 Documented Target({ElementType.FIELD})//用此注解用在属性上。 Retention(RetentionPolicy.RUNTIME)//注解不仅被保存到class文件中,jvm加载class文件之后&#xff0c…

【2025最新计算机毕业设计】基于Spring Boot+Vue影院购票系统(高质量源码,提供文档,免费部署到本地)

作者简介:✌CSDN新星计划导师、Java领域优质创作者、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和学生毕业项目实战,高校老师/讲师/同行前辈交流。✌ 主要内容:🌟Java项目、Python项目、前端项目、PHP、ASP.NET、人工智能…

信息科技伦理与道德1:研究方法

1 问题描述 1.1 讨论? 请挑一项信息技术,谈一谈为什么认为他是道德的/不道德的,或者根据使用场景才能判断是否道德。判断的依据是什么(自身的道德准则)?为什么你觉得你的道德准则是合理的,其他…

解读 C++23 std::expected 函数式写法

文章目录 std::expected 基础概念什么是 std::expected?优势与 std::optional 和 std::variant 的区别 函数式写法的功能和应用1. transform : 对"成功值"进行映射基本用法完全返回不同类型 2 and_then : 对"成功值"进行连续计算3 transform_error : 对&q…

Web安全扫盲

1、建立网络思维模型的必要 1 . 我们只有知道了通信原理, 才能够清楚的知道数据的交换过程。 2 . 我们只有知道了网络架构, 才能够清楚的、准确的寻找漏洞。 2、局域网的简单通信 局域网的简单通信(数据链路层) 一般局域网都通…

领域驱动设计(4)—绑定模型与实现

(4)—绑定模型与实现 模式:MODEL-DRIVEN DESIGN为什么模型对用户至关重要?模式:HANDS-ON MODELER 很多项目设计之初只考虑到模型如何设计,没有将模型如何实现、数据关系如何存储这些实现考虑在内,往往设计…

@MapperScan

简介: MapperScan注解是MyBatis框架在Spring Boot中的一个重要集成注解 作用: MapperScan主要作用是告诉Spring框架在启动时扫描指定的包路径,并将该路径下的所有MyBatis的Mapper接口批量注入到Spring容器中。这样,开发者就可以…

Linux驱动开发(18):linux驱动并发与竞态

并发是指多个执行单元同时、并行执行,而并发的执行单元对共享资源(硬件资源和软件上的全局变量、静态变量等)的访问 则很容易导致竞态。对于多核系统,很容易理解,由于多个CPU同时执行,多个CPU同时读、写共享资源时很容易造成竞态。…

009:传统计算机视觉之边缘检测

本文为合集收录,欢迎查看合集/专栏链接进行全部合集的系统学习。 合集完整版请参考这里。 本节来看一个利用传统计算机视觉方法来实现图片边缘检测的方法。 什么是边缘检测? 边缘检测是通过一些算法来识别图像中物体之间或者物体与背景之间的边界&…

QML使用Popup实现弹出Message

方案一:popup import QtQuick 2.15 import QtQuick.Controls 2.15 import QtQuick.Layouts 1.15ApplicationWindow {visible: truewidth: 640height: 480title: qsTr("Top Message Popup Example")ColumnLayout {anchors.centerIn: parentspacing: 10Butt…

idea java.lang.OutOfMemoryError: GC overhead limit exceeded

Idea build项目直接报错 java: GC overhead limit exceeded java.lang.OutOfMemoryError: GC overhead limit exceeded 设置 编译器 原先heap size 设置的是 700M , 改成 2048M即可

webpack5基础(上篇)

一、基本配置 在开始使用 webpack 之前,我们需要对 webpack 的配置有一定的认识 1、5大核心概念 1)entry (入口) 指示 webpack 从哪个文件开始打包 2)output(输出) 制视 webpack 打包完的…

boot-126网易邮件发送

【SpringBoot整合JavaMail发送邮件】 一 . Java Mail基本概念 1.SMTP Simple Mail Transfer Protocol:简单邮件传输协议,用于发送邮件的协议。 2.POP3 Post office Protocol 3:邮局通讯协议第三版,用于接收邮件的标准协议。 3.IMAP Internet Message Acc…

《学校一卡通管理系统》数据库MySQL的设计与实现

引言:学校一卡通管理系统旨在为学校提供一个高效的数字化管理平台,集中管理学生和教职工的账户、充值、消费、查询等日常事务。通过该系统,学生可以便捷地进行充值、消费及查看余额,管理员则可以高效地管理用户账户、充值记录、消费记录等数据。系统采用MySQL数据库,通过视…