使用PyTorch高效读取二进制数据集进行训练

使用pickle制作类cifar10二进制格式的数据集

使用pytorc框架来训练(以猫狗大战数据集为例)

此方法是为了实现阿里云PAI studio上可视化训练模型时使用的数据格式。

一、制作类cifar10二进制格式数据

import os, cv2
from pickled import *
from load_data import *data_path = './data_n/test'
file_list = './data_n/test.txt'
save_path = './bin'if __name__ == '__main__':data, label, lst = read_data(file_list, data_path, shape=128)pickled(save_path, data, label, lst, bin_num=1)

read_data模块

import cv2
import os
import numpy as npDATA_LEN = 49152
CHANNEL_LEN = 16384
SHAPE = 128def imread(im_path, shape=None, color="RGB", mode=cv2.IMREAD_UNCHANGED):im = cv2.imread(im_path, cv2.IMREAD_UNCHANGED)if color == "RGB":im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)# im = np.transpose(im, [2, 1, 0])if shape != None:assert isinstance(shape, int) im = cv2.resize(im, (shape, shape))return imdef read_data(filename, data_path, shape=None, color='RGB'):"""filename (str): a file data file is stored in such format:image_name  labeldata_path (str): image data folderreturn (numpy): a array of image and a array of label""" if os.path.isdir(filename):print("Can't found data file!")else:f = open(filename)lines = f.read().splitlines()count = len(lines)data = np.zeros((count, DATA_LEN), dtype=np.uint8)#label = np.zeros(count, dtype=np.uint8)lst = [ln.split(' ')[0] for ln in lines]label = [int(ln.split(' ')[1]) for ln in lines]idx = 0s, c = SHAPE, CHANNEL_LENfor ln in lines:fname, lab = ln.split(' ')im = imread(os.path.join(data_path, fname), shape=s, color='RGB')'''im = cv2.imread(os.path.join(data_path, fname), cv2.IMREAD_UNCHANGED)im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)im = cv2.resize(im, (s, s))'''
#      print(len(np.reshape(im[:,:,0], c))) # 1024data[idx, :c] = np.reshape(im[:, :, 0], c)data[idx, c:2*c] = np.reshape(im[:, :, 1], c)data[idx, 2*c:] = np.reshape(im[:, :, 2], c)label[idx] = int(lab)idx = idx + 1return data, label, lst

pickled模块

import os
import pickleBIN_COUNTS = 5def pickled(savepath, data, label, fnames, bin_num=BIN_COUNTS, mode="train"):'''savepath (str): save pathdata (array): image data, a nx3072 arraylabel (list): image label, a list with length nfnames (str list): image names, a list with length nbin_num (int): save data in several filesmode (str): {'train', 'test'}'''assert os.path.isdir(savepath)total_num = len(fnames)samples_per_bin = total_num / bin_numassert samples_per_bin > 0idx = 0for i in range(bin_num): start = int(i*samples_per_bin)end = int((i+1)*samples_per_bin)print(start)print(end)if end <= total_num:dict = {'data': data[start:end, :],'labels': label[start:end],'filenames': fnames[start:end]}else:dict = {'data': data[start:, :],'labels': label[start:],'filenames': fnames[start:]}if mode == "train":dict['batch_label'] = "training batch {} of {}".format(idx, bin_num)else:dict['batch_label'] = "testing batch {} of {}".format(idx, bin_num)#    with open(os.path.join(savepath, 'data_batch_'+str(idx)), 'wb') as fi:with open(os.path.join(savepath, 'batch_tt' + str(idx)), 'wb') as fi:pickle.dump(dict, fi)idx = idx + 1def unpickled(filename):#assert os.path.isdir(filename)assert os.path.isfile(filename)with open(filename, 'rb') as fo:dict = pickle.load(fo)return dict

测试生成的二进制数据


import os
import pickle
import numpy as np
import cv2def load_batch(fpath):with open(fpath, 'rb') as f:d = pickle.load(f)data = d["data"]labels = d["labels"]return data, labelsdef load_data(dirname, one_hot=False):X_train = []Y_train = []for i in range(0):fpath = os.path.join(dirname, 'data_batch_' + str(i))print(fpath)data, labels = load_batch(fpath)if i == 0:X_train = dataY_train = labelselse:X_train = np.concatenate([X_train, data], axis=0)Y_train = np.concatenate([Y_train, labels], axis=0)ftpath = os.path.join(dirname, 'batch_tt0')X_test, Y_test = load_batch(ftpath)X_test = np.dstack((X_test[:, :16384], X_test[:, 16384:32768],X_test[:, 32768:]))X_test = np.reshape(X_test, [-1, 128, 128, 3])print(X_test.shape)xx_test = np.transpose(X_test,(0, 3, 1, 2))print(xx_test.shape)
#    print(X_test[2])imgs = X_test[2:4]img = imgs[1]print(img.shape)img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)cv2.imshow('img', img)cv2.waitKey(0)if __name__ == '__main__':dirname = 'test'load_data(dirname)

二、使用制作好的数据训练

import torch
import os
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from PIL import Image
import pickle
import numpy as np#device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")def load_batch(fpath):with open(fpath, 'rb') as f:d = pickle.load(f)data = d["data"]labels = d["labels"]return data, labelsdef load_data(dirname, one_hot=False, train=False):print(dirname)if train:X_train = []Y_train = []for i in range(1):fpath = os.path.join(dirname, 'data_batch_' + str(i))print(fpath)data, labels = load_batch(fpath)if i == 0:X_train = dataY_train = labelselse:X_train = np.concatenate([X_train, data], axis=0)Y_train = np.concatenate([Y_train, labels], axis=0)X_train = np.dstack((X_train[:, :16384], X_train[:, 16384:32768],X_train[:, 32768:]))X_train = np.reshape(X_train, [-1, 128, 128, 3])#       X_train = np.transpose(X_train, (0, 3, 1, 2))return X_train, Y_trainelse:ftpath = os.path.join(dirname, 'test_batch_0')print(ftpath)X_test, Y_test = load_batch(ftpath)X_test = np.dstack((X_test[:, :16384], X_test[:, 16384:32768],X_test[:, 32768:]))X_test = np.reshape(X_test, [-1, 128, 128, 3])# 这里不需要转化数据格式[n, h, w, c]#       X_test = np.transpose(X_test, (0, 3, 1, 2))return X_test, Y_testclass MyDataset(torch.utils.data.Dataset):def __init__(self, namedir, transform=None, train=False):super().__init__()self.namedir = namedirself.transform = transformself.train = trainself.datas, self.labels = load_data(self.namedir, train=self.train)def __getitem__(self, index):
#        print(index)imgs = self.datas[index]
#        print(imgs.shape)
#        print(imgs)img_labes = int(self.labels[index])#       print(img_labes)if self.transform is not None:imgs = self.transform(imgs)return imgs, img_labesdef __len__(self):return len(self.labels)class MyDataset_s(torch.utils.data.Dataset):def __init__(self, datatxt, transform=None):super().__init__()imgs = []fh = open(datatxt, 'r')for line in fh:line = line.rstrip()words = line.split()imgs.append((words[0], int(words[1])))self.imgs = imgsself.transform = transformdef __getitem__(self, index):fn, label = self.imgs[index]img = Image.open(fn).convert('RGB')if self.transform is not None:img = self.transform(img)return img, labeldef __len__(self):return len(self.imgs)mean = [0.5071, 0.4867, 0.4408]
stdv = [0.2675, 0.2565, 0.2761]transform = transforms.Compose([#    transforms.Resize([224, 224]),#    transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(mean=mean, std=stdv)])train_data = MyDataset(namedir='data\\train\\', transform=transform, train=True)
trainloader = torch.utils.data.DataLoader(dataset=train_data, batch_size=4, shuffle=True)
test_data = MyDataset(namedir='data\\val\\', transform=transform, train=False)
testloader = torch.utils.data.DataLoader(dataset=test_data, batch_size=4, shuffle=True)classes = ('cat', 'dog')class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)self.conv4 = nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1)#        self.conv5 = nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1)self.fc1 = nn.Linear(32 * 8 * 8, 256)self.fc2 = nn.Linear(256, 64)self.fc3 = nn.Linear(64, 2)def forward(self, x):                       # (n, 3, 128, 128)x = self.pool(F.relu(self.conv1(x)))    # (n, 16, 64, 64)x = self.pool(F.relu(self.conv2(x)))    # (n, 32, 32, 32)x = self.pool(F.relu(self.conv3(x)))    # (n, 64, 16, 16)x = self.pool(F.relu(self.conv4(x)))    # (n, 32, 8, 8)#        x = self.pool(F.relu(self.conv5(x)))# print(x)x = x.view(-1, 32 * 8 * 8)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return xclass VGG16(nn.Module):def __init__(self):super(VGG16, self).__init__()# 3 * 224 * 224self.conv1_1 = nn.Conv2d(3, 64, 3)  # 64 * 222 * 222self.conv1_2 = nn.Conv2d(64, 64, 3, padding=(1, 1))  # 64 * 222 * 222self.maxpool1 = nn.MaxPool2d((2, 2), padding=(1, 1))  # pooling 64 * 112 * 112self.conv2_1 = nn.Conv2d(64, 128, 3)  # 128 * 110 * 110self.conv2_2 = nn.Conv2d(128, 128, 3, padding=(1, 1))  # 128 * 110 * 110self.maxpool2 = nn.MaxPool2d((2, 2), padding=(1, 1))  # pooling 128 * 56 * 56self.conv3_1 = nn.Conv2d(128, 256, 3)  # 256 * 54 * 54self.conv3_2 = nn.Conv2d(256, 256, 2, padding=(1, 1))  # 256 * 54 * 54self.conv3_3 = nn.Conv2d(256, 256, 3, padding=(1, 1))  # 256 * 54 * 54self.maxpool3 = nn.MaxPool2d((2, 2), padding=(1, 1))  # 256 * 28 * 28self.conv4_1 = nn.Conv2d(256, 512, 3)  # 512 * 26 * 26self.conv4_2 = nn.Conv2d(512, 512, 3, padding=(1, 1))  # 512 * 26 * 26self.conv4_3 = nn.Conv2d(512, 512, 3, padding=(1, 1))  # 512 * 26 * 26self.maxpool4 = nn.MaxPool2d((2, 2), padding=(1, 1))  # pooling 512 * 14 * 14self.conv5_1 = nn.Conv2d(512, 512, 3)  # 512 * 12 * 12self.conv5_2 = nn.Conv2d(512, 512, 3, padding=(1, 1))  # 512 * 12 * 12self.conv5_3 = nn.Conv2d(512, 512, 3, padding=(1, 1))  # 512 * 12 * 12self.maxpool5 = nn.MaxPool2d((2, 2), padding=(1, 1))  # pooling 512 * 7 * 7# viewself.fc1 = nn.Linear(512 * 7 * 7, 512)self.fc2 = nn.Linear(512, 64)self.fc3 = nn.Linear(64, 2)def forward(self, x):# x.size(0)即为batch_sizein_size = x.size(0)out = self.conv1_1(x)  # 222out = F.relu(out)out = self.conv1_2(out)  # 222out = F.relu(out)out = self.maxpool1(out)  # 112out = self.conv2_1(out)  # 110out = F.relu(out)out = self.conv2_2(out)  # 110out = F.relu(out)out = self.maxpool2(out)  # 56out = self.conv3_1(out)  # 54out = F.relu(out)out = self.conv3_2(out)  # 54out = F.relu(out)out = self.conv3_3(out)  # 54out = F.relu(out)out = self.maxpool3(out)  # 28out = self.conv4_1(out)  # 26out = F.relu(out)out = self.conv4_2(out)  # 26out = F.relu(out)out = self.conv4_3(out)  # 26out = F.relu(out)out = self.maxpool4(out)  # 14out = self.conv5_1(out)  # 12out = F.relu(out)out = self.conv5_2(out)  # 12out = F.relu(out)out = self.conv5_3(out)  # 12out = F.relu(out)out = self.maxpool5(out)  # 7# 展平out = out.view(in_size, -1)out = self.fc1(out)out = F.relu(out)out = self.fc2(out)out = F.relu(out)out = self.fc3(out)#       out = F.log_softmax(out, dim=1)return outclass VGG8(nn.Module):def __init__(self):super(VGG8, self).__init__()# 3 * 224 * 224self.conv1_1 = nn.Conv2d(3, 64, 3)  # 64 * 222 * 222self.maxpool1 = nn.MaxPool2d((2, 2), padding=(1, 1))  # pooling 64 * 112 * 112self.conv2_1 = nn.Conv2d(64, 128, 3)  # 128 * 110 * 110self.maxpool2 = nn.MaxPool2d((2, 2), padding=(1, 1))  # pooling 128 * 56 * 56self.conv3_1 = nn.Conv2d(128, 256, 3)  # 256 * 54 * 54self.maxpool3 = nn.MaxPool2d((2, 2), padding=(1, 1))  # 256 * 28 * 28self.conv4_1 = nn.Conv2d(256, 512, 3)  # 512 * 26 * 26self.maxpool4 = nn.MaxPool2d((2, 2), padding=(1, 1))  # pooling 512 * 14 * 14self.conv5_1 = nn.Conv2d(512, 512, 3)  # 512 * 12 * 12self.maxpool5 = nn.MaxPool2d((2, 2), padding=(1, 1))  # pooling 512 * 7 * 7# viewself.fc1 = nn.Linear(512 * 7 * 7, 512)self.fc2 = nn.Linear(512, 64)self.fc3 = nn.Linear(64, 2)def forward(self, x):# x.size(0)即为batch_sizein_size = x.size(0)out = self.conv1_1(x)  # 222out = F.relu(out)out = self.maxpool1(out)  # 112out = self.conv2_1(out)  # 110out = F.relu(out)out = self.maxpool2(out)  # 56out = self.conv3_1(out)  # 54out = F.relu(out)out = self.maxpool3(out)  # 28out = self.conv4_1(out)  # 26out = F.relu(out)out = self.maxpool4(out)  # 14out = self.conv5_1(out)  # 12out = F.relu(out)out = self.maxpool5(out)  # 7# 展平out = out.view(in_size, -1)out = self.fc1(out)out = F.relu(out)out = self.fc2(out)out = F.relu(out)out = self.fc3(out)#       out = F.log_softmax(out, dim=1)return outnet = Net()#net.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.005, momentum=0.9)if __name__ == '__main__':for epoch in range(11):running_loss = 0.0for i, data in enumerate(trainloader, 0):inputs, labels = data
#            inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()if i % 100 == 99:print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 100))running_loss = 0.0if epoch % 2 == 0:correct = 0total = 0with torch.no_grad():for data in testloader:images, labels = data#                    images, labels = images.to(device), labels.to(device)outputs = net(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy of the network on the 1000 test images: %d %%' % (100 * correct / total))print('finished !!!')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/38991.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

非常疑惑文章变成了仅VIP可读

关于博客发布的一些感想 挺久没上 CSDN 了&#xff0c;平时遇到问题都是问 ChatGPT&#xff0c;自行查阅资料的时间也不多了&#xff0c;写博文的频率也随之降低。偶尔会记些笔记自用&#xff0c;也没有再发布出来。 今天在谷歌查了个问题&#xff0c;突然想发个博客&#xf…

微信小程序渲染层与逻辑层交互原理

1. 网页开发与小程序开发有何不同&#xff1f; 2. 小程序运行环境 3. 页面渲染技术选型 1. 纯客户端技术&#xff1b; 2. 纯Web技术&#xff1b; 3. 用客户端原生技术与Web技术结合的混合技术&#xff08;Hybrid&#xff09;&#xff0c;小程序就是使用的这种技术&#xff1…

零基础学MySQL:从入门到实践的完整指南

引言&#xff1a; MySQL&#xff0c;作为全球最受欢迎的开源关系型数据库管理系统之一&#xff0c;以其高性能、易用性和灵活性&#xff0c;在Web开发、数据分析等领域占据着举足轻重的地位。如果你是一位编程新手&#xff0c;想要踏入数据库管理的大门&#xff0c;本文将从零…

MacBook关闭谷歌浏览器双指左右移动(扫动)前进后退功能

这个功能真的很反人类&#xff0c;正常上下滑动页面的时候很容易误操作&#xff0c;尤其是当你在一个页面上做了很多的编辑工作后误触发了此手势&#xff0c;那真叫一个崩溃&#xff01; 其实这应该是 Macbook 触控板提供的一个快捷操作&#xff0c;跟浏览器本身估计没关系&am…

基于大模型构建企业私有智能知识库落地的简单实践

​ 随着人工智能技术的飞速发展&#xff0c;大模型在企业知识管理中的应用日益广泛。下文是作者围绕如何基于大模型技术构建企业私有知识库&#xff0c;以提升企业的知识管理效率和创新能力的一些思考和简单实践。 ​ 本文对企业知识库的落地场景暂不作广泛的展开&#xff0c;…

Oracle PL / SQL数据类型

PL / SQL是SQL的过程语言扩展&#xff0c;它支持与SQL对数据库相同的数据类型。 PL / SQL可以处理任何数据库数据类型&#xff0c;并且还有自己的数据类型。 VARCHAR2&#xff1a;我们将使用数据类型VARCHAR2处理PL / SQL中的字符串。 PL / SQL VARCHAR2最多可容纳32,767个字…

成都百洲文化传媒有限公司网上开店怎么样?

在电商的浪潮中&#xff0c;每一个品牌都在寻找自己的定位与突破。而成都百洲文化传媒有限公司&#xff0c;正是这场浪潮中的一股强大助力&#xff0c;凭借其专业的电商服务&#xff0c;助力无数品牌实现了飞跃发展。 一、专业铸就品质&#xff0c;服务成就品牌 成都百洲文化传…

Runway:Gen-3 Alpha 文生视频现已开放

Runway 自 6 月 17 号公布 Gen-3 Alpha 快半个月了, 现在终于对所有人开放了&#xff0c;当然前提是你至少订阅了标准版&#xff08;12 美刀/月), 传送门&#xff1a;runwayml.com

中原汉族与北方游牧民族舞蹈文化在这段剧中表现得淋漓尽致,且看!

中原汉族与北方游牧民族舞蹈文化在这段剧中表现得淋漓尽致&#xff0c;且看&#xff01; 《神探狄仁杰》之使团喋血记是一部深入人心的历史侦探剧&#xff0c;不仅以其曲折离奇的案情和狄仁杰的睿智形象吸引观众&#xff0c;更以其对唐代文化的精准再现而备受赞誉。#李秘书讲写…

引力波信号的连续小波变换(Python)

提到引力波&#xff0c;就要提到引力波天文学。引力波天文学是观测天文学的一个新兴分支&#xff0c;主要利用引力波&#xff08;微小时空扭曲&#xff09;观测发出引力辐射的天体系统&#xff0c;比如中子星和黑洞等波源、超新星等事件以及大爆炸后不久的早期宇宙演化过程。 …

Java代码基础算法练习-计算平均身高-2024.07.02

任务描述&#xff1a; n个同学站成一排&#xff0c;求它们的平均身高 解决思路&#xff1a; 输入的学生人数为 for 循环次数&#xff0c;循环中每输入一个值就添加在总数中&#xff0c;循环结束总数除以对应的学生人数得到平均身高 代码示例&#xff1a; package a4_2024_07;…

泽州县和美环保科技有限公司——绿色环保的践行者

在环保产业蓬勃发展的今天&#xff0c;泽州县和美环保科技有限公司以其卓越的技术和强大的实力&#xff0c;成为山西省危废综合处置领域的翘楚。作为雅居乐环保集团的全资子公司&#xff0c;和美环保科技有限公司紧跟集团发展战略&#xff0c;致力于为社会提供全方位的环境服务…

html之内联样式

内联样式&#xff08;inline styles&#xff09;是在HTML元素的style属性中直接定义的CSS样式。与外部样式表或内部样式表不同&#xff0c;内联样式仅应用于特定的HTML元素。使用内联样式时&#xff0c;可以在HTML标签中直接添加样式&#xff0c;而无需通过外部或内部的CSS文件…

JavaSE多线程线程池

文章目录 1. 多线程入门1.1 多线程相关概念1.2 什么是多线程1.3 多线程的创建方式1.3.1 继承 Thread 的方式1.3.2 实现 Runnable 接口的方式1.3.3 实现 Callable 接口的方式1.3.4 Thread 类中常用方法1.3.5 sleep() 方法 和 wait() 方法区别&#xff1a; 2. 线程安全2.1 线程安…

项目实战--Spring Boot + Minio文件切片上传下载

1.搭建环境 引入项目依赖 <!-- 操作minio的java客户端--> <dependency><groupId>io.minio</groupId><artifactId>minio</artifactId><version>8.5.2</version> </dependency> <!-- jwt鉴权相应依赖--> &…

Linux下编程之内存检查

前言 我们在进行编程时&#xff0c;有时不免会无意中写出一些容易导致内存问题&#xff08;可能一时表象上正常&#xff09;的代码&#xff0c;导致的后果肯定是不好的&#xff0c;就像一颗颗“哑弹”&#xff0c;令人心慌。网上推荐的辅助工具很多&#xff0c;此篇文章…

Unity 功能 之 创建 【Unity Package】 Manager 自己自定义管理的包的简单整理

Unity 功能 之 创建 【Unity Package】 Manager 自己自定义管理的包的简单整理 目录 Unity 功能 之 创建 【Unity Package】 Manager 自己自定义管理的包的简单整理 一、简单介绍 二、Unity Package 的目录结构 三、package.json 说明 四、程序集定义 1、程序集定义说明 …

在C#/Net中使用Mqtt

net中MQTT的应用场景 c#常用来开发上位机程序&#xff0c;或者其他一些跟设备打交道比较多的系统&#xff0c;所以会经常作为拥有数据的终端&#xff0c;可以用来采集上传数据&#xff0c;而MQTT也是物联网常用的协议&#xff0c;所以下面介绍在C#开发中使用MQTT。 安装MQTTn…

使用 Mybatis 时,调用 DAO接口时是怎么调用到 SQL 的?

Mybatis 是一个流行的 Java 持久层框架&#xff0c;它提供了一种半自动的 SQL 映射方式&#xff0c;允许开发者在 Java 代码中以一种更加直观和灵活的方式来操作数据库。当你使用 Mybatis 调用 DAO 接口时&#xff0c;背后的工作流程大致如下&#xff1a; 接口定义&#xff1a;…

科普文:一文搞懂jvm实战(二)Cleaner回收jvm资源

概叙 在JDK9中新增了Cleaner类&#xff0c;该类的作用是用于替代finalize方法&#xff0c;更有效地释放资源并避免内存泄漏。 在JEP260提案中&#xff0c;封装了大部分Sun包内部的API之余&#xff0c;还引入了一些新的API&#xff0c;其中就包含着Cleaner这个工具类。Cleaner承…