预训练模型应用工具 PaddleHub情感分析、对话情绪识别文本相似度

文章目录

    • 1. 预训练模型的应用背景
      • 1.1 多任务学习与迁移学习
      • 1.2 自监督学习
    • 2. 快速使用PaddleHub
      • 2.1 通过Python代码调用方式 使用PaddleHub
        • 2.1.1 CV任务
        • 原图展示
        • 人像扣图
        • 人体部位分割
        • 人脸检测
        • 关键点检测
        • 2.1.2 NLP 任务
      • 2.2 通过命令行调用方式 使用PaddleHub
    • 3. PaddleHub提供的预训练模型
    • 4. 使用自己的数据Fine-tune PaddleHub预训练模型
      • 4.1 安装PaddleHub
      • 4.2 数据准备
      • 4.3 模型准备
      • 4.4 训练准备
      • 4.5 组建Fine-tune Task
      • 4.6 启动Fine-tune
    • 5. 相关参考链接

十行代码能干什么? 相信多数人的答案是可以写个“Hello world”,或者做个简易计算器,本章将告诉你另一个答案,还可以实现人工智能算法应用。基于PaddleHub,可以轻松使用十行代码完成所有主流的人工智能算法应用,比如目标检测、人脸识别、语义分割等任务。

PaddleHub是飞桨预训练模型应用工具,集成了最优秀的算法模型,旨在帮助开发者使用最简单的代码快速完成复杂的深度学习任务,另外,PaddleHub提供了方便的Fine-tune API,开发者可以使用高质量的预训练模型结合Fine-tune API快速完成模型迁移到部署的全流程工作。

图1是2020年疫情期间,PaddleHub提供的十行代码即可完成根据肺部影像诊断病情的任务,以及检测人像是否佩戴口罩的任务。


图1:PaddleHub产业应用

运行如下代码,快速体验一下

  1. 安装PaddleHub并升级到最新版本。

    # 下载安装paddlehub到最新版本,仅第一次运行项目时执行此命令
    !pip install paddlehub==1.6.1 -i https://pypi.tuna.tsinghua.edu.cn/simple #指定版本安装PaddleHub,使用清华源更稳定、更迅速
    !pip install paddlehub --upgrade -i https://pypi.tuna.tsinghua.edu.cn/simple #升级到最新版本,使用清华源更稳定、更迅速
    
    • 1
    • 2
    • 3
  2. 使用Paddlehub实现口罩人脸检测,只需要几行命令。其中,test_mask_detection.jpg是一张测试图片。

    ! wget  https://paddlehub.bj.bcebos.com/resources/test_mask_detection.jpg #下载测试图片
    ! hub install pyramidbox_lite_mobile_mask==1.3.0 #加载预训练模型
    ! hub run pyramidbox_lite_mobile_mask --input_path test_mask_detection.jpg #运行预测结果
    
    • 1
    • 2
    • 3

本节将从如下几个方面介绍PaddleHub

  • 预训练模型的应用背景;
  • PaddleHub的快速使用方法和PaddleHub支持的模型列表;
  • 通过一个完整的案例,介绍如何使用自己的数据Fine-tune PaddleHub的预训练模型。

1. 预训练模型的应用背景

众所周知,深度学习任务依赖较多的数据完成神经网络的训练。在实际场景中,数据量的大小与成本成正比,常遇到语料数据或者图像数据较少,不足以支持完成神经网络模型训练的场景。

经过不断的探索,人们发现有两种思路可以解决训练数据不足的问题。

1.1 多任务学习与迁移学习

人们发现处理很多任务所依赖的信息特征是相通的,比如从图片中框选出一只猫的任务与识别一个生物是不是猫的任务,均需要提取出标识猫的有效特征。这是符合认知的,人类处理一件任务也会不自觉的运用上从其他任务上学习到的知识和方法,比如我们学习英语的时候,也会代入已经掌握的很多中文语法习惯。

基于迁移学习的思想,我们可以将模型先在数据丰富的任务上学习,再使用新任务的小数据量做Fine-tune(网络参数的微调,继承了从数据丰富任务上学习到的知识),最终达到较好的效果。

图2展示了对于不同的自然语言任务,很多本质的信息和知识是可以共享的。词性标注、句子句法成分划分、命名实体识别、语义角色标注等NLP任务适合采用多任务学习来解决。PaddleHub提供了预训练好的语义表示库ERNIE,它是这方面的佼佼者。


图2:多任务学习与迁移学习

1.2 自监督学习

通过一些巧妙的方法,我们可以将一些无监督的数据样本转变成监督学习,来学习数据中的知识。如图3所示,按照通常的理解,一张无标签的图片和一段自然语言文本是无监督的数据。但我们可以将部分图像进行遮挡,未遮挡的部分作为监督模型的输入,遮挡的部分作为模型需要预测的输出。同样的,也可以将一段文本中的部分短语遮挡,未遮挡的部分作为监督模型的输入,遮挡的部分作为模型需要预测的输出。


图3:自监督学习

PaddleHub中预置了大量的预训练模型,均采用了上述两种技术,并结合了百度在互联网领域海量的独有数据积累,数十种广受开发者欢迎的模型均是PaddleHub独有的。

2. 快速使用PaddleHub

既然PaddleHub的使用如此简单,功能又如此强大,那么读者们是否迫不及待了呢?下面我们就展示下快速使用PaddleHub的两种方式:Python代码调用命令行调用

2.1 通过Python代码调用方式 使用PaddleHub

首先以计算机视觉任务为例,我们选用一张测试图片test.jpg,分别实现如下四项功能:

  • 人像扣图(deeplabv3p_xception65_humanseg)
  • 人体部位分割(ace2p)
  • 人脸检测(ultra_light_fast_generic_face_detector_1mb_640)
  • 关键点检测(human_pose_estimation_resnet50_mpii)

注:有关调用的模型名字参考官方文档。

2.1.1 CV任务

原图展示

# 待预测图片
test_img_path = ["./test.jpg"]

import matplotlib.pyplot as plt
import matplotlib.image as mpimg

img = mpimg.imread(test_img_path[0])

# 展示待预测图片
plt.figure(figsize=(10,10))
plt.imshow(img)
plt.axis(‘off’)
plt.show()

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-gMNhrP8y-1594442261635)(C:\Users\ADMINI~1\AppData\Local\Temp\1594438697714.png)]

人像扣图

#安装预训练模型
!hub install deeplabv3p_xception65_humanseg==1.1.0
  • 1
  • 2
Downloading deeplabv3p_xception65_humanseg
[==================================================] 100.00%
Uncompress /home/aistudio/.paddlehub/tmp/tmpybp717db/deeplabv3p_xception65_humanseg
[==================================================] 100.00%
Successfully installed deeplabv3p_xception65_humanseg-1.1.0
  • 1
  • 2
  • 3
  • 4
  • 5
import paddlehub as hub
import matplotlib.image as mpimg
import matplotlib.pyplot as plt

module = hub.Module(name=“deeplabv3p_xception65_humanseg”)
res = module.segmentation(paths = ["./test.jpg"], visualization=True, output_dir=‘humanseg_output’)

res_img_path = ‘humanseg_output/test.png’
img = mpimg.imread(res_img_path)
plt.figure(figsize=(10, 10))
plt.imshow(img)
plt.axis(‘off’)
plt.show()

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-8NlreER9-1594442261638)(C:\Users\ADMINI~1\AppData\Local\Temp\1594438787148.png)]

人体部位分割

#安装预训练模型
!hub install ace2p==1.1.0
  • 1
  • 2
Downloading ace2p
[==================================================] 100.00%
Uncompress /home/aistudio/.paddlehub/tmp/tmpses1557l/ace2p
[==================================================] 100.00%
Successfully installed ace2p-1.1.0
  • 1
  • 2
  • 3
  • 4
  • 5
import paddlehub as hub
import matplotlib.image as mpimg
import matplotlib.pyplot as plt

module = hub.Module(name=“ace2p”)
res = module.segmentation(paths = ["./test.jpg"], visualization=True, output_dir=‘ace2p_output’)

res_img_path = ‘./ace2p_output/test.png’
img = mpimg.imread(res_img_path)
plt.figure(figsize=(10, 10))
plt.imshow(img)
plt.axis(‘off’)
plt.show()

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-XgSbI5sH-1594442261640)(C:\Users\ADMINI~1\AppData\Local\Temp\1594438880710.png)]

人脸检测

#安装预训练模型
! hub install ultra_light_fast_generic_face_detector_1mb_640==1.1.2
  • 1
  • 2
Downloading ultra_light_fast_generic_face_detector_1mb_640
[==================================================] 100.00%
Uncompress /home/aistudio/.paddlehub/tmp/tmpstkqzi19/ultra_light_fast_generic_face_detector_1mb_640
[==================================================] 100.00%
Successfully installed ultra_light_fast_generic_face_detector_1mb_640-1.1.2
  • 1
  • 2
  • 3
  • 4
  • 5
import paddlehub as hub
import matplotlib.image as mpimg
import matplotlib.pyplot as plt

module = hub.Module(name=“ultra_light_fast_generic_face_detector_1mb_640”)
res = module.face_detection(paths = ["./test.jpg"], visualization=True, output_dir=‘face_detection_output’)

res_img_path = ‘./face_detection_output/test.jpg’
img = mpimg.imread(res_img_path)
plt.figure(figsize=(10, 10))
plt.imshow(img)
plt.axis(‘off’)
plt.show()

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-JeMJKreB-1594442261645)(C:\Users\ADMINI~1\AppData\Local\Temp\1594438957906.png)]

关键点检测

#安装预训练模型
!hub install human_pose_estimation_resnet50_mpii==1.1.0
  • 1
  • 2
File human_pose_estimation_resnet50_mpii_1.1.0.tar.gz already existed
Wait to check the MD5 value
MD5 check failed!
Delete invalid file.
Downloading human_pose_estimation_resnet50_mpii_1.1.0.tar.gz
[==================================================] 100.00%
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
import paddlehub as hub
import matplotlib.image as mpimg
import matplotlib.pyplot as plt

module = hub.Module(name=“human_pose_estimation_resnet50_mpii”)
res = module.keypoint_detection(paths = ["./test.jpg"], visualization=True, output_dir=‘keypoint_output’)

res_img_path = ‘./keypoint_output/test.jpg’
img = mpimg.imread(res_img_path)
plt.figure(figsize=(10, 10))
plt.imshow(img)
plt.axis(‘off’)
plt.show()

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-dqm5dE72-1594442261646)(C:\Users\ADMINI~1\AppData\Local\Temp\1594439040844.png)]

2.1.2 NLP 任务

对于自然语言处理任务,下面以中文分词和情感分类的任务为例,待处理的数据以函数参数的形式传入。

#安装预训练模型
!hub install lac==2.1.1
  • 1
  • 2
Downloading lac
[==================================================] 100.00%
Uncompress /home/aistudio/.paddlehub/tmp/tmpt0t3qlhr/lac
[==================================================] 100.00%
Successfully installed lac-2.1.1
  • 1
  • 2
  • 3
  • 4
  • 5
import paddlehub as hub

lac = hub.Module(name=“lac”)
test_text = [“1996年,曾经是微软员工的加布·纽维尔和麦克·哈灵顿一同创建了Valve软件公司。他们在1996年下半年从id software取得了雷神之锤引擎的使用许可,用来开发半条命系列。”]
res = lac.lexical_analysis(texts = test_text)
print(“中文词法分析结果:”, res)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
[2020-05-25 10:40:44,605] [    INFO] - Installing lac module
[2020-05-25 10:40:44,610] [    INFO] - Module lac already installed in /home/aistudio/.paddlehub/modules/lac
中文词法分析结果: [{'word': ['1996年', ',', '曾经', '是', '微软', '员工', '的', '加布·纽维尔', '和', '麦克·哈灵顿', '一同', '创建', '了', 'Valve软件公司', '。', '他们', '在', '1996年下半年', '从', 'id', ' ', 'software', '取得', '了', '雷神之锤', '引擎', '的', '使用', '许可', ',', '用来', '开发', '半条命', '系列', '。'], 'tag': ['TIME', 'w', 'd', 'v', 'ORG', 'n', 'u', 'PER', 'c', 'PER', 'd', 'v', 'u', 'ORG', 'w', 'r', 'p', 'TIME', 'p', 'nz', 'w', 'n', 'v', 'u', 'n', 'n', 'u', 'vn', 'vn', 'w', 'v', 'v', 'n', 'n', 'w']}]
  • 1
  • 2
  • 3
#安装预训练模型
! hub install senta_bilstm==1.1.0
  • 1
  • 2
Downloading senta_bilstm
[==================================================] 100.00%
Uncompress /home/aistudio/.paddlehub/tmp/tmpjh2vgjcx/senta_bilstm
[==================================================] 100.00%
Successfully installed senta_bilstm-1.1.0
  • 1
  • 2
  • 3
  • 4
  • 5
import paddlehub as hub

senta = hub.Module(name=“senta_bilstm”)
test_text = [“味道不错,确实不算太辣,适合不能吃辣的人。就在长江边上,抬头就能看到长江的风景。鸭肠、黄鳝都比较新鲜。”]
res = senta.sentiment_classify(texts = test_text)
print(“中文词法分析结果:”, res)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
[2020-05-25 10:43:00,885] [    INFO] - Installing senta_bilstm module
[2020-05-25 10:43:00,949] [    INFO] - Module senta_bilstm already installed in /home/aistudio/.paddlehub/modules/senta_bilstm
[2020-05-25 10:43:02,875] [    INFO] - Installing lac module
[2020-05-25 10:43:02,877] [    INFO] - Module lac already installed in /home/aistudio/.paddlehub/modules/lac
中文词法分析结果: [{'text': '味道不错,确实不算太辣,适合不能吃辣的人。就在长江边上,抬头就能看到长江的风景。鸭肠、黄鳝都比较新鲜。', 'sentiment_label': 1, 'sentiment_key': 'positive', 'positive_probs': 0.9775, 'negative_probs': 0.0225}]
  • 1
  • 2
  • 3
  • 4
  • 5

2.2 通过命令行调用方式 使用PaddleHub

PaddleHub在设计时,为模型的管理和使用提供了命令行工具,也提供了通过命令行调用PaddleHub模型完成预测的方式。比如,上面人像分割和文本分词的任务也可以通过命令行调用的方式实现。

#通过命令行方式实现人像分割任务
! hub run deeplabv3p_xception65_humanseg --input_path test.jpg
  • 1
  • 2
[{'data': array([[-226.66667, -226.66667, -226.66667, ..., -226.66667, -226.66667,-226.66667],[-226.66667, -226.66667, -226.66667, ..., -226.66667, -226.66667,-226.66667],[-226.66667, -226.66667, -226.66667, ..., -226.66667, -226.66667,-226.66667],...,[-226.66667, -226.66667, -226.66667, ..., -226.66667, -226.66667,-226.66667],[-226.66667, -226.66667, -226.66667, ..., -226.66667, -226.66667,-226.66667],[-226.66667, -226.66667, -226.66667, ..., -226.66667, -226.66667,-226.66667]], dtype=float32)}]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
#通过命令行方式实现文本分词任务
!hub run lac --input_text "今天是个好日子"
  • 1
  • 2
[{'word': ['今天', '是', '个', '好日子'], 'tag': ['TIME', 'v', 'q', 'n']}]
  • 1

上面的命令中包含四个部分,分别是:

  • hub 表示PaddleHub的命令。
  • run 调用run执行模型的预测。
  • deeplabv3p_xception65_humanseglac表示要调用的算法模型。
  • --input_path/--input_text 表示模型的输入数据,图像和文本的输入方式不同。

PaddleHub的命令行工具在开发时借鉴了AnacondaPIP等软件包管理的理念,可以方便快捷的完成模型的搜索、下载、安装、升级、预测等功能。
可点击Github的网址了解详情。
目前,PaddleHub的命令行工具支持以下12个命令:

  • install:用于将Module安装到本地,默认安装在{HUB_HOME}/.paddlehub/modules目录下;
  • uninstall:卸载本地Module
  • show:用于查看本地已安装Module的属性或者指定目录下确定的Module的属性,包括其名字、版本、描述、作者等信息;
  • download:用于下载百度提供的Module
  • search:通过关键字在服务端检索匹配的Module,当想要查找某个特定模型的Module时,使用search命令可以快速得到结果,例如hub search ssd命令,会查找所有包含了ssd字样的Module,命令支持正则表达式,例如hub search ^s.\* 搜索所有以s开头的资源;
  • list:列出本地已经安装的Module
  • run:用于执行Module的预测;
  • version:显示PaddleHub版本信息;
  • help:显示帮助信息;
  • clearPaddleHub在使用过程中会产生一些缓存数据,这部分数据默认存放在${HUB_HOME}/.paddlehub/cache目录下,用户可以通过clear命令来清空缓存;
  • autofinetune:用于自动调整Fine-tune任务的超参数,具体使用详情参考PaddleHub AutoDL Finetuner使用教程;
  • config:用于查看和设置Paddlehub相关设置,包括对server地址、日志级别的设置;
  • serving:用于一键部署Module预测服务,详细用法见PaddleHub Serving一键服务部署。

PaddleHub的产品理念是模型即软件,通过Python API或命令行实现模型调用,可快速体验或集成飞桨特色预训练模型。

此外,当用户想用少量数据来优化预训练模型时,PaddleHub也支持迁移学习,通过Fine-tune API,内置多种优化策略,只需少量代码即可完成预训练模型的Fine-tuning

3. PaddleHub提供的预训练模型

为了更好的应用PaddleHub的各种能力,我们需要知道PaddleHub集成了哪些模型。PaddleHub提供的预训练模型涵盖了图像分类、目标检测、视频分类、图像生成、图像分割、关键点检测、词法分析、语义模型、情感分析、文本审核等主流模型。PaddleHub的资源已有100多个分布在各领域的预训练模型,其中各领域均有百度独有数据训练或独有技术积累的模型,即只能在PaddleHub中找到的强大预训练模型,如 图4 所示。


图4:PaddleHub特色预训练模型

PaddleHub中集成的模型列表如下(持续扩充中):

  • NLP模型列表
    • 语义模型:word2vec_skipgram、simnet_bow、rbtl3、rbt3、Ernie_v2_eng_large、ernie_v2_eng_base、ernie_tiny、ERNIE、chinese-roberta-wwm-ext-large、chinese-roberta-wwm-ext、chinese-electra-small、chinese-electra-base、chinese-bert-wwm-ext、chinese-bert-wwm
    • 文本审核: porn_detection_lstm、 porn_detection_gru、 porn_detection_cnn
    • 词法分析:lac
    • 情感分析:senta_lstm、senta_gru、senta_cnn、senta_bow、senta_bilstm、emotion_detection_textcnn
  • CV模型列表
    • 图像分类:vgg、xception、shufflenetv2、se_resnet、resnet、resnet_vd、resnet_v2、pnasnet、mobilenet、inception_v4、Googlenet、efficientent、dpn、densent、darknet、alexnet
    • 关键点检测:pose_resnet50_mpii、face_landmark_localization
    • 目标检测:yolov3、ssd、Pyramidbox、faster_rcnn
    • 图像生成: StyleProNNet、stgan、cyclegan、attgan
    • 图像分割:deeplabv3、ace2p
    • 视频分类:TSN、TSM、stnet、nonlocal

4. 使用自己的数据Fine-tune PaddleHub预训练模型

果农需要根据水果的不同大小和质量进行产品的定价,所以每年收获的季节有大量的人工对水果分类的需求。基于人工智能模型的方案,收获的大堆水果会被机械放到传送带上,模型会根据摄像头拍到的图片,控制仪器实现水果的自动分拣,节省了果农大量的人力。


图5:水果在工厂传送带上自动分类

下面我们就看看如果采集到少量的桃子数据,如何基于PaddleHubImageNet数据集上预训练模型进行Fine-tune,得到一个更有效的模型。桃子分类数据集取自AI Studio公开数据集桃脸识别,该桃脸识别数据集中已经将所有桃子的图片分为2个文件夹,一个是训练集一个是测试集;每个文件夹中有4个分类,分别是B1M2R0S3


图6:自动分类结果示意

使用PaddleHub中的模型进行迁移学习的步骤如 图7 所示:


图7:PaddleHub模型迁移学习步骤

实现迁移学习,包括如下步骤:

  1. 安装PaddleHub
  2. 数据准备
  3. 模型准备
  4. 训练准备
  5. 组建Fine-tune Task
  6. 启动Fine-tune

在迁移学习的过程中,除了指定迁移学习的问题类型之外(通过选择模型的方式),还可以选择迁移学习的策略,以及对新收集样本做出数据增强的方法。

4.1 安装PaddleHub

paddlehub安装可以使用pip完成安装,如下:

# 安装并升级PaddleHub,使用清华源更稳定、更迅速
pip install paddlehub==1.6.1 -i https://pypi.tuna.tsinghua.edu.cn/simple 
pip install paddlehub --upgrade -i https://pypi.tuna.tsinghua.edu.cn/simple
  • 1
  • 2
  • 3

4.2 数据准备

在本次教程提供的数据文件中,已经提供了分割好的训练集、验证集、测试集的索引和标注文件。如果用户利用PaddleHub迁移CV类任务使用自定义数据,则需要自行切分数据集,将数据集切分为训练集、验证集和测试集。需要三个文本文件来记录对应的图片路径和标签,此外还需要一个标签文件用于记录标签的名称。相关方法可参考用户自定义PaddleHub的数据格式。

├─data: 数据目录	├─train_list.txt:训练集数据列表	├─test_list.txt:测试集数据列表	├─validate_list.txt:验证集数据列表	├─label_list.txt:标签列表	└─……	
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

训练集、验证集和测试集的数据列表文件的格式如下,列与列之间以空格键分隔。

图片1路径 图片1标签	
图片2路径 图片2标签	
...
  • 1
  • 2
  • 3

label_list.txt的格式如下:

分类1名称	
分类2名称	
...	
  • 1
  • 2
  • 3
  • 4

准备好数据后即可使用PaddleHub完成数据读取器的构建,实现方法如下所示:构建数据读取Python类,并继承BaseCVDataset这个类完成数据读取器构建。只要按照PaddleHub要求的数据格式放置数据,均可以用这个数据读取器完成数据读取工作。

!unzip -q -o ./data/data34445/peach.zip -d ./work
  • 1
import paddlehub as hub
from paddlehub.dataset.base_cv_dataset import BaseCVDataset  #加载图像类自定义数据集,仅需要继承基类BaseCVDatast,修改数据集存放地址即可

class DemoDataset(BaseCVDataset):
def init(self):
# 数据集存放位置
self.dataset_dir = “./work/peach-classification” #dataset_dir为数据集实际路径,需要填写全路径
super(DemoDataset, self).init(
base_path=self.dataset_dir,
train_list_file=“train_list.txt”,
validate_list_file=“validate_list.txt”,
test_list_file=“test_list.txt”,
#predict_file=“predict_list.txt”, #如果还有预测数据(没有文本类别),可以将预测数据存放在predict_list.txt文件
label_list_file=“label_list.txt”,
# label_list=[“数据集所有类别”] #如果数据集类别较少,可以不用定义label_list.txt,可以选择定义label_list=[“数据集所有类别”]
)
dataset = DemoDataset()

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

4.3 模型准备

我们要在PaddleHub中选择合适的预训练模型来Fine-tune,由于桃子分类是一个图像分类任务,这里采用Resnet50模型,并且是采用ImageNet数据集Fine-tune过的版本。这个预训练模型是在图像任务中的一个“万金油”模型,Resnet是目前较为有效的处理图像的网络结构,50层是一个精度和性能兼顾的选择,而ImageNet又是计算机视觉领域公开的最大的分类数据集。所以,在不清楚选择什么模型好的时候,可以优先以这个模型作为baseline

使用PaddleHub,不需要重新手写Resnet50网络,可以通过一行代码实现模型的调用。

#安装预训练模型
! hub install resnet_v2_50_imagenet
  • 1
  • 2
Downloading resnet_v2_50_imagenet
[==================================================] 100.00%
Uncompress /home/aistudio/.paddlehub/tmp/tmpn1e1enxo/resnet_v2_50_imagenet
[==================================================] 100.00%
Successfully installed resnet_v2_50_imagenet-1.0.1
  • 1
  • 2
  • 3
  • 4
  • 5
import paddlehub as hub

module = hub.Module(name=“resnet_v2_50_imagenet”) #加载Hub提供的图像分类的预训练模型resnet_v2_50_imagenet

  • 1
  • 2
  • 3
[2020-05-25 10:45:16,692] [    INFO] - Installing resnet_v2_50_imagenet module
[2020-05-25 10:45:16,710] [    INFO] - Module resnet_v2_50_imagenet already installed in /home/aistudio/.paddlehub/modules/resnet_v2_50_imagenet
  • 1
  • 2

将训练数据输入模型之前,我们通常还需要对原始数据做一些数据处理的工作,比如数据格式的规范化处理,或增加一些数据增强策略

构建图像分类模型的数据读取器(Reader),负责将桃子dataset的数据进行预处理,以特定格式组织并输入给模型进行训练。

如下数据处理策略,只做了两种操作:

  1. 指定输入图片的尺寸,并将所有样本数据统一处理成该尺寸。
  2. 对所有输入图片数据进行归一化处理。其中,需要通过参数指定上一步的dataset来链接到具体数据集,相当于在第一步的数据读取器上又包了一层处理策略。
data_reader = hub.reader.ImageClassificationReader(image_width=module.get_expected_image_width(),   #预期桃子图片经过reader处理后的图像宽度image_height=module.get_expected_image_height(), #预期桃子图片经过reader处理后的图像高度images_mean=module.get_pretrained_images_mean(), #进行桃子图片标准化处理时所减均值。默认为Noneimages_std=module.get_pretrained_images_std(),   #进行桃子图片标准化处理时所除标准差。默认为Nonedataset=dataset)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
[2020-05-25 10:45:19,933] [    INFO] - Dataset label map = {'R0': 0, 'B1': 1, 'M2': 2, 'S3': 3}
  • 1

4.4 训练准备

定义好模型,也设定好数据读取器后,我们就可以开始设置训练的策略。训练的配置使用hub.RunConfig函数完成,包括配置Fine-tune的轮数、Batchsize、评估的间隔等等,实现如下:

# Setup runing config for PaddleHub Finetune API
config = hub.RunConfig(use_cuda=True,                                                #是否使用GPU训练,默认为False;num_epoch=1,                                                  #Fine-tune的轮数;checkpoint_dir="cv_finetune_turtorial_demo",                  #模型checkpoint保存路径, 若用户没有指定,程序会自动生成;batch_size=32,                                                #训练的批大小,如果使用GPU,请根据实际情况调整batch_size;eval_interval=50,                                             #模型评估的间隔,默认每100个step评估一次验证集;strategy=hub.finetune.strategy.DefaultFinetuneStrategy())     #Fine-tune优化策略;
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
[2020-05-25 10:45:22,634] [    INFO] - Checkpoint dir: cv_finetune_turtorial_demo
  • 1

4.5 组建Fine-tune Task

有了合适的预训练模型,并准备好要迁移的数据集后,我们开始组建一个Task。在PaddleHub中,Task代表了一个Fine-tune的任务。任务中包含了执行该任务相关的Program、数据读取器Reader、运行配置等内容。PaddleHub预置了常见任务的Task,每种Task都有特定的应用场景并提供了对应的度量指标,满足用户的不同需求。在这里可以找到图像分类任务的对应说明ImageClassifierTask。

由于桃子分类是一个四分类的任务,而我们下载的分类module是在ImageNet数据集上训练的1000分类模型。所以需要对模型进行简单的微调,即将最后一层1000分类全连接层改成4分类的全连接层,并重新训练整个网络。实现方案如下:

  1. 获取modulePaddleHub的预训练模型)的上下文环境,包括输入和输出的变量,以及Paddle Program(可执行的模型格式)。
  2. 从预训练模型的输出变量中找到特征图提取层feature_map,在feature_map后面接入一个全连接层,如下代码中通过hub.ImageClassifierTaskfeature_map参数指定。
  3. 网络的输入层保持不变,依然从图像输入层开始,如下代码中通过hub.ImageClassifierTask的参数feed_list变量指定。

hub.ImageClassifierTask就是通过这两个参数明确我们的截取骨干网络的要求,按照这样的配置,我们截取的网络是从输入层“image”一直到特征提取的最后一层“feature_map”

input_dict, output_dict, program = module.context(trainable=True) #获取module的上下文信息包括输入、输出变量以及paddle program

img = input_dict[“image”] #待传入图片格式

feature_map = output_dict[“feature_map”] #从预训练模型的输出变量中找到最后一层特征图,提取最后一层的feature_map

feed_list = [img.name] #待传入的变量名字列表

task = hub.ImageClassifierTask(
data_reader=data_reader, #提供数据的Reader
feed_list=feed_list, #待feed变量的名字列表
feature=feature_map, #输入的特征矩阵
num_classes=dataset.num_labels, #分类任务的类别数量
config=config) #运行配置

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
[2020-05-25 10:45:26,755] [    INFO] - 267 pretrained paramaters loaded by PaddleHub
  • 1

4.6 启动Fine-tune

最后,使用Finetune_and_eval函数可以同时完成训练和评估。在Fine-tune的过程中,控制台会周期性打印模型评估的效果,以便我们了解整个训练过程的精度变化。

run_states = task.finetune_and_eval() #通过众多finetune API中的finetune_and_eval接口,可以边训练,边打印结果
  • 1
[2020-05-25 10:45:34,892] [    INFO] - Strategy with slanted triangle learning rate, L2 regularization, 
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/executor.py:804: UserWarning: There are no operators in the program to be executed. If you pass Program manually, please use fluid.program_guard to ensure the current Program is being used.warnings.warn(error_info)
[2020-05-25 10:45:37,003] [    INFO] - Try loading checkpoint from cv_finetune_turtorial_demo/ckpt.meta
[2020-05-25 10:45:37,952] [    INFO] - PaddleHub model checkpoint loaded. current_epoch=2, global_step=188, best_score=0.99760
[2020-05-25 10:45:37,953] [    INFO] - PaddleHub finetune start
[2020-05-25 10:45:37,954] [    INFO] - PaddleHub finetune finished.
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

Fine-tune完成后,我们使用模型来进行预测,实现如下:

import numpy as np

data = ["./work/peach-classification/test/M2/0.png"] #传入一张测试M2类别的桃子照片
task.predict(data=data,return_result=True) #使用PaddleHub提供的API实现一键结果预测,return_result默认结果是False

  • 1
  • 2
  • 3
  • 4
[2020-05-25 10:45:40,748] [    INFO] - PaddleHub predict start
[2020-05-25 10:45:40,749] [    INFO] - Load the best model from cv_finetune_turtorial_demo/best_model
[2020-05-25 10:45:42,838] [    INFO] - PaddleHub predict finished.
['M2']
  • 1
  • 2
  • 3
  • 4

以上为加载模型后实际预测结果(这里只测试了一张图片),返回的是预测的实际效果,可以看到我们传入待预测的是M2类别的桃子照片,经过Fine-tune之后的模型预测的效果也是M2,由此成功完成了桃子分类的迁移学习。

5. 相关参考链接

  • PaddleHub 官网链接:https://www.paddlepaddle.org.cn/hub
  • PaddleHub Github链接:https://github.com/PaddlePaddle/PaddleHub
  • PaddleHub 课程链接:https://aistudio.baidu.com/aistudio/course/introduce/1070
                                </div>

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/479762.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

NIPS’20 Spotlight | 精准建模用户兴趣,广告CTR预估准确率大幅提升!

源 | 京东零售技术在以人工智能技术为支持的推荐、搜索、广告等业务中&#xff0c;点击率预估&#xff08;CTR&#xff09;一直是技术攻坚的核心&#xff0c;同时也是人工智能技术在业务落地中最难实现的技术方向之一。第一期介绍了视觉信息使用帮助提高点击率预估的准确度&…

史上最强大型分布式架构详解:高并发+数据库+缓存+分布式+微服务+秒杀

分布式架构设计是成长为架构师的必备技能&#xff0c;涵盖的内容很广&#xff0c;今天一次打包分享&#xff0c;文末有&#xff1a;最全分布式架构设计资料获取方式~ 负载均衡 负载均衡的原理和分类 负载均衡架构和应用场景 分布式缓存 常见分布式缓存比较&#xff1a;memcac…

论文浅尝 | 面向多语言语义解析的神经网络框架

论文笔记整理&#xff1a;杜昕昱&#xff0c;东南大学本科生。来源&#xff1a;ACL2017链接&#xff1a;https://aclweb.org/anthology/P17-2007论文训练了一个多语言模型&#xff0c;将现有的Seq2Tree模型扩展到一个多任务学习框架&#xff0c;该框架共享用于生成语义表示的解…

LeetCode 46. 全排列(回溯)

文章目录1. 题目信息2. 解题2.1 利用hash map解决2.2 改用bool数组判断是否出现过1. 题目信息 给定一个没有重复数字的序列&#xff0c;返回其所有可能的全排列。 示例:输入: [1,2,3] 输出: [[1,2,3],[1,3,2],[2,1,3],[2,3,1],[3,1,2],[3,2,1] ]来源&#xff1a;力扣&#xf…

谷歌师兄的刷题笔记分享!

高畅现在是谷歌无人车部门&#xff08;Waymo&#xff09;的工程师&#xff0c;从事计算机视觉和机器学习方向。他在美国卡内基梅隆大学攻读硕士学位时&#xff0c;为了准备实习秋招&#xff0c;他从夏天开始整理某 code 上的题目&#xff0c;几个月的时间&#xff0c;刷了几百道…

【深度揭秘】百度、阿里、腾讯内部岗位级别和薪资结构,附带求职建议!

“ 最近很忙&#xff0c;文章没有及时更新。。 最近被问得最多就是想进入BAT等一线互联网公司&#xff0c;应该怎么办&#xff1f; 我先从BAT等这样的公司看看他们的招聘需求谈起&#xff0c;再结合这样的公司需要对技术的要求是什么&#xff0c;最后结合我的建议&#xff0…

LeetCode 47. 全排列 II(回溯+搜索剪枝)

文章目录1. 题目信息2. 解题1. 题目信息 给定一个可包含重复数字的序列&#xff0c;返回所有不重复的全排列。 示例:输入: [1,1,2] 输出: [[1,1,2],[1,2,1],[2,1,1] ]来源&#xff1a;力扣&#xff08;LeetCode&#xff09; 链接&#xff1a;https://leetcode-cn.com/problem…

会议 | 2019 全国知识图谱与语义大会 (CCKS 2019)

会议注册:http://www.ccks2019.cn/?page_id53会议地址与住宿&#xff1a;http://www.ccks2019.cn/?page_id366OpenKG开放知识图谱&#xff08;简称 OpenKG&#xff09;旨在促进中文知识图谱数据的开放与互联&#xff0c;促进知识图谱和语义技术的普及和广泛应用。点击阅读原文…

jieba分词太慢,怎么办?找jieba_fast

原文链接&#xff1a;https://www.rtnzero.com/archives/272.html 有时候感觉处理一个几十M的文本&#xff0c;要一分钟才能好&#xff0c;然后调试时各种心焦&#xff01; 下面举个例子&#xff1a; 归零有一个11.9M的文本文件&#xff0c;是一些抓取到的Python长尾关键词&am…

DGL_图的打印

首先要安装 networkx import matplotlib.pyplot as plt import networkx as nx import dgl import numpy as np def build_karate_club_graph():src np.array([1, 2, 2, 3, 3])dst np.array([0, 0, 1, 0, 1])u np.concatenate([src, dst])v np.concatenate([dst, src])ret…

闲鱼账号被封怎么办?解封看这里!

怎样避免宝贝被屏蔽、限流解封账号&#xff1f;首先我们要学会规避封号的风险 不要频繁的更改账号&#xff0c;不要多账号单手机操作&#xff0c;一机一号才是正确。 不要连续给人商品点赞或是我想要&#xff0c;连续的操作容易被封 不要发布违禁品&#xff0c;违禁品具体可…

推荐系统顶会RecSys’20亮点赏析

文 | banana源 | 知乎RecSys 2020原计划是在南美洲巴西举办&#xff0c;因为疫情的原因不得不改到线上。虽说线上举办会议&#xff0c;参会效果会打折扣&#xff0c;但也为远在北京的我提供了参会便利。得益于各方的努力和软件的应用&#xff0c;整体来看此次参会的效果高于我对…

技术研讨会 | 2019 恒生技术开放日产业链知识图谱专场开始报名

知识图谱旨在采用图结构 (Graph Structure) 来建模和记录世界万物之间的关联关系和知识&#xff0c;是互联网时代的知识工程方法&#xff0c;能够对纷繁复杂、多源异构的金融资讯大数据进行加工整合&#xff0c;提升决策分析的效率&#xff0c;已经得到金融行业从业人士的普遍认…

01.神经网络和深度学习 W1.深度学习概论

文章目录1. 什么是神经网络2. 使用神经网络进行监督学习3. 神经网络的兴起4. 练习题1. 什么是神经网络 它是一个强大的学习算法&#xff0c;类似于人脑的工作方式。 例子1. 单个神经网络 给定房地产市场上房屋大小的数据&#xff0c;预测其价格。这是一个线性回归问题。 …

中文任务型对话系统中的领域分类

大规模跨领域中文任务导向多轮对话数据集及模型CrossWOZ&#xff1a;项目地址&#xff1a;https://gitee.com/yh14232988/CrossWOZ?_fromgitee_search 具体介绍&#xff1a;https://cloud.tencent.com/developer/article/1617197 北邮张庆恒&#xff1a;如何基于 rasa 搭建一…

互联网热门职位薪酬报告

“ 很多同学毕业后想进入互联网领域&#xff0c;当前有什么热门的互联网工作机会&#xff0c;薪资结构怎么样&#xff1f;看图说话&#xff0c;我简短给 大家做一个回报。 互联网职位需求最热的TOP20 mikechen&#xff1a;我个人比较看好旅游、金融板块、医疗健康板块&#x…

算法岗面试前怎样高效刷题?

如果不是为了面试AI工程师刷题有用吗&#xff1f;把时间都放在项目上不香嘛&#xff1f;作为一个战五渣&#xff0c;我特地去观察和询问了身边很多精通此道的大神&#xff0c;他们对于“刷题”还是保持着认可的态度&#xff1a;很清晰地理解问题的本质&#xff0c;并进行合理的…

征稿 | JIST 2019 Regular Technical Papers

JIST 2019: The 9th Joint International Semantic Technology ConferenceNov. 25-27, 2019, Hangzhou, China.http://jist2019.openkg.cn/第 9 届国际语义技术联合会议 JIST 2019 将于今年 11 月在美丽的杭州召开&#xff0c;投稿截止日期临近 (Abstract submission: 23:59 (H…

DGL_子图

用途一&#xff1a;数据集太大&#xff0c;无法画图&#xff0c;取子图看看是有向图/无向图 import dgl import matplotlib.pyplot as plt import networkx as nx G dgl.DGLGraph() G.add_nodes(5) # G.add_edges([0, 1, 2, 3, 4], [1, 2, 3, 4, 0]) # 有向图 G.add_edges(…

史上最全互联网八大技术岗位详解

“互联网技术岗位详解&#xff0c;涉及到前段开发、后端开发、移动端开发、大数据、项目管理、测试、运维、技术管理等八大领域。 架构师 每个产品线都有架构师&#xff0c;在技术平台部门也需要技术平台的架构师。 架构师负责设计系统整体架构&#xff0c;从需求到设计的每个…