快速上手笔记,PyTorch模型训练实用教程(附代码)

前言

自 2017 年 1 月 PyTorch 推出以来,其热度持续上升,一度有赶超 TensorFlow 的趋势。PyTorch 能在短时间内被众多研究人员和工程师接受并推崇是因为其有着诸多优点,如采用 Python 语言、动态图机制、网络构建灵活以及拥有强大的社群等。因此,走上学习 PyTorch 的道路已刻不容缓。

本教程以实际应用、工程开发为目的,着重介绍模型训练过程中遇到的实际问题和方法。如上图所示,在机器学习模型开发中,主要涉及三大部分,分别是数据、模型和损失函数优化器。本文也按顺序的依次介绍数据、模型和损失函数优化器,从而给大家带来清晰的机器学习结构。

通过本教程,希望能够给大家带来一个清晰的模型训练结构。当模型训练遇到问题时,需要通过可视化工具对数据、模型、损失等内容进行观察,分析并定位问题出在数据部分?模型部分?还是优化器?只有这样不断的通过可视化诊断你的模型,不断的对症下药,才能训练出一个较满意的模型。

为什么写此教程

前几年一直在用 Caffe 和 MatConvNet,近期转 PyTorch。当时只想快速地用上 PyTorch 进行模型开发,然而搜了一圈 PyTorch 的教程,并没有找到一款适合的。很多 PyTorch 教程是从学习机器学习 (深度学习) 的角度出发,以 PyTorch 为工具进行编写,里面介绍很多模型,并且附上模型的 demo。

然而,工程应用开发中所遇到的问题并不是跑一个模型的 demo 就可以的,模型开发需要对数据的预处理、数据增强、模型定义、权值初始化、模型 Finetune、学习率调整策略、损失函数选取、优化器选取、可视化等等。鉴于此,我只能自己对着官方文档,一步一步地学习。

起初,只是做了一些学习笔记,后来觉得这些内容应该对大家有些许帮助,毕竟在互联网上很难找到这类内容的分享,于是此教程就诞生了。

本教程内容及结构

本教程内容主要为在 PyTorch 中训练一个模型所可能涉及到的方法及函数,并且对 PyTorch 提供的数据增强方法(22 个)、权值初始化方法(10 个)、损失函数(17 个)、优化器(6 个)及 tensorboardX 的方法(13 个)进行了详细介绍。

本教程分为四章,结构与机器学习三大部分一致:

  • 第一章,介绍数据的划分,预处理,数据增强;

  • 第二章,介绍模型的定义,权值初始化,模型 Finetune;

  • 第三章,介绍各种损失函数优化器

  • 第四章,介绍可视化工具,用于监控数据、模型权及损失函数的变化。

本教程适用读者:

  1. 想熟悉 PyTorch 使用的朋友;

  2. 想采用 PyTorch 进行模型训练的朋友;

  3. 正采用 PyTorch,但无有效机制去诊断模型的朋友;

干货直达:

1.6 transforms 的二十二个方法

2.2 权值初始化的十种方法

3.1 PyTorch 的十七个损失函数

3.3 PyTorch 的十个优化器

3.4 PyTorch 的六个学习率调整方法

4.1 TensorBoardX

项目代码:https://github.com/tensor-yu/PyTorch_Tutorial

为了展示该教程的内容,读者可试读第二章的第一小节,了解PyTorch如何搭建模型:

第二章 模型

第二章介绍关于网络模型的一系列内容,包括模型的定义,模型参数初始化方法,模型的保存和加载,模型的 finetune(本质上还是模型权值初始化),首先介绍模型的定义。

2.1 模型的搭建

2.1.1 模型定义的三要

首先,必须继承 nn.Module 这个类,要让 PyTorch 知道这个类是一个 Module。

其次,在__init__(self) 中设置好需要的「组件"(如 conv、pooling、Linear、BatchNorm 等)。

最后,在 forward(self, x) 中用定义好的「组件」进行组装,就像搭积木,把网络结构搭建出来,这样一个模型就定义好了。

接下来,请看代码,在/Code/main_training/main.py 中可以看到定义了一个类 class Net(nn.Module),先看__init__(self) 函数:

def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.pool1 = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.pool2 = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)

第一行是初始化,往后定义了一系列组件,如由 Conv2d 构成的 conv1,有 MaxPool2d 构成的 poo1l,这些操作均由 torch.nn 提供,torch.nn 中的操作可查看文档:https://PyTorch.org/docs/stable/nn.html#。

当这些组件定义好之后,就可以定义 forward() 函数,用来搭建网络结构,请看代码:

def forward(self, x):x = self.pool1(F.relu(self.conv1(x)))x = self.pool2(F.relu(self.conv2(x)))x = x.view(-1, 16 * 5 * 5)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x
  • x 为模型的输入,第一行表示,x 经过 conv1,然后经过激活函数 relu,再经过 pool1 操作;

  • 第二行于第一行一样;第三行,表示将 x 进行 reshape,为了后面做为全连接层的输入;

  • 第四,第五行的操作都一样,先经过全连接层 fc,然后经过 relu;

  • 第六行,模型的最终输出是 fc3 输出。

至此,一个模型定义完毕,接着就可以在后面进行使用。例如,实例化一个模型 net = Net(),然后把输入 inputs 扔进去,outputs = net(inputs),就可以得到输出 outputs。

2.1.2 模型定义多说两句

上面只是介绍了模型定义的要素和过程,但是在工程应用中会碰到各种各样的网络模型,这时,我们就需要一些实用工具来帮助我们定义模型了。

这里以 Resnet34 为例介绍「复杂」模型的定义,这部分代码从 github 上获取。

地址:https://github.com/yuanlairuci110/PyTorch-best-practice-master/blob/master/models/ResNet34.py

class ResidualBlock(nn.Module):'''实现子module: Residual Block'''def __init__(self, inchannel, outchannel, stride=1, shortcut=None):super(ResidualBlock, self).__init__()self.left = nn.Sequential(nn.Conv2d(inchannel, outchannel, 3, stride, 1, bias=False),nn.BatchNorm2d(outchannel),nn.ReLU(inplace=True),nn.Conv2d(outchannel, outchannel, 3, 1, 1, bias=False),nn.BatchNorm2d(outchannel) )self.right = shortcut

    def forward(self, x):
        out = self.left(x)
        residual = x if self.right is None else self.right(x)
        out += residual
        return F.relu(out)

class ResNet34(BasicModule):
    ‘’’
    实现主module:ResNet34
    ResNet34包含多个layer,每个layer又包含多个Residual block
    用子module来实现Residual block,用_make_layer函数来实现layer
    ‘’’
    def init(self, num_classes=2):
        super(ResNet34, self).init()
        self.model_name = ‘resnet34’

        # 前几层: 图像转换
        self.pre = nn.Sequential(
                nn.Conv2d(3, 64, 7, 2, 3, bias=False),
                nn.BatchNorm2d(64),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(3, 2, 1))

        # 重复的layer,分别有3,4,6,3个residual block
        self.layer1 = self._make_layer( 64, 128, 3)
        self.layer2 = self._make_layer( 128, 256, 4, stride=2)
        self.layer3 = self._make_layer( 256, 512, 6, stride=2)
        self.layer4 = self._make_layer( 512, 512, 3, stride=2)

        #分类用的全连接
        self.fc = nn.Linear(512, num_classes)

    def _make_layer(self,  inchannel, outchannel, block_num, stride=1):
        ‘’’
        构建layer,包含多个residual block
        ‘’’
        shortcut = nn.Sequential(
                nn.Conv2d(inchannel,outchannel,1,stride, bias=False),
                nn.BatchNorm2d(outchannel))

        layers = []
        layers.append(ResidualBlock(inchannel, outchannel, stride, shortcut))

        for i in range(1, block_num):
            layers.append(ResidualBlock(outchannel, outchannel))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.pre(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = F.avg_pool2d(x, 7)
        x = x.view(x.size(0), -1)
        return self.fc(x)

还是从三要素出发看看是怎么定义 Resnet34 的。

  • 首先,继承 nn.Module;

  • 其次,看__init__() 函数,在__init__() 中,定义了这些组件,self.pre,self.layer1-4, self.fc ;

  • 最后,看 forward(),分别用了在__init__() 中定义的一系列组件,并且用了 torch.nn.functional.avg_pool2d 这个操作。

至此,网络定义完成。

以为就完了?怎么可能,init() 函数中的组件是怎么定义的,在__init__() 中出现了 torch.nn.Sequential。

组件定义还调用函数_make_layer(),其中也用到了 torch.nn.Sequential,其中还调用了 ResidualBlock(nn.Module),在 ResidualBlock(nn.Module) 中有一次调用了 torch.nn.Sequential。

torch.nn.Sequential 到底是什么呢?为什么都在用呢?

2.1.3 nn.Sequetial

torch.nn.Sequential 其实就是 Sequential 容器,该容器将一系列操作按先后顺序给包起来,方便重复使用。例如 Resnet 中有很多重复的 block,就可以用 Sequential 容器把重复的地方包起来。

官方文档中给出两个使用例子:

# Example of using Sequential
model = nn.Sequential(
nn.Conv2d(1,20,5),
nn.ReLU(),
nn.Conv2d(20,64,5),
nn.ReLU()
)

# Example of using Sequential with OrderedDict
model = nn.Sequential(OrderedDict([
          (‘conv1’, nn.Conv2d(1,20,5)),
          (‘relu1’, nn.ReLU()),
          (‘conv2’, nn.Conv2d(20,64,5)),
          (‘relu2’, nn.ReLU())
        ]))

小结:

模型的定义就是先继承,再构建组件,最后组装。

其中基本组件可从 torch.nn 中获取,或者从 torch.nn.functional 中获取,同时为了方便重复使用组件,可以使用 Sequential 容器将一系列组件包起来,最后在 forward() 函数中将这些组件组装成你的模型。

获取方式一:

获取方式二:

链接: https://pan.baidu.com/s/11hvPGusAopXNwuCsuLilCA
提取码: anw5 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/480848.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

领域应用 | 中医临床术语系统V2.0在线发布啦!

本文转载自公众号:中医药知识组织与标准。中医临床术语系统V2.0在线发布中医临床术语系统(Traditional Chinese Medicine Clinical Terminological Systems, TCMCTS)是由中国中医科学院中医药信息研究所研制的,用来描述健康状况和…

NLP Subword三大算法原理:BPE、WordPiece、ULM

Subword算法如今已经成为了一个重要的NLP模型性能提升方法。自从2018年BERT横空出世横扫NLP界各大排行榜之后,各路预训练语言模型如同雨后春笋般涌现,其中Subword算法在其中已经成为标配。且与传统空格分隔tokenization技术的对比有很大的优势~~ E.g. 模…

【小程序】微信小程序开发实践

版权声明&#xff1a;本文为博主原创文章&#xff0c;未经博主允许不得转载。 https://blog.csdn.net/diandianxiyu/article/details/53068012 </div><link rel"stylesheet" href"https://csdnimg.cn/release/phoenix/template/css/ck…

技术人如何提升自己的核心竞争力

互联网行业是一个发展非常快&#xff0c;变化也快的行业&#xff0c;在这个行业&#xff0c;总是让人感觉既兴奋又不安。 兴奋的是你总能看到无数新奇的事物&#xff0c;甚至亲身参与到一场变革中去&#xff0c;而不安的则是&#xff0c;任凭你如何NB&#xff0c;你也无法保证哪…

AAAI 2018经典论文获奖者演讲:本体论的昨天和今天

本文转自公众号&#xff1a;AI科技评论。AI 科技评论按&#xff1a;正在美国新奥尔良召开的 AAAI 2018 的经典论文奖颁给了《Algorithm and Tool for Automated Ontology Merging and Alignment》。这篇论文发表在 2000 年的第 17 届 AAAI 大会上。这次颁奖是为了表彰这篇论文在…

ICLR2020 | 如何判断两个神经网络学到的知识是否一致

人工智能顶会 ICLR 2020 将于 4 月 26 日于埃塞俄比亚首都亚的斯亚贝巴举行。在最终提交的 2594 篇论文中&#xff0c;有 687 篇被接收&#xff0c;接收率为 26.5%。本文介绍了上海交通大学张拳石团队的一篇接收论文——《Knowledge Consistency between Neural Networks and B…

7张图学会SQL

第1周&#xff1a;SQL入门 学习SQL语句的书写语法和规则从零学会SQL&#xff1a;入门​www.zhihu.com 第2周&#xff1a;查询基础 Select查询语句是SQL中最基础也是最重要的语句&#xff0c;这周我们就来利用Select来对表中的数据进行查询。从零学会SQL&#xff1a;简单查询​w…

大公司稳定工作和创业之间如何选择?

“ 是留在大公司&#xff0c;还是加入小型创业公司&#xff0c;还是自己创业&#xff0c;面对房价每年高涨的趋势&#xff0c;面对未来的不确定&#xff0c;应该怎样选择。 作为一个亲历者&#xff0c;希望你看完后能有所启发。 本文作者&#xff0c;陈睿 优知学院创始人 优知…

论文浅尝 |「知识表示学习」专题论文推荐

本文转载自公众号&#xff1a;PaperWeekly。本期论文清单来自清华大学博士生韩旭和北师大本科生曹书林&#xff0c;涵盖了近年知识表示学习方向的重要论文。[ 综述类 ]■ 论文 | Representation Learning: A Review and New Perspectives■ 链接 | https://www.paperweekly.sit…

如何选择一家公司

不管是刚毕业的大学生还是工作几年的职场朋友&#xff0c;每个人都会面临选择公司和行业的困扰&#xff0c;我也相信每个人都还记忆犹新你的第一份工作以及让你无比难忘的一家公司。有时候我们也盲目的所求&#xff0c;其实&#xff0c;偶尔停下来思考下你真想去的地方&#xf…

LightGBM最强解析,从算法原理到代码实现~

1 LightGBM简介 GBDT (Gradient Boosting Decision Tree) 是机器学习中一个长盛不衰的模型&#xff0c;其主要思想是利用弱分类器&#xff08;决策树&#xff09;迭代训练以得到最优模型&#xff0c;该模型具有训练效果好、不易过拟合等优点。GBDT不仅在工业界应用广泛&#xf…

数据分析师基本技能——SQL

我们做数据分析工作时&#xff0c;多数数据来源于数据库&#xff0c;SQL非常方便我们访问和查询数据库。 SQL 作为数据分析师的基本技能&#xff0c;那么需要掌握哪些SQL核心技能 理解数据库SQL基础重点知识&#xff1a;查询&#xff0c;更新&#xff0c;提取&#xff0c;插入&…

论文浅尝 | 基于置信度的知识图谱表示学习框架

本文转载自公众号&#xff1a;PaperWeekly。作者丨谢若冰单位丨腾讯微信搜索应用部研究方向丨知识表示学习知识图谱被广泛地用来描述世界上的实体和实体之间的关系&#xff0c;一般使用三元组&#xff08;h,r,t&#xff09;&#xff08;head entity, relation, trail entity&am…

史上最强Java架构师的13大技术能力讲解! | 附架构师能力图谱

从程序员进阶成为架构师&#xff0c;并非一蹴而就&#xff0c;需要系统化、阶段性地学习&#xff0c;在实战项目中融会贯通&#xff0c;这如同打怪通关&#xff0c;我们得一关一关突破&#xff0c;每攻破一个关口&#xff0c;就能得到更精良的装备&#xff0c;技能值也随之不断…

写给运营同学和初学者的SQL入门教程

作者简介 多肉&#xff0c;饿了么资深python工程师。曾在17年担任饿了么即时配送众包系统的研发经理&#xff0c;这篇文章最早的版本就诞生于那段时间&#xff0c;目前负责配送相关业务系统的整体稳定性建设。个人比较喜欢c和python&#xff0c;最近有点迷rust&#xff0c;同时…

强化学习,路在何方?

▌一、深度强化学习的泡沫 2015年&#xff0c;DeepMind的Volodymyr Mnih等研究员在《自然》杂志上发表论文Human-level control through deep reinforcement learning[1]&#xff0c;该论文提出了一个结合深度学习&#xff08;DL&#xff09;技术和强化学习&#xff08;RL&…

论文浅尝 | 基于神经网络的实体识别和关系抽取联合学习

本文转载自公众号&#xff1a;PaperWeekly。作者丨罗凌学校丨大连理工大学博士生研究方向丨深度学习&#xff0c;文本分类&#xff0c;实体识别联合学习&#xff08;Joint learning&#xff09;一词并不是一个最近才出现的术语&#xff0c;在自然语言处理领域&#xff0c;很早就…

一篇文章搞懂架构师的核心技能

“ 这是架构师系列的第一篇&#xff1a;核心技能&#xff0c;希望这个系列能完全揭示架构师这个职位&#xff1a;我先从核心技能开始&#xff0c;后续还有架构师之路&#xff0c;架构实战等架构师系列文章。 本文作者 陈睿 优知学院创始人&#xff0c;前携程定制旅游CTO,在互联…

史上最全的分词算法与工具介绍

分词&#xff08;word tokenization&#xff09;&#xff0c;也叫切词&#xff0c;即通过某种方式将句子中的各个词语识别并分离开来&#xff0c;使得文本从“字序列”的表示升级为“词序列”表示。分词技术不仅仅适用于中文&#xff0c;对于英文、日文、韩文等语言也同样适用。…

论文解读:Attention is All you need

论文解读:Attention is All you need习翔宇​北京大学 软件工程博士在读​关注他192 人赞同了该文章Attention机制最早在视觉领域提出&#xff0c;2014年Google Mind发表了《Recurrent Models of Visual Attention》&#xff0c;使Attention机制流行起来&#xff0c;这篇论文采…