【深度学习基础】理解 PyTorch 中的 logits 和交叉熵损失函数

在深度学习中,理解损失函数是训练模型的关键一步。在分类任务中,交叉熵损失函数是最常用的损失函数之一。本文将详细解释 PyTorch 中的 logits、交叉熵损失函数的工作原理,并展示如何调整张量的形状以确保计算正确的损失。

什么是 logits?

logits 是模型输出的未归一化预测值,通常是全连接层的输出。在分类任务中,logits 的形状通常为 (batch_size, num_labels),其中 batch_size 是一个批次中的样本数,num_labels 是分类任务中的类别数。

什么是交叉熵损失函数?

交叉熵损失函数(Cross-Entropy Loss)是一种常用于分类任务的损失函数。它衡量的是预测分布与真实分布之间的差异。具体而言,它会计算每个样本的预测类别与真实类别之间的距离,然后取平均值。

在 PyTorch 中,交叉熵损失函数可以通过 torch.nn.CrossEntropyLoss 来实现。该函数结合了 LogSoftmaxNLLLoss 两个操作,适用于未归一化的 logits。

示例:计算 logits 和交叉熵损失

让我们通过一个具体示例来详细解释如何计算 logits 和交叉熵损失。

定义模型

首先,我们定义一个简单的模型,其中包含一个全连接层和一个 dropout 层。

import torch
import torch.nn as nnclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.dropout = nn.Dropout(p=0.1)self.classifier = nn.Linear(768, 3)  # 假设输入的维度是768,输出的维度是3def forward(self, output):pooled_output = output[1]pooled_output = self.dropout(pooled_output)logits = self.classifier(pooled_output)return logits
训练循环

接下来,我们定义一个训练循环,并在其中计算损失。

# 假设你有数据加载器和优化器等
# dataloader = ...
# optimizer = ...model = MyModel()
criterion = nn.CrossEntropyLoss()  # 定义交叉熵损失函数
optimizer = torch.optim.Adam(model.parameters())for epoch in range(num_epochs):for inputs, labels in dataloader:optimizer.zero_grad()outputs = model(inputs)logits = outputs# 计算损失loss = criterion(logits.view(-1, model.classifier.out_features), labels.view(-1))# 反向传播和优化loss.backward()optimizer.step()
解释代码细节
  1. logits:

    • logits 是模型的输出。假设 logits 的形状为 (batch_size, num_labels),例如 (32, 3),表示每个批次有 32 个样本,每个样本有 3 个类别的预测值。
  2. labels:

    • labels 是模型的真实标签。假设 labels 的形状为 (batch_size,),例如 (32,),表示每个批次有 32 个样本的真实类别标签。
  3. .view():

    • logits.view(-1, model.classifier.out_features)view 方法用于重新调整张量的形状。这里将 logits 的形状调整为 (-1, num_labels),其中 -1 表示自动计算的维度大小,使总元素数保持不变。这种调整通常用于确保张量形状与损失函数期望的输入形状相匹配。
    • labels.view(-1):同样,view(-1)labels 的形状调整为一维,便于与 logits 的形状对齐。
  4. 计算损失:

    • loss = criterion(logits.view(-1, model.classifier.out_features), labels.view(-1)):这行代码计算 logitslabels 之间的交叉熵损失。调整后的 logits 形状为 (batch_size * num_labels, num_labels),调整后的 labels 形状为 (batch_size * num_labels,)。这样,损失函数能够正确计算每个样本的损失。
具体示例

假设有一个分类任务,模型的输出和标签如下:

logits = torch.tensor([[2.0, 0.5, 0.3], [0.2, 2.0, 0.5]])
labels = torch.tensor([0, 1])

解释如下:

  • logits 的形状是 (2, 3),表示有 2 个样本,每个样本有 3 个类别的预测值。
  • labels 的形状是 (2,),表示有 2 个样本的真实类别标签。
  • model.classifier.out_features 是 3,表示有 3 个类别。

调整形状并计算损失:

logits = logits.view(-1, 3)  # 形状变为 (2, 3)
labels = labels.view(-1)     # 形状变为 (2,)loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits, labels)  # 计算交叉熵损失
交叉熵损失计算

交叉熵损失会分别计算每个样本的损失,并取平均值。例如,对于第一个样本,真实标签是类别 0,损失函数会对类别 0 的预测值计算损失。对于第二个样本,真实标签是类别 1,损失函数会对类别 1 的预测值计算损失。

总结

在本文中,我们深入解释了 PyTorch 中 logits 和交叉熵损失函数的工作原理,并展示了如何调整张量的形状以确保正确计算损失。这是分类任务中标准的损失计算步骤,有助于优化模型的参数。通过理解这些概念,你可以更好地调试和优化你的深度学习模型。

希望这篇文章对你理解 PyTorch 中的 logits 和交叉熵损失函数有所帮助!如果你有任何问题或建议,请在评论区留言。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/852957.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

论人工智能与真实性

论人工智能与真实性 这让我们都感到不安:不是因为人工智能已经足够好,可以准确地预测某人可能会如何回答(包括猫的名字、表情符号的使用、汤的参考以及对“精神动物”的随意参考),而是因为提供这些反应菜单的模式首先代表了对这些互动功能的误解。即使回…

59.指向指针的指针(二级指针)

目录 一.什么是指向指针的指针 二.扩展 三.视频教程 一.什么是指向指针的指针 我们先看回顾一下指针&#xff1a; #include <stdio.h>int main(void) {int a 100;int *p &a;printf("*p is %d\n",*p);return 0;} 解析&#xff1a; 所以printf输出的结…

TCP/IP协议,三次握手,四次挥手

IP - 网际协议 IP 负责计算机之间的通信。 IP 负责在因特网上发送和接收数据包。 HTTP - 超文本传输协议 HTTP 负责 web 服务器与 web 浏览器之间的通信。 HTTP 用于从 web 客户端&#xff08;浏览器&#xff09;向 web 服务器发送请求&#xff0c;并从 web 服务器向 web …

Java 网站开发入门指南:如何用java写一个网站

Java 网站开发入门指南&#xff1a;如何用java写一个网站 Java 作为一门强大的编程语言&#xff0c;在网站开发领域也占据着重要地位。虽然现在 Python、JavaScript 等语言在网站开发中越来越流行&#xff0c;但 Java 凭借其稳定性、可扩展性和丰富的生态系统&#xff0c;仍然…

【CS.AL】算法必学之贪心算法:从入门到进阶 —— 关键概念和代码示例

文章目录 1. 概述2. 适用场景3. 设计步骤4. 优缺点5. 典型应用6. 题目和代码示例6.1 简单题目&#xff1a;找零问题6.2 中等题目&#xff1a;区间调度问题6.3 困难题目&#xff1a;分数背包问题 7. 题目和思路表格8. 总结References 1000.1.CS.AL.1.4-核心-GreedyAlgorithm-Cre…

李永乐线代笔记

线性方程组 解方程组的变换就是矩阵初等行变换 三秩相等 方程组系数矩阵的行秩列秩&#xff0c;线性相关的问题应求列秩&#xff0c;但求行秩方便 齐次线性方程组 对应向量组的线性相关&#xff0c;所以回顾下线性相关的知识&#xff1a; 其中k是x&#xff0c;所以用向…

Leaflet集成wheelnav在WebGIS中的应用

目录 前言 一、两种错误的实现方式 1、组件不展示 2、意外中的空白 二、不同样式的集成 1、在leaflet中集成wheelnav 2、给marker绑定默认组件 2、面对象绑定组件 3、如何自定义样式 三、总结 前言 在之前的博客中&#xff0c;我们曾经介绍了使用wheelnav.js构建酷炫…

http穿透怎么做?

众所周知http协议的默认端口是80&#xff0c;由于国家工信部要求&#xff0c;域名必须备案才给开放80端口&#xff0c;而备案需要固定公网IP&#xff0c;这就使得开放http80端口的费用成本和时间成本变的很高。那么能不能利用内网穿透技术做http穿透呢&#xff1f;下面我就给大…

【C语言】14. qsort 的底层与模拟实现

一、回调函数 回调函数就是⼀个通过函数指针调用的函数。 把函数的指针&#xff08;地址&#xff09;作为参数传递给另⼀个函数&#xff0c;当这个指针被用来调用其所指向的函数时&#xff0c;被调用的函数就是回调函数。回调函数不是由该函数的实现方直接调用&#xff0c;而是…

深入探索 Python 面向对象编程:封装、继承、多态和设计原则

在 Python 中&#xff0c;面向对象编程&#xff08;Object-Oriented Programming&#xff0c;简称 OOP&#xff09;是一种重要的编程范式&#xff0c;它将数据和操作封装在对象中&#xff0c;使得代码更加模块化、可复用和易于维护。 基本语法 Python 中的面向对象编程主要涉…

二分【3】 旋转数组

目录 旋转数组 旋转数组找最小值 旋转数组找指定值 严格递增序列 递增序列 旋转序列找中位数&#xff1a; 旋转数组 旋转数组找最小值 思路 #include <iostream> #include <vector> #include <cmath> #include <string> #include <cstrin…

is not null 、StringUtils.isNotEmpty和StringUtils.isNotBlank之间的区别?

这三者主要是针对对象是否为空、是否为空串和是否为空白字符串有不同的功能。 is not null 只是说明该对象不为空&#xff0c;没有考虑是否为空串和空白字符串。 StringUtils.isNotEmpty检查字符串是否不为 null且长度大于零&#xff0c;不考虑字符串中的空白字符。 StringU…

03通讯录管理系统——菜单功能

功能描述&#xff1a;用户选择功能的界面 菜单界面效果如下图&#xff1a; 步骤&#xff1a; 1.封装函数显示该界面&#xff0c;如void showMenu() 2.在main函数中调用封装好的函数 代码&#xff1a; 运行结果

【INTEL(ALTERA)】Quartus® 软件 Pin Planner 中 Agilex™ 5 FPGA的 HSIO 库可以选择 1.8V VCCIO?

目录 说明 解决方法 说明 由于 Quartus Prime Pro Edition 软件版本 24.1 存在一个问题&#xff0c;Quartus 软件 Pin Planner 中的 I/O 组属性 GUI 允许用户选择 1.8V 作为 HSIO 银行位置的 VCCIO。HSIO bank 支持的有效 VCCIO 电压仅为 1.0V、1.05V、1.1V、1.2V 和 1.3V。…

Java--数组的使用

1.普通For循环&#xff08;用的最多&#xff0c;需从中取出数据以及下标&#xff09; eg&#xff1a;图中三类问题都可 2.For-each循环&#xff08;一般用来打印一些结果&#xff09; eg&#xff1a;打印数组的具体元素 3.数组作方法入参&#xff08;对数组进行一些操作&#x…

蓝牙资讯|苹果iOS 18增加对AirPods Pro 2自适应音频的更多控制

苹果 iOS 18 系统将为 AirPods Pro 2 用户带来一项实用功能 —— 更精细的“自适应音频”控制。AirPods Pro 2 的“自适应音频”功能包含自适应降噪、个性化音量和对话增强等特性&#xff0c;可以根据周围环境自动调节声音和降噪效果。 当更新至最新测试版固件的 AirPods Pro 2…

KVM+GFS分布式存储系统构建高可用群集

KVMGFS 分布式存储系统构建 KVM 高可用群集 一&#xff1a;理论概述 1.1&#xff1a;Glusterfs 简介 Glusterfs 文件系统是由 Gluster 公司的创始人兼首席技术官 Anand Babu Periasamy编写。 一个可扩展的分布式文件系统&#xff0c; 用于大型的、 分布式的、 对大量数据进行访…

泛微开发修炼之旅--15后端开发连接外部数据源,实现在ecology系统中查询其他异构系统数据库得示例和源码

文章链接&#xff1a;15后端开发连接外部数据源&#xff0c;实现在ecology系统中查询其他异构系统数据库得示例和源码

深入理解Java正则表达式及其应用

正则表达式是一种强大的文本匹配和处理工具&#xff0c;可以在字符串中查找、替换、提取符合特定模式的内容。Java作为一种广泛应用的编程语言&#xff0c;提供了丰富的正则表达式支持。本文将深入探讨Java正则表达式的基本概念、语法以及常见应用场景&#xff0c;帮助读者全面…

太速科技-4通道 12bit 125Msps 直流耦合 AD FMC 子卡

4通道 12bit 125Msps 直流耦合 AD FMC 子卡 一、板卡概述: FMC 高速 AD 模块 FL9627 为 4 路 125MSPS&#xff0c; 12 位的模拟信号转数字信号模块。 FMC 模块的 AD 转换采用了 2 片 ADI 公司的 AD9627 芯片&#xff0c;每个 AD9627 芯片支持 2 路 AD 输入转换&#x…