使用PyTorch处理多维特征输入的完美指南

💗💗💗欢迎来到我的博客,你将找到有关如何使用技术解决问题的文章,也会找到某个技术的学习路线。无论你是何种职业,我都希望我的博客对你有所帮助。最后不要忘记订阅我的博客以获取最新文章,也欢迎在文章下方留下你的评论和反馈。我期待着与你分享知识、互相学习和建立一个积极的社区。谢谢你的光临,让我们一起踏上这个知识之旅!
请添加图片描述

文章目录

  • 🥦引言
  • 🥦前期的回顾与准备
  • 🥦代码实现
  • 🥦总结

🥦引言

在机器学习和深度学习领域,我们经常会面对具有多维特征输入的问题。这种情况出现在各种应用中,包括图像识别、自然语言处理、时间序列分析等。PyTorch是一个强大的深度学习框架,它提供了丰富的工具和库,可以帮助我们有效地处理这些多维特征输入数据。在本篇博客中,我们将探讨如何使用PyTorch来处理多维特征输入数据。

🥦前期的回顾与准备

这里我们采用一组预测糖尿病的数据集,如下图
在这里插入图片描述
这里的每一行代表一个样本,同样的,每一列代表什么呢,代表一个特征,如下图。所以糖尿病的预测由下面这八个特征共同进行决定
在这里插入图片描述
按照过去的逻辑回归,应该是下图所示的,因为这是单特征值
在这里插入图片描述
但是现在由单特征值已经转变为多特征值了,所以我们需要对每个特征值进行处理,如下图
在这里插入图片描述
中间的特征值与权重的点乘可以从矩阵的形式进行表现
在这里插入图片描述
因为逻辑回归所以还有套一个Sigmoid函数,通常情况下我们将函数内的整体成为z(i)
在这里插入图片描述

注意: Sigmoid函数是一个按向量方式实现的

下面我们从矩阵相乘的形式进行展示,说明可以将一组方程合并为矩阵运算,可以想象为拼接哈。这样的目的是转化为并行运算,从而实现更快的运行速度。
在这里插入图片描述
所以从代码的角度去修改,我们只需要改变一下维度就行了

class Model(torch.nn.Module):def __init__(self):super(Model, self).__init__()self.linear = torch.nn.Linear(8, 1) self.sigmoid = torch.nn.Sigmoid()def forward(self, x):x = self.sigmoid(self.linear(x)) return x
model = Model()

这里的输入维度设置为8,就像上图中展示的x一样是N×8形式的矩阵,而 y ^ \hat{y} y^是一个N×1的矩阵。
这里我们将矩阵看做是一个空间变换的函数

我们可以从下图很好的展示多层神经网络的变换
在这里插入图片描述

从一开始的属于8维变为输出6维,再从输入的6维变为输出的4维,最后从输入的4维变为输出的1维。

如果从代码的角度去写,可以从下面的代码进行实现

class Model(torch.nn.Module):def __init__(self):super(Model, self).__init__()self.linear1 = torch.nn.Linear(8, 6) self.linear2 = torch.nn.Linear(6, 4) self.linear3 = torch.nn.Linear(4, 1) self.sigmoid = torch.nn.Sigmoid()def forward(self, x):x = self.sigmoid(self.linear1(x)) x = self.sigmoid(self.linear2(x)) x = self.sigmoid(self.linear3(x)) return x
model = Model()

这里我说明一下下面这条语句

  • self.sigmoid = torch.nn.Sigmoid():这一行创建了一个 Sigmoid 激活函数的实例,用于在神经网络的正向传播中引入非线性。

后面的前向计算就是一层的输出是另一层输入进行传,最后将 y ^ \hat{y} y^返回


同时我们的损失函数也没有变化,更新函数也没有变化,采用交叉熵和梯度下降
在这里插入图片描述

刘二大人这里没有使用Mini-Batch进行批量,后续的学习应该会更新
在这里插入图片描述

🥦代码实现

import torch
import torch.nn as nn
import torch.optim as optim
from sklearn import datasets
from sklearn.model_selection import train_test_split
import numpy as np# 载入Diabetes数据集
diabetes = datasets.load_diabetes()# 将数据集拆分为特征和目标
X = diabetes.data  # 特征
y = diabetes.target  # 目标# 数据预处理
X = (X - np.mean(X, axis=0)) / np.std(X, axis=0)  # 特征标准化# 拆分数据集为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 转换为PyTorch张量
X_train = torch.FloatTensor(X_train)
y_train = torch.FloatTensor(y_train).view(-1, 1)  # 将目标变量转换为列向量
X_test = torch.FloatTensor(X_test)
y_test = torch.FloatTensor(y_test).view(-1, 1)# 构建包含多个线性层的神经网络模型
class DiabetesModel(nn.Module):def __init__(self, input_size):super(DiabetesModel, self).__init__()self.fc1 = nn.Linear(input_size, 64)  # 第一个线性层self.fc2 = nn.Linear(64, 32)  # 第二个线性层self.fc3 = nn.Linear(32, 1)  # 最终输出线性层def forward(self, x):x = torch.relu(self.fc1(x))  # ReLU激活函数x = torch.relu(self.fc2(x))x = self.fc3(x)return x# 初始化模型
input_size = X_train.shape[1]
model = DiabetesModel(input_size)# 定义损失函数和优化器
criterion = nn.MSELoss()  # 均方误差损失
optimizer = optim.SGD(model.parameters(), lr=0.01)# 训练模型
num_epochs = 1000
for epoch in range(num_epochs):# 前向传播outputs = model(X_train)loss = criterion(outputs, y_train)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()if (epoch + 1) % 100 == 0:print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')# 在测试集上进行预测
model.eval()
with torch.no_grad():y_pred = model(X_test)# 计算性能指标
mse = nn.MSELoss()(y_pred, y_test)
print(f"均方误差 (MSE): {mse.item():.4f}")

运行结果如下
在这里插入图片描述

感兴趣的同学可以使用不同的激活函数一一测试一下

比如我使用tanh函数测试后得到的均方误差就小了许多
在这里插入图片描述

此链接是GitHub上的大佬做的可视化函数https://dashee87.github.io/deep%20learning/visualising-activation-functions-in-neural-networks/

🥦总结

这就是使用PyTorch处理多维特征输入的基本流程。当然,实际应用中,你可能需要更复杂的神经网络结构,更大的数据集,以及更多的调优和正则化技巧。但这个指南可以帮助你入门如何处理多维特征输入的问题,并利用PyTorch构建强大的深度学习模型。希望这篇博客对你有所帮助!

请添加图片描述

挑战与创造都是很痛苦的,但是很充实。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/104330.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C++ 模板同时使用默认参数和偏特化

C 模板同时使用默认参数和偏特化 正常的偏特化都很简单,但是如果和默认参数碰到一起就会复杂一点点 先讲答案(按照顺序): 如果显示指定了,首先看显示指定的如果没显示指定,就找是否有默认参数的版本如果都…

Element Plus阻止 el-dropdown、el-switch等冒泡事件

最近做vue3项目&#xff0c;使用Element Plus,又遇到坑了&#xff01; 问题点&#xff1a;组件中遇到事件冒泡问题了&#xff0c;el-checkbox 中 change事件要求阻止冒泡&#xff0c;如下代码中要求点击checkbox时不调用li标签的show方法 <li click"show()">…

STM32物联网基于ZigBee智能家居控制系统

实践制作DIY- GC0169-ZigBee智能家居 一、功能说明&#xff1a; 基于STM32单片机设计-ZigBee智能家居 二、功能介绍&#xff1a; 1个主机显示板&#xff1a;STM32F103C最小系统ZigBee无线模块OLED显示器 语音识别模块多个按键ESP8266-WIFI模块&#xff08;仅WIFI版本有&…

Spring Cloud Pipelines 入门实践

文章目录 1. 前言2. Spring Cloud Pipelines 是做什么的2.1. 预定义的流程2.2. 集成测试和契约测试2.3.部署策略 4. Spring Cloud Pipelines的使用示例4.1. 创建一个Spring Boot应用4.2. 将代码托管到GitHub仓库4.3. 添加Spring Cloud Pipelines依赖4.4. 配置Spring Cloud Pipe…

生成Android证书

前提&#xff1a;确保本机安装好了java1.8 第一步 打开CMD keytool -genkey -alias AAAA -keyalg RSA -keysize 2048 -validity 36500 -keystore D:\mygitee\keyStore\szxApp.jksD:\mygitee\keyStore\szxApp.jks是证书要生成的位置 szxApp.jks 是证书名称 D:\mygitee\keySto…

volatile为什么无法保证原子性

假设定义 volatile int i 0; 现在2个线程同时 i&#xff0c;为什么数据还可能会出错&#xff1f;一起来看下图&#xff0c;虽然volatile的机制是&#xff1a;如果volatile修饰的变量有修改&#xff0c;那么会将变更内容写回主内存&#xff0c;同时让其他线程工作内存的该变量缓…

Pytorvh之Vision Transformer图像分类

文章目录 前言一、Transformer1.Transformer概览2.Self-Attention3.Multi-head Attention4.Position-wise Feed-Forward Networks(位置前馈网络)5.残差连接和层归一化6.Positional Encodings(位置编码) 二、Vision Transformer1.Vision Transformer概览2.Embedding层结构&#…

C++中resize和reserve

1.reserve(n)对capacity操作 capacity < n : 扩容capacity > n : 不操作 2.resize(n, m)对size操作 size < n : size增加到n 增加的值为msize > n : size减小到ncapacity < n : 先增大容量至n 再增大size至n 增加的值为m

计算机网络 | OSI 参考模型

计算机网络 | OSI 参考模型 计算机网络 | OSI 参考模型应用层表示层会话层传输层网络层数据链路层物理层 参考视频&#xff1a;王道计算机考研 计算机网络 参考书&#xff1a;《2022年计算机网络考研复习指导》 计算机网络 | OSI 参考模型 OSI 参考模型自下而上分为7层&…

Text-to-SQL小白入门(八)RLAIF论文:AI代替人类反馈的强化学习

学习RLAIF论文前&#xff0c;可以先学习一下基于人类反馈的强化学习RLHF&#xff0c;相关的微调方法&#xff08;比如强化学习系列RLHF、RRHF、RLTF、RRTF&#xff09;的论文、数据集、代码等汇总都可以参考GitHub项目&#xff1a;GitHub - eosphoros-ai/Awesome-Text2SQL: Cur…

mysql5.7获取json数组中的某个对象

前言 表中的一个字段类型是字符串&#xff0c;存的是一个对象数据。 现在要根据对象中的某个属性&#xff0c;获取到整个对象信息。 如果是mysql8&#xff0c;则可以使用JSON_TABLE。 示例&#xff1a;https://blog.csdn.net/weixin_44071721/article/details/123347229 sele…

MIT 6.S081 Operating System/Fall 2020 macOS搭建risc-v与xv6开发调试环境

文章目录 本机配置安装环境Homebrew执行安装脚本查看安装是否成功 RISC-V tools执行brew的安装脚本 QEMUXV6 测试有用的参考链接&#xff08;感谢前辈&#xff09;写在结尾 本机配置 电脑型号&#xff1a;Apple M2 Pro 2023 操作系统&#xff1a;macOS Ventura 13.4 所以我的电…

C# 处理TCP数据的类(服务端)

using System; using System.Collections.Generic; using System.Net; using System.Net.Sockets; using System.Threading;namespace TestDemo {/// <summary>/// 处理TCP数据的类&#xff08;服务端&#xff09;/// </summary>public class TcpService{/// <s…

Java中ArrayList 和 LinkedList 的区别是什么?

Java中ArrayList 和 LinkedList 的区别是什么&#xff1f; ArrayList 和 LinkedList 都是Java中常用的集合类&#xff0c;它们用于存储和操作元素的容器&#xff0c;但在内部实现和使用方式上有很大的区别。以下是它们的区别、作用、优缺点以及示例说明&#xff1a; 区别&…

Linux 安全 - LSM hook点

文章目录 一、LSM file system hooks1.1 LSM super_block hooks1.2 LSM file hooks1.3 LSM inode hooks 二、LSM Task hooks三、LSM IPC hooks四、LSM Network hooks五、LSM Module & System hooks 一、LSM file system hooks 在VFS&#xff08;虚拟文件系统&#xff09;层…

JVM第六讲:JVM 基础 - Java 内存模型引入

JVM 基础 - Java 内存模型引入 很多人都无法区分Java内存模型和JVM内存结构&#xff0c;以及Java内存模型与物理内存之间的关系。本文是JVM第六讲&#xff0c;从堆栈角度引入JMM&#xff0c;然后介绍JMM和物理内存之间的关系, 为后面JMM详解, JVM 内存结构详解, Java 对象模型详…

蓝桥杯基础---切面条

切面条 一根高筋拉面&#xff0c;中间切一刀&#xff0c;可以得到2根面条。 如果先对折1次&#xff0c;中间切一刀&#xff0c;可以得到3根面条。 如果连续对折2次&#xff0c;中间切一刀&#xff0c;可以得到5根面条。 那么&#xff0c;连续对折10次&#xff0c;中间切一刀…

react数据管理之setState与Props

react数据管理之setState与Props setState调用原理 setState 是 React 中用于更新组件状态&#xff08;state&#xff09;的方法。它的调用原理可以分为以下几个步骤&#xff1a; 状态的改变&#xff1a;当调用 setState 时&#xff0c;React 会将新的状态对象与当前状态对象…

QT 网络编程 服务端 客户端 QTcpServer

服务端的创建 //创建服务端QTcpServer对象 server new QTcpServer(this);//设置服务端&#xff0c;端口&#xff0c;这里绑定的是主机的所有网卡&#xff0c; server->listen(QHostAddress::Any, 8080);//绑定连接信号与槽 connect(this->server, &QTcpServer::new…

[C++随想录] 继承

继承 继承的引言基类和子类的赋值转换继承中的作用域派生类中的默认成员函数继承与友元继承与静态成员多继承的结构棱形继承的结构棱形虚拟继承的结构继承与组合 继承的引言 概念 继承(inheritance)机制是面向对象程序设计使代码可以 复用的最重要的手段&#xff0c;它允许程序…