PyTorch之nn.Module与nn.functional用法区别

文章目录

  • 1. nn.Module
  • 2. nn.functional
    • 2.1 基本用法
    • 2.2 常用函数
  • 3. nn.Module 与 nn.functional
    • 3.1 主要区别
    • 3.2 具体样例:nn.ReLU() 与 F.relu()
  • 参考资料

1. nn.Module

在PyTorch中,nn.Module 类扮演着核心角色,它是构建任何自定义神经网络层、复杂模块或完整神经网络架构的基础构建块。通过继承 nn.Module 并在其子类中定义模型结构和前向传播逻辑(forward() 方法),开发者能够方便地搭建并训练深度学习模型。

关于 nn.Module 的更多介绍可以参考博客:PyTorch之nn.Module、nn.Sequential、nn.ModuleList使用详解

这里,我们基于nn.Module创建一个简单的神经网络模型,实现代码如下:

import torch
import torch.nn as nnclass MyModel(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(MyModel, self).__init__()self.layer1 = nn.Linear(input_size, hidden_size)self.layer2 = nn.Linear(hidden_size, output_size)def forward(self, x):x = torch.relu(self.layer1(x))x = self.layer2(x)return x

2. nn.functional

nn.functional 是PyTorch中一个重要的模块,它包含了许多用于构建神经网络的函数。与 nn.Module 不同,nn.functional 中的函数不具有可学习的参数。这些函数通常用于执行各种非线性操作、损失函数、激活函数等。

2.1 基本用法

如何在神经网络中使用nn.functional?

在PyTorch中,你可以轻松地在神经网络中使用 nn.functional 函数。通常,你只需将输入数据传递给这些函数,并将它们作为网络的一部分。

以下是一个简单的示例,演示如何在一个全连接神经网络中使用ReLU激活函数:

import torch.nn as nn
import torch.nn.functional as Fclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.fc1 = nn.Linear(64, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = F.relu(self.fc1(x))x = self.fc2(x)return x

在上述示例中,我们首先导入nn.functional 模块,然后在网络的forward 方法中使用F.relu 函数作为激活函数。

nn.functional 的主要优势是它的计算效率和灵活性,因为它允许你以函数的方式直接调用这些操作,而不需要创建额外的层。

2.2 常用函数

(1)激活函数

激活函数是神经网络中的关键组件,它们引入非线性性,使网络能够拟合复杂的数据。以下是一些常见的激活函数:

  • ReLU(Rectified Linear Unit)
    ReLU是一种简单而有效的激活函数,它将输入值小于零的部分设为零,大于零的部分保持不变。它的数学表达式如下:
output = F.relu(input)
  • Sigmoid
    Sigmoid函数将输入值映射到0和1之间,常用于二分类问题的输出层。它的数学表达式如下:
output = F.sigmoid(input)
  • Tanh(双曲正切)
    Tanh函数将输入值映射到-1和1之间,它具有零中心化的特性,通常在循环神经网络中使用。它的数学表达式如下:
output = F.tanh(input)

(2)损失函数

损失函数用于度量模型的预测与真实标签之间的差距。PyTorch的nn.functional 模块包含了各种常用的损失函数,例如:

  • 交叉熵损失(Cross-Entropy Loss)
    交叉熵损失通常用于多分类问题,计算模型的预测分布与真实分布之间的差异。它的数学表达式如下:
loss = F.cross_entropy(input, target)
  • 均方误差损失(Mean Squared Error Loss)
    均方误差损失通常用于回归问题,度量模型的预测值与真实值之间的平方差。它的数学表达式如下:
loss = F.mse_loss(input, target)
  • L1 损失
    L1损失度量预测值与真实值之间的绝对差距,通常用于稀疏性正则化。它的数学表达式如下:
loss = F.l1_loss(input, target)

(3)非线性操作

nn.functional 模块还包含了许多非线性操作,如池化、归一化等。

  • 最大池化(Max Pooling)
    最大池化是一种用于减小特征图尺寸的操作,通常用于卷积神经网络中。它的数学表达式如下:
output = F.max_pool2d(input, kernel_size)
  • 批量归一化(Batch Normalization)
    批量归一化是一种用于提高训练稳定性和加速收敛的技术。它的数学表达式如下:
output = F.batch_norm(input, mean, std, weight, bias)

3. nn.Module 与 nn.functional

3.1 主要区别

nn.Module 与 nn.functional 的主要区别在于:

  • nn.Module实现的layers是一个特殊的类,都是由class Layer(nn.Module)定义,会自动提取可学习的参数;
  • nn.functional中的函数更像是纯函数,由def function(input)定义。

注意:

  1. 如果模型有可学习的参数时,最好使用nn.Module。
  2. 激活函数(ReLU、sigmoid、Tanh)、池化(MaxPool)等层没有可学习的参数,可以使用对应的functional函数。
  3. 卷积、全连接等有可学习参数的网络建议使用nn.Module。
  4. dropout没有可学习参数,但建议使用nn.Dropout而不是nn.functional.dropout。

3.2 具体样例:nn.ReLU() 与 F.relu()

nn.ReLU() :

import torch.nn as nn
'''
nn.ReLU()

F.relu():

import torch.nn.functional as F
'''
out = F.relu(input)

其实这两种方法都是使用relu激活,只是使用的场景不一样,F.relu()是函数调用,一般使用在foreward函数里。而nn.ReLU()是模块调用,一般在定义网络层的时候使用。

当用print(net)输出时,nn.ReLU()会有对应的层,而F.ReLU()是没有输出的。

import torch.nn as nn
import torch.nn.functional as Fclass NET1(nn.Module):def __init__(self):super(NET1, self).__init__()self.conv = nn.Conv2d(3, 16, 3, 1, 1)self.bn = nn.BatchNorm2d(16)self.relu = nn.ReLU()  # 模块的激活函数def forward(self, x):out = self.conv(x)x = self.bn(x)out = self.relu()return outclass NET2(nn.Module):def __init__(self):super(NET2, self).__init__()self.conv = nn.Conv2d(3, 16, 3, 1, 1)self.bn = nn.BatchNorm2d(16)def forward(self, x):x = self.conv(x)x = self.bn(x)out = F.relu(x)  # 函数的激活函数return outnet1 = NET1()
net2 = NET2()
print(net1)
print(net2)

在这里插入图片描述

参考资料

  • PyTorch的nn.Module类的详细介绍
  • PyTorch nn.functional 模块详解:探索神经网络的魔法工具箱
  • pytorch:F.relu() 与 nn.ReLU() 的区别

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/37501.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Spring Boot 源码学习】初识 ConfigurableEnvironment

《Spring Boot 源码学习系列》 初识 ConfigurableEnvironment 一、引言二、主要内容2.1 Environment2.1.1 配置文件(profiles)2.1.2 属性(properties) 2.2 ConfigurablePropertyResolver2.2.1 属性类型转换配置2.2.2 占位符配置2.…

Python--进程基础

创建进程 os.fork() 该方法只能在linux和mac os中使用,因为其主要基于系统的fork来实现。window中没有这个方法。 通过os.fork()方法会创建一个子进程,子进程的程序集为该语句下方的所有语句。 import os​​print("主进程的PID为:" , os.g…

Python pdfkit wkhtmltopdf html转换pdf 黑体字体乱码

wkhtmltopdf 黑体在html转换pdf时&#xff0c;黑体乱码&#xff0c;分析可能wkhtmltopdf对黑体字体不太兼容&#xff1b; 1.html内容如下 <html> <head> <meta http-equiv"content-type" content"text/html;charsetutf-8"> </head&…

DreamView数据流

DreamView数据流 查看DV中界面启动dag&#xff0c;/apollo/modules/dreamview_plus/conf/hmi_modes/pnc.pb.txt可以看到点击界面的planning按钮&#xff0c;后台其实启动的是/apollo/modules/planning/planning_component/dag/planning.dag和/apollo/modules/external_command…

Oracle连接mysql

oracle使用的11g&#xff0c;在一台windows服务器&#xff1b;mysql使用的是5.7版本&#xff0c;在另一台windows服务器&#xff0c;这两个服务器之间的网络是互通的。做BI时&#xff0c;要获取不同数据源的数据&#xff0c;这些数据源可能是Oracle&#xff0c;也可能是sqlserv…

springboot基础入门2(profile应用)

Profile应用 一、何为Profile二、profile配置方式1.多profile文件方式2.yml多文档方式 三、加载顺序1. file:./config/: 当前项目下的/config目录下2. file:./ &#xff1a;当前项目的根目录3. classpath:/config/:classpath的/config目录4. classpath:/ : classpath的根目录 四…

Redis 集群模式

一、集群模式概述 Redis 中哨兵模式虽然提高了系统的可用性&#xff0c;但是真正存储数据的还是主节点和从节点&#xff0c;并且每个节点都存储了全量的数据&#xff0c;此时&#xff0c;如果数据量过大&#xff0c;接近或超出了 主节点 / 从节点机器的物理内存&#xff0c;就…

个人网站制作 Part 28 添加用户活动跟踪功能 | Web开发项目添加页面缓存

文章目录 &#x1f469;‍&#x1f4bb; 基础Web开发练手项目系列&#xff1a;个人网站制作&#x1f680; 添加用户活动跟踪功能&#x1f528;使用分析工具&#x1f527;步骤 1: 选择分析工具&#x1f527;步骤 2: 注册Google Analytics账户&#x1f527;步骤 3: 获取Analytics…

Java面试题--JVM大厂篇之深入了解G1 GC:高并发、响应时间敏感应用的最佳选择

引言&#xff1a; 在现代Java应用的性能优化中&#xff0c;垃圾回收器&#xff08;GC&#xff09;的选择至关重要。对于高并发、响应时间敏感的应用而言&#xff0c;G1 GC&#xff08;Garbage-First Garbage Collector&#xff09;无疑是一个强大的工具。本文将深入探讨G1 GC适…

李一桐遭遇蜈蚣惊魂

李一桐遭遇“蜈蚣惊魂”&#xff01;刘宇宁展现真男人本色在娱乐圈的幕后&#xff0c;总有一些心跳加速的惊险。近日&#xff0c;李一桐在拍戏时遭遇了一场“蜈蚣惊魂”&#xff0c;让无数粉丝和网友为她捏了一把冷汗。而在这场惊险的遭遇中&#xff0c;刘宇宁展现出了真男人的…

ActiveMq工具之管理页面说明

文章目录 安装ActiveMQ一: 访问管理页面二: 进入管理页面&#xff0c;主页三: Queues页说明四: Topics页说明五: Subscribers页说明 安装ActiveMQ wget https://archive.apache.org/dist//activemq/5.13.3/apache-activemq-5.13.3-bin.tar.gz wget https://mirrors.huaweiclou…

为什么越来越多的企业选择外包?赋能企业未来

软件开发过程包括设计需求、设计方案、产品研发、产品交付、后期维护&#xff0c;许多企业并沒有软件开发的专业能力与工作经验&#xff0c;将软件开发工作进行外包是比较节约成本的&#xff0c;企业能少走不少弯路。 YesPMP平台&#xff08;一站式软件外包、项目外包服务-YesP…

UWA Pipeline 2.6.1版本更新

UWA Pipeline是专为游戏开发团队设计的本地协作平台&#xff0c;旨在帮助团队建立专业的DevOps研发交付流水线。本平台提供了可视化的CI/CD操作界面&#xff0c;高可用的自动化测试和无缝集成的UWA性能保障服务等核心功能。 在最新的Pipeline更新中&#xff0c;UWA引入了参数配…

protobufjs解析proto消息出错RangeError: index out of range: 2499 + 10 > 2499解决办法

使用websocket通讯传输protobuf消息的时候&#xff0c;decode的时候出错了&#xff1a; RangeError: index out of range: 2499 10 > 2499 Error: invalid wire type 4 at offset 1986 出现这种错误的时候&#xff0c;99%是因为proto里面的消息类型和服务端发送的消息类型不…

vue表头字段添加鼠标悬浮提示

<el-table-column prop"jfScore" align"center" min-width"100px"><template slot"header" slot-scope"scope"><div><span>信用积分</span><el-tooltip:aa"scope"class"it…

关于windows,wifi图标显示不了的解决方法

解决方法一&#xff08;解决了我的问题的方法&#xff09;&#xff1a; winr -->输入 regedit 打开注册表 --> 删除HKEY-CLASSES_ROOT\CLSID\{3d09c1ca-2bcc-40b7-b9bb-3f3ec143a87b} CLSID在下面仔细找&#xff0c;然后找到09开头那个删掉重启就可以了&#xff0c;可能…

CAS服务端部署

部署CAS Cas服务端其实就是一个war包。 在资源\cas\source\cas-server-4.0.0-release\cas-server-4.0.0\modules目录下cas-server-webapp-4.0.0.war 将其改名为cas.war放入tomcat目录下的webapps下。启动tomcat自动解压war包。浏览器输入 登录页面 http://localhost:8080/ca…

20240701 每日AI必读资讯

&#x1f3eb;AI真炼丹&#xff1a;整整14天&#xff0c;无需人类参与 - 英矽智能推出全球首个AI参与决策的生物学实验室&#xff0c;实现了14天内完成靶点发现和验证的全自动化闭环实验。 - 该实验室由PandaOmics平台驱动&#xff0c;集成多种预测模型和海量数据&#xff0…

【Python】从基础到进阶(二):了解Python语言基础以及数据类型转换、基础输入输出

&#x1f525; 个人主页&#xff1a;空白诗 文章目录 一、引言二、基本数据类型转换1. 隐式转换2. 显式转换 三、基本输入输出1. 输入&#xff08;input&#xff09;2. 输出&#xff08;print&#xff09;3. 案例&#xff1a;输入姓名、年龄、身高以及体重&#xff0c;计算BMI指…

DM表级触发器

可以理解为行变动级 触发体中写逻辑 这是表修改时调用存储过程 感谢大哥分享: https://blog.csdn.net/WuLex/article/details/83181449 感谢大哥分享: https://blog.csdn.net/ChennyWJS/article/details/131913198