softmax回实战

1.数据集

MNIST数据集 (LeCun et al., 1998) 是图像分类中广泛使用的数据集之一,但作为基准数据集过于简单。 我们将使用类似但更复杂的Fashion-MNIST数据集 (Xiao et al., 2017)。

import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2ld2l.use_svg_display()

读取数据集

我们可以通过框架中的内置函数将Fashion-MNIST数据集下载并读取到内存中。

# 通过ToTensor实例将图像数据从PIL类型变换成32位浮点数格式,
# 并除以255使得所有像素的数值均在0~1之间
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)

读取小批量

为了使我们在读取训练集和测试集时更容易,我们使用内置的数据迭代器,而不是从零开始创建。 回顾一下,在每次迭代中,数据加载器每次都会读取一小批量数据,大小为batch_size。 通过内置数据迭代器,我们可以随机打乱了所有样本,从而无偏见地读取小批量。

batch_size = 256def get_dataloader_workers():  #@save"""使用4个进程来读取数据"""return 4train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers())

整合所有组件

def load_data_fashion_mnist(batch_size, resize=None):  #@save"""下载Fashion-MNIST数据集,然后将其加载到内存中"""trans = [transforms.ToTensor()]if resize:trans.insert(0, transforms.Resize(resize))trans = transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)return (data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers()),data.DataLoader(mnist_test, batch_size, shuffle=False,num_workers=get_dataloader_workers()))

我们现在已经准备好使用Fashion-MNIST数据集,便于下面的章节调用来评估各种分类算法。

  • 数据迭代器是获得更高性能的关键组件。依靠实现良好的数据迭代器,利用高性能计算来避免减慢训练过程。

2.初始化模型参数

在softmax回归中,我们的输出与类别一样多。 因为我们的数据集有10个类别,所以网络输出维度为10。 因此,权重将构成一个784×10的矩阵, 偏置将构成一个1×10的行向量。

num_inputs = 784
num_outputs = 10W = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True)
b = torch.zeros(num_outputs, requires_grad=True)

3.定义softmax操作

def softmax(X):X_exp = torch.exp(X)partition = X_exp.sum(1, keepdim=True)return X_exp / partition  # 这里应用了广播机制

正如上述代码,对于任何随机输入,我们将每个元素变成一个非负数。 此外,依据概率原理,每行总和为1。

X = torch.normal(0, 1, (2, 5))
X_prob = softmax(X)
X_prob, X_prob.sum(1)(tensor([[0.1686, 0.4055, 0.0849, 0.1064, 0.2347],[0.0217, 0.2652, 0.6354, 0.0457, 0.0321]]),tensor([1.0000, 1.0000]))

4.定义模型

定义softmax操作后,我们可以实现softmax回归模型。 下面的代码定义了输入如何通过网络映射到输出。 注意,将数据传递到模型之前,我们使用reshape函数将每张原始图像展平为向量。

def net(X):return softmax(torch.matmul(X.reshape((-1, W.shape[0])), W) + b)

5.定义损失函数

接下来,我们引入交叉熵损失函数。 这可能是深度学习中最常见的损失函数,因为目前分类问题的数量远远超过回归问题的数量。

def cross_entropy(y_hat, y):return - torch.log(y_hat[range(len(y_hat)), y])

小批量随机梯度下降来优化模型的损失函数,设置学习率为0.1

lr = 0.1def updater(batch_size):return d2l.sgd([W, b], lr, batch_size)

6.分类精度

为了计算精度,我们执行以下操作。 首先,如果y_hat是矩阵,那么假定第二个维度存储每个类的预测分数。 我们使用argmax获得每行中最大元素的索引来获得预测类别。 然后我们将预测类别与真实y元素进行比较。 由于等式运算符“==”对数据类型很敏感, 因此我们将y_hat的数据类型转换为与y的数据类型一致。 结果是一个包含0(错)和1(对)的张量。 最后,我们求和会得到正确预测的数量。

def accuracy(y_hat, y):  #@save"""计算预测正确的数量"""if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:y_hat = y_hat.argmax(axis=1)cmp = y_hat.type(y.dtype) == yreturn float(cmp.type(y.dtype).sum())

7.训练

在这里,我们重构训练过程的实现以使其可重复使用。 首先,我们定义一个函数来训练一个迭代周期。 请注意,updater是更新模型参数的常用函数,它接受批量大小作为参数。

def train_epoch_ch3(net, train_iter, loss, updater):  #@save"""训练模型一个迭代周期(定义见第3章)"""# 将模型设置为训练模式if isinstance(net, torch.nn.Module):net.train()# 训练损失总和、训练准确度总和、样本数metric = Accumulator(3)for X, y in train_iter:# 计算梯度并更新参数y_hat = net(X)l = loss(y_hat, y)if isinstance(updater, torch.optim.Optimizer):# 使用PyTorch内置的优化器和损失函数updater.zero_grad()l.mean().backward()updater.step()else:# 使用定制的优化器和损失函数l.sum().backward()updater(X.shape[0])metric.add(float(l.sum()), accuracy(y_hat, y), y.numel())# 返回训练损失和训练精度return metric[0] / metric[2], metric[1] / metric[2]

在展示训练函数的实现之前,我们定义一个在动画中绘制数据的实用程序类Animator, 它能够简化本书其余部分的代码。

class Animator:  #@save"""在动画中绘制数据"""def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,ylim=None, xscale='linear', yscale='linear',fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1,figsize=(3.5, 2.5)):# 增量地绘制多条线if legend is None:legend = []d2l.use_svg_display()self.fig, self.axes = d2l.plt.subplots(nrows, ncols, figsize=figsize)if nrows * ncols == 1:self.axes = [self.axes, ]# 使用lambda函数捕获参数self.config_axes = lambda: d2l.set_axes(self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)self.X, self.Y, self.fmts = None, None, fmtsdef add(self, x, y):# 向图表中添加多个数据点if not hasattr(y, "__len__"):y = [y]n = len(y)if not hasattr(x, "__len__"):x = [x] * nif not self.X:self.X = [[] for _ in range(n)]if not self.Y:self.Y = [[] for _ in range(n)]for i, (a, b) in enumerate(zip(x, y)):if a is not None and b is not None:self.X[i].append(a)self.Y[i].append(b)self.axes[0].cla()for x, y, fmt in zip(self.X, self.Y, self.fmts):self.axes[0].plot(x, y, fmt)self.config_axes()display.display(self.fig)display.clear_output(wait=True)

接下来我们实现一个训练函数, 它会在train_iter访问到的训练数据集上训练一个模型net。 该训练函数将会运行多个迭代周期(由num_epochs指定)。 在每个迭代周期结束时,利用test_iter访问到的测试数据集对模型进行评估。 我们将利用Animator类来可视化训练进度。

def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater):  #@save"""训练模型(定义见第3章)"""animator = Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0.3, 0.9],legend=['train loss', 'train acc', 'test acc'])for epoch in range(num_epochs):train_metrics = train_epoch_ch3(net, train_iter, loss, updater)test_acc = evaluate_accuracy(net, test_iter)animator.add(epoch + 1, train_metrics + (test_acc,))train_loss, train_acc = train_metricsassert train_loss < 0.5, train_lossassert train_acc <= 1 and train_acc > 0.7, train_accassert test_acc <= 1 and test_acc > 0.7, test_acc

训练:

num_epochs = 10
train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, updater)

8.预测

现在训练已经完成,我们的模型已经准备好对图像进行分类预测。 给定一系列图像,我们将比较它们的实际标签(文本输出的第一行)和模型预测(文本输出的第二行)。

def predict_ch3(net, test_iter, n=6):  #@save"""预测标签(定义见第3章)"""for X, y in test_iter:breaktrues = d2l.get_fashion_mnist_labels(y)preds = d2l.get_fashion_mnist_labels(net(X).argmax(axis=1))titles = [true +'\n' + pred for true, pred in zip(trues, preds)]d2l.show_images(X[0:n].reshape((n, 28, 28)), 1, n, titles=titles[0:n])predict_ch3(net, test_iter)

9.总结:

  • 借助softmax回归,我们可以训练多分类的模型。
  • 训练softmax回归循环模型与训练线性回归模型非常相似:先读取数据,再定义模型和损失函数,然后使用优化算法训练模型。大多数常见的深度学习模型都有类似的训练过程。

10.简洁实现:

10.1读取数据

import torch
from torch import nn
from d2l import torch as d2lbatch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

10.2定义模型,初始化参数

跟线性规划模型是一样的,只不过有十个输出

# PyTorch不会隐式地调整输入的形状。因此,
# 我们在线性层前定义了展平层(flatten),来调整网络输入的形状
net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))def init_weights(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, std=0.01)net.apply(init_weights);

10.3定义损失函数

loss = nn.CrossEntropyLoss(reduction='none')

10.4定义优化器

我们使用学习率为0.1的小批量随机梯度下降作为优化算法

updater = torch.optim.SGD(net.parameters(), lr=0.1)

10.5训练

调用上面写好的函数,只不过这次的模型,损失函数,优化器都是pytorch内置的

num_epochs = 10
train_ch3(net, train_iter, test_iter, loss, num_epochs, updater)

10.6预测:

predict_ch3(net, test_iter)

10.7小结

  • 使用深度学习框架的高级API,我们可以更简洁地实现softmax回归。
  • 从计算的角度来看,实现softmax回归比较复杂。在许多情况下,深度学习框架在这些著名的技巧之外采取了额外的预防措施,来确保数值的稳定性。这使我们避免了在实践中从零开始编写模型时可能遇到的陷阱。
    数,优化器**都是pytorch内置的
num_epochs = 10
train_ch3(net, train_iter, test_iter, loss, num_epochs, updater)

10.6预测:

predict_ch3(net, test_iter)

10.7小结

  • 使用深度学习框架的高级API,我们可以更简洁地实现softmax回归。
  • 从计算的角度来看,实现softmax回归比较复杂。在许多情况下,深度学习框架在这些著名的技巧之外采取了额外的预防措施,来确保数值的稳定性。这使我们避免了在实践中从零开始编写模型时可能遇到的陷阱。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/636624.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

JVM篇--垃圾回收高频面试题

JVM垃圾回收 1 简单说下Java垃圾回收机制&#xff1f; 首先在java运行过程中&#xff0c;其实程序员并不需要去显示的调用程序来释放对象的内存&#xff0c;而是由虚拟机来完成的&#xff0c;具体来看是在jvm中有个垃圾回收线程&#xff0c;这个线程是个守护线程&#xff0c;…

Leetcode 用队列实现栈

题目&#xff1a; 请你仅使用两个队列实现一个后入先出&#xff08;LIFO&#xff09;的栈&#xff0c;并支持普通栈的全部四种操作&#xff08;push、top、pop 和 empty&#xff09;。 实现 MyStack 类&#xff1a; void push(int x) 将元素 x 压入栈顶。 int pop() 移除并…

【c++笔记】用c++解决一系列质数问题!

质数是c语言和c中比较常见的数学问题&#xff0c;本篇文章将带你走进有关质数的一系列基础问题&#xff0c;其中包含常见的思路总结&#xff0c;本篇文章过后&#xff0c;将会持续更新c算法系列&#xff0c;感兴趣的话麻烦点个关注吧&#xff01; 希望能给您带来帮助&#xff…

零基础学Python(2)— 安装Python开发工具之PyCharm

前言&#xff1a;Hello大家好&#xff0c;我是小哥谈。PyCharm是由JetBrains公司开发的一款Python开发工具。在Windows、Mac OS和Linux操作系统中都可以使用。它具有语法高亮显示、Project&#xff08;项目&#xff09;管理代码跳转、智能提示、自动完成、调试、单元测试和版本…

C++——函数的常见样式

常见的函数样式有4种&#xff0c;即在函数定义过程中函数的四种格式&#xff0c;他们也分别对应了四种调用方法&#xff1a; 1&#xff0c;无参无返 2&#xff0c;有参无返 3&#xff0c;无参有返 4&#xff0c;有参有返 示例&#xff1a; #include<bits/stdc.h> u…

x-cmd pkg | yt-dlp - 专注于 YouTube 的下载工具

目录 简介首次用户功能特点竞品和相关作品进一步探索 简介 yt-dlp 是一款强大的命令行下载工具&#xff0c;专注于下载 YouTube 视频和音频。它是 youtube-dl 的一个改进和拓展版本&#xff0c;提供了更多功能和修复了一些问题。 yt-dlp 具有灵活的支持&#xff0c;可下载 Yo…

武汉灰京文化:手游市场游戏体验和社交互动的新趋势

随着移动设备的普及和技术的不断提高&#xff0c;手游市场正在迎来全新的发展时期。用户的游戏习惯正在发生重大变化&#xff0c;他们越来越倾向于随时随地玩游戏。手游的便携性使得用户可以在公交车上、休息时间或等待朋友时轻松进行游戏。这种随时随地的游戏体验满足了现代生…

【成本价特惠】招募证书代理:工信部、PMP、阿里云、华为等认证,机会难得!

扫码和我联系 亲爱的读者朋友们&#xff0c; 今天&#xff0c;我想和大家分享一个难得的机会。我们目前正在积极招募各类证书的代理&#xff0c;包括工信部的证书、PMP&#xff08;项目管理专业人士&#xff09;证书、阿里云证书、华为证书、OCP 证书、CFA 证书等。这些证书在…

最大流—EK算法,流网络,残留网络,定理证明,详细代码

文章目录 零、卡车运输一、流网络1.1流网络1.2流1.3最大流1.4残留网络1.5增广路径1.6流网络的割1.7最大流最小割定理1.7.1证明 1.8Ford-Fulkerson方法 二、Edmonds-Karp算法2.1定义2.2EK算法的实现2.3EK算法详细代码2.4OJ练习 零、卡车运输 Lucky Puck公司有冰球工厂Vancouver…

Unity导出Android项目踩坑记录

导出的时候需要注意以下地方的配置&#xff1a; 1、buildSetting-> 设置ExportProject 2、buildsetting ->playerSetting ->设置IL2CPP 3、设置ndk edit->preferences->external tools->ndk 如果unity的ndk版本和android项目里的ndk版本不一致会报错&…

【Qt开发】初识Qt

文章目录 1. Qt的背景1.1 Qt是什么1.2 Qt的发展史1.3 Qt支持的平台 2. Qt开发环境的搭建2.1 Qt SDK下载2.2 Qt SDK的安装 3. 一个简单的Qt模板程序的创建4. Qt模板程序的代码讲解4.1 main.cpp4.2 widget.h4.3 widget.cpp4.4 widget.ui4.5 test_1_18.pro4.6 一些中间文件 5. Qt在…

keil软件仿真

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言一、pandas是什么&#xff1f;二、使用步骤 1.引入库2.读入数据总结 前言 例如&#xff1a;随着人工智能的不断发展&#xff0c;机器学习这门技术也越来越重要…

vue基于Spring Boot框架的甘肃敦煌文化旅游管理系统

本敦煌文化旅游管理系统是为了提高用户查阅信息的效率和管理人员管理信息的工作效率&#xff0c;可以快速存储大量数据&#xff0c;还有信息检索功能&#xff0c;这大大的满足了用户和管理员这两者的需求。操作简单易懂&#xff0c;合理分析各个模块的功能&#xff0c;尽可能优…

(蓝桥杯每日一题)平方末尾及补充(常用的字符串函数功能)

能够表示为某个整数的平方的数字称为“平方数 虽然无法立即说出某个数是平方数&#xff0c;但经常可以断定某个数不是平方数。因为平方数的末位只可能是:0,1,4,5,6,9 这 6 个数字中的某个。所以&#xff0c;4325435332 必然不是平方数。 如果给你一个 2 位或 2 位以上的数字&am…

六、标准对话框、多应用窗体

一、标准对话框 Qt提供了一些常用的标准对话框&#xff0c;如打开文件对话框、选择颜色对话框、信息提示和确认选择对话框、标准输入对话框等。1、预定义标准对话框 &#xff08;1&#xff09;QFileDialog 文件对话框 QString getOpenFileName() 打开一个文件QstringList ge…

You need to add dependency of ‘poi-ooxml‘ to your project, and version >= 4.1.2

原因 由于在依赖中引用了多个版本的 hutool,导致在最终打包时使用的版本不是由在开发时所引用的版本 cn.hutool.core.exceptions.DependencyException: You need to add dependency of poi-ooxml to your project, and version > 4.1.2at cn.hutool.poi.excel.ExcelUtil.get…

MyBatis-Plus 日常操作

本文主要介绍 mybatis-plus 日常操作。 一、快速开始 本文基于 springboot、maven、jdk1.8、mysql 环境。 新建如下数据库&#xff1a; 建议大家选择 utf8mb4 这种字符集&#xff0c;做过微信的同学应该会知道&#xff0c;微信用户名称的表情&#xff0c;是需要这种字符集才…

基于python旅游推荐系统 协同过滤算法 爬虫 Echarts可视化 Django框架(源码)✅

毕业设计&#xff1a;2023-2024年计算机专业毕业设计选题汇总&#xff08;建议收藏&#xff09; 毕业设计&#xff1a;2023-2024年最新最全计算机专业毕设选题推荐汇总 &#x1f345;感兴趣的可以先收藏起来&#xff0c;点赞、关注不迷路&#xff0c;大家在毕设选题&#xff…

Android14之DefaultKeyedVector实现(一百八十二)

简介&#xff1a; CSDN博客专家&#xff0c;专注Android/Linux系统&#xff0c;分享多mic语音方案、音视频、编解码等技术&#xff0c;与大家一起成长&#xff01; 优质专栏&#xff1a;Audio工程师进阶系列【原创干货持续更新中……】&#x1f680; 优质专栏&#xff1a;多媒…

计算机组成原理 第一弹

ps&#xff1a;本文章的图片来源都是来自于湖科大教书匠高老师的视频&#xff0c;声明&#xff1a;仅供自己复习&#xff0c;里面加上了自己的理解 这里附上视频链接地址&#xff1a;1-2 计算机的发展_哔哩哔哩_bilibili ​​ 目录 &#x1f680;计算机系统 &#x1f680;计…