PyTorch如何修改模型(魔改)

文章目录

  • PyTorch如何修改模型(魔改)
    • 1.修改模型层(模型框架⭐)
      • 1.1通过继承修改模型
      • 1.2通过组合修改模型(重点学👀)
      • 1.3通过猴子补丁修改模型
    • 2.添加外部输入
    • 3.添加额外输出
    • 参考

PyTorch如何修改模型(魔改)

对模型缝缝补补、修修改改,是我们必须要掌握的技能,本文详细介绍了如何修改PyTorch模型?也就是我们经常说的如何魔改。👍

PyTorch 的模型是一个 torch.nn.Module 的某个子类的对象,修改模型实际就等价于修改某个类,对面向对象熟悉的同学应该知道,对类做修改有两个经典的方法:组合继承

1.修改模型层(模型框架⭐)

1.1通过继承修改模型

首先创建自己需要的模型类,然后其父类指向需要被修改的模型,这时自己的模型则具有完备的父类行为,最后在子类中实现魔改的逻辑。其大致的框架代码如下所示:

from torchvision.models import ResNetclass CustomizedResNet(ResNet):def __init__(self):super().__init__()...def forward(self, x):...

下面这个例子,将对 ResNet 进行魔改,把 ResNet 的 4 个 stage 输出的特征连接起来,然后通过一个全连接层后输出一个标量。

from torchvision.models.resnet import Bottleneck, BasicBlock, ResNet
import torch# 定义一个自定义的ResNet类,继承自torchvision的ResNet类
class CustomizedResNet(ResNet):def __init__(self, block, layers, num_classes=2):"""初始化函数block: ResNet中的基本块类型,可以是BasicBlock或Bottlenecklayers: 每个层级的基本块数量,是一个列表num_classes: 输出的类别数量,默认为2"""# 调用父类的初始化方法super().__init__(block, layers, num_classes)# 重新定义全连接层,改变输出的特征数量self.fc = torch.nn.Linear(int(512 * block.expansion * 1.875), num_classes)def forward(self, x):# 以下是ResNet的前向传播过程x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)# 通过四个残差层x1 = self.layer1(x)x2 = self.layer2(x1)x3 = self.layer3(x2)x4 = self.layer4(x3)# 将四个残差层的输出进行拼接x = torch.cat([self.avgpool(x1),self.avgpool(x2),self.avgpool(x3),self.avgpool(x4),], dim=1)# 将拼接后的张量展平x = torch.flatten(x, 1)# 通过全连接层,得到最终的输出x = self.fc(x)return x# 创建不同版本的ResNet模型
new_resnet34 = CustomizedResNet(BasicBlock, [3, 4, 6, 3], num_classes=1)
new_resnet50 = CustomizedResNet(Bottleneck, [3, 4, 6, 3], num_classes=1)
new_resnet101 = CustomizedResNet(Bottleneck, [3, 4, 23, 3], num_classes=1)
new_resnet200 = CustomizedResNet(Bottleneck, [3, 24, 36, 3], num_classes=1)

1.2通过组合修改模型(重点学👀)

在面向对象编程中,可能听说过「组合优于继承」,在模型修改的场景中其实也是这样,大多数情况下我们可能都适用组合而非继承。

首先依然需要创建模型的类,但这个类不再继承自魔改的类,而是直接继承 PyTorch 的模型基类 torch.nn.Module,然后将需要魔改的类作为类变量融入到模型中,下面是大致的框架代码:

from torchvision.models import resnet18
import torch.nn as nnclass CustomizedResNet(nn.Module):def __init__(self, backbone):super().__init__()self.backbone = backbone...def forward(self, x):...my_resnet18 = CustomizedResNet(resnet18)

同样,实现对 ResNet 进行魔改,把 ResNet 的 4 个 stage 输出的特征连接起来,然后通过一个全连接层后输出一个标量。

from torchvision.models import resnet50class CustomizedResNet(torch.nn.Module):def __init__(self, backbone, num_classes=2):super().__init__()self.backbone = backboneself.fc = torch.nn.Linear(3840, num_classes)def forward(self, x):x = self.backbone.conv1(x)x = self.backbone.bn1(x)x = self.backbone.relu(x)x = self.backbone.maxpool(x)x1 = self.backbone.layer1(x)x2 = self.backbone.layer2(x1)x3 = self.backbone.layer3(x2)x4 = self.backbone.layer4(x3)x = torch.cat([self.backbone.avgpool(x1),self.backbone.avgpool(x2),self.backbone.avgpool(x3),self.backbone.avgpool(x4),],dim=1,)x = torch.flatten(x, 1)x = self.fc(x)return xnew_resnet50 = CustomizedResNet(resnet50())

1.3通过猴子补丁修改模型

最简单粗暴的方法:猴子补丁(Monkey Patch)。之所以叫猴子补丁,是因为这种方法从程序设计的角度上来说,是具有破坏性的。而且这种方法仅能实现一些简单的修改需求,所以还是推荐使用继承或组合去修改我们的模型。😉

猴子补丁修改模型非常简单粗暴,直接使用需要修改的模型创建对象,然后直接对对象的属性做出修改。下面是把 ResNet34 的输出从 1000 改为 1 的简单例子:

from torchvision.models import resnet50
import torch.nn as nnmodel = resnet50()
model.fc = nn.Linear(2048, 1)

还有一个例子,以 PyTorch 官方视觉库 torchvision 预定义好的模型 ResNet50 为例,修改模型的某一层或者某几层。先观察一下它的网络结构:

import torch
import torch.nn as nn
from collections import OrderedDict
import torchvision.models as models
net = models.resnet50()
print(net)

假设要用这个模型去做一个10分类的问题,就应该修改模型的 fc 层,将其输出节点数替换为10。另外,想再加一层全连接层。可以做如下修改:

classifier = nn.Sequential(OrderedDict([('fc1', nn.Linear(2048, 128)),('relu1', nn.ReLU()), ('dropout1',nn.Dropout(0.5)),('fc2', nn.Linear(128, 10)),('output', nn.Softmax(dim=1))]))net.fc = classifier

这里的操作相当于将模型(net)最后名称为“fc”的层替换成了名称为“classifier”的结构。

2.添加外部输入

有时候在模型训练中,除了已有模型的输入之外,还需要输入额外的信息。比如在CNN网络中,我们除了输入图像,还需要同时输入图像对应的其他信息,这时候就需要在已有的CNN网络中添加额外的输入变量。基本思路是:将原模型添加输入位置前的部分作为一个整体,同时在forward中定义好原模型不变的部分、添加输入和后续层之间的连接关系,从而完成模型的修改。

以 torchvision 的 resnet50 模型为基础,任务还是10分类任务。不同点在于,我们希望利用已有的模型结构,在倒数第二层增加一个额外的输入变量 add_variable 来辅助预测。具体实现如下:

class Model(nn.Module):def __init__(self, net):super().__init__()self.net = netself.relu = nn.ReLU()self.dropout = nn.Dropout(0.5)self.fc_add = nn.Linear(1001, 10, bias=True)self.output = nn.Softmax(dim=1)def forward(self, x, add_variable):x = self.net(x)x = torch.cat((self.dropout(self.relu(x)),add_variable.unsqueeze(1)),1)x = self.fc_add(x)x = self.output(x)return x

这里的实现要点是通过torch.cat实现了tensor的拼接。torchvision 中的 resnet50 输出是一个1000维的 tensor,通过修改 forward 函数,先将 1000 维的 tensor 通过激活函数层和dropout层,再和外部输入变量"add_variable"拼接,最后通过全连接层映射到指定的输出维度 10。

另外这里对外部输入变量"add_variable"进行 unsqueeze 操作是为了和 net 输出的 tensor 保持维度一致,常用于 add_variable 是单一数值 (scalar) 的情况,此时 add_variable 的维度是 (batch_size, ),需要在第二维补充维数1,从而可以和 tensor 进行torch.cat操作。
unsqueeze与sequeeze语法说明

最后,对我们修改好的模型结构进行实例化,就可以使用了:

net = models.resnet50()
model = Model(net).cuda()

另外别忘了,训练中在输入数据的时候要给两个inputs:

outputs = model(inputs, add_var)

3.添加额外输出

有时候在模型训练中,除了模型最后的输出外,我们需要输出模型某一中间层的结果,以施加额外的监督,获得更好的中间层结果。基本的思路是修改模型定义中 forward 函数的 return 变量。

依然以 resnet50 做 10 分类任务为例,在已经定义好的模型结构上,同时输出 1000 维的倒数第二层和 10 维的最后一层结果。具体实现如下:

class Model(nn.Module):def __init__(self, net):super().__init__()self.net = netself.relu = nn.ReLU()self.dropout = nn.Dropout(0.5)self.fc1 = nn.Linear(1000, 10, bias=True)self.output = nn.Softmax(dim=1)def forward(self, x, add_variable):x1000 = self.net(x)x10 = self.dropout(self.relu(x1000))x10 = self.fc1(x10)x10 = self.output(x10)return x10, x1000

之后,对我们修改好的模型结构进行实例化,就可以使用了:

net = models.resnet50()
model = Model(net).cuda()out10, out1000 = model(inputs, add_var)

参考

  • Chenglu’s Log

  • Pytorch修改预训练模型的方法汇总

😃😃😃

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/6140.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【算法刷题 | 动态规划02】5.02(不同路径、不同路径||、整数拆分、不同的二叉搜索树)

文章目录 5.不同路径5.1题目5.2解法一:深度搜索5.2.1深度搜索思路5.2.2代码实现 5.3解法二:动规5.3.1动规思路5.3.2代码实现 6.不同路径||6.1题目6.2解法:动规6.2.1动规思路(1)dp数组以及下标含义(2&#x…

基于Springboot的交流互动系统

基于SpringbootVue的交流互动系统的设计与实现 开发语言:Java数据库:MySQL技术:SpringbootMybatis工具:IDEA、Maven、Navicat 系统展示 用户登录 首页 帖子信息 聚会信息 后台登录 后台管理首页 用户管理 帖子分类管理 帖子信息…

Python语言零基础入门——文件

目录 一、文件的基本概念 1.文件 2.绝对路径与相对路径 3.打开文件的模式 二、文件的读取 三、文件的追加 四、文件的写入 五、with语句 六、csv文件 1.csv文件的读取 2.csv文件的写入 七、练习题:实现日记本 一、文件的基本概念 1.文件 文件是以计算…

Mysql中索引的概念

索引相关概念 基础概念: 在MySQL中,索引是一种数据结构,用于加快数据库查询的速度和性能。索引可以帮助MySQL快速定位和访问表中的特定数据,就像书籍的索引一样,通过存储指向数据行的指针,可以快速…

ICode国际青少年编程竞赛- Python-1级训练场-路线规划

ICode国际青少年编程竞赛- Python-1级训练场-路线规划 1、 Dev.step(3) Dev.turnLeft() Dev.step(4)2、 Dev.step(3) Dev.turnLeft() Dev.step(3) Dev.step(-6)3、 Dev.step(-2) Dev.step(4) Dev.turnLeft() Dev.step(3)4、 Dev.step(2) Spaceship.step(2) Dev.step(3)5、…

Android手写自己的路由SDK

实现自己的路由框架 ​ 在较大型的Android app中常会用到组件化技术,针对不同的业务/基础功能对模块进行划分,从上到下为壳工程、业务模块、基础模块。其中业务模块依赖基础模块,壳工程依赖业务模块。同级的横向模块(比如多个业务…

软件杯 深度学习的动物识别

文章目录 0 前言1 背景2 算法原理2.1 动物识别方法概况2.2 常用的网络模型2.2.1 B-CNN2.2.2 SSD 3 SSD动物目标检测流程4 实现效果5 部分相关代码5.1 数据预处理5.2 构建卷积神经网络5.3 tensorflow计算图可视化5.4 网络模型训练5.5 对猫狗图像进行2分类 6 最后 0 前言 &#…

MySQL-逻辑架构

1、MySQL服务器处理客户端请求 MySQL是典型的C/S架构,服务端程序使用 mysqld。实现效果:客户端进程像服务端发送(SQL语句),服务器进程处理后再像客户端进程发送 处理结果。 2、connectors 指不同语言中与SQL的交互…

【C++】双指针算法:四数之和

1.题目 2.算法思路 这道题目十分困难,在leetcode上的通过率只有36%,大家要做好心理准备。 在做个题目前强烈建议大家先看看我的上一篇博客:有效三角形个数,看完之后再去leetcode上写一写三数之和,搞懂那两个题目之后…

JavaEE 初阶篇-深入了解 Junit 单元测试框架和 Java 中的反射机制(使用反射做一个简易版框架)

🔥博客主页: 【小扳_-CSDN博客】 ❤感谢大家点赞👍收藏⭐评论✍ 文章目录 1.0 Junit 单元测试框架概述 1.1 使用 Junit 框架进行测试业务代码 1.2 Junit 单元测试框架的常用注解(Junit 4.xxx 版本) 2.0 反射概述 2.1 获…

计算机毕业设计php自行车在线租赁管理系统-vue+mysql

本系统的开发使获取自行车在线租赁管理系统信息能够更加方便快捷,同时也使自行车在线租赁管理系统管理信息变的更加系统化、有序化。系统界面较友好,易于操作。 自行车在线租赁管理系统,主要的模块包括首页、个人中心、用户管理、会员管理、自…

软件系统安全设计(安全保证措施)

软件安全保证措施word 软件所有全套资料获取进主页或者本文末个人名片直接。

C++之set/map相关实现

看着上面的图片,你可能对set和map的多样变化产生疑惑,下面我们就来详细讲解他们的区别以及实现 一.set/map 首先,在这里我们要声明,如果你对二叉搜索树一点都不了解的话,建议你先去将搜索二叉树学会再来学习这里的内…

ArkTS开发原生鸿蒙HarmonyOS短视频应用

HarmonyOS实战课程“2024鸿蒙零基础快速实战-仿抖音App开发(ArkTS版)”已经于今日上线至慕课网(https://coding.imooc.com/class/843.html),有致力于鸿蒙生态开发的同学们可以关注一下。 课程简介 本课程以原生鸿蒙Ha…

【Canvas与艺术】新制无底图安布雷拉暗黑系桌面(1920*1080)

【主要变化】 1.去掉底图&#xff0c;改为金丝正六边形组合而成的网格&#xff1b; 2.将安布雷拉标志调暗&#xff1b; 【成图】 【代码】 <!DOCTYPE html> <html lang"utf-8"> <meta http-equiv"Content-Type" content"text/html;…

力扣HOT100 - 78. 子集

解题思路&#xff1a; class Solution {public List<List<Integer>> subsets(int[] nums) {List<List<Integer>> lists new ArrayList<>(); // 解集lists.add(new ArrayList<Integer>()); // 首先将空集加入解集中for(int i 0; i < n…

Mac 安装 JDK21 流程

一、下载JDK21 访问Oracle官方网站或选择OpenJDK作为替代品。Oracle JDK从11版本开始是商业的&#xff0c;可能需要支付费用。OpenJDK是一个免费开源选项。 Oracle JDK官方网站&#xff1a;Oracle JDK Downloads OpenJDK官方网站&#xff1a;OpenJDK Downloads 这里以JDK21为…

FP16、BF16、INT8、INT4精度模型加载所需显存以及硬件适配的分析

大家好,我是herosunly。985院校硕士毕业,现担任算法研究员一职,热衷于机器学习算法研究与应用。曾获得阿里云天池比赛第一名,CCF比赛第二名,科大讯飞比赛第三名。拥有多项发明专利。对机器学习和深度学习拥有自己独到的见解。曾经辅导过若干个非计算机专业的学生进入到算法…

EDA(一)Verilog

EDA&#xff08;一&#xff09;Verilog Verilog是一种用于电子系统设计自动化&#xff08;EDA&#xff09;的硬件描述语言&#xff08;HDL&#xff09;&#xff0c;主要用于设计和模拟电子系统&#xff0c;特别是在集成电路&#xff08;IC&#xff09;和印刷电路板&#xff08;…

CogVLM/CogAgent环境搭建推理测试

引子 对于多模态大语言模型&#xff0c;一直没有怎么接触。刚巧一朋友有问到这方面的问题&#xff0c;也就顺手调研下。智谱AI的东西一直以来&#xff0c;还是很不错的。ChatGLM的忠实fans&#xff0c;看到白嫖网站github上有他们开源的多模态CogVLM/CogAgent&#xff0c;那就…