Pytorch深度学习实践笔记8(b站刘二大人)

🎬个人简介:一个全栈工程师的升级之路!
📋个人专栏:pytorch深度学习
🎀CSDN主页 发狂的小花
🌄人生秘诀:学习的本质就是极致重复!

《PyTorch深度学习实践》完结合集_哔哩哔哩_bilibili​

目录

1 Pytorch 数据加载

2 Dataset和DataLoader

3 程序


1 Pytorch 数据加载

  • epoch、Batch-size 、iteration


例如下图:
8个样本、shuffle是打乱样本的顺序,Batch-szie为2,iteration 就是 8 / 2 为4,epoch是训练集进行几个轮次的迭代。

 




2 Dataset和DataLoader

 




Dataset 是一个抽象类,使用时必须进行重写,from 在torch.utils.data Dataset
(1)重写时,需要根据数据来进行构造__init__(self,filepath)
(2)__getitem__(self,index)用来让数据可以进行索引操作
(3)__len__(self)用来获取数据集的大小
DataLoader 用来加载数据为mini-Batch ,支持Batch-size 的设置,shuffle支持数据的打乱顺序。

  • 参数说明:
from torch.utils.data import DataLoadertest_load = DataLoader(dataset=test_data, batch_size=4 , shuffle= True, num_workers=0,drop_last=False)


batch_size=4表示每次取四个数据
shuffle= True表示开启数据集随机重排,即每次取完数据之后,打乱剩余数据的顺序,然后再进行下一次取
num_workers=0表示在主进程中加载数据而不使用任何额外的子进程,如果大于0,表示开启多个进程,进程越多,处理数据的速度越快,但是会使电脑性能下降,占用更多的内存
drop_last=False表示不丢弃最后一个批次,假设我数据集有10个数据,我的batch_size=3,即每次取三个数据,那么我最后一次只有一个数据能取,如果设置为true,则不丢弃这个包含1个数据的子集数据,反之则丢弃

 

  • 数据转换为dataset形式,进行DataLoader的使用
x_data = torch.tensor([[1.0],[2.0],[3.0],[4.0],[5.0],[6.0],[7.0],[8.0],[9.0]])
y_data = torch.tensor([[2.0],[4.0],[6.0],[8.0],[10.0],[12.0],[14.0],[16.0],[18.0]])dataset = Data.TensorDataset(x_data,y_data)loader = Data.DataLoader(  dataset=dataset,  batch_size=BATCH_SIZE,  shuffle=True,  num_workers=0  
)

pytorch中的DataLoader_pytorch dataloader-CSDN博客​


3 程序


数据分为训练集和测试集:Adam 训练

import torch
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_splitimport matplotlib.pyplot as plt# 读取原始数据,并划分训练集和测试集
raw_data = np.loadtxt('./dataset/diabetes.csv.gz', delimiter=',', dtype=np.float32)
X = raw_data[:, :-1]
Y = raw_data[:, [-1]]
Xtrain, Xtest, Ytrain, Ytest = train_test_split(X,Y,test_size=0.1)
Xtest = torch.from_numpy(Xtest)
Ytest = torch.from_numpy(Ytest)# 将训练数据集进行批量处理
# prepare datasetclass DiabetesDataset(Dataset):def __init__(self, data,label):self.len = data.shape[0] # shape(多少行,多少列)self.x_data = torch.from_numpy(data)self.y_data = torch.from_numpy(label)def __getitem__(self, index):return self.x_data[index], self.y_data[index]def __len__(self):return self.lentrain_dataset = DiabetesDataset(Xtrain,Ytrain)
train_loader = DataLoader(dataset=train_dataset, batch_size=16, shuffle=True, num_workers=0) #num_workers 多线程# design model using classclass Model(torch.nn.Module):def __init__(self):super(Model, self).__init__()self.linear1 = torch.nn.Linear(8, 6)self.linear2 = torch.nn.Linear(6, 4)self.linear3 = torch.nn.Linear(4, 2)self.linear4 = torch.nn.Linear(2, 1)self.sigmoid = torch.nn.Sigmoid()def forward(self, x):x = self.sigmoid(self.linear1(x))x = self.sigmoid(self.linear2(x))x = self.sigmoid(self.linear3(x))x = self.sigmoid(self.linear4(x))return xmodel = Model()# construct loss and optimizer
criterion = torch.nn.BCELoss(reduction='mean')
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)epoch_list = []
loss_list = []# training cycle forward, backward, update
def train(epoch):for i, data in enumerate(train_loader, 0):inputs, labels = datay_pred = model(inputs)loss = criterion(y_pred, labels)optimizer.zero_grad()loss.backward()optimizer.step()return loss.item()def test():with torch.no_grad():y_pred = model(Xtest)y_pred_label = torch.where(y_pred>=0.5,torch.tensor([1.0]),torch.tensor([0.0]))acc = torch.eq(y_pred_label, Ytest).sum().item() / Ytest.size(0)print("test acc:", acc)if __name__ == '__main__':for epoch in range(10000):loss_val = train(epoch)print("epoch: ",epoch," loss: ",loss_val)epoch_list.append(epoch)loss_list.append(loss_val)test()plt.plot(epoch_list,loss_list)plt.title("Adam")plt.xlabel("Epoch")plt.ylabel("Loss")plt.savefig("./data/pytorch7_1.png")



简单的程序
 

import torch
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader# prepare datasetclass DiabetesDataset(Dataset):def __init__(self, filepath):xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)self.len = xy.shape[0] # shape(多少行,多少列)self.x_data = torch.from_numpy(xy[:, :-1])self.y_data = torch.from_numpy(xy[:, [-1]])def __getitem__(self, index):return self.x_data[index], self.y_data[index]def __len__(self):return self.lendataset = DiabetesDataset('./dataset/diabetes.csv.gz')
train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, num_workers=0) #num_workers 多线程# design model using classclass Model(torch.nn.Module):def __init__(self):super(Model, self).__init__()self.linear1 = torch.nn.Linear(8, 6)self.linear2 = torch.nn.Linear(6, 4)self.linear3 = torch.nn.Linear(4, 1)self.sigmoid = torch.nn.Sigmoid()def forward(self, x):x = self.sigmoid(self.linear1(x))x = self.sigmoid(self.linear2(x))x = self.sigmoid(self.linear3(x))return xmodel = Model()# construct loss and optimizer
criterion = torch.nn.BCELoss(reduction='mean')
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)# training cycle forward, backward, update
if __name__ == '__main__':for epoch in range(100):for i, data in enumerate(train_loader, 0): # train_loader 是先shuffle后mini_batchinputs, labels = datay_pred = model(inputs)loss = criterion(y_pred, labels)print(epoch, i, loss.item())optimizer.zero_grad()loss.backward()optimizer.step()

🌈我的分享也就到此结束啦🌈
如果我的分享也能对你有帮助,那就太好了!
若有不足,还请大家多多指正,我们一起学习交流!
📢未来的富豪们:点赞👍→收藏⭐→关注🔍,如果能评论下就太惊喜了!
感谢大家的观看和支持!最后,☺祝愿大家每天有钱赚!!!欢迎关注、关注!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/16694.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

视频汇聚管理安防监控平台EasyCVR程序报错“create jwtSecret del server class:0xf98b6040”的原因排查与解决

国标GB28181协议EasyCVR安防视频监控平台可以提供实时远程视频监控、视频录像、录像回放与存储、告警、语音对讲、云台控制、平台级联、磁盘阵列存储、视频集中存储、云存储等丰富的视频能力,平台支持7*24小时实时高清视频监控,能同时播放多路监控视频流…

java调用科大讯飞在线语音合成API --内附完整项目

科大讯飞语音开放平台基础环境搭建 1.用户注册 注册科大讯飞开放平台账号 2.注册好后先创建一个自己的应用 创建完成后进入应用可以看到我们开发需要的三个参数:APPID,APISecret,APIKey 3.因为平台提供的SDK中只支持了简单的中英两种语言语音…

Redis 可视化工具 RedisInsight 的保姆级安装以及使用(最新)

Redis 可视化工具 RedisInsight 的保姆级安装以及使用 一、下载 RedisInsight二、安装 RedisInsight三、使用 RedisInsight四、新建 Redis 连接 一、下载 RedisInsight 官网 https://redis.io/insight/填写基本信息之后点击 DOWNLOAD 二、安装 RedisInsight 双击安装包 点击下一…

cad角度如何精确到0.1

可以通过更改角度精度的方式把角度的标注精确到小数点后几位,具体方法如下: 1、打开一个CAD文档,在文档中画一个角,如下图: 文章源自设计学徒自学网-https://www.sx1c.com/47920.html 2、给此角进行角度的标注&#…

Java锁的策略

White graces&#xff1a;个人主页 &#x1f649;专栏推荐:Java入门知识&#x1f649; &#x1f649; 内容推荐:<多线程案例(线程池)>&#x1f649; &#x1f439;今日诗词:"你我推心置腹, 岂能相负"&#x1f439; 目录 锁的策略 乐观锁和悲观锁 轻量级锁…

[数据集][目标检测]森林火灾检测数据集VOC+YOLO格式362张1类别

数据集格式&#xff1a;Pascal VOC格式YOLO格式(不包含分割路径的txt文件&#xff0c;仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数)&#xff1a;362 标注数量(xml文件个数)&#xff1a;362 标注数量(txt文件个数)&#xff1a;362 标注类别…

四川音盛佳云电子商务有限公司铸就抖音电商新高度

在数字经济的浪潮中&#xff0c;抖音电商以其独特的魅力迅速崛起&#xff0c;成为新时代消费潮流的引领者。四川音盛佳云电子商务有限公司&#xff0c;作为抖音电商领域的佼佼者&#xff0c;凭借专业的团队和创新的理念&#xff0c;致力于为广大消费者提供优质、便捷的购物体验…

和可被k整除的子数组 ---- 前缀和

题目链接 题目: 分析: 补充知识 1. 同余定理: (a-b) % p 0即a-b能被p整除, > a % p b % p 2. c, java中 [负数 % 正数] 的结果是负数, 想要得到正确结果 > (a%pp)%p这道题和<和为k的子数组>类似, 利用前缀和的思想, 计算以i结尾的所有子数组, 前缀和为sum[i] …

探索编程逻辑中的“卡特牛(continue)”魔法

新书上架~&#x1f447;全国包邮奥~ python实用小工具开发教程http://pythontoolsteach.com/3 欢迎关注我&#x1f446;&#xff0c;收藏下次不迷路┗|&#xff40;O′|┛ 嗷~~ 目录 一、引言&#xff1a;卡特牛逻辑的魅力 二、卡特牛逻辑的解析 三、卡特牛逻辑的应用实例 …

sqlserver——查询(四)——连接查询

目录 一.连接查询 分类&#xff1a; 内连接&#xff1a; 1. select ... from A&#xff0c;B &#xff1b; 2. select ..from A&#xff0c;B where ..&#xff1b; 3.select ...,... from A join B on... 4. where 与 join...on 的区别 5. where位置的先后 导语&#xff1…

每日5题Day11 - LeetCode 51 - 55

每一步向前都是向自己的梦想更近一步&#xff0c;坚持不懈&#xff0c;勇往直前&#xff01; 第一题&#xff1a;51. N 皇后 - 力扣&#xff08;LeetCode&#xff09; class Solution {public List<List<String>> solveNQueens(int n) {List<List<String>…

如果查看svn的账号和密码

一、找到svn存放目录&#xff08;本地默认存放SVN用户信息的目录为&#xff1a;C:\Users\Administrator\AppData\Roaming\Subversion\auth\svn.simple&#xff09;每个人的电脑环境不一样&#xff0c;因人而异。 如果找不到直接搜索svn.simple 二、下载密码查看工具 链接: 百…

MySQL——MySQL目录结构

MySQL安装完成后&#xff0c;会在磁盘上生成一个目录&#xff0c;该目录被称为MySQL的安装目录。在MySQL的安装目录中包含了启动文件、配置文件、数据库文件和命令文件等。 下面对 MySQL 的安装目录进行详细讲解 (1)bin 目录 : 用于放置一些可执行文件,如 mysql.exe、mysqld. …

软件设计师基础知识难点总结

软件设计师基础知识难点 I/O设备管理软件一般分为4个层次&#xff0c;如下图所示。 用户进程与设备无关的系统软件设备驱动程序中断处理程序硬件 直接查询控制 分为有无条件传送和程序查询方式&#xff0c;都需要通过CPU执行程序来查询外设的状态&#xff0c;判断外设是否准备好…

C#多维数组不同读取方式的性能差异

背景 近来在优化一个图像显示程序&#xff0c;图像数据存储于一个3维数组data[x,y,z]中&#xff0c;三维数组为一张张图片数据的叠加而来&#xff0c;其中x为图片的张数&#xff0c;y为图片行&#xff0c;Z为图片的列&#xff0c;也就是说这个三维数组存储的为一系列图片的数据…

深度解析 Spring 源码:探秘 CGLIB 代理的奥秘

文章目录 一、CGLIB 代理简介1.1 CGLIB 代理的基本原理和特点1.2 分析 CGLIB 如何通过字节码技术创建代理类 二、深入分析 CglibAopProxy 类的结构2.1 CglibAopProxy 类结构2.2 CglibAopProxy 类源码 三、CGLIB 代理对象的创建过程3.1 配置 Enhancer 生成代理对象3.2 探讨如何通…

PageHelper分页查询时,count()查询记录总数与实际返回的数据数量不一致

目录 场景简介代码判断异常情况排查原因解决 场景简介 1、使用PageHelper进行分页查询 2、最终构建PageInfo对象时&#xff0c;total与实际数据量不符 代码判断 异常情况 排查 通过对比count()查询的SQL与查询记录的SQL&#xff0c;发现是PageHelper分页查询时省去了order b…

Linux系统编程——基础IO与文件描述符(管理已打开的内存文件)

目录 一&#xff0c;文件预备 二&#xff0c;C语言文件操作函数 2.1 默认打开的三个流 2.2 写文件 2.3 读文件 2.4 再次理解当前路径 三&#xff0c;Linux操作文件系统调用 3.1 open()和close() 3.1.1 第一个参数 3.1.2 *第二个参数 3.1.3 第三个参数 3.2 write(…

(2024,基于熵的激活函数动态优化,具有边界条件的最差激活函数,修正正则化 ReLU)寻找更优激活函数

A Method on Searching Better Activation Functions 公众号&#xff1a;EDPJ&#xff08;进 Q 交流群&#xff1a;922230617 或加 VX&#xff1a;CV_EDPJ 进 V 交流群&#xff09; 目录 0. 摘要 3. 动机 4. 方法论 4.1 问题设定 4.1.1 贝叶斯错误率和信息熵 4.1.2 激活…

host修改

前言 想要修改 hosts 文件&#xff0c;您需要具有对系统文件的适当访问权限&#xff0c;并且知道如何编辑文本文件。hosts 文件是一个用于域名解析的本地文件&#xff0c;它允许您为特定的 IP 地址指定主机名。 以下是在不同操作系统中修改 hosts 文件的步骤&#xff1a; 一、…