笔记:Few-Shot Learning小样本分类问题 + 孪生网络 + 预训练与微调

内容摘自王老师的B站视频,大家还是尽量去看视频,老师讲的特别好,不到一小时的时间就缕清了小样本学习的基础知识点~Few-Shot Learning (1/3): 基本概念_哔哩哔哩_bilibili

Few-Shot Learning(小样本分类)

假设现在每类只有一两个样本,计算机能否做到像人一样的正确分类?

  • 这个例子Support Set有两类,每类只有一两个样本,靠这些样本,难以训练出一个深度神经网络,这个集合只能提供一些参考信息。对于小样本问题,不能用传统的分类方法。

小样本分类与传统的监督学习有所不同,小样本学习的目标不是让机器通过学习训练集中图片,知道哪类是什么样子;当我拿一个很大的训练集来训练神经网络后进行小样本分类,预训练模型的目的是让机器自己学会学习-----也就是学习事物的异同,学会区分不同的事物。

现在训练集有五类,其中并没有松鼠这个类别

训练完成之后,可以问模型这两张图片是否是相同的东西呢?这时候模型已经学会分辨了事物的异同,比如给出两张松鼠图片,模型知道这两个动物之间长得很像,模型能够告诉你两张图片很可能是相同的东西。

支持集

给出一张图片,神经网络不知道这是什么。

这时候就需要支持集(Support Set),每类给出少样本(1~2)张,神经网络将Query图片和支持集中的每个类别依次对比,找出最相似的。

训练集和支持集的区别
  • 训练集规模很大,每类有很多张图片,可以训练一个深度神经网络

  • 支持集每类只有一张或几张图片,不足以训练一个大的神经网络,只能在做预测时候提供一些额外信息。

  • 用足够大的训练集训练的目的不是让模型识别训练集中的大象、老虎,而是知道事物的异同。对于训练的模型,只要提供含有该类别的小样本信息,模型就能区分类别,尽管训练集中没有这个类别。

小样本分类:Learn To Learn

带小朋友去动物园,小朋友不知道这个动物是什么,但是小朋友只需要翻一遍卡片(将目标与卡片上动物对应),就知道看到的动物是什么,这个卡片就是支持集,前提是小朋友有读卡片的能力,也就是得先经过训练学习。

如果卡片中每类只有一张,那就是One-Shot Learning(单样本学习)

传统监督学习 和 小样本学习 步骤的区别

  • 传统监督学习:测试图片虽然不是训练集中图片,但包含在训练集类别,模型已经见过上千张该类别图片,能够判断出是哪类。

  • 小样本学习:测试图片不但不包含在训练集中,也不是训练集中的类别。所以小样本学习比传统监督学习更难。因为不是训练集中的类别,所以要提供支持集,提供更多信息(给模型看小卡片,每张卡片有一个图片和一个标签,模型发现测试图片和某张卡片相似度高,就知道测试图片属于哪个标签)

小样本学习两个术语
  • k-way :支持集含有的种类数

  • n-shot : 支持集中每个种类有多少张图片

小样本学习预测准确率

  • 横轴是支持集类别数量。随着类别数量增加,分类准确率会降低。

  • 比如从三选一变成六选一

  • 每类样本越多,做预测越容易

相似度函数

sim(x, x'), x,x'为两个input

理想情况:sim(x1,x2) = 1 , sim(x1,x3) = 0, sim(x2,x3) = 0

从一个很大的训练集上学习一个相似度函数,它可以判断两张图片的相似度有多高。

孪生神经网络就可以作为相似度函数,可以拿大规模数据集做训练,训练结束之后,可以拿得到的相似度函数做预测。给一个测试图片,可以拿他跟支持集中的图片逐一对比,计算相似度,找到相似度最高作为预测结果。

  • Omniglot 特点:小样本(20个,105*105)

孪生网络(Siamese Network)

孪生网络要解决的问题
  • 第一类,分类数量较少,每一类的数据量较多,比如ImageNet、VOC等。这种分类问题可以使用神经网络或者SVM解决,只要事先知道了所有的类。

  • 第二类,分类数量较多(或者说无法确认具体数量),每一类的数据量较少,比如人脸识别、人脸验证任务。(少样本问题)

孪生网络的优点
  • 这个网络主要的优点是淡化了标签,使得网络具有很好的扩展性,可以对那些没有训练过的小样本类别进行分类,这点是优于很多算法的。

第一种训练孪生网络方法:每次取2样本,比较相似度 。
  • 训练这个神经网络要用一个大的数据集,每类有标注,每类下面都有很多个样本。

  • 我们需要用训练集来构造正样本和负样本

    • 正样本告诉神经网络什么东西是同一类。

    • 负样本告诉神经网络事物之间的区别。

  • 正样本获取

    • 每次从训练集中抽取一张图片(老虎),然后从同一类中随机抽取另一张图片(老虎),标签设置为1 (tiger, tiger, 1),意思是相似度满分。

  • 负样本获取

    • 每次从训练集中抽取一张图片(汽车),排除汽车这个类别,再从数据集中随机抽样(大象),标签设计为(car, elephant, 0),意思是相似度为0

  • 搭建一个卷积神经网络CNN用来提取特征,这个神经网络有很多卷积层,Pooling层,以及一个flatten层。输入是一张图片x,输出是提取的特征向量 f(x)

  • 现在开始训练神经网络,输入为(x1, x2 , 0或1),把这两张图片输入神经网络,把刚才搭建的卷积神经网络记作函数f。

  • 对于提取的特征向量,第一张图片特征向量记作h1 = f(x1),第二张图片特征向量记作h2 = f(x2),如果都是用CNN,这两个f需要是相同的卷积神经网络,共享相同的权值W(之所以叫孪生,就是因为共享特征提取的部分)。也可以不同权值,则不同场景,允许不同神经网络。

  • 然后拿h1 - h2 得到一个向量,再对这个向量所有元素求绝对值,记作z = ||h1 - h2||,表示两个特征向量之间的区别,再用一些全连接层来处理z向量,输出一些标量。

  • 最后用Sigmoid激活函数,得到输出是一个介于0~1之间的实数,可以衡量两个图片之间的相似度。如果两张图片是同一个类别,输出应该接近1,如果两张图片不同类别,输出应该接近0(希望神经网络的训练输出接近1),把标签与预测之间的差别作为损失函数

  • 损失函数可以是标签与预测的交叉熵损失函数cross-entropy loss function,可以衡量标签与预测的差别

  • 有了损失函数可以用反向传播计算梯度,用梯度下降来更新模型参数。

  • 模型主要有两部分,一个是卷积神经网络f用来从图片提取特征,一个是全连接层预测相似度,训练部分就是更新这两个的参数

  • 做反向传播,梯度从损失函数传回到向量z以及全连接层的参数,有了损失函数关于全连接层的梯度,就可以更新全连接层的参数了。

  • 然后梯度进一步从向量z传回到卷积神经网络,更新卷积神经网络参数,这样就完成了一轮训练

  • 做训练时候,我们要准备同样数量正样本和负样本。负样本标签设置为0,希望神经网络预测接近0,意思是这两张图片不同。还是用同样方法做反向传播,更新参数。

训练好模型之后,可以做One-Shot Prediction

  • 六个类别,每个类别一张图片,这六个类别可以都不在训练集中

  • 将Query与Support Set支持集中图片作对比:

    • 将Query图片与支持集中某一类一张图片作为input1 和 input2 ,输入到孪生网络中,孪生网络会输出一个0~1之间的值。用同样方法算出Query与所有图片相似度,查找相似度最高的。

孪生网络第二种训练方法:Triplet Loss
准备数据
  • 有这样一个训练集,每次选出三张图片

  • 首先从训练集随机选一张图片,作为anchor(锚点),记录这个锚点,然后从同类中随机抽取一张图片作为正样本Positive;排除该类别,从数据集中作随机抽样,得到不同类别的负样本Negative。

  • 现在有锚点x^a,正样本x+,负样本x-,把三张图片分别输入卷积神经网络f来提取特征(f指的是同一个卷积神经网络),得到三个特征向量

  • 计算正样本和锚点再特征空间上的距离,将特征向量 f(x+)与f(xa)求差,然后算二范数的平方,得到距离d+

  • 类似操作得到d-

  • 我们希望得到的神经网络有这样性质,像同类别特征向量聚在一起,不同类别的特征向量能够被分开,所以d+应该很小,d-应该很大

  • 这个坐标系是特征空间,卷积神经网络可以把图片映射到这个特征空间

  • d-应该比d+大很多,否则模型分辨不了同类和不同类

  • 所以鼓励正样本在特征空间接近锚点(d+尽量小),鼓励负样本在特征空间远离锚点(d-尽量大)

  • 指定一个margin :α,α>0。如果d- >= d+ + α,我们就认为没有损失loss=0,分类正确。假如条件不满足,则会有loss = d+ + α - d- , 我们希望loss越小越好

  • 有了损失函数,就可以求损失函数关于神经网络的梯度,作梯度下降来更新模型参数

测试模型
  • 给一个query,一个支持集,用神经网络提取特征,把所有这些图片变为特征向量,比较特征向量之间的距离。找出距离最小的。

总结

我们使用了Siamese Network解决了少样本学习

基本思路:

  • 用一个比较大的训练集来训练孪生网络,让孪生网络知道事物之间的异同

  • 训练结束之后拿孪生网络作预测,解决少样本问题。少样本的问题是少样本的类别不在训练集中。比如query是松鼠,但训练集中没有松鼠这个类别,需要额外的信息来识别query的图片,这个额外的信息就是少样本支持集。

  • 支持集称为k-way, n-shot,k个类别,类别越多,预测越困难,n个样本,样本越少,预测越困难,one-shot learning单样本预测最困难。

  • 有了训练好的孪生网络,我们就可以将query与support set中的样本逐一对比,选出距离最小或相似度最高作为分类结果。

  • 两种训练孪生网络方法:1.两个input,标签0或1,输出0~1之间数值,与标签差值作为loss,目标是让预测尽量接近标签。 2.另一种是Triplet Loss,xa,x+,x-,用CNN提取得到三个特征向量,输出d+,d-,目标是让d+尽量小,d-尽量大。有了这样一个神经网络就可以用它提取特征,比较两张图片在特征空间距离,作出few-shot分类

Fine Tuning

基本思路

在大规模数据上预训练模型,然后再小规模的support set上做fine-tuning。方法简单,准确率高。

  • 看个例子,余弦相似度consine similarity,衡量两个向量之间相似度,现在两个向量长度都是1,即他们的二范数都为1。

  • 把向量x和w的夹角记作θ,由于向量x和w长度都是1,cosθ就是x和w的内积,表示两个向量的相似度

  • 可以理解,把向量x投影到w方向上,投影长度就是-1到+1之间

  • 如果向量x和w的长度不是1,则需要做归一化把他们程度变为1,然后求得的内积才是余弦相似度

微调主要用到Softmax Function
  • 它是一个常用的激活函数,可以把一个k维向量映射成一个概率分布

  • 输入为Φ,它是任意的k维向量。把Φ的每一个元素做指数变换,得到k个大于0的数;然后对其作归一化,让得到的k个数相加等于1,把得到的k个数记为向量p

  • 向量p就是softmax函数的输出

  • 性质

    • 输入Φ和输出p都是k维向量

    • 向量p的元素都是正数,而且相加等于1

    • 所以p是个概率分布

  • softmax通常用于分类器的输出层,如果有k个类别,那么softmax的输出就是k个概率值,每个概率值表示对一个类别的confidence

  • softmax会让最大的值变大,其余的值变小。softmax比max函数要温柔一些

Softmax分类器
  • 是一个全连接层加一个Softmax函数

  • 分类器的输入是特征向量x,表示输入的测试图片的特征向量,把x乘到参数矩阵w上,再加上向量b,得到一个向量

  • 对得到的向量做softmax变换,得到输出向量p

  • 假如类别数量为k,那么向量p就是k维的

  • 矩阵W和b是这一层的参数,可以从训练数据中学习。W有K行,k是类别数量,所以W每一行对应一个类别,d是每个类别的特征数量

使用预训练好的神经网络,在query和support set上做fine-tuning的过程
  • 把query和support set中的图片都映射成特征向量,这样可以比较query和support set在特征空间上的相似度,比如可以计算两两之间的cosine similarity。最后选择相似度最高的作为query的分类结果

  • 预训练

    • 搭一个卷积神经网络用来提取特征,有很多卷积层、Pooling层以及一个Flatten层,也可以有全连接层

    • 神经网络输入是一张图片x,输出一个特征向量f(x)

    • 可以用传统的监督学习,预训练好后把全连接层都去掉;也可以用孪生网络训练

  • Few-Shot分类方法

    • 3-way 2-shot,三类别,每类别两样本

    • 拿预训练的神经网络提取特征,每张图片变成一个特征向量,每个类别两个特征向量

    • 平均每个类别特征向量作平均,得到一个同样大小的向量,也就是均值向量

    • 有三个类别,一共得到三个均值向量

    • 均值向量归一化,得到三个向量μ1,μ2,μ3,它们的二范数都等于一,μ1,μ2,μ3就是对三个类别的表征

    • 做分类的时候,要拿query的特征向量对μ1,μ2,μ3作对比

  • 对query作分类

    • 给一张query图片,需要判断是三个类别中的哪一个

    • 拿预训练的神经网络f来提取特征,得到一个特征向量

    • 对特征向量作归一化,得到向量q,它的二范数等于1

    • 与刚才从support set中提取的三个向量μ1,μ2,μ3,它们的二范数也是1,每个μ向量表征一个类别

    • 可以把三个μ向量堆叠起来,作为矩阵M的三个行向量

  • 做few-shot预测

    • query的特征向量q乘到矩阵M上,再做Softmax变换,得到p = Softmax(Mq),p是个概率分布,这个例子里,p是三维向量,表示对三个类别的confidence

    • 三个元素分别是q与μ1,μ2,μ3的内积

    • 很显然,在向量p中,第一个元素最大,分类结果是第一类

Fine-tuning可以大幅提高预测准确率
  • 基本都是先做预训练,后做Fine-Tuning

  • 刚才我们用了固定的W和b,没有学习这两个参数

  • 可以在Support Set上学习W和b,这叫做fine tuning

    • Cross Entropy来衡量yj与pj的差别有多大,yj是真实标签,pj是分类器做出的预测,损失函数就是Cross Entropy Loss

    • Support set中有几个或者几十个有标注的样本,每个样本都对应一个Cross Entropy Loss,把这些Cross entropy loss加起来,作为损失函数

    • 也就是说我们用support set中所有的图片和标签来学习这个分类器

    • CrossEntropyLoss做最小化Minimization,让预测pj尽量接近真实标签yj

    • Minimization是对分类器参数W和b求的,希望学习W和b;当然也可以让梯度传播到卷积神经网络,更新神经网络参数,让提取的特征向量更有效

    • support通常很小几十个到几百个样本,最好加个regularization来防止过拟合。有一篇文章建议用Entropy Regularization

  • 有一篇ICLR2020的论文说 对于5-way 1-shot,做fine tuning可以提到2%~7%的准确率;对5-way 5-shot,提高1.5%~4%准确率

  • 尽管support set很小,但用support set来训练分类器有助于提高准确率,预训练+fine tuning比只用预训练好很多

  • W,b默认值

  • Entropy Regularization防止过拟合

    • 希望Entropy Regularization越小越好

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/873936.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Linux】基础I/O——动静态库的制作

我想把我写的头文件和源文件给别人用 1.把源代码直接给他2.把我们的源代码想办法打包为库 1.制作静态库 1.1.制作静态库的过程 我们先看看怎么制作静态库的! makefile 所谓制作静态库 需要将所有的.c源文件都编译为(.o)目标文件。使用ar指令将所有目标文件打包…

【前端】JavaScript入门及实战41-45

文章目录 41 嵌套的for循环42 for循环嵌套练习(1)43 for循环嵌套练习(2)44 break和continue45 质数练习补充 41 嵌套的for循环 <!DOCTYPE html> <html> <head> <title></title> <meta charset "utf-8"> <script type"…

谷粒商城实战笔记-38-前端基础-Vue-指令-单向绑定双向绑定

文章目录 一&#xff0c;插值表达式注意事项1&#xff1a;不适合复杂的逻辑处理注意事项2&#xff1a;插值表达式支持文本拼接注意事项3&#xff1a;插值表达式只能在标签体中 二&#xff0c;v-html和v-textv-textv-html区别总结&#xff1a;最佳实践 三&#xff0c;v-model复选…

WordPress杂技

WordPress杂技 WordPress页面构建器: Avada、Elementor、astra、 Elementor作为一款强大的页面构建工具。 Avada&#xff1a;是一款非常受欢迎的WordPress主题&#xff0c;它的设计理念是简洁、现代、响应式&#xff0c;Avada拥有丰富的模板和布局&#xff0c;可以轻松创建出…

多线程顺序执行

前言 现在面试中&#xff0c;不光会问力扣之类的算法&#xff0c;手撕多线程问题也被提上了日程。多线程之间的顺序执行是一个高频的面试手撕题&#xff0c;而且在实际应用中也会有用武之地。因此在这里&#xff0c;我们考虑使用不同的方式来实现多线程的顺序执行。在本文中&a…

Jackson 库简介--以及数据脱敏

Jackson 是一个流行的 Java JSON 处理库&#xff0c;它提供了将 Java 对象与 JSON 数据相互转换的功能。Jackson 的主要功能包括&#xff1a; 序列化&#xff1a;将 Java 对象转换为 JSON 字符串。反序列化&#xff1a;将 JSON 字符串转换为 Java 对象。 Jackson 提供了以下几…

C2W2.Assignment.Parts-of-Speech Tagging (POS).Part2

理论课&#xff1a;C2W2.Part-of-Speech (POS) Tagging and Hidden Markov Models 文章目录 2 Hidden Markov Models2.1 Generating MatricesCreating the A transition probabilities matrixExercise 03Create the B emission probabilities matrixExercise 04 理论课&#x…

FastAPI 学习之路(五十六)将token缓存到redis

在之前的文章中&#xff0c;FastAPI 学习之路&#xff08;二十九&#xff09;使用&#xff08;哈希&#xff09;密码和 JWT Bearer 令牌的 OAuth2&#xff0c;FastAPI 学习之路&#xff08;二十八&#xff09;使用密码和 Bearer 的简单 OAuth2&#xff0c;FastAPI 学习之路&…

Kubernetes 之 Ingress

Kubernetes 之 Ingress 定义 Ingress 可以把外部需要进入到集群内部的请求转发到集群中的一些服务上&#xff0c;从而实现把服务映射到集群外部的需要。Ingress 能把集群内 Service 配置成外网能够访问的 URL&#xff0c;流量负载均衡&#xff0c;提供基于域名访问的虚拟主机…

RabbitMQ 和 RocketMQ 的区别

RabbitMQ 和 RocketMQ 都是流行的开源消息中间件&#xff0c;它们用于在分布式系统中异步传输消息。尽管它们都实现了核心的消息队列功能&#xff0c;但它们在设计、性能、特性和使用场景上有一些关键的区别&#xff1a; 基础架构: RabbitMQ: 基于AMQP&#xff08;高级消息队列…

阵列信号处理学习笔记(二)--空域滤波基本原理

阵列信号 阵列信号处理学习笔记&#xff08;一&#xff09;–阵列信号处理定义 阵列信号处理学习笔记&#xff08;二&#xff09;–空域滤波基本原理 文章目录 阵列信号前言一、阵列信号模型1.1 信号的基本模型1.2 阵列的几何构型1.3 均匀直线阵的阵列信号基本模型 总结 前言…

HOW - React 处理不紧急的更新和渲染

目录 useDeferredValueuseTransitionuseIdleCallback 在 React 中&#xff0c;有一些钩子函数可以帮助你处理不紧急的更新或渲染&#xff0c;从而优化性能和用户体验。 以下是一些常用的相关钩子及其应用场景&#xff1a; useDeferredValue 用途&#xff1a;用于处理高优先级…

嵌入式面试总结

C语言中struct和union的区别 struct和union都是常见的复合结构。 结构体和联合体虽然都是由多个不同的数据类型成员组成的&#xff0c;但不同之处在于联合体中所有成员共用一块地址空间&#xff0c;即联合体只存放了一个被选中的成员&#xff0c;结构体中所有成员占用空间是累…

【网络】windows和linux互通收发

windows和linux互通收发 一、windows的udp客户端代码1、代码剖析2、总体代码 二、linux服务器代码三、成果展示 一、windows的udp客户端代码 1、代码剖析 首先我们需要包含头文件以及lib的一个库&#xff1a; #include <iostream> #include <WinSock2.h> #inclu…

前端页面是如何禁止被查看源码、被下载,被爬取,以及破解方法

文章目录 1.了解禁止查看,爬取原理1.1.JS代码,屏蔽屏蔽键盘和鼠标右键1.2.查看源码时,通过JS控制浏览器窗口变化2.百度文库是如何防止抓包2.1.HTPPS2.2. 动态加载为什么看不到?如何查看动态加载的内容?3.禁止复制,如果解决3.1.禁止复制原理3.2.如何破解1.了解禁止查看,爬…

使用scikit-learn进行机器学习:基础教程

使用scikit-learn进行机器学习&#xff1a;基础教程 Scikit-learn是Python中最流行的机器学习库之一。它提供了简单易用的工具&#xff0c;帮助我们进行数据预处理、模型训练、评估和预测。本文将带你通过一个基础教程&#xff0c;了解如何使用scikit-learn进行机器学习。 1.…

【模板代码】用于编写Threejs Demo的模板代码

基础模板代码 使用须知常规模板代码常规Shader模板代码 使用须知 本模板代码&#xff0c;主要用于编写Threejs的Demo&#xff0c;因为本人在早期学习的过程中&#xff0c;大量抄写Threejs/examples下的代码以及各个demo站的代码&#xff0c;所以养成了编写Threejs的demo的习惯…

SAP 采购订单 Adobe 消息输出

目录 1 简介 2 业务数据例子 3 选择增强 & 代码 1&#xff09;BADI: MM_PUR_S4_PO_MODIFY_HEADER 2&#xff09;BADI: MM_PUR_S4_PO_MODIFY_ITEM 4 自定义 Adobe form 1&#xff09;PO Master form 2&#xff09;PO form 5 前台主数据配置 6 后台配置 1&#xf…

昇思22天

CycleGAN图像风格迁移互换 CycleGAN&#xff08;循环生成对抗网络&#xff09;是一种用于在没有成对训练数据的情况下学习将图像从源域 X 转换到目标域 Y 的方法。该技术的一个重要应用是域迁移&#xff0c;即图像风格迁移。 模型介绍 模型简介: CycleGAN 来自于论文 Unpair…

掌握Rust:函数、闭包与迭代器的综合运用

掌握Rust&#xff1a;函数、闭包与迭代器的综合运用 引言&#xff1a;解锁 Rust 高效编程的钥匙函数定义与模式匹配&#xff1a;构建逻辑的基石高阶函数与闭包&#xff1a;代码复用的艺术迭代器与 for 循环&#xff1a;高效数据处理的引擎综合应用案例&#xff1a;构建一个简易…