Pytorch机器学习/深度学习代码笔记

代码步骤笔记

  • 导入模块
  • 设置参数
  • 数据预处理
    • 定义数据集
      • 1.Dataset
      • 2.ImageFolder
    • 加载数据集
      • DataLoader
  • torchvision--数据预处理要使用的库
    • torchvision.datasets
    • torchvision.models
    • torchvision.transforms
  • 训练网络参数
    • 训练前的准备
      • 设置指定的训练设备(GPU、CPU)
      • 定义损失函数
      • 定义优化器
    • 训练过程
    • 验证/测试过程
    • 运行

导入模块

import torch
from tensorboardX import SummaryWriter  //可视化

设置参数

batch_size=64
works=4
epochs=20
train_path="train"
val_path="val"

数据预处理

流程:先定义数据集,再将定义的数据集导入数据载入器(Dataloader)来读取数据。

定义数据集有两种方式,一种是自定义Dataset包装类,和DataLoader类一样,它是torch.utils.data的里的一个类,另一种是直接调用ImageFolder函数,它是torchvision.datasets里的函数。

定义数据集

1.Dataset

Dataset是一个抽象类,可以自定义数据集,为了能够方便的读取,需要将要使用的数据包装为Dataset类。
自定义的Dataset需要继承它并且实现两个成员方法:
1.__getitem__():该方法定义用索引(0到len(self))获取一条数据或一个样本。
2.__len__()方法返回数据集的总长度。
模板如下:

import torch.utils.data
#定义一个数据集
class CaptionDataset(Dataset):""" 数据集演示 """def __init__(self,transform=None):  """实现初始化方法,在初始化的时候将数据读载入"""....(包括加载数据路径)def __getitem__(self):return self....def __len__(self):return len(...)# 实例化这个类,然后我们就得到了Dataset类型的数据,记下来就将这个类传给DataLoader,就可以了。    
train_data= CaptionDataset(transform=transform) #transform需自己定义(见下面torchvision.transforms)

2.ImageFolder

ImageFolder假设所有的文件按文件夹保存,每个文件夹下存储同一个类别的图片,文件夹名为类名,其构造函数如下:

import torchvision.datasets
ImageFolder(root, transform=None, target_transform=None, loader=default_loader)

各参数含义:

root:在root指定的路径下寻找图片

transform:对PIL Image进行的转换操作,transform的输入是使用loader读取图片的返回对象

target_transform:对label的转换

loader:给定路径后如何读取图片,默认读取为RGB格式的PIL Image对象

label:按照文件夹名顺序排序后存成字典,即{类名:类序号(从0开始)}

举例如下:

import torchvision.datasets
#此处transform需自己定义(见下面torchvision.transforms),其他参数为默认值
train_data=torchvision.datasets.ImageFolder(root=train_path,transform=transform) 

加载数据集

DataLoader

DataLoader是一个数据加载器类,实现了对数据集进行随机采样和多轮次迭代的功能。在训练过程中,可以非常方便地实现多轮次小批量随机梯度下降训练。
常用参数有:Dataset数据集实例,batch_size(每个batch的大小,shuffle(是否进行搅乱操作),num_workers(加载数据的时候使用几个子进程),返回一个可迭代对象。

import torch.utils.data
train_loader = torch.utils.data.DataLoader(CaptionDataset(train_data, transform=transform),batch_size=batch_size, shuffle=True, num_workers=workers)

详细有关参数见博客:PyTorch 中的数据类型 torch.utils.data.DataLoader

torchvision–数据预处理要使用的库

torchvision是Pytorch中专门用来处理图像的库。
提供了常用图片数据集(datasets);
训练好的模型(models);
一般的图像转换操作类(transforms),

torchvision.datasets

torchvision.datasets可以理解为PyTorch团队自定义的dataset,这些dataset帮我们提前处理好了很多的图片数据集,我们拿来就可以直接使用:

  • MNIST
  • COCO
  • Captions
  • Detection
  • LSUN
  • ImageFolder
  • Imagenet-12
  • CIFAR
  • STL10
  • SVHN
  • PhotoTour
    以上我们可以直接用(其他的只能通过自己自定义数据集),示例如下:
import torchvision.datasets as datasets
trainset = datasets.MNIST(root='./data', # 表示 MNIST 数据的加载的目录train=True,  # 表示是否加载数据库的训练集,false的时候加载测试集download=True, # 表示是否自动下载 MNIST 数据集transform=None) # 表示是否需要对数据进行预处理,none为不进行预处理

torchvision.models

torchvision提供了训练好的模型,可以加载后直接使用(见下面代码),或者在进行迁移学习torchvision.models模块的子模块中包含以下模型结构:

  • AlexNet
  • VGG
  • ResNet
  • SqueezeNet
  • DenseNet
#导入预训练模型
import torchvision.models
model = torchvision.models.vgg16(pretrained=True) #True代表已经训练好的模型

torchvision.transforms

transform模块提供了一般的图像转换操作类,用作数据处理和数据增强。
主要提供了对PIL Image对象和Tensor对象的常用操作。

对PIL Image对象的常用操作有:

  • Resize:调整图片尺寸
  • CenterCrop、RandomCrop、RandomSizedCrop:裁剪图片
  • Pad:填充
  • ToTensor:将PIL Image对象转成Tensor,会自动将[0,255]归一化至[0,1]

对Tensor对象的常用操作有:

  • Normalize:标准化,即减均值,除以标准差
  • ToPILImage:将Tensor转为PIL Image对象。
import torchvision.transforms as transforms
transform = transforms.Compose([transforms.RandomCrop(32, padding=4),  #先四周填充0,在把图像随机裁剪成32*32transforms.RandomHorizontalFlip(),  #图像一半的概率翻转,一半的概率不翻转transforms.RandomRotation((-45,45)), #随机旋转transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.229, 0.224, 0.225)), #R,G,B每层的归一化用到的均值和方差
])

详细有关transforms的用法见博客:PyTorch 学习笔记(三):transforms的二十二个方法

训练网络参数

训练前的准备

设置指定的训练设备(GPU、CPU)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

定义损失函数

torch.nn模块中定义了很多标准地损失函数。

import torch.nn as nn
xentropy=nn.CrossEntropyLoss() #此处定义一个交叉熵损失函数对象,该对象可以调用backward()方法实现误差反向传播。

定义优化器

torch.optim模块提供了很多优化算法类,
比如:torch.optim.SGD,torch.optim.Adam,torch.optim.RMSprop。这里以SGD为例。

#import torch.optim
net=CNN().to(device) #使用gpu构造一个CNN对象
optimizer=torch.optim.SGD(params=net.parameters(),lr=0.01,momentum=0.9) 
#上式参数依次为:需要网络模型的参数、学习率、动量参数

详细参数见博客:torch.optim.SGD()各参数的解释

训练过程

神经网络训练过程的一步迭代包含四个主要步骤:

  • 前向运算,计算给定输入的预测结果
  • 计算损失函数值
  • 反向传播(BP),计算参数梯度(计算之前要先梯度清零)
  • 使用梯度下降法更新参数值

详细代码如下:

def train(net,optimizer,loss_fn,num_epoch,data_loader,device):
'''参数分别为网络模型、损失函数(对应之前的xentropy)、epoch总次数、数据加载器、训练设备'''net.train() #进入训练模型for epoch in range(num_epoch):print('Epoch {}/{}'.format(epoch+1, num_epochs))running_loss=0running_corrects=0for i,data in enumerate(data_loader):inputs=data[0].to(device)  #输入labels=data[1].to(device)  #真实值标签#下面优化过程optimizer.zero_grad() #先把前一步的梯度清除,设置梯度值为0outputs=net(inputs)  #前向运算,计算网络模型在inputs上的输出outputsloss=loss_fn(outputs,labels) #计算损失函数值loss.backward() #进行反向传播,计算梯度optimizer.step() #使用优化器的step()方法,进行梯度下降,更新模型参数#可以输出两种loss,loss为每次迭代的loss,running_loss为每个epoch的loss,之后再取平均值。running_loss+=loss.item() #计算每个epoch的loss总值_, preds = torch.max(outputs, 1)running_corrects += torch.sum(preds == labels).item()epoch_loss=running__loss/len(train_data) #计算每个epoch的平均lossepoch_acc = running_corrects / len(train_data)print('{} Loss: {:.4f} Acc: {:.4f}'.format('train', epoch_loss, epoch_acc))				

验证/测试过程

测试和验证集过程不用反向传播,也不用更新梯度。

def evaluate(net,loss_fn,data_load,device):net.eval() #进入模型评估模式,验证和测试都是这个running_loss=0correct=0.0total=0for data in data_loader:inputs=data[0].to(device)  #输入labels=data[1].to(device)  #真实值标签with torch.no_grad():	outputs=net(inputs)loss=loss_fn(outputs,labels)running_loss+=loss.item()_,predicted=torch.max(outputs.data,1)total+=labels.size(0) #另一种计算总数的方法correct+=(predicted==labels).sum().item() #计算预测对的数epoch_loss = running_loss/len(val_data)acc=correct/total #计算准确率print('{} Loss: {:.4f} Acc: {:.4f}'.format('valid', epoch_loss, acc))	

运行

有两种方式:

  • 1.设立一个主函数main(),将for epoch in epochs:以及train函数和test函数放到main()里运行就可以了。
  • 2.将for epoch in epochs:和test函数放入train函数,再直接运行train()函数就可以了。
    完整代码实例:pytorch实现图像分类代码实例

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/333905.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

maven安装教程安装教程_Maven教程之春

maven安装教程安装教程1.简介 在这篇文章中,我们将演示如何针对非常特定的用例对Spring使用Maven依赖项。 我们使用的所有库的最新版本都可以在Maven Central上找到。 对于一个有效的构建周期来说,了解Maven依赖项的工作方式以及如何对其进行管理很重要…

如何完全卸载 Sublime Text

只是在应用程序删除软件是不够,你还必须把下面这个目录删除掉才行: /Users/liaowenxiong/Library/Application Support/Sublime Text /Users/liaowenxiong/Library/Preferences/Sublime Text /Users/liaowenxiong/Library/Caches/Sublime Text不这么干&…

5位随机数重复的概率 php_PHP产生不重复随机数的5个方法总结

无论是Web应用,还是WAP或者移动应用,随机数都有其用武之地。在最近接触的几个小项目中,我也经常需要和随机数或者随机数组打交道,所以,对于PHP如何产生不重复随机数常用的几种方法小结一下(ps:方法1、4、5是…

pytorch实现图像分类代码实例

图像多标签分类例子 import os import torch import torch.nn as nn import torchvision.transforms as transforms import torchvision.datasets as datasets import torchvision.models as models import matplotlib.pyplot as plt from matplotlib.ticker import MultipleL…

空调吸气和排气_吸气剂和二传手被认为有害

空调吸气和排气Java程序员习惯性地用“ getters”和“ setters”来修饰类,这种做法根深蒂固,以至于几乎没有人质疑为什么这样做或是否应该这样做。 最近,我认为最好不要这样做,并且我开始在编写的Java代码中避免使用它。 在这篇博…

Sublime Text for Mac 如何格式化代码

文章目录格式化 HTML/CSS/JS格式化 Java/C/C格式化 HTML/CSS/JS 格式化 HTML/CSS/JS,请安装插件:html-css-js prettify 格式化的快捷键:Shift Cmd H html-css-js prettify 的简介: Usage Tools -> Command Palette (CmdS…

ic启动器我的世界_hmcl启动器下载-我的世界HMCL启动器下载 v3.3.172官方最新版--pc6下载站...

我的世界HMCL启动器是我的世界游戏玩家必备的游戏启动器,是三年来超过使用3亿次的老牌启动器,不需要其他任何设置,操作非常方便,本站提供现在最新版本下载。我的世界HMCL启动器是我的世界游戏玩家必备的游戏启动器,是三…

C++核心编程笔记

C核心编程1 内存分区模型1.1 程序运行前1.2 程序运行后1.3 new操作符2 引用2.1 引用的基本使用2.2 引用注意事项2.3 引用做函数参数2.4 引用做函数返回值2.5 引用的本质2.6 常量引用3 函数提高3.1 函数默认参数3.2 函数占位参数3.3 函数重载3.3.1 函数重载概述3.3.2 函数重载注…

Sublime Text 如何设置组合快捷键

Sublime 有个功能叫再次缩进(Reindent),我就以这个功能为例讲下如何设置快捷键,这个功能的菜单路径是:Edit ➠ Line ➠ Reindent,有人说这个再次缩进可以格式化代码,扯淡,缩进两下也…

朝着理想坚实迈进_坚实原则:开放/封闭原则

朝着理想坚实迈进先前我们讨论了单一责任原则。 关于实体原则首字母缩写, 打开/关闭原则是该行中的第二个原则。 “软件实体(类,模块,功能等)应打开以进行扩展,但应关闭以进行修改” 通过采用该原理&…

协程asyncio_Python 异步模块 asyncio 中的协程与任务

协程(Coroutine)是允许执行被挂起、恢复、以及取消的程序。Python 3 中最初是使用 asyncio.coroutine 装饰器和 yield from 关键字组合来实现协程。单词 yield 在这里并非在生成器(Generator)中所表示的“产出”,而是交…

ie8兼容性视图灰色修复_IE8网页显示不正常 用”兼容性视图”搞定

网页显示不正常,出现图片错位,文字跑远……等等,别急,试试IE8自带的”兼容性视图”功能吧!其实出现网页显示问题,一般不是您的电脑或者浏览器有问题,而是由于各网站开发标准不同,所以在不同的浏…

GAN对抗生成网络原始论文理解笔记

文章目录论文:Generative Adversarial Nets符号意义生成器(Generator)判别器(Discriminator)生成器和判别器的关系GAN的训练流程简述论文中的生成模型和判别模型GAN的数学理论最大似然估计转换为最小化KL散度问题定义PGP_GPG​全局最优论文:Generative A…

php cdi_CDI和lambda的策略模式

php cdi策略设计模式在运行时动态选择一种实现算法,一种策略。 该模式可用于根据情况选择不同的业务算法。 我们可以将不同的算法实现定义为单独的类。 或者,我们利用Java SE 8 lambda和函数,这些lambda和函数在此处用作轻量级策略实现。 C…

Linux 命令之 cp -- 复制文件或目录

文章目录一、命令介绍二、常用选项三、命令示例(一)复制某个目录到某个目录下(二)复制文件(三)复制文件到目标目录下,若存在文件则备份(四)复制某个目录的全部文件到某个…

向上累积频数怎么算_excel数据分析向上累计和向下累计怎么做呢

2016-07-08 00:25赵飞虎 客户经理一、Excel在分析性测试、复核中的运用注册会计师在分析审计风险确定重点审计领域、重要性水平和重大异常经济业务事项时,常常要对被审计单位的会计报表进行分析性测试和复核。在执行具体审计程序时,也常常要对本期数和上…

okta使用_使用Okta的单点登录保护您的Vert.x服务器

okta使用“我喜欢编写身份验证和授权代码。” 〜从来没有Java开发人员。 厌倦了一次又一次地建立相同的登录屏幕? 尝试使用Okta API进行托管身份验证,授权和多因素身份验证。 Vert.x是Spring生态系统中增长最快的元素之一,保护Vert.x服务器可…

Linux 命令之 make -- GNU的工程化编译工具

文章目录一、命令介绍二、常用选项三、命令示例(一)指定命令 make 的工作目录一、命令介绍 make 命令是 GNU 的工程化编译工具,用于编译众多相互关联的源代码文件,还可以编辑内核或模块,以实现工程化的管理&#xff0…

SDL2笔记

SDL2基本操作头文件主函数初始化创建窗口窗口暂停以及事件讲解销毁窗口(释放指针)并退出加载bmp图片新加载图片的方法(使用渲染、纹理)加载其他格式的图片头文件 #include "SDL.h" #include "SDL_image.h"主函数 int main(int argc,char* argv[]) //一定…

操作系统时间片轮换_《操作系统_时间片轮转RR进程调度算法》

转自:https://blog.csdn.net/houchaoqun_xmu/article/details/55540250时间片轮转RR进程调度算法一、概念介绍和案例解析时间片轮转法 - 基本原理:在早期的时间片轮转法中,系统将所有的就绪进程按先来先服务的原则排成一个队列,每次调度时&am…