原型网络Prototypical Network的python代码逐行解释,新手小白也可学会!!-----系列4

在这里插入图片描述

文章目录

  • 原型网络进行分类的基本流程
  • 一、原始代码---计算欧氏距离,设计原型网络(计算原型+开始训练)
  • 二、每一行代码的详细解释
  • 总结


原型网络进行分类的基本流程

利用原型网络进行分类,基本流程如下:

1.对于每一个样本使用编码的方式fφ (),学习到每一个样本的编码表示(信息抽取)。
2.学习到每一个样本的编码表示之后,对于每一个分类下的所有的样本编码进行求和求取平均的操作,将结果作为分类的原型表示。
3.当一个新的数据样本被输入到网络中的时候,对于这个样本使用fφ(),生成其编码表示。
4.计算新的样本的编码表示和每一个分类的原型表示之间的距离情况,通过最下距离来确定查询样本属于哪一个分类。
5.在计算出所有的分类之间的距离之后,使用softmax的方式将距离转换成概率的形式。

一、原始代码—计算欧氏距离,设计原型网络(计算原型+开始训练)

def eucli_tensor(x,y):	#计算两个tensor的欧氏距离,用于loss的计算return -1*torch.sqrt(torch.sum((x-y)*(x-y))).view(1)class Protonets(object):def __init__(self,input_shape,outDim,Ns,Nq,Nc,log_data,step,trainval=False):#Ns:支持集数量,Nq:查询集数量,Nc:每次迭代所选类数,log_data:模型和类对应的中心所要储存的位置,step:若trainval==True则读取已训练的第step步的模型和中心,trainval:是否从新开始训练模型self.input_shape = input_shapeself.outDim = outDimself.batchSize = 1self.Ns = Nsself.Nq = Nqself.Nc = Ncif trainval == False:#若训练一个新的模型,初始化CNN和中心点self.center = {}self.model = CNNnet(input_shape,outDim)else:#否则加载CNN模型和中心点self.center = {}self.model = torch.load(log_data+'model_net_'+str(step)+'.pkl')		#'''修改,存储模型的文件名'''self.load_center(log_data+'model_center_'+str(step)+'.csv')	#'''修改,存储中心的文件名'''def compute_center(self,data_set):	#data_set是一个numpy对象,是某一个支持集,计算支持集对应的中心的点center = 0for i in range(self.Ns):data = np.reshape(data_set[i], [1, self.input_shape[0], self.input_shape[1], self.input_shape[2]])data = Variable(torch.from_numpy(data))data = self.model(data)[0]	#将查询点嵌入另一个空间if i == 0:center = dataelse:center += datacenter /= self.Nsreturn centerdef train(self,labels_data,class_number):	#网络的训练#Select class indices for episodeclass_index = list(range(class_number))random.shuffle(class_index)choss_class_index = class_index[:self.Nc]#选20个类sample = {'xc':[],'xq':[]}for label in choss_class_index:D_set = labels_data[label]#从D_set随机取支持集和查询集support_set,query_set = self.randomSample(D_set)#计算中心点self.center[label] = self.compute_center(support_set)#将中心和查询集存储在list中sample['xc'].append(self.center[label])	#listsample['xq'].append(query_set)#优化器optimizer = torch.optim.Adam(self.model.parameters(),lr=0.001)optimizer.zero_grad()protonets_loss = self.loss(sample)protonets_loss.backward()optimizer.step()

二、每一行代码的详细解释

def eucli_tensor(x, y):return -1 * torch.sqrt(torch.sum((x - y) * (x - y))).view(1)

这是一个函数,用于计算两个张量(tensor)之间的欧氏距离(Euclidean Distance)。它通过计算两个张量差的平方和的平方根,并乘以-1。最后通过 view(1) 将结果转换成一个形状为 (1,) 的张量。

class Protonets(object):def __init__(self, input_shape, outDim, Ns, Nq, Nc, log_data, step, trainval=False):self.input_shape = input_shapeself.outDim = outDimself.batchSize = 1self.Ns = Nsself.Nq = Nqself.Nc = Ncif trainval == False:self.center = {}self.model = CNNnet(input_shape, outDim)else:self.center = {}self.model = torch.load(log_data + 'model_net_' + str(step) + '.pkl')self.load_center(log_data + 'model_center_' + str(step) + '.csv')

这是一个 Protonets 类的定义,它有一个构造函数 __init__,用于初始化类的属性。其中的参数含义如下:

  • input_shape:输入数据的形状。
  • outDim:输出维度。
  • Ns:支持集(support set)的数量。
  • Nq:查询集(query set)的数量。
  • Nc:每次迭代所选类别数。
  • log_data:模型和中心的存储位置。
  • step:训练的步数。
  • trainval:是否重新开始训练模型。

根据 trainval 的取值,分为两种情况进行初始化:

  1. trainval=False:表示训练一个新的模型。此时,初始化一个空的中心字典 self.center,并创建一个名为 CNNnet 的模型对象 self.model,其输入形状为 input_shape,输出维度为 outDim
  2. trainval=True:表示加载已经训练好的模型和中心。同样,初始化一个空的中心字典 self.center。然后通过 torch.load 加载之前训练保存的模型文件 log_data + 'model_net_' + str(step) + '.pkl',并将其赋给 self.model。接着调用 load_center 方法加载之前训练保存的中心文件 log_data + 'model_center_' + str(step) + '.csv'

总结

这段代码是一个用于实现 Protonets 算法的类。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/146927.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Confluence 快速安装教程

安装jdk yum install -y java-1.8.0-openjdk.x86_64 java -version 安装MySQL mkdir -p /data/mysql/data chmod 777 /data/mysql/datadocker rm -f mysql docker run -d --name mysql \-p 3306:3306 \-e MYSQL_ROOT_PASSWORDfingard1 \-v /data/mysql/data:/var/lib/mysql …

​软考-高级-系统架构设计师教程(清华第2版)【第18章 安全架构设计理论与实践(P648~690)-思维导图】​

软考-高级-系统架构设计师教程(清华第2版)【第18章 安全架构设计理论与实践(P648~690)-思维导图】 课本里章节里所有蓝色字体的思维导图

视频剪辑技巧:简单步骤,批量剪辑并随机分割视频

随着社交媒体平台的广泛普及和视频制作需求的急剧增加,视频剪辑已经成为了当今社会一项不可或缺的技能。然而,对于许多初学者来说,视频剪辑可能是一项令人望而生畏的复杂任务。可能会面临各种困难,如如何选择合适的软件和硬件、如…

VBA技术资料MF84:判断文件夹是否存在并创建

我给VBA的定义:VBA是个人小型自动化处理的有效工具。利用好了,可以大大提高自己的工作效率,而且可以提高数据的准确度。我的教程一共九套,分为初级、中级、高级三大部分。是对VBA的系统讲解,从简单的入门,到…

【Qt-23】基于QCharts绘制曲线图

一、QChart简介 QChart是Qt中专门用于绘制图表的模块,支持折线图、柱状图、饼图等常见类型。其主要组成部分有: QChart:整个图表的容器,管理图表中的所有数据和图形属性QChartView:继承自QGraphicsView,用于…

基于单片机C51全自动洗衣机仿真设计

**单片机设计介绍, 基于单片机C51全自动洗衣机仿真设计 文章目录 一 概要二、功能设计设计思路 三、 软件设计原理图 五、 程序六、 文章目录 一 概要 基于单片机C51的全自动洗衣机仿真设计是一个复杂的项目,它涉及到硬件和软件的设计和实现。以下是对这…

镀膜与干刻中的平均自由程是什么?

在芯片制造中,镀膜和干刻是其中的重要环节,通常要用到CVD,RIE等技术,对材料表面进行纳米级的精细操作。在这些工序中,原子,分子,离子等,会在气体或真空中进行自由运动,直…

IDEA 高分辨率卡顿优化

VM设置优化 -Dsun.java2d.uiScale.enabledfalse 增加该条设置,关闭高分切换 https://intellij-support.jetbrains.com/hc/en-us/articles/115001260010-Troubleshooting-IDE-scaling-DPI-issues-on-Windows​intellij-support.jetbrains.com/hc/en-us/articles/1…

金融业务系统: Service Mesh用于安全微服务集成

随着云计算的不断演进,微服务架构变得日益复杂。为了有效地管理这种复杂性,人们开始采用服务网格。在本文中,我们将解释什么是Service Mesh,为什么它对现代云架构至关重要,以及它是如何解决开发人员今天面临的一些最紧…

剑指JUC原理-19.线程安全集合

👏作者简介:大家好,我是爱吃芝士的土豆倪,24届校招生Java选手,很高兴认识大家📕系列专栏:Spring源码、JUC源码🔥如果感觉博主的文章还不错的话,请👍三连支持&…

IDEA-git commit log 线

一、本地代码颜色标识 红色:新建的文件,没有add到git本地仓库蓝色:修改的文件,没有提交到git远程仓库绿色:已添加到git本地仓库,没有提交到git远程仓库灰色:删除的文件,没有提交到g…

QT专栏1 -Qt安装教程

#本文时间2023年11月18日,Qt 6.6# Qt 安装简要说明: Qt有两个版本一个是商业版本(收费),另一个是开源版本(免费); 打开安装程序时,通过判断账号是否有公司,安…

基于SSM的学生疫情信息管理系统设计与实现

末尾获取源码 开发语言:Java Java开发工具:JDK1.8 后端框架:SSM 前端:采用JSP技术开发 数据库:MySQL5.7和Navicat管理工具结合 服务器:Tomcat8.5 开发软件:IDEA / Eclipse 是否Maven项目&#x…

100套Axure RP大数据可视化大屏模板及通用组件库

106套Axure RP大数据可视化大屏模板包括了多种实用美观的可视化组件库及行业模板库,行业模板涵盖:金融、教育、医疗、政府、交通、制造等多个行业提供设计参考。 随着大数据的发展,可视化大屏在各行各业得到越来越广泛的应用。可视化大屏不再…

高斯积分-Gaussian Quadrature

https://mathworld.wolfram.com/GaussianQuadrature.html

mmdet 3.x 打印各类指标

和mmdet2.x中的修改地方不一样,在mmdet/evaluation/metrics/coco_metric.py中第72行将classwise设为True就可以打印各类指标了 但是在test的时候一直都是什么指标都不打印,不管是上面总的指标还是下面的各类指标,暂时不知道怎么处理 找到原因…

解决docker运行elastic服务端启动不成功

现象: 然后查看docker日志,发现有vm.max_map_count报错 ERROR: [1] bootstrap checks failed [1]: max virtual memory areas vm.max_map_count [65530] is too low, increase to at least [262144] 解决办法: 1. 宿主机(运行doc…

移远EC600U-CN开发板 11.15

制作一个简单UI: 1."端口设置"模块 *效果图 *代码 def backEvent(evt): #返回主界面code evt.get_code() if code lv.EVENT.CLICKED:lv.scr_load(mainInterface)def popUpEvent(evt): #弹窗提醒code evt.get_code()if code lv.EVENT.CL…

Azure 机器学习:使用 Azure 机器学习 CLI、SDK 和 REST API 训练模型

目录 环境准备克隆示例存储库 示例案例在云中训练1.连接到工作区PythonAzure CLIREST API 2. 创建用于训练的计算资源4. 提交训练作业PythonAzure CLIREST API 注册已训练的模型PythonAzure CLIREST API Azure 机器学习提供了多种提交 ML 训练作业的方法。 在本文中&#xff0c…

算法萌新闯力扣:存在重复元素II

力扣题:存在重复元素II 开篇 这道题是217.存在重复元素的升级版,难度稍微提高。通过这道题,能加强对哈希表和滑动窗口的运用。 题目链接:219.存在重复元素II 题目描述 代码思路 1.利用哈希表,来保存数组元素及其索引位置 2.遍…