PyTorch使用教程(6)一文讲清楚torch.nn和torch.nn.functional的区别

torch.nn torch.nn.functional 在 PyTorch 中都是用于构建神经网络的重要组件,但它们在设计理念、使用方式和功能上存在一些显著的区别。以下是关于这两个模块的详细区别:

1. 继承方式与结构

torch.nn

  • torch.nn 中的模块大多数是通过继承 torch.nn.Module 类来实现的。这些模块都是 Python 类,包含了神经网络的各种层(如卷积层、全连接层等)和其他组件(如损失函数、优化器等)。
  • torch.nn 中的模块可以包含可训练参数,如权重和偏置,这些参数在训练过程中会被优化。

torch.nn.functional

  • torch.nn.functional 中的函数是直接调用的,无需实例化。这些函数通常用于执行各种非线性操作、损失函数计算、激活函数应用等。
  • torch.nn.functional 中的函数没有可训练参数,它们只是执行操作并返回结果。

2. 实现方式与调用方式

torch.nn

  • torch.nn 中的模块是基于面向对象的方法实现的。开发者需要创建类的实例,并在类的 forward 方法中定义数据的前向传播路径。
  • torch.nn 中的模块通常需要先创建模型实例,再将输入数据传入模型中进行前向计算。

torch.nn.functional

  • torch.nn.functional 中的函数是基于函数式编程实现的。它们提供了灵活的接口,允许开发者以函数调用的方式轻松定制和扩展神经网络架构。
  • torch.nn.functional 中的函数可以直接调用,只需要将输入数据传入函数中即可进行前向计算。

3. 使用场景与优势

torch.nn

  • torch.nn 更适合用于定义有状态的模块,如包含可训练参数的层。
  • 当定义具有变量参数的层时(如卷积层、全连接层等),torch.nn 会帮助初始化好变量,并且模型类本身就是 nn.Module 的实例,看起来会更加协调统一。
  • torch.nn 可以结合 nn.Sequential 来简化模型的构建过程。

torch.nn.functional

  • torch.nn.functional 中的函数相比 torch.nn 更偏底层,封装性不高但透明度很高。开发者可以在其基础上定义出自己想要的功能。
  • 使用 torch.nn.functional 可以更方便地进行函数组合、复用等操作,适合那些喜欢使用函数式编程风格的开发者。当激活函数只需要在前向传播中使用时,使用 torch.nn.functional 中的激活函数会更加简洁。

4. 权重与参数管理

torch.nn

  • torch.nn 中的模块会自动管理权重和偏置等参数,这些参数可以通过 model.parameters() 方法获取,并用于优化算法的训练。

torch.nn.functional

  • torch.nn.functional 中的函数不直接管理权重和偏置等参数。如果需要使用这些参数,开发者需要在函数外部定义并初始化它们,然后将它们作为参数传入函数中。

5.举例说明

例子1:定义卷积层

使用 torch.nn

import torch.nn as nnclass MyConvNet(nn.Module):def __init__(self):super(MyConvNet, self).__init__()self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1)def forward(self, x):x = self.conv1(x)return x# 实例化模型
model = MyConvNet()# 传入输入数据
input_tensor = torch.randn(1, 1, 32, 32)
output_tensor = model(input_tensor)

使用 torch.nn.functional

import torch.nn.functional as Fdef my_conv_net(input_tensor, weight, bias=None):output_tensor = F.conv2d(input_tensor, weight, bias=bias, stride=1, padding=1)return output_tensor# 定义卷积核的权重和偏置
weight = nn.Parameter(torch.randn(16, 1, 3, 3))
bias = nn.Parameter(torch.randn(16))# 传入输入数据
input_tensor = torch.randn(1, 1, 32, 32)
output_tensor = my_conv_net(input_tensor, weight, bias)

在这个例子中,使用 torch.nn 定义了一个包含卷积层的模型类,而使用 torch.nn.functional 则是通过函数直接进行卷积操作。注意在使用 torch.nn.functional 时,需要手动定义和传递卷积核的权重和偏置。

例子2:应用激活函数

使用 torch.nn

import torch.nn as nnclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.relu = nn.ReLU()def forward(self, x):x = self.relu(x)return x# 实例化模型
model = MyModel()# 传入输入数据
input_tensor = torch.randn(1, 10)
output_tensor = model(input_tensor)

使用 torch.nn.functional

import torch.nn.functional as Fdef my_model(input_tensor):output_tensor = F.relu(input_tensor)return output_tensor# 传入输入数据
input_tensor = torch.randn(1, 10)
output_tensor = my_model(input_tensor)

在这个例子中,使用 torch.nn 定义了一个包含 ReLU 激活函数的模型类,而使用 torch.nn.functional 则是通过函数直接应用 ReLU 激活函数。

例子3:定义和计算损失

使用 torch.nn

import torch.nn as nnclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.linear = nn.Linear(10, 2)def forward(self, x):x = self.linear(x)return x# 实例化模型
model = MyModel()# 定义损失函数
criterion = nn.CrossEntropyLoss()# 传入输入数据和标签
input_tensor = torch.randn(1, 10)
target = torch.tensor()# 前向传播和计算损失
output_tensor = model(input_tensor)
loss = criterion(output_tensor, target)

使用 torch.nn.functional

import torch.nn.functional as Fdef my_model(input_tensor):output_tensor = torch.matmul(input_tensor, weight.t()) + biasreturn output_tensor# 定义权重和偏置
weight = nn.Parameter(torch.randn(10, 2))
bias = nn.Parameter(torch.randn(2))# 定义损失函数
criterion = nn.CrossEntropyLoss()# 传入输入数据和标签
input_tensor = torch.randn(1, 10)
target = torch.tensor()# 前向传播和计算损失
output_tensor = my_model(input_tensor)
loss = criterion(output_tensor, target)

在这个例子中,使用 torch.nn 定义了一个包含全连接层的模型类,并使用了 torch.nn 中的损失函数来计算损失。而使用 torch.nn.functional 则是通过函数直接进行线性变换,并使用 torch.nn 中的损失函数来计算损失。注意在使用 torch.nn.functional 时,需要手动定义和传递权重和偏置。

6. 小结

torch.nn 和 torch.nn.functional 在定义神经网络组件、应用激活函数和计算损失等方面存在显著的区别。torch.nn 提供了一种面向对象的方式来构建模型,而 torch.nn.functional 则提供了一种更灵活、更函数式的方式来执行相同的操作。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/68488.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

移动端布局 ---- 学习分享

响应式布局实现方法 主流的实现方案有两种: 通过rem \ vw/vh \ 等单位,实现在不同设备上显示相同比例进而实现适配. 响应式布局,通过媒体查询media 实现一套HTML配合多套CSS实现适配. 在学习移动端适配之前,还需要学习移动端适配原理: 移动端适配原理(Viewport) 了解VSCo…

cuda + cudnn安装

1.安装CUDA Toolkit 在设备管理器(此电脑–右键–属性)的显示适配器中可以查看自己的显卡型号,去下载对应的CUDA Toolkit 。或者输入以下命令查看Driver Version ,cuda Version:12.2代表12.2版本以下兼容可以进行安装 …

C++入门基础之命名空间|using声明|缺省参数

命名空间 一、命名空间的基本概念 命名空间是 C 中的一种机制,用于将相关的标识符(包括函数、类、变量、对象、结构体、枚举等)分组在一起,以避免命名冲突。 它可以看作是一个逻辑容器,将具有相似功能或来自同一模块…

【PyCharm】快捷键使用

【PyCharm】相关链接 【PyCharm】连接 Git【PyCharm】连接Jupyter Notebook【PyCharm】快捷键使用【PyCharm】远程连接Linux服务器【PyCharm】设置为中文界面 【PyCharm】快捷键使用 PyCharm 是一个功能强大且专为 Python 开发设计的集成开发环境(IDE&#xff09…

DevUI 2024 年度运营报告:开源生态的成长足迹与未来蓝图

在当今数字化飞速发展的时代,开源已成为推动技术创新与协作的重要力量。DevUI 作为开源领域的重要一员,其发展历程与成果备受关注。值此之际,GitCode 精心整理了 DevUI 年度运营报告,为您全面呈现 DevUI 社区在过去一年里的开源之…

python中的RPA->playwright自动化录制脚本实战案例笔记

playwright录制功能使用绕过登录操作 1、首先安装playwright pip install playwright2、 安装支持的浏览器 playwright install # 安装支持的浏览器:cr, chromium, ff, firefox, wk 和 webkit3、接着在自己的项目下运行录制命令: playwright codegen…

Vue.js组件开发全解析

Vue.js组件开发全解析 文章目录 Vue.js组件开发全解析一、Vue.js基础回顾(一)Vue实例(二)指令(三)计算属性和方法 二、组件基础(一)组件的概念(二)全局组件注…

如何选择适合特定项目需求的人工智能学习框架?

人工智能学习框架(AI Learning Framework)是一种用于开发、训练和部署人工智能模型的软件平台,旨在简化AI模型的设计、训练和部署过程。这些框架通常提供一系列工具、库和预构建模块,使开发者能够快速实现机器学习任务&#xff0c…

架构设计:微服务还是集群更适合?

在现代软件开发中,微服务和集群是两种广泛应用的架构设计方案。随着系统需求的不断复杂化和规模的扩大,选择一种适合的架构对系统的性能、可维护性和扩展性至关重要。那么,在架构设计中,是选择微服务还是集群更适合?本…

Spring Bug解决

报错: Exception in thread "main" org.springframework.beans.factory.NoUniqueBeanDefinitionException: No qualifying bean of type com.itxl.spring6.iocxml.User available: expected single matching bean but found 2: user,user1 at org.sp…

算法(蓝桥杯)贪心算法5——删数问题的解题思路

问题描述 给定一个高精度的正整数 n(n≤1000 位),需要删除其中任意 s 个数字,使得剩下的数字按原左右顺序组成一个新的正整数,并且这个新的正整数最小。例如,对于数字 153748,删除 2 个数字后&a…

元素隐式具有 “any“ 类型,因为类型为 “string“ 的表达式不能用于索引类型

元素隐式具有 “any” 类型,因为类型为 “string” 的表达式不能用于索引类型 “{ minLon: string; maxLon: string; minLat: string; maxLat: string; minTime: string; maxTime: string; keyword: string; subjct: string; ele: string; }”。 在类型 “{ minLon:…

《探秘鸿蒙NEXT中的人工智能核心架构》

在当今科技飞速发展的时代,华为HarmonyOS NEXT的发布无疑是操作系统领域的一颗重磅炸弹,其将人工智能与操作系统深度融合,开启了智能新时代。那么,鸿蒙NEXT中人工智能的核心架构究竟是怎样的呢?让我们一同探秘。 基础…

git克隆原项目到新目录,保留提交记录分支等,与原项目保持各自独立

1、克隆原仓库到本地 --mirror 会完整克隆所有git数据,包括所有分支、标签、提交记录 git clone --mirror http://gitlab.com.../old-project.git2、进入文件夹 cd old-project.git3、添加目录仓库为远程 xx-orgin表示给远程地址命名 git remote add xx-orgin htt…

U盘被格式化后的数据救赎与防范策略

一、U盘格式化后的数据困境 在日常的工作与生活中,U盘作为数据传输与存储的重要工具,扮演着不可或缺的角色。然而,当U盘不幸遭遇格式化操作后,存储在其中的宝贵数据瞬间化为乌有,给用户带来极大的困扰。格式化后的U盘…

PyBroker:利用 Python 和机器学习助力算法交易

PyBroker:利用 Python 和机器学习助力算法交易 你是否希望借助 Python 和机器学习的力量来优化你的交易策略?那么你需要了解一下 PyBroker!这个 Python 框架专为开发算法交易策略而设计,尤其关注使用机器学习的策略。借助 PyBrok…

【AI论文】LlamaV-o1:重新思考大型语言模型(LLMs)中的逐步视觉推理方法

摘要:推理是解决复杂多步骤问题的基本能力,特别是在需要逐步顺序理解的视觉环境中尤为重要。现有的方法缺乏一个全面的视觉推理评估框架,并且不强调逐步解决问题。为此,我们通过三项关键贡献,提出了一个在大型语言模型…

【HTTP】详解

目录 HTTP 基本概念啥是HTTP,有什么用?一次HTTP请求的过程当你在浏览器中输入一个浏览器地址,它会发送什么 ?---(底层流程)HTTP的协议头请求头(对应客户端)一些请求头请求方法 响应头…

EasyExcel - 行合并策略(二级列表)

😼前言:博主在工作中又遇到了新的excel导出挑战:需要导出多条文章及其下联合作者的信息,简单的来说是一个二级列表的数据结构。 🕵️‍♂️思路:excel导出实际上是一行一行的记录,再根据条件对其…

第9章:基于Vision Transformer(ViT)网络实现的迁移学习图像分类任务:早期秧苗图像识别

目录 1. ViT 模型 2. 早期秧苗分类 2.1 数据集 2.2 训练 2.3 训练结果 2.4 可视化网页推理 3. 下载 1. ViT 模型 视觉变换器(ViT)是一种神经网络架构,它将变换器架构的原理应用于视觉数据。最初,Transformers主要用于自然…