持续学习动态架构算法LwF(Learning without Forgetting )解读总结与代码注释

0.持续学习

  • 持续学习相关文章汇总,包含论文地址、代码地址、具体分析解读地址

1.LwF算法相关链接

  • 论文地址
  • 代码地址

2.基本想法

  • 针对问题:在无法获得原始任务训练数据的情况下,适合使视觉系统适应新任务,并且保证其在旧任务上的性能
  • 问题建模:学习对新任务具有判别能力的参数,同时保留训练数据上原始任务的输出
  • 将网络分为所有任务共享部分和特定任务独享部分,网络架构如下:
    图片

3.损失函数

  • 待学习参数有三种:共享部分参数、旧任务们的独享参数、新任务独享参数
  • 由三部分组成:旧任务损失、新任务损失、正则化项
  • 旧任务损失:增长后的网络的输出与增长前的输出尽可能相同,采用知识蒸馏损失,类似交叉熵损失,只不过加大了较小概率的惩罚权重(其中关键参数T,要大于1来加大小概率的权重,文中通过网格搜索将其定位2)
  • 新任务损失:对于新任务的预测与真实值尽可能相同,使用交叉熵损失或者NLL损失
  • 正则化项:限制网络中所有参数,权重0.0005
  • 新旧任务权衡:在新任务损失前面有一个系数来表示对新旧任务性能的权衡,文中取1,参数越大,在新任务上的性能越好,在旧任务上的性能越差。通过改变该参数可以获得新旧任务性能曲线。

4.训练流程

  • 热身阶段(warm-up step):冻结共享部分参数、旧任务们的独享参数,单独训练新任务独享参数
  • 联合优化阶段(joint-optimize step):优化所有参数

5.特点

  • 与传统联合调优方法相比:无需存储旧任务的数据,新任务只需要通过一次共享层便可以用来进行旧任务和新任务的更新,却具有了联合调优的优点。但因为不同任务的分布会不相同,所以文中的方法效果会不如传统联合调优,传统联合调优的效果可以视为本文方法的上限。
  • 效率分析
    • 最慢:共享参数的正反向传播
    • 最快:特征提取层,因为只需要训练新任务的参数
    • 与传统微调相比:多了一步旧任务的独享参数更新,效率稍微低一点
    • 与传统联合调优相比:新旧任务共享的参数只需要进行一次前后向传播,效率更高

6.具体细节

  • 使用动量0.9的随机梯度下降
  • 在全连接层使用了dropout
  • 用旧任务的信息对新任务进行归一化
  • 数据增强:
    • 5X5的网格上对调整过大小的图像进行随机的固定尺寸裁剪
    • 随机镜像裁剪
    • RGB值上添加方差
  • 使用Xavier初始化新任务独享参数
  • 学习率是原网络学习率的0.1-0.02倍
  • 由于任务独享的特征提取部分参数量少,所以使用5倍学习率
  • 对于学习速度相似的方法,使用相同的训练epoch来进行公平比较
  • 有时为了防止过拟合、提升学习速度,会接近平稳在的时候将学习率变为0.1倍
  • 为了公平比较,将热身阶段后的共享网络作为联合训练和微调训练的起始点

7.实验

  • 添加单个新任务
  • 添加多个新任务
  • 数据集大小的影响
  • 网络设计的影响
  • 不同损失
  • 扩展网络结构的效用
  • 小学习率微调来保证旧任务的影响
  • 改变任务专属部分的网络层数

8.结论

  • 对于增长节点式的任务专属网络,其性能与原本的LwF性能相近,但是计算开销却大很多
  • 仅仅降低共享网络的学习率对保留旧任务性能的帮助并不大,但却会很大程度影响新任务
  • 用网络输出的变化来现在旧任务的变化要优于用网络参数的变化来衡量,因为网络参数一点小小的改变就可能引起输出巨大的改变
  • 知识蒸馏损失略优于L1、L2、交叉熵损失,但优势很小
  • 训练速度优于联合优化,对新任务的性能优于微调
  • 本文针对旧任务的损失对旧任务性能上的表现更可解释

9.未来工作

  • 应用到图像分类、跟踪等更多领域:分割、检测、视觉外的任务
  • 探索根据任务分布针对性地保留一些过去的任务数据和输出(由于是面对重尾分布)

10.代码解读

  • 参考文章
  • 含有备注的model.py
import torch
torch.backends.cudnn.benchmark=True
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import numpy as np
from PIL import Image
from tqdm import tqdm
import time
import copyimport torchvision.models as models
import torchvision.transforms as transformsdef MultiClassCrossEntropy(logits, labels, T):# Ld = -1/N * sum(N) sum(C) softmax(label) * log(softmax(logit))labels = Variable(labels.data, requires_grad=False).cuda()outputs = torch.log_softmax(logits/T, dim=1)   # compute the log of softmax valueslabels = torch.softmax(labels/T, dim=1)# print('outputs: ', outputs)# print('labels: ', labels.shape)outputs = torch.sum(outputs * labels, dim=1, keepdim=False)outputs = -torch.mean(outputs, dim=0, keepdim=False)# print('OUT: ', outputs)return Variable(outputs.data, requires_grad=True).cuda()def kaiming_normal_init(m):if isinstance(m, nn.Conv2d):#判断m是不是nn.Conv2d的类型或子类nn.init.kaiming_normal_(m.weight, nonlinearity='relu')#一种初始化方法,要指明激活函数,保证输出有一定方差https://zhuanlan.zhihu.com/p/536483424elif isinstance(m, nn.Linear):nn.init.kaiming_normal_(m.weight, nonlinearity='sigmoid')class Model(nn.Module):'''分为超参数、网络架构、类增加三个部分前向传播里没有softmax'''def __init__(self, classes, classes_map, args):# Hyper Parametersself.init_lr = args.init_lrself.num_epochs = args.num_epochsself.batch_size = args.batch_sizeself.lower_rate_epoch = [int(0.7 * self.num_epochs), int(0.9 * self.num_epochs)] #hardcoded decay scheduleself.lr_dec_factor = 10self.pretrained = Falseself.momentum = 0.9self.weight_decay = 0.0001# Constant to provide numerical stability while normalizingself.epsilon = 1e-16# Network architecturesuper(Model, self).__init__()self.model = models.resnet34(pretrained=self.pretrained)self.model.apply(kaiming_normal_init)"""独享层:一层全连接层,与classes数量有关,且没有偏置"""num_features = self.model.fc.in_featuresself.model.fc = nn.Linear(num_features, classes, bias=False)self.fc = self.model.fc'''共享层:resnet34除去最后一层'''#nn.Sequential按序列构建模型https://blog.csdn.net/hxxjxw/article/details/106231242#.children()返回模型的最外层,与.model()的区别类似于attend和extendself.feature_extractor = nn.Sequential(*list(self.model.children())[:-1])#*用于迭代地取出list中的内容#用nn.DataParallel包装模型,可以在多GPU上运行https://zhuanlan.zhihu.com/p/647169457self.feature_extractor = nn.DataParallel(self.feature_extractor) # n_classes is incremented(递增) before processing new data in an iteration# n_known is set to n_classes after all data for an iteration has been processed数据处理完后n_known设为n_classesself.n_classes = 0self.n_known = 0self.classes_map = classes_mapdef forward(self, x):x = self.feature_extractor(x)x = x.view(x.size(0), -1)x = self.fc(x)return xdef increment_classes(self, new_classes):"""Add n classes in the final fc layer"""n = len(new_classes)print('new classes: ', n)in_features = self.fc.in_featuresout_features = self.fc.out_featuresweight = self.fc.weight.data#保存旧任务的网络权重if self.n_known == 0:new_out_features = nelse:new_out_features = out_features + nprint('new out features: ', new_out_features)self.model.fc = nn.Linear(in_features, new_out_features, bias=False)self.fc = self.model.fckaiming_normal_init(self.fc.weight)#所有任务网络统一初始化self.fc.weight.data[:out_features] = weight#还原旧任务网络权重self.n_classes += ndef classify(self, images):"""Classify images by softmaxArgs:x: input image batchReturns:preds: Tensor of size (batch_size,)"""_, preds = torch.max(torch.softmax(self.forward(images), dim=1), dim=1, keepdim=False)return predsdef update(self, dataset, class_map, args):self.compute_means = True# Save a copy to compute distillation outputs保存旧网络来计算旧任务原始输出prev_model = copy.deepcopy(self)prev_model.cuda()classes = list(set(dataset.train_labels))#print("Classes: ", classes)print('Known: ', self.n_known)if self.n_classes == 1 and self.n_known == 0:#self.n_classes初始值是1不是0吗?!new_classes = [classes[i] for i in range(1,len(classes))]else:new_classes = [cl for cl in classes if class_map[cl] >= self.n_known]#有新任务就动态调整网络if len(new_classes) > 0:self.increment_classes(new_classes)self.cuda()loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size,shuffle=True, num_workers=12)print("Batch Size (for n_classes classes) : ", len(dataset))optimizer = optim.SGD(self.parameters(), lr=self.init_lr, momentum = self.momentum, weight_decay=self.weight_decay)with tqdm(total=self.num_epochs) as pbar:for epoch in range(self.num_epochs):# Modify learning rate# if (epoch+1) in lower_rate_epoch:# 	self.lr = self.lr * 1.0/lr_dec_factor# 	for param_group in optimizer.param_groups:# 		param_group['lr'] = self.lrfor i, (indices, images, labels) in enumerate(loader):seen_labels = []images = Variable(torch.FloatTensor(images)).cuda()seen_labels = torch.LongTensor([class_map[label] for label in labels.numpy()])labels = Variable(seen_labels).cuda()# indices = indices.cuda()optimizer.zero_grad()logits = self.forward(images)cls_loss = nn.CrossEntropyLoss()(logits, labels)if self.n_classes//len(new_classes) > 1:dist_target = prev_model.forward(images)logits_dist = logits[:,:-(self.n_classes-self.n_known)]dist_loss = MultiClassCrossEntropy(logits_dist, dist_target, 2)loss = dist_loss+cls_losselse:loss = cls_lossloss.backward()optimizer.step()if (i+1) % 1 == 0:tqdm.write('Epoch [%d/%d], Iter [%d/%d] Loss: %.4f' %(epoch+1, self.num_epochs, i+1, np.ceil(len(dataset)/self.batch_size), loss.data))pbar.update(1)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/222171.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

在Vue开发中v-if指令和v-show指令的使用介绍和区别及使用场景

一、条件渲染 v-if v-if 指令用于条件性地渲染一块内容。这块内容只会在指令的表达式返回真值时才被渲染。 <h1 v-if"awesome">Vue is awesome!</h1>v-else 你也可以使用 v-else 为 v-if 添加一个“else 区块”。 <h1 v-if"awesome"&g…

什么品牌的猫粮比较好?主食冻干猫粮品牌十大排行

咱们养猫人每天最愁的就是咋给自家猫咪选一款优质的猫粮&#xff0c;让猫主子吃了健健康康的。早些年大多养猫人的标准就是盯着进口的买&#xff0c;所以之前进口猫粮的销量一直遥遥领先&#xff0c;感觉品控也严&#xff0c;也就放心大胆的冲进口猫粮了&#xff0c;但近期百利…

34.用过JavaConfig方式的spring配置吗?它是如何替代xml的?

用过JavaConfig方式的spring配置吗?它是如何替代xml的? 基于Java的配置,允许你在少量的Java注解的帮助下,进行你的大部分Spring配置而非通过XML文件。 以@Configuration 注解为例,它用来标记类可以当做一个bean的定义,被Spring IOC容器使用。 另一个例子是@Bean注解,它…

【开题报告】基于SpringBoot的艺术类家教平台的设计与实现

1.选题背景 随着人们生活水平的提高和文化教育的重视&#xff0c;越来越多的家长开始注重孩子的艺术教育&#xff0c;希望让孩子在绘画、音乐、舞蹈等方面得到更加专业的指导和培养。 然而&#xff0c;市场上现有的艺术类家教资源不够丰富和专业&#xff0c;家长们很难找到合…

flink中如何把DB大表的配置数据加载到内存中对数据流进行增强处理

背景 在处理flink的数据流时&#xff0c;比如处理商品流时&#xff0c;一般我们从kafka中只拿到了商品id&#xff0c;此时我们需要把商品的其他配置信息比如品牌品类等也拿到&#xff0c;此时就需要关联上外部配置表来达到丰富数据流的目的&#xff0c;如果外部配置表很大&…

我的隐私计算学习——隐私集合求交(1)

笔记内容来自多本书籍、学术资料、白皮书及ChatGPT等工具&#xff0c;经由自己阅读后整理而成。 &#xff08;一&#xff09;PSI的介绍 隐私计算关键技术&#xff1a;隐私集合求交&#xff08;PSI&#xff09;原理介绍 隐私计算关键技术&#xff1a;隐私集合求交&#xff08…

在系统中查找重复文件

说在前面 &#x1f388;不知道大家对于算法的学习是一个怎样的心态呢&#xff1f;为了面试还是因为兴趣&#xff1f;不管是出于什么原因&#xff0c;算法学习需要持续保持。 一、题目描述 给你一个目录信息列表 paths &#xff0c;包括目录路径&#xff0c;以及该目录中的所有…

事务--03---TCC空回滚、悬挂、幂等解决方案

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 Seata TCC 模式设计思路TCC存在的问题1、空回滚以及解决方案解决方案&#xff1a; 2、幂等问题以及解决方案解决方案&#xff1a; 3、悬挂问题以及解决方案解决方案…

PCIe设备热插拔-理论篇

硬件层面理解热插拔 PRSNT1#和PRSNT2#信号与PCIe设备的热插拔相关。在基于PCIe总线的Add-in 卡中&#xff0c;PRSNT1# 和PRSNT2#信号直接相连&#xff0c;而在处理器主板中&#xff0c;PRSNT1#信号接地&#xff0c;而PRSNT2#信号通过上 拉电阻接为高。 不同的处理器系统处理PC…

【Mysql】InnoDB的表空间(九)

概述 表空间是一个在 InnoDB 中比较抽象的概念&#xff0c;对于系统表空间来说&#xff0c;对应着文件系统中一个或多个实际文件&#xff1b;而对于每个独立表空间来说&#xff0c;对应着文件系统中一个名为表名.ibd 的实际文件。可以把表空间想象成由很多个页组成的池子&…

【Unity 实用工具篇】| 游戏多语言解决方案,官方插件Localization 实现本地化及多种语言切换

前言 【Unity 实用工具篇】| 游戏多语言解决方案&#xff0c;官方插件Localization 实现本地化及多种语言切换一、多语言本地化插件 Localization1.1 介绍1.2 效果展示1.3 使用说明 二、 插件导入并配置2.1 安装 Localization2.2 全局配置 三、多语言映射表3.1 创建多语言文本配…

Python之面向对象程序设计

文章目录 1、类定义2、创建实例3、属性4、方法5、继承6、多态7、组合8、导入类 1、类定义 面向对象程序设计的一个关键性观念是将数据以及对数据的操作封装在一起&#xff0c;组成一个相互依存、不可分割的整体&#xff0c;即对象。对于相同类型的对象进行分类、抽象后&#x…

字符处理 C语言xdoj52

问题描述 从键盘输入一个字符&#xff0c;若为小写字母&#xff0c;则输出其对应的大写字母&#xff1b;若为大写字母&#xff0c;则输出对应的小写字母&#xff1b;其他字符原样输出。 输入说明 输入一个字符 输出说明 输出一个字符 输入样例 样例1输入 a 样例…

分布式块存储 ZBS 的自主研发之旅|元数据管理

重点内容 元数据管理十分重要&#xff0c;犹如整个存储系统的“大黄页”&#xff0c;如果元数据操作出现性能瓶颈&#xff0c;将严重影响存储系统的整体性能。如何提升元数据处理速度与高可用是元数据管理的挑战之一。SmartX 分布式存储 ZBS 采用 Log Replication 的机制&…

安装ingress-nginx

1、下载helm压缩包 wget https://get.helm.sh/helm-v3.2.3-linux-amd64.tar.gz2、解压 [rootk8s-master-10 helm]# tar -zxvf helm-v3.2.3-linux-amd64.tar.gz linux-amd64/ linux-amd64/README.md linux-amd64/LICENSE linux-amd64/helm3、进入linux-amd64 [rootk8s-maste…

论文修改润色平台 PaperBERT

大家好&#xff0c;今天来聊聊论文修改润色平台&#xff0c;希望能给大家提供一点参考。 以下是针对论文重复率高的情况&#xff0c;提供一些修改建议和技巧&#xff1a; 标题&#xff1a;论文修改润色平台――助力学术研究&#xff0c;提升论文质量 一、引言 在学术研究中&am…

复制粘贴——QT实现原理

复制粘贴——QT实现原理 QT 剪贴板相关类 QClipboard 对外通用的剪贴板类&#xff0c;一般通过QGuiApplication::clipboard() 来获取对应的剪贴板实例。 // qtbase/src/gui/kernel/qclipboard.h class Q_GUI_EXPORT QClipboard : public QObject {Q_OBJECT private:explici…

单片机——通信协议(FPGA+c语言应用之spi协议解析篇)

引言 串行外设接口(SPI)是微控制器和外围IC&#xff08;如传感器、ADC、DAC、移位寄存器、SRAM等&#xff09;之间使用最广泛的接口之一。本文先简要说明SPI接口&#xff0c;然后介绍ADI公司支持SPI的模拟开关与多路转换器&#xff0c;以及它们如何帮助减少系统电路板设计中的数…

ChatGLM大模型推理加速之Speculative Decoding

目录 一、推测解码speculative decoding 1、自回归解码 2、speculative decoding 3、细节理解 二、核心逻辑代码 1、算法流程代码 2、模型自回归代码 a、带缓存的模型自回归实现代码 b、优化版本带缓存的模型自回归实现代码 c、ChatGLM的past_key_values的回滚 三、…

要求CHATGPT高质量回答的艺术:提示工程技术的完整指南—第 21 章:课程学习提示

要求CHATGPT高质量回答的艺术&#xff1a;提示工程技术的完整指南—第 21 章&#xff1a;课程学习提示 课程学习是一种技术&#xff0c;它允许模型通过首先训练较简单的任务并逐渐增加难度来学习复杂的任务。 要在 ChatGPT 中使用课程学习提示&#xff0c;应为模型提供一系列…