Pytorch构建模型的3种方法

这个地方一直是我思考的地方!因为学的代码太多了,构建的模型各有不同,这里记录一下!
可以使用以下3种方式构建模型:

1,继承nn.Module基类构建自定义模型。

2,使用nn.Sequential按层顺序构建模型。

3,继承nn.Module基类构建模型并辅助应用模型容器进行封装(nn.Sequential,nn.ModuleList,nn.ModuleDict)。

其中 第1种方式最为常见,第2种方式最简单,第3种方式最为灵活也较为复杂。

推荐使用第1种方式构建模型。

头文件:

import torch 
from torch import nn

一,继承nn.Module基类构建自定义模型

以下是继承nn.Module基类构建自定义模型的一个范例。模型中的用到的层一般在__init__函数中定义,然后在forward方法中定义模型的正向传播逻辑。

class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(in_channels=3,out_channels=32,kernel_size = 3)self.pool1 = nn.MaxPool2d(kernel_size = 2,stride = 2)self.conv2 = nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5)self.pool2 = nn.MaxPool2d(kernel_size = 2,stride = 2)self.dropout = nn.Dropout2d(p = 0.1)self.adaptive_pool = nn.AdaptiveMaxPool2d((1,1))self.flatten = nn.Flatten()self.linear1 = nn.Linear(64,32)self.relu = nn.ReLU()self.linear2 = nn.Linear(32,1)self.sigmoid = nn.Sigmoid()def forward(self,x):x = self.conv1(x)x = self.pool1(x)x = self.conv2(x)x = self.pool2(x)x = self.dropout(x)x = self.adaptive_pool(x)x = self.flatten(x)x = self.linear1(x)x = self.relu(x)x = self.linear2(x)y = self.sigmoid(x)return ynet = Net()
print(net)

二,使用nn.Sequential按层顺序构建模型

使用nn.Sequential按层顺序构建模型无需定义forward方法。仅仅适合于简单的模型。

以下是使用nn.Sequential搭建模型的一些等价方法。

1,利用add_module方法

net = nn.Sequential()
net.add_module("conv1",nn.Conv2d(in_channels=3,out_channels=32,kernel_size = 3))
net.add_module("pool1",nn.MaxPool2d(kernel_size = 2,stride = 2))
net.add_module("conv2",nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5))
net.add_module("pool2",nn.MaxPool2d(kernel_size = 2,stride = 2))
net.add_module("dropout",nn.Dropout2d(p = 0.1))
net.add_module("adaptive_pool",nn.AdaptiveMaxPool2d((1,1)))
net.add_module("flatten",nn.Flatten())
net.add_module("linear1",nn.Linear(64,32))
net.add_module("relu",nn.ReLU())
net.add_module("linear2",nn.Linear(32,1))
net.add_module("sigmoid",nn.Sigmoid())print(net)

2,利用变长参数

这种方式构建时不能给每个层指定名称。

net = nn.Sequential(nn.Conv2d(in_channels=3,out_channels=32,kernel_size = 3),nn.MaxPool2d(kernel_size = 2,stride = 2),nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5),nn.MaxPool2d(kernel_size = 2,stride = 2),nn.Dropout2d(p = 0.1),nn.AdaptiveMaxPool2d((1,1)),nn.Flatten(),nn.Linear(64,32),nn.ReLU(),nn.Linear(32,1),nn.Sigmoid()
)print(net)

3,利用OrderedDict

from collections import OrderedDictnet = nn.Sequential(OrderedDict([("conv1",nn.Conv2d(in_channels=3,out_channels=32,kernel_size = 3)),("pool1",nn.MaxPool2d(kernel_size = 2,stride = 2)),("conv2",nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5)),("pool2",nn.MaxPool2d(kernel_size = 2,stride = 2)),("dropout",nn.Dropout2d(p = 0.1)),("adaptive_pool",nn.AdaptiveMaxPool2d((1,1))),("flatten",nn.Flatten()),("linear1",nn.Linear(64,32)),("relu",nn.ReLU()),("linear2",nn.Linear(32,1)),("sigmoid",nn.Sigmoid())]))
print(net)

三,继承nn.Module基类构建模型并辅助应用模型容器进行封装

当模型的结构比较复杂时,我们可以应用模型容器(nn.Sequential,nn.ModuleList,nn.ModuleDict)对模型的部分结构进行封装。

这样做会让模型整体更加有层次感,有时候也能减少代码量。

注意,在下面的范例中我们每次仅仅使用一种模型容器,但实际上这些模型容器的使用是非常灵活的,可以在一个模型中任意组合任意嵌套使用。

1,nn.Sequential作为模型容器

class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv = nn.Sequential(nn.Conv2d(in_channels=3,out_channels=32,kernel_size = 3),nn.MaxPool2d(kernel_size = 2,stride = 2),nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5),nn.MaxPool2d(kernel_size = 2,stride = 2),nn.Dropout2d(p = 0.1),nn.AdaptiveMaxPool2d((1,1)))self.dense = nn.Sequential(nn.Flatten(),nn.Linear(64,32),nn.ReLU(),nn.Linear(32,1),nn.Sigmoid())def forward(self,x):x = self.conv(x)y = self.dense(x)return y net = Net()
print(net)

2,nn.ModuleList作为模型容器

class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.layers = nn.ModuleList([nn.Conv2d(in_channels=3,out_channels=32,kernel_size = 3),nn.MaxPool2d(kernel_size = 2,stride = 2),nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5),nn.MaxPool2d(kernel_size = 2,stride = 2),nn.Dropout2d(p = 0.1),nn.AdaptiveMaxPool2d((1,1)),nn.Flatten(),nn.Linear(64,32),nn.ReLU(),nn.Linear(32,1),nn.Sigmoid()])def forward(self,x):for layer in self.layers:x = layer(x)return x
net = Net()
print(net)

3,nn.ModuleDict作为模型容器

注意下面中的ModuleDict不能用Python中的字典代替。

class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.layers_dict = nn.ModuleDict({"conv1":nn.Conv2d(in_channels=3,out_channels=32,kernel_size = 3),"pool": nn.MaxPool2d(kernel_size = 2,stride = 2),"conv2":nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5),"dropout": nn.Dropout2d(p = 0.1),"adaptive":nn.AdaptiveMaxPool2d((1,1)),"flatten": nn.Flatten(),"linear1": nn.Linear(64,32),"relu":nn.ReLU(),"linear2": nn.Linear(32,1),"sigmoid": nn.Sigmoid()})def forward(self,x):layers = ["conv1","pool","conv2","pool","dropout","adaptive","flatten","linear1","relu","linear2","sigmoid"]for layer in layers:x = self.layers_dict[layer](x)return x
net = Net()
print(net)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/389342.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

vue取数据第一个数据_我作为数据科学家的第一个月

vue取数据第一个数据A lot.很多。 I landed my first job as a Data Scientist at the beginning of August, and like any new job, there’s a lot of information to take in at once.我于8月初找到了数据科学家的第一份工作,并且像任何新工作一样,一…

Flask-SocketIO 简单使用指南

Flask-SocketIO 使 Flask 应用程序能够访问客户端和服务器之间的低延迟双向通信。客户端应用程序可以使用 Javascript,C ,Java 和 Swift 中的任何 SocketIO 官方客户端库或任何兼容的客户端来建立与服务器的永久连接。 安装 直接使用 pip 来安装&#xf…

STL-开篇

基本概念 STL: Standard Template Library,标准模板库 定义: c引入的一个标准类库 特点:1)数据结构和算法的 c实现( 采用模板类和模板函数)2)数据的存储和算法的分离3)高…

Symbol Mc1000 声音的设置以及播放

首先引用Symbol.Audio 加一命名空间using Symbol.Audio; /声音设备的设置 //Select Device from device list Symbol.Audio.Device MyDevice (Symbol.Audio.Device)Symbol.StandardForms.SelectDevice.Select( Symbol.Audio.Controller.Title, Symbol.Audio.Devic…

/bin/bash^M: 坏的解释器: 没有那个文件或目录

在win下编辑的时候,换行结尾是\n\r , 而在linux下 是\n,所以会多出来一个\r,这样会出现错误 此时执行 sed -i s/\r$// file.sh 将file.sh中的\r都替换为空白,问题解决转载于:https://www.cnblogs.com/zzdbullet/p/9890…

rcp rapido_为什么气流非常适合Rapido

rcp rapidoBack in 2019, when we were building our data platform, we started building the data platform with Hadoop 2.8 and Apache Hive, managing our own HDFS. The need for managing workflows whether it’s data pipelines, i.e. ETL’s, machine learning predi…

pandas处理丢失数据与数据导入导出

3.4pandas处理丢失数据 头文件: import numpy as np import pandas as pd丢弃数据部分: dates pd.date_range(20130101,periods6) df pd.DataFrame(np.random.randn(6,4),indexdates,columns[A,B,C,D]) df.iloc[0,1] np.nan df.iloc[1,2] np.nanp…

Mysql5.7开启远程

2019独角兽企业重金招聘Python工程师标准>>> 1.注掉bind-address #bind-address 127.0.0.1 2.开启远程访问权限 grant all privileges on *.* to root"xxx.xxx.xxx.xxx" identified by "密码"; 或 grant all privileges on *.* to root"%…

分类结果可视化python_可视化分类结果的另一种方法

分类结果可视化pythonI love good data visualizations. Back in the days when I did my PhD in particle physics, I was stunned by the histograms my colleagues built and how much information was accumulated in one single plot.我喜欢出色的数据可视化。 早在我获得…

算法组合 优化算法_算法交易简化了风险价值和投资组合优化

算法组合 优化算法Photo by Markus Spiske (left) and Jamie Street (right) on UnsplashMarkus Spiske (左)和Jamie Street(右)在Unsplash上的照片 In the last post, we saw how actual algorithms are developed and tested. In this post, we will figure out the level of…

Symbol Mc1000 快捷键 的 设置 事件 开发

switch (e.KeyCode) { ///数据 case Keys.F1://清除数据 if(File.Exists("Storage Card/CG.sdf")) { Mc.gConn.Close(); Mc.gConn.Dispose(); File.Delete("Storage Card/CG.sdf"); } MessageBox.S…

pandas合并concatmerge和plot画图

3.6,3.7pandas合并concat&merge 头文件: import pandas as pd import numpy as npconcat基础合并用法 df1 pd.DataFrame(np.ones((3,4))*0,columns [a,b,c,d]) df2 pd.DataFrame(np.ones((3,4))*1,columns [a,b,c,d]) df3 pd.DataFrame(np.ones…

Android跳转WIFI界面的四种方式

第一种 Intent intent new Intent(); intent.setAction("android.net.wifi.PICK_WIFI_NETWORK"); startActivity(intent); 第二种 startActivity(new Intent(android.provider.Settings.ACTION_WIFI_SETTINGS)); 第三种 Intent i new Intent(); if(android.os.Buil…

PS抠发丝技巧 「选择并遮住…」

PS抠发丝技巧 「选择并遮住…」 现在的海报设计,大多数都有模特MM,然而MM的头发实用太多了,有的还飘起来…… 对于设计师(特别是淘宝美工)没有一个强大、快速、实用的抠发丝技巧真的混不去哦。而PS CC 2017版本开始,就有了一个强大…

covid 19如何重塑美国科技公司的工作文化

未来 , 技术 , 观点 (Future, Technology, Opinion) Who would have thought that a single virus would take down the whole world and make us stay inside our homes? A pandemic wave that has altered our lives in such a way that no human (bi…

Symbol Mc1000 Text文本阅读器整体代码

using System; using System.ComponentModel;using System.Data;using System.Drawing;using System.Text;using System.Windows.Forms;using System.Collections;using System.IO;namespace text{ /// <summary> /// Form1 的摘要说明。 /// </summary> public c…

python生日悖论分析_生日悖论

python生日悖论分析If you have a group of people in a room, how many do you need to for it to be more likely than not, that two or more will have the same birthday?如果您在一个房间里有一群人&#xff0c;那么您需要多少个才能使两个或两个以上的人有相同的生日&a…

统计0-n数字中出现k的次数

/*** 统计0-n数字中出现k的次数&#xff0c;其中k范围为0-9 */ public static int countOne(int k, int n) {if (k > n) {return 0;}int sum 0;int right 0;for (int i 0; n > 0; i) {int last n % 10;sum last * i * (int) Math.pow(10, i - 1);if (k 0) {sum - (…

房价预测 search Search 中对数据预处理的学习

对于缺失的数据&#xff1a; 我们对连续数值的特征做标准化&#xff08;standardization&#xff09;&#xff1a;设该特征在整个数据集上的均值为 μ &#xff0c;标准差为 σ 。那么&#xff0c;我们可以将该特征的每个值先减去 μ 再除以 σ 得到标准化后的每个特征值。对于…

3.6.1.非阻塞IO

本节讲解什么是非阻塞IO&#xff0c;如何将文件描述符修改为非阻塞式 3.6.1.1、阻塞与非阻塞 &#xff08;1&#xff09;阻塞是指函数调用会被阻塞。本质是当前进程调用了函数&#xff0c;进入内核里面去后&#xff0c;因为当前进程的执行条件不满足&#xff0c;内核无法里面完…