nn.functional 和 nn.Module入门讲解

本文来自《20天吃透Pytorch》

一,nn.functional 和 nn.Module

前面我们介绍了Pytorch的张量的结构操作和数学运算中的一些常用API。

利用这些张量的API我们可以构建出神经网络相关的组件(如激活函数,模型层,损失函数)。

Pytorch和神经网络相关的功能组件大多都封装在 torch.nn模块下。

这些功能组件的绝大部分既有函数形式实现,也有类形式实现。

其中nn.functional(一般引入后改名为F)有各种功能组件的函数实现。例如:

(激活函数) * F.relu * F.sigmoid * F.tanh * F.softmax
(模型层) * F.linear * F.conv2d * F.max_pool2d * F.dropout2d * F.embedding
(损失函数) * F.binary_cross_entropy * F.mse_loss * F.cross_entropy

为了便于对参数进行管理,一般通过继承 nn.Module 转换成为类的实现形式,并直接封装在 nn 模块下。例如:

(激活函数) * nn.ReLU * nn.Sigmoid * nn.Tanh * nn.Softmax
(模型层) * nn.Linear * nn.Conv2d * nn.MaxPool2d * nn.Dropout2d * nn.Embedding
(损失函数) * nn.BCELoss * nn.MSELoss * nn.CrossEntropyLoss

二,使用nn.Module来管理参数

在Pytorch中,模型的参数是需要被优化器训练的,因此,通常要设置参数为 requires_grad = True 的张量。
同时,在一个模型中,往往有许多的参数,要手动管理这些参数并不是一件容易的事情。
Pytorch一般将参数用nn.Parameter来表示,并且用nn.Module来管理其结构下的所有参数。

# nn.Parameter 具有 requires_grad = True 属性
w = nn.Parameter(torch.randn(2,2))
print(w)
print(w.requires_grad)# nn.ParameterList 可以将多个nn.Parameter组成一个列表
params_list = nn.ParameterList([nn.Parameter(torch.rand(8,i)) for i in range(1,3)])
print(params_list)
print(params_list[0].requires_grad)# nn.ParameterDict 可以将多个nn.Parameter组成一个字典params_dict = nn.ParameterDict({"a":nn.Parameter(torch.rand(2,2)),"b":nn.Parameter(torch.zeros(2))})
print(params_dict)
print(params_dict["a"].requires_grad)# 可以用Module将它们管理起来
# module.parameters()返回一个生成器,包括其结构下的所有parametersmodule = nn.Module()
module.w = w
module.params_list = params_list
module.params_dict = params_dictnum_param = 0
for param in module.parameters():print(param,"\n")num_param = num_param + 1
print("number of Parameters =",num_param)#实践当中,一般通过继承nn.Module来构建模块类,并将所有含有需要学习的参数的部分放在构造函数中。#以下范例为Pytorch中nn.Linear的源码的简化版本
#可以看到它将需要学习的参数放在了__init__构造函数中,并在forward中调用F.linear函数来实现计算逻辑。class Linear(nn.Module):__constants__ = ['in_features', 'out_features']def __init__(self, in_features, out_features, bias=True):super(Linear, self).__init__()self.in_features = in_featuresself.out_features = out_featuresself.weight = nn.Parameter(torch.Tensor(out_features, in_features))if bias:self.bias = nn.Parameter(torch.Tensor(out_features))else:self.register_parameter('bias', None)def forward(self, input):return F.linear(input, self.weight, self.bias)

三,使用nn.Module来管理子模块

实际上nn.Module除了可以管理其引用的各种参数,还可以管理其引用的子模块,功能十分强大。

一般情况下,我们都很少直接使用 nn.Parameter来定义参数构建模型,而是通过一些拼装一些常用的模型层来构造模型。

这些模型层也是继承自nn.Module的对象,本身也包括参数,属于我们要定义的模块的子模块。

nn.Module提供了一些方法可以管理这些子模块。

children() 方法: 返回生成器,包括模块下的所有子模块。

named_children()方法:返回一个生成器,包括模块下的所有子模块,以及它们的名字。

modules()方法:返回一个生成器,包括模块下的所有各个层级的模块,包括模块本身。

named_modules()方法:返回一个生成器,包括模块下的所有各个层级的模块以及它们的名字,包括模块本身。

其中chidren()方法和named_children()方法较多使用。

modules()方法和named_modules()方法较少使用,其功能可以通过多个named_children()的嵌套使用实现。

i = 0
for child in net.children():i+=1print(child,"\n")
print("child number",i)
i = 0
for name,child in net.named_children():i+=1print(name,":",child,"\n")
print("child number",i)
i = 0
for module in net.modules():i+=1print(module)
print("module number:",i)

下面我们通过named_children方法找到embedding层,并将其参数设置为不可训练(相当于冻结embedding层)。

children_dict = {name:module for name,module in net.named_children()}print(children_dict)
embedding = children_dict["embedding"]
embedding.requires_grad_(False) #冻结其参数

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/389364.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

10.30PMP试题每日一题

SC>0&#xff0c;CPI<1&#xff0c;说明项目截止到当前&#xff1a;A、进度超前&#xff0c;成本超值B、进度落后&#xff0c;成本结余C、进度超前&#xff0c;成本结余D、无法判断 答案将于明天和新题一起揭晓&#xff01; 10.29试题答案&#xff1a;A转载于:https://bl…

02-web框架

1 while True:print(server is waiting...)conn, addr server.accept()data conn.recv(1024) print(data:, data)# 1.得到请求的url路径# ------------dict/obj d["path":"/login"]# d.get(”path“)# 按着http请求协议解析数据# 专注于web业…

ai驱动数据安全治理_AI驱动的Web数据收集解决方案的新起点

ai驱动数据安全治理Data gathering consists of many time-consuming and complex activities. These include proxy management, data parsing, infrastructure management, overcoming fingerprinting anti-measures, rendering JavaScript-heavy websites at scale, and muc…

从Text文本中读值插入到数据库中

/// <summary> /// 转换数据&#xff0c;从Text文本中导入到数据库中 /// </summary> private void ChangeTextToDb() { if(File.Exists("Storage Card/Zyk.txt")) { try { this.RecNum.Visibletrue; SqlCeCommand sqlCreateTable…

Dataset和DataLoader构建数据通道

重点在第二部分的构建数据通道和第三部分的加载数据集 Pytorch通常使用Dataset和DataLoader这两个工具类来构建数据管道。 Dataset定义了数据集的内容&#xff0c;它相当于一个类似列表的数据结构&#xff0c;具有确定的长度&#xff0c;能够用索引获取数据集中的元素。 而D…

铁拳nat映射_铁拳如何重塑我的数据可视化设计流程

铁拳nat映射It’s been a full year since I’ve become an independent data visualization designer. When I first started, projects that came to me didn’t relate to my interests or skills. Over the past eight months, it’s become very clear to me that when cl…

Django2 Web 实战03-文件上传

作者&#xff1a;Hubery 时间&#xff1a;2018.10.31 接上文&#xff1a;接上文&#xff1a;Django2 Web 实战02-用户注册登录退出 视频是一种可视化媒介&#xff0c;因此视频数据库至少应该存储图像。让用户上传文件是个很大的隐患&#xff0c;因此接下来会讨论这俩话题&#…

BZOJ.2738.矩阵乘法(整体二分 二维树状数组)

题目链接 BZOJ洛谷 整体二分。把求序列第K小的树状数组改成二维树状数组就行了。 初始答案区间有点大&#xff0c;离散化一下。 因为这题是一开始给点&#xff0c;之后询问&#xff0c;so可以先处理该区间值在l~mid的修改&#xff0c;再处理询问。即二分标准可以直接用点的标号…

从数据库里读值往TEXT文本里写

/// <summary> /// 把预定内容导入到Text文档 /// </summary> private void ChangeDbToText() { this.RecNum.Visibletrue; //建立文件&#xff0c;并打开 string oneLine ""; string filename "Storage Card/YD" DateTime.Now.…

DengAI —如何应对数据科学竞赛? (EDA)

了解机器学习 (Understanding ML) This article is based on my entry into DengAI competition on the DrivenData platform. I’ve managed to score within 0.2% (14/9069 as on 02 Jun 2020). Some of the ideas presented here are strictly designed for competitions li…

Pytorch模型层简单介绍

模型层layers 深度学习模型一般由各种模型层组合而成。 torch.nn中内置了非常丰富的各种模型层。它们都属于nn.Module的子类&#xff0c;具备参数管理功能。 例如&#xff1a; nn.Linear, nn.Flatten, nn.Dropout, nn.BatchNorm2d nn.Conv2d,nn.AvgPool2d,nn.Conv1d,nn.Co…

有效沟通的技能有哪些_如何有效地展示您的数据科学或软件工程技能

有效沟通的技能有哪些What is the most important thing to do after you got your skills to be a data scientist? It has to be to show off your skills. Otherwise, there is no use of your skills. If you want to get a job or freelance or start a start-up, you ha…

java.net.SocketException: Software caused connection abort: socket write erro

场景&#xff1a;接口测试 编辑器&#xff1a;eclipse 版本&#xff1a;Version: 2018-09 (4.9.0) testng版本&#xff1a;TestNG version 6.14.0 执行testng.xml时报错信息&#xff1a; 出现此报错原因之一&#xff1a;网上有人说是testng版本与eclipse版本不一致造成的&#…

[博客..配置?]博客园美化

博客园搞定时间 -> 18年6月27日 [让我歇会儿 搞这个费脑子 代码一个都看不懂] 转载于:https://www.cnblogs.com/Steinway/p/9235437.html

使用K-Means对美因河畔法兰克福的社区进行聚类

介绍 (Introduction) This blog post summarizes the results of the Capstone Project in the IBM Data Science Specialization on Coursera. Within the project, the districts of Frankfurt am Main in Germany shall be clustered according to their venue data using t…

Pytorch损失函数losses简介

一般来说&#xff0c;监督学习的目标函数由损失函数和正则化项组成。(Objective Loss Regularization) Pytorch中的损失函数一般在训练模型时候指定。 注意Pytorch中内置的损失函数的参数和tensorflow不同&#xff0c;是y_pred在前&#xff0c;y_true在后&#xff0c;而Ten…

读取Mc1000的 唯一 ID 机器号

先引用Symbol.ResourceCoordination 然后引用命名空间 using System;using System.Security.Cryptography;using System.IO; 以下为类程序 /// <summary> /// 获取设备id /// </summary> /// <returns></returns> public static string GetDevi…

样本均值的抽样分布_抽样分布样本均值

样本均值的抽样分布One of the most important concepts discussed in the context of inferential data analysis is the idea of sampling distributions. Understanding sampling distributions helps us better comprehend and interpret results from our descriptive as …

玩转ceph性能测试---对象存储(一)

笔者最近在工作中需要测试ceph的rgw&#xff0c;于是边测试边学习。首先工具采用的intel的一个开源工具cosbench&#xff0c;这也是业界主流的对象存储测试工具。 1、cosbench的安装&#xff0c;启动下载最新的cosbench包wget https://github.com/intel-cloud/cosbench/release…

[BZOJ 4300]绝世好题

Description 题库链接 给定一个长度为 \(n\) 的数列 \(a_i\) &#xff0c;求 \(a_i\) 的子序列 \(b_i\) 的最长长度&#xff0c;满足 \(b_i\wedge b_{i-1}\neq 0\) &#xff08; \(\wedge\) 表示按位与&#xff09; \(1\leq n\leq 100000\) Solution 令 \(f_i\) 为二进制第 \(i…