文章目录
- 原型网络进行分类的基本流程
- 一、原始代码---计算欧氏距离,设计原型网络(计算原型+开始训练)
- 二、每一行代码的详细解释
- 总结
原型网络进行分类的基本流程
利用原型网络进行分类,基本流程如下:
1.对于每一个样本使用编码的方式fφ (),学习到每一个样本的编码表示(信息抽取)。
2.学习到每一个样本的编码表示之后,对于每一个分类下的所有的样本编码进行求和求取平均的操作,将结果作为分类的原型表示。
3.当一个新的数据样本被输入到网络中的时候,对于这个样本使用fφ(),生成其编码表示。
4.计算新的样本的编码表示和每一个分类的原型表示之间的距离情况,通过最下距离来确定查询样本属于哪一个分类。
5.在计算出所有的分类之间的距离之后,使用softmax的方式将距离转换成概率的形式。
一、原始代码—计算欧氏距离,设计原型网络(计算原型+开始训练)
def eucli_tensor(x,y): #计算两个tensor的欧氏距离,用于loss的计算return -1*torch.sqrt(torch.sum((x-y)*(x-y))).view(1)class Protonets(object):def __init__(self,input_shape,outDim,Ns,Nq,Nc,log_data,step,trainval=False):#Ns:支持集数量,Nq:查询集数量,Nc:每次迭代所选类数,log_data:模型和类对应的中心所要储存的位置,step:若trainval==True则读取已训练的第step步的模型和中心,trainval:是否从新开始训练模型self.input_shape = input_shapeself.outDim = outDimself.batchSize = 1self.Ns = Nsself.Nq = Nqself.Nc = Ncif trainval == False:#若训练一个新的模型,初始化CNN和中心点self.center = {}self.model = CNNnet(input_shape,outDim)else:#否则加载CNN模型和中心点self.center = {}self.model = torch.load(log_data+'model_net_'+str(step)+'.pkl') #'''修改,存储模型的文件名'''self.load_center(log_data+'model_center_'+str(step)+'.csv') #'''修改,存储中心的文件名'''def compute_center(self,data_set): #data_set是一个numpy对象,是某一个支持集,计算支持集对应的中心的点center = 0for i in range(self.Ns):data = np.reshape(data_set[i], [1, self.input_shape[0], self.input_shape[1], self.input_shape[2]])data = Variable(torch.from_numpy(data))data = self.model(data)[0] #将查询点嵌入另一个空间if i == 0:center = dataelse:center += datacenter /= self.Nsreturn centerdef train(self,labels_data,class_number): #网络的训练#Select class indices for episodeclass_index = list(range(class_number))random.shuffle(class_index)choss_class_index = class_index[:self.Nc]#选20个类sample = {'xc':[],'xq':[]}for label in choss_class_index:D_set = labels_data[label]#从D_set随机取支持集和查询集support_set,query_set = self.randomSample(D_set)#计算中心点self.center[label] = self.compute_center(support_set)#将中心和查询集存储在list中sample['xc'].append(self.center[label]) #listsample['xq'].append(query_set)#优化器optimizer = torch.optim.Adam(self.model.parameters(),lr=0.001)optimizer.zero_grad()protonets_loss = self.loss(sample)protonets_loss.backward()optimizer.step()
二、每一行代码的详细解释
def eucli_tensor(x, y):return -1 * torch.sqrt(torch.sum((x - y) * (x - y))).view(1)
这是一个函数,用于计算两个张量(tensor)之间的欧氏距离(Euclidean Distance)。它通过计算两个张量差的平方和的平方根,并乘以-1。最后通过 view(1)
将结果转换成一个形状为 (1,) 的张量。
class Protonets(object):def __init__(self, input_shape, outDim, Ns, Nq, Nc, log_data, step, trainval=False):self.input_shape = input_shapeself.outDim = outDimself.batchSize = 1self.Ns = Nsself.Nq = Nqself.Nc = Ncif trainval == False:self.center = {}self.model = CNNnet(input_shape, outDim)else:self.center = {}self.model = torch.load(log_data + 'model_net_' + str(step) + '.pkl')self.load_center(log_data + 'model_center_' + str(step) + '.csv')
这是一个 Protonets 类的定义,它有一个构造函数 __init__
,用于初始化类的属性。其中的参数含义如下:
input_shape
:输入数据的形状。outDim
:输出维度。Ns
:支持集(support set)的数量。Nq
:查询集(query set)的数量。Nc
:每次迭代所选类别数。log_data
:模型和中心的存储位置。step
:训练的步数。trainval
:是否重新开始训练模型。
根据 trainval
的取值,分为两种情况进行初始化:
trainval=False
:表示训练一个新的模型。此时,初始化一个空的中心字典self.center
,并创建一个名为CNNnet
的模型对象self.model
,其输入形状为input_shape
,输出维度为outDim
。trainval=True
:表示加载已经训练好的模型和中心。同样,初始化一个空的中心字典self.center
。然后通过torch.load
加载之前训练保存的模型文件log_data + 'model_net_' + str(step) + '.pkl'
,并将其赋给self.model
。接着调用load_center
方法加载之前训练保存的中心文件log_data + 'model_center_' + str(step) + '.csv'
。
总结
这段代码是一个用于实现 Protonets 算法的类。