6-pytorch - 网络的保存和提取

前言

我们训练好的网络,怎么保存和提取呢?
总不可以一直不关闭电脑吧,训练到一半,想结束到明天再来训练,这就需要进行网络的保存和提取了。
本文以前面博客3-pytorch搭建一个简单的前馈全连接层网络(回归问题)的网络进行网络的保存和提取,建议先看完上面博客再来看本博客。

一、生成训练数据

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np# 生成数据(fake data)
x = torch.linspace(-1,1,100).reshape(-1,1)
# 加上点噪声
y = x.pow(2) + 0.2*torch.rand(x.shape)# 可视化一下数据
plt.scatter(x.data.numpy(),y.data.numpy())
plt.show()

输出:
在这里插入图片描述

二、网络保存

def save():net1 = torch.nn.Sequential(torch.nn.Linear(1,10),torch.nn.ReLU(), torch.nn.Linear(10,1))optimizer = torch.optim.SGD(net1.parameters(),lr=0.5)loss_func = torch.nn.MSELoss()for t in range(100):prediction = net1(x)loss = loss_func(prediction,y)optimizer.zero_grad()loss.backward()optimizer.step()# 下面介绍两种不同的保存方法,方法二可能运行速度要快点# 保存整个网络的所有torch.save(net1, 'net.pkl')     # 保存好网络的参数torch.save(net1.state_dict(),'net_params.pkl')# plot resultplt.figure(1,figsize=(10,3))plt.subplot(131)plt.title('Net1')plt.scatter(x.data.numpy(),y.data.numpy())plt.plot(x.data.numpy(),prediction.data.numpy(),'r-',lw=5)

【注】:保存整个网络还是保存网络参数,个人建议仅保存参数,这个速度更快。

三、网络提取

def restore_net():net2 = torch.load('net.pkl')prediction = net2(x)# plot resultplt.subplot(132)plt.title('Net1')plt.scatter(x.data.numpy(),y.data.numpy())plt.plot(x.data.numpy(),prediction.data.numpy(),'r-',lw=5)def restore_params():# 如果只是保留参数的情况,提取时需要再次定义相同网络才行net3 = torch.nn.Sequential(torch.nn.Linear(1,10),torch.nn.ReLU(),torch.nn.Linear(10,1))net3.load_state_dict(torch.load('net_params.pkl'))prediction = net3(x)# plot resultplt.subplot(133)plt.title('Net1')plt.scatter(x.data.numpy(),y.data.numpy())plt.plot(x.data.numpy(),prediction.data.numpy(),'r-',lw=5)

四、对保存网络提取进行结果展示

save()
restore_net()
restore_params()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/822717.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

开通订阅plus

提示: 您的信用卡被拒绝了,请尝试用借记卡支付。您的金融卡已被拒绝。您拒绝了,请尝试用签账卡支付。我们未能验证您的支付方式,请选择另一支付方式并重试。 我都崩溃了,一次又一次的不行,换了好多方式。…

Java switch使用

Java switch使用 涉及关键字: switch: 表达式 变量类型可以是: byte、short、int 或者 char。从 Java SE 7 开始,switch 支持字符串 String 类型, case: 分支语句,需要指定当前分支的常量或者字…

学习R语言第三天

R语句中的函数信息 1. 函数信息 x <- c(1:100) x #获取x的长度信息 length(x) # 获取第一个数据信息 x[1] # 获取4到18的数据信息 x[c(4:18)]2.存入逻辑值的方式 # y中存入逻辑值的方式 y[c(T,F,T,F)]#输出大于5的数据信息 y[y>5]#输出大于5小于9的数字 y[y<5 &…

【图文教程】在PyCharm中导入Conda环境

文章目录 &#xff08;1&#xff09;在Anaconda Prompt中新建一个conda虚拟环境&#xff08;2&#xff09;使用PyCharm打开需要搭建环境的项目&#xff08;3&#xff09;配置环境 &#xff08;1&#xff09;在Anaconda Prompt中新建一个conda虚拟环境 conda create - myenv py…

OWASP发布十大开源软件安全风险清单

OWASP发布了“十大开源软件风险”TOP10清单&#xff0c;并针对每种风险给出了安全建议。 近年来开源软件安全风险快速增长&#xff0c;不久前曝光的XZ后门更是被称为“核弹级”的开源软件供应链漏洞。虽然XZ后门事件侥幸未酿成灾难性后果&#xff0c;但为全球科技界敲响了警钟&…

Day99:云上攻防-云原生篇K8s安全实战场景攻击Pod污点Taint横向移动容器逃逸

目录 云原生-K8s安全-横向移动-污点Taint 云原生-K8s安全-Kubernetes实战场景 知识点&#xff1a; 1、云原生-K8s安全-横向移动-污点Taint 2、云原生-K8s安全-Kubernetes实战场景 云原生-K8s安全-横向移动-污点Taint 如何判断实战中能否利用污点Taint&#xff1f; 设置污点…

STM32学习和实践笔记(14):按键控制实验

消除抖动有软件和硬件两种方法 软件方法就是在首次检测到低电平时加延时&#xff0c;通常延时5-10ms&#xff0c;让抖动先过去&#xff0c;然后再来检测是否仍为低电平&#xff0c;如果仍然是&#xff0c;说明确实按下。 硬件方法就是加RC滤波电路&#xff0c;硬件方法会增加…

✌粤嵌—2024/4/3—合并K个升序链表✌

代码实现&#xff1a; /*** Definition for singly-linked list.* struct ListNode {* int val;* struct ListNode *next;* };*/ struct ListNode* merge(struct ListNode *l1, struct ListNode *l2) {if (l1 NULL) {return l2;}if (l2 NULL) {return l1;}struct Lis…

DNS服务器配置与管理(3)——综合案例

DNS服务器配置与管理 前言 在之前&#xff0c;曾详细介绍了DNS服务器原理和使用BIND部署DNS服务器&#xff0c;本文主要以一个案例为驱动&#xff0c;在网络中部署主DNS服务器、辅助DNS服务器以及子域委派的配置。 案例需求 某公司申请了域名example.com&#xff0c;公司服…

第七周学习笔记DAY.1-封装

学完本次课程后&#xff0c;你能够&#xff1a; 理解封装的作用 会使用封装 会使用Java中的包组织类 掌握访问修饰符&#xff0c;理解访问权限 没有封装的话属性访问随意&#xff0c;赋值也可能不合理&#xff0c;为了解决这些代码设计缺陷&#xff0c;可以使用封装。 面向…

vue快速入门(二十八)页面渲染完成后让输入框自动获取焦点

注释很详细&#xff0c;直接上代码 上一篇 新增内容 使用挂载完成的钩子函数用focus使输入框获取焦点 源码 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8" /><meta name"viewport" content"width…

leetcode:739.每日温度/496.下一个更大元素

单调栈的应用&#xff1a; 求解当前元素右边比该元素大的第一个元素&#xff08;左右、大小都可以&#xff09;。 单调栈的构成&#xff1a; 单调栈里存储数组的下标&#xff1b; 单调栈里的元素递增&#xff0c;求解当前元素右边比该元素大的第一个元素&#xff1b;元素递…

上海计算机学会2021年1月月赛C++丙组T4三倍游戏

题目描述 三倍游戏是一种单人游戏。玩家会得到 n 个整数a1​,a2​,…,an​。玩家从这些整数中挑出两个数字相加&#xff0c;如果它们的和是 3 的倍数&#xff0c;则可以将这两个整数消除&#xff0c;如此反复&#xff0c;直到不能再消除数字为止。 请问玩家最多能消除多少对数…

(十一)C++自制植物大战僵尸游戏客户端更新实现

植物大战僵尸游戏开发教程专栏地址http://t.csdnimg.cn/cFP3z 更新检查 游戏启动后会下载服务器中的版本号然后与本地版本号进行对比&#xff0c;如果本地版本号小于服务器版本号就会弹出更新提示。让用户选择是否更新客户端。 在弹出的更新对话框中有显示最新版本更新的内容…

2024-04-17 问AI: 介绍一下卷积网络GoogleNet

文心一言 GoogleNet&#xff0c;也被称为Inception-v1&#xff0c;是由Google团队在2014年提出的一种深度卷积神经网络架构&#xff0c;专门用于图像分类和特征提取任务。它在ILSVRC&#xff08;ImageNet Large Scale Visual Recognition Challenge&#xff09;比赛中取得了优…

2024五一杯数学建模A题思路分析

文章目录 1 赛题思路2 比赛日期和时间3 组织机构4 建模常见问题类型4.1 分类问题4.2 优化问题4.3 预测问题4.4 评价问题 5 建模资料 1 赛题思路 (赛题出来以后第一时间在CSDN分享) https://blog.csdn.net/dc_sinor?typeblog 2 比赛日期和时间 报名截止时间&#xff1a;2024…

代码随想录算法训练营day41

343. 整数拆分 五部曲&#xff1a; dp数组下标及含义&#xff1a;dp[i]表示第i个位置最大乘积dp数组初始化&#xff1a;dp[2]1递推公式&#xff1a;dp[i] max({dp[i], (i - j) * j, dp[i - j] * j});遍历方向&#xff1a;从前往后遍历dp数组推到举例&#xff1a; 2 3 4 5 6 …

光学雨量计雨量传感器的工作原理与实时数据采集

光学雨量计雨量传感器的工作原理与实时数据采集 河北稳控科技光学雨量计是一种常用的雨量传感器&#xff0c;其工作原理基于光学原理和实时数据采集技术。它的主要作用是测量雨水的大小和强度&#xff0c;为气象、农业、水文等领域提供重要的数据支持。 光学雨量计的工作原理是…

单片机之ESP8266模块

目录 ESP8266简介 前言 ESP8266的工作模式 ESP8266引脚说明 ESP8266测试 步骤 单片机与esp8266交互 前言 收到数据的格式 AP模式 服务器模式 外部执行命令 代码内执行命令 代码部分 客户端模式 外部执行命令 内部执行命令 代码部分 STA模式 服务器模式 外…

使用ansible的连通性检查的关键参数

使用ansible进行ping命令的时候发现有些不通 ansible cba -m ping 10.1.1.1 | FAILED! > {"msg": "Using a SSH password instead of a key is not possible because Host Key checking is enabled and sshpass does not support this. Please add this h…