pytorch backbone

1 简介

在PyTorch深度学习中,预训练backbone(骨干网络)是一个常见的做法,特别是在处理图像识别、目标检测、图像分割等任务时。预训练backbone通常是指在大型数据集(如ImageNet)上预先训练好的卷积神经网络(CNN)模型,这些模型能够提取图像中的通用特征,这些特征在多种任务中都是有用的。

1. 常见的预训练Backbone

以下是一些在PyTorch中常用的预训练backbone:

  • ResNet:由何恺明等人提出的深度残差网络,通过引入残差连接解决了深层网络训练中的梯度消失或梯度爆炸问题。ResNet系列包括ResNet-18、ResNet-34、ResNet-50、ResNet-101、ResNet-152等,数字表示网络的层数。
  • VGG:由牛津大学的Visual Geometry Group提出,特点是使用了多个小卷积核(如3x3)的卷积层和池化层来构建深层网络。VGG系列包括VGG16、VGG19等。
  • MobileNet:专为移动和嵌入式设备设计的轻量级网络,通过深度可分离卷积减少了计算量和模型大小。
  • DenseNet:通过密集连接(dense connections)提高了信息流动和梯度传播效率,进一步增强了特征重用。
  • EfficientNet:通过同时缩放网络的深度、宽度和分辨率来优化网络,实现了在保持模型效率的同时提高准确率。

2. 如何使用预训练Backbone

在PyTorch中,使用预训练backbone通常涉及以下几个步骤:

  1. 导入模型:使用PyTorch的torchvision.models模块导入所需的预训练模型。

    import torchvision.models as models  # 导入预训练的ResNet50模型  
    resnet50 = models.resnet50(pretrained=True)
    print(resnet50)
  2. 修改模型:根据需要修改模型的最后几层以适应特定的任务(如分类任务中的类别数)。

    # 假设我们有一个100类的分类任务  
    num_ftrs = resnet50.fc.in_features  
    resnet50.fc = torch.nn.Linear(num_ftrs, 100)
  3. 冻结backbone:在训练时,可以选择冻结backbone的参数,只训练新添加的层(如分类层),这有助于加快训练速度并防止过拟合。

    for param in resnet50.parameters():  param.requires_grad = False  # 只对新添加的层设置requires_grad=True  
    resnet50.fc.parameters().requires_grad = True
  4. 训练模型:使用适当的数据集和训练策略来训练模型。

  5. 评估模型:在测试集上评估模型的性能。

3. 注意事项

  • 使用预训练权重时,应确保输入图像的预处理(如大小调整、归一化等)与预训练时使用的预处理一致。
  • 冻结backbone时,应确保模型的其余部分(如新添加的层)有足够的容量来学习任务特定的特征。
  • 在某些情况下,解冻backbone的一部分或全部并在目标数据集上进行微调可能会获得更好的性能。

通过以上步骤,可以在PyTorch中有效地利用预训练backbone来解决各种计算机视觉任务。

2 查看模型源码

想查看models.resnet50的源码,可以点击查看pytorch中的官方注释,可以看到源码链接为

vision/torchvision/models/resnet.py at main · pytorch/vision · GitHub

这样就可以看到 class ResNet(nn.Module) 的定义

3 查看权重参数

在PyTorch中,查看深度学习预训练backbone的权重参数可以通过几种方法实现。以下是一些常用的步骤和方法:

1. 加载预训练模型

首先,你需要使用torchvision.models模块加载所需的预训练模型。例如,加载一个预训练的ResNet50模型:

import torchvision.models as models  # 加载预训练的ResNet50模型  
resnet50 = models.resnet50(pretrained=True)

2. 查看模型参数

方法一:使用model.parameters()

model.parameters()方法返回一个生成器,包含模型的所有参数(权重和偏置)。但是,这个方法不会直接显示参数的名称,只适合在训练循环中迭代参数。

方法二:使用model.named_parameters()

model.named_parameters()方法返回一个生成器,其中每个元素都是一个包含参数名称和参数本身的元组。这是查看模型每层权重参数及其名称的最直接方法。

for name, param in resnet50.named_parameters():  print(name, param.size())

这段代码会遍历模型的所有参数,并打印出每个参数的名称和尺寸。

3. 专注于特定层的参数

如果你只对backbone中的特定层感兴趣,可以进一步筛选named_parameters()的输出。例如,如果你想看ResNet50中第一个卷积层的参数:

for name, param in resnet50.named_parameters():  if 'conv1' in name:  print(name, param.size())

4. 注意事项

  • 当查看模型参数时,请确保你了解模型的架构,以便正确地解释参数的名称和尺寸。
  • 预训练模型的权重是在特定数据集(如ImageNet)上训练的,因此这些权重可能对你的特定任务有所帮助,但也可能需要进一步的微调。
  • 如果你的模型是基于预训练模型进行修改的(例如,更改了最后一层以匹配不同的类别数),请确保你理解这些修改如何影响模型的参数。

5. 示例输出

运行上述代码(针对ResNet50的named_parameters())将输出类似以下的信息(输出将非常长,这里只展示部分):

conv1.weight torch.Size([64, 3, 7, 7])  
conv1.bias torch.Size([64])  
bn1.weight torch.Size([64])  
bn1.bias torch.Size([64])  
bn1.running_mean torch.Size([64])  
bn1.running_var torch.Size([64])  
...

这表示conv1层有一个权重参数(大小为[64, 3, 7, 7])和一个偏置参数(大小为[64]),以及对应的批量归一化层的权重、偏置、运行均值和运行方差等参数。

4 常见bakcbone以及适用业务

在PyTorch中,预训练的backbone模型是深度学习领域中的重要组成部分,它们为各种任务提供了强大的特征提取能力。然而,由于PyTorch本身是一个灵活的深度学习框架,它并不直接提供所有可能的预训练backbone模型,而是由社区和研究者基于PyTorch框架实现并分享。以下是一些常见的PyTorch预训练backbone模型,以及它们的优劣和适用场景:

1. ResNet(残差网络)

优势

  • 引入了残差连接,解决了深层网络训练中的梯度消失或梯度爆炸问题。
  • 在多个计算机视觉任务中表现出色,如图像分类、目标检测等。

劣势

  • 对于某些特定任务,可能不是最优选择,需要根据任务特点进行调整。

适用场景

  • 图像分类、目标检测、语义分割等。

2. VGG

优势

  • 结构简单明了,易于理解和实现。
  • 在多个基准数据集上取得了良好的性能。

劣势

  • 参数量较大,计算成本较高。

适用场景

  • 早期深度学习研究和教学。

3. MobileNet

优势

  • 专为移动和嵌入式设备设计,具有较小的模型大小和较快的推理速度。
  • 采用了深度可分离卷积等技术,减少了计算量和参数量。

劣势

  • 相比于其他大型模型,可能在某些复杂任务上的精度稍低。

适用场景

  • 移动应用、嵌入式设备上的实时图像处理和分类。

4. DenseNet(密集连接网络)

优势

  • 每一层都直接与后面的所有层相连,增强了特征传播和复用。
  • 在多个数据集上取得了比ResNet更好的性能。

劣势

  • 参数量和计算量相对较大。

适用场景

  • 需要高精度和强特征表达能力的任务,如医学图像分析。

5. EfficientNet

优势

  • 通过复合缩放方法(compound scaling)平衡了网络的深度、宽度和分辨率,实现了在有限资源下的最佳性能。
  • 在多个计算机视觉任务中取得了SOTA(state-of-the-art)性能。

劣势

  • 需要根据具体任务进行微调以获得最佳性能。

适用场景

  • 追求极致性能的计算机视觉任务,如大规模图像分类和检测。

6. YOLOv5的Backbone(如CSPDarknet)

优势

  • 专为目标检测任务设计,具有较快的推理速度和较高的检测精度。
  • 采用了CSPNet等结构,进一步提升了网络性能。

劣势

  • 相比于专门的分类网络,可能在分类任务上的性能稍逊。

适用场景

  • 实时目标检测任务,如自动驾驶、视频监控等。

请注意,以上列举的backbone模型并不全面,PyTorch社区和研究者们不断在推出新的模型和架构。此外,每种模型都有其特定的优势和劣势,以及适用的场景。在选择模型时,需要根据具体任务的需求、计算资源等因素进行综合考虑。

对于PyTorch中预训练backbone模型的获取,可以通过PyTorch的官方模型库(如torchvision)或第三方库(如timmpretrainedmodels等)来获取。这些库提供了大量预训练的backbone模型,并支持多种加载和使用方式。

5 从backbone提取特征图(☆)

import torch
import torch.nn as nn
import torchvision.models as models
from collections import OrderedDictclass ResNet18(nn.Module):def __init__(self):super().__init__()self.resnet18 = models.resnet18(pretrained=True)def forward(self, x):features = OrderedDict()x = self.resnet18.conv1(x)x = self.resnet18.bn1(x)x = self.resnet18.relu(x)x = self.resnet18.maxpool(x)features['3'] = xx = self.resnet18.layer1(x)x = self.resnet18.layer2(x)features['2'] = xx = self.resnet18.layer3(x)features['1'] = xx = self.resnet18.layer4(x)features['0'] = xreturn featuresmodel = ResNet18()
input = torch.ones(1, 3, 640, 640)  # NCHW
y = model(input)
for key, value in y.items():print(key, value.shape)

打印信息

3 torch.Size([1, 64, 160, 160])
2 torch.Size([1, 128, 80, 80])
1 torch.Size([1, 256, 40, 40])
0 torch.Size([1, 512, 20, 20])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/49783.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于Flink SQL CDC的实时数据同步

基于Flink SQL CDC(Change Data Capture)的实时数据同步是一种高效的数据处理方案,它允许用户实时捕获数据库中的变更操作,并将这些变更以流的形式进行处理和同步到其他系统或数据库中。以下是关于基于Flink SQL CDC的实时数据同步…

Linux嵌入式学习——数据结构——队列

一、概念 1)定义 是只允许在一端进行插入操作,而在另一端进行删除操作的线性表 队列 是一种 先进先出(First In First Out) 的线性表 线性表有顺序存储和链式存储,栈是线性表,所以有这两种存储方式 同样…

【在开发小程序的时候如何排查问题】

在开发小程序的时候如何排查问题 在最近开发小程序的时候,经常出现本地在浏览器中调试没有问题,但是一发布到预发环境就出现各种个样的问题 手机兼用性问题 有时候会出现苹果🍎手机键盘弹出,导致ui界面高度出现异常边界问题&#…

使用PageHelper插件来分页查询

目录 一.什么是PageHelper? 二.PageHelper的实战操作: 1.导入PageHelper的相关依赖: 2.配置代码展示: 3.分页查询代码解析: 另外,肯定读者会好奇为什么能够自动动态拼接? 一.什么是PageH…

关于Static 误用问题,总是记不住

一、常规的 静态局部变量,静态成员变量和成员函数没啥疑问 二、全局变量问题。。。 * 如果在 C 文件中使用 static 修饰全局变量, * 它将限制变量的作用域在当前文件内。 * 这意味着其他文件无法直接访问或修改这个变量的值。 …

Arduino IDE界面和设置(基础知识)

Arduino IDE界面和设置(基础知识) 1-2 Arduino IDE界面和设置如何来正确选择Arduino开发板型号如何正确选择Arduino这个端口如何来保存一个Arduino程序Arduino ide 的界面功能按钮验证编译上传新建打开保存工作状态 1-2 Arduino IDE界面和设置 大家好这…

day00-系统重要文件

01.知识点回顾 1.resolv.conf dns的配置文件 [rootlinux ~]# vim /etc/resolv.conf [rootlinux ~]# nslookup www.baidu.com Server: 8.8.8.8 Address: 8.8.8.8#53Non-authoritative answer: www.baidu.com canonical name www.a.shifen.com. Name: www.a.shifen.com Addre…

MongoDB适合哪些人使用

MongoDB 是一款高性能、开源、无模式的文档型数据库,它使用 BSON(Binary JSON)作为其数据格式,这使得 MongoDB 非常适合于存储和查询复杂的数据结构。MongoDB 的灵活性、可扩展性和高性能特性吸引了多种类型的用户。以下是 MongoD…

如何穿透模糊,还原图片真实面貌

目录 图像清晰化的魔法棒:AI如何穿透模糊,还原图片真实面貌 前言 论文背景 论文思路 模型介绍 复现过程 演示视频 使用方式 本文所涉及所有资源均在传知代码平台可获取。 图像清晰化的魔法棒:AI如何穿透模糊,还原图片真实面貌 在我…

全网最最实用--模型高效推理:量化基础

文章目录 一、量化基础--计算机中数的表示1. 原码(Sign-Magnitude)2. 反码(Ones Complement)3. 补码(Twos Complement)4. 浮点数(Floating Point)a.常用的浮点数标准--IEEE 754(FP32…

状态机 XState 使用

状态机 一般指的是有限状态机(Finite State Machine,FSM),又可以称为有限状态自动机(Finite State Automation,FSA),简称状态机,它是一个数学模型,表示有限个…

【计算机网络】数据链路层实验

一:实验目的 1:学习WireShark软件的抓包操作,分析捕获的以太网的MAC帧结构。 2:学习网络中交换机互相连接、交换机连接计算机的拓扑结构,理解虚拟局域网(WLAN)的通信机制。 3:学习…

cas 和 synchronized 优化过程

cas 什么是CAS CAS:全称Compareandswap,字⾯意思:”⽐较并交换“,⼀个CAS涉及到以下操作: 我们假设内存中的原数据V,旧的预期值A,需要修改的新值B。 1. ⽐较A与V是否相等。(⽐较) 2. 如果⽐较…

半导体行业黑话-02

31. #Silicon Chef# - 硅厨师,指负责设计和制造芯片的工程师。 32. #Silicon Chefs Kitchen# - 硅厨师的厨房,指半导体设计和制造的实验室或工作区。 33. #Silicon Ghetto# - 硅贫民区,有时用来形容那些技术落后或条件较差的制造厂。 34. #Silicon Jungle# - 硅丛林,形容半…

ubuntu22.04单个网口两个IP

其中 4网段IP可用来上网,3 网段用来内网 界面显示: 配置文件: 01-network-manager-all.yaml 放在 /etc/netplan/ # Let NetworkManager manage all devices on this systemnetwork:version: 2renderer: networkdethernets:eth0:dhcp4: falsedhcp6: …

防火墙与入侵检测系统(IDS/IPS)在现代网络安全中的关键角色

在数字化日益加速的今天,网络安全变得尤为重要。随着网络攻击的复杂性和频率不断增加,保护关键信息资产已成为各大小组织的首要任务。防火墙(Firewall)和入侵检测系统(Intrusion Detection System,IDS&…

开放式耳机哪个牌子好?五大超值机型整理,速速收藏!!

大家都知道现在的开放式耳机是越来越火了,后台也有非常多的小伙伴来私信,作为一个耳机测评师,当然是为了你们服务啦,所以这一期文章,就是为了个大家答疑解惑,告诉大家如何才能选购出一款比较好用的开放式耳…

【Python】字母 Rangoli 图案

一、题目 You are given an integer N. Your task is to print an alphabet rangoli of size N. (Rangoli is a form of Indian folk art based on creation of patterns.) Different sizes of alphabet rangoli are shown below: # size 3 ----c---- --c-b-c-- c-b-a-b-c --…

3106. 满足距离约束且字典序最小的字符串 Medium

给你一个字符串 s 和一个整数 k 。 定义函数 distance(s1, s2) ,用于衡量两个长度为 n 的字符串 s1 和 s2 之间的距离,即: 字符 a 到 z 按 循环 顺序排列,对于区间 [0, n - 1] 中的 i ,计算所有「 s1[i] 和 s2[i] 之间…

万字长文详解Java反射技术 | JavaSE | Java进阶知识 | 源码

🙋大家好!我是毛毛张! 🌈个人首页: 神马都会亿点点的毛毛张 📌今天分享的是JavaSE中的进阶知识🛑:反射技术。内容有点长,非常全面,记得点赞👍、收藏✅加关…