🌞欢迎来到AI+生物医学的世界
🌈博客主页:卿云阁💌欢迎关注🎉点赞👍收藏⭐️留言📝
🌟本文由卿云阁原创!
🌠本阶段属于练气阶段,希望各位仙友顺利完成突破
📆首发时间:🌹2024年3月22日🌹
✉️希望可以和大家一起完成进阶之路!
🙏作者水平很有限,如果发现错误,请留言轰炸哦!万分感谢!
目录
基本思想
引入
传统的监督学习VSFew-Shot Learning
Siamese Network
Learning Pairwise Similarity Scores
Triplet损失
Pretraining and Fine Tuning
数学知识
简单的例子
基本思想
引入
下图有4张图片,左边两张是球鱼,右边两个是穿山甲,现在给了一张Query图片,这张图片是球鱼还是穿山甲呐?如果每一类只有一两个样本,计算机能否想我们一样正确的分类?对于这4个样本,我们不可能构建一个传统的神经网络。图中的Support set指的是一个很小的数据集,Few-Shot Learning针对的就是这种小样本分类问题。
Few-Shot Learning和传统的监督学习有所不同,它的目标不是让计算机识别训练集里的图片,并且泛化到测试集,而是让计算机自己学会学习,我拿一个很大是数据集来训练神经网络,学习的目的不是让计算机知道什么是大象什么是老虎,不是让计算机识别没见过的大象和老虎,学习的目的是让计算机理解事物的异同,学会区分不同的事物,给两张图片,让计算机知道两个图片是相同的东西还是不同的东西。
我们的数据集里有很多的动物,一共有5类但是没有松鼠这个类别,所以模型不会识别松鼠,模型看到这两张图片,不知道是否是松鼠,而是去判断异同。
现在我们换一种问法,我们给一张Query图片图片,让计算机判断这是什么东西,计算机会和Support set中的图片依次进行对比。(Support set中的图片可能很小,每一类的下面可能只有一类或者几类图片)
Meta-Learning (元学习)其实就是Few-Shot Learning,小朋友虽然没有见过水獭,但是看过一些动物的图片,就可以判断出来了。
传统的监督学习VSFew-Shot Learning
Few-Shot Learning里面有两个常用的术语
一般来说,分类准确率会随着ways是增加而降低,shot的增加而增加。
如何解决Few-Shot Learning?
学一个函数来判断相似度sim()
第一,从很大的数据集上,学习相似度函数。Siamese Network就可以作为相似度函数。
公共数据集介绍
Siamese Network
Learning Pairwise Similarity Scores
第一种方法是每次取两个样本比较它们的相似度。
正样本,就是从同类中取图片,把标签设置成1。
负样本,随机抽取一张图片,比如这辆车,排除汽车这个类别,然后再做随机抽样。
搭建一个卷积神经网络用来提取特征,输入是一张图片(X),输出是一个向量(f(x))。
现在开始训练神经网络,z表示两个向量的区别,再用全连接层来处理这个z向量,最后变成一个标量,最后经过激活函数,最后得到的就是0-1的实数,这个实数就可以衡量两张图片的相似度。
损失函数就是真实值和预测值的差别。有了损失函数就可以反向更新参数,模型一共有两个部分,一部分是卷积神经网络用来提取特征,一部分是全连接层用来计算相似度,训练的过程就是更新这两部分的参数。
训练好之后就可以进行预测,这6个类别都不在训练集里。现在来了一个Query,我们知道它一定属于6个样本中的一个。
Triplet损失
我们需要这样准备数据,我们有这样一个数据集我们每次选3张图片进行训练。
首先从训练集中随机选一张图片,把他当成anchor,然后从老虎中随机抽取另一张图片,作为正样本,排除老虎类别再抽取一张图片作为负样本。
搭建一个卷积神经网络用来提取特征,分别计算锚点和正负样本的距离。
d+越小越好,d-越大越好
训练好之后就可以进行预测,找出距离最小的。
Pretraining and Fine Tuning
基本想法是在数据上预训练模型,然后在小规模的support set上做Fine Tuning。虽然这类方法很简单,但是很准确。
数学知识
假设两个向量的长度都是1,二范数为1,把向量x和向量w的夹角记作g。
如果向量x和向量w的长度不是1,归一化,把长度变成1,然后再求内积。
Softmax Function假设它的输入是任意的k维向量,对每一个元素做归一化,得到k个大于0的数,然后对结果做归一化,让相加的结果为1。p就是 Softmax的输出。Softmax Classifier是一个全连接层,加上一个Softmax函数,分类器的输入是x。预训练一个神经网络用来提取特征,做预测的时候需要用到这个模型,我们把Query和support set中的图片都映射成特征向量,这样我们就可以比较Query和support set在特征你空间上的相似度。搭建一个卷积神经网络用来提取特征。提取特征向量预测Fine TuningXj是图片,Yj是标签,预训练的神经网络记作f(x)(特征向量)放到分类器。实际上W和b是可以学习的,support set中有一个甚至几十个有标注的样本,每个样本都对应一个损失函数,相加作为目标函数,用support set所有的图片和标签来学习这个分类器,对目标函数做minimazation。好的初始化防止过拟合
简单的例子
import torch from torchmeta.datasets import Omniglot from torchmeta.transforms import Categorical, ClassSplitter, Rotation from torchmeta.utils.data import BatchMetaDataLoader from torchmeta.modules import (MetaModule, MetaSequential, MetaConv2d, MetaBatchNorm2d, MetaLinear)# 加载 Omniglot 数据集 dataset = Omniglot("data",num_classes_per_task=5,meta_train=True,transform = Categorical(num_classes=5))# 定义元学习模型 class ConvolutionalNeuralNetwork(MetaModule):def __init__(self, in_channels, out_features):super(ConvolutionalNeuralNetwork, self).__init__()self.features = MetaSequential(MetaConv2d(in_channels, 64, kernel_size=3, padding=1),torch.nn.ReLU(),torch.nn.MaxPool2d(2),MetaConv2d(64, 64, kernel_size=3, padding=1),torch.nn.ReLU(),torch.nn.MaxPool2d(2),)self.classifier = MetaLinear(64*7*7, out_features)def forward(self, inputs, params=None):features = self.features(inputs, params=self.get_subdict(params, 'features'))features = features.view((features.size(0), -1))logits = self.classifier(features, params=self.get_subdict(params, 'classifier'))return logits# 构建元学习模型 model = ConvolutionalNeuralNetwork(1, dataset.num_classes)# 定义损失函数和优化器 loss_func = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)# 训练元学习模型 for epoch in range(num_epochs):for batch in dataloader:train_inputs, train_targets = batch["train"]test_inputs, test_targets = batch["test"]train_outputs = model(train_inputs)train_loss = loss_func(train_outputs, train_targets)model.zero_grad()train_loss.backward()optimizer.step()# 在测试集上进行元更新test_outputs = model(test_inputs)test_loss = loss_func(test_outputs, test_targets)model.adapt(test_loss)
使用 PyTorch 和 Torchmeta 库来实现。我们将使用 Omniglot 数据集,该数据集是一个小型的手写字符数据集,用于图像分类任务。我们将使用 MAML(Model-Agnostic Meta-Learning)算法来训练元学习模型,该模型能够快速适应新的任务。
【1】https://www.youtube.com/watch?v=UkQ2FVpDxHg&list=PLvOO0btloRnuGl5OJM37a8c6auebn-rH2