【机器学习】pytorch 常用函数解析

目录

一、基本函数介绍

1.1 nn.Module 类

1.2 nn.Embedding

1.3 nn.LSTM

1.4 nn.Linear

1.5 nn.CrossEntropyLoss

1.6 torch.save

1.7 torch.load

1.8 nn.functional

1.9 nn.functional.softmax


本文主要对 pytorch 中用到的函数进行介绍,本文会不断更新~

一、基本函数介绍

1.1 nn.Module 类

我们自己定义模型的时候,通常继承 nn.Module 类,然后重写 nn.Module 中的方法,nn.Module 的主要方法如下所示。

class Module(object):def __init__(self):def forward(self, *input):def add_module(self, name, module):def cuda(self, device=None):def cpu(self):def __call__(self, *input, **kwargs):def parameters(self, recurse=True):def named_parameters(self, prefix='', recurse=True):def children(self):def named_children(self):def modules(self):  def named_modules(self, memo=None, prefix=''):def train(self, mode=True):def eval(self):def zero_grad(self):def __repr__(self):def __dir__(self):#......还有一部分,此处未列出

自定义模型一般重写 __init__ 和 forward 函数。

1.2 nn.Embedding

nn.Embedding(num_embeddings, embedding_dim)

(1)参数

num_embeddings :嵌入字典的大小,也可以理解为该模型可以表示词的数量;

embedding_size:表示嵌入向量的维度。

nn.Embedding 层的本质是一个查找表,它将输入的每个索引映射到一个固定大小的向量,可以理解为每个词都有一个固定的向量。这个映射表在初始化时会随机生成,然后在训练过程中通过反向传播进行优化。

(2)主要步骤

初始化:在初始化时,nn.Embedding 会创建一个大小为 (num_embeddings, embedding_dim)的权重矩阵。这些权重是嵌入层的参数,会在训练过程中更新;

前向传播:在前向传播过程中,nn.Embedding 层会将输入的索引映射到权重矩阵的相应行,从而得到对应的嵌入向量;

反向传播:在训练过程中,嵌入层的权重矩阵会根据损失函数的梯度进行更新。这使得嵌入向量能够捕捉到输入的语义信息。

(3)nn.Embedding 原理

nn.Embedding 的核心是一个查找表,其大小为 (num_embeddings,embedding_dim),每一行代表一个词或索引的嵌入向量。 在向量化时,输入的索引被用来查找嵌入向量,假设输入是 [1, 2, 3],则输出是权重矩阵(num_embeddings,embedding_dim)中第 1、2、3 行的向量。

下面通过一个例子进行说明。

import torch
import torch.nn as nn# 创建 Embedding 层
num_embeddings = 10  # 词汇表大小
embedding_dim = 3    # 嵌入向量的维度
embedding_layer = nn.Embedding(num_embeddings, embedding_dim)# 输入
input_indices = torch.LongTensor([1, 2, 3, 4])# 转换为嵌入向量
output_vectors = embedding_layer(input_indices)# 输出
print("input_indices:", input_indices)
print("output_vectors:", output_vectors)

输出如下所示。

(chat6b) D:\code\ChatGLM-6B-main>python test.py
input_indices: tensor([1, 2, 3, 4])
output_vectors: tensor([[-0.3269, -1.2620,  0.0695],[-1.6919, -1.6591, -0.7417],[ 2.0479,  0.9768,  1.4318],[-0.7075,  1.1718,  0.7530]], grad_fn=<EmbeddingBackward0>)(chat6b) D:\code\ChatGLM-6B-main>

输出一共包含四个向量,每行表示一个。

1.3 nn.LSTM

后续更新~

1.4 nn.Linear

nn.Linear 是神经网络的线性层,可以看作是通过一个二维矩阵做了一个转换。

torch.nn.Linear(in_features,  # 输入的神经元个数out_features, # 输出神经元个数bias=True     # 是否包含偏置)

nn.Linear 对输入执行线性变换,如下所示。

其中,X 表示输入,Y 表示输出,b 为偏置。

下面来看一个例子。

import torch
from torch import nninput = torch.Tensor([1, 2, 3]) # 样本有 3 个特征
model = nn.Linear(3, 2) # 输入特征数为 3,输出特征数为 2
print("model = ", model)
# nn.Linear 权重
for param in model.parameters():print(param)output = model(input)
print(output)

输出如下所示。

model =  Linear(in_features=3, out_features=2, bias=True)
Parameter containing:
tensor([[-0.4270,  0.0396,  0.2899],[-0.4481,  0.4071,  0.4366]], requires_grad=True)
Parameter containing:
tensor([-0.1091,  0.3018], requires_grad=True)
tensor([0.4128, 1.9777], grad_fn=<ViewBackward0>)

 X(1x3)= [1,2,3], W(3x2) = [[-0.4270,  0.0396,  0.2899], [-0.4481,  0.4071,  0.4366]]的转置,b = [-0.1091,  0.3018],可以手动计算最后的结果,例如:0.4128 = -0.4270 * 1 + 0.0396 * 2 + 0.2899*3 - 0.1091,同理也可以计算 1.9777。

1.5 nn.CrossEntropyLoss

交叉熵(Cross-Entropy)是一种用于比较真实标签和预测标签概率之间差异的度量,交叉熵通常用作损失函数,用于衡量模型预测与真实标签之间的差异,尤其在分类任务中广泛使用。

交叉熵越小,模型预测越准确。当模型的预测与真实标签完全一致时,交叉熵达到最小值为 0。

import torch
import torch.nn as nn
from torch.nn.functional import one_hotoutput = torch.randn(4, 3)  # 模型预测,4 个样本,3 分类
print('output:\n', output)target = torch.tensor([1, 2, 0, 1])  # 真实标签值
target1 = target
# 实际上不需要转换为 one_hot,这里测试证明了这一点
target = one_hot(target, num_classes=3)
target = target.to(dtype=torch.float)
crossentropyloss = nn.CrossEntropyLoss()
output_loss = crossentropyloss(output, target)
output_loss1 = crossentropyloss(output, target1)print('output_loss:\n', output_loss)
print('output_loss1:\n', output_loss1)

 顺便测试了下是否需要转换为 one_hat。

1.6 torch.save

torch.save() 的主要作用就是将 PyTorch 对象(如模型、张量等)保存到磁盘上,以文件的形式进行存储。如果想使用训练后的模型,从磁盘上加载即可。

torch.save(model,保存路径)  # 保存整个模型
torch.save(model.state_dict(), 保存路径) # 只保存模型参数

 CrossEntropyLoss() 损失函数结合了 nn.LogSoftmax() 和 nn.NLLLoss() 两个函数。它在做分类训练的时候是非常有用的。

1.7 torch.load

 torch.load() 函数用于加载磁盘上模型文件。 

 torch.load(模型路径)

1.8 nn.functional

nn.functional 是 PyTorch 中一个重要的模块,它包含了许多用于构建神经网络的函数。与 nn.Module 不同,nn.functional 中的函数不具有可学习的参数。

这些函数通常用于执行各种非线性操作、损失函数、激活函数等。 这个模块的主要优势是它的计算效率和灵活性,因为它允许你以函数的方式直接调用这些操作,而不需要创建额外的层。

1.9 nn.functional.softmax

softmax 有两种形式。

torch.nn.Softmax(input, dim)
torch.nn.functional.softmax(input, dim)

下面主要对 torch.nn.functional.softmax 进行介绍。 

对 n 维输入张量运用 softmax 函数,将张量的每个元素缩放到(0,1)区间且和为1。

softmax(input, dim=None, _stacklevel=3, dtype=None)

主要参数:

input : 输入的张量;

dim : 指明维度,dim=0表示按列计算;dim=1表示按行计算。默认dim的方法已经弃用了,最好声明dim,否则会警告。

softmax 公式如下所示。

下面来看一个例子。

import torch
import torch.nn.functional as Finput = torch.Tensor([[1, 2, 3, 4],[1, 2, 3, 4]])output1 = F.softmax(input, dim=0) #对每一列进行softmax
print(output1)output2 = F.softmax(input, dim=1) #对每一行进行softmax
print(output2)

 输出如下所示。

tensor([[0.5000, 0.5000, 0.5000, 0.5000],[0.5000, 0.5000, 0.5000, 0.5000]])
tensor([[0.0321, 0.0871, 0.2369, 0.6439],[0.0321, 0.0871, 0.2369, 0.6439]])

分别对输入张量的列和行进行了 softmax。 

后续更新:torch.randn、torch.tensor、one_hot、torch.LongTensor

参考链接:

[1] Pytorch nn.Linear()的基本用法与原理详解及全连接层简介_nn.linear()作用-CSDN博客

[2] pytorch教程之nn.Module类详解——使用Module类来自定义模型-CSDN博客

[3] torch.nn - PyTorch中文文档 

[4] pytorch nn.Embedding 用法和原理_pytorch nn.embedding 设置初始化函数-CSDN博客

[5] Pytorch nn.Linear()的基本用法与原理详解及全连接层简介_nn.linear()作用-CSDN博客 

[6] https://www.cnblogs.com/wanghui-garcia/p/10675588.html 

[7] PyTorch `nn.functional` 模块详解:探索神经网络的魔法工具箱_torch.nn.functional-CSDN博客  

[8] Pytorch CrossEntropyLoss() 原理和用法详解-CSDN博客 

[9] https://www.cnblogs.com/peixu/p/13194801.html 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/50928.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C语言内存函数精讲

目录 引言 1.内存分配函数malloc 2.内存释放函数free 3.内存拷贝函数memcpy 4.内存移动函数memmove 5.内存设置函数memset 6.内存比较函数memcmp 总结 引言 在C语言编程中&#xff0c;内存管理是核心技能之一。C语言提供了一系列内存操作函数&#xff0c;这些函数在动…

jmeter-beanshell学习-try处理异常

有时候代码执行过程中&#xff0c;出现一些不能处理的情况&#xff0c;就会报错&#xff0c;还影响之后的代码执行&#xff0c;就需要跳过异常。 上面这情况报错了&#xff0c;还影响了下面的打印。beanshell用try和catch处理异常&#xff0c;下面是try的用法&#xff0c;和if有…

技术守护尊严||Chat GPT在抵抗性骚扰的作用分析

就在本周&#xff0c;中国人民大学女博士实名举报导师性骚扰的事情&#xff0c;引发全网关注&#xff01; 性骚扰&#xff0c;无论在线上还是线下&#xff0c;无论在职场还是校园&#xff0c;都是对个人尊严与权益的严重侵犯。 幸运的是&#xff0c;随着人工智能技术的飞速发…

优化冗余代码:提升前端项目开发效率的实用方法

目录 前言代码复用与组件化模块化开发与代码分割工具辅助与自动化结束语 前言 在前端开发中&#xff0c;我们常常会遇到代码冗余的问题&#xff0c;这不仅增加了代码量&#xff0c;还影响了项目的可维护性和开发效率。还有就是有时候会接到紧急业务需求&#xff0c;要求立马完…

[网络通信原理]——TCP/IP模型—网络层

网络层 网络层概述 网络层位于OSI模型的第三层&#xff0c;它定义网络设备的逻辑地址&#xff0c;也就是我们说的IP地址&#xff0c;能够在不同的网段之间选择最佳数据转发路径。在网络层中有许多协议&#xff0c;其中主要的协议是IP协议。 IP数据包格式 IP数据报是可变长度…

《最新出炉》系列入门篇-Python+Playwright自动化测试-55- 上传文件 (非input控件)- 中篇

软件测试微信群&#xff1a;https://bbs.csdn.net/topics/618423372 有兴趣的可以扫码加入 1.简介 在实际工作中&#xff0c;我们进行web自动化的时候&#xff0c;文件上传是很常见的操作&#xff0c;例如上传用户头像&#xff0c;上传身份证信息等。所以宏哥打算按上传文件…

Java从入门到精通(十二)~ 动态代理

晚上好&#xff0c;愿这深深的夜色给你带来安宁&#xff0c;让温馨的夜晚抚平你一天的疲惫&#xff0c;美好的梦想在这个寂静的夜晚悄悄成长。 文章目录 目录 前言 主要作用和功能&#xff1a; 应用场景&#xff1a; 二、代理概念 1.静态代理 2.动态代理 2.1 概念介绍 …

网址导航系统PHP源码分享

1、采用光年全新v5模板开发后台 2、后台内置8款主题色&#xff0c;分别是简约白、炫光绿、渐变紫、活力橙、少女粉、少女紫、科幻蓝、护眼黑 3、可管理无数引导页主题并且主题内可以进行不同的自定义设置&#xff0c;目前内置16套主题 持续增加中… 4、可单独开发各种插件&a…

OSPF Type2 Message / DBD Packet (Database Descriptor)

注&#xff1a;机翻&#xff0c;未校对。 OSPF Type2 Message / DBD Packet (Database Descriptor) DBD (Database Description or Type2 OSPF Packet) is a sort of summary of the OSPF Database in a router. DBD is used to check if the LSDB between 2 routers is the s…

Linux---make/makefile工具

目录 基本了解 makefile基础语法 依赖关系 依赖方法 makefile文件内容格式 make执行机制 补充知识 机制解释 PHONY关键字 makefile补充语法 基本了解 在Linux中&#xff0c;make/makefile是项目自动化构建工具。如果我们没有make/makefile&#xff0c;那我们要编译一…

基于Java的模拟写字板的设计与实现

点击下载链接 基于Java的模拟写字板的设计与实现 摘要&#xff1a;目前&#xff0c;很多新的技术领域都涉及到了Java语言&#xff0c;Java语言是面向对象编程&#xff0c;并且涉及到网络、多线程等重要的基础知识&#xff0c;因此Java语言也是学习面向对象编程和网络编程的首…

Linux系统编程——生产者消费者模型

目录 一&#xff0c;模型介绍 1.1 预备知识&#xff08;超市买东西的例子&#xff09; 1.2 模型介绍 1.3 CP模型特点 二&#xff0c;基于阻塞队列的CP模型 2.1 介绍 2.2 阻塞队列的实现 2.3 主函数实现 2.4 效果展示 三&#xff0c;POSIX信号量 3.1 信号量原理 3…

力扣 快慢指针

1 环形链表 141. 环形链表 - 力扣&#xff08;LeetCode&#xff09; 定义两个指针&#xff0c;一快一慢。慢指针每次只移动一步&#xff0c;而快指针每次移动两步。初始时&#xff0c;慢指针和快指针都在位置 head&#xff0c;这样一来&#xff0c;如果在移动的过程中&#x…

05。拿捏ArkTS 第 3 天 --- 对象、联合类型、枚举

1&#xff0c;什么是对象&#xff1f;对象是干什么的&#xff1f; &#xff5e;用来存储不同类型数据的容器 &#xff5e;用来描述物体的特征和行为 //特征就是属性&#xff0c;行为就是方法&#xff08;对象内的函数&#xff09; 2&#xff0c;对象的基本样式是&#xff1f; …

Noah-MP陆面生态水文模拟与多源遥感数据同化技术

了解陆表过程的主要研究内容以及陆面模型在生态水文研究中的地位和作用&#xff1b;熟悉模型的发展历程&#xff0c;常见模型及各自特点&#xff1b;理解Noah-MP模型的原理&#xff0c;掌握Noah-MP模型在单站和区域的模拟、模拟结果的输出和后续分析及可视化等方法&#xff1b;…

OpenGL入门第六步:材质

目录 结果显示 材质介绍 函数解析 具体代码 结果显示 材质介绍 当描述一个表面时&#xff0c;我们可以分别为三个光照分量定义一个材质颜色(Material Color)&#xff1a;环境光照(Ambient Lighting)、漫反射光照(Diffuse Lighting)和镜面光照(Specular Lighting)。通过为每…

23.jdk源码阅读之Thread(下)

1. 写在前面 上篇文章我们介绍了Tread的一些方法的底层代码实现&#xff0c;这篇文章我们继续。 2. join()方法的底层实现 /*** Waits at most {code millis} milliseconds for this thread to* die. A timeout of {code 0} means to wait forever.** <p> This impleme…

从工艺到性能:模具3D打印材料不断革新

在模具3D打印领域&#xff0c;材料性能的持续优化与创新是推动模具3D打印的关键因素&#xff0c;近年来&#xff0c;各种3D打印新材料不断涌现&#xff0c;模具3D打印材料也开始重工艺导向逐步向性能导向发展&#xff0c;如毅速公司推出的ESU-EM191/191S及ESU-EM201不锈钢粉末、…

电脑文件误删除如何恢复?数据恢复第一步是什么?这五点要第一时间处理!

电脑文件误删除如何恢复&#xff1f;数据删除恢复的第一时间要做什么&#xff0c;你知道吗&#xff1f; 在使用电脑的过程中&#xff0c;误删除重要文件的情况时有发生。面对这种情况&#xff0c;不必过于慌张&#xff0c;因为有多种方法可以帮助你恢复误删除的文件。以下是恢复…

网络通信---UDP

前两天做了个mplayer项目&#xff0c;今日继续学习 网络内容十分重要&#xff01;&#xff01;&#xff01; 1.OSI七层模型 应用层:要传输的数据信息&#xff0c;如文件传输&#xff0c;电子邮件等&#xff08;最接近用户&#xff0c;看传输的内容类型到底是什么&#xff09; …