【PYG】Cora数据集分类任务计算损失,cross_entropy为什么不能直接替换成mse_loss

  • cross_entropy计算误差方式,输入向量z为[1,2,3],预测y为[1],选择数为2,计算出一大坨e的式子为3.405,再用-2+3.405计算得到1.405
  • MSE计算误差方式,输入z为[1,2,3],预测向量应该是[1,0,0],和输入向量维度相同

在这里插入图片描述
将cross_entropy直接替换成mse_loss报错RuntimeError: The size of tensor a (7) must match the size of tensor b (140) at non-singleton dimension 1

cross_entropy 换成 mse_loss 会报错的原因是,这两个损失函数的输入和输出形状要求不同。cross_entropy 是一个分类损失函数,它期望输入是未归一化的logits(形状为 [batch_size, num_classes]),而标签是整数类别(形状为 [batch_size])。mse_loss 是一个回归损失函数,它期望输入和标签的形状相同。

如果你想使用 mse_loss 来替代 cross_entropy,你需要对标签进行one-hot编码,使它们与模型的输出形状匹配。下面是如何修改代码以使用 mse_loss 的示例:

修改代码以使用 mse_loss

  1. 加载必要的库
    你需要一个工具来将标签转换为one-hot编码。这里我们使用 torch.nn.functional.one_hot

  2. 修改训练函数
    在训练函数中,将标签转换为one-hot编码,然后计算 mse_loss

核心测试代码讲解

out=model(data)模型输出形状为torch.Size([140, 7])
data.y中测试数据输出形状为torch.Size([140]),打印第一个数据为3,7个类别中的第4个类别
将3转化为7位置独热码计算MSE,对应train_labels_one_hot第一个数据[0., 0., 0., 1., 0., 0., 0.]为4
out形状为torch.Size([140, 7]),train_labels_one_hot的形状为[140, 7]

torch.Size([140, 7]) torch.Size([140])
tensor([-0.0166,  0.0191, -0.0036, -0.0053, -0.0160,  0.0071, -0.0042],device='cuda:0', grad_fn=<SelectBackward0>) tensor(3, device='cuda:0')
tensor([[0., 0., 0., 1., 0., 0., 0.],...[0., 1., 0., 0., 0., 0., 0.]], device='cuda:0')
train_labels_one_hot shape torch.Size([140, 7])
test out torch.Size([2708, 7])
train_labels_one_hot = F.one_hot(data.y[data.train_mask], num_classes=dataset.num_classes).float()
print(out[data.train_mask].shape, data.y[data.train_mask].shape)
print(out[data.train_mask][0], data.y[data.train_mask][0])
print(train_labels_one_hot)
print(f"train_labels_one_hot shape {train_labels_one_hot.shape}")
loss = F.mse_loss(out[data.train_mask], train_labels_one_hot)

解释

  1. 加载库:我们使用 torch.nn.functional.one_hot 将标签转换为one-hot编码。
  2. 修改训练函数
    • 将标签 train_labels 转换为one-hot编码,train_labels_one_hot = F.one_hot(train_labels, num_classes=dataset.num_classes).float()
    • 使用 mse_loss 计算均方误差损失 loss = F.mse_loss(train_out, train_labels_one_hot)
  3. 保持评估函数不变:评估函数仍然使用 argmax 提取预测类别,并计算准确性。

魔改完整代码

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures# 加载Cora数据集
dataset = Planetoid(root='/tmp/Cora', name='Cora',  transform=NormalizeFeatures())
data = dataset[0]# 定义GCN模型
class GCN(torch.nn.Module):def __init__(self):super(GCN, self).__init__()self.conv1 = GCNConv(dataset.num_node_features, 16)self.conv2 = GCNConv(16, dataset.num_classes)def forward(self, data):x, edge_index = data.x, data.edge_indexx = self.conv1(x, edge_index)x = F.relu(x)x = F.dropout(x, training=self.training)x = self.conv2(x, edge_index)return x# return F.log_softmax(x, dim=1)# 初始化模型和优化器
model = GCN()
print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
data = data.to('cuda')
model = model.to('cuda')# 打印归一化后的特征
print(data.x[0])print(f"data.train_mask{data.train_mask}")# 训练模型
def train():model.train()optimizer.zero_grad()out = model(data)# print(f"out[data.train_mask] {data.train_mask.shape} {out[data.train_mask].shape} {out[data.train_mask]}")# loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])train_labels_one_hot = F.one_hot(data.y[data.train_mask], num_classes=dataset.num_classes).float()print(out[data.train_mask].shape, data.y[data.train_mask].shape)print(out[data.train_mask][0], data.y[data.train_mask][0])print(train_labels_one_hot)print(f"train_labels_one_hot shape {train_labels_one_hot.shape}")loss = F.mse_loss(out[data.train_mask], train_labels_one_hot)loss.backward()optimizer.step()return loss.item()# 评估模型
def test():model.eval()out = model(data)print(f"test out {out.shape}")print(f"test out[0] {out[0].shape} {out[0]}")print(f"test out[0:1,:] {out[0:1,:].shape} {out[0:1,:]}")print(f"test out[0:1,:].argmax(dim=1) {out[0:1,:].argmax(dim=1)}")pred = out.argmax(dim=1)print(f"test pred {pred[data.test_mask].shape} {pred[data.test_mask]}")print(f"data {data.y[data.test_mask].shape} {data.y[data.test_mask]}")correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()acc = int(correct) / int(data.test_mask.sum())return accfor epoch in range(1):loss = train()acc = test()print(f'Epoch {epoch+1}, Loss: {loss:.4f}, Accuracy: {acc:.4f}')

原始代码

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures# 加载Cora数据集,并应用NormalizeFeatures变换
dataset = Planetoid(root='/tmp/Cora', name='Cora', transform=NormalizeFeatures())
data = dataset[0]# 计算训练、验证和测试集的大小
num_train = data.train_mask.sum().item()
num_val = data.val_mask.sum().item()
num_test = data.test_mask.sum().item()print(f'Number of training nodes: {num_train}')
print(f'Number of validation nodes: {num_val}')
print(f'Number of test nodes: {num_test}')# 定义GCN模型
class GCN(torch.nn.Module):def __init__(self):super(GCN, self).__init__()self.conv1 = GCNConv(dataset.num_node_features, 16)self.conv2 = GCNConv(16, dataset.num_classes)def forward(self, data):x, edge_index = data.x, data.edge_indexx = self.conv1(x, edge_index)x = F.relu(x)x = F.dropout(x, training=self.training)x = self.conv2(x, edge_index)return x  # 返回未归一化的logits# 初始化模型和优化器
model = GCN()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
data = data.to('cuda')
model = model.to('cuda')# 训练模型
def train():model.train()optimizer.zero_grad()out = model(data)  # out 的形状是 [num_nodes, num_classes]train_out = out[data.train_mask]  # 选择训练集节点的输出train_labels = data.y[data.train_mask]  # 选择训练集节点的标签# 将标签转换为one-hot编码train_labels_one_hot = F.one_hot(train_labels, num_classes=dataset.num_classes).float()# 计算均方误差损失loss = F.mse_loss(train_out, train_labels_one_hot)loss.backward()optimizer.step()return loss.item()# 评估模型
def test():model.eval()out = model(data)pred = out.argmax(dim=1)  # 提取预测类别correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()acc = int(correct) / int(data.test_mask.sum())return accfor epoch in range(200):loss = train()acc = test()print(f'Epoch {epoch+1}, Loss: {loss:.4f}, Accuracy: {acc:.4f}')

通过这些修改,你可以将交叉熵损失函数替换为均方误差损失函数,并确保输入和标签的形状匹配,从而避免报错。

  • 简单版本的的答案

Cross Entropy vs. MSE Loss

  1. Cross Entropy Loss:

    • 输入:模型的logits,形状为 ([N, C]),其中 (N) 是批次大小,(C) 是类别数量。
    • 目标:目标类别的索引,形状为 ([N])。
  2. MSE Loss:

    • 输入:模型的预测值,形状为 ([N, C])。
    • 目标:实际值,形状为 ([N, C])(通常是 one-hot 编码)。

要将 cross_entropy 换成 mse_loss,需要确保输入和目标的形状匹配。具体来说,你需要将目标类别索引转换为 one-hot 编码。

示例代码

假设你有一个分类任务,其中模型输出的是 logits,目标是类别索引。我们将这个设置转换为使用 MSE Loss。

import torch
import torch.nn.functional as F# 假设有一个批次的模型输出和目标标签
logits = torch.tensor([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]], requires_grad=True)  # 模型输出
target = torch.tensor([0, 2])  # 目标类别# 使用 cross_entropy
cross_entropy_loss = F.cross_entropy(logits, target)
print("Cross-Entropy Loss:")
print(cross_entropy_loss)# 转换目标类别为 one-hot 编码
target_one_hot = F.one_hot(target, num_classes=logits.size(1)).float()
print("One-Hot Encoded Targets:")
print(target_one_hot)# 计算 MSE Loss
mse_loss = F.mse_loss(F.softmax(logits, dim=1), target_one_hot)
print("MSE Loss:")
print(mse_loss)

输出

Cross-Entropy Loss:
tensor(1.4076, grad_fn=<NllLossBackward>)
One-Hot Encoded Targets:
tensor([[1., 0., 0.],[0., 0., 1.]])
MSE Loss:
tensor(0.2181, grad_fn=<MseLossBackward>)

解释

  1. logits: 模型的原始输出,形状为 ([N, C])。
  2. target: 原始目标类别索引,形状为 ([N])。
  3. target_one_hot: 将目标类别索引转换为 one-hot 编码,形状为 ([N, C])。
  4. F.mse_loss: 使用 F.softmax(logits, dim=1) 计算模型的概率分布,然后与 target_one_hot 计算 MSE 损失。

通过将目标类别转换为 one-hot 编码并确保输入和目标的形状匹配,可以成功地将 cross_entropy 换成 mse_loss

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/39555.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Dify入门指南

一.Dify介绍 生成式 AI 应用创新引擎&#xff0c;开源的 LLM 应用开发平台。提供从 Agent 构建到 AI workflow 编排、RAG 检索、模型管理等能力&#xff0c;轻松构建和运营生成式 AI 原生应用&#xff0c;比 LangChain 更易用。一个平台&#xff0c;接入全球大型语言模型。不同…

德璞资本:桥水公司如何利用AI实现投资决策的精准提升?

摘要&#xff1a; 在金融科技的浪潮中&#xff0c;桥水公司推出了一只依靠机器学习决策的创新基金&#xff0c;吸引了大量投资者的关注。本文将深入探讨该基金的背景、AI技术的应用、对桥水公司转型的影响&#xff0c;以及未来发展的前景。 新基金背景&#xff1a;桥水公司的创…

2024年7月2日 (周二) 叶子游戏新闻

老板键工具来唤去: 它可以为常用程序自定义快捷键&#xff0c;实现一键唤起、一键隐藏的 Windows 工具&#xff0c;并且支持窗口动态绑定快捷键&#xff08;无需设置自动实现&#xff09;。 卸载工具 HiBitUninstaller: Windows上的软件卸载工具 经典名作30周年新篇《恐怖惊魂夜…

MyBatis入门案例

实施前的准备工作&#xff1a; 1.准备数据库表2.创建一个新的springboot工程&#xff0c;选择引入对应的起步依赖&#xff08;mybatis、mysql驱动、lombok&#xff09;3.在application.properties文件中引入数据库连接信息4.创建对应的实体类Emp&#xff08;实体类属性采用驼峰…

补浏览器环境

一&#xff0c;导言 // global是node中的关键字&#xff08;全局变量&#xff09;&#xff0c;在node中调用其中的元素时&#xff0c;可以直接引用&#xff0c;不用加global前缀&#xff0c;和浏览器中的window类似&#xff1b;在浏览器中可能会使用window前缀&#xff1a;win…

校园水质信息化监管系统——水质监管物联网系统

随着物联网技术的发展越来越成熟&#xff0c;它不断地与人们的日常生活和工作深入融合&#xff0c;推动着社会的进步。其中物联网系统集成在高校实践课程中可以应用到许多项目&#xff0c;如环境气象检测、花卉种植信息化监管、水质信息化监管、校园设施物联网信息化改造、停车…

springboot项目jar包修改数据库配置运行时异常

一、背景 我将软件成功打好jar包了&#xff0c;到部署的时候发现jar包中数据库配置写的有问题&#xff0c;不想再重新打包了&#xff0c;打算直接修改配置文件&#xff0c;结果修改配置后&#xff0c;再通过java -jar运行时就报错了。 二、问题描述 本地项目是springBoot项目…

【计算机图形学 | 基于MFC三维图形开发】期末考试知识点汇总(上)

文章目录 视频教程第一章 计算机图形学概述计算机图形学的定义计算机图形学的应用计算机图形学 vs 图像处理 vs模式识别图形显示器的发展及工作原理理解三维渲染管线 第二章 基本图元的扫描转换扫描转换直线的扫描转换DDA算法Bresenham算法中点画线算法圆的扫描转换中点画圆算法…

【Godot4.2】Godot中的贝塞尔曲线

概述 通过指定平面上的多个点&#xff0c;然后顺次连接&#xff0c;我们可以得到折线段&#xff0c;如果闭合图形&#xff0c;就可以获得多边形。通过向量旋转我们可以获得圆等特殊图形。 但是对于任意曲线&#xff0c;我们无法使用简单的方式来获取其顶点&#xff0c;好在计…

mac上使用finder时候,显示隐藏的文件或者文件夹

默认在finder中是不显示隐藏的文件和文件夹的&#xff0c;但是想创建.gitignore文件&#xff0c;并向里面写入内容&#xff0c;即便是打开xcode也是不显示这几个隐藏文件的&#xff0c;那有什么办法呢&#xff1f; 使用快捷键&#xff1a; 使用finder打开包含隐藏文件的文件夹…

Linux如何安装openjdk1.8

文章目录 Centosyum安装jdk和JRE配置全局环境变量验证ubuntu使用APT(适用于Ubuntu 16.04及以上版本)使用PPA(可选,适用于需要特定版本或旧版Ubuntu)Centos yum安装jdk和JRE yum install java-1.8.0-openjdk-devel.x86_64 安装后的目录 配置全局环境变量 vim /etc/pr…

ISP IC/FPGA设计-第一部分-SC130GS摄像头分析-IIC通信(1)

1.摄像头模组 SC130GS通过一个引脚&#xff08;SPI_I2C_MODE&#xff09;选择使用IIC或SPI配置接口&#xff0c;通过查看摄像头模组的原理图&#xff0c;可知是使用IIC接口&#xff1b; 通过手册可知IIC设备地址通过一个引脚控制&#xff0c;查看摄像头模组的原理图&#xff…

中日区块链“大比拼”!中国蚂蚁加大区块链押注资本!日本索尼进军加密货币市场!

科技巨头在区块链和加密货币领域的动作越来越频繁。近期&#xff0c;中国金融科技巨头蚂蚁集团进一步加大了在区块链业务上的投资&#xff0c;而日本电子科技巨头索尼集团则正式进军加密货币交易领域。这些举措反映了两国对于区块链和加密资产领域的不同态度和布局。 蚂蚁集团加…

disql使用

进入bin目录&#xff1a;cd /opt/dmdbms/bin 启动disql&#xff1a;./disql&#xff0c;然后输入用户名、密码 sh文件直接使用disql&#xff1a; 临时添加路径到PATH环境变量&#xff1a;在当前会话中临时使用disql命令而无需每次都写完整路径&#xff0c;可以在执行脚本之前…

973. 最接近原点的 K 个点-k数组维护+二分查找

973. 最接近原点的 K 个点-k数组维护二分查找 给定一个数组 points &#xff0c;其中 points[i] [xi, yi] 表示 X-Y 平面上的一个点&#xff0c;并且是一个整数 k &#xff0c;返回离原点 (0,0) 最近的 k 个点。 这里&#xff0c;平面上两点之间的距离是 欧几里德距离&#…

初学嵌入式是弄linux还是单片机?

在开始前刚好我有一些资料&#xff0c;是我根据网友给的问题精心整理了一份「单片机的资料从专业入门到高级教程」&#xff0c; 点个关注在评论区回复“666”之后私信回复“666”&#xff0c;全部无偿共享给大家&#xff01;&#xff01;&#xff01;1、先入门了51先学了89c52…

leetcode每日一练:链表OJ题

链表经典算法OJ题 1.1 移除链表元素 题目要求&#xff1a; 给你一个链表的头节点 head 和一个整数 val &#xff0c;请你删除链表中所有满足 Node.val val 的节点&#xff0c;并返回 新的头节点 。 示例 1&#xff1a; 输入&#xff1a;head [1,2,6,3,4,5,6], val 6 输出&a…

模电-二极管及其应用51单片机LED点亮前置工作!

今日小记 2024-7-2&#xff0c;星期二&#xff0c;16:32&#xff0c;天气&#xff1a;晴&#xff0c;心情&#xff1a;晴。持续了两个星期的梅雨天终于暂时过去啦&#xff0c;迎来了久违的阳光&#xff0c;虽然没有雨天凉快&#xff0c;但是能看到太阳也是开心哒&#xff0c;心…

2021强网杯

一、环境 网上自己找 二、步骤 2.1抛出引题 在这个代码中我们反序列&#xff0c;再序列化 <?php$raw O:1:"A":1:{s:1:"a";s:1:"b";};echo serialize(unserialize($raw));//O:1:"A":1:{s:1:"a";s:1:"b";…

工业 web4.0UI 风格品质卓越

工业 web4.0UI 风格品质卓越