[pytorch、学习] - 4.4 自定义层

参考

4.4 自定义层

深度学习的一个魅力在于神经网络中各式各样的层,例如全连接层和后面章节将要用介绍的卷积层、池化层与循环层。虽然PyTorch提供了大量常用的层,但有时候我们依然希望自定义层。本节将介绍如何使用Module来自定义层,从而可以被重复调用。

4.4.1 不含模型参数的自定义层

我们先介绍如何定义一个不含模型参数的自定义层。

import torch
from torch import nnclass CenteredLayer(nn.Module):def __init__(self, **kwargs):super(CenteredLayer, self).__init__(**kwargs)def forward(self, x):return x - x.mean()
layer = CenteredLayer()
layer(torch.tensor([1, 2, 3, 4, 5], dtype=torch.float))

在这里插入图片描述
我们也可以用它来构造更复杂的模型。

net = nn.Sequential(nn.Linear(8, 128), CenteredLayer())
y = net(torch.rand(4, 8))
y.mean().item()

在这里插入图片描述

4.4.2 含模型参数的自定义层

我们还可以自定义含模型参数的自定义层。其中的模型参数可以通过训练学习。
Parameter类其实是Tensor的子类,如果一个Tensor是Parameter,那么它会自动被添加到模型的参数列表里。所以在自定义含模型参数的层时,我们应该将参数定义成Parameter,除了像4.2.1节那样直接定义成Parameter类外,还可以使用ParameterListParameterDict分别定义参数的列表和字典。

ParameterList接收一个Parameter实例的列表作为输入然后得到一个参数列表,使用的时候可以用索引来访问某个参数,另外也可以使用appendextend在列表后面新增参数。

class MyDense(nn.Module):def __init__(self):super(MyDense, self).__init__()self.params = nn.ParameterList([nn.Parameter(torch.randn(4, 4)) for i in range(3)])self.params.append(nn.Parameter(torch.randn(4, 1)))def forward(self, x):for i in range(len(self.params)):x =  torch.mm(x, self.params[i])return xnet = MyDense()
print(net)

在这里插入图片描述
ParameterDict接收一个Parameter实例的字典作为输入然后得到一个参数字典,然后可以按照字典的规则使用了。

class MyDictDense(nn.Module):def __init__(self):super(MyDictDense, self).__init__()self.params = nn.ParameterDict({'linear1': nn.Parameter(torch.randn(4, 4)),'linear2': nn.Parameter(torch.randn(4, 1))})self.params.update({'linear3': nn.Parameter(torch.randn(4, 2))})def forward(self, x, choice='linear1'):return torch.mm(x, self.params[choice])net = MyDictDense()
print(net)

在这里插入图片描述

x = torch.ones(1, 4)
print(net(x, 'linear1'))
print(net(x, 'linear2'))
print(net(x, 'linear3'))

在这里插入图片描述
我们也可以使用自定义层构造模型。它和PyTorch的其他层在使用上很类似。

net = nn.Sequential(MyDictDense(),MyDictDense()
)
print(net)
print(net(x))

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/250158.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

树的存储

父亲表示法 顾名思义,就是只记录每个结点的父结点。 int n; int p[MAX_N]; // 指向每个结点的父结点 孩子表示法 如上,就是只记录每个结点的子结点。 int n; int cnt[MAX_N]; // 记录每个结点的子结点的数量 int p[MAX_N][MAX_CNT]; // 指向每个结点的子…

spring-boot注解详解(四)

repository repository跟Service,Compent,Controller这4种注解是没什么本质区别,都是声明作用,取不同的名字只是为了更好区分各自的功能.下图更多的作用是mapper注册到类似于以前mybatis.xml中的mappers里. 也是因为接口没办法在spring.xml中用bean的方式来配置实现类吧(接口…

令人叫绝的EXCEL函数功能

http://club.excelhome.net/thread-166725-1-1.html https://wenku.baidu.com/view/db319da0bb0d4a7302768e9951e79b8969026864.html转载于:https://www.cnblogs.com/cqufengchao/articles/9150401.html

[pytorch、学习] - 4.5 读取和存储

参考 4.5 读取和存储 到目前为止,我们介绍了如何处理数据以及如何构建、训练和测试深度学习模型。然而在实际中,我们有时需要把训练好的模型部署到很多不同的设备。在这种情况下,我们可以把内存中训练好的模型参数存储在硬盘上供后续读取使用。 4.5.1 读写tensor 我们可以直…

JAVA排序的方法

//冒泡排序法: package fuxi;public class Bubble { public static void main(String[] args) { int a[] { 10,23,11,56,45,26,59,28,84,79 }; int i,temp; System.out.println("输出原始数组数据:"); for (i…

spring-boot注解详解(五)

AutoWired 首先要知道另一个东西,default-autowire,它是在xml文件中进行配置的,可以设置为byName、byType、constructor和autodetect;比如byName,不用显式的在bean中写出依赖的对象,它会自动的匹配其它bea…

什么是p12证书?ios p12证书怎么获取?

.cer是苹果的默认证书,在xcode开发打包可以使用,如果在lbuilder、phonegap、HBuilder、AppCan、APICloud这些跨平台开发工具打包,就需要用到p12文件。 .cer证书仅包含公钥,.p12证书可能既包含公钥也包含私钥,这就是他们…

[pytorch、学习] - 4.6 GPU计算

参考 4.6 GPU计算 到目前为止,我们一直使用CPU进行计算。对复杂的神经网络和大规模数据来说,使用CPU来计算可能不够高效。 在本节中,将要介绍如何使用单块NIVIDA GPU进行计算 4.6.1 计算设备 PyTorch可以指定用来存储和计算的设备,如果用内存的CPU或者显存的GPU。默认情况下…

adb connect 192.168.1.10 failed to connect to 192.168.1.10:5555

adb connect 192.168.1.10 输出 failed to connect to 192.168.1.10:5555 关闭安卓端Wi-Fi,重新打开连接即可 转载于:https://www.cnblogs.com/sea-stream/p/10020995.html

创建oracle数据库表空间并分配用户

我们在本地的oracle上或者virtualbox的oracle上 创建新的数据库表空间操作:通过system账号来创建并授权/*--创建表空间create tablespace YUJKDATAdatafile c:\yujkdata200.dbf --指定表空间对应的datafile文件的具体的路径size 100mautoextend onnext 10m*/ /*--创…

spring-boot注解详解(六)

Target Target说明了Annotation所修饰的对象范围:Annotation可被用于 packages、types(类、接口、枚举、Annotation类型)、类型成员(方法、构造方法、成员变量、枚举值)、方法参数和本地变量(如循环变量、…

[pytorch、学习] - 5.1 二维卷积层

参考 5.1 二维卷积层 卷积神经网络(convolutional neural network)是含有卷积层(convolutional layer)的神经网络。本章介绍的卷积神经网络均使用最常见的二维卷积层。它有高和宽两个空间维度,常用来处理图像数据。本节中,我们将介绍简单形式的二维卷积层的工作原理。 5.1.1…

[51CTO]给您介绍Windows10各大版本之间区别

给您介绍Windows10各大版本之间区别 随着win10的不断普及和推广,越来越多的朋友想安装win10系统了,但是很多朋友不知道win10哪个版本好用,为了让大家能够更好的选择win10系统版本,下面小编就来告诉你 http://os.51cto.com/art/201…

iOS中NSString转换成HEX(十六进制)-NSData转换成int

NSString *str "0xff055008"; //先以16为参数告诉strtoul字符串参数表示16进制数字,然后使用0x%X转为数字类型 unsigned long red strtoul([str UTF8String],0,16); //strtoul如果传入的字符开头是“0x”,那么第三个参数是0,也是会转为十…

spring-boot注解详解(七)

Configuration 从Spring3.0,Configuration用于定义配置类,可替换xml配置文件,被注解的类内部包含有一个或多个被Bean注解的方法,这些方法将会被AnnotationConfigApplicationContext或AnnotationConfigWebApplicationContext类进行…

[pytorch、学习] - 5.2 填充和步幅

参考 5.2 填充和步幅 5.2.1 填充 填充(padding)是指在输入高和宽的两侧填充元素(通常是0元素)。图5.2里我们在原输入高和宽的两侧分别添加了值为0的元素,使得输入高和宽从3变成了5,并导致输出高和宽由2增加到4。图5.2中的阴影部分为第一个输出元素及其计算所使用的输入和核数…

java实现Comparable接口和Comparator接口,并重写compareTo方法和compare方法

原文地址https://segmentfault.com/a/1190000005738975 实体类:java.lang.Comparable(接口) comareTo(重写方法),业务排序类 java.util.Comparator(接口) compare(重写方法). 这两个接口我们非常的熟悉,但是 在用的时候会有一些不知道怎么下手的感觉&a…

hdu 4714 树+DFS

题目链接:http://acm.hdu.edu.cn/showproblem.php?pid4714 本来想直接求树的直径,再得出答案,后来发现是错的。 思路:任选一个点进行DFS,对于一棵以点u为根节点的子树来说,如果它的分支数大于1&#xff0c…

springboot----shiro集成

springboot中集成shiro相对简单,只需要两个类:一个是shiroConfig类,一个是CustonRealm类。 ShiroConfig类: 顾名思义就是对shiro的一些配置,相对于之前的xml配置。包括:过滤的文件和权限,密码加…

[pytorch、学习] - 5.3 多输入通道和多输出通道

参考 5.3 多输入通道和多输出通道 前面两节里我们用到的输入和输出都是二维数组,但真实数据的维度经常更高。例如,彩色图像在高和宽2个维度外还有RGB(红、绿、蓝)3个颜色通道。假设彩色图像的高和宽分别是h和w(像素),那么它可以表示为一个3 * h * w的多维数组。我们将大小为3…