【深度学习实验】前馈神经网络(七):批量加载数据(直接加载数据→定义类封装数据)

目录

一、实验介绍

 二、实验环境

1. 配置虚拟环境

2. 库版本介绍

三、实验内容

 0. 导入必要的工具包

1. 直接加载鸢尾花数据集

a. 加载数据集

b. 数据归一化

c. 洗牌操作

d. 打印数据

2. 定义类封装数据

a. __init__(构造函数:用于初始化数据集对象)

b. __getitem__(获取指定索引处的样本)

c. __len__(获取数据集的长度)

3. 构建数据集(批量加载训练、验证、测试集)

4. 代码整合


一、实验介绍

        在本系列先前的代码中,借助深度学习框架的帮助,已经完成了前馈神经网络的大部分功能。本文将基于鸢尾花数据集构建一个数据迭代器,以便在每次迭代时从全部数据集中获取指定数量的数据。(借助深度学习框架中的Dataset类和DataLoader类来实现此功能)
 

【深度学习】Pytorch 系列教程(十三):PyTorch数据结构:5、数据加载器(DataLoader)_QomolangmaH的博客-CSDN博客icon-default.png?t=N7T8https://blog.csdn.net/m0_63834988/article/details/132924381?spm=1001.2014.3001.5502

 二、实验环境

    本系列实验使用了PyTorch深度学习框架,相关操作如下:

1. 配置虚拟环境

conda create -n DL python=3.7 
conda activate DL
pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
conda install matplotlib
 conda install scikit-learn

2. 库版本介绍

软件包本实验版本目前最新版
matplotlib3.5.33.8.0
numpy1.21.61.26.0
python3.7.16
scikit-learn0.22.11.3.0
torch1.8.1+cu1022.0.1
torchaudio0.8.12.0.2
torchvision0.9.1+cu1020.15.2

三、实验内容

ChatGPT:

        前馈神经网络(Feedforward Neural Network)是一种常见的人工神经网络模型,也被称为多层感知器(Multilayer Perceptron,MLP)。它是一种基于前向传播的模型,主要用于解决分类和回归问题。

        前馈神经网络由多个层组成,包括输入层、隐藏层和输出层。它的名称"前馈"源于信号在网络中只能向前流动,即从输入层经过隐藏层最终到达输出层,没有反馈连接。

以下是前馈神经网络的一般工作原理:

  1. 输入层:接收原始数据或特征向量作为网络的输入,每个输入被表示为网络的一个神经元。每个神经元将输入加权并通过激活函数进行转换,产生一个输出信号。

  2. 隐藏层:前馈神经网络可以包含一个或多个隐藏层,每个隐藏层由多个神经元组成。隐藏层的神经元接收来自上一层的输入,并将加权和经过激活函数转换后的信号传递给下一层。

  3. 输出层:最后一个隐藏层的输出被传递到输出层,输出层通常由一个或多个神经元组成。输出层的神经元根据要解决的问题类型(分类或回归)使用适当的激活函数(如Sigmoid、Softmax等)将最终结果输出。

  4. 前向传播:信号从输入层通过隐藏层传递到输出层的过程称为前向传播。在前向传播过程中,每个神经元将前一层的输出乘以相应的权重,并将结果传递给下一层。这样的计算通过网络中的每一层逐层进行,直到产生最终的输出。

  5. 损失函数和训练:前馈神经网络的训练过程通常涉及定义一个损失函数,用于衡量模型预测输出与真实标签之间的差异。常见的损失函数包括均方误差(Mean Squared Error)和交叉熵(Cross-Entropy)。通过使用反向传播算法(Backpropagation)和优化算法(如梯度下降),网络根据损失函数的梯度进行参数调整,以最小化损失函数的值。

        前馈神经网络的优点包括能够处理复杂的非线性关系,适用于各种问题类型,并且能够通过训练来自动学习特征表示。然而,它也存在一些挑战,如容易过拟合、对大规模数据和高维数据的处理较困难等。为了应对这些挑战,一些改进的网络结构和训练技术被提出,如卷积神经网络(Convolutional Neural Networks)和循环神经网络(Recurrent Neural Networks)等。

本系列为实验内容,对理论知识不进行详细阐释

(咳咳,其实是没时间整理,待有缘之时,回来填坑)

977468b5ae9843c6a88005e792817cb1.png

 0. 导入必要的工具包

import torch
from sklearn.datasets import load_iris
from torch.utils.data import Dataset, DataLoader
  • DatasetDataLoader类用于处理数据集和数据加载

1. 直接加载鸢尾花数据集

        加载鸢尾花数据进行归一化并可选地进行洗牌操作,以便于后续的深度学习任务。

import torch
from sklearn.datasets import load_irisdef load_data(shuffle=True):x = torch.tensor(load_iris().data)y = torch.tensor(load_iris().target)# 数据归一化x_min = torch.min(x, dim=0).valuesx_max = torch.max(x, dim=0).valuesx = (x - x_min) / (x_max - x_min)if shuffle:idx = torch.randperm(x.shape[0])x = x[idx]y = y[idx]return x, y

a. 加载数据集

  • 调用load_iris().data函数加载数据,并使用torch.tensor将数据转换为PyTorch张量,将结果赋值给变量x

  • 调用load_iris().target函数加载目标变量,并使用torch.tensor将数据转换为PyTorch张量,将结果赋值给变量y

b. 数据归一化

  • 计算矩阵x每列的最小值。

    • torch.min函数的dim参数设置为0表示按列计算最小值,.values属性获取最小值的张量。

  • 计算矩阵x每列的最大值。

    • torch.max函数的dim参数设置为0表示按列计算最大值,.values属性获取最大值的张量。

  • x = (x-x_min)/(x_max-x_min):对矩阵x进行归一化处理,将每个元素减去最小值,然后除以最大值与最小值之差。这样可以将数据缩放到0和1之间

c. 洗牌操作

  • if shuffle::如果shuffle参数为True,执行以下代码块。

    • idx = torch.randperm(x.shape[0]):生成一个随机排列的索引,范围从0到x的行数减1。torch.randperm函数返回一个随机排列的整数序列。

    • x = x[idx]:根据生成的随机索引对矩阵x进行行重排,打乱数据的顺序。

    • y = y[idx]:根据生成的随机索引对向量y进行行重排,保持目标变量与输入数据的对应关系

  • return x, y:返回处理后的输入特征矩阵x和目标变量向量y

d. 打印数据

x, y = load_data()
print("Input features (x):")
print(x)
print("Target variables (y):")
print(y)

2. 定义类封装数据

        创建一个用于处理鸢尾花数据集的自定义数据集(继承自Dataset类),该自定义数据集类可以用于创建鸢尾花数据集的训练集、验证集或测试集对象,并提供给__getitem____len__方法,以便能够使用DataLoader类进行数据加载和批处理操作

class IrisDataset(Dataset):def __init__(self, mode='train', num_train=120, num_dev=15):super(IrisDataset,self).__init__()x, y = load_data(shuffle=True)if mode == 'train':self.x, self.y = x[:num_train], y[:num_train]elif mode == 'dev':self.x, self.y = x[num_train:num_train + num_dev], y[num_train:num_train + num_dev]else:self.x, self.y = x[num_train + num_dev:], y[num_train + num_dev:]def __getitem__(self, idx):return self.x[idx], self.y[idx]def __len__(self):return len(self.x)

  • class IrisDataset(Dataset)::定义了一个名为IrisDataset的类,该类继承自Dataset类,表示一个自定义的数据集。

a. __init__(构造函数:用于初始化数据集对象)

  • 该函数接受三个参数:

    • mode表示数据集模式(默认为'train')

    • num_train表示训练样本的数量(默认为120)

    • num_dev表示验证样本的数量(默认为15)。

  • super(IrisDataset, self).__init__()调用父类Dataset的构造函数,确保正确地初始化基类。

  • x, y = load_data(shuffle=True):调用之前定义的load_data函数加载数据集

  • 如果数据集模式为'train':

    • 将前num_train个训练样本赋值给类的成员变量self.xself.y,表示训练数据集

  • 如果数据集模式为'dev':

    • 将从第num_train个样本开始的num_dev个样本赋值给类的成员变量self.xself.y,表示验证数据集

  • 如果数据集模式不是'train'也不是'dev':

    • 将从第num_train + num_dev个样本开始的剩余样本赋值给类的成员变量self.xself.y,表示测试数据集

b. __getitem__(获取指定索引处的样本)

  • return self.x[idx], self.y[idx]:根据索引idx返回对应位置的输入特征和目标变量。

c. __len__(获取数据集的长度)

  • return len(self.x):返回数据集的长度,即样本数量。

3. 构建数据集(批量加载训练、验证、测试集)

batch_size = 16# 分别构建训练集、验证集和测试集
train_dataset = IrisDataset(mode='train')
dev_dataset = IrisDataset(mode='dev')
test_dataset = IrisDataset(mode='test')train_loader = DataLoader(train_dataset, batch_size=batch_size,shuffle=True)
dev_loader = DataLoader(dev_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)
  • 使用自定义的数据封装类加载鸢尾花数据集的训练集、验证集和测试集,并使用DataLoader进行批量加载。
    • train_dataset是要加载的数据集对象,batch_size是批量大小,表示每个批次的样本数量,shuffle=True表示在每个迭代周期中对数据进行随机洗牌。
    • 将验证集数据集加载到dev_loader中,未指定shuffle参数,默认为False,不进行洗牌。
    • 将测试集数据集加载到test_loader中,将batch_size设置为1,表示每个批次只包含一个样本,同时指定shuffle=True,在每个迭代周期中对数据进行随机洗牌。

4. 代码整合

# 导入必要的工具包
import torch
from sklearn.datasets import load_iris
from torch.utils.data import Dataset, DataLoader# 此函数用于加载鸢尾花数据集
def load_data(shuffle=True):x = torch.tensor(load_iris().data)y = torch.tensor(load_iris().target)# 数据归一化x_min = torch.min(x, dim=0).valuesx_max = torch.max(x, dim=0).valuesx = (x - x_min) / (x_max - x_min)if shuffle:idx = torch.randperm(x.shape[0])x = x[idx]y = y[idx]return x, y# 构建自己的数据集,继承自Dataset类
class IrisDataset(Dataset):def __init__(self, mode='train', num_train=120, num_dev=15):super(IrisDataset, self).__init__()x, y = load_data(shuffle=True)if mode == 'train':self.x, self.y = x[:num_train], y[:num_train]elif mode == 'dev':self.x, self.y = x[num_train:num_train + num_dev], y[num_train:num_train + num_dev]else:self.x, self.y = x[num_train + num_dev:], y[num_train + num_dev:]def __getitem__(self, idx):return self.x[idx], self.y[idx]def __len__(self):return len(self.x)batch_size = 16# 分别构建训练集、验证集和测试集
train_dataset = IrisDataset(mode='train')
dev_dataset = IrisDataset(mode='dev')
test_dataset = IrisDataset(mode='test')train_loader = DataLoader(train_dataset, batch_size=batch_size,shuffle=True)
dev_loader = DataLoader(dev_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/85089.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

华为OD机试 - 构成正方形的数量 - 数据结构map(Java 2023 B卷 100分)

目录 专栏导读一、题目描述二、输入描述三、输出描述四、Java算法源码五、效果展示1、输入2、输出3、说明 华为OD机试 2023B卷题库疯狂收录中,刷题点这里 专栏导读 本专栏收录于《华为OD机试(JAVA)真题(A卷B卷)》。 …

mysql 半同步复制模式使用详解

目录 一、前言 二、mysql主从架构简介 2.1 mysql主从复制架构概述 2.2 为什么使用主从架构 2.2.1 提高数据可用性 2.2.2 提高数据可靠性 2.2.3 提升数据读写性能 2.3 主从架构原理 2.4 主从架构扩展 2.4.1 双机热备(AB复制) 2.4.2 级联复制 2…

Qt核心:元对象系统、属性系统、对象树、信号槽

一、元对象系统 1、Qt 的元对象系统提供的功能有:对象间通信的信号和槽机制、运行时类型信息和动态属性系统等。 2、元对象系统是 Qt 对原有的 C进行的一些扩展,主要是为实现信号和槽机制而引入的, 信号和槽机制是 Qt 的核心特征。 3、要使…

当网络设置为自动获取dns时而实际nds是8.8.8.8,1.1.1.1的解决方法

笔记本换网络环境后,网络设置的是自动获取IP和自动获取dns。但使用命令:config/all命令时发现dns总是8.8.8.8,1.1.1.1。导致csdn上不了。 8.8.8.8,1.1.1.1:是谷歌的dns。 解决办法: 在支行中输入regedit打开注册表后&#xff0…

windows下载虚拟机virtualBox

链接:Downloads – Oracle VM VirtualBox 进入链接这样点击: 直接下载即可

Java“牵手”速卖通商品列表页数据采集+速卖通商品价格数据排序,速卖通API接口申请指南

速卖通是阿里巴巴旗下的面向国际市场打造的跨境电商平台,被称为国际版淘宝,速卖通面向海外买家客户,通过支付宝国际账户进行担保交易,并使用国际物流渠道运输发货,是全球第三大英文在线购物网站。 速卖通商品列表数据…

关于IDEA没有显示日志输出?IDEA控制台没有显示Tomcat Localhost Log和Catalina Log 怎么办?

问题描述: 原因是;CATALINA_BASE里面没有相关的文件配置。而之前学习IDEA的时候,把这个文件的位置改变了。导致,最后输出IDEA的时候,不会把日志也打印出来。 检查IDEA配置; D:\work_soft\tomcat_user\Tomcat10.0\bin 在此目录下&…

如何在没有第三方.NET库源码的情况,调试第三库代码?

大家好,我是沙漠尽头的狼。 本方首发于Dotnet9,介绍使用dnSpy调试第三方.NET库源码,行文目录: 安装dnSpy编写示例程序调试示例程序调试.NET库原生方法总结 1. 安装dnSpy dnSpy是一款功能强大的.NET程序反编译工具,…

STM32 Cubemx 通用定时器 General-Purpose Timers同步

文章目录 前言简介cubemx配置 前言 持续学习stm32中… 简介 通用定时器是一个16位的计数器,支持向上up、向下down与中心对称up-down三种模式。可以用于测量信号脉宽(输入捕捉),输出一定的波形(比较输出与PWM输出&am…

activemq部署

目录 1.下载 2.java环境 3.解压启动 4.访问测试 5.问题记录 5.1.无法启动成功问题 5.2.其他服务器无法访问 1.下载 ActiveMQ 2.java环境 需要注意要求的jdk版本,否则启动不会成功 3.解压启动 tar -zxvf apache-activemq-5.18.2-bin.tar.gz 进入到目录下执行…

使用递归思想遍历二叉树

二叉树的遍历主要有两种方式:深度优先遍历和广度优先遍历 这篇主要讲使用深度优先遍历来遍历二叉树 深度优先遍历有以下三种 前、中、后序遍历,这三种遍历方式的主要区别是中间节点的位置所在的顺序 前序遍历:中间节点在叶子节点前面 中序遍历…

Flink--4、DateStream API(执行环境、源算子、基本转换算子)

星光下的赶路人star的个人主页 注意力的集中,意象的孤立绝缘,便是美感的态度的最大特点 文章目录 1、DataStream API1.1 执行环境(Execution Environment)1.1.1 创建执行环境 1.2 执行模式(Execution Mode)…

Linux学习记录——이십구 网络基础(2)

文章目录 1、理解网络间通信2、理解协议3、网络字节序4、socket编程接口和sockaddr结构 1、理解网络间通信 宏观上,是主机与主机在发送接收消息,但主机怎么去发送消息?主机间的通信是通过进程完成的,这个进程就是用户发起的进程&…

终于把量化入门了,实盘权限已开,学习这件事也不难

多数人18岁就死了,但直到75岁才埋。 ——网易云热评《杀死那个石家庄人》 猫猫挺喜欢这句话的,为什么只活动75岁,于是我查了查现如今78.6岁,大差不差的。 那扣扣减减,人生短短57年,唯一十八岁那年&#xff…

鼠标移入展示字体操作

鼠标移入展示字体 点击删除实行删除操作&#xff0c;点击图片文字跳转产品详情的逻辑实现 <div class"allProduct-content"><template v-for"(item, index) in obj.product" :key"index"><!-- <img :src"item.image&qu…

Mac 上安装yt-dlp 和下载视频的操作

安装 打开终端&#xff0c;在终端输入 cd python的路径&#xff0c;然后输入pip3 install yt-dlp&#xff0c;如下图&#xff1b; 出现 如Successfully installed yt-dlp-2023.7.6 的时候&#xff0c;说明下载成功 下载 下载命令&#xff1a; yt-dlp --list-formats https…

【00】FISCO BCOS区块链简介

官方文档&#xff1a;https://fisco-bcos-documentation.readthedocs.io/zh_CN/latest/docs/introduction.html FISCO BCOS是由国内企业主导研发、对外开源、安全可控的企业级金融联盟链底层平台&#xff0c;由金链盟开源工作组协作打造&#xff0c;并于2017年正式对外开源。 F…

【Unity基础】4.动画Animation

【Unity基础】4.动画Animation 大家好&#xff0c;我是Lampard~~ 欢迎来到Unity基础系列博客&#xff0c;所学知识来自B站阿发老师~感谢 &#xff08;一&#xff09;Unity动画编辑器 &#xff08;1&#xff09;Animation组件 这一张我们要学习如何在unity编辑器中&…

HarmonyOS/OpenHarmony应用开发-DevEco Studio新建项目的整体说明

一、文件-新建-新建项目 二、传统应用形态与IDE自带的模板可供选用与免安装的元服与IDE中自带模板的选择 三、以元服务&#xff0c;远程模拟器为例说明IDE整体结构 1区是工程目录结构&#xff0c;是最基本的配置与开发路径等的认知。 2区是代码开发与修改区&#xff0c;是开发…

TexStudio报错 Class: No Found

\classdocument[preprint,review,fleqn,sort&compress,3p]{elsarticle}这里常见导入的类&#xff08;class&#xff09;文件有article.cls&#xff0c;elsarticle.cls&#xff0c;sn-jnl.cls等 一般来说这些文件都应该和我们的源文件document.tex在同一个目录下。如果不在…