【PyTorch】softmax回归

文章目录

  • 1. 模型与代码实现
    • 1.1. 模型
    • 1.2. 代码实现
  • 2. Q&A

1. 模型与代码实现

1.1. 模型

  • 背景
    在分类问题中,模型的输出层是全连接层,每个类别对应一个输出。我们希望模型的输出 y ^ j \hat{y}_j y^j可以视为属于类 j j j的概率,然后选择具有最大输出值的类别作为我们的预测。
    softmax函数能够将未规范化的输出变换为非负数并且总和为1,同时让模型保持可导的性质,而且不会改变未规范化的输出之间的大小次序
  • softmax函数
    y ^ = s o f t m a x ( o ) \mathbf{\hat{y}}=\mathrm{softmax}(\mathbf{o}) y^=softmax(o)其中 y ^ j = e x p ( o j ) ∑ k e x p ( o k ) \hat{y}_j=\frac{\mathrm{exp}({o_j})}{\sum_{k}\mathrm{exp}({o_k})} y^j=kexp(ok)exp(oj)
  • softmax是一个非线性函数,但softmax回归的输出仍然由输入特征的仿射变换决定,因此,softmax回归是一个线性模型
  • 为了避免将softmax的输出直接送入交叉熵损失造成的数值稳定性问题,将softmax和交叉熵损失结合在一起,具体做法是:不将softmax概率传递到损失函数中, 而是在交叉熵损失函数中传递未规范化的输出,并同时计算softmax及其对数。因此定义交叉熵损失函数时也进行了softmax运算

1.2. 代码实现

import torch
from torchvision.datasets import FashionMNIST
from torchvision import transforms
from torch.utils.data import DataLoader
from torch import nn
from tensorboardX import SummaryWriter# 全局参数设置
batch_size = 256
num_workers = 0
num_epochs = 10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')writer = SummaryWriter()# 加载数据集
root = "./dataset"
transform = transforms.Compose([transforms.ToTensor()])
mnist_train = FashionMNIST(root=root, train=True, transform=transform, download=True
)
mnist_test = FashionMNIST(root=root, train=False, transform=transform, download=True
)
dataloader_train = DataLoader(mnist_train, batch_size, shuffle=True, num_workers=num_workers
)
dataloader_test = DataLoader(mnist_test, batch_size, shuffle=False,num_workers=num_workers
)# 定义神经网络
net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10)).to(device)# 初始化模型参数
def init_weights(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, mean=0, std=0.01)nn.init.constant_(m.bias, val=0)
net.apply(init_weights)# 定义损失函数
criterion = nn.CrossEntropyLoss(reduction='none')# 定义优化器
optimizer = torch.optim.SGD(net.parameters(), lr=0.1)class Accumulator:"""在n个变量上累加"""def __init__(self, n):self.data = [0.0] * ndef add(self, *args):self.data = [a + float(b) for a, b in zip(self.data, args)]def reset(self):self.data = [0.0] * len(self.data)def __getitem__(self, idx):return self.data[idx]def accuracy(y_hat, y):"""计算预测正确的数量"""if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:y_hat = y_hat.argmax(axis=1)cmp = y_hat.type(y.dtype) == yreturn float(cmp.type(y.dtype).sum())for epoch in range(num_epochs):# 训练模型net.train()train_metrics = Accumulator(3)  # 训练损失总和、训练准确度总和、样本数for X, y in dataloader_train:X, y = X.to(device), y.to(device)y_hat = net(X)loss = criterion(y_hat, y)optimizer.zero_grad()loss.mean().backward()optimizer.step()train_metrics.add(float(loss.sum()), accuracy(y_hat, y), y.numel())train_loss = train_metrics[0] / train_metrics[2]train_acc = train_metrics[1] / train_metrics[2]# 测试模型net.eval()with torch.no_grad():    test_metrics = Accumulator(2)   # 测试准确度总和、样本数for X, y in dataloader_test:X, y = X.to(device), y.to(device)y_hat = net(X)loss = criterion(y_hat, y)test_metrics.add(accuracy(y_hat, y), y.numel())test_acc = test_metrics[0] / test_metrics[1]writer.add_scalars("metrics", {'train_loss': train_loss, 'train_acc': train_acc, 'test_acc': test_acc}, epoch)
writer.close()   

输出结果:
tensorboard

2. Q&A

  • 运行过程中出现以下警告:

UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at …\torch\csrc\utils\tensor_numpy.cpp:180.)
return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)

该警告的大致意思是给定的NumPy数组不可写,并且PyTorch不支持不可写的张量。这意味着你可以使用张量写入底层(假定不可写)NumPy数组。在将数组转换为张量之前,可能需要复制数组以保护其数据或使其可写。在本程序的其余部分,此类警告将被抑制。因此需要修改C:\Users\%UserName%\anaconda3\envs\%conda_env_name%\lib\site-packages\torchvision\datasets\mnist.py的第498行,将return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)中的False改成True

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/195498.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ArcGIS提示当前许可不支持影像服务器

1、问题&#xff1a; 在用ArcGIS上处理影像栅格数据时&#xff08;比如栅格数据集裁剪、镶嵌数据集构建镶嵌线等&#xff09;经常会出现。 无法启动配置 RasterComander.ImageServer <详信息 在计算机XXXXX上创建服务器对象实例失败 当前许可不支持影像服务器。 ArcGIS提示当…

Python的模块与库,及if __name__ == ‘__main__语句【侯小啾Python基础领航计划 系列(二十四)】

Python的模块与库,及if name == ‘__main__语句【侯小啾Python基础领航计划 系列(二十四)】 大家好,我是博主侯小啾, 🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔…

MYSQL练题笔记-聚合函数-各赛事的用户注册率

一、题目相关内容 1&#xff09;相关的表 2&#xff09;题目 3&#xff09;帮助理解题目的示例&#xff0c;提供返回结果的格式 二、自己初步的理解 有两张不同左右的表&#xff0c;用户表和赛事注册表。然后解题。 1.各种赛事的用户注册百分率 各种赛事的意味着通过contes…

synchronized底层原理(一)

文章目录 1. 问题引入2. 相关概念3. Synchronized使用4. Synchronized底层原理1. 简介2. Monitor&#xff08;管程/监视器&#xff09;3. Java语言的内置管程synchronized4. Java对象的内存布局5. 如何使用MarkWord记录锁状态6. 偏向锁7. 轻量级锁 1. 问题引入 假设我们有1000…

Spring cloud - gateway

什么是Spring Cloud Gateway 先去看下官网的解释&#xff1a; This project provides an API Gateway built on top of the Spring Ecosystem, including: Spring 6, Spring Boot 3 and Project Reactor. Spring Cloud Gateway aims to provide a simple, yet effective way t…

Git:分布式版本控制系统的崛起与演变

简介 Git是一个开源的分布式版本控制系统&#xff0c;旨在有效、高速地处理从很小到非常大的项目版本管理。它是由Linus Torvalds于2005年创建的&#xff0c;最初是为了服务于Linux内核开发的版本控制需求。Git通过强大的分支功能、高效的缓存机制以及可扩展的架构设计&#xf…

Golang 并发 — 流水线

并发模式 我们可以将流水线理解为一组由通道连接并由 goroutine 处理的阶段。每个阶段都被定义为执行特定的任务&#xff0c;并按顺序执行&#xff0c;下一个阶段在前一个阶段完成后开始执行。 流水线的另一个重要特性是&#xff0c;除了连接在一起&#xff0c;每个阶段都使用…

统信UOS_麒麟KYLINOS配置apt及git内网代理

原文链接&#xff1a;统信UOS/麒麟KYLINOS上配置APT和GIT内网代理 **hello&#xff0c;大家好啊&#xff01;**在企业环境中&#xff0c;出于安全和管理的考虑&#xff0c;很多公司会设置内网代理服务器&#xff0c;以控制和监管内部网络的访问。这就意味着&#xff0c;员工在使…

jsp多站点图书管理系统Myeclipse开发mysql数据库web结构java编程计算机网页项目

一、源码特点 JSP 多站点图书管理系统是一套完善的java web信息管理系统&#xff0c;对理解JSP java编程开发语言有帮助&#xff0c;系统具有完整的源代码和数据库&#xff0c;系统主要采用B/S模式开发。开发环境为 TOMCAT7.0,Myeclipse8.5开发&#xff0c;数据库为Mysql5…

git常用命令小记

&#xff08;文章正在持续更新中&#xff09; git init - 在当前目录下初始化一个新的 Git 仓库。 git clone [url] - 克隆远程仓库到本地。 git add [file] - 将文件添加到暂存区。 git commit -m "commit message" - 将添加到暂存区的文件提交到本地仓库。 git pus…

STM32 Nucleo-64 boards 外设资源引脚对应关系图

STM32 Nucleo-64 boards 外设资源引脚对应关系图 1. STM32 NUCLEO-F103RB1.1 串口对应关系图1.2 I2C对应关系图 【参考博文】 1. STM32 NUCLEO-F103RB 1.1 串口对应关系图 1.2 I2C对应关系图 注意&#xff1a;STM32 NUCLEO-F103RB 在Arduino 端子分配的 I2C 重映射为 PB8 PB9 …

直击2023云栖大会-大模型时代到来:“计算,为了无法计算的价值”

2023年的云栖大会以“计算&#xff0c;为了无法计算的价值”为主题&#xff0c;强调了计算技术在现代社会中的重要性&#xff0c;特别是在大模型时代到来的背景下。 大模型时代指的是以深度学习为代表的人工智能技术的快速发展&#xff0c;这些技术需要大量的计算资源来训练和优…

深度学习设计基于Tensorflow卷积神经网络猫的品种识别系统

欢迎大家点赞、收藏、关注、评论啦 &#xff0c;由于篇幅有限&#xff0c;只展示了部分核心代码。 文章目录 一项目简介 二、功能三、系统四. 总结 一项目简介 基于Tensorflow卷积神经网络的猫的品种识别系统可以用于自动识别猫的品种类型。下面我将为您介绍一下这个系统的基本…

Python函数的基本使用(一)

Python函数的基本使用&#xff08;一&#xff09; 一、函数概述二、函数的定义2.1 函数的语法2.2 语法说明2.3 函数定义的方式2.4 总结 三、函数的调用3.1 函数调用语法3.2 语法说明3.3 函数调用 四、函数的参数4.1 参数的分类4.2 必需参数4.3 默认值参数4.4 关键字参数4.5 不定…

路由策略,gRPC 路由如何实现

目录 一、为啥我们要路由策略&#xff1a; 二、基于gRPC 路由策略 一、为啥我们要路由策略&#xff1a; 我们可以重新回到调用方发起 RPC 调用的流程。在 RPC 发起真实请求的时候&#xff0c;有一个步骤就是从服务提供方节点集合里面选择一个合适的节点&#xff08;就是我们…

保育员个人简历精选7篇

想要在保育员职位的求职过程中脱颖而出吗&#xff0c;参考这7篇精选的保育员简历案例&#xff01;无论您的经验如何&#xff0c;都能找到适合自己的简历样式及参考内容。 保育员个人简历模板下载&#xff08;可在线编辑制作&#xff09;&#xff1a;来幻主简历&#xff0c;做好…

微服务的流量管理-服务网格

对于单体应用来说&#xff0c;一般只有流入和流出两种流量。而微服务架构引入了跨进程的网络通信&#xff0c;流量发生在服务之间。由许多服务组成了复杂的网络拓扑结构&#xff0c;每次请求都会产生流量。 这些流量如果没有妥善的管理&#xff0c;整个应用的行为和状态将会不…

封装Servlet使用自定义注解进行参数接收

文章目录 前言一、前后对比✨二、具体实现&#x1f387;三、效果展示&#x1f38f; 前言 先说项目背景&#xff0c;本项目是本人在校期间老师布置的作业&#xff08;就一个CRUD&#xff09;&#xff0c;课程是后端应用程序设计&#xff0c;其实就是servlet和jsp那一套&#xf…

【c】课程满意度计算

我们不好直接比较二维数组中任意多个元素的值是否相等&#xff0c;我们可以创建一维数组&#xff0c;首先将一维数组的值全部设为0&#xff0c;一维数组的下标代表你喜欢课程的量&#xff0c;一维数组的各个元素的值代表你喜欢的次数 例如 你输入3 5&#xff0c;代表你喜欢第三…

好用的挂耳式蓝牙耳机有哪些?分享几款热门好用的蓝牙耳机

挂耳式蓝牙耳机已经成为我们日常生活中的一部分&#xff0c;无论是在通勤、运动还是日常休闲时&#xff0c;它们都发挥着不可替代的作用&#xff0c;随着技术的不断进步&#xff0c;挂耳式蓝牙耳机的音质、连接稳定性以及续航时间都有了显著的提升&#xff0c;下面&#xff0c;…