【DL】FocalLoss的PyTorch实现

【DL】FocalLoss的PyTorch实现

此篇不介绍FocalLoss的原理,仅展示PyTorch实现FocalLoss的两种方式。个人认为相关原理已在文章《FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现》中讲得很清晰,故此篇不再介绍。

方式一

同时计算一个batch中所有样本关于FocalLoss的损失值(来自文章《FocalLoss原理通俗解释及其二分类和多分类场景下的原理与实现》,个人补充了一些注释):

import torch
from torch import nn
import random
class FocalLoss(nn.Module):"""参考 https://github.com/lonePatient/TorchBlocks"""def __init__(self, gamma=2.0, alpha=1, epsilon=1.e-9, device=None):super(FocalLoss, self).__init__()self.gamma = gammaif isinstance(alpha, list):self.alpha = torch.Tensor(alpha, device=device)else:self.alpha = alphaself.epsilon = epsilon'''batch中所有样本一起计算loss'''def forward(self, input, target):"""Args:input: model's output, shape of [batch_size, num_cls]target: ground truth labels, shape of [batch_size]Returns:shape of [batch_size]"""num_labels = input.size(-1) # 类别数量idx = target.view(-1, 1).long() # 行向量target变成列向量idxone_hot_key = torch.zeros(idx.size(0), num_labels, dtype=torch.float32, device=idx.device)one_hot_key = one_hot_key.scatter_(1, idx, 1) # one_hot_key矩阵中的每一行对应相应样本的标签one_hot向量,利用scatter_方法将样本的标签类别标记为1,其余位置为0one_hot_key[:, 0] = 0  # ignore 0 index. 此行需要视具体情况决定是否保留,如果标签中存在类别0(而不是直接从类别1开始),此行应当注释、不使用logits = torch.softmax(input, dim=-1)loss = -self.alpha * one_hot_key * torch.pow((1 - logits), self.gamma) * (logits + self.epsilon).log() # 计算FocalLossloss = loss.sum(1)return loss.mean()# 固定随机数种子,方便复现
def setup_seed(seed):torch.manual_seed(seed)torch.cuda.manual_seed_all(seed)np.random.seed(seed)random.seed(seed)torch.backends.cudnn.deterministic = Trueif __name__ == '__main__':loss = FocalLoss(alpha=[0.1, 0.2, 0.3, 0.15, 0.25])# 设置随机数种子setup_seed(20) input = torch.randn(3, 5, requires_grad=True) # torch.Size([3, 5]) [sample_num, class_num]target = torch.empty(3, dtype=torch.long).random_(5) # torch.Size([3]) [sample_num]output = loss(input, target)# print(output)output.backward()

方式二

一个batch中逐个样本计算关于FocalLoss的损失值,将它们求平均,返回一个batch内所有样本的FocalLoss的平均值:

import torch
from torch import nn
import random
class FocalLoss(nn.Module):"""参考 https://github.com/lonePatient/TorchBlocks"""def __init__(self, gamma=2.0, alpha=1, epsilon=1.e-9, device=None):super(FocalLoss, self).__init__()self.gamma = gammaif isinstance(alpha, list):self.alpha = torch.Tensor(alpha, device=device)else:self.alpha = alphaself.epsilon = epsilon'''逐个样本计算loss'''    def forward(self, input, target):"""Args:input: model's output, shape of [batch_size, num_cls]target: ground truth labels, shape of [batch_size]Returns:shape of [batch_size]"""num_labels = input.size(-1) # 类别数量loss = []for i, sample in enumerate(input):one_hot_key = torch.zeros(1, num_labels, dtype=torch.float32, device=input.device)one_hot_key.scatter_(1, target[i].view(1, -1), 1)logits = torch.softmax(sample, dim=-1)loss_this_sample = - self.alpha * one_hot_key * torch.pow((1 - logits), self.gamma) * (logits + self.epsilon).log()loss_this_sample = loss_this_sample.sum(1)if i == 0:loss = loss_this_sampleelse:loss = torch.cat((loss, loss_this_sample))return loss.mean()# 固定随机数种子,方便复现
def setup_seed(seed):torch.manual_seed(seed)torch.cuda.manual_seed_all(seed)np.random.seed(seed)random.seed(seed)torch.backends.cudnn.deterministic = Trueif __name__ == '__main__':loss = FocalLoss(alpha=[0.1, 0.2, 0.3, 0.15, 0.25])# 设置随机数种子setup_seed(20) input = torch.randn(3, 5, requires_grad=True) # torch.Size([3, 5]) [sample_num, class_num]target = torch.empty(3, dtype=torch.long).random_(5) # torch.Size([3]) [sample_num]output = loss(input, target)# print(output)output.backward()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/9955.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【iOS】frame与bounds区别

文章目录 前言framebounds两者区别size的区别总结 前言 在学习响应者链的过程中用到了frame与bounds的混用,这两个属性经常出现在我们的开发中,特别撰写一篇博客分析区别 首先,我们来看一下iOS特有的坐标系,在iOS坐标系中以左上…

C语言如何查看进程中环境变量中所有的值

示例代码&#xff1a;查看进程中环境变量中所有的值。 #include <stdio.h>int main(){extern char** environ;for (char** pp environ; *pp; pp){printf("%s\n", *pp);}return 0; }输出结果&#xff1a; SHELL/bin/bash WSL2_GUI_APPS_ENABLED1 WSL_DISTRO_…

【debug】如何使用pycharm对代码调试

后续会将所有debug中遇到的知识放入&#xff0c;建议关注收藏 本站友情链接&#xff1a; 基本理论专栏&#xff08;当前更新好的debug所有内容都在这里&#xff09; 【debug】报错解决方法&#xff08;CondaHTTPError&#xff1a;HTTP 000 connection failed for url&#xff…

【回溯 状态压缩 深度优先】37. 解数独

本文涉及知识点 回溯 状态压缩 深度优先 LeetCode37. 解数独 编写一个程序&#xff0c;通过填充空格来解决数独问题。 数独的解法需 遵循如下规则&#xff1a; 数字 1-9 在每一行只能出现一次。 数字 1-9 在每一列只能出现一次。 数字 1-9 在每一个以粗实线分隔的 3x3 宫内只…

leetCode刷题记录4-面试经典150题-2

文章目录 不要摆&#xff0c;没事干就刷题&#xff0c;只有好处&#xff0c;没有坏处&#xff0c;实在不行&#xff0c;看看竞赛题面试经典 150 题 - 2210. 课程表 II 不要摆&#xff0c;没事干就刷题&#xff0c;只有好处&#xff0c;没有坏处&#xff0c;实在不行&#xff0c…

[C++核心编程-06]----C++类和对象之对象模型和this指针

&#x1f3a9; 欢迎来到技术探索的奇幻世界&#x1f468;‍&#x1f4bb; &#x1f4dc; 个人主页&#xff1a;一伦明悦-CSDN博客 ✍&#x1f3fb; 作者简介&#xff1a; C软件开发、Python机器学习爱好者 &#x1f5e3;️ 互动与支持&#xff1a;&#x1f4ac;评论 &…

Microsoft 365 for Mac v16.84 office365全套办公软件

Microsoft 365 for Mac是一款功能丰富的办公软件套件&#xff0c;为Mac用户提供了丰富的功能和工具&#xff0c;提高了工作效率和协作能力。Microsoft 365 for Mac是一款专为Mac用户设计的订阅式办公软件套件&#xff0c;旨在提高生产力和效率。 Microsoft 365 for Mac v16.84正…

数据赋能(83)——数据要素:数据要素管理与数据管理

数据要素管理则更关注数据作为生产性资源在创造经济价值中的作用&#xff1b;数据管理更侧重于数据在整个生命周期中的控制、保护和价值提升。 数据要素管理是对数据作为关键生产要素进行系统性管理的过程。它聚焦于数据在经济和社会活动中的价值创造和贡献&#xff0c;将数据…

ubantu安装docker以及docker-compose

ubantu安装docker以及docker-compose 安装docker1、从官方存储库中安装Docker2、启动Docker服务3、验证 安装docker compose使用docker部署服务1、需要再opt文件夹下创建以下文件夹&#xff0c;/opt文件夹目录说明2、可将已备份对应文件夹拷至对应文件夹下3、在/opt/compose目录…

python集合

集合是一个无序的不重复元素序列&#xff0c;集合中的元素必须是不可变类型 集合的创建与删除 用{}直接创建 用集合推导式创建 用ser&#xff08;&#xff09;函数将列表&#xff0c;元组&#xff0c;range对象转换成集合 numset1{1,2,3,4,5}numset2{x**2 for x in range(…

【代码】Mysql 查询近一个月各类型设备新增数量

错误示例 SELECT COUNT(*) AS count, p.type, d.active_date FROM device d LEFT JOIN product p ON d.product_id p.pid WHERE MONTH (active_date) MONTH (CURRENT_DATE - INTERVAL 1 MONTH) AND YEAR (active_date) YEAR (CURRENT_DATE - INTERVAL 1 MONTH) group by p.…

mysql高可用集群MGR组复制的介绍、部署及配置说明

前言 MGR全称MySQL Group Replication(Mysql组复制),是MySQL官方于2016年12月推出的一个全新的高可用与高扩展的解决方案。MGR提供了高可用、高扩展、高可靠的MySQL集群服务。 高一致性:基于分布式paxos协议实现组复制,保证数据一致性; 高容错性:自动检测机制,只要不…

霍金《时间简史 A Brief History of Time》书后索引(A--D)

图源&#xff1a;Wikipedia INDEX A Abacus Absolute position Absolute time Absolute zero Acceleration Age of the universe Air resistance Albrecht, Andreas Alpha Centauri Alpher, Ralph Anthropic principle Antigravity Antiparticles Aristotle Arrows of time …

基于Vant UI的微信小程序开发(随时更新的写手)

基于Vant UI的微信小程序开发✨ &#xff08;一&#xff09;悬浮浮动1、效果图&#xff1a;只要无脑引用样式就可以了2、页面代码3、js代码4、样式代码 &#xff08;二&#xff09;底部跳转1、效果图&#xff1a;点击我要发布跳转到发布的页面2、js代码3、页面代码4、app.json代…

vue项目设置主题色

在vue开发过程中&#xff0c;很多页面为了保持主题颜色统一&#xff0c;且方便后期管理&#xff0c;通常会设有主题色&#xff0c;通过主题色可以使得页面上的按钮单选框等控件保持颜色统一。 接下来介绍其中一种方法&#xff1a; 1.先建立一个js文件用于存放主题色&#xff…

我觉得POC应该贴近实际

今天我看到一位老师给我一份测试数据。 这是三个国产数据库。算是分布式的。其中有两个和我比较熟悉&#xff0c;但是这个数据看上去并不好。看上去第一个黄色的数据库数据是这里最好的了。但是即使如此&#xff0c;我相信大部分做数据库的人都知道。MySQL和PostgreSQL平时拿出…

Spark Streaming笔记总结(保姆级)

万字长文警告&#xff01;&#xff01;&#xff01; 目录 一、离线计算与流式计算 1.1 离线计算 1.1.1 离线计算的特点 1.1.2 离线计算的应用场景 1.1.3 离线计算代表技术 1.2 流式计算 1.2.1 流式计算的特点 1.2.2 流式计算的应用场景 1.2.3 流式计算的代表技术 二…

最小生成树刷题笔记

算法基础&#xff1a; 首先是prim算法三部曲&#xff1a; &#xff08;1&#xff09;找到距离最小生成树最近的节点。 &#xff08;2&#xff09;将距离最小生成树最近的节点加入到最小生成树中。 &#xff08;3&#xff09;更新非最小生成树节点到最小生成树的距离。 实现…

HTML批量文件上传3—Servlet批量文件处理FileUpLoad

作者:私语茶馆 1.开源的文件上传组件介绍 本文使用的是Apache Commons下面的一个子项目FileUpload,另外一个常见组件是SmartUpload。FileUpload遵循RFC 1897,即“Form-based File Upload in HTML”,对于请求需要满足:HTTP协议,Post请求,content Type=“multipart/form-d…

Kafka 面试题(五)

1. kafka的消费者是pull(拉)还是push(推)模式&#xff0c;这种模式有什么好处&#xff1f; Kafka的消费者是pull&#xff08;拉&#xff09;模式。在这种模式下&#xff0c;消费者主动从Kafka的broker中拉取数据来进行消费。 这种pull模式的好处主要体现在以下几个方面&#…