基于神经网络的弹弹堂类游戏弹道快速预测

目录

一、 目的... 1

1.1 输入与输出.... 1

1.2 隐网络架构设计.... 1

1.3 激活函数与损失函数.... 1

二、 训练... 2

2.1 数据加载与预处理.... 2

2.2 训练过程.... 2

2.3 训练参数与设置.... 2

三、 测试与分析... 2

3.1     性能对比.... 2

3.2     训练过程差异.... 3

四、 训练过程中的损失变化... 3

五、 代码... 6


一、目的

在机器学习中,神经网络是解决回归和分类问题的强大工具。本文通过对比全连接神经网络(SimpleNN)在不同激活函数下的表现,探索不同激活函数对模型训练过程和最终性能的影响。本实验通过使用PyTorch框架,首先使用ReLU激活函数,之后将激活函数切换为tanh,分析这两种激活函数在回归问题中的差异。

1.1 输入与输出

本实验中的神经网络模型输入的是来自MATLAB文件(data.mat)的数据集,其中包括4个输入特征和1个输出标签。数据通过标准化处理后输入神经网络,网络模型通过学习特征和标签之间的关系来预测输出。最终网络输出为一个连续值,即回归问题中的预测值。

1.2 隐网络架构设计

SimpleNN模型:
本实验使用了一个简单的前馈神经网络模型,包含一个输入层、一个隐藏层和一个输出层。输入层的节点数与特征数量相同,输出层的节点数与标签数量相同。隐藏层的节点数设置为10。激活函数用于隐藏层的神经元,以增加模型的非线性表达能力。

在此实验中,我们首先使用了ReLU激活函数进行训练,然后将激活函数替换为tanh进行对比分析。

1.3 激活函数与损失函数

  1. 激活函数选择
  • ReLU(Rectified Linear Unit):
    是一种常用的激活函数,其输出为正输入或零。ReLU有助于缓解梯度消失问题,并加速神经网络的训练。

  • tanh(双曲正切函数):
    是一种平滑的非线性激活函数,其输出范围为-1到1。与ReLU相比,tanh的输出范围较小,并且存在梯度消失的风险,但它能够处理负值输入,适用于某些回归任务。

  1. 损失函数选择
    本实验使用均方误差(MSE)作为损失函数,用于回归任务中度量模型预测与真实输出之间的差异。

二、训练

2.1 数据加载与预处理

数据集来自MATLAB的.mat文件。输入特征(4个)和输出标签(1个)首先被提取,并通过MinMaxScaler进行归一化处理。数据集被随机分割为训练集和测试集,其中50个样本用于测试,剩余的用于训练。

2.2 训练过程

网络通过3000次迭代进行训练。在每一次迭代中,模型使用训练数据进行前向传播,计算预测结果与真实标签之间的损失。然后进行反向传播,更新网络的参数。训练的停止条件为损失低于设定阈值(1e-14)。

2.3 训练参数与设置

训练过程中使用的主要参数如下:

  • 学习率: 0.001
  • 训练轮次: 最大3000次,或提前停止
  • 损失函数: 均方误差(MSE)损失函数
  • 优化器: Adam

三、测试与分析

3.1 性能对比

  • 使用ReLU激活函数时:
    在训练过程中,模型的损失函数逐渐下降,表现出良好的学习效果。最终损失值趋近于0,表明网络能够较好地拟合训练数据。测试时,模型能够有效地预测测试集的数据,偏差较小。

  • 使用tanh激活函数时:
    与ReLU相比,使用tanh激活函数时,损失下降的速度较慢,且网络训练的初期出现较大的波动。这可能与tanh的输出范围(-1到1)有关,导致梯度消失问题,尤其是在多层网络中。

3.2 训练过程差异

  1. 收敛速度
  • ReLU: 在训练初期收敛较快,且表现出较好的梯度更新能力。在训练过程中,模型的准确性和损失函数下降速度较为平稳。

  • tanh: 收敛速度较慢,且在训练初期存在较大的梯度波动。由于其在负输入下的饱和特性,可能导致梯度更新较慢,尤其是在深层网络中。

  1. 偏差分析
  • 使用ReLU时: 偏差较小,模型预测与实际值之间的差异较少,说明模型具有较好的预测能力。

  • 使用tanh时: 偏差稍大,尤其是在某些测试样本上。虽然损失函数已经较低,但由于tanh的输出范围限制,模型在某些输入上可能无法达到完全准确的预测。


四、训练过程中的损失变化

图 1: ReLU训练损失曲线
图 2: ReLU测试数据集结果图
图 3: tanh训练损失曲线
图 4: tanh测试数据集结果图

1 relu训练损失曲线

2 relu测试数据集结果图

3 tanh训练损失曲线

4 tanh测试数据集结果图

  • 代码

import numpy as np

import torch

import torch.nn as nn

import torch.optim as optim

from sklearn.preprocessing import MinMaxScaler

import matplotlib.pyplot as plt

import scipy.io

# 加载 .mat 文件(替换为实际的路径)

data = scipy.io.loadmat('E:\\Learn Project\\matlab_pjt\\data.mat')

# 获取数据

data = data['data']

# 输入和输出数据

inputs = data[:, :4# 输入特征

outputs = data[:, 4:]  # 输出标签

# 随机分割数据为训练集和测试集

test_size = 50  # 测试集大小

indices = np.random.permutation(len(inputs))

train_indices = indices[test_size:]

test_indices = indices[:test_size]

input_train = inputs[train_indices]

output_train = outputs[train_indices]

input_test = inputs[test_indices]

output_test = outputs[test_indices]

# 数据归一化

scaler_input = MinMaxScaler()

scaler_output = MinMaxScaler()

input_train_scaled = scaler_input.fit_transform(input_train)

output_train_scaled = scaler_output.fit_transform(output_train)

input_test_scaled = scaler_input.transform(input_test)

output_test_scaled = scaler_output.transform(output_test)

# 转换为 PyTorch 张量

X_train_tensor = torch.tensor(input_train_scaled, dtype=torch.float32)

y_train_tensor = torch.tensor(output_train_scaled, dtype=torch.float32)

X_test_tensor = torch.tensor(input_test_scaled, dtype=torch.float32)

y_test_tensor = torch.tensor(output_test_scaled, dtype=torch.float32)

# 定义简单的神经网络

class SimpleNN(nn.Module):

    def __init__(self, input_size, hidden_size, output_size):

        super(SimpleNN, self).__init__()

        self.layer1 = nn.Linear(input_size, hidden_size)

        self.layer2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):

        x = torch.relu(self.layer1(x))  # 激活函数 ReLU

        x = self.layer2(x)  # 输出层

        return x

# 网络参数

input_size = input_train.shape[1]

hidden_size = 10

output_size = output_train.shape[1]

# 创建模型

model = SimpleNN(input_size, hidden_size, output_size)

# 损失函数和优化器

criterion = nn.MSELoss()  # 均方误差损失

optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练模型

epochs = 3000

loss_history = []  # 保存损失变化

for epoch in range(epochs):

    optimizer.zero_grad()  # 清空梯度

    output = model(X_train_tensor)  # 前向传播

    loss = criterion(output, y_train_tensor)  # 计算损失

    loss.backward()  # 反向传播

    optimizer.step()  # 更新参数

    # 记录损失

    loss_history.append(loss.item())

    # 停止条件

    if loss.item() < 1e-14:

        print(f"训练提前停止,当前迭代:{epoch}")

        break

# 绘制训练损失图

plt.plot(loss_history)

plt.xlabel('Epoch')

plt.ylabel('Loss (MSE)')

plt.title('Training Loss History')

plt.show()

# 测试模型

with torch.no_grad():

    model.eval()  # 设置模型为评估模式

    y_test_pred_scaled = model(X_test_tensor)  # 预测

    y_test_pred = scaler_output.inverse_transform(y_test_pred_scaled.numpy())  # 反归一化

# 计算每个样本的偏差

deviation = np.sqrt(np.sum((output_test - y_test_pred) ** 2, axis=1))  # 欧几里得距离

# 绘制偏差图

plt.plot(deviation, marker='o', color='red')

plt.xlabel('Sample Index')

plt.ylabel('Deviation')

plt.title('Test Deviation')

plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/888653.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Xlsxwriter生成Excel文件时TypeError异常处理

在使用 XlsxWriter 生成 Excel 文件时&#xff0c;如果遇到 TypeError&#xff0c;通常是因为尝试写入的值或格式与 XlsxWriter 的限制或要求不兼容。 1、问题背景 在使用 Xlsxwriter 库生成 Excel 文件时&#xff0c;出现 TypeError: “expected string or buffer” 异常。此…

MATLAB期末复习笔记(下)

目录 五、数据和函数的可视化 1.MATLAB的可视化对象 2.二维图形的绘制 3.图形标识 4.多子图绘图 5.直方图的绘制 &#xff08;1&#xff09;分类 &#xff08;2&#xff09;垂直累计式 &#xff08;3&#xff09;垂直分组式 &#xff08;4&#xff09;水平分组式 &…

操作系统学习

问题&#xff1a; 因为想用傲梅来给系统盘扩容&#xff0c;导致无法进入操作系统&#xff0c;报错如下&#xff1a; 无法加载应用程序或操作系统&#xff0c;原因是所需文件丢失或包含错误. 文件:Windowslsystem32lwinload.efi错误代码: 0xc000007b 你需要使用恢复工具。如果…

【环境搭建】Python、PyTorch与cuda的版本对应表

一个愿意伫立在巨人肩膀上的农民...... 在深度学习的世界里&#xff0c;选择合适的工具版本是项目成功的关键。CUDA、PyTorch和Python作为深度学习的三大支柱&#xff0c;它们的版本匹配问题不容忽视。错误的版本组合可能导致兼容性问题、性能下降甚至项目失败。因此&#xff0…

No.26 笔记 | 信息收集与工具实践指南

渗透测试的第一步&#xff1a;信息收集背后的“侦察艺术” 在网络安全的世界里&#xff0c;信息就是武器。 无论是追踪隐藏的漏洞&#xff0c;还是找到不被注意的入口&#xff0c;信息收集就像一场现代化的“谍战片”。而作为渗透测试的开场白&#xff0c;信息收集不仅考验技…

计算机网络 第5章 运输层

计算机网络 &#xff08;第8版&#xff09; 第 5 章 传输层5.4 可靠传输的原理5.4.1 停止等待协议5.4.2 连续ARQ协议 5.5 TCP报文段的首部格式5.6 TCP可靠传输的实现5.6.1 以字节为单位的滑动窗口5.6.2 超时重传时间的选择 5.7 TCP的流量控制5.7.1 利用滑动窗口实现流量控制 5.…

股指期货基差的影响因素有哪些?

在股指期货交易中&#xff0c;有一个重要的概念叫做“基差”。简单来说&#xff0c;基差就是股指期货价格与其对应的现货价格之间的差异。比如&#xff0c;我们现在有IC2401股指期货&#xff0c;它挂钩的是中证500指数。如果IC2401的价格是5244&#xff0c;而中证500指数的价格…

智能社区服务小程序+ssm(lw+演示+源码+运行)

摘要 随着信息技术在管理上越来越深入而广泛的应用&#xff0c;管理信息系统的实施在技术上已逐步成熟。本文介绍了智能社区服务小程序的开发全过程。通过分析智能社区服务小程序管理的不足&#xff0c;创建了一个计算机管理智能社区服务小程序的方案。文章介绍了智能社区服务…

用人话讲计算机:Python篇!(十一)相对路径与绝对路径

目录 一、计算机中的路径 &#xff08;1&#xff09;什么叫路径 &#xff08;2&#xff09;绝对路径 &#xff08;3&#xff09;相对路径 二、Python中的路径 &#xff08;1&#xff09;绝对路径 &#xff08;2&#xff09;相对路径 &#xff08;3&#xff09;总结 一、…

基于VTX356语音识别合成芯片的智能语音交互闹钟方案

一、方案概述 本方案旨在利用VTX356语音识别合成芯片强大的语音处理能力&#xff0c;结合蓝牙功能、APP或小程序&#xff0c;打造一款功能全面且智能化程度高的闹钟产品。除了基本的时钟显示和闹钟提醒功能外&#xff0c;还拥有正计时、倒计时、日程安排、重要日提醒以及番茄钟…

MFC图形函数学习13——在图形界面输出文字

本篇是图形函数学习的最后一篇&#xff0c;相关内容暂告一段落。 在图形界面输出文字&#xff0c;涉及文字字体、大小、颜色、背景、显示等问题&#xff0c;完成这些需要系列函数的支持。下面做简要介绍。 一、输出文本函数 原型&#xff1a;virtual BOOL te…

【CANoe示例分析】Basic UDP Multicast(CAPL)

1、工程路径 C:\Users\Public\Documents\Vector\CANoe\Sample Configurations 16.6.2\Ethernet\Simulation\UDPBasicCAPLMulticast 在CANoe软件上也可以打开此工程:File|Sample Configurations|Ethernet - Simulation of Ethernet ECUs|Basic UDP Multicast(CAPL) 2、示例目…

【动手学电机驱动】STM32-FOC(10)使用旋钮调节电机转速

STM32-FOC&#xff08;1&#xff09;STM32 电机控制的软件开发环境 STM32-FOC&#xff08;2&#xff09;STM32 导入和创建项目 STM32-FOC&#xff08;3&#xff09;STM32 三路互补 PWM 输出 STM32-FOC&#xff08;4&#xff09;IHM03 电机控制套件介绍 STM32-FOC&#xff08;5&…

最新,Vue 性能提升 400%

最近&#xff0c;Vue 团队核心成员 Johnson Chu 开源一个全新的信号库&#xff1a;alien-signals&#xff0c;这是一个基于 Vue 3.4 响应式系统重写的研究型信号库&#xff0c;可以使 Vue 3.4 的响应式系统性能提升 400%。目前&#xff0c;alien-signals 是所有信号库中最快的实…

springboot mvn 打包,jar和资源文件分离打包

默认打包方式如下&#xff1a; <build><finalName>${project.artifactId}</finalName><plugins><plugin><groupId>org.springframework.boot</groupId><artifactId>spring-boot-maven-plugin</artifactId><execution…

OpenHarmony-3.HDF框架(2)

OpenHarmony HDF 平台驱动 1.平台驱动概述 系统平台驱动框架是系统驱动框架的重要组成部分&#xff0c;它基于HDF驱动框架、操作系统适配层(OSAL, operating system abstraction layer)以及驱动配置管理机制&#xff0c;为各类平台设备驱动的实现提供标准模型。 系统平台驱动(…

BT1120接口自学笔记

一、技术简介 1.1名词解释 BT.1120协议是一种广泛应用的高清数字视频传输协议,能够把取样结构为4:4:4和4:4:2的视频数据编码成内嵌同步定时基准码的视频数据流进行传输。也可以用于ITU-R BT.709建议书和ITU-R BT.2100建议书规定的像素阵列为1 920*1080视屏数据传输。 经常听…

pdf转word/markdown等格式——MinerU的部署:2024最新的智能数据提取工具

一、简介 MinerU是开源、高质量的数据提取工具&#xff0c;支持多源数据、深度挖掘、自定义规则、快速提取等。含数据采集、处理、存储模块及用户界面&#xff0c;适用于学术、商业、金融、法律等多领域&#xff0c;提高数据获取效率。一站式、开源、高质量的数据提取工具&…

探索前端世界的无限可能:玩转Excel文件

&#x1f31f; 前言 欢迎来到我的技术小宇宙&#xff01;&#x1f30c; 这里不仅是我记录技术点滴的后花园&#xff0c;也是我分享学习心得和项目经验的乐园。&#x1f4da; 无论你是技术小白还是资深大牛&#xff0c;这里总有一些内容能触动你的好奇心。&#x1f50d; &#x…

MySQL两阶段提交目的

阶段提交的过程 事务执行阶段&#xff1a;事务开始执行&#xff0c;InnoDB执行SQL语句的具体操作&#xff0c;如数据修改、删除等&#xff0c;并将这些操作记录在内存中。写入Redo Log&#xff08;准备阶段&#xff09;&#xff1a;事务即将提交时&#xff0c;首先将事务相关的…