关于pytorch的加载数据,cpu init, cpu getitem, gpu init

文章目录

    • 一. (cpu,init)图像加载到CPU内存,是在 __init__中函数中全部数据, 然后在item中取图像
    • 二.(cpu,get_item)是图像在 get_item函数中,载入图像到CPU
    • 三(gpu,init)是将图像加载到GPU, 在init函数中

跑多光谱估计的代码,参考:https://github.com/caiyuanhao1998/MST-plus-plus
原代码dataset一次加载所有图像到cpu内存中

一. (cpu,init)图像加载到CPU内存,是在 __init__中函数中全部数据, 然后在item中取图像

这种方法比较常用,读取图像的效率也高,但是cpu内存要够

from torch.utils.data import Dataset
import numpy as np
import random
import cv2
import h5py
import torch
class TrainDataset(Dataset):def __init__(self, data_root, crop_size, arg=True, bgr2rgb=True, stride=8):self.crop_size = crop_sizeself.hypers = []self.bgrs = []self.arg = argh,w = 482,512  # img shapeself.stride = strideself.patch_per_line = (w-crop_size)//stride+1self.patch_per_colum = (h-crop_size)//stride+1self.patch_per_img = self.patch_per_line*self.patch_per_columhyper_data_path = f'{data_root}/Train_spectral/'bgr_data_path = f'{data_root}/Train_RGB/'with open(f'{data_root}/split_txt/train_list.txt', 'r') as fin:hyper_list = [line.replace('\n','.mat') for line in fin]bgr_list = [line.replace('mat','jpg') for line in hyper_list]hyper_list.sort()bgr_list.sort()# hyper_list = hyper_list[:300]# bgr_list = bgr_list[:300]print(f'len(hyper) of ntire2022 dataset:{len(hyper_list)}')print(f'len(bgr) of ntire2022 dataset:{len(bgr_list)}')for i in range(len(hyper_list)):hyper_path = hyper_data_path + hyper_list[i]if 'mat' not in hyper_path:continuewith h5py.File(hyper_path, 'r') as mat:hyper =np.float32(np.array(mat['cube']))hyper = np.transpose(hyper, [0, 2, 1])bgr_path = bgr_data_path + bgr_list[i]assert hyper_list[i].split('.')[0] ==bgr_list[i].split('.')[0], 'Hyper and RGB come from different scenes.'bgr = cv2.imread(bgr_path)if bgr2rgb:bgr = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)bgr = np.float32(bgr)bgr = (bgr-bgr.min())/(bgr.max()-bgr.min())bgr = np.transpose(bgr, [2, 0, 1])  # [3,482,512]self.hypers.append(hyper)self.bgrs.append(bgr)mat.close()print(f'Ntire2022 scene {i} is loaded.')self.img_num = len(self.hypers)self.length = self.patch_per_img * self.img_numdef arguement(self, img, rotTimes, vFlip, hFlip):# Random rotationfor j in range(rotTimes):img = np.rot90(img.copy(), axes=(1, 2))# Random vertical Flipfor j in range(vFlip):img = img[:, :, ::-1].copy()# Random horizontal Flipfor j in range(hFlip):img = img[:, ::-1, :].copy()return imgdef __getitem__(self, idx):stride = self.stridecrop_size = self.crop_sizeimg_idx, patch_idx = idx//self.patch_per_img, idx%self.patch_per_imgh_idx, w_idx = patch_idx//self.patch_per_line, patch_idx%self.patch_per_linebgr = self.bgrs[img_idx]hyper = self.hypers[img_idx]bgr = bgr[:,h_idx*stride:h_idx*stride+crop_size, w_idx*stride:w_idx*stride+crop_size]hyper = hyper[:, h_idx * stride:h_idx * stride + crop_size,w_idx * stride:w_idx * stride + crop_size]rotTimes = random.randint(0, 3)vFlip = random.randint(0, 1)hFlip = random.randint(0, 1)if self.arg:bgr = self.arguement(bgr, rotTimes, vFlip, hFlip)hyper = self.arguement(hyper, rotTimes, vFlip, hFlip)return np.ascontiguousarray(bgr), np.ascontiguousarray(hyper)def __len__(self):return self.patch_per_img*self.img_num

二.(cpu,get_item)是图像在 get_item函数中,载入图像到CPU

这种方法可以处理大数据集,比如所有图像占用内存大于电脑内存的时候,用这种方法
但是由于读取图像放在了get_item中,训练的时候加载数据会比较慢。

class TrainDataset_single(Dataset):def __init__(self, data_root, crop_size, arg=True, bgr2rgb=True, stride=8):self.crop_size = crop_sizeself.hypers = []self.bgrs = []self.arg = argself.bgr2rgb = bgr2rgbh,w = 482,512  # img shapeself.stride = strideself.patch_per_line = (w-crop_size)//stride+1self.patch_per_colum = (h-crop_size)//stride+1self.patch_per_img = self.patch_per_line*self.patch_per_columhyper_data_path = f'{data_root}/Train_spectral/'bgr_data_path = f'{data_root}/Train_RGB/'with open(f'{data_root}/split_txt/train_list.txt', 'r') as fin:hyper_list = [line.replace('\n','.mat') for line in fin]bgr_list = [line.replace('mat','jpg') for line in hyper_list]hyper_list.sort()bgr_list.sort()# hyper_list = hyper_list[:300]# bgr_list = bgr_list[:300]print(f'len(hyper) of ntire2022 dataset:{len(hyper_list)}')print(f'len(bgr) of ntire2022 dataset:{len(bgr_list)}')for i in range(len(hyper_list)):hyper_path = hyper_data_path + hyper_list[i]bgr_path = bgr_data_path + bgr_list[i]# if 'mat' not in hyper_path:#     continue# with h5py.File(hyper_path, 'r') as mat:#     hyper =np.float32(np.array(mat['cube']))# hyper = np.transpose(hyper, [0, 2, 1])# assert hyper_list[i].split('.')[0] ==bgr_list[i].split('.')[0], 'Hyper and RGB come from different scenes.'# bgr = cv2.imread(bgr_path)# if bgr2rgb:#     bgr = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)# bgr = np.float32(bgr)# bgr = (bgr-bgr.min())/(bgr.max()-bgr.min())# bgr = np.transpose(bgr, [2, 0, 1])  # [3,482,512]self.hypers.append(hyper_path)self.bgrs.append(bgr_path)# mat.close()print(f'Ntire2022 scene {i} is loaded.')self.img_num = len(self.hypers)self.length = self.patch_per_img * self.img_numdef arguement(self, img, rotTimes, vFlip, hFlip):# Random rotationfor j in range(rotTimes):img = np.rot90(img.copy(), axes=(1, 2))# Random vertical Flipfor j in range(vFlip):img = img[:, :, ::-1].copy()# Random horizontal Flipfor j in range(hFlip):img = img[:, ::-1, :].copy()return imgdef __getitem__(self, idx):stride = self.stridecrop_size = self.crop_sizeimg_idx, patch_idx = idx//self.patch_per_img, idx%self.patch_per_imgh_idx, w_idx = patch_idx//self.patch_per_line, patch_idx%self.patch_per_linebgr_path = self.bgrs[img_idx]hyper_path = self.hypers[img_idx]# if 'mat' not in hyper_path:#     continuewith h5py.File(hyper_path, 'r') as mat:hyper =np.float32(np.array(mat['cube']))hyper = np.transpose(hyper, [0, 2, 1])# assert hyper_list[i].split('.')[0] ==bgr_list[i].split('.')[0], 'Hyper and RGB come from different scenes.'bgr = cv2.imread(bgr_path)if self.bgr2rgb:bgr = bgr[..., ::-1] #cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)bgr = np.float32(bgr)bgr = (bgr-bgr.min())/(bgr.max()-bgr.min())bgr = np.transpose(bgr, [2, 0, 1])  # [3,482,512]bgr = bgr[:,h_idx*stride:h_idx*stride+crop_size, w_idx*stride:w_idx*stride+crop_size]hyper = hyper[:, h_idx * stride:h_idx * stride + crop_size,w_idx * stride:w_idx * stride + crop_size]rotTimes = random.randint(0, 3)vFlip = random.randint(0, 1)hFlip = random.randint(0, 1)if self.arg:bgr = self.arguement(bgr, rotTimes, vFlip, hFlip)hyper = self.arguement(hyper, rotTimes, vFlip, hFlip)return np.ascontiguousarray(bgr), np.ascontiguousarray(hyper)def __len__(self):return self.patch_per_img*self.img_num

三(gpu,init)是将图像加载到GPU, 在init函数中

就是cpu内存不够不能使用方法一,且我们不像速度太慢不能使用方法二。
如果GPU显存比较大的时候,或者有多个GPU的时候,可以在init函数中将图像读取到若干个GPU中。

比如下面,将450张读取到gpu0, 另外450张读取到gpu1
这样TrainDataset_gpu[i] 返回的就是在gpu上的数据

"""
数据在不同的gpu上,不能使用dataloader
"""    
class TrainDataset_gpu(Dataset):def __init__(self, data_root, crop_size, arg=True, bgr2rgb=True, stride=8):self.crop_size = crop_sizeself.hypers = []self.bgrs = []self.arg = argself.bgr2rgb = bgr2rgbh,w = 482,512  # img shapeself.stride = strideself.patch_per_line = (w-crop_size)//stride+1self.patch_per_colum = (h-crop_size)//stride+1self.patch_per_img = self.patch_per_line*self.patch_per_columhyper_data_path = f'{data_root}/Train_spectral/'bgr_data_path = f'{data_root}/Train_RGB/'with open(f'{data_root}/split_txt/train_list.txt', 'r') as fin:hyper_list = [line.replace('\n','.mat') for line in fin]bgr_list = [line.replace('mat','jpg') for line in hyper_list]hyper_list.sort()bgr_list.sort()# hyper_list = hyper_list[:100]# bgr_list = bgr_list[:100]print(f'len(hyper) of ntire2022 dataset:{len(hyper_list)}')print(f'len(bgr) of ntire2022 dataset:{len(bgr_list)}')for i in range(len(hyper_list)):hyper_path = hyper_data_path + hyper_list[i]bgr_path = bgr_data_path + bgr_list[i]if 'mat' not in hyper_path:continuewith h5py.File(hyper_path, 'r') as mat:hyper =np.float32(np.array(mat['cube']))hyper = np.transpose(hyper, [0, 2, 1])assert hyper_list[i].split('.')[0] ==bgr_list[i].split('.')[0], 'Hyper and RGB come from different scenes.'bgr = cv2.imread(bgr_path)if bgr2rgb:bgr = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)bgr = np.float32(bgr)bgr = (bgr-bgr.min())/(bgr.max()-bgr.min())bgr = np.transpose(bgr, [2, 0, 1])  # [3,482,512]if i < 450:device = torch.device('cuda:0')self.hypers.append(torch.from_numpy(hyper).to(device))self.bgrs.append(torch.from_numpy(bgr).to(device))elif i<900:device = torch.device('cuda:1')self.hypers.append(torch.from_numpy(hyper).to(device))self.bgrs.append(torch.from_numpy(bgr).to(device))# mat.close()print(f'Ntire2022 scene {i} is loaded.')self.img_num = len(self.hypers)self.length = self.patch_per_img * self.img_numdef arguement(self, img, hyper, rotTimes, vFlip, hFlip):# Random rotationif rotTimes:img = torch.rot90(img, rotTimes, [1, 2])hyper = torch.rot90(hyper, rotTimes, [1, 2])# Random vertical Flipif vFlip:#img = img[:, :, ::-1]img = torch.flip(img, dims=[1])hyper = torch.flip(hyper, dims=[1])# Random horizontal Flipif hFlip:#img = img[:, ::-1, :]img = torch.flip(img, dims=[2])hyper = torch.flip(hyper, dims=[2])return img, hyperdef __getitem__(self, idx):stride = self.stridecrop_size = self.crop_sizeimg_idx, patch_idx = idx//self.patch_per_img, idx%self.patch_per_imgh_idx, w_idx = patch_idx//self.patch_per_line, patch_idx%self.patch_per_linebgr = self.bgrs[img_idx]hyper = self.hypers[img_idx]bgr = bgr[:,h_idx*stride:h_idx*stride+crop_size, w_idx*stride:w_idx*stride+crop_size]hyper = hyper[:, h_idx * stride:h_idx * stride + crop_size,w_idx * stride:w_idx * stride + crop_size]rotTimes = random.randint(0, 3)vFlip = random.randint(0, 1)hFlip = random.randint(0, 1)if self.arg:bgr, hyper = self.arguement(bgr, hyper, rotTimes, vFlip, hFlip)return bgr, hyper # np.ascontiguousarray(bgr.cpu().numpy()), np.ascontiguousarray(hyper.cpu().numpy()) def __len__(self):return self.patch_per_img*self.img_num

但是读取到GPU之后,训练的时候 好像不能使用dataloader, 容易报错。

这个时候自己设计一个 批处理函数,和shuffle

# 1.加载数据集
train_data = TrainDataset_gpu(data_root=opt.data_root, crop_size=opt.patch_size, bgr2rgb=True, arg=True, stride=opt.stride)# 2. 获取数据集的长度, 使是batch_size的倍数, 打乱顺序
inddd = np.arange(len(train_data))
l =len(inddd) -  (len(inddd)%opt.batch_size) 
inddd2 = np.random.permutation(inddd)[:l]
inddd2 = inddd2.reshape(-1, opt.batch_size) #batch num, batch size
print(len(train_data), len(inddd)%opt.batch_size, inddd2.shape)# 3. 读取每一个batch的图像
for i in range(inddd2.shape[0]):t0 = time.time()# 检索batch size个图像拼接为一个batchinddd3 = inddd2[i]#print('i, len, curlist:',i, len(inddd2), inddd3)images = []labels = []for j in inddd3:image, label = train_data[j]image = image[None, ...]label = label[None, ...]# print(i, j, image.shape, label.shape)# cv2.imwrite(f'{i:9d}_{j:4d}_image.png', (image[0].cpu().numpy().transpose(1,2,0)[...,[2,1,0]]*255).astype(np.uint8))# cv2.imwrite(f'{i:9d}_{j:4d}_label.png', (label[0].cpu().numpy().transpose(1,2,0)[...,[5,15,25]]*255).astype(np.uint8))images.append(image.cpu())labels.append(label.cpu())images = torch.cat(images, 0)labels = torch.cat(labels, 0)#print(images.shape, labels.shape)labels = labels.cuda()images = images.cuda()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/40331.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

如何使用 uni-app 快速集成聊天会话能力?

移动互联网时代&#xff0c;即时通讯&#xff08;IM&#xff09;功能是许多app不可或缺的一部分&#xff0c;然而在即时通讯app开发时&#xff0c;开发者常常面临着选择困难&#xff1a;是为每个平台单独开发应用&#xff0c;还是有限开发某个平台&#xff1f;uni-app的出现&am…

BLOB视频技术原理,如何将Blob视频转换为MP4格式?

BLOB视频的制作涉及将视频数据转换为二进制大对象&#xff08;BLOB&#xff09;格式&#xff0c;然后对其进行编码、分割、封装和传输。在目标设备上&#xff0c;BLOB数据被解析、解码和播放&#xff0c;同时可能需要进行实时的优化以适应播放条件。这种制作方式旨在提供一种高…

linux监控服务器磁盘、内存空间使用率到达90%发送邮件脚本

以下是一个使用Python编写的Linux监控服务器磁盘、内存空间使用率并在达到90%时发送邮件的脚本&#xff1a; import os import smtplib from email.mime.text import MIMEText from email.header import Header# 设置阈值 DISK_THRESHOLD 90 MEMORY_THRESHOLD 90# 获取磁盘使…

7月11日云技术研讨会 | 车载信息安全全流程实施方案

伴随着汽车的智能网联化发展&#xff0c;网络攻击也逐渐渗透漫延至汽车领域&#xff0c;汽车行业面临着重大的信息安全挑战。此外&#xff0c;UNECE WP.29 R155和ISO/SAE 21434等标准也对汽车的信息安全提出了规范化要求&#xff0c;旨在分阶段将产品全生命周期中由信息安全威胁…

中介子方程四十二

XXFXXuXXWXXuXXdXXrXXαXXuXpXXKXηXiXXnXXyXηXuXXrXXaXnXXαXLXyXXWXuXeXuXWXXyXLXαXXnXaXXrXXuXηXyXXnXXiXηXKXXpXuXXαXXrXXdXXuXWXπXXWXeXyXeXbXπXpXXNXXqXeXXrXXαXXuXpXXKXηXiXXnXXyXηXuXXrXXaXnXXαXLXyXXWXuXeXuXWXXyXLXαXXnXaXXrXXuXηXyXXnXXiXηXKXXpXuXXα…

JavaScript中的this指向

1. 全局环境下的this 在全局环境中&#xff08;在浏览器中是window对象&#xff0c;在Node.js中是global对象&#xff09;&#xff0c;this指向全局对象。 console.log(this window); // 在浏览器中为true console.log(this.document ! undefined); // true&#xff0c;因为…

opencv编译报错OpenCV does not recognize MSVC_VERSION “1940“

具体如下: CMake Warning at cmake/OpenCVDetectCXXCompiler.cmake:182 (message):OpenCV does not recognize MSVC_VERSION "1940". Cannot set OpenCV_RUNTIME Call Stack (most recent call first):CMakeLists.txt:174 (include) 打开源码\opencv\sources\cmak…

如何加密U盘?U盘加密软件推荐

U盘是我们最常用的移动存储设备&#xff0c;可以帮助我们随身携带大量数据。为了避免U盘数据泄露&#xff0c;我们需要加密保护U盘。那么&#xff0c;U盘该如何加密呢&#xff1f;下面小编就为大家推荐两款专业的U盘加密软件。 U盘超级加密3000 U盘超级加密3000是一款优秀的U盘…

大模型概述-定义/分类/训练/应用

大模型概述 随着时代的发展, 大模型各个领域的应用正在不断扩大. 本文尽力梳理各种材料, 将从概念定义, 类型分类, 训练以及应用等方面对大模型进行一个简要的概述. 如果你想了解大模型但是却缺乏基础的知识或者觉得无从下手, 那么阅读该文章可能对你有所帮助. 如果想了解更多…

react antd表格翻页时记录勾选状态

在Ant Design的React表格&#xff08;Table&#xff09;组件中&#xff0c;如果需要在翻页时记住勾选状态&#xff0c;可以通过以下步骤实现&#xff1a; 使用rowSelection属性来控制勾选状态&#xff0c;并添加preserveSelectedRowKeys: true以保留 key。 设置rowKey属性。 …

Django任意URL跳转漏洞(CVE-2018-14574)

目录 Django介绍 URL跳转漏洞介绍 Django任意URL跳转漏洞介绍 环境搭建 防御方法 前段时间在面试时&#xff0c;问到了URL跳转漏洞&#xff0c;我没有回答好&#xff0c;下午把URL跳转漏洞学习了&#xff0c;发现也不难&#xff0c;看来还需要学习的东西很多呀&#xff0c…

cadence symbol修改之一

cdaence virtuoso 复制cell&#xff0c;或者拷贝symbol之后&#xff0c;再次调用的时候&#xff0c;symbol还是跟随原来的cell名字 解决办法 打开对应的symbol 修改partName为 cellName

把前端打包放到Eladmin框架中运行

再resuorces目录创建static文件夹&#xff0c;然后把前端文件放进来 然后修改 ConfigurerAdapter文件&#xff0c;如下图所示 这样就可以通过ip端口/index.html 这样访问啦&#xff01;

基于Lua源码开发动态库供lua脚本使用

通过require的方式可以加载动态库&#xff0c;然后脚本就可以使用库中提供的函数&#xff0c;一般过程如下&#xff1a; 比如有一个动态库名为&#xff1a;MyFirstLua.dll 则使用时&#xff1a;MyFirstLua require("MyFirstLua") 导出的函数接口名称一定是 int l…

垂直领域大模型的机遇与挑战:从构建到应用

在人工智能技术的浪潮中,大模型以其强大的数据处理和学习能力,成为推动科技进步的重要力量。然而,这种跨领域应用的过程并非一帆风顺,既面临挑战也蕴含机遇。本文从复旦大学的研究工作出发,详细分析大模型的机遇与挑战。 背景 GPT4技术报告指出,GPT4仍处于通用人工智…

Mybatis连接数据库

文章目录 大纲定义类、创建表添加相关依赖五件套5.1 mybatis-config.xml5.2 MybatisUtils.java5.3 \**Mapper.xml5.4 \**Mapper.java5.5 \**Dao.java 测试类 大纲 在Java中定义类、在数据库汇总创建表添加依赖编写5件套测试 定义类、创建表 import java.time.LocalDate; imp…

土壤分析仪:分析土壤中的各种养分

土壤作为地球生命的摇篮&#xff0c;承载着农作物的生长与繁衍。土壤中的养分是农作物生长的关键。 一、土壤分析仪的工作原理 土壤分析仪是一种采用先进传感技术的仪器设备&#xff0c;能够精确测量土壤中的PH值、电导率、有机质含量、养分含量以及微生物数量等参数。它利用多…

# Kafka_深入探秘者(9):kafka 集群管理

Kafka_深入探秘者&#xff08;9&#xff09;&#xff1a;kafka 集群管理 一、kafka 集群概述 1、kafka 集群概述&#xff1a; 集群是一种计算机系统&#xff0c;它通过一组松散集成的计算机软件和/或硬件连接起来高度紧密地协作完成计算工作。在某种意义上&#xff0c;他们可…

ffmpeg下载/配置环境/测试

一、下载 1、访问FFmpeg官方网站下载页面&#xff1a;FFmpeg Download Page&#xff1b; 2、选择适合Windows的版本&#xff08;将鼠标移动到windows端&#xff09;。通常&#xff0c;你会找到“Windows builds from gyan.dev”或者“BtbN GitHub Releases”等选项&#xff0…

研0学习Python基础4

1.数组是一种存储大量同性质数据的连续内存空间&#xff0c;只要使用相同的变量名称&#xff0c;便可以连续访问 每一组数据。由于数组元素的便利性&#xff0c;使得大多数程序中都可以看到数组的身影。数组是一 个带有多个数据且模式相同的元素集合。比如&#xff0c;数值所…