创建和探索VGG16模型

        PyTorch在torchvision库中提供了一组训练好的模型。这些模型大多数接受一个称为 pretrained 的参数,当这个参数为True 时,它会下载为ImageNet 分类问题调整好的权重。让我们看一下创建 VGG16模型的代码片段:

from torchvision import models
vgg = models.vggl6(pretrained=True)

        现在有了所有权重已经预训练好且可马上使用的VGG16模型。当代码第一次运行时,可能需要几分钟,这取决于网络速度。权重的大小可能在500MB左右。我们可以通过打印快速查看下 VGG16模型。当使用现代架构时,理解这些网络的实现方式非常有用。我们来看看这个模型:

VGG((features): Sequential((0):Conv2d(3,64,kernel_size=(3,3),stride=(1,1),padding=(1,1))(1):ReLU (inplace)(2):Conv2d(64,64,kernel_size=(3,3),stride=(1,1),padding=(1,1))(3):ReLU(inplace)(4):MaxPool2d(size=(2,2),stride=(2,2),dilation=(1,1))(5):Conv2d(64,128,kernel_size=(3,3),stride=(1,1),padding=(1,1))(6):ReLU(inplace)(7):Conv2d(128,128,kernel_size=(3,3),stride=(1,1),padding=(1,1))(8):ReLU(inplace)(9):MaxPool2d(size=(2,2),stride=(2,2),dilation=(1,1))(10):Conv2d(128,256,kernel_size=(3,3),stride=(1,1),padding=(1,1))(11):ReLU(inplace)(12):Conv2d(256,256,kernel_size=(3,3),stride=(1,1),padding=(1,1))(13):ReLU(inplace)(14):Conv2d(256,256,kernel_size=(3,3),stride=(1,1),padding=(1,1))(15):ReLU(inplace)(16):MaxPool2d(size=(2,2),stride=(2,2)dilation=(1,1))(17):Conv2d(256,512,kernel_size=(3,3),stride=(1,1),padding=(1,1))(18):ReLU(inplace)(19):Conv2d(512,512,kernel_size=(3,3),stride=(1,1),padding=(1,1))(20):ReLU(inplace)(21):Conv2d(512,512,kernel_size=(3,3),stride=(1,1),padding=(1,1))(22):ReLU(inplace)(23):MaxPool2d(size=(2,2),stride=(2,2),dilation=(1,1))(24):Conv2d(512,512,kernel_size=(3,3),stride=(1,1),padding=(1,1))(25):ReLU(inplace)(26):Conv2d(512,512,kernel_size=(3,3),stride=(1,1),padding=(1,1))(27):ReLU(inplace)(28):Conv2d(512,512,kernel_size=(3,3),stride=(1,1),padding=(1,1))(29):ReLU(inplace)(30):MaxPool2d(size=(2,2),stride=(2,2),dilation=(1,1)))(classifier):Sequential((0):Linear(25088>4096)(1):ReLU(inplace)(2):Dropout(p=0.5)(3):Linear(4096->4096)(4):ReLU (inplace)(5):Dropout(p=0.5)(6):Linear(4096>1000))
)

        模型摘要包含了两个序列模型:features和classifiers。features和sequentia1模型包含了将要冻结的层。

冻结层

        下面冻结包含卷积块的features模型的所有层。冻结层中的权重将阻止更新这些卷积块的权重。由于模型的权重被训练用来识别许多重要的特征,因而我们的算法从第一个迭代开时就具有了这样的能力。使用最初为不同用例训练的模型权重的能力,被称为迁移学习。现在看一下如何冻结层的权重或参数:

for param in vgg.features.parameters():param.requires_grad = False

        该代码阻止优化器更新权重。

微调VGG16模型

        VGG16模型被训练为针对1000个类别进行分类,但没有训练为针对狗和猫进行分类。因此,需要将最后一层的输出特征从1000改为2。以下代码片段执行此操作:

vgg.classifier[6].out_features = 2

        vgg.classifier可以访问序列模型中的所有层,第6个元素将包含最后一个层。当训练VGG16模型时,只需要训练分类器参数。因此,我们只将classifier.parameters传入优化器,如下所示:

optimizer=
optim.SGD(vgg.classifier.parameters(),lr=0.0001,momentum=0.5)

训练VGG16模型

        我们已经创建了模型和优化器。由于使用的是Dogs vs. Cats数据集,因此可以使用相同的数据加载器和train函数来训练模型。请记住,当训练模型时,只有分类器内的参数会发生变化。下面的代码片段对模型进行了20轮的训练,在验证集上达到了98.45%的准确率:

train_losses, train_accuracy =[],[]
val_losses, val_accuracy =[],[]
for epoch in range(l,20):epoch_loss,epoch_accuracy=fit(epoch,vgg,train_data_loader,phase='training')val_epoch_loss,val_epoch_accuracy=fit(epoch,vgg,valid_data_loader,phase='validation')train_losses.append(epoch_loss)train_accuracy.append(epoch_accuracy)val_losses.append(val_epoch_loss)val_accuracy.append(val_epoch_accuracy)

        将训练和验证的损失可视化,如图5.19所示。

        将训练和验证的准确率可视化,如图5.20所示:

        我们可以应用一些技巧,例如数据增强和使用不同的dropout值来改进模型的泛化能力。以下代码片段将 VGG分类器模块中的dropout值从0.5更改为0.2并训练模型:

for layer in vgg.classifier.children():if(type(layer)== nn.Dropout):layer.p=0.2
#训练
train_losses,train_accuracy = [][]
val_losses, val accuracy =[],[ ]
for epoch in range(1,3):epoch_loss,epoch_accuracy=fit(epoch,vgg,train_data_loader,phase='training')val_epoch_loss,val_epoch_accuracy=fit(epoch,vgg,valid_data_loader,phase='validation')train_losses.append(epoch_loss)train_accuracy.append(epoch_accuracy)val_losses.append(val_epoch_loss)val_accuracy.append(val_epoch_accuracy)

        通过几轮的训练,模型得到了些许改进。还可以尝试使用不同的dropout值。改进模型泛化能力的另一个重要技巧是添加更多数据或进行数据增强。我们将通过随机地水平翻转图像或以小角度旋转图像来进行数据增强。torchvision转换为数据增强提供了不同的功能,它们可以动态地进行,每轮都发生变化。我们使用以下代码实现数据增强:

train transform =transforms.Compose([transforms,Resize((224,224)),transforms.RandomHorizontalFlip(),transforms.RandomRotation(0.2),transforms.ToTensor(),transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])
train = ImageFolder('dogsandcats/train/',train_transform)
valid = ImageFolder('dogsandcats/valid/',simple_transform)
#训练
train_losses,train_accuracy=[][]
val_losses,val_accuracy = [],[]
for epoch in range(1,3):epoch_loss,epoch_accuracy=fit(epoch,vgg,train_data_loader,phase='training')val_epoch_loss,val_epoch_accuracy=fit(epoch,vgg,valid_data_loader,phase='validation')train_losses.append(epoch_loss)train_accuracy.append(epoch_accuracy)val_losses.append(val_epoch_loss)val_accuracy.append(val_epoch_accuracy)

        前面的代码输出如下:

#结果
training loss is 0.041 and training accuracy is 22657/23000 98.51
validation loss is 0.043 and validation accuracy is 1969/2000 98.45
training loss is 0.04 and training accuracy is 22697/23000 98.68 
validation loss is 0.043 and validation accuracy is 1970/2000 98.5

        使用增强数据训练模型仅运行两轮就将模型准确率提高了0.1%;可以再运行几轮以进一步改进模型。如果大家在阅读本书时一直在训练这些模型,将意识到每轮的训练可能需要几分钟,具体取决于运行的GPU。让我们看一下可以在几秒钟内训练一轮的技术。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/34354.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【JavaScript脚本宇宙】加速您的网站:图像优化工具和库的终极指南

别让大图拖垮你的应用:如何正确优化图像 前言 在数字时代,图像是我们日常生活中不可或缺的一部分。然而,随着图像数量的增加和分辨率的提高,它们也占据了越来越多的存储空间和带宽。为了解决这个问题,开发人员可以使…

什么美业系统好用?美业门店收银系统源码分享、小程序展示

专业美业系统与普通系统相比,更加贴合美业门店的经营需求,提供了更全面、便捷、高效的管理功能,有助于提升门店的服务质量和经营效益。 博弈美业系统包括PC、iPad、手机、小程序四大端口,满足不同人群的各种需求。客户可从小程序…

python并行批量存储mat文件

输入:包含数组的列表arrays_list,以及包含每个数组存储位置的列表save_path_list from concurrent.futures import ThreadPoolExecutor, as_completed from tqdm import tqdm from scipy.io import * def save_array_to_mat(array, filepath):savemat(f…

有什么能和ai聊天的软件?5个软件教你快速和ai进行聊天

有什么能和ai聊天的软件?5个软件教你快速和ai进行聊天 当今数字化时代,人工智能(AI)技术已经逐渐渗透到我们的日常生活中,而与AI进行聊天也成为了一种趋势和乐趣。以下是五款可以和AI进行聊天的软件,它们提…

如何提高台式扫描电镜的放大倍数

台式扫描电镜(SEM)因其紧凑的设计和高效的成像能力,在材料科学、生物学和纳米技术等领域中扮演着重要角色。然而,用户在使用过程中可能会遇到需要更高放大倍数以获得更细微结构图像的情况。以下是一些提高台式扫描电镜放大倍数的策…

大厂面试官问我:Redis持久化RDB有没有可能阻塞?阻塞点在哪里?【后端八股文三:Redis持久化八股文合集】

往期内容: 大厂面试官问我:Redis处理点赞,如果瞬时涌入大量用户点赞(千万级),应当如何进行处理?【后端八股文一:Redis点赞八股文合集】-CSDN博客 大厂面试官问我:布隆过滤…

与其他自动化配置管理工具(如 Ansible 、Chef )相比,Puppet 的独特优势和局限性分别是什么?

Puppet的独特优势包括: 基于声明式语言:Puppet使用自己的声明式语言(Puppet DSL)来描述系统配置,使得配置更加简洁、易于理解和维护。 完善的资源模型:Puppet具有丰富的资源模型,可以管理各种不…

C++ 入门

前言 c的发展史: C的起源可以追溯到1979年,当时Bjarne Stroustrup在贝尔实验室开始开发一种名为“C with Classes”的语言。以下是C发展的几个关键阶段: 1979年:Bjarne Stroustrup在贝尔实验室开始开发“C with Classes”。1983…

鸿蒙NEXT,保障亿万中国老百姓数据安全的操作系统

吉祥学安全知识星球🔗除了包含技术干货:Java代码审计、web安全、应急响应等,还包含了安全中常见的售前护网案例、售前方案、ppt等,同时也有面向学生的网络安全面试、护网面试等。 上周华为发布了最新的鸿蒙NEXT操作系统&#xff0…

windows系统上nginx搭建文件共享

1、下载windows版nginx 下载地址 2、配置nginx 编辑nginx.conf配置文件 在http模块下添加这个参数 underscores_in_headers on;#修改location内容,共享哪个文件夹,就写哪个文件夹,最后一定要跟上/,否则无法访问 location / {…

深入解析Ansible

文章目录 引言Ansible的原理Ansible的使用安装Ansible配置Ansible编写Playbook执行Playbook Ansible的优缺点Ansible的优点Ansible的缺点 总结 引言 在现代IT运维中,自动化工具扮演着至关重要的角色。Ansible作为一款开源的自动化运维工具,凭借其易用性…

Depth Anything环境搭建推理测试

引子 基于单目摄像头的深度估计,一直是CV领域的一个难点,之前也对此关注也不够多。偶然浏览技术博客,看到Depth Anything: Unleashing the Power of Large-Scale Unlabeled Data这个最新CVPR2024的工作。看到名字,大概也能猜出来…

【机器学习300问】130、什么是Seq2Seq?又叫编码器(Encoder)和解码器(Decoder)。

Seq2Seq,全称为Sequence to Sequence,是一种用于处理序列数据的神经网络模型,特别适用于如机器翻译、语音识别、聊天机器人等需要将一个序列转换为另一个序列的任务。这种模型由两部分核心组件构成:编码器(Encoder&…

服务器(Linux系统的使用)——自学习梳理

root表示用户名 后是机器的名字 ~表示文件夹,刚上来是默认的用户目录 ls -a 可以显示出隐藏的文件 蓝色的表示文件夹 白色的是文件 ll -a 查看详细信息 total表示所占磁盘总大小 一般以KB为单位 d开头表示文件夹 -代表文件 后面得三组rwx分别对应管理员用户-组…

shell的正则表达------awk

一、awk:按行取列 1.awk原理:根据指令信息,逐行的读取文本内容,然后按照条件进行格式化输出。 2.awk默认分隔符:空格、tab键,把多个空格自动压缩成一个。 3.awk的选项: awk ‘操作符 {动作}’…

pytorch库 03 基础知识

文章目录 一、准备工作二、tensorboard的使用1、add_scalar()方法2、add_image()方法 三、transforms的使用1、ToTensor()类2、常见transforms的类 三、torchvision中的数据集使用 官网 https://pytorch.org/ 一、准备工作 ①在pycharm和jupyter上,检查当前系统是…

构建LangChain应用程序的示例代码:42、如何使用 `LLMCheckerChain` 来验证和校正由大型语言模型(LLM)生成的文本

自我检查链使用指南 概述 本指南展示了如何使用 LLMCheckerChain 来验证和校正由大型语言模型(LLM)生成的文本。 代码示例 from langchain.chains import LLMCheckerChain # 导入 LLMCheckerChain 类 from langchain_openai import OpenAI # 导入 …

SpringBoot的Web开发支持【超详细【一篇搞定】果断收藏系列】

Override public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex) throws Exception { System.out.println(“MyInterceptor.afterCompletion”); } } 使用Java的形式配置拦截器的拦截路径 在WebMvcConfig…

记录:windows 命令板快捷键

dir 列出当前目录下的所有文件cd 目录名: cd. 进入当前目录 cd…进入上一层目录md 目录名 创建文件夹rd 目录名 删除文件夹cd.>文件名.后缀名 比如 cd.>a.txtcls 清除exit 退出

与亚马逊云科技深度合作,再获WAPP、ISV认证

上半年,VERYCLOUD睿鸿股份加入亚马逊云科技的WAPP(Well-Architected Partner Programs)和ISV加速计划(ISV Accelerate Program),为客户带来更坚实优质的海外云服务。 Well-Architected 获得WAPP这项认证代表…