《pytorch深度学习实战》学习笔记第2章

第2章 预训练网络

讨论3种常用的预训练模型:

        1、根据内容对图像进行标记(识别)

        2、从真实图像中生成新图像(GAN)

        3、使用正确的英语句子来描述图像内容(自然语言)

2.1 获取一个预训练好的网络用于图像识别

ImageNet数据集,用于大规模视觉识别挑战赛。

所有预训练好的模型都在TorchVision中。

2.1.1 导入已有的模型

所有模型都在torchvison的models中。导入并查看。

from torchvision import models
dir(models)

输出的是所有torchvison里面集成的模型框架。其中首字母大写的是一些流行的模型小写的名字是快捷函数,返回实例化模型函数

1.1.1 AlexNet模型

实例化AlexNet。

alexnet=models.AlexNet()
alexnet

可以像函数一样调用它。给alexnet输入数据,就会通过正向传播(forward pass)得到输出。比如output=alexnet(input)。由于网络没有初始化,没有经过训练。所以一般先要将模型从头训练或者加载训练好的网络。然后再调用。

1.1.2 Resnet模型

(1)加载在ImageNet数据集上训练好的权重,来实例化ResNet101
resnet=models.resnet101(pretrained=True)
resnet

然后就开始下载,下载完成后查看resnet101的结构。

神经网络由许多模块构成,包含过滤器和非线性函数,fc层结束,输出每个类的分数。

预训练好的模型可以跟函数一样调用,并输入图片实现预测。

(2)定义预处理函数:
from torchvision import transforms
preprocess = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])
])

预处理包括:图像缩放到256*256像素,围绕中心裁剪到224*224像素,转为张量,归一化处理,使用定义的均值和标准差。

(3)导入图片并进行预处理

导入一张狗的照片并显示

from PIL import Image
img = Image.open('bobby.jpg')
img

待用预处理函数对图片进行预处理。

img_t=preprocess(img)
img_t.shape

输出为一个3维的张量。

给张量前面再增加一个维度。

import torch
batch_t = torch.unsqueeze(img_t,0)#加一个维度,数字0代表增加在第0维前面,如果为1就代表维度1前面
batch_t.shape

输出在第0维前面增加了一个1.

(4)运行模型

在新数据上运行训练过的模型的过程被称为推理(inference),为了推理需要先将网络放到eval模式。执行代码:

resnet.eval()

进行推理:

out=resnet(batch_t)
out

产生了一个1000分类的向量,每个ImageNet对应一个分数。

(5)查看预测结果

加载定义好的ImageNet标签。

with open('imagenet_classes.txt') as f:labels = [line.strip() for line in f.readlines()]

需要找出out输出在labels标签中的索引。可以利用max()函数输出张量中最大值以及最大值的索引。代码如下:

_,index=torch.max(out,1)
index

输出的索引不是一个数字,而是一个一维张量。

使用index[0]获得实际的数字作为标签列表的索引,用torch.nn.functional.softmax()将输出归一化到[0,1],然后除以总和。可以求出模型在预测中的置信度。

代码:

percentage = torch.nn.functional.softmax(out,dim=1)[0]*100
labels[index[0]],percentage[index[0]].item()

输出:('golden retriever', 96.29335021972656)

分类结果维金毛犬,置信度为96%。

也可以对预测结果的其它值进行排序输出。比如输出前5个。

_,indices = torch.sort(out,descending=True)
[(labels[idx],percentage[idx].item()) for idx in indices[0][:5]]

输出:

2.2 一个足以以假乱真的预训练模型

GAN是生成式对抗网络(generative adversarial network)的缩写。

cycleGAN是循环生成式对抗网络的缩写,可以将一个领域的图像转换为另一个领域的图像。

2.2.1 将马变为斑马的网络

CycleGAN从ImageNet数据集中提取的马和斑马的数据集进行训练。该网络学习获取一匹或多匹马的图像,并将它们全部变成斑马,图像的其余部分尽可能不被修改。

使用预训练好的CycleGAN将使我们有机会更进一步了解网络是如何实现的,对于本例就是生成器。

(1)以ResNet为例,定义一个ResNetGenerator类。

import torch
import torch.nn as nnclass ResNetBlock(nn.Module): # <1>def __init__(self, dim):super(ResNetBlock, self).__init__()self.conv_block = self.build_conv_block(dim)def build_conv_block(self, dim):conv_block = []conv_block += [nn.ReflectionPad2d(1)]conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),nn.InstanceNorm2d(dim),nn.ReLU(True)]conv_block += [nn.ReflectionPad2d(1)]conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),nn.InstanceNorm2d(dim)]return nn.Sequential(*conv_block)def forward(self, x):out = x + self.conv_block(x) # <2>return outclass ResNetGenerator(nn.Module):def __init__(self, input_nc=3, output_nc=3, ngf=64, n_blocks=9): # <3> assert(n_blocks >= 0)super(ResNetGenerator, self).__init__()self.input_nc = input_ncself.output_nc = output_ncself.ngf = ngfmodel = [nn.ReflectionPad2d(3),nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=True),nn.InstanceNorm2d(ngf),nn.ReLU(True)]n_downsampling = 2for i in range(n_downsampling):mult = 2**imodel += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,stride=2, padding=1, bias=True),nn.InstanceNorm2d(ngf * mult * 2),nn.ReLU(True)]mult = 2**n_downsamplingfor i in range(n_blocks):model += [ResNetBlock(ngf * mult)]for i in range(n_downsampling):mult = 2**(n_downsampling - i)model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),kernel_size=3, stride=2,padding=1, output_padding=1,bias=True),nn.InstanceNorm2d(int(ngf * mult / 2)),nn.ReLU(True)]model += [nn.ReflectionPad2d(3)]model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]model += [nn.Tanh()]self.model = nn.Sequential(*model)def forward(self, input): # <3>return self.model(input)

(2)实例化

netG = ResNetGenerator()

权重为随机权重。

(3)将预训练好的权重添加到ReNet Generator中。

model_path='horse2zebra_0.4.0.pth'
model_data = torch.load(model_path)
netG.load_state_dict(model_data)

执行后,netG就获得了训练中需要的所有知识。

(4)推理

netG.eval()

输出:

程序是将一匹或多匹马逐像素修改。

导入随机马的图像进行测试。

导入需要的库:

from PIL import Image
from torchvision import transforms

定义预处理函数:

preprocess = transforms.Compose([transforms.Resize(256),transforms.ToTensor()
])

导入马的图片:

img = Image.open('horse.jpg')
img

对图片预处理:

img_t = preprocess(img)
batch_t = torch.unsqueeze(img_t,0)

将变量传递给模型:

batch_out = netG(batch_t)

将生成器的输出转换为图像。

out_t = (batch_out.data.squeeze()+1.0)/2.0
out_img = transforms.ToPILImage()(out_t)
out_img

2.6 练习题

1.将金毛猎犬的图像输入马-斑马模型中。

参考资料:

1. 预训练网络 · 深度学习与PyTorch(中文版) (paper2fox.github.io)

4. PyTorch深度学习 Deep Learning with PyTorch ch.2, p2_哔哩哔哩_bilibili

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/797619.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

通过 Cookie、Redis共享Session 和 Spring 拦截器技术,实现对用户登录状态的持有和清理(三)

本篇内容对应 “2.4 生成验证码” 小节 和 “4.7 优化登陆模块”小节 视频链接 1 Kaptcha介绍 Kaotcga是一个生成验证码的工具。 你的网站验证码是什么&#xff1f; 在我们这个牛客论坛项目&#xff0c;验证码分为两部分 给用户看的是图片&#xff0c;用户根据图片上显示的…

什么牌子开放式耳机好用?优选五大高分好物真诚分享

对于习惯长时间佩戴耳机的朋友来说&#xff0c;入耳式耳机固然能够提供较优质的音质体验。但是&#xff0c;由于其较为封闭的设计以及对耳洞的压迫&#xff0c;舒适感较差&#xff0c;长时间佩戴可能会对听力造成一定的影响。因此&#xff0c;开放式耳机的出现为音乐发烧友们提…

Leetcode 684. 冗余连接

心路历程&#xff1a; 这道题属于图论的经典连通问题&#xff0c;这道题翻译过来就是&#xff0c;找到破开图中环的一条边&#xff1b;再翻译过来就是&#xff0c;从后往前遍历edges&#xff0c;依次连接边&#xff0c;当发现新连接的边已经有相同父节点时&#xff08;已经马上…

基于单片机风力发电机迎风面对风向的追踪系统设计

**单片机设计介绍&#xff0c;基于单片机风力发电机迎风面对风向的追踪系统设计 文章目录 一 概要二、功能设计三、 软件设计原理图 五、 程序六、 文章目录 一 概要 基于单片机风力发电机迎风面对风向的追踪系统设计是一个涉及单片机编程、传感器技术、机械控制等多个领域的综…

java日志框架简介

文章目录 概要常用日志框架常见框架有以下&#xff1a;slf4j StaticLoggerBinder绑定过程&#xff08;slf4j-api-1.7.32 &#xff09;JCL 运行时动态查找过程&#xff1a;&#xff08;commons-logging-1.2&#xff09;使用桥接修改具体日志实现 一行日志的打印过程开源框架日志…

面试算法-153-旋转图像

题目 给定一个 n n 的二维矩阵 matrix 表示一个图像。请你将图像顺时针旋转 90 度。 你必须在 原地 旋转图像&#xff0c;这意味着你需要直接修改输入的二维矩阵。请不要 使用另一个矩阵来旋转图像。 示例 1&#xff1a; 输入&#xff1a;matrix [[1,2,3],[4,5,6],[7,8,…

Java项目:基于Springboot+vue实现的医院住院管理系统设计与实现(源码+数据库+开题报告+任务书+毕业论文)

一、项目简介 本项目是一套基于Springbootvue实现的医院住院管理系统设 包含&#xff1a;项目源码、数据库脚本等&#xff0c;该项目附带全部源码可作为毕设使用。 项目都经过严格调试&#xff0c;eclipse或者idea 确保可以运行&#xff01; 该系统功能完善、界面美观、操作简…

基于Springboot+Vue实现前后端分离社团管理系统

一、&#x1f680;选题背景介绍 &#x1f4da;推荐理由&#xff1a; 21世纪时信息化的时代&#xff0c;几乎任何一个行业都离不开计算机&#xff0c;将计算机运用于社团管理也是十分常见的。过去使用手工的管理方式对大学生社团进行管理&#xff0c;造成了管理繁琐、难以维护等…

基于java+SpringBoot+Vue的房屋租赁系统设计与实现

基于javaSpringBootVue的房屋租赁系统设计与实现 开发语言: Java 数据库: MySQL技术: Spring Boot JSP工具: IDEA/Eclipse、Navicat、Maven 系统展示 前台展示 房源浏览模块&#xff1a;展示可租赁的房源信息&#xff0c;用户可以根据条件筛选房源。 预约看房模块&#…

java项目基于Springboot和Vue的高校心理教育辅导系统的设计与实现

今天要和大家聊的是基于Springboot和Vue的高校心理教育辅导系统的设计与实现 &#xff01;&#xff01;&#xff01; 有需要的小伙伴可以通过文章末尾名片咨询我哦&#xff01;&#xff01;&#xff01; &#x1f495;&#x1f495;作者&#xff1a;李同学 &#x1f495;&…

springboot实战---5.最简单最高效的后台管理系统开发

&#x1f388;个人主页&#xff1a;靓仔很忙i &#x1f4bb;B 站主页&#xff1a;&#x1f449;B站&#x1f448; &#x1f389;欢迎 &#x1f44d;点赞✍评论⭐收藏 &#x1f917;收录专栏&#xff1a;SpringBoot &#x1f91d;希望本文对您有所裨益&#xff0c;如有不足之处&…

安达发|APS软件在皮具箱包生产工艺中的应用

APS软件&#xff0c;即高级生产计划排程系统&#xff08;Advanced Planning and Scheduling&#xff09;&#xff0c;在皮具箱包生产工艺中的应用至关重要。它通过高效的生产计划和资源优化&#xff0c;帮助企业降低成本、提高生产效率和市场响应速度。以下是APS软件在皮具箱包…

day03-Docker

1.初识 Docker 1.1.什么是 Docker 1.1.1.应用部署的环境问题 大型项目组件较多&#xff0c;运行环境也较为复杂&#xff0c;部署时会碰到一些问题&#xff1a; 依赖关系复杂&#xff0c;容易出现兼容性问题开发、测试、生产环境有差异 例如一个项目中&#xff0c;部署时需要依…

代码随想录学习Day 24

93.复原IP地址 题目链接 讲解链接 本题属于切割问题&#xff0c;切割问题需要使用回溯算法来将所有的结果搜索出来&#xff0c;与前一题分割回文串是类似的。本题的树形结构如下图所示&#xff1a; 回溯三部曲&#xff1a; 1.递归函数参数及返回值&#xff1a;参数为待分割…

2012年认证杯SPSSPRO杯数学建模D题(第一阶段)人机游戏中的数学模型全过程文档及程序

2012年认证杯SPSSPRO杯数学建模 D题 人机游戏中的数学模型 原题再现&#xff1a; 计算机游戏在社会和生活中享有特殊地位。游戏设计者主要考虑易学性、趣味性和界面友好性。趣味性是本质吸引力&#xff0c;使玩游戏者百玩不厌。网络游戏一般考虑如何搭建安全可靠、丰富多彩的…

JVM高级篇之GC

文章目录 版权声明垃圾回收器的技术演进ShenandoahShenandoah GC体验Shenandoah GC循环过程 ZGCZGC简介ZGC的版本更迭ZGC体验&使用ZGC的参数设置ZGC的调优 版权声明 本博客的内容基于我个人学习黑马程序员课程的学习笔记整理而成。我特此声明&#xff0c;所有版权属于黑马…

【C++】拆分详解 - 内存管理

文章目录 前言一、C/C内存分布二、C语言中动态内存管理方式&#xff1a;malloc/calloc/realloc/free三、C内存管理方式  3.1 new/delete操作内置类型  3.2 new和delete操作自定义类型  3.3 operator new与operator delete函数 四、new和delete的实现原理  4.1 内置类型…

【微服务】SpringCloud之Feign远程调用

&#x1f3e1;浩泽学编程&#xff1a;个人主页 &#x1f525; 推荐专栏&#xff1a;《深入浅出SpringBoot》《java对AI的调用开发》 《RabbitMQ》《Spring》《SpringMVC》《项目实战》 &#x1f6f8;学无止境&#xff0c;不骄不躁&#xff0c;知行合一 文章目录 …

Solo 开发者周刊 (第10期):Sora 之后,谁是被遗忘的?谁又是被仰望的?

这里会整合 Solo 社区每周推广内容、产品模块或活动投稿&#xff0c;每周五发布。在这期周刊中&#xff0c;我们将深入探讨开源软件产品的开发旅程&#xff0c;分享来自一线独立开发者的经验和见解。本杂志开源&#xff0c;欢迎投稿。 好文推荐 Solo 社区 x 机器之心-再谈复现 …

如何利用HubSpot 出海CRM实现精准海外客户定位与拓展?

在当今全球化的商业环境中&#xff0c;企业寻求海外市场的拓展已成为增长的重要策略。然而&#xff0c;海外市场的复杂性和多样性为企业带来了巨大的挑战。为了有效地定位和拓展海外客户&#xff0c;许多企业选择了HubSpot 出海CRM作为他们的营销和销售管理工具。今天运营坛将带…