Pytorch深度学习实践笔记8(b站刘二大人)

🎬个人简介:一个全栈工程师的升级之路!
📋个人专栏:pytorch深度学习
🎀CSDN主页 发狂的小花
🌄人生秘诀:学习的本质就是极致重复!

《PyTorch深度学习实践》完结合集_哔哩哔哩_bilibili​

目录

1 Pytorch 数据加载

2 Dataset和DataLoader

3 程序


1 Pytorch 数据加载

  • epoch、Batch-size 、iteration


例如下图:
8个样本、shuffle是打乱样本的顺序,Batch-szie为2,iteration 就是 8 / 2 为4,epoch是训练集进行几个轮次的迭代。

 




2 Dataset和DataLoader

 




Dataset 是一个抽象类,使用时必须进行重写,from 在torch.utils.data Dataset
(1)重写时,需要根据数据来进行构造__init__(self,filepath)
(2)__getitem__(self,index)用来让数据可以进行索引操作
(3)__len__(self)用来获取数据集的大小
DataLoader 用来加载数据为mini-Batch ,支持Batch-size 的设置,shuffle支持数据的打乱顺序。

  • 参数说明:
from torch.utils.data import DataLoadertest_load = DataLoader(dataset=test_data, batch_size=4 , shuffle= True, num_workers=0,drop_last=False)


batch_size=4表示每次取四个数据
shuffle= True表示开启数据集随机重排,即每次取完数据之后,打乱剩余数据的顺序,然后再进行下一次取
num_workers=0表示在主进程中加载数据而不使用任何额外的子进程,如果大于0,表示开启多个进程,进程越多,处理数据的速度越快,但是会使电脑性能下降,占用更多的内存
drop_last=False表示不丢弃最后一个批次,假设我数据集有10个数据,我的batch_size=3,即每次取三个数据,那么我最后一次只有一个数据能取,如果设置为true,则不丢弃这个包含1个数据的子集数据,反之则丢弃

 

  • 数据转换为dataset形式,进行DataLoader的使用
x_data = torch.tensor([[1.0],[2.0],[3.0],[4.0],[5.0],[6.0],[7.0],[8.0],[9.0]])
y_data = torch.tensor([[2.0],[4.0],[6.0],[8.0],[10.0],[12.0],[14.0],[16.0],[18.0]])dataset = Data.TensorDataset(x_data,y_data)loader = Data.DataLoader(  dataset=dataset,  batch_size=BATCH_SIZE,  shuffle=True,  num_workers=0  
)

pytorch中的DataLoader_pytorch dataloader-CSDN博客​


3 程序


数据分为训练集和测试集:Adam 训练

import torch
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_splitimport matplotlib.pyplot as plt# 读取原始数据,并划分训练集和测试集
raw_data = np.loadtxt('./dataset/diabetes.csv.gz', delimiter=',', dtype=np.float32)
X = raw_data[:, :-1]
Y = raw_data[:, [-1]]
Xtrain, Xtest, Ytrain, Ytest = train_test_split(X,Y,test_size=0.1)
Xtest = torch.from_numpy(Xtest)
Ytest = torch.from_numpy(Ytest)# 将训练数据集进行批量处理
# prepare datasetclass DiabetesDataset(Dataset):def __init__(self, data,label):self.len = data.shape[0] # shape(多少行,多少列)self.x_data = torch.from_numpy(data)self.y_data = torch.from_numpy(label)def __getitem__(self, index):return self.x_data[index], self.y_data[index]def __len__(self):return self.lentrain_dataset = DiabetesDataset(Xtrain,Ytrain)
train_loader = DataLoader(dataset=train_dataset, batch_size=16, shuffle=True, num_workers=0) #num_workers 多线程# design model using classclass Model(torch.nn.Module):def __init__(self):super(Model, self).__init__()self.linear1 = torch.nn.Linear(8, 6)self.linear2 = torch.nn.Linear(6, 4)self.linear3 = torch.nn.Linear(4, 2)self.linear4 = torch.nn.Linear(2, 1)self.sigmoid = torch.nn.Sigmoid()def forward(self, x):x = self.sigmoid(self.linear1(x))x = self.sigmoid(self.linear2(x))x = self.sigmoid(self.linear3(x))x = self.sigmoid(self.linear4(x))return xmodel = Model()# construct loss and optimizer
criterion = torch.nn.BCELoss(reduction='mean')
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)epoch_list = []
loss_list = []# training cycle forward, backward, update
def train(epoch):for i, data in enumerate(train_loader, 0):inputs, labels = datay_pred = model(inputs)loss = criterion(y_pred, labels)optimizer.zero_grad()loss.backward()optimizer.step()return loss.item()def test():with torch.no_grad():y_pred = model(Xtest)y_pred_label = torch.where(y_pred>=0.5,torch.tensor([1.0]),torch.tensor([0.0]))acc = torch.eq(y_pred_label, Ytest).sum().item() / Ytest.size(0)print("test acc:", acc)if __name__ == '__main__':for epoch in range(10000):loss_val = train(epoch)print("epoch: ",epoch," loss: ",loss_val)epoch_list.append(epoch)loss_list.append(loss_val)test()plt.plot(epoch_list,loss_list)plt.title("Adam")plt.xlabel("Epoch")plt.ylabel("Loss")plt.savefig("./data/pytorch7_1.png")



简单的程序
 

import torch
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader# prepare datasetclass DiabetesDataset(Dataset):def __init__(self, filepath):xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)self.len = xy.shape[0] # shape(多少行,多少列)self.x_data = torch.from_numpy(xy[:, :-1])self.y_data = torch.from_numpy(xy[:, [-1]])def __getitem__(self, index):return self.x_data[index], self.y_data[index]def __len__(self):return self.lendataset = DiabetesDataset('./dataset/diabetes.csv.gz')
train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, num_workers=0) #num_workers 多线程# design model using classclass Model(torch.nn.Module):def __init__(self):super(Model, self).__init__()self.linear1 = torch.nn.Linear(8, 6)self.linear2 = torch.nn.Linear(6, 4)self.linear3 = torch.nn.Linear(4, 1)self.sigmoid = torch.nn.Sigmoid()def forward(self, x):x = self.sigmoid(self.linear1(x))x = self.sigmoid(self.linear2(x))x = self.sigmoid(self.linear3(x))return xmodel = Model()# construct loss and optimizer
criterion = torch.nn.BCELoss(reduction='mean')
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)# training cycle forward, backward, update
if __name__ == '__main__':for epoch in range(100):for i, data in enumerate(train_loader, 0): # train_loader 是先shuffle后mini_batchinputs, labels = datay_pred = model(inputs)loss = criterion(y_pred, labels)print(epoch, i, loss.item())optimizer.zero_grad()loss.backward()optimizer.step()

🌈我的分享也就到此结束啦🌈
如果我的分享也能对你有帮助,那就太好了!
若有不足,还请大家多多指正,我们一起学习交流!
📢未来的富豪们:点赞👍→收藏⭐→关注🔍,如果能评论下就太惊喜了!
感谢大家的观看和支持!最后,☺祝愿大家每天有钱赚!!!欢迎关注、关注!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/16694.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Vue 状态管理深入研究:Vuex 和 Pinia 的原理与实践对比

推荐一个AI网站,免费使用豆包AI模型,快去白嫖👉海鲸AI 👋 引言 在 Vue.js 应用程序中,状态管理是一个至关重要的方面。它有助于集中管理应用的状态,使组件之间的数据共享更加高效和可维护。Vuex 和 Pinia …

视频汇聚管理安防监控平台EasyCVR程序报错“create jwtSecret del server class:0xf98b6040”的原因排查与解决

国标GB28181协议EasyCVR安防视频监控平台可以提供实时远程视频监控、视频录像、录像回放与存储、告警、语音对讲、云台控制、平台级联、磁盘阵列存储、视频集中存储、云存储等丰富的视频能力,平台支持7*24小时实时高清视频监控,能同时播放多路监控视频流…

java调用科大讯飞在线语音合成API --内附完整项目

科大讯飞语音开放平台基础环境搭建 1.用户注册 注册科大讯飞开放平台账号 2.注册好后先创建一个自己的应用 创建完成后进入应用可以看到我们开发需要的三个参数:APPID,APISecret,APIKey 3.因为平台提供的SDK中只支持了简单的中英两种语言语音…

Redis 可视化工具 RedisInsight 的保姆级安装以及使用(最新)

Redis 可视化工具 RedisInsight 的保姆级安装以及使用 一、下载 RedisInsight二、安装 RedisInsight三、使用 RedisInsight四、新建 Redis 连接 一、下载 RedisInsight 官网 https://redis.io/insight/填写基本信息之后点击 DOWNLOAD 二、安装 RedisInsight 双击安装包 点击下一…

SQLite 3.4.60 版本发布,带来优化器和函数增强!

SQLite 开发团队于 2024 年 05 月 23 日发布了 SQLite 3.46.0 版本,带来了不少优化器和函数相关的增强,我们来了解一下新版本的改进功能。 优化数据库 新版本增强了 PRAGMA optimize 指令,简化了它的使用,具体包括: …

cad角度如何精确到0.1

可以通过更改角度精度的方式把角度的标注精确到小数点后几位,具体方法如下: 1、打开一个CAD文档,在文档中画一个角,如下图: 文章源自设计学徒自学网-https://www.sx1c.com/47920.html 2、给此角进行角度的标注&#…

Java锁的策略

White graces&#xff1a;个人主页 &#x1f649;专栏推荐:Java入门知识&#x1f649; &#x1f649; 内容推荐:<多线程案例(线程池)>&#x1f649; &#x1f439;今日诗词:"你我推心置腹, 岂能相负"&#x1f439; 目录 锁的策略 乐观锁和悲观锁 轻量级锁…

C\C++ 中大端存储及小端存储的代码判定

● 大端存储&#xff08;Big Endian&#xff09;还是小端存储&#xff08;Little Endian&#xff09; 大端小端的说法来源于一个具有讽刺意味的小说。说的是有两个党派争论从哪一头敲破鸡蛋更好&#xff0c;一党坚持从小端敲破鸡蛋更好&#xff0c;即 Little Endian&#xff0c…

[数据集][目标检测]森林火灾检测数据集VOC+YOLO格式362张1类别

数据集格式&#xff1a;Pascal VOC格式YOLO格式(不包含分割路径的txt文件&#xff0c;仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数)&#xff1a;362 标注数量(xml文件个数)&#xff1a;362 标注数量(txt文件个数)&#xff1a;362 标注类别…

四川音盛佳云电子商务有限公司铸就抖音电商新高度

在数字经济的浪潮中&#xff0c;抖音电商以其独特的魅力迅速崛起&#xff0c;成为新时代消费潮流的引领者。四川音盛佳云电子商务有限公司&#xff0c;作为抖音电商领域的佼佼者&#xff0c;凭借专业的团队和创新的理念&#xff0c;致力于为广大消费者提供优质、便捷的购物体验…

和可被k整除的子数组 ---- 前缀和

题目链接 题目: 分析: 补充知识 1. 同余定理: (a-b) % p 0即a-b能被p整除, > a % p b % p 2. c, java中 [负数 % 正数] 的结果是负数, 想要得到正确结果 > (a%pp)%p这道题和<和为k的子数组>类似, 利用前缀和的思想, 计算以i结尾的所有子数组, 前缀和为sum[i] …

探索编程逻辑中的“卡特牛(continue)”魔法

新书上架~&#x1f447;全国包邮奥~ python实用小工具开发教程http://pythontoolsteach.com/3 欢迎关注我&#x1f446;&#xff0c;收藏下次不迷路┗|&#xff40;O′|┛ 嗷~~ 目录 一、引言&#xff1a;卡特牛逻辑的魅力 二、卡特牛逻辑的解析 三、卡特牛逻辑的应用实例 …

sqlserver——查询(四)——连接查询

目录 一.连接查询 分类&#xff1a; 内连接&#xff1a; 1. select ... from A&#xff0c;B &#xff1b; 2. select ..from A&#xff0c;B where ..&#xff1b; 3.select ...,... from A join B on... 4. where 与 join...on 的区别 5. where位置的先后 导语&#xff1…

每日5题Day11 - LeetCode 51 - 55

每一步向前都是向自己的梦想更近一步&#xff0c;坚持不懈&#xff0c;勇往直前&#xff01; 第一题&#xff1a;51. N 皇后 - 力扣&#xff08;LeetCode&#xff09; class Solution {public List<List<String>> solveNQueens(int n) {List<List<String>…

如果查看svn的账号和密码

一、找到svn存放目录&#xff08;本地默认存放SVN用户信息的目录为&#xff1a;C:\Users\Administrator\AppData\Roaming\Subversion\auth\svn.simple&#xff09;每个人的电脑环境不一样&#xff0c;因人而异。 如果找不到直接搜索svn.simple 二、下载密码查看工具 链接: 百…

MySQL——MySQL目录结构

MySQL安装完成后&#xff0c;会在磁盘上生成一个目录&#xff0c;该目录被称为MySQL的安装目录。在MySQL的安装目录中包含了启动文件、配置文件、数据库文件和命令文件等。 下面对 MySQL 的安装目录进行详细讲解 (1)bin 目录 : 用于放置一些可执行文件,如 mysql.exe、mysqld. …

BioTech - 使用 CombFold 算法 实现 大型蛋白质复合物结构 的组装过程

欢迎关注我的CSDN:https://spike.blog.csdn.net/ 本文地址:https://spike.blog.csdn.net/article/details/139242199 CombFold 是用于预测大型蛋白质复合物结构的组合和分层组装算法,利用 AlphaFold2 预测的亚基之间的成对相互作用。 CombFold 算法的关键特点包括: 组合和…

工厂模式(Factory Pattern)简介

工厂模式是一种创建对象的设计模式&#xff0c;它通过定义一个创建对象的接口&#xff0c;让子类决定实例化哪一个类。工厂模式使得一个类的实例化延迟到其子类。 在Android开发中&#xff0c;工厂模式可以用于创建不同类型的对象&#xff0c;尤其是在需要创建复杂对象或对象创…

软件设计师基础知识难点总结

软件设计师基础知识难点 I/O设备管理软件一般分为4个层次&#xff0c;如下图所示。 用户进程与设备无关的系统软件设备驱动程序中断处理程序硬件 直接查询控制 分为有无条件传送和程序查询方式&#xff0c;都需要通过CPU执行程序来查询外设的状态&#xff0c;判断外设是否准备好…

Vue3学习-vue-router之路由传参

传参方案一&#xff1a;RouterLink 字符串 //传值 <RouterLink to"/page?a1&b2">{{ RouterLink 字符串传参 }}</RouterLink> //取值 import { toRefs } from vue import { useRoute } from vue-router const { query} toRefs(useRoute()) console.…