《PyTorch深度学习实践》第八讲加载数据集

一、

1、DataSet 是抽象类,不能实例化对象,主要是用于构造我们的数据集

2、DataLoader 需要获取DataSet提供的索引[i]和len;用来帮助我们加载数据,比如说做shuffle(提高数据集的随机性),batch_size,能拿出Mini-Batch进行训练。它帮我们自动完成这些工作。DataLoader可实例化对象。DataLoader is a class to help us loading data in Pytorch.

3、__getitem__目的是为支持下标(索引)操作
 

二、

import torch
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader# prepare datasetclass DiabetesDataset(Dataset):def __init__(self, filepath):xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)self.len = xy.shape[0] # shape(多少行,多少列)self.x_data = torch.from_numpy(xy[:, :-1])self.y_data = torch.from_numpy(xy[:, [-1]])def __getitem__(self, index):return self.x_data[index], self.y_data[index]def __len__(self):return self.lendataset = DiabetesDataset('diabetes.csv')
train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, num_workers=0) #num_workers 多线程# design model using classclass Model(torch.nn.Module):def __init__(self):super(Model, self).__init__()self.linear1 = torch.nn.Linear(8, 6)self.linear2 = torch.nn.Linear(6, 4)self.linear3 = torch.nn.Linear(4, 1)self.sigmoid = torch.nn.Sigmoid()def forward(self, x):x = self.sigmoid(self.linear1(x))x = self.sigmoid(self.linear2(x))x = self.sigmoid(self.linear3(x))return xmodel = Model()# construct loss and optimizer
criterion = torch.nn.BCELoss(reduction='mean')
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)# training cycle forward, backward, update
if __name__ == '__main__':for epoch in range(100):for i, data in enumerate(train_loader, 0): # train_loader 是先shuffle后mini_batchinputs, labels = datay_pred = model(inputs)loss = criterion(y_pred, labels)print(epoch, i, loss.item())optimizer.zero_grad()loss.backward()optimizer.step()

1、需要mini_batch 就需要import DataSet和DataLoader

2、继承DataSet的类需要重写init,getitem,len魔法函数。分别是为了加载数据集,获取数据索引,获取数据总量。

3、DataLoader对数据集先打乱(shuffle),然后划分成mini_batch。

4、len函数的返回值 除以 batch_size 的结果就是每一轮epoch中需要迭代的次数。

5、inputs, labels = data中的inputs的shape是[32,8],labels 的shape是[32,1]。也就是说mini_batch在这个地方体现的

6、diabetes.csv数据集老师给了下载地址,该数据集需和源代码放在同一个文件夹内。

问题:loss没有收敛

网友解决:

做了两个实验:(1)输出每批次的loss,不收敛,loss在0.6上下浮动(2)每个epoch都不分批,把所有样本都输入,收敛,最后结果在0.6附近。所以猜测:小样本之间的loss差距相对于0.6而言有点大,所以看着像是没收敛,实际上从总loss来看已经收敛了

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/712642.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Windows10环境下MongoDB安装配置

1. 下载对应MongoDB安装包 进入官网:MongoDB官网 如果不连接外网则在官网下载较慢,这里给出下载好的安装包,版本为4.2.25:百度网盘 选择你需要的版本,推荐选择Package的格式为zip(解压即可) Pa…

[VNCTF2024]-PWN:preinit解析(逆向花指令,绕过strcmp,函数修改,机器码)

查看保护: 查看ida: 这边其实看反汇编没啥大作用,需要自己动调。 但是前面的绕过strcmp还是要看一下的。 解题: 这里是用linux自带的产生随机数的文件urandom来产生一个随机密码,然后让我们输入密码,用st…

k8s 存储卷详解与动静部署详解

目录 一、Volume 卷 1.1 卷类型 emptyDir : hostPath: persistentVolumeClaim (PVC): configMap 和 secret: 二、 emptyDir存储卷 2.1 特点 2.2 用途: 2.3 示例 三、 hostPath存储卷 3.1 特点 3.2 用途 …

前端mock数据 —— 使用Apifox mock页面所需数据

前端mock数据 —— 使用Apifox 一、使用教程二、本地请求Apifox所mock的接口 一、使用教程 在首页进行新建项目: 新建项目名称: 新建接口: 创建json: 请求方法: GET。URL: api/basis。响应类型&#xff1…

可以用numpy为for加速

Numpy除了用于科学计算,还有一个功能是可以代替某些for循环,进行同样的功能实现,有于是向量矩阵运算,碰到复杂的for时,计算速度可以提高,从而提高程序性能。以下是一些常用的NumPy函数和操作,可…

Socket网络编程(六)——简易聊天室案例

目录 聊天室数据传输设计客户端、服务器数据交互数据传输协议服务器、多客户端模型客户端如何发送消息到另外一个客户端2个以上设备如何交互数据? 聊天室消息接收实现代码结构client客户端重构server服务端重构自身描述信息的构建重构TCPServer.java基于synchronize…

Nginx多次代理后获取真实的用户IP访问地址

需求:记录用户操作记录,类似如下表格的这样 PS: 注意无论你的服务是Http访问还是Https 访问的都是可以的,我们服务之前是客户只给开放了一个端口,但是既要支持https又要支持http协议,nginx 是可以通过stream 模块配置双…

2023中国PostgreSQL数据库生态大会:洞察前沿趋势,探索无限可能(附核心PPT资料下载)

随着数字化浪潮的推进,数据库技术已成为支撑各行各业数字化转型的核心力量。2023中国PostgreSQL数据库生态大会的召开,无疑为业界提供了一个深入交流、共同探索PostgreSQL数据库技术未来发展趋势的平台。本文将带您走进这场盛会,解析大会的亮…

k8s Pod基础(概念,容器功能及分类,镜像拉取和容器重启策略)

目录 pod概念 Kubernetes设计Pod概念和特殊组成结构的用意 Pod内部结构: 网络共享: 存储共享: pause容器主要功能 pod创建方式 pod使用方式 pod分类 pod的容器分类 基础容器(infrastructure container)&…

加密和签名的区别及应用场景

原文网址:加密和签名的区别及应用场景_IT利刃出鞘的博客-CSDN博客 简介 本文介绍加密和签名的区别及应用场景。 RSA是一种非对称加密算法, 可生成一对密钥(私钥和公钥)。(RSA可以同时支持加密和签名)。 …

元宇宙3D虚拟场景制作深圳华锐视点免费试用

随着元宇宙兴起,3D线上展厅得到了越来越多的关注和应用。基于VR虚拟现实技术的元宇宙3D线上展厅在线编辑系统,更是为企业在展览展示领域带来了前所未有的辅助。 高效便捷: 元宇宙3D线上展厅在线编辑无需复杂的施工和搭建过程,只需…

报错问题解决django.db.utils.OperationalError: (1049, “Unknown database ‘ mxshop‘“)

开发环境:ubuntu22.04 pycharm 功能:django连接使用mysql数据库,各项配置看似正常 报错: django.db.utils.OperationalError: (1049, "Unknown database mxshop") 分析检查原因: Setting的配置文件内&…

gcd+线性dp,[蓝桥杯 2018 国 B] 矩阵求和

一、题目 1、题目描述 经过重重笔试面试的考验,小明成功进入 Macrohard 公司工作。 今天小明的任务是填满这么一张表: 表有 �n 行 �n 列,行和列的编号都从 11 算起。 其中第 �i 行第 �j 个元素…

GRPC 错误码表

code数描述OK0不是错误;成功返回。CANCELLED1操作通常由调用方取消。UNKNOWN2未知错误。例如,当从另一个地址空间接收的值属于此地址空间中未知的错误空间时,可能会返回此错误。此外,未返回足够错误信息的 API 引发的错误可能会转换为此错误。…

ggplot去除背景

在ggplot2中去除背景,通常指的是去除图表的灰色背景和网格线,使图表背景变为透明或白色,以及去除或简化坐标轴的背景。这可以通过调整主题(theme)来实现。ggplot2提供了多种主题设置,可以用来调整图表的外观…

Spring MVC 和 Spring Cloud Gateway不兼容性问题

当启动SpringCloudGateway网关服务的时候,没注意好依赖问题,出现了这个问题: Spring MVC found on classpath, which is incompatible with Spring Cloud Gateway. 解决办法就是:删除SpringMVC的依赖,即下列依赖。 &…

ChatGPT/GPT4科研应用与AI绘图及论文高效写作

原文:ChatGPT/GPT4科研应用与AI绘图及论文高效写作 第一:2024年AI领域最新技术 1.OpenAI新模型-GPT-5 2.谷歌新模型-Gemini Ultra 3.Meta新模型-LLama3 4.科大讯飞-星火认知 5.百度-文心一言 6.MoonshotAI-Kimi 7.智谱AI-GLM-4 第二:…

【C++从0到王者】第四十六站:图的深度优先与广度优先

文章目录 一、图的遍历二、广度优先遍历1.思想2.算法实现3.六度好友 三、深度优先遍历1.思想2.代码实现 四、其他问题 一、图的遍历 对于图而言,我们的遍历一般是遍历顶点,而不是边,因为边的遍历是比较简单的,就是邻接矩阵或者邻接…

《汇编语言》第3版 (王爽)检测点3.1解析

第三章 检测点3.1 (1).在Debug中,用“d 0:0 1f”查看内存,结果如下。 下面的程序执行前,AX 0,BX 0,写出每条汇编指令执行完后相关寄存器中的值。 mov ax,1 ;将1放入AX寄存器中,…

GC如何判定对象已死

GC判定对象已死的2种方法 引用计数法 给对象中添加一个引用计数器,每当有一个地方引用它时,计数器值就加1;当引用失效时,计数器值就减1;Java语言中没有选用引用计数算法来管理内存,其中最主要的原因是它很…