Pytorch 神经网络nn模块

文章目录

    • 1. nn模块
    • 2. torch.optim 优化器
    • 3. 自定义nn模块
    • 4. 权重共享

参考 http://pytorch123.com/

1. nn模块

import torch
N, D_in, Hidden_size, D_out = 64, 1000, 100, 10
  • torch.nn.Sequential 建立模型,跟 keras 很像
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)model = torch.nn.Sequential(torch.nn.Linear(D_in, Hidden_size),torch.nn.ReLU(),torch.nn.Linear(Hidden_size, D_out)
)# 损失函数
loss_fn = torch.nn.MSELoss(reduction='sum')
learning_rate = 1e-4
loss_list = []for t in range(500):y_pred = model(x) # 前向传播loss = loss_fn(y_pred, y) # 损失loss_list.append(loss.item())print(t, loss.item())model.zero_grad() # 清零梯度loss.backward() # 反向传播,计算梯度with torch.no_grad(): # 更新参数,不计入网络图的操作当中for param in model.parameters():param -= learning_rate*param.grad # 更新参数
# 绘制损失
import pandas as pd
loss_curve = pd.DataFrame(loss_list, columns=['loss'])
loss_curve.plot()

2. torch.optim 优化器

  • torch.optim.Adam 使用优化器
  • optimizer.zero_grad() # 清零梯度
  • optimizer.step() # 更新参数
learning_rate = 1e-4optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)loss_list = []
for t in range(500):y_pred = model(x) # 前向传播loss = loss_fn(y_pred, y) # 损失loss_list.append(loss.item())print(t, loss.item())optimizer.zero_grad() # 清零梯度loss.backward() # 反向传播,计算梯度optimizer.step() # 更新参数

3. 自定义nn模块

  • 继承 nn.module,并定义 forward 前向传播函数
import torch
class myModel(torch.nn.Module):def __init__(self, D_in, Hidden_size, D_out):super(myModel, self).__init__()self.fc1 = torch.nn.Linear(D_in, Hidden_size)self.fc2 = torch.nn.Linear(Hidden_size, D_out)def forward(self, x):x = self.fc1(x).clamp(min=0) # clamp 修剪数据在 min - max 之间,relu的作用x = self.fc2(x)return x
N, D_in, Hidden_size, D_out = 64, 1000, 100, 10
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)model = myModel(D_in, Hidden_size, D_out) # 自定义模型loss_fn = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)loss_val = []for t in range(500):y_pred = model(x)loss = loss_fn(y_pred, y)loss_val.append(loss.item())optimizer.zero_grad()loss.backward()optimizer.step()import pandas as pd
loss_val = pd.DataFrame(loss_val, columns=['loss'])
loss_val.plot()

4. 权重共享

  • 建立一个有3种FC层的玩具模型,中间 shareFC层会被 for 循环重复 0-3 次(随机),这几层(次数随机)的参数是共享的
import random
import torchclass shareParamsModel(torch.nn.Module):def __init__(self, D_in, Hidden_size, D_out):super(shareParamsModel, self).__init__()self.inputFC = torch.nn.Linear(D_in, Hidden_size)self.shareFC = torch.nn.Linear(Hidden_size, Hidden_size)self.outputFC = torch.nn.Linear(Hidden_size, D_out)self.sharelayers = 0 # 记录随机出了多少层def forward(self, x):x = self.inputFC(x).clamp(min=0)self.sharelayers = 0for _ in range(random.randint(0, 3)):x = self.shareFC(x).clamp(min=0)self.sharelayers += 1x = self.outputFC(x)return x
N, D_in, Hidden_size, D_out = 64, 1000, 100, 10
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)model = shareParamsModel(D_in, Hidden_size, D_out)loss_fn = torch.nn.MSELoss(reduction='sum')optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)loss_val = []for t in range(500):y_pred = model(x)print('share layers: ', model.sharelayers)loss = loss_fn(y_pred, y)loss_val.append(loss.item())optimizer.zero_grad()loss.backward()optimizer.step()for p in model.parameters():print(p.size())import pandas as pd
loss_val = pd.DataFrame(loss_val, columns=['loss'])
loss_val.plot()

输出:

share layers:  1
share layers:  0
share layers:  2
share layers:  1
share layers:  2
share layers:  1
share layers:  0
share layers:  1
share layers:  0
share layers:  0
share layers:  3
share layers:  3
。。。省略

参数数量,多次运行,均为以下结果

torch.Size([100, 1000])
torch.Size([100])
torch.Size([100, 100])
torch.Size([100])
torch.Size([10, 100])
torch.Size([10])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/473278.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

全面系统地总结Linux的基本操作(下)

4、 Linux命令-系统管理 4.1 查看日历:cal cal 命令用于查看当前日历,-y 显示整年日历: 4.2 显示或设置日期:date 设置时间格式(需要管理员权限): date [MMDDhhmm[[CC]YY][.ss]] format CC 为年前两位 yy 为年的后…

免费个人博客:使用hexo+github搭建详细教程

前言 使用github pages服务搭建博客的好处有: 全是静态文件,访问速度快;免费方便,不用花一分钱就可以搭建一个自由的个人博客,不需要服务器不需要后台;可以随意绑定自己的域名,不仔细看的话根…

LeetCode 1235. 规划兼职工作(动态规划+二分查找)

文章目录1. 题目2. 解题1. 题目 你打算利用空闲时间来做兼职工作赚些零花钱。 这里有 n 份兼职工作,每份工作预计从 startTime[i] 开始到 endTime[i] 结束,报酬为 profit[i]。 给你一份兼职工作表,包含开始时间 startTime,结束…

刷新页面,无论点击多少次让Element UI的Message消息提示弹出一个

一、遇到的问题 Element UI的Message消息提示是点击一次触发一次的。在开发的时候经常会作为一些校验提示,但是公司的测试人员在进行测试时会一直点,然后就会出现如下图的情况。虽然客户使用的时候一般来说不会出现这种情况(毕竟客户不会闲着…

如何让二维码自适应浏览器的尺寸

一、遇到的问题: 正常浏览网页,二维码正常显示,但是随着浏览器的扩大与缩小,二维码尺寸不会随着屏幕自适应 正常浏览(截取部分): 缩小浏览器(截取部分&#xf…

E6全部刷机包

此版本号基于R533_G_11.11.10P_GSZMCAUT679DA01B_LP064DA_T679DA_S005_E001_P002_R001_G004_1FF.sbf制作耳机接听或挂机正常内置Loader(asmotoe2)、Console(网上的大侠)、showQ(bint大侠)、SetupPKG&#x…

LeetCode 330. 按要求补齐数组(贪心)

文章目录1. 题目2. 解题1. 题目 给定一个已排序的正整数数组 nums,和一个正整数 n 。 从 [1, n] 区间内选取任意个数字补充到 nums 中,使得 [1, n] 区间内的任何数字都可以用 nums 中某几个数字的和来表示。请输出满足上述要求的最少需要补充的数字个数…

系统总结vue组件间通信、数据传递(父子组件,同级组件)

总结一下对vue组件通信的理解和使用。一、组件目录结构 父组件&#xff1a;app.vue子组件&#xff1a;page1.vue子组件&#xff1a;page2.vue 父组件 app.vue <template><div id"app"><p>请输入单价: <input type"text" v-model&qu…

LeetCode 1224. 最大相等频率(哈希)

文章目录1. 题目2. 解题1. 题目 给出一个正整数数组 nums&#xff0c;请你帮忙从该数组中找出能满足下面要求的 最长 前缀&#xff0c;并返回其长度&#xff1a; 从前缀中 删除一个 元素后&#xff0c;使得所剩下的每个数字的出现次数相同。 如果删除这个元素后没有剩余元素…

从零开始,手把手交给你vue如何新建一个项目

vue创建项目&#xff08;npm安装→初始化项目&#xff09; 第一步npm安装 首先&#xff1a;先从nodejs.org中下载nodejs 图1 双击安装&#xff0c;在安装界面一直Next 图2 图3 图4 直到Finish完成安装。 打开控制命令行程序&#xff08;CMD&#xff09;,检查是否正常 图5 …

数学图形(1.33) 棕子曲线

#http://www.mathcurve.com/courbes2d/vasques/vasques.shtml vertices 10000 t from 0 to (8*PI) a rand_int2(1, 30) b rand_int2(1, 4) n 8 x cos(n*t - t)*cos(n*t) y cos(n*t)^2 a 10 x x*a y y*a 相关软件参见:数学图形可视化工具,使用自己定义语法的脚本代码生…

LeetCode 1278. 分割回文串 III(区间DP)

文章目录1. 题目2. 解题1. 题目 给你一个由小写字母组成的字符串 s&#xff0c;和一个整数 k。 请你按下面的要求分割字符串&#xff1a; 首先&#xff0c;你可以将 s 中的部分字符修改为其他的小写英文字母。接着&#xff0c;你需要把 s 分割成 k 个非空且不相交的子串&…

LeetCode 1187. 使数组严格递增(DP)*

文章目录1. 题目2. 解题1. 题目 给你两个整数数组 arr1 和 arr2&#xff0c;返回使 arr1 严格递增所需要的最小「操作」数&#xff08;可能为 0&#xff09;。 每一步「操作」中&#xff0c;你可以分别从 arr1 和 arr2 中各选出一个索引&#xff0c;分别为 i 和 j&#xff0c…

用Python进行屏幕截图,只用两行代码搞定

一、计算机中如何进行屏幕截图呢&#xff1f; 1、全屏截图 按下键盘中的‘PRTSC’或者‘Print Screen’键&#xff0c;即可实现全屏截图&#xff08;不同键盘位置和名称可能不同&#xff09;。此时&#xff0c;并不能看到效果&#xff0c;只是将截图保存在粘贴板中&#xff0…

利用nginx建立windows软连,实现IP访问文件

一、运行nginx 1、首先下载nginx&#xff0c;下载地址&#xff1a;https://www.lanzous.com/ianm7tg 2、解压文件如图&#xff1a; 3、运行nginx.exe&#xff0c;浏览器运行电脑ip地址&#xff0c;如图&#xff1a; 二、cmd管理员权限 运行中输入“cmd”&#xff0c;按住shi…

LeetCode 1263. 推箱子(BFS+DFS / 自定义哈希set)

文章目录1. 题目2. 解题2.1 超时解2.2 BFS DFS1. 题目 「推箱子」是一款风靡全球的益智小游戏&#xff0c;玩家需要将箱子推到仓库中的目标位置。 游戏地图用大小为 n * m 的网格 grid 表示&#xff0c;其中每个元素可以是墙、地板或者是箱子。 现在你将作为玩家参与游戏&a…

深入浅出Java回调机制

前几天看了一下Spring的部分源码&#xff0c;发现回调机制被大量使用&#xff0c;觉得有必要把Java回调机制的理解归纳总结一下&#xff0c;以方便在研究类似于Spring源码这样的代码时能更加得心应手。 注&#xff1a;本文不想扯很多拗口的话来充场面&#xff0c;我的目的是希望…

前端:实现div等块元素添加X轴滚动显示(Y轴不滚动)

一、建立外盒子与内盒子 原生态代码&#xff1a; <div class"tol_dev"><div class"dev_li"></div><div class"dev_li"></div><div class"dev_li"></div><div class"dev_li"…

2020年学习总结

文章目录1. CSDN 博客数据2. 基础算法练习3. 机器学习4. 深度学习5. MySQL6. 总结和展望时间过得很快&#xff0c;2020结束了&#xff01; 写个流水账&#xff0c;记录一下。 1. CSDN 博客数据 截个图对比下&#xff1a; 2019年终2020年终 2. 基础算法练习 LeetCode 刷题 …

npm全局环境变量配置及解决VsCode使用时遇到的问题

一、npm全局环境变量配置 1、我们要先配置npm的全局模块的存放路径以及cache的路径 例如我希望将以上两个文件夹放在NodeJS的主目录下&#xff0c;便在NodeJs下建立”node_global”及”node_cache”两个文件夹。如下图 2、cmd 中输入如下命令 npm config set prefix “d:\no…