简单的神经网络

一、softmax的基本概念

我们之前学过sigmoid、relu、tanh等等激活函数,今天我们来看一下softmax。

先简单回顾一些其他激活函数:

  1. Sigmoid激活函数:Sigmoid函数(也称为Logistic函数)是一种常见的激活函数,它将输入映射到0到1之间。它常用于二分类问题中,特别是在输出层以概率形式表示结果时。Sigmoid函数的优点是输出值限定在0到1之间,相当于对每个神经元的输出进行了归一化处理。
  2. Tanh激活函数:Tanh函数(双曲正切函数)将输入映射到-1到1之间。与Sigmoid函数相比,Tanh函数的中心点在零值附近,这意味着它的输出是以0为中心的。这种特性可以在某些情况下提供更好的性能。
  3. ReLU激活函数:ReLU(Rectified Linear Unit)函数是当前非常流行的一个激活函数,其表达式为f(x)=max(0, x)。ReLU函数的优点是计算简单,能够在正向传播过程中加速计算。此外,ReLU函数在正值区间内梯度为常数,有助于缓解梯度消失问题。但它的缺点是在负值区间内梯度为零,这可能导致某些神经元永远不会被激活,即“死亡ReLU”问题。

Softmax函数是一种在机器学习中广泛使用的函数,尤其是在处理多分类问题时。它的主要作用是将一组未归一化的分数转换成一个概率分布。Softmax函数的一个重要性质是其输出的总和等于1,这符合概率分布的定义。这意味着它可以将一组原始分数转换为概率空间,使得每个类别都有一个明确的概率值。

  • 二分类问题选择sigmoid激活函数

  • 多分类问题选择softmax激活函数

二、交叉熵损失函数

交叉熵损失函数的公式可以分为二分类和多分类两种情况。对于二分类问题,假设我们只考虑正类(标签为1)和负类(标签为0)在多分类问题中,交叉熵损失函数可以扩展为−∑𝑖=1𝐾𝑦𝑖⋅log⁡(𝑝𝑖)−∑i=1K​yi​⋅log(pi​),其中𝐾K是类别的总数,( y_i )是样本属于第𝑖i个类别的真实概率(通常用one-hot编码表示),而𝑝𝑖pi​是模型预测该样本属于第( i )个类别的概率。

import torch
from torch import nn# 确定随机数种子
torch.manual_seed(7)
# 自定义数据集
X = torch.rand((7, 2, 2))
target = torch.randint(0, 2, (7,))

定义网络结构

  • 一层全连接层 + Softmax层
  • x1𝑥1,x2𝑥2,x3𝑥3,x4𝑥4为 X
  • o1𝑜1,o2𝑜2,o3𝑜3为 target
class LinearNet(nn.Module):def __init__(self):super(LinearNet, self).__init__()# 定义一层全连接层self.dense = nn.Linear(4, 3)# 定义Softmaxself.softmax = nn.Softmax(dim=1)def forward(self, x):y = self.dense(x.view((-1, 4)))y = self.softmax(y)return ynet = LinearNet()
  •  nn.Softmax(dim=1)用于计算输入张量在指定维度上的softmax激活。dim=1表示沿着第二个维度(即列)进行softmax操作。

定义损失函数和优化函数

  • torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean')
  • 衡量模型输出与真实标签的差异,在分类时相当有用。
  • 结合了nn.LogSoftmax()和nn.NLLLoss()两个函数,进行交叉熵计算。
loss = nn.CrossEntropyLoss()  # 交叉熵损失函数
optimizer = torch.optim.SGD(net.parameters(), lr=0.1)  # 随机梯度下降法

训练模型

for epoch in range(70):train_l = 0.0y_hat = net(X)l = loss(y_hat, target).sum()# 梯度清零optimizer.zero_grad()# 自动求导梯度l.backward()# 利用优化函数调整所有权重参数optimizer.step()train_l += lprint('epoch %d, loss %.4f' % (epoch + 1, train_l))

三、自动微分模块

torch.autograd.backward(tensors, grad_tensors=None, retain_graph=None, create_graph=False)  :自动求取梯度

  • grad_tensors:多梯度权重
  • create_graph:创建导数计算图,用于高阶求导
  • retain_graph:保存计算图
  • tensors:用于求导的张量,如 loss
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)y.backward(retain_graph=True)

 注意点:

  1. 梯度不自动清零
  2. 依赖于叶子节点的节点,requires_grad默认为True
  3. 叶子节点不可执行in-place

神经网络全连接层: 每个神经元都与前一层的所有神经元相连接。全连接层通常用于网络的最后几层,它将之前层(如卷积层和池化层)提取的特征进行整合,以映射到样本标记空间,即最终的分类或回归结果。

关于loss.backward()方法:

主要作用就是计算损失函数对模型参数的梯度,loss.backward()实现了反向传播算法,它通过链式法则计算每个模型参数相对于最终损失的梯度。这个过程从输出层开始,向后传递到输入层,逐层计算梯度。

过程:得到每个参数相对于损失函数的梯度,这些梯度信息会存储在对应张量的.grad属性中。loss.backward本身不负责更细权重,但它为权重更新提供了梯度值,方便配合optimizer.step()来更新参数。

前向传播过程中,数据从输入层流向输出层,并生成预测结果;而在反向传播过程中,误差(即预测值与真实值之间的差距,也就是损失函数的值)会从输出层向输入层传播,逐层计算出每个参数相对于损失函数的梯度。这些梯度指示了如何调整每一层中的权重和偏置,以最小化损失函数。

  • 损失函数衡量了当前模型预测与真实情况之间的不一致程度,而梯度则提供了损失函数减少最快的方向。

建立一个简单的全连接层:

import torch
import torch.nn as nn# 定义一个简单的全连接层模型
class SimpleFC(nn.Module):def __init__(self, input_size, output_size):super(SimpleFC, self).__init__()self.fc = nn.Linear(input_size, output_size)def forward(self, x):  return self.fc(x)# 创建输入数据和目标输出
input_data = torch.tensor([[1.0, 2.0, 3.0]])
target_output = torch.tensor([[4.0, 5.0]])# 实例化模型、损失函数和优化器
model = SimpleFC(input_size=3, output_size=2)
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)# 前向传播
output = model(input_data)# 计算损失
loss = criterion(output, target_output)# 反向传播
loss.backward()# 更新参数
optimizer.step()

当调用loss.backward()时,PyTorch会自动计算损失值关于模型参数的梯度,并将这些梯度存储在模型参数的.grad属性中。然后优化器(torch.optim.SGD)可以使用这些梯度来更新模型参数,以最小化损失函数。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/835380.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

四十九坊股权设计,白酒新零售分红制度,新零售策划机构

肆拾玖坊商业模式 | 白酒新零售体系 | 新零售系统开发 坐标:厦门,我是易创客肖琳 深耕社交新零售行业10年,主要提供新零售系统工具及顶层商业模式设计、全案策划运营陪跑等。 不花钱开3000多家门店,只靠49个男人用一套方法卖白酒…

2010年认证杯SPSSPRO杯数学建模D题(第一阶段)服务网点的分布全过程文档及程序

2010年认证杯SPSSPRO杯数学建模 D题 服务网点的分布 原题再现: 服务网点、通讯基站的设置,都存在如何设置较少的站点,获得较大效益的问题。通讯基站的覆盖范围一般是圆形的,而消防、快餐、快递服务则受到道路情况和到达时间的限…

[图解]实现领域驱动设计译文暴露的问题01

0 00:00:00,430 --> 00:00:03,470 今天呢,我们来说一个主题 1 00:00:03,810 --> 00:00:04,041 2 00:00:04,041 --> 00:00:05,430 我们来谈一谈 3 00:00:05,960 --> 00:00:07,710 实现领域驱动设计 4 00:00:09,120 --> 00:00:11,070 这本书的中译本…

Android使用Chaquo来运行Python的librosa的相关代码【有详细案例教程】

在某些情况下,我们可能需要在android上运行python的代码,那么常见的解释器有很多,目前比较成熟的就是chaquo,它适配的第三方机器学习的库很多,下面是它的简单使用教程 1.环境的搭建 1.1 在Android studio中新建安卓工…

社交媒体数据恢复:飞书

飞书数据恢复过程包括以下几个步骤: 确认数据丢失:首先要确认数据是否真的丢失,有时候可能只是被隐藏或者误操作删除了。 检查回收站:飞书中删除的文件会默认保存在回收站中,用户可以通过进入回收站找到被删除的文件&…

springboot整合redis多数据源(附带RedisUtil)

单数据源RedisUtil(静态) 单数据源RedisUtil,我这里implements ApplicationContextAware在setApplicationContext注入redisTemplate,工具类可以直接类RedisUtil.StringOps.get()使用 package com.vehicle.manager.core.util;import com.alibaba.fastjson.JSON; import lombok.e…

如何向Linux内核提交开源补丁?

2021年,我曾经在openEuler社区上看到一项改进Linux内核工具的需求,因此参与过Linux内核社区的开源贡献。贡献开源社区的流程都可以在内核社区文档中找到,但是,单独学习需要一个较长的过程,新手难以入门,因此…

AI 数据观 | TapData Cloud + MongoDB Atlas:大模型与 RAG 技术有机结合,落地实时工单处理智能化解决方案

本篇为「AI 数据观」系列文章第二弹,在这里,我们将进一步探讨 AI 行业的数据价值。以 RAG 的智能工单应用场景为例,共同探索如何使用 Tapdata Cloud MongoDB Atlas 实现具备实时更新能力的向量数据库,为企业工单处理的智能化和自…

[C/C++] -- 大数的加减法

大数加减法的问题主要产生于计算机基本数据类型的表示范围限制。通常情况下,计算机采用有限位数的数据类型(如int、long)来表示整数,这些数据类型的表示范围有限,无法表示超出范围的大整数。 例如超过了long类型的表示…

【JavaScript】内置对象 - 数组对象 ⑤ ( 数组转字符串 | toString 方法 | join 方法 )

文章目录 一、数组转字符串1、数组转字符串 ( 逗号分割 ) - toString()2、数组转字符串 ( 自定义分割符 ) - join() Array 数组对象参考文档 : https://developer.mozilla.org/zh-CN/docs/Web/JavaScript/Reference/Global_Objects/Array 一、数组转字符串 1、数组转字符串 ( 逗…

指针(脑图梳理)

今天让我们来梳理一下指针都有哪些概念吧 这个脑图是整理的一些指针相关知识的概念,希望对大家有帮助

Web前端开发 小实训(三) 商品秒杀小练习

学生能够在本次实训中完成商品秒杀页面的基本逻辑 任务要求 能够实现某一个商品的秒杀&#xff0c;在倒计时结束后不再进行秒杀。 操作步骤 1、打开预设好的页面 <html><head><meta charset"utf-8"><title>秒杀</title><link …

python中如何把list变成字符串

python中如何把list变成字符串&#xff1f;方法如下&#xff1a; python中list可以直接转字符串&#xff0c;例如&#xff1a; data ["hello", "world"] print(data1:,str(data)) 得到结果&#xff1a; (data1:, "[hello, world]") 这里将整个…

视频号小店究竟有什么秘密,值得商家疯狂入驻,商家必看!

大家好&#xff0c;我是电商花花。 我们都知道视频号和抖音本身都是一个短视频平台&#xff0c;但是随着直播电商的发展&#xff0c;背后的流量推动逐步显露出强大的红利市场和变现机会。 视频号小店流量大和赚钱之外&#xff0c;还非常适合普通人创业。 这也使得越来越多的…

easypoi动态表头导出数据

需求&#xff1a;动态导出某年某月用户和用户评分数据信息&#xff0c;表头(序号、姓名、用户姓名)&#xff0c;数据(所有用户对应的评分以及平均分)&#xff1b; 分析&#xff1a;1、表头除过序号、姓名&#xff0c;用户姓名要动态生成&#xff1b; 2、用户评分信息要和表头中…

【赠书活动第4期】《Rust编程与项目实战》

赠书活动 《Rust编程与项目实战》免费赠书 3 本&#xff0c; 收到赠书之后&#xff0c;写一篇 本书某一节内容 的学习博客文章。 可在本帖评论中表示参加&#xff0c;即可获得赠书&#xff0c;先到先得。学习心得博客链接&#xff0c;后面有空发上来。 赠书截止日期为送出3…

无人播剧直播收益在哪里!快手无人播剧新秘籍:版权无忧,日入四位数攻略

无人播剧顾名思义就是通过短视频平台直播不需要真人出镜受众群体通过网络短视频平台看到的经典影视剧集可以实现24小时不停断的播放利用多种途径变现的一种直播形式 1、操作简单、不露脸、不出镜2、手机、电脑都可以操作3、可以矩阵操作4、0粉丝、0作品、0保证金就可以开播5、…

2010-2030年GHS-POP数据集下载

扫描文末二维码&#xff0c;关注微信公众号&#xff1a;ThsPool 后台回复 g008&#xff0c;领取 2010-2030年100m分辨率GHS-POP 数据集 &#x1f4ca; GHS Population Grid (R2023)&#xff1a;全球人口分布的精准视图与深度应用 &#x1f310; 在全球化和快速城市化的今天&am…

[嵌入式系统-73]:RT-Thread-快速上手:如何选择RT Thread的版本?

目录 如何选择合适的 RT-Thread 版本进行开发&#xff1f; RT-Thread 分支与版本介绍 如何选择 发布版本&#xff08;GitHub releases&#xff09; 开发分支&#xff08;GitHub master 主分支&#xff09; 长期支持分支&#xff08;GitHub lts-v3.1.x 分支&#xff09; …

10.轮转数组

文章目录 题目简介题目解答解法一&#xff1a;使用额外的数组代码&#xff1a;复杂度分析&#xff1a; 解法二&#xff1a;数组反转代码&#xff1a;复杂度分析&#xff1a; 题目链接 大家好&#xff0c;我是晓星航。今天为大家带来的是 轮转数组 相关的讲解&#xff01;&#…