pytorch学习day2

1 数据加载Dataset

PyTorch的数据读取机制主要依赖于DatasetDataLoader这两个核心组件。它们用于加载和处理数据,以便在训练模型时进行高效的数据流动和处理。

Dataset

Dataset是一个抽象类,用户可以继承这个类并重载以下两个方法来创建自定义的数据集:

  1. __init__ 方法:

    • csv_file:指向包含图像路径和标签的CSV文件路径。
    • root_dir:包含所有图像的根目录路径。
    • transform:一个可选的变换,用于在返回样本之前处理数据。

    在初始化过程中,读取CSV文件并存储在self.data_frame中,还设置了图像的根目录和可选的变换。

  2. __len__ 方法:

    • 返回数据集中样本的数量,即CSV文件中记录的行数。
  3. __getitem__ 方法:

    • 接收一个索引 idx,从CSV文件中获取对应的图像路径和标签。
    • 使用PIL库打开图像文件,并将其转换为RGB格式。
    • 如果定义了变换,则将其应用到图像。
    • 返回处理后的图像和对应的标签。

自定义Dataset示例

import torch
from torch.utils.data import Datasetclass CustomDataset(Dataset):def __init__(self, data, labels):self.data = dataself.labels = labelsdef __len__(self):return len(self.data)def __getitem__(self, index):sample = self.data[index]label = self.labels[index]return sample, label# 示例数据
data = torch.randn(100, 3)  # 100个样本,每个样本3个特征
labels = torch.randint(0, 2, (100,))  # 100个标签# 创建自定义数据集
dataset = CustomDataset(data, labels)

2 可迭代的数据装载器DataLoader

DataLoader 是 PyTorch 中一个非常重要的类,用于构建可迭代的数据装载器。它能够有效地加载数据并在训练模型时提供数据批次。下面我们详细介绍 DataLoader 的各个参数和使用方法。

DataLoader 的功能

DataLoader 主要用于在训练过程中,每个 for 循环中从数据集中获取一个指定大小(batch_size)的数据批次。

参数解释

  1. dataset:

    • 类型:Dataset 类实例
    • 功能:决定数据从哪里读取以及如何读取。Dataset 类定义了数据集的具体内容及访问方式。
  2. batch_size:

    • 类型:整数
    • 功能:每个数据批次的大小。例如,batch_size=32 表示每次从数据集中获取32个样本。
  3. num_workers:

    • 类型:整数
    • 功能:决定使用多少个子进程来加载数据。更多的进程数可以加快数据加载速度,但过多的进程数可能会导致系统资源不足,建议设置为 4、8、16 等。
  4. shuffle:

    • 类型:布尔值
    • 功能:决定每个 epoch 开始时是否打乱数据顺序。打乱数据可以增加训练过程的随机性,通常设置为 True
  5. drop_last:

    • 类型:布尔值
    • 功能:如果数据集中的样本数不能被 batch_size 整除,决定是否舍弃最后一个不完整的数据批次。设置为 True 表示舍弃。

重要概念

  1. Epoch:

    • 定义:所有训练样本都已输入到模型中,称为一个 epoch。
  2. Iteration:

    • 定义:一个批次的样本输入到模型中,称为一次 iteration。
  3. Batch Size:

    • 定义:批大小,决定一个 epoch 中有多少次 iteration。
# 创建 DataLoader 实例
dataloader = DataLoader(dataset=dataset,       # 自定义数据集batch_size=32,         # 每批次32个样本shuffle=True,          # 每个epoch开始时打乱数据num_workers=4,         # 使用4个子进程加载数据drop_last=True         # 当样本数不能被batch_size整除时,舍弃最后一批数据
)# 训练循环示例
for epoch in range(num_epochs):for batch_idx, (data, labels) in enumerate(dataloader):# 模型训练代码pass

3 图像预处理transforms

在PyTorch中,transforms是一个用于图像预处理的模块。transforms提供了一组常用的图像变换方法,可以对图像进行数据增强、归一化、裁剪、缩放等操作。transforms主要用于将图像数据转换成适合模型输入的格式。

常用的Transforms

以下是一些常用的transforms操作:

  1. transforms.Compose:将多个变换组合起来。
  2. transforms.Resize:调整图像大小。
  3. transforms.CenterCrop:从图像中心裁剪。
  4. transforms.RandomCrop:随机裁剪图像。
  5. transforms.RandomHorizontalFlip:随机水平翻转图像。
  6. transforms.ToTensor:将PIL图像或Numpy数组转换为张量,并将像素值归一化到[0, 1]。
  7. transforms.Normalize:用均值和标准差归一化张量。
  8. transforms.ColorJitter:随机改变图像的亮度、对比度和饱和度。
  9. transforms.RandomRotation:随机旋转图像。
from torchvision import transforms
from PIL import Image# 定义图像预处理变换
transform = transforms.Compose([transforms.Resize((128, 128)),             # 调整图像大小transforms.RandomHorizontalFlip(),         # 随机水平翻转transforms.RandomRotation(10),             # 随机旋转10度transforms.ColorJitter(brightness=0.5),    # 随机改变亮度transforms.ToTensor(),                     # 转换为张量并归一化到[0, 1]transforms.Normalize((0.5,), (0.5,))       # 用均值0.5和标准差0.5归一化
])# 加载图像
image = Image.open("path_to_image.jpg").convert("RGB")# 应用预处理变换
transformed_image = transform(image)# 检查变换后的图像
print(transformed_image.size())

如果现有的transforms无法满足需求,可以自定义变换。只需实现__call__方法即可

import torchclass CustomTransform:def __call__(self, sample):# 自定义变换逻辑,例如将图像转换为灰度图return transforms.functional.rgb_to_grayscale(sample)# 使用自定义变换
transform = transforms.Compose([transforms.Resize((128, 128)),CustomTransform(),transforms.ToTensor()
])image = Image.open("path_to_image.jpg").convert("RGB")
transformed_image = transform(image)
print(transformed_image.size())

4 综合数据读取和数据预处理

以下是一个综合示例,展示如何定义数据集并使用各种transforms进行图像预处理和数据增强。

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
import pandas as pdclass CustomCSVImageDataset(Dataset):def __init__(self, csv_file, root_dir, transform=None):self.data_frame = pd.read_csv(csv_file)self.root_dir = root_dirself.transform = transformdef __len__(self):return len(self.data_frame)def __getitem__(self, idx):if torch.is_tensor(idx):idx = idx.tolist()img_name = os.path.join(self.root_dir, self.data_frame.iloc[idx, 0])image = Image.open(img_name).convert('RGB')label = self.data_frame.iloc[idx, 1]if self.transform:image = self.transform(image)return image, label# 定义图像预处理和数据增强
transform = transforms.Compose([transforms.Resize((128, 128)),transforms.RandomHorizontalFlip(),transforms.RandomRotation(15),transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])# 示例数据
csv_file = './data/labels.csv'
root_dir = './data/images'# 创建数据集
dataset = CustomCSVImageDataset(csv_file=csv_file, root_dir=root_dir, transform=transform)# 创建DataLoader
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)# 迭代DataLoader
for batch_idx, (data, labels) in enumerate(dataloader):print(f"Batch {batch_idx}:")print("数据大小:", data.size())print("标签大小:", labels.size())

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/19470.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

吴恩达深度学习个人笔记

1. 上一个视频提到的房地产领域,我们不就使用了一个普遍标准神经网络架构吗 而对于图像识别处理问题,我们则要使用卷积神经网络(Convolution Neural Network),即CNN。 对于序列数据,例如音频,有一个时间组件,随着时间的推移,音频被播放出来, 所以音频是最自然的表…

Mac下载docker

先安装homebrew Mac下载Homebrew-CSDN博客 然后输入以下命令安装docker brew install --cask --appdir/Applications docker 期间需要输入密码。输入完等待即可

Kubernetes (K8s) 普及指南

在当今的云计算和微服务时代,Kubernetes(简称K8s)已经成为容器编排的标准工具。它帮助开发者和运维人员管理和部署应用程序,实现高可用性、可伸缩性和自我修复。本文将详细介绍Kubernetes的基本概念、核心组件、工作原理及其优势。…

刷机 iPhone 进入恢复模式

文章目录 第 1 步:确保你有一台电脑(Mac 或 PC)第 2 步:将 iPhone 关机第 3 步:将 iPhone 置于恢复模式第 4 步:使用 Mac 或 PC 恢复 iPhone需要更多协助? 本文转载自:如果你忘记了 …

【AI】llama-fs的 安装与运行

pip install -r .\requirements.txt Windows PowerShell Copyright (C) Microsoft Corporation. All rights reserved.Install the latest PowerShell for new features and improvements! https://aka.ms/PSWindows(venv) PS D:\XTRANS\pythonProject>

微服务 feign-gateway

早期微服务跨集群调用 使用的是Eureka 和RestTemplate,这种写法虽然可以解决服务之间的调用问题 ,但是随着服务的增多,实例变动,早期的写法相当于把请求方式,请求地址,参数写死了,耦合度太高,参…

山东大学软件学院2021级编译原理回忆版

一、判断题 1、正则文法可以表示一般的高级程序语言,构成其语法成分和生成句子() 2、NFA的状态和符号有且只有一条边,因此看起来更直观() 3、DFA无法表示这样的语言{anbn,n>1}() …

HackTheBox-Machines--Nibbles

Nibbles 测试过程 1 信息收集 NMAP 80 端口 网站出了打印出“Hello world!”外,无其他可利用信息,但是查看网页源代码时,发现存在一个 /nibbleblog 文件夹 检查了 http://10.129.140.63/nibbleblog/ ,发现了 /index.p…

Windows环境下 postgresql16 增量备份及恢复

修改postgresql.conf isten_addresses * wal_level replica archive_mode on archive_command copy /V "%p" C:\\backup\\wal_files\\%f 注意写法,这里有大坑 restore_command copy c:\\backup\\wal_files\\%f "%p" recov…

探索无限可能:API平台引领数据驱动的新时代

在数字化浪潮的推动下,数据已成为推动商业创新和增长的核心动力。然而,数据的获取、整合和应用并非易事,需要跨越技术、安全和效率等多重挑战。幸运的是,API(应用程序接口)平台的出现,为我们打开…

pom文件中,Maven导入依赖出现 Dependency not found

解决方案: 1、首先看一下自己的Maven是否配置好了 2、再检查一下镜像是否正确 3、如果上面都没有问题,看 dependencyManagement 标签 我这个出错,爆一大片红就是因为 这个标签 dependencyManagement 解决方法:在父工程中进行依…

代码随想录算法训练营第36期DAY44

DAY44 闫氏DP 2 01背包问题 用滚动数组来优化空间&#xff0c;从后向前&#xff08;大到小&#xff09;遍历j #include<iostream>using namespace std;const int N1010;int n,m;int v[N],w[N];int f[N][N];//所有只考虑前i个物品&#xff0c;**且总体积不超过j**的选法…

趋势分析:2024年 2D CAD 在工业工程软件中的市场现状

文章概览 CAD发展趋势 一、现状 二、2D CAD在工业工程规划软件中的作用 三、工业工程师使用什么软件&#xff1f; 四、DraftSight&#xff1a;功能强大的工业工程软件 实际工业工程应用 一、ERIKS&#xff1a;使用 DraftSight 管理大量 2D 图纸 二、Sealed Air&#xff1…

过敏者的福音:猫毛克星大揭秘!使用宠物空气净化器效果如何?

对于猫毛过敏者来说&#xff0c;家中爱宠的陪伴与过敏的困扰并存&#xff0c;给他们的日常生活带来了极大的不便。猫毛过敏者常常因为与猫咪接触后出现打喷嚏、鼻塞、眼睛发痒等症状而苦恼&#xff0c;严重时甚至可能影响到他们的呼吸健康。 然而&#xff0c;这并不意味着猫毛…

JavaWeb笔记整理+图解——Filter过滤器

欢迎大家来到这一篇章——Filter过滤器 监听器和过滤器都是JavaWeb服务器三大组件(Servlet、监听器、过滤器)之一,他们对于Web开发起到了不可缺少的作用。 ps:想要补充Java知识的同学们可以移步我已经完结的JavaSE笔记,里面整理了大量详细的知识点和图解,可以帮你快速掌…

Java 并发编程面试二

目录 一、并发编程三要素? 二、实现可见性的方法有哪些? 三、多线程的价值? 四、创建线程的有哪些方式? 五、创建线程的三种方式的对比? 六、Java 线程具有五中基本状态 七、什么是线程池?有哪几种创建方式 八、四种线程池的创建 九、线程池的优点? 十、常用的…

精益管理|AIRIOT智慧变电站管理解决方案

随着社会电气化进程的加速&#xff0c;电力需求与日俱增&#xff0c;变电站作为电网的关键节点&#xff0c;其稳定性和智能化管理水平直接关系到整个电力系统的高效运作。传统变电站管理平台难以适应现代电力系统复杂管理需求&#xff0c;存在如下痛点&#xff1a; 数据收集与…

【机器学习】深入探索机器学习:利用机器学习探索股票价格预测的新路径

❀机器学习 &#x1f4d2;1. 引言&#x1f4d2;2. 多种机器学习算法的应用&#x1f4d2;3. 机器学习在股票价格预测中的应用现状&#x1f389;数据收集与预处理&#x1f389;模型构建与训练&#x1f308;模型评估与预测&#x1f31e;模型评估&#x1f319;模型预测⭐注意事项 &…

请问Java8进阶水平中,常用的设计模式有哪些?

设计模式通常被分为三大类&#xff1a;创建型&#xff08;Creational&#xff09;、结构型&#xff08;Structural&#xff09;和行为型&#xff08;Behavioral&#xff09;。以下是这20个设计模式的分类&#xff1a; 创建型&#xff08;Creational&#xff09;设计模式&#…

Linux Centos内网环境中安装mysql5.7详细安装过程

一、下载安装包 下载地址&#xff08;可下载历史版本&#xff09;&#xff1a; https://downloads.mysql.com/archives/community 二、解压到安装路径 tar -zxvf mysql-5.7.20-linux-glibc2.12-x86_64.tar.gz三、重命名 mv /usr/local/mysql-5.7.20-linux-glibc2.12-x86_64 …