深度学习之pytorch 中 torch.nn介绍

1. torch.nn 介绍

pytorch 中必用的包就是 torch.nn,torch.nn 中按照功能分,主要如下有几类:

1. Layers(层):包括全连接层、卷积层、池化层等。
2. Activation Functions(激活函数):包括ReLU、Sigmoid、Tanh等。
3. Loss Functions(损失函数):包括交叉熵损失、均方误差等。
4. Optimizers(优化器):包括SGD、Adam、RMSprop等。
5. Initialization Functions(初始化函数):包括Xavier初始化、He初始化等。
6. Utilities(实用工具):包括数据处理的函数、模型构建的函数等。

2. Layers(层):

2.1 全连接层

在 `torch.nn` 中,全连接层的函数主要包括以下几个:
1. `torch.nn.Linear(in_features, out_features, bias=True)`: 创建一个线性层,其中in_features是输入特征的大小,out_features是输出特征的大小,bias表示是否使用偏置。
2. `torch.nn.Bilinear(in1_features, in2_features, out_features, bias=True)`: 创建一个双线性层,用于计算两组输入之间的双线性操作。
这些函数用于创建神经网络中的全连接层,其中 `torch.nn.Linear` 是应用最广泛的全连接层函数。

2.2 卷积层

在 `torch.nn` 中,卷积层的函数主要包括以下几个:
1. `torch.nn.Conv1d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)`: 创建一维卷积层。
2. `torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)`: 创建二维卷积层。
3. `torch.nn.Conv3d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)`: 创建三维卷积层。
这些函数用于创建神经网络中的卷积层。

2.3 池化层

在 `torch.nn` 中,池化层的函数主要包括以下几个:
1. `torch.nn.MaxPool1d(kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False)`: 创建一维最大池化层。
2. `torch.nn.MaxPool2d(kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False)`: 创建二维最大池化层。
3. `torch.nn.MaxPool3d(kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False)`: 创建三维最大池化层。
4. `torch.nn.AvgPool1d(kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True)`: 创建一维平均池化层。
5. `torch.nn.AvgPool2d(kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None)`: 创建二维平均池化层。
6. `torch.nn.AvgPool3d(kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None)`: 创建三维平均池化层。
这些函数用于创建神经网络中的不同类型的池化层。

3. Activation Functions(激活函数)

在 `torch.nn` 中,常见的激活函数包括:
1. `torch.nn.ReLU`: ReLU激活函数。
2. `torch.nn.ReLU6`: 限制线性整流激活函数,将负值截断为0,将正值截断为6。
3. `torch.nn.ELU`: 指数线性单元激活函数。
4. `torch.nn.Sigmoid`: Sigmoid激活函数。
5. `torch.nn.Tanh`: 双曲正切激活函数。
6. `torch.nn.Softmax`: Softmax激活函数,常用于多分类问题。
这些激活函数可用于神经网络的不同层中,以引入非线性特性和增加网络的表达能力。

4. Loss Functions(损失函数)

在 `torch.nn` 中,常见的损失函数包括:
1. `torch.nn.MSELoss`: 均方误差损失函数,用于回归问题。
2. `torch.nn.L1Loss`: 平均绝对误差损失函数,也用于回归问题。
3. `torch.nn.CrossEntropyLoss`: 交叉熵损失函数,通常用于多分类问题。
4. `torch.nn.NLLLoss`: 负对数似然损失函数,结合了LogSoftmax和负对数似然损失,通常用于多分类问题。
5. `torch.nn.BCELoss`: 二元交叉熵损失函数,通常用于二分类问题。
6. `torch.nn.BCEWithLogitsLoss`: 结合了Sigmoid和二元交叉熵损失的函数,通常用于二分类问题。
这些损失函数可以帮助定义神经网络的训练目标,并用于计算预测值与真实标签之间的误差。

5、Optimizers(优化器)

在 `torch.nn` 中,常见的优化器函数包括:
1. `torch.optim.SGD`: 随机梯度下降优化器。
2. `torch.optim.Adam`: Adam优化器,结合了动量和自适应学习率。
3. `torch.optim.Adagrad`: 自适应学习率优化器。
4. `torch.optim.RMSprop`: RMSprop优化器,结合了动量和自适应学习率。
5. `torch.optim.Adadelta`: Adadelta优化器,自适应学习率方法之一。
6. `torch.optim.Adamax`: Adamax优化器,基于Adam算法的变种。
这些优化器函数可用于调整神经网络的参数以最小化定义的损失函数,从而实现模型的训练和优化。

6. Initialization Functions(初始化函数)

在 `torch.nn` 中,常见的初始化函数包括:
1. `torch.nn.init.uniform_`: 均匀分布初始化。
2. `torch.nn.init.normal_`: 正态分布初始化。
3. `torch.nn.init.constant_`: 常数初始化。
4. `torch.nn.init.eye_`: 单位矩阵初始化。
5. `torch.nn.init.xavier_uniform_`: Xavier均匀分布初始化。
6. `torch.nn.init.xavier_normal_`: Xavier正态分布初始化。
7. `torch.nn.init.kaiming_uniform_`: Kaiming均匀分布初始化(用于ReLU激活函数)。
8. `torch.nn.init.kaiming_normal_`: Kaiming正态分布初始化(用于ReLU激活函数)。
这些初始化函数可用于初始化神经网络的权重和偏置,以帮助训练神经网络时更好地收敛和避免梯度消失或梯度爆炸等问题。

7. Utilities(实用工具)

在 `torch.nn` 中,有一些实用的工具函数,包括但不限于:
1. `torch.nn.functional.conv2d`: 二维卷积操作。
2. `torch.nn.functional.linear`: 线性变换,等同于全连接层。
3. `torch.nn.functional.relu`: ReLU激活函数。
4. `torch.nn.functional.max_pool2d`: 二维最大池化操作。
5. `torch.nn.functional.dropout`: 随机失活操作。
6. `torch.nn.functional.cross_entropy`: 交叉熵损失函数。
7. `torch.nn.functional.mse_loss`: 均方误差损失函数。
8. `torch.nn.functional.softmax`: softmax函数,常用于多分类问题。
这些工具函数提供了在神经网络中常见的操作,用于构建神经网络结构和定义损失函数。

8、其他

8.1 torch.nn.Conv2d 与 torch.nn.functional.conv2d 的区别

`torch.nn.Conv2d` 是一个类,用于创建二维卷积层,它是 PyTorch 中的一个模块。而 `torch.nn.functional.conv2d` 是一个函数,用于执行二维卷积操作。
区别在于:
- `torch.nn.Conv2d` 是一个类,它是一个包含可学习参数的卷积层,可以包含权重和偏置项,并且具有其他卷积层相关的属性和方法。
- `torch.nn.functional.conv2d` 是一个函数,它执行卷积操作,但是它不包含可学习的参数,需要手动输入卷积核和偏置项,通常用于在自定义的网络结构或者函数中执行卷积操作。


以下是它们的使用示例:
使用 `torch.nn.Conv2d` 创建一个简单的卷积层:

import torch
import torch.nn as nn
# 创建一个二维卷积层
conv_layer = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)
# 定义输入数据
input_data = torch.randn(1, 3, 32, 32)  # (batch_size, channels, height, width)
# 通过卷积层进行前向传播
output = conv_layer(input_data)
print(output.shape)  # 输出结果的形状

使用 `torch.nn.functional.conv2d` 执行二维卷积操作:

import torch
import torch.nn.functional as F
# 定义输入数据
input_data = torch.randn(1, 3, 32, 32)  # (batch_size, channels, height, width)
# 定义卷积核
weight = torch.randn(16, 3, 3, 3)  # (out_channels, in_channels, kernel_height, kernel_width)
# 执行二维卷积操作
output = F.conv2d(input_data, weight, bias=None, stride=1, padding=1)
print(output.shape)  # 输出结果的形状

8.2 torch.nn.Linear 与torch.nn.functional.linear的区别

`torch.nn.Linear` 是一个类,用于创建全连接层(也称为线性层),它是 PyTorch 中的一个模块。而 `torch.nn.functional.linear` 是一个函数,用于执行线性变换操作。
区别在于:
- `torch.nn.Linear` 是一个类,它是一个包含可学习参数的全连接层,可以包含权重和偏置项,并且具有其他全连接层相关的属性和方法。
- `torch.nn.functional.linear` 是一个函数,它执行线性变换操作,但是它不包含可学习的参数,需要手动输入权重和偏置项,通常用于在自定义的网络结构或者函数中执行线性变换操作。


下面是它们的使用示例:
使用 `torch.nn.Linear`:

import torch
import torch.nn as nn
# 创建一个线性变换模块
linear_module = nn.Linear(in_features=5, out_features=3)
# 构造输入张量
input_data = torch.randn(2, 5)  # 2个样本,每个样本5个特征
# 应用线性变换
output = linear_module(input_data)


使用 `torch.nn.functional.linear`:

import torch
import torch.nn.functional as F
# 构造输入张量和权重
input_data = torch.randn(2, 5)  # 2个样本,每个样本5个特征
weight = torch.randn(3, 5)      # 权重矩阵
# 应用线性变换
output = F.linear(input_data, weight)


在这两个示例中,都创建了输入数据并应用了线性变换,但是 `torch.nn.Linear`更适合用于创建神经网络模型,而 `torch.nn.functional.linear` 更适合在自定义的前向传播函数中使用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/693490.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

速盾网络:有什么不用备案的CDN吗?

备案是指将网站相关信息提交给当地通信管理部门进行备案登记,以确保网站的合法性和规范运营。然而,对于一些特定的CDN服务,是否存在不用备案的情况呢?让我们来了解一下速盾网络在这方面的看法和实践。 备案与CDN服务 在中国大陆…

在qml中的ShaderEffect在arm板的3568的系统上是用GPU渲染的吗

在QML中的ShaderEffect通常是利用GPU进行渲染的。 ShaderEffect 是 Qt Quick 提供的一个功能强大的组件,它允许开发者在 QML 层面实现像素级别的操作。这个组件的设计目的就是为了充分利用 GPU 的强大计算能力来进行图形渲染。因此,当你在 QML 中使用 S…

挑战!贪吃蛇小游戏的实现(3)

经过(1)(2)两篇文章的介绍,相信大家对该游戏的实现已经有了具体的思路,废话不多说,让我们开始实现相关的代码吧! 1.游戏主逻辑 void test() {int ch 0;srand((unsigned int)time(NU…

【Unity3D】ASE制作天空盒

找到官方shader并分析 下载对应资源包找到\DefaultResourcesExtra\Skybox-Cubed.shader找到\CGIncludes\UnityCG.cginc观察变量, 观察tag, 观察代码 需要注意的内容 ASE要处理的内容 核心修改 添加一个Custom Expression节点 code内容为: return DecodeHDR(In0, In1);outp…

JavaSpringBoot中,Mybatis plus 语法展示

目录 语法展示 基础的增删改查 分页查询 语法指导 删除操作 条件操作 语法展示 Mapper public interface UserMapper extends BaseMapper<User> {} public interface UserService extends IService<User> {} Service public class UserServiceImpl extends…

在Win系统部署WampServer并实现公网访问本地服务【内网穿透】

目录 推荐 前言 1.WampServer下载安装 2.WampServer启动 3.安装cpolar内网穿透 3.1 注册账号 3.2 下载cpolar客户端 3.3 登录cpolar web ui管理界面 3.4 创建公网地址 4.固定公网地址访问 推荐 前些天发现了一个巨牛的人工智能学习网站&#xff0c;通俗易懂&#xff0…

企业资产|企业资产管理系统|基于springboot企业资产管理系统设计与实现(源码+数据库+文档)

企业资产管理系统目录 目录 基于springboot企业资产管理系统设计与实现 一、前言 二、系统功能设计 三、系统实现 1、用户信息管理 2、用户审核管理 3、资产分类管理 4、资产信息管理 5、资产信息添加 6、资产借出统计 7、资产归还审核 8、资产维修管理 9、资产维修…

事务的4大特性、隔离级别、传播机制

目录 一、4大特性&#xff08;ACID&#xff09;二、隔离级别三、传播机制 一、4大特性&#xff08;ACID&#xff09; 原子性&#xff08;A&#xff09;&#xff1a;在一个事务中&#xff0c;要么全部成功&#xff0c;要么全部失败。一致性&#xff08;C&#xff09;&#xff1…

effective c++ 笔记 条款26-31

条款 26&#xff1a;尽可能延后变量定义式出现的时间 应该延后变量的定义&#xff0c;直到非得使用该变量的前一刻为止&#xff0c;甚至应该尝试延后这份定义直到能够给它初值实参为止&#xff0c;以此避免构造&#xff08;和析构&#xff09;非必要对象&#xff0c;还可以避免…

c++笔记理解

1.封装 &#xff08;1&#xff09;构造函数不是必须在的 可以通过行为修改属性 &#xff08;2&#xff09;private和protected区别在于继承那里要学 &#xff08;3&#xff09;类默认是私有&#xff0c;struct是共有 私有的好处&#xff1a;控制数据的有效性&#xff0c;意…

编程笔记 Golang基础 012 项目构建

编程笔记 Golang基础 012 项目构建 一、模块&#xff08;Module&#xff09;、包&#xff08;Package&#xff09;和文件二、项目结构三、VsCode项目管理四、Goland项目管理五、工作空间小结 如何构建和组织一个项目&#xff0c;是学习该语言编程的开始。 一、模块&#xff08;…

MySQL 8.0.36 WorkBench安装

一、下载安装包 百度网盘链接&#xff1a;点击此处下载安装文件 提取码&#xff1a;hhwz 二、安装&#xff0c;跟着图片来 选择Custom,然后点Next 顺着左边框每一项的加号打开到每一个项的最底层&#xff0c;点击选中最底层的项目&#xff0c;再点击传过去右边的绿色箭头&a…

Codeforces Round 530 (Div. 2)

CF1099A Snowball 题目 有一个重量为 w 的雪球正在高度为 h 的地方向下滚动。每秒它的高度会减少 1。同时在高度 i 的位置它的重量会增加 i&#xff08;包括初始位置&#xff09; 同时在滚动的路线上有 2 块石头&#xff0c;第 i 块石头的高度为 hi​&#xff0c;即雪球会在 hi…

【论文阅读|基于 YOLO 的红外小目标检测的逆向范例】

基于 YOLO 的红外小目标检测的逆向范例 摘要1 引言2 相关工作2.1 逆向推理2.2 物体检测方法 3 方法3.1 总体架构3.2 逆向标准的可微分积分 4 实验4.1 数据集和指标4.2 实验环境4.4 OL-NFA 为少样本环境带来稳健性 5 结论 论文题目&#xff1a; A Contrario Paradigm for YOLO-b…

详解 leetcode_078. 合并K个升序链表.小顶堆实现

/*** 构造单链表节点*/ class ListNode{int value;//节点值ListNode next;//指向后继节点的引用public ListNode(){}public ListNode(int value){this.valuevalue;}public ListNode(int value,ListNode next){this.valuevalue;this.nextnext;} }package com.ag; import java.ut…

[树形DP] 最长乘积链

题目 1.最长乘积链 - 蓝桥云课 (lanqiao.cn) 初始思路 对问题进行分析&#xff0c;对每个点dfs去求走不同路的最远距离与次远距离求乘积&#xff0c;时间复杂度为O(n^2) 看了答案怎么弄的优化 解题思路 总的来说 预处理&#xff08;对每个结点的信息进行统计&#xff09…

AWS无服务器直播解决方案

随着媒体系统的发展&#xff0c;越来越多的直播客户想要一个即开即用的平台&#xff0c;在不需要管理和运维底层资源的同时使用一站式的媒体平台。九河云对多家云厂商有所了解及有一定合作&#xff0c;下面将按客户的需求介绍aws的无服务器直播解决方案。 架构概述&#xff1a…

Flutter插件开发指南02: 事件订阅 EventChannel

Flutter插件开发指南02: 事件订阅 EventChannel 视频 https://www.bilibili.com/video/BV1zj411d7k4/ 前言 上一节我们讲了 Channel 通道&#xff0c;但是如果你是卫星定位业务&#xff0c;原生端主动推消息给 Flutter 这时候就要用到 EventChannel 通道了。 本节会写一个 1~…

HarmonyOS 权限 介绍

权限说明 权限等级 根据权限对于不同等级应用有不同的开放范围&#xff0c;权限类型对应分为以下三种&#xff0c;等级依次提高。 normal权限 normal 权限允许应用访问超出默认规则外的普通系统资源。 这些系统资源的开放&#xff08;包括数据和功能&#xff09;对用户隐私以及…