[deeplearning]pytorch实现softmax多分类问题预测训练

写在前面:俺这两天也是刚刚加入实验室,因为之前的学习过程中用到更多的框架是tensorflow,所以突然上手pytorch多少有些力不从心了。

这两个框架的主要区别在与tensorflow更偏向于工业使用,所以里面的很多函数和类都已经封装得很完整了,直接调用,甚至连w,b等尺寸都会自动调整。但是pytorch更加偏向于学术,。。。。或者说更加偏向于数学,很多功能都需要我们自己手动去实现:

刚刚跟这d2l的课程学习了如何去实现最基本的神经网络和计算,这里使用当时学过的solfmax作为经典案例,作为一个简单的补充,我会在这里面简单讲解一下softmax是怎么实现的,以及一些库函数

纯手动实现:

其实是有一些更高级别的api可以调用,比如损失函数就不用我们自己手写,但是训练的过程还是要的。

1.获取一些数据,这里我们通过一个特殊数据集合来或去数据

#先凑成一个数据集合
batch_size = 256
#这里好像就上面那么恶心了,直接从这个数据集合中获取数据
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

这里注意一个问题,batch_size不是你获取到的全部数据,而是你确定每一批数据的大小

接下来根据这个大小,获取多批数据,然后保存为训练集合以及测试集合

(由于我们这里要的事情非常简单,所以我们不验证)

2.我们开始创建一层神经元,输出为10个分量的估计数值

#初始化参数
num_inputs = 784      #输入,也就是特征值的数目为784
num_outputs = 10      #输出也就是softmax层神经元的数目,10#这段代码用于构建某一层的w和b,并且先将其初始化
W = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True)
b = torch.zeros (num_outputs, requires_grad=True)

这里w和b是仅仅是一对数字,而是一个完整的对象,除了基本的数值以外,还能存储一些注入诸如梯度等等信息。代表了这一层神经元的具体情况。

这个layer构建出来的神经元其实就是10个神经元,每个神经元支持的输入为784个特征。

3.创建solftmax函数,这个函数内部将会对神经网络的输出作出一些处理

#创建一个softmax函数,用来完成最后的softmax操作
#X在这里应该是一个10个分量的tensor,下面的函数就是正常的softmax操作
def softmax(X):X_exp = torch.exp(X)partition = X_exp.sum(1, keepdim=True) #沿着列展开的方向求和return X_exp / partition               #这里应用了广播机制

我们先进行指数化,然后求和,最后使用广播技术(其实这个所谓的广播也算是线性代数计算时候的基本特征了)得到一个(归一化)的tensor(所有分量相加为1,符合我们先是生活中对事物的预测逻辑,比如:连衣裙可能性0.55,鞋子可能性0.25,帽子可能性0.20)

4.然后是定义最核心的预测函数,称之为网络本身到也可以

#定义一个神经网络
#其实说是神经网络,这里只是进行了一个简单的数据变换,然后计算wx+b
#最后计算出来的结果因为是matmul的矩阵乘法,而且w和b本身也是size=10 的 tensor
#所以计算结果也是一样大小的tensor,然后就可以放心进行softmax操作
def net(X):return softmax(torch.matmul(X.reshape((-1, W.shape[0])), W) + b)

其实这个就是对于十个神经元,然后进行计算操作,得到估计数值

其实直接返回torch.matmul(X.reshape((-1, W.shape[0])), W) + b的话就变化成一个很常见的10线性回归了,在这里可以很清楚的看到softmax实现的是一个激活函数的作用

5.定义损失函数loss function


def cross_entropy(y_hat, y):return - torch.log(y_hat[range(len(y_hat)), y])

这个东西稍微有一点点复杂。。。

首先先解释一下这个东西

y_hat[  range(len(y_hat))  ,   y )

首先要先说明一点就是,y_hat是预测数值,一个二维tensor,比如说其中的第一条数据

[0.22,0.23,0.35.........]这代表的是某一个物体的预测结果,在10个标签中每一种可能性的概率

y则是一个一维tensor,每个分量代表的是该物体到底是什么,是确切数值

而这个[]中携带两个tensor的语法,被称为“高级索引”

#补充一下:这个语法的名字叫做高级索引,是从二维矩阵中选择出一个一维tensor
#第一个tensor是选择哪些行,这里选择所有行
#第二个是选择有哪些列
#在这个数据中我们实现的效果就是
#y-hat是一个二维tensor,每行是一个数据,每一列是对不同类型的预测
#y。。。严格来说是一个一维tensor,每个分量代表第i个数据到底是什么标签
#也就是说这个的逻辑意义是:每条数据猜对的概率?差不多可以这样子理解

6.优化/迭代函数

其实这个部分就是我们迭代,gradient descent 时候的操作

所谓的梯度就是求得的偏导数

#优化函数,其实这玩意就是我们的迭代函数,就那个repeat部分的东西,0.1是learning rate
def updater(batch_size):return d2l.sgd([W, b], 0.1, batch_size)

sgd就是d2l包下内置的“随机 gd”函数,这个里面梯度已经保存起来了

7.创建单次训练函数

#把模型训练了
def train_epoch_ch3(net, train_iter, loss, updater):  #@save# 将模型设置为训练模式if isinstance(net, torch.nn.Module):net.train()for X, y in train_iter:# loss是已经封装好的损失计算函数l = loss(net(X), y)# 使用定制的优化器和损失函数l.sum().backward()           #计算梯度,也就是代价函数导的东西updater(X.shape[0])          #梯度在这里好像是没有传入进来,但是实际上已经保存在w和b中了,对所有的w和b进行迭代计算

这个函数执行一次也就是一次训练

8.训练10次

#训练函数def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater):  #@save"""训练模型(定义见第3章)"""for epoch in range(num_epochs):train_epoch_ch3(net, train_iter, loss, updater) # 直接就是训练了,不验证了#开始训练
num_epochs = 10
train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, updater)

这里我们直接根据训练集合进行验证

9.最后进行预测以及可视化展示

#预测函数
def predict_ch3(net, test_iter, n=6):  #@save"""预测标签(定义见第3章)"""for X, y in test_iter:break# 将真实标签转换为对应的类别名称trues = d2l.get_fashion_mnist_labels(y)# 使用net进行预测,并且寻找预测结果转化为名称preds = d2l.get_fashion_mnist_labels(net(X).argmax(axis=1))#转化为title(还是使用对列生成器语法)titles = [   true +'\n' + pred    for true, pred in zip(trues, preds)   ]#展示图片d2l.show_images( X[0:n].reshape((n, 28, 28)), 1, n, titles=titles[0:n])#展示预测
predict_ch3(net, test_iter)plt.show()

关于在训练和预测的时候我们需要干什么

其实前面也算是写了不少代码了(其实也就是单纯实现了一个单一神经元以及softmax的预测)

这里就简单总结一下,在这个“训练”部分,我们一般都会做一些什么事情:

我们先拿出一个很简单的单一线性回归预测来举个例子

for X, y in data_iter:l = loss(net(X) ,y)  #计算这个一批数据(10)个的损失trainer.zero_grad()  #清除已经有的梯度l.backward()         # 计算损失对当前模型的梯度trainer.step()       #根据梯度更新模型参数,梯度下降的根本操作

其实看这个代码,我们第一步做的就是遍历,通过一开始设置的数据批次进行分批次的训练

进入某一次训练中的时候,我们要先根据损失函数,计算出这一批的损失

(不同的框架和代码对这个玩应的实现和理解都完全不一样,但是你要记住这个东西的数学本质是损失函数之和,即为这个批次数据的代价函数,我们最后梯度下降的公式,最重要的一个步骤就是对代价函数求偏倒数,这也就是框架中常说的gradient梯度)

然后根据损失,通过一种称之为“反向传递”的技术,计算出偏导

最后这个step,就代表开始训练

大致架构就是这个样子实现的,如果这个样子还不是太明白具体要做什么,那么我们直接把上面是用softmax技术的东西简化一下再放出来:

#把模型训练了for X, y in train_iter:l = loss(net(X), y)          #loss是已经封装好的损失计算函数l.sum().backward()           #计算梯度,也就是代价函数导的东西updater(X.shape[0])          #梯度在这里好像是没有传入进来,但是实际上已经保存在w和b中了

也是进行分批次的训练

然后计算一下损失,再计算代价函数,对代价函数是用反向传播求偏导数

最后进行训练

最终总结一下,像这样子手动实现一个训练的过程中,我们能做的就是

(1)想办法得到代价函数(也许还要清除之前计算得到的梯度)

(2)获取代价函数的梯度(一般是反向传递)

(3)训练

至于在预测的时候做什么,就是一些预测结果的分析,精度计算什么的,那都是后话了

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/82179.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

算法通关村-----链表中环的问题

环形链表 问题描述 给你一个链表的头节点 head ,判断链表中是否有环。如果链表中有某个节点,可以通过连续跟踪 next 指针再次到达,则链表中存在环。 为了表示给定链表中的环,评测系统内部使用整数 pos 来表示链表尾连接到链表中…

Reactor 第十二篇 WebFlux集成PostgreSQL

1 引言 在现代的应用开发中,数据库是存储和管理数据的关键组件。PostgreSQL 是一种强大的开源关系型数据库,而 WebFlux 是 Spring 框架提供的响应式编程模型。本文将介绍如何使用 Reactor 和 WebFlux 集成 PostgreSQL,实现响应式的数据库访问…

【chrome扩展开发】消息通讯之onMessage消息监听

前言 chrome.runtime.onMessage.addListener 是 Chrome 扩展程序中用于监听其他模块发送的消息并做出响应的 API 当从扩展进程 (by runtime.sendMessage) 或内容脚本 (by tabs.sendMessage)发送消息时触发 语法 chrome.runtime.onMessage.addListener(callback: function, )ca…

使用Scrapy构建高效的网络爬虫

💂 个人网站:【工具大全】【游戏大全】【神级源码资源网】🤟 前端学习课程:👉【28个案例趣学前端】【400个JS面试题】💅 寻找学习交流、摸鱼划水的小伙伴,请点击【摸鱼学习交流群】 Scrapy是一个强大的Pyth…

python虚拟环境(venv)

一、什么是python环境 首先要知道什么是python环境? Python环境主要包括以下内容: 解释器 python.exe (python interpreter,使用的哪个解释看环境配置) Lib目录 标准库 第三方库:site-pakages目录,默认安装第三方…

题目:2859.计算 K 置位下标对应和

​​题目来源: leetcode题目,网址:2859. 计算 K 置位下标对应元素的和 - 力扣(LeetCode) 解题思路: 逐个计算下标是否符合要求即可。 解题代码: class Solution {public int sumIndicesWithK…

敏捷开发工具:提升软件研发效率的重要利器

在当今的软件开发领域,敏捷开发方法越来越受到推崇。敏捷开发的核心是灵活应对需求变化,以快速迭代的方式不断优化产品。为了助力敏捷开发的实施,各种敏捷开发工具应运而生。 本文将介绍几种常用的敏捷开发工具,阐述其特点、应用…

18 Python的sys模块

概述 在上一节,我们介绍了Python的os模块,包括:os模块中一些常用的属性和函数。在这一节,我们将介绍Python的sys模块。sys模块提供了访问解释器使用或维护的变量,以及与解释器进行交互的函数。 通俗来讲,sy…

第三十一章 Classes - 继承规则

第三十一章 Classes - 继承规则 继承规则 与其他基于类的语言一样,可以通过继承组合多个类定义。 类定义可以扩展(或继承)多个其他类。这些类又可以扩展其他类。 请注意,类不能继承 Python 中定义的类(即 .py 文件中…

基于DSPACE功率平衡理论的并联有源电力滤波器模型(Simulink)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…

入行IC | 新人入行IC选择哪个岗位更好?

很多同学入行不知道怎么选择岗位。IC的岗位一般有设计、验证、后端、封装、测试、FPGA等等。但是具体到每个人身上,就要在开始的时候确定下你要找的职位,可以有两个或三个,但是要分出主次,主次不分会让你纠结整个找工作的过程。 …

webpack配置alias后eslint和ts无法识别

背景 我们在 webpack 配置 alias 后,发现项目中引入的时候,还是会报错,如下: 可以看到,有一个是 ts报错,还有一个是 eslint 报错。 解决 ts 报错 tsconfig.json {"compilerOptions": {...&q…

【力扣每日一题】2023.9.18 打家劫舍Ⅲ

目录 题目: 示例: 分析: 代码: 题目: 示例: 分析: 今天是打家劫舍3,明天估计就是打家劫舍4了。 今天的打家劫舍不太一样,改成二叉树了,不过规则没有变&…

ORACLE多列中取出数据最大的一条

1.需求说明: 当查询出来的数据存在多条数据时,想按照一定条件排序取出其中一条数据。 2.使用函数: row_number() over( partition by 分组字段 order by 排序字段 desc) 3.示例: --根据table_a中的pk_house&#x…

狗dog 数据集VOC-5912张

狗,是食肉目犬科犬属 哺乳动物 ,别称犬,与马、牛、羊、猪、鸡并称“六畜” 。狗的体型大小、毛色因品种不同而不同,体格匀称;鼻吻部较长;眼呈卵圆形;两耳或竖或垂;四肢矫健&#xff…

网站降权的康复办法(详解百度SEO数据分析)

随着搜索引擎算法的不断升级,很多网站在SEO优化过程中遭遇到降权的情况。如果您的网站也遭遇到了类似的问题,不必惊慌失措。本文将为您详细介绍网站降权恢复的方法,包括百度SEO数据分析、网站收录少的5个原因、网站被降权的6个因素以及百度SE…

超自动化的未来

如今,部分企业正尝试从小规模的自动化开始,将超级自动化用于营销分析和数据库维护等不同任务。企业应该对超自动化进行更深入的挖掘,如果人们能够更加仔细的观察总结企业的每个流程,那么就能发现更多可以从自动化技术中受益的领域…

IP风险查询:抵御DDoS攻击和CC攻击的关键一步

随着互联网的普及,网络攻击变得越来越普遍和复杂,对企业和个人的网络安全构成了重大威胁。其中,DDoS(分布式拒绝服务)攻击和CC(网络连接)攻击是两种常见且具有破坏性的攻击类型,它们…

js写一个判断字符串是否能够转为JSON 的函数

其实非常简单 这里我们需要涉及到 捕获异常 因为如果你直接在if里面转 我已经试过了 直接就报错了 一点面子不给 我们写一个这样的函数 function isJsonString(str) {try {JSON.parse(str);return true;} catch (e) {return false;} }编写如下代码 console.log(isJsonString(…

企业架构LNMP学习笔记58

开始学习Tomcat: 学习目标和内容: 1)能够描述Tomcat的使用场景; 2)能够简单描述Tomcat的工作原理; 3)能够实现部署安装Tomcat; 4)能够实现和配置Tomcat的Server服务…