⌈ 传知代码 ⌋ 将一致性正则化用于弱监督学习

💛前情提要💛

本文是传知代码平台中的相关前沿知识与技术的分享~

接下来我们即将进入一个全新的空间,对技术有一个全新的视角~

本文所涉及所有资源均在传知代码平台可获取

以下的内容一定会让你对AI 赋能时代有一个颠覆性的认识哦!!!

以下内容干货满满,跟上步伐吧~


📌导航小助手📌

  • 💡本章重点
  • 🍞一. 论文概述
  • 🍞二. 算法原理
  • 🍞三.核心逻辑
  • 🍞四.效果演示
  • 🫓总结


💡本章重点

  • 一致性正则化用于弱监督学习

🍞一. 论文概述

本文复现论文 Revisiting Consistency Regularization for Deep Partial Label Learning[1] 提出的偏标记学习方法。程序基于Pytorch,会保存完整的训练日志,并生成损失变化图和准确度变化图。

偏标记学习(Partial Label Learning)是一个经典的弱监督问题。在偏标记学习中,每个样例的监督信息为一个包含多个标签的候选标签集合。目前的偏标记方法大多基于自监督或者对比学习范式,或多或少地会遇到低性能或低效率的问题。该论文基于一致性正则化的思想,改进基于自监督的偏标记学习方法。具体地,该论文所提出的方法设计了两个训练目标。其中第一个训练目标为最小化非候选标签的预测输出,第二个目标最大化不同视图的预测输出之间的一致性。

在这里插入图片描述
总的来说,该论文所提出的方法着眼于将模型对同一图像不同增强视图的预测输出对齐,以提升模型输出的可靠性和对标签的消歧能力,这一方法同样能给其他弱监督学习任务带来提升。


🍞二. 算法原理

首先,论文所提出方法的第一项损失(监督损失)如下:

在这里插入图片描述

其中,当事件 A 为真时,I(A)= 1 否则 I(A)= 0,f(.)表示模型的输出概率。

然后,论文所提出方法的第二项损失(一致性损失)如下:

在这里插入图片描述
其在训练过程中通过所有增强视图预测结果的几何平均来更新标签分布:

在这里插入图片描述
由于数据增强的不稳定性,该论文通过叠加 K 个不同的增强视图的一致性损失来提升方法性能。

最后,考虑到训练初期模型的预测准确率较低,一致性损失的权重被设置为从零开始随着训练轮数的增加逐渐提高:

在这里插入图片描述

综上所述,模型的总损失函数如下:

在这里插入图片描述


🍞三.核心逻辑

具体的核心逻辑如下所示:

def dpll_sup_loss(probs, partial_labels):loss = -torch.sum(torch.log(1 + 1e-6 - probs) * (1 - partial_labels), dim=-1)loss_avg = torch.mean(loss)return loss_avgdef dpll_cont_loss(logits, targets):logits_log = torch.log_softmax(logits, dim=-1)loss = F.kl_div(logits_log, targets, reduction='batchmean')return lossdef train():# main loopsfor epoch_id in range(total_epochs):# trainmodel.train()for batch in train_dataloader:optimizer.zero_grad()ids = batch['ids']data1 = batch['data1'].to(device)data2 = batch['data2'].to(device)data3 = batch['data3'].to(device)partial_labels = batch['partial_labels'].to(device)targets = train_targets[ids].to(device)logits1 = model(data1)logits2 = model(data2)logits3 = model(data3)probs1 = F.softmax(logits1, dim=-1)# update targetswith torch.no_grad():probs2 = F.softmax(logits2.detach(), dim=-1)probs3 = F.softmax(logits3.detach(), dim=-1)new_targets = torch.pow(probs1.detach() * probs2 * probs3, 1 / 3)new_targets = F.normalize(new_targets * partial_labels, p=1, dim=-1)train_targets[ids] = new_targets.cpu()# dynamic weightbalancing_weight = max_weight * (epoch_id + 1) / max_weight_epochbalancing_weight = min(max_weight, balancing_weight)# supervised lossloss_sup = dpll_sup_loss(probs1, partial_labels)# consistency regularization lossloss_cont1 = dpll_cont_loss(logits1, targets)loss_cont2 = dpll_cont_loss(logits2, targets)loss_cont3 = dpll_cont_loss(logits3, targets)# all lossloss = loss_sup + balancing_weight * (loss_cont1 + loss_cont2 + loss_cont3)loss.backward()optimizer.step()if epoch_id in lr_decay_epochs:lr_scheduler.step()

🍞四.效果演示

本文基于网络 Wide-ResNet[2] 和数据集 CIFAR-10[3] 进行实验,偏标记的随机翻转概率为0.1。当然,本文所提供的程序不仅仅提供了上述的实验设置,同时也可以直接基于CIFAR-100(100类图像分类数据集),SVHN(数字号牌识别数据集),Fashion-MNIST(时装识别数据集),Kuzushiji-MNIST(日本古草体识别数据集)进行实验。仅仅需要替换运行命令的对应部分即可(使用说明见下文)

  • 损失曲线:

在这里插入图片描述

  • 准确率曲线:

在这里插入图片描述


🫓总结

综上,我们基本了解了“一项全新的技术啦” 🍭 ~~

恭喜你的内功又双叒叕得到了提高!!!

感谢你们的阅读😆

后续还会继续更新💓,欢迎持续关注📌哟~

💫如果有错误❌,欢迎指正呀💫

✨如果觉得收获满满,可以点点赞👍支持一下哟~✨

【传知科技 – 了解更多新知识】

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/54322.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

查看 Git 对象存储中的内容

查看 Git 对象存储中的内容 ls -C .git/objects/<dir>ls: 列出目录内容的命令。-C: 以列的形式显示内容。.git/objects/<dir>: .git 是存储仓库信息的 Git 目录&#xff0c;objects 是其中存储对象的子目录。<dir> 是对象存储目录下的一个特定的子目录。 此…

mysql学习教程,从入门到精通,SQL 修改表(ALTER TABLE 语句)(29)

1、SQL 修改表&#xff08;ALTER TABLE 语句&#xff09; 在编写一个SQL的ALTER TABLE语句时&#xff0c;你需要明确你的目标是什么。ALTER TABLE语句用于在已存在的表上添加、删除或修改列和约束等。以下是一些常见的ALTER TABLE语句示例&#xff0c;这些示例展示了如何修改表…

H.264编解码 - I/P/B帧详解

一、概述 在H.264编解码中,I/P/B帧是一种常见的帧类型。以下是它们的解释: I帧(关键帧):也称为关键帧,它是视频序列中的第一个帧或每个关键时刻的第一个帧。I帧是完整的、自包含的图像帧,不依赖于其他帧进行解码。它存储了关键时刻的完整图像信息。 P帧(预测帧):P帧…

<STC32G12K128入门第十六步>获取NTP网络时间

前言 这里主要讲解如何通过NTP服务器获取网络时间。 一、NTP是什么? NTP全名“Network TimeProtocol”,即网络时间协议,是由RFC 1305定义的时间同步协议,用来在分布式时间服务器和客户端之间进行时间同步。 NTP基于UDP报文进行传输,使用的UDP端口号为123。使用NTP的目的…

2款.NET开源且免费的Git可视化管理工具

Git是什么&#xff1f; Git是一种分布式版本控制系统&#xff0c;它可以记录文件的修改历史和版本变化&#xff0c;并可以支持多人协同开发。Git最初是由Linux开发者Linus Torvalds创建的&#xff0c;它具有高效、灵活、稳定等优点&#xff0c;如今已成为软件开发领域中最流行…

some 蓝桥杯题

12.反异或01串 - 蓝桥云课 (lanqiao.cn) #include "bits/stdc.h" #define int long long using namespace std; char c[10000000]; char s[10000000]; int cnt,Ans,mr,mid; int maxi; int p[10000000],pre[10000000]; signed main() {ios::sync_with_stdio(0);cin.t…

如何使用EventChannel

文章目录 1 知识回顾2 示例代码3 经验总结我们在上一章回中介绍了MethodChannel的使用方法,本章回中将介绍EventChannel的使用方法.闲话休提,让我们一起Talk Flutter吧。 1 知识回顾 我们在前面章回中介绍了通道的概念和作用,并且提到了通道有不同的类型,本章回将其中一种…

使用Apifox创建接口文档,部署第一个简单的基于Vue+Axios的前端项目

前言 在当今软件开发的过程中&#xff0c;接口文档的创建至关重要&#xff0c;它不仅能够帮助开发人员更好地理解系统架构&#xff0c;还能确保前后端开发的有效协同。Apifox作为一款集API文档管理、接口调试、Mock数据模拟为一体的工具&#xff0c;能够大幅度提高开发效率。在…

我为什么决定关闭ChatGPT的记忆功能?

你好&#xff0c;我是三桥君 几个月前&#xff0c;ChatGPT宣布即将推出一项名为“记忆功能”的新特性&#xff0c;英文名叫memory。 这个功能听起来相当吸引人&#xff0c;宣传口号是让GPT更加了解用户&#xff0c;仿佛是要为我们每个人量身打造一个专属的AI助手。 在记忆功…

用Arduino单片机读取PCF8591模数转换器的模拟量并转化为数字输出

PCF8591是一款单芯片&#xff0c;单电源和低功耗8位CMOS数据采集设备。博文[1]对该产品已有介绍&#xff0c;此处不再赘述。但该博文是使用NVIDIA Jetson nano运行python读取输入PCF8591的模拟量的&#xff0c;读取的结果显示在屏幕上&#xff0c;或输出模拟量点亮灯。NVIDIA J…

Ubuntu下Kafka安装及使用

Kafka是由Apache软件基金会开发的一个开源流处理平台&#xff0c;同时也是一个高吞吐量的分布式发布订阅消息系统。它由Scala和Java编写&#xff0c;具有多种特性和广泛的应用场景。 Kafka是一个分布式消息系统&#xff0c;它允许生产者&#xff08;Producer&#xff09;发布消…

docker 部署nacos

目录 一、拉取镜像 二、部署 三、访问&#xff08;默认是用内嵌数据库&#xff09; 四、配置 五、重启容器 一、拉取镜像 docker pull nacos/nacos-server 二、部署 docker run --name nacos -d -p 8848:8848 -p 9848:9848 -p 9849:9849 --restartalways --privilegedt…

软考鸭微信小程序:助力软考备考的便捷工具

一、软考鸭微信小程序的功能 “软考鸭”微信小程序是一款针对软考考生的备考辅助工具&#xff0c;提供了丰富的备考资源和功能&#xff0c;帮助考生提高备考效率&#xff0c;顺利通过考试。其主要功能包括&#xff1a; 历年试题库&#xff1a;小程序内集成了历年软考试题&…

加油站智能视频监控预警系统(AI识别烟火打电话抽烟) Python 和 OpenCV 库

加油站作为存储和销售易燃易爆油品的场所&#xff0c;是重大危险源之一&#xff0c;随着科技的不断发展&#xff0c;智能视频监控预警系统在加油站的安全保障方面发挥着日益关键的作用&#xff0c;尤其是其中基于AI的烟火识别、抽烟识别和打电话识别功能&#xff0c;以及其独特…

云服务架构与华为云架构

目录 1.云服务架构是什么&#xff1f; 1.1 云服务模型 1.2 云部署模型 1.3 云服务架构的组件 1.4 云服务架构模式 1.5 关键设计考虑 1.6 优势 1.7 常见的云服务架构实践 2.华为云架构 2.1 华为云服务模型 2.2 华为云部署模型 2.3 华为云服务架构的核心组件 2.4 华…

MFC工控项目实例之十九手动测试界面输出信号切换

承接专栏《MFC工控项目实例之十八手动测试界面输入信号实时检测》 根据板卡设置界面组合框选项设定的输出信号&#xff0c;通过读取文件中保存的键值&#xff0c;用单选按钮切换输出信号接通、关闭。 1、在Data_1.h文件中添加代码 CString COMB_Data_O_1[]{"夹紧",&…

JS基础练习|ES6-类定义和基础

class Animal {constructor(name) {this.name name;}speak() {console.log(${this.name} makes a noise.);} }class Dog extends Animal {constructor(name, breed) {super(name); // 调用父类的构造函数this.breed breed;}speak() {console.log(${this.name} barks.);} }con…

实时语音交互,打造更加智能便捷的应用

随着人工智能和自然语言处理技术的进步&#xff0c;用户对智能化和便捷化应用的需求不断增加。语音交互技术以其直观的语音指令&#xff0c;革新了传统的手动输入方式&#xff0c;简化了用户操作&#xff0c;让应用变得更加易用和高效。 通过语音交互&#xff0c;用户可以在不…

Label-Studio ML利用yolov8模型实现自动标注

引言 Label Studio ML 后端是一个 SDK&#xff0c;用于包装您的机器学习代码并将其转换为 Web 服务器。Web 服务器可以连接到正在运行的 Label Studio 实例&#xff0c;以自动执行标记任务。我们提供了一个示例模型库&#xff0c;您可以在自己的工作流程中使用这些模型&#x…

基于SpringCloud的微服务架构下安全开发运维准则

为什么要进行安全设计 微服务架构进行安全设计的原因主要包括以下几点&#xff1a; 提高数据保护&#xff1a;微服务架构中&#xff0c;服务间通信频繁&#xff0c;涉及到大量敏感数据的交换。安全设计可以确保数据在传输和存储过程中的安全性&#xff0c;防止数据泄露和篡改。…