Pytorch将标签转为One-Hot编码

一、标签映射与One-Hot编码过程

先进行标签映射,要为每个分类建立一个整数索引,对于每个样本的标签,使用整数索引创建一个长度为类别总数的二进制向量。这个向量的所有元素都是0,除了与整数索引相对应的位置,该位置的值为1。

二、pytorch的官方实现

在pytorch中实现了one hot编码,就在torch.nn.functional里面,下面是它的注释当中的示例,我们开看看:

Examples:>>> F.one_hot(torch.arange(0, 5) % 3)tensor([[1, 0, 0],[0, 1, 0],[0, 0, 1],[1, 0, 0],[0, 1, 0]])>>> F.one_hot(torch.arange(0, 5) % 3, num_classes=5)tensor([[1, 0, 0, 0, 0],[0, 1, 0, 0, 0],[0, 0, 1, 0, 0],[1, 0, 0, 0, 0],[0, 1, 0, 0, 0]])>>> F.one_hot(torch.arange(0, 6).view(3,2) % 3)tensor([[[1, 0, 0],[0, 1, 0]],[[0, 0, 1],[1, 0, 0]],[[0, 1, 0],[0, 0, 1]]])

我们可以根据那自己实现的与它给出的这个示例进行比对,一样就当然没问题了。

三、手写实现

首先,在原先的函数(one_hot)当中numclass=-1,类别当然不能为1,说明这里是自动进行了计算,大家普遍使用的方式都是创建一个全零矩阵,使用 scatter_ 函数进行独热编码,作用是按照给定的索引,在指定的维度上进行赋值。

def one_hot(labels, num_classes=-1):"""将标签转为独热编码, 经过测试与torch.nn.functional里面的函数测试相同:param labels: 标签:param num_classes: 默认为-1, 表示进行自动计算类别最大的那个Examples:>>> label_1 = torch.arange(0, 5) % 3# tensor([0, 1, 2, 0, 1])>>> label_2 = torch.arange(0, 6).view(3, 2) % 3# tensor([[0, 1], [2, 0], [1, 2]])>>> print(one_hot(label_1))tensor([[1, 0, 0],[0, 1, 0],[0, 0, 1],[1, 0, 0],[0, 1, 0]])>>> print(one_hot(label_1, 5))tensor([[1, 0, 0, 0, 0],[0, 1, 0, 0, 0],[0, 0, 1, 0, 0],[1, 0, 0, 0, 0],[0, 1, 0, 0, 0]])>>> print(one_hot(label_2))tensor([[[1, 0, 0],[0, 1, 0]],[[0, 0, 1],[1, 0, 0]],[[0, 1, 0],[0, 0, 1]]])"""if num_classes == -1:num_classes = int(labels.max()) + 1one_hot_tensor = torch.zeros(labels.size() + (num_classes,), dtype=torch.int64)one_hot_tensor.scatter_(-1, labels.unsqueeze(-1).to(torch.int64), 1)return one_hot_tensorlabel_1 = torch.arange(0, 5) % 3
# tensor([0, 1, 2, 0, 1])
label_2 = torch.arange(0, 6).view(3, 2) % 3
# tensor([[0, 1], [2, 0], [1, 2]])
print(one_hot(label_1))
print(one_hot(label_1, 5))
print(one_hot(label_2))

首先是判断分类数是不是为-1,如果是就根据其中的最大值+1进行自动计算。然后创建一个契合分类数量的全零矩阵。

在这里,labels.unsqueeze(-1)用于在标签的最后一个维度上添加一个维度,以便与独热编码张量进行广播操作。

假设原始的 labels 张量的形状为 (batch_size,),那么经过 unsqueeze(-1) 操作后,形状变为 (batch_size, 1)。这样,每个样本的标签都被表示为一个列向量,而不再是一个标量。scatter_函数在最后一个维度进行操作,也就是对类别总数的维度进行操作,而 1 是要赋给相应位置的值。

labels.unsqueeze(-1) 已经确保了与 one_hot_tensor 的形状匹配,所以在这里能够正确地进行广播和赋值操作。

下面这一种是应用于分割网络当中,在保留输入标签张量形状的同时,将独热编码张量的最后一个维度设置为分类数num_classes,确保独热编码张量与输入标签张量具有相同的形状。

def get_one_hot(labels, num_classes=-1):"""用于分割网络的one hot"""labels = torch.as_tensor(labels)ones = one_hot(labels, num_classes)return ones.view(*labels.size(), num_classes)if __name__=="__main__":seg_labels = torch.randint(0, 3, size=[512, 512])print(get_one_hot(seg_labels))print(get_one_hot(seg_labels).shape)   # torch.Size([512, 512, 3])

你可以将这里应用于自定义dataset部分。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/612089.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

引领行业赛道!聚铭网络入选安全419年度策划“2023年教育行业优秀解决方案”

近日,由网络安全产业资讯媒体安全419主办的《年度策划》2023年度优秀解决方案评选结果正式出炉,聚铭网络「高校大日志留存分析及实名审计解决方案」从众多参选方案中脱颖而出,被评为“教育行业优秀解决方案”,以硬核实力引领行业赛…

java基础 -02java集合之 List,AbstractList,ArrayList介绍

补充上篇 AbstractCollection < E > 在正式List之前&#xff0c;我们先了解我们补充上篇Collection接口的拓展实现&#xff0c;也就是说当我我们需要实现一个不可修改的Collection的时候&#xff0c;我们只需要拓展某个类&#xff0c;也就是AbstractCollection这个类&a…

ChatGPT4+Python近红外光谱数据分析及机器学习与深度学习建模

2022年11月30日&#xff0c;可能将成为一个改变人类历史的日子——美国人工智能开发机构OpenAI推出了聊天机器人ChatGPT3.5&#xff0c;将人工智能的发展推向了一个新的高度。2023年4月&#xff0c;更强版本的ChatGPT4.0上线&#xff0c;文本、语音、图像等多模态交互方式使其在…

安全漏洞周报(2024.01.01-2023.01.08)

漏洞速览 ■ 用友CRM系统存在逻辑漏洞 漏洞详情 1. 用友CRM系统存在逻辑漏洞 漏洞介绍&#xff1a; 某友CRM系统是一款综合性的客户关系管理软件&#xff0c;旨在帮助企业建立和维护与客户之间的良好关系。它提供了全面的功能&#xff0c;包括销售管理、市场营销、客户服…

1.10 Unity中的数据存储 XML

一、XML 1.介绍 XML是一个文档后缀名是*.xmlXML是一个特殊格式的文档XML是可扩展的标记性语言XML是Extentsible Markup Language的缩 写XML是由万维网联盟(W3C)创建的标记语言&#xff0c;用于定义编码人类和机器可以读取的文档的语法。它通过使用定义文档结构的标签以及如何…

代码随想录算法训练营第二十一天| 回溯 216. 组合总和 III 17. 电话号码的字母组合

216. 组合总和 III 可以参考77.组合中关于选取数组的相关操作。 递归函数的返回值以及参数&#xff1a;一般为void类型 递归函数终止条件&#xff1a;path这个数组的大小如果达到k&#xff0c;说明我们找到了一个子集大小为k的组合了&#xff0c;然后当n为0的时候&#xff0…

uniApp下载图片到手机相册,适配Android、Ios、微信小程序、H5

uniapp下载图片到手机&#xff0c;适配Android、Ios、微信小程序、H5 1.根据不同设备展示不同的按钮1.1 图片显示1.2 微信小程序显示的按钮1.3 h5显示的按钮1.4 app显示的按钮 2. 引入需要用到的文件3. data中需要的数据4. onload方法5. methods需要用到的方法6. 获取手机相册的…

Maven报错:Malformed \uxxxx encoding 解决办法

maven构建出现这个Malformed \uxxxx encoding问题&#xff0c;应该是maven仓库里面有脏东西进入了&#xff01; 解决&#xff1a; 将仓库中的resolver-status.properties文件全部干掉。 我使用的everything工具全局搜索resolver-status.properties文件&#xff0c;然后Ctrla,再…

Nodejs 第三十一章(响应头和请求头)

响应头 HTTP响应头&#xff08;HTTP response headers&#xff09;是在HTTP响应中发送的元数据信息&#xff0c;用于描述响应的特性、内容和行为。它们以键值对的形式出现&#xff0c;每个键值对由一个标头字段&#xff08;header field&#xff09;和一个相应的值组成。 例如…

第三十九级台阶

解题思路&#xff1a; 本题运用递归的思想&#xff0c;每走一步可以上一个或者两个台阶&#xff0c;一开始是左脚最后是右脚&#xff0c;所以走的总步数应该为偶数&#xff0c;最后跨过的台阶数应该等于39。 解题代码&#xff1a; public class disnashijiujitaijie {static i…

03. BI - 详解机器学习神器 XGBoost

本文专辑 : 茶桁的AI秘籍 - BI篇 原文链接: https://mp.weixin.qq.com/s/kLEg_VcxAACy8dH35kK3zg 文章目录 集成学习XGBoost Hi&#xff0c;你好。我是茶桁。 学习总是一个循序渐进的过程&#xff0c;之前两节课的内容中&#xff0c;咱们去了解了LR和SVM在实际项目中是如何使…

ROS2学习笔记二:开发准备

目录 1 DDS介绍 2. 工程介绍 4 构建工具colcon 5 启动一个节点 1 DDS介绍 DDS&#xff0c;全称 Data Distribution Service (数据分发服务)。是由对象管理组 (OMG) 于 2003 年发布并于 2007 年修订的开分布式系统标准。通过类似于ROS中的话题发布和订阅形式来进行通信&…

100V耐压 LED恒流驱动芯片 SL2516D兼容替换LN2516车灯照明芯片

SL2516D LED恒流驱动芯片是一款专为LED照明设计的高效、高精度恒流驱动芯片。与LN2516车灯照明芯片兼容&#xff0c;可直接替换LN2516芯片&#xff0c;为LED车灯照明提供稳定、可靠的电源解决方案。 一、SL2516D LED恒流驱动芯片的特点 1. 高效率&#xff1a;SL2516D采用先进的…

HarmonyOS4.0系统性深入开发17进程模型概述

进程模型概述 HarmonyOS的进程模型&#xff1a; 应用中&#xff08;同一包名&#xff09;的所有UIAbility运行在同一个独立进程中。WebView拥有独立的渲染进程。 基于HarmonyOS的进程模型&#xff0c;系统提供了公共事件机制用于一对多的通信场景&#xff0c;公共事件发布者…

深度解析-Java语言的未来

深度解析-Java语言的未来&#xff0c;文末有我耗时一个月&#xff0c;问遍了身边的大佬&#xff0c;零基础自学Java的路线&#xff0c;适用程序员入门&进阶&#xff0c;Java学习路线&#xff0c;2024新版最新版。 文章目录 Q1 - 能否自我介绍下&#xff1f; Q2 - Java语…

Python常用配置文件读取方法

常见的应用配置方式有环境变量和配置文件,对于微服务应用,还会从配置中心加载配置,比如nacos、etcd等,有的应用还会把部分配置写在数据库中。此处主要记录从环境变量、.env文件、.ini文件、.yaml文件、.toml文件、.json文件读取配置。 ini文件 ini文件格式一般如下: [m…

GBASE南大通用CreateParameter 方法

创建一个GBASE南大通用Parameter 对象的实例。  语法 [Visual Basic] Public Function CreateParameter As GBaseParameter [C#] public GBaseParameter CreateParameter()  返回值 创建的 GBaseParameter 对象。 执行一个 SQL 语句并返回影响的行数。  语法 […

设计模式——抽象工厂模式(Abstract Factory Pattern)

概述 抽象工厂模式的基本思想是将一些相关的产品组成一个“产品族”&#xff0c;由同一个工厂统一生产。在工厂方法模式中具体工厂负责生产具体的产品&#xff0c;每一个具体工厂对应一种具体产品&#xff0c;工厂方法具有唯一性&#xff0c;一般情况下&#xff0c;一个具体工厂…

数据结构与算法之美学习笔记:46 | 概率统计:如何利用朴素贝叶斯算法过滤垃圾短信?

目录 前言算法解析总结引申 前言 本节课程思维导图&#xff1a; 上一节我们讲到&#xff0c;如何用位图、布隆过滤器&#xff0c;来过滤重复的数据。今天&#xff0c;我们再讲一个跟过滤相关的问题&#xff0c;如何过滤垃圾短信&#xff1f; 垃圾短信和骚扰电话&#xff0c;我…

基于长短期神经网络LSTM的路径追踪

目录 背影 摘要 代码和数据下载:基于长短期神经网络LSTM的路径追踪(代码完整,数据齐全)资源-CSDN文库 https://download.csdn.net/download/abc991835105/88714816 LSTM的基本定义 LSTM实现的步骤 基于长短期神经网络LSTM的路径追踪 结果分析 展望 参考论文 背影 路径坐标…