Pytorch指定数据加载器使用子进程

torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True,num_workers=4, pin_memory=True)

num_workers 参数是 DataLoader 类的一个参数,它指定了数据加载器使用的子进程数量。通过增加 num_workers 的数量,可以并行地读取和预处理数据,从而提高数据加载的速度。

通常情况下,增加 num_workers 的数量可以提高数据加载的效率,因为它可以使数据加载和预处理工作在多个进程中同时进行。然而,当 num_workers 的数量超过一定阈值时,增加更多的进程可能不会再带来更多的性能提升,甚至可能会导致性能下降。

这是因为增加 num_workers 的数量也会增加进程间通信的开销。当 num_workers 的数量过多时,进程间通信的开销可能会超过并行化所带来的收益,从而导致性能下降。

此外,还需要考虑到计算机硬件的限制。如果你的计算机 CPU 核心数量有限,增加 num_workers 的数量也可能会导致性能下降,因为每个进程需要占用 CPU 核心资源。

因此,对于 num_workers 参数的设置,需要根据具体情况进行调整和优化。通常情况下,一个合理的 num_workers 值应该在 2 到 8 之间,具体取决于你的计算机硬件配置和数据集大小等因素。在实际应用中,可以通过尝试不同的 num_workers 值来找到最优的配置。

综上所述,当 num_workers 的值从 4 增加到 8 时,如果你的计算机硬件配置和数据集大小等因素没有发生变化,那么两者之间的性能差异可能会很小,或者甚至没有显著差异。

测试代码如下

import torch
import torchvision
import matplotlib.pyplot as plt
import torchvision.models as models
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
import timeif __name__ == '__main__':mp.freeze_support()train_on_gpu = torch.cuda.is_available()if not train_on_gpu:print('CUDA is not available. Training on CPU...')else:print('CUDA is available! Training on GPU...')device = torch.device("cuda" if torch.cuda.is_available() else "cpu")batch_size = 4# 设置数据预处理的转换transform = torchvision.transforms.Compose([torchvision.transforms.Resize((512,512)),  # 调整图像大小为 224x224torchvision.transforms.ToTensor(),  # 转换为张量torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 归一化])dataset = torchvision.datasets.ImageFolder('C:\\Users\\ASUS\\PycharmProjects\\pythonProject1\\cats_and_dogs_train',transform=transform)val_ratio = 0.2val_size = int(len(dataset) * val_ratio)train_size = len(dataset) - val_sizetrain_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])train_dataset = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True,num_workers=4, pin_memory=True)val_dataset = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True,num_workers=4, pin_memory=True)model = models.resnet18()num_classes = 2for param in model.parameters():param.requires_grad = Falsemodel.fc = nn.Sequential(nn.Dropout(),nn.Linear(model.fc.in_features, num_classes),nn.LogSoftmax(dim=1))optimizer = optim.Adam(model.parameters(), lr=0.001)criterion = nn.CrossEntropyLoss().to(device)model.to(device)filename = "recognize_cats_and_dogs.pt"def save_checkpoint(epoch, model, optimizer, filename):checkpoint = {'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': loss,}torch.save(checkpoint, filename)num_epochs = 3train_loss = []for epoch in range(num_epochs):running_loss = 0correct = 0total = 0epoch_start_time = time.time()for i, (inputs, labels) in enumerate(train_dataset):# 将数据放到设备上inputs, labels = inputs.to(device), labels.to(device)# 前向计算outputs = model(inputs)# 计算损失和梯度loss = criterion(outputs, labels)optimizer.zero_grad()loss.backward()# 更新模型参数optimizer.step()# 记录损失和准确率running_loss += loss.item()train_loss.append(loss.item())_, predicted = torch.max(outputs.data, 1)correct += (predicted == labels).sum().item()total += labels.size(0)accuracy_train = 100 * correct / total# 在测试集上计算准确率with torch.no_grad():running_loss_test = 0correct_test = 0total_test = 0for inputs, labels in val_dataset:inputs, labels = inputs.to(device), labels.to(device)outputs = model(inputs)loss = criterion(outputs, labels)running_loss_test += loss.item()_, predicted = torch.max(outputs.data, 1)correct_test += (predicted == labels).sum().item()total_test += labels.size(0)accuracy_test = 100 * correct_test / total_test# 输出每个 epoch 的损失和准确率epoch_end_time = time.time()epoch_time = epoch_end_time - epoch_start_timeprint("Epoch [{}/{}], Time: {:.4f}s, Loss: {:.4f}, Train Accuracy: {:.2f}%, Loss: {:.4f}, Test Accuracy: {:.2f}%".format(epoch + 1, num_epochs,epoch_time,running_loss / len(val_dataset),accuracy_train, running_loss_test / len(val_dataset), accuracy_test))save_checkpoint(epoch, model, optimizer, filename)plt.plot(train_loss, label='Train Loss')# 添加图例和标签plt.legend()plt.xlabel('Epochs')plt.ylabel('Loss')plt.title('Training Loss')# 显示图形plt.show()

不同num_workers的结果如下

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/118811.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

如何将音频与视频分离

您一定经历过这样的情况:当你非常喜欢视频中的背景音乐时,希望将音频从视频中分离出来,以便你可以在音乐播放器中收听音乐。有没有一种有效的方法可以帮助您快速从视频中提取音频呢?当然是有的啦,在下面的文章中&#…

根据输入类型来选择函数不同的实现方法functools.singledispatch

【小白从小学Python、C、Java】 【计算机等考500强证书考研】 【Python-数据分析】 根据输入类型来选择函数不同的实现方法 functools.singledispatch 输入6后,下列输出正确的是? from functools import singledispatch singledispatch def calcu…

树莓派系统文件解析

title: “树莓派系统文件分析” date: 2023-10-25 permalink: /posts/2023/10/blog-post-5/ tags: 树莓派 本篇blog来分析和总结下树莓派系统文件以及他们的作用。使用的系统是Raspberry Pi OS with desktop System: 64-bitKernel version: 6.1Debian version: 12 (bookworm)…

经典链表试题(二)

文章目录 一、移除链表元素1、题目介绍2、思路讲解3、代码实现 二、反转链表1、题目介绍2、思路讲解3、代码实现 三、相交链表1、题目介绍2、思路讲解3、代码实现 四、链表的中间结点1、题目介绍2、思路讲解3、代码实现 五、设计循环队列1、题目介绍2、思路讲解3、代码实现 六、…

2023高频前端面试题-http

1. HTTP有哪些⽅法? HTTP 1.0 标准中,定义了3种请求⽅法:GET、POST、HEAD HTTP 1.1 标准中,新增了请求⽅法:PUT、PATCH、DELETE、OPTIONS、TRACE、CONNECT 2. 各个HTTP方法的具体作用是什么? 方法功能G…

『C语言进阶』动态内存管理

🔥博客主页: 小羊失眠啦. 🔖系列专栏: C语言、Linux、Cpolar ❤️感谢大家点赞👍收藏⭐评论✍️ 文章目录 前言一、动态内存函数的介绍1.1 malloc和free函数1.2 calloc函数1.3 realloc函数 二、常见的动态内存错误2.1 …

行业模型应该如何去拆解?

行业模型应该如何去拆解? 拆解行业模型是一个复杂的过程,涉及对整个行业的深入分析和理解。下面是一些步骤和方法,可以帮助你系统地拆解行业模型: 1. 确定行业范围 定义行业:明确你要分析的行业是什么,包括…

React中的Virtual DOM(看这一篇就够了)

文章目录 前言了解Virtual DOMreact创建虚拟dom的方式React Element虚拟dom的流程虚拟dom和真实dom的对比后言 前言 hello world欢迎来到前端的新世界 😜当前文章系列专栏:react合集 🐱‍👓博主在前端领域还有很多知识和技术需要掌…

在pycharm中创建python模板文件

File——>Setting——>File and Code Templates——>Python Scripts 在文本框中输入模板内容

vue首页多模块布局(标题布局)

<template><div class"box"><div class"content"><div class"box1" style"background-color: rgb(245,23,156)">第一个</div><div class"box2" style"background-color: rgb(12,233,…

windows下使用FFmpeg开源库进行视频编解码完整步聚

最终解码效果: 1.UI设计 2.在控件属性窗口中输入默认值 3.复制已编译FFmpeg库到工程同级目录下 4.在工程引用FFmpeg库及头文件 5.链接指定FFmpeg库 6.使用FFmpeg库 引用头文件 extern "C" { #include "libswscale/swscale.h" #include "libavdevic…

composer安装thinkphp6报错

composer安装thinkphp6报错&#xff0c; 查看是否安装了对应的PHP扩展&#xff0c;我这边使用的是宝塔的环境&#xff0c;全程可以可视化操作 这样就可以安装完成了

【AIGC】百度文库文档助手之 - 一键生成PPT

百度文库文档助手之 - 一键生成PPT 引言一、文档助手&#xff1a;体验一键生成PPT二、文档助手&#xff1a;进阶用法三、其它生成PPT的方法3.1 ChatGPT3.2 文心一言 引言 就在上个月百度文库升级为一站式智能文档平台&#xff0c;开放四大AI能力&#xff1a;智能PPT、智能总结、…

【Ansible自动化运维工具 第一部分】Ansible常用模块详解(附各模块应用实例和Ansible环境安装部署)

Ansible常用模块 一、Ansible1.1 简介1.2 工作原理1.3 Ansible的特性1.3.1 特性一&#xff1a;Agentless&#xff0c;即无Agent的存在1.3.2 特性二&#xff1a;幂等性 1.4 Ansible的基本组件 二、Ansible环境安装部署2.1 安装ansible2.2 查看基本信息2.3 配置远程主机清单 三、…

计算机中了mallox勒索病毒怎么办,勒索病毒解密,数据恢复

最近一段时间&#xff0c;云天数据恢复中心陆续收到很多企业的求助&#xff0c;企业的计算机服务器遭到了mallox勒索病毒攻击&#xff0c;导致企业的数据库无法正常使用&#xff0c;严重影响了企业的正常生产生活&#xff0c;为此&#xff0c;云天数据恢复中心的工程师通过对此…

【Java笔记+踩坑】设计模式——原型模式

导航&#xff1a; 【Java笔记踩坑汇总】Java基础JavaWebSSMSpringBootSpringCloud瑞吉外卖/黑马旅游/谷粒商城/学成在线设计模式面试题汇总性能调优/架构设计源码-CSDN博客​ 目录 零、经典的克隆羊问题&#xff08;复制10只属性相同的羊&#xff09; 一、传统方案&#xff1…

酒类商城小程序怎么做

随着互联网的快速发展&#xff0c;线上购物越来越普及。酒类商品也慢慢转向线上销售&#xff0c;如何搭建一个属于自己的酒类小程序商城呢&#xff1f;下面就让我们一起来看看吧&#xff01; 一、登录乔拓云平台 首先&#xff0c;我们需要进入乔拓云平台的后台&#xff0c;点击…

使用boost.mysql来操作mysql 数据库

准备条件 1. visual studio 2019 2. boost库 3. 安装本地的mysql 服务器&#xff0c;boost.mysql对mysql有版本要求最好8.0&#xff0c;具体参考官方文档 安装 使用Nuget安装boost 要安装 openssl&#xff0c;否则的话编译其他项目会产生依赖ssl的错误 安装mysql 省略 …

Java操作Excel

一、Java操作Excel 二、Excel根据单元格状态自动变更行背景颜色 用excel记录和跟进工作的时候&#xff0c;设置背景颜色&#xff0c;以此让记录更加突出&#xff0c;方便查看&#xff0c;但是手动修改背景颜色一来麻烦容易漏&#xff0c;也可能颜色不统一&#xff0c;我们可以试…

PHP 数据库交互优化,根据传参查询

接上文 修改以下内容 将查询的 uid 改为 username&#xff0c;同时在 user 和 message 两张表中查询 $sql "select m.id,u.username,m.title,m.content from user u,message m where u.idm.uid;"根据 message 中的 id 查询&#xff0c;形式为 http://127.0.0.1/m…