【pytorch练习】使用pytorch神经网络架构拟合余弦曲线

在本篇博客中,我们将通过一个简单的例子,讲解如何使用 PyTorch 实现一个神经网络模型来拟合余弦函数。本文将详细分析每个步骤,从数据准备到模型的训练与评估,帮助大家更好地理解如何使用 PyTorch 进行模型构建和训练。

一、背景

在机器学习中,拟合曲线是一个常见的任务,尤其是在函数预测和回归问题中。今天,我们使用一个简单的神经网络模型来拟合余弦曲线,具体步骤包括:

准备训练数据;
构建神经网络模型;
训练模型;
可视化预测结果与真实数据。
本例通过 PyTorch 实现了整个流程,我们将逐步展开。

二、代码解析

  1. 导入必要的库
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
matplotlib.rcParams['font.sans-serif'] = ['SimHei']
matplotlib.rcParams['axes.unicode_minus'] = False

首先,我们导入了PyTorch相关的库 torch、torch.nn,以及用于数据加载的 DataLoader 和 TensorDataset。为了可视化结果,我们还引入了 matplotlib。

此外,为了避免某些系统环境下的警告信息,我们设置了 os.environ[“KMP_DUPLICATE_LIB_OK”] = “TRUE”,这有助于避免在多线程计算中遇到一些潜在的错误。

  1. 准备拟合数据
# 准备拟合数据
x = np.linspace(-2 * np.pi, 2 * np.pi, 400)  # 生成从 -2π 到 2π 的 400 个点
y = np.cos(x)  # 计算对应的余弦值# 绘制生成的数据的散点图
plt.figure(figsize=(7, 5), dpi=160)
plt.scatter(x, y, color='red', label='生成数据')
plt.title('x 和 cos(x) 数据散点图', fontsize=15)
plt.xlabel('x', fontsize=12)
plt.ylabel('cos(x)', fontsize=12)
plt.legend(fontsize=12)
plt.grid(True)
plt.show()

在这里插入图片描述
使用 numpy.linspace 生成一个包含 400 个点的 x 轴数据,范围从 -2π 到 2π,然后计算对应的 y 值,这里 y = cos(x)。

  1. 接下来,将数据整理成 PyTorch 能够接受的格式
# 将数据做成数据集的模样
X = np.expand_dims(x, axis=1)  # 使 X 变为二维数组
Y = y.reshape(400, -1)  # Y 为一列的数组
dataset = TensorDataset(torch.tensor(X, dtype=torch.float), torch.tensor(Y, dtype=torch.float))
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)

通过 TensorDataset 将 x 和 y 数据捆绑成一个数据集,并使用 DataLoader 来批量加载数据,设置 batch_size=10,并启用数据打乱(shuffle=True)以增加模型训练的随机性。

  1. 构建神经网络
    接下来,我们将构建一个简单的神经网络来拟合这些数据。在这个例子中,我们使用了一个全连接的神经网络,并采用了 ReLU 激活函数。网络的结构如下:

输入层:1 个神经元(因为我们的输入是一个 1D 数值)。
隐藏层 1:10 个神经元,使用 ReLU 激活函数。
隐藏层 2:100 个神经元,使用 ReLU 激活函数。
隐藏层 3:10 个神经元,使用 ReLU 激活函数。
输出层:1 个神经元,输出拟合的结果。

import torch.nn as nnclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.net = nn.Sequential(nn.Linear(in_features=1, out_features=10), nn.ReLU(),nn.Linear(10, 100), nn.ReLU(),nn.Linear(100, 10), nn.ReLU(),nn.Linear(10, 1))def forward(self, input: torch.FloatTensor):return self.net(input)# 创建模型实例
net = Net()
net 
Net((net): Sequential((0): Linear(in_features=1, out_features=10, bias=True)(1): ReLU()(2): Linear(in_features=10, out_features=100, bias=True)(3): ReLU()(4): Linear(in_features=100, out_features=10, bias=True)(5): ReLU()(6): Linear(in_features=10, out_features=1, bias=True))
)

这段代码定义了一个简单的神经网络类 Net,它继承自 nn.Module。通过 nn.Sequential 来堆叠多个层,使得网络的结构更加简洁和易于理解。每一层都紧跟着一个 ReLU 激活函数,用于引入非线性特征。

  1. 训练模型
    接下来,我们开始训练模型。我们选择 Adam 优化器,并使用均方误差(MSE)作为损失函数。在每个 epoch 中,我们都会迭代一次所有的训练数据,通过反向传播更新模型参数。
# 设置优化器和损失函数
optim = torch.optim.Adam(net.parameters(), lr=0.001)
Loss = nn.MSELoss()# 训练模型
for epoch in range(100):loss = Nonefor batch_x, batch_y in dataloader:# 前向传播y_predict = net(batch_x)# 计算损失loss = Loss(y_predict, batch_y)# 清空梯度optim.zero_grad()# 反向传播loss.backward()# 更新参数optim.step()# 每10步打印一次训练日志if (epoch + 1) % 10 == 0:print(f"训练步骤: {epoch+1}, 模型损失: {loss.item()}")
训练步骤: 10, 模型损失: 0.12506699562072754
训练步骤: 20, 模型损失: 0.024437546730041504
训练步骤: 30, 模型损失: 0.08189699053764343
训练步骤: 40, 模型损失: 0.03138166293501854
训练步骤: 50, 模型损失: 0.00651053711771965
训练步骤: 60, 模型损失: 0.0032562180422246456
训练步骤: 70, 模型损失: 0.00018047125195153058
训练步骤: 80, 模型损失: 0.005476313643157482
训练步骤: 90, 模型损失: 0.0014593529049307108
训练步骤: 100, 模型损失: 0.0008746677194721997
  1. 可视化
    训练完成后,我们可以使用训练好的模型来进行预测,并将预测结果与真实数据进行比较。
# 绘制真实数据与预测数据的对比
plt.figure(figsize=(12, 7), dpi=160)
plt.plot(x, y, label="实际值", marker="X")
plt.plot(x, predict.detach().numpy(), label="预测值", marker='o')
plt.xlabel("x", size=15)
plt.ylabel("cos(x)", size=15)
plt.xticks(size=15)
plt.yticks(size=15)
plt.legend(fontsize=15)
plt.show()

在这里插入图片描述
通过绘制图表,我们可以清楚地看到,训练好的神经网络已经很好地拟合了余弦函数,并且与真实数据非常接近。

** 通过本篇教程,我们了解了如何使用 PyTorch 从零开始构建神经网络,并使用该网络拟合一个简单的余弦曲线。我们逐步演示了数据准备、网络构建、模型训练以及预测可视化的过程。希望通过这篇文章,你能够掌握神经网络的基本操作,并能够将其应用于其他任务中。**

如果你有任何问题或建议,欢迎在评论区留言交流!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/66133.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

编程入门(2)-2024年 RAD Studio version 12发布综述

随着2024年即将画上句号,我想借此机会回顾一下我们在这一年中发布的一些Embarcadero产品、行业趋势,并感谢我们尊贵的客户们对我们的产品一如既往的支持。这一年对我们来说充满了激动人心的变化和发展,我们非常高兴能与您一起踏上这段旅程。 …

visual studio 安全模式

一、安全模式: 在 Visual Studio 中,安全模式是一种启动方式,允许你在禁用所有扩展和自定义设置的情况下启动 Visual Studio。这个模式可以帮助排除插件或扩展引起的问题,特别是在 Visual Studio 无法正常启动时。 二、安全模式下…

RocketMQ消费者如何消费消息以及ack

1.前言 此文章是在儒猿课程中的学习笔记,感兴趣的想看原来的课程可以去咨询儒猿课堂 这篇文章紧挨着上一篇博客来进行编写,有些不清楚的可以看下上一篇博客: https://blog.csdn.net/u013127325/article/details/144934073 2.broker是如何…

EasyExcel自定义动态下拉框(附加业务对象转换功能)

全文直接复制粘贴即可,测试无误 一、注解类 1、ExcelSelected.java 设置下拉框 Documented Target({ElementType.FIELD})//用此注解用在属性上。 Retention(RetentionPolicy.RUNTIME)//注解不仅被保存到class文件中,jvm加载class文件之后&#xff0c…

【2025最新计算机毕业设计】基于Spring Boot+Vue影院购票系统(高质量源码,提供文档,免费部署到本地)

作者简介:✌CSDN新星计划导师、Java领域优质创作者、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和学生毕业项目实战,高校老师/讲师/同行前辈交流。✌ 主要内容:🌟Java项目、Python项目、前端项目、PHP、ASP.NET、人工智能…

信息科技伦理与道德1:研究方法

1 问题描述 1.1 讨论? 请挑一项信息技术,谈一谈为什么认为他是道德的/不道德的,或者根据使用场景才能判断是否道德。判断的依据是什么(自身的道德准则)?为什么你觉得你的道德准则是合理的,其他…

Web安全扫盲

1、建立网络思维模型的必要 1 . 我们只有知道了通信原理, 才能够清楚的知道数据的交换过程。 2 . 我们只有知道了网络架构, 才能够清楚的、准确的寻找漏洞。 2、局域网的简单通信 局域网的简单通信(数据链路层) 一般局域网都通…

Linux驱动开发(18):linux驱动并发与竞态

并发是指多个执行单元同时、并行执行,而并发的执行单元对共享资源(硬件资源和软件上的全局变量、静态变量等)的访问 则很容易导致竞态。对于多核系统,很容易理解,由于多个CPU同时执行,多个CPU同时读、写共享资源时很容易造成竞态。…

009:传统计算机视觉之边缘检测

本文为合集收录,欢迎查看合集/专栏链接进行全部合集的系统学习。 合集完整版请参考这里。 本节来看一个利用传统计算机视觉方法来实现图片边缘检测的方法。 什么是边缘检测? 边缘检测是通过一些算法来识别图像中物体之间或者物体与背景之间的边界&…

QML使用Popup实现弹出Message

方案一:popup import QtQuick 2.15 import QtQuick.Controls 2.15 import QtQuick.Layouts 1.15ApplicationWindow {visible: truewidth: 640height: 480title: qsTr("Top Message Popup Example")ColumnLayout {anchors.centerIn: parentspacing: 10Butt…

idea java.lang.OutOfMemoryError: GC overhead limit exceeded

Idea build项目直接报错 java: GC overhead limit exceeded java.lang.OutOfMemoryError: GC overhead limit exceeded 设置 编译器 原先heap size 设置的是 700M , 改成 2048M即可

boot-126网易邮件发送

【SpringBoot整合JavaMail发送邮件】 一 . Java Mail基本概念 1.SMTP Simple Mail Transfer Protocol:简单邮件传输协议,用于发送邮件的协议。 2.POP3 Post office Protocol 3:邮局通讯协议第三版,用于接收邮件的标准协议。 3.IMAP Internet Message Acc…

【ArcGISPro/GeoScenePro】检查多光谱影像的属性并优化其外观

数据 https://arcgis.com/sharing/rest/content/items/535efce0e3a04c8790ed7cc7ea96d02d/data 操作 其他数据 检查影像的属性 熟悉检查您正在使用的栅格属性非常重要。

音视频入门基础:MPEG2-PS专题(4)——FFmpeg源码中,判断某文件是否为PS文件的实现

一、引言 通过FFmpeg命令: ./ffmpeg -i XXX.ps 可以判断出某个文件是否为PS文件: 所以FFmpeg是怎样判断出某个文件是否为PS文件呢?它内部其实是通过mpegps_probe函数来判断的。从《FFmpeg源码:av_probe_input_format3函数和AVI…

[Python学习日记-74] 面向对象实战2——选课系统

[Python学习日记-74] 面向对象实战2——选课系统 简介 开发要求 实现:选课系统 简介 在前面的《年会答题系统》当中我们介绍了面向对象软件开发的一些流程,当然这一流程只是涵括了大部分的,目前在业界也没有一个统一的标准,每个…

用户注册模块(芒果头条项目进度4)

1 创建⽤户模块⼦应⽤ 1.1 在项⽬包⽬录下 创建apps的python包。 1.2 在apps包下 创建应⽤userapp $ cd 项⽬包⽬录/apps $ python ../../manage.py startapp userapp 1.3 配置导包路径 默认情况下导包路径指向项⽬根⽬录 # 通过下⾯语句可以打印当前导包路径 print(sys.pa…

Elasticsearch:利用 AutoOps 检测长时间运行的搜索查询

作者:来自 Elastic Valentin Crettaz 了解 AutoOps 如何帮助你调查困扰集群的长期搜索查询以提高搜索性能。 AutoOps 于 11 月初在 Elastic Cloud Hosted 上发布,它通过性能建议、资源利用率和成本洞察、实时问题检测和解决路径显著简化了集群管理。 Au…

关于Flutter应用国际化语言的设置

目录 1. Locale配置 2. 用户切换/启动自动加载缓存里面的locale 由于最近在开发app国际化设置的时候遇到一些问题,所以做出一些总结。 1. Locale配置 具体的初始化配置可以参考文档:i18n | Flutter 中文文档 - Flutter 中文开发者网站 - Flutter 值得…

基层医联体医院患者历史检验检查数据的快速Python编程分析

​​​​​​​ 一、引言 1.1 研究背景与意义 在当今数字化医疗时代,医疗数据呈爆炸式增长,涵盖患者的基本信息、病史、检验检查结果、治疗方案等各个维度。这些海量且复杂的数据蕴含着巨大价值,为精准医疗决策提供了关键依据。通过对患者历史检验检查数据的深入对比分析…

如何使用OpenCV进行抓图-多线程

前言 需求: 1、如何使用OpenCV捕抓Windows电脑上USB摄像头的流、 2、采用多线程 3、获知当前摄像头的帧率。 这个需求,之前就有做了,但是由于出现了一个问题,人家摄像头的帧率目前都可以达到60帧/s 了,而我的程序…