《PyTorch深度学习实践》第八讲加载数据集

一、

1、DataSet 是抽象类,不能实例化对象,主要是用于构造我们的数据集

2、DataLoader 需要获取DataSet提供的索引[i]和len;用来帮助我们加载数据,比如说做shuffle(提高数据集的随机性),batch_size,能拿出Mini-Batch进行训练。它帮我们自动完成这些工作。DataLoader可实例化对象。DataLoader is a class to help us loading data in Pytorch.

3、__getitem__目的是为支持下标(索引)操作
 

二、

import torch
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader# prepare datasetclass DiabetesDataset(Dataset):def __init__(self, filepath):xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)self.len = xy.shape[0] # shape(多少行,多少列)self.x_data = torch.from_numpy(xy[:, :-1])self.y_data = torch.from_numpy(xy[:, [-1]])def __getitem__(self, index):return self.x_data[index], self.y_data[index]def __len__(self):return self.lendataset = DiabetesDataset('diabetes.csv')
train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, num_workers=0) #num_workers 多线程# design model using classclass Model(torch.nn.Module):def __init__(self):super(Model, self).__init__()self.linear1 = torch.nn.Linear(8, 6)self.linear2 = torch.nn.Linear(6, 4)self.linear3 = torch.nn.Linear(4, 1)self.sigmoid = torch.nn.Sigmoid()def forward(self, x):x = self.sigmoid(self.linear1(x))x = self.sigmoid(self.linear2(x))x = self.sigmoid(self.linear3(x))return xmodel = Model()# construct loss and optimizer
criterion = torch.nn.BCELoss(reduction='mean')
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)# training cycle forward, backward, update
if __name__ == '__main__':for epoch in range(100):for i, data in enumerate(train_loader, 0): # train_loader 是先shuffle后mini_batchinputs, labels = datay_pred = model(inputs)loss = criterion(y_pred, labels)print(epoch, i, loss.item())optimizer.zero_grad()loss.backward()optimizer.step()

1、需要mini_batch 就需要import DataSet和DataLoader

2、继承DataSet的类需要重写init,getitem,len魔法函数。分别是为了加载数据集,获取数据索引,获取数据总量。

3、DataLoader对数据集先打乱(shuffle),然后划分成mini_batch。

4、len函数的返回值 除以 batch_size 的结果就是每一轮epoch中需要迭代的次数。

5、inputs, labels = data中的inputs的shape是[32,8],labels 的shape是[32,1]。也就是说mini_batch在这个地方体现的

6、diabetes.csv数据集老师给了下载地址,该数据集需和源代码放在同一个文件夹内。

问题:loss没有收敛

网友解决:

做了两个实验:(1)输出每批次的loss,不收敛,loss在0.6上下浮动(2)每个epoch都不分批,把所有样本都输入,收敛,最后结果在0.6附近。所以猜测:小样本之间的loss差距相对于0.6而言有点大,所以看着像是没收敛,实际上从总loss来看已经收敛了

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/712642.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Windows10环境下MongoDB安装配置

1. 下载对应MongoDB安装包 进入官网:MongoDB官网 如果不连接外网则在官网下载较慢,这里给出下载好的安装包,版本为4.2.25:百度网盘 选择你需要的版本,推荐选择Package的格式为zip(解压即可) Pa…

[VNCTF2024]-PWN:preinit解析(逆向花指令,绕过strcmp,函数修改,机器码)

查看保护: 查看ida: 这边其实看反汇编没啥大作用,需要自己动调。 但是前面的绕过strcmp还是要看一下的。 解题: 这里是用linux自带的产生随机数的文件urandom来产生一个随机密码,然后让我们输入密码,用st…

k8s 存储卷详解与动静部署详解

目录 一、Volume 卷 1.1 卷类型 emptyDir : hostPath: persistentVolumeClaim (PVC): configMap 和 secret: 二、 emptyDir存储卷 2.1 特点 2.2 用途: 2.3 示例 三、 hostPath存储卷 3.1 特点 3.2 用途 …

前端mock数据 —— 使用Apifox mock页面所需数据

前端mock数据 —— 使用Apifox 一、使用教程二、本地请求Apifox所mock的接口 一、使用教程 在首页进行新建项目: 新建项目名称: 新建接口: 创建json: 请求方法: GET。URL: api/basis。响应类型&#xff1…

Socket网络编程(六)——简易聊天室案例

目录 聊天室数据传输设计客户端、服务器数据交互数据传输协议服务器、多客户端模型客户端如何发送消息到另外一个客户端2个以上设备如何交互数据? 聊天室消息接收实现代码结构client客户端重构server服务端重构自身描述信息的构建重构TCPServer.java基于synchronize…

Nginx多次代理后获取真实的用户IP访问地址

需求:记录用户操作记录,类似如下表格的这样 PS: 注意无论你的服务是Http访问还是Https 访问的都是可以的,我们服务之前是客户只给开放了一个端口,但是既要支持https又要支持http协议,nginx 是可以通过stream 模块配置双…

2023中国PostgreSQL数据库生态大会:洞察前沿趋势,探索无限可能(附核心PPT资料下载)

随着数字化浪潮的推进,数据库技术已成为支撑各行各业数字化转型的核心力量。2023中国PostgreSQL数据库生态大会的召开,无疑为业界提供了一个深入交流、共同探索PostgreSQL数据库技术未来发展趋势的平台。本文将带您走进这场盛会,解析大会的亮…

k8s Pod基础(概念,容器功能及分类,镜像拉取和容器重启策略)

目录 pod概念 Kubernetes设计Pod概念和特殊组成结构的用意 Pod内部结构: 网络共享: 存储共享: pause容器主要功能 pod创建方式 pod使用方式 pod分类 pod的容器分类 基础容器(infrastructure container)&…

元宇宙3D虚拟场景制作深圳华锐视点免费试用

随着元宇宙兴起,3D线上展厅得到了越来越多的关注和应用。基于VR虚拟现实技术的元宇宙3D线上展厅在线编辑系统,更是为企业在展览展示领域带来了前所未有的辅助。 高效便捷: 元宇宙3D线上展厅在线编辑无需复杂的施工和搭建过程,只需…

报错问题解决django.db.utils.OperationalError: (1049, “Unknown database ‘ mxshop‘“)

开发环境:ubuntu22.04 pycharm 功能:django连接使用mysql数据库,各项配置看似正常 报错: django.db.utils.OperationalError: (1049, "Unknown database mxshop") 分析检查原因: Setting的配置文件内&…

gcd+线性dp,[蓝桥杯 2018 国 B] 矩阵求和

一、题目 1、题目描述 经过重重笔试面试的考验,小明成功进入 Macrohard 公司工作。 今天小明的任务是填满这么一张表: 表有 �n 行 �n 列,行和列的编号都从 11 算起。 其中第 �i 行第 �j 个元素…

Spring MVC 和 Spring Cloud Gateway不兼容性问题

当启动SpringCloudGateway网关服务的时候,没注意好依赖问题,出现了这个问题: Spring MVC found on classpath, which is incompatible with Spring Cloud Gateway. 解决办法就是:删除SpringMVC的依赖,即下列依赖。 &…

ChatGPT/GPT4科研应用与AI绘图及论文高效写作

原文:ChatGPT/GPT4科研应用与AI绘图及论文高效写作 第一:2024年AI领域最新技术 1.OpenAI新模型-GPT-5 2.谷歌新模型-Gemini Ultra 3.Meta新模型-LLama3 4.科大讯飞-星火认知 5.百度-文心一言 6.MoonshotAI-Kimi 7.智谱AI-GLM-4 第二:…

【C++从0到王者】第四十六站:图的深度优先与广度优先

文章目录 一、图的遍历二、广度优先遍历1.思想2.算法实现3.六度好友 三、深度优先遍历1.思想2.代码实现 四、其他问题 一、图的遍历 对于图而言,我们的遍历一般是遍历顶点,而不是边,因为边的遍历是比较简单的,就是邻接矩阵或者邻接…

《汇编语言》第3版 (王爽)检测点3.1解析

第三章 检测点3.1 (1).在Debug中,用“d 0:0 1f”查看内存,结果如下。 下面的程序执行前,AX 0,BX 0,写出每条汇编指令执行完后相关寄存器中的值。 mov ax,1 ;将1放入AX寄存器中,…

【零基础SRC】成为漏洞赏金猎人的第一课:加入玲珑安全漏洞挖掘班。

我们是谁 你是否对漏洞挖掘充满好奇?零基础或有基础但想更进一步?想赚取可观的漏洞赏金让自己有更大的自由度? 那么,不妨了解下我们《玲珑安全团队》。 玲珑安全团队,拥有多名实力讲师,均就职于互联网头…

一线互联网大厂中高级Android面试真题收录,记一次字节跳动Android社招面试

在开始回答前,先简单概括性地说说Linux现有的所有进程间IPC方式: 1. **管道:**在创建时分配一个page大小的内存,缓存区大小比较有限; 2. 消息队列:信息复制两次,额外的CPU消耗;不合…

指针与malloc动态内存申请,堆和栈的差异

定义了两个函数print_stack()和print_malloc(),分别演示了两种不同的内存分配方式:栈内存和堆内存。然后在main()函数中调用这两个函数,并将它们返回的指针打印出来。 由于print_stack()中的数组c是在栈上分配的,当函数返回后&…

Python装饰器的使用详解

目录 1、函数装饰器 1.1、闭包函数 1.2、装饰器语法 1.3、装饰带参数的函数 1.4、被装饰函数的身份问题 1.4.1、解决被装饰函数的身份问题 1.5、装饰器本身携带/传参数 1.6、嵌套多个装饰器 2、类装饰器 装饰器顾名思义作为一个装饰的作用,本身不改变被装…

Acwing 周赛135 解题报告 | 珂学家 | 反悔堆贪心

前言 整体评价 VP了这场比赛, T3挺有意思的,反悔贪心其实蛮套路的。 A. 买苹果 思路: 签到 n, x list(map(int, input().split())) print (n // x)B. 牛群 思路: 分类讨论 from collections import Counters input() cnt Counter(s)lists sorte…