6-pytorch - 网络的保存和提取

前言

我们训练好的网络,怎么保存和提取呢?
总不可以一直不关闭电脑吧,训练到一半,想结束到明天再来训练,这就需要进行网络的保存和提取了。
本文以前面博客3-pytorch搭建一个简单的前馈全连接层网络(回归问题)的网络进行网络的保存和提取,建议先看完上面博客再来看本博客。

一、生成训练数据

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np# 生成数据(fake data)
x = torch.linspace(-1,1,100).reshape(-1,1)
# 加上点噪声
y = x.pow(2) + 0.2*torch.rand(x.shape)# 可视化一下数据
plt.scatter(x.data.numpy(),y.data.numpy())
plt.show()

输出:
在这里插入图片描述

二、网络保存

def save():net1 = torch.nn.Sequential(torch.nn.Linear(1,10),torch.nn.ReLU(), torch.nn.Linear(10,1))optimizer = torch.optim.SGD(net1.parameters(),lr=0.5)loss_func = torch.nn.MSELoss()for t in range(100):prediction = net1(x)loss = loss_func(prediction,y)optimizer.zero_grad()loss.backward()optimizer.step()# 下面介绍两种不同的保存方法,方法二可能运行速度要快点# 保存整个网络的所有torch.save(net1, 'net.pkl')     # 保存好网络的参数torch.save(net1.state_dict(),'net_params.pkl')# plot resultplt.figure(1,figsize=(10,3))plt.subplot(131)plt.title('Net1')plt.scatter(x.data.numpy(),y.data.numpy())plt.plot(x.data.numpy(),prediction.data.numpy(),'r-',lw=5)

【注】:保存整个网络还是保存网络参数,个人建议仅保存参数,这个速度更快。

三、网络提取

def restore_net():net2 = torch.load('net.pkl')prediction = net2(x)# plot resultplt.subplot(132)plt.title('Net1')plt.scatter(x.data.numpy(),y.data.numpy())plt.plot(x.data.numpy(),prediction.data.numpy(),'r-',lw=5)def restore_params():# 如果只是保留参数的情况,提取时需要再次定义相同网络才行net3 = torch.nn.Sequential(torch.nn.Linear(1,10),torch.nn.ReLU(),torch.nn.Linear(10,1))net3.load_state_dict(torch.load('net_params.pkl'))prediction = net3(x)# plot resultplt.subplot(133)plt.title('Net1')plt.scatter(x.data.numpy(),y.data.numpy())plt.plot(x.data.numpy(),prediction.data.numpy(),'r-',lw=5)

四、对保存网络提取进行结果展示

save()
restore_net()
restore_params()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/822717.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

开通订阅plus

提示: 您的信用卡被拒绝了,请尝试用借记卡支付。您的金融卡已被拒绝。您拒绝了,请尝试用签账卡支付。我们未能验证您的支付方式,请选择另一支付方式并重试。 我都崩溃了,一次又一次的不行,换了好多方式。…

Java switch使用

Java switch使用 涉及关键字: switch: 表达式 变量类型可以是: byte、short、int 或者 char。从 Java SE 7 开始,switch 支持字符串 String 类型, case: 分支语句,需要指定当前分支的常量或者字…

【图文教程】在PyCharm中导入Conda环境

文章目录 (1)在Anaconda Prompt中新建一个conda虚拟环境(2)使用PyCharm打开需要搭建环境的项目(3)配置环境 (1)在Anaconda Prompt中新建一个conda虚拟环境 conda create - myenv py…

Day99:云上攻防-云原生篇K8s安全实战场景攻击Pod污点Taint横向移动容器逃逸

目录 云原生-K8s安全-横向移动-污点Taint 云原生-K8s安全-Kubernetes实战场景 知识点: 1、云原生-K8s安全-横向移动-污点Taint 2、云原生-K8s安全-Kubernetes实战场景 云原生-K8s安全-横向移动-污点Taint 如何判断实战中能否利用污点Taint? 设置污点…

STM32学习和实践笔记(14):按键控制实验

消除抖动有软件和硬件两种方法 软件方法就是在首次检测到低电平时加延时,通常延时5-10ms,让抖动先过去,然后再来检测是否仍为低电平,如果仍然是,说明确实按下。 硬件方法就是加RC滤波电路,硬件方法会增加…

✌粤嵌—2024/4/3—合并K个升序链表✌

代码实现: /*** Definition for singly-linked list.* struct ListNode {* int val;* struct ListNode *next;* };*/ struct ListNode* merge(struct ListNode *l1, struct ListNode *l2) {if (l1 NULL) {return l2;}if (l2 NULL) {return l1;}struct Lis…

DNS服务器配置与管理(3)——综合案例

DNS服务器配置与管理 前言 在之前,曾详细介绍了DNS服务器原理和使用BIND部署DNS服务器,本文主要以一个案例为驱动,在网络中部署主DNS服务器、辅助DNS服务器以及子域委派的配置。 案例需求 某公司申请了域名example.com,公司服…

第七周学习笔记DAY.1-封装

学完本次课程后,你能够: 理解封装的作用 会使用封装 会使用Java中的包组织类 掌握访问修饰符,理解访问权限 没有封装的话属性访问随意,赋值也可能不合理,为了解决这些代码设计缺陷,可以使用封装。 面向…

vue快速入门(二十八)页面渲染完成后让输入框自动获取焦点

注释很详细&#xff0c;直接上代码 上一篇 新增内容 使用挂载完成的钩子函数用focus使输入框获取焦点 源码 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8" /><meta name"viewport" content"width…

leetcode:739.每日温度/496.下一个更大元素

单调栈的应用&#xff1a; 求解当前元素右边比该元素大的第一个元素&#xff08;左右、大小都可以&#xff09;。 单调栈的构成&#xff1a; 单调栈里存储数组的下标&#xff1b; 单调栈里的元素递增&#xff0c;求解当前元素右边比该元素大的第一个元素&#xff1b;元素递…

(十一)C++自制植物大战僵尸游戏客户端更新实现

植物大战僵尸游戏开发教程专栏地址http://t.csdnimg.cn/cFP3z 更新检查 游戏启动后会下载服务器中的版本号然后与本地版本号进行对比&#xff0c;如果本地版本号小于服务器版本号就会弹出更新提示。让用户选择是否更新客户端。 在弹出的更新对话框中有显示最新版本更新的内容…

2024五一杯数学建模A题思路分析

文章目录 1 赛题思路2 比赛日期和时间3 组织机构4 建模常见问题类型4.1 分类问题4.2 优化问题4.3 预测问题4.4 评价问题 5 建模资料 1 赛题思路 (赛题出来以后第一时间在CSDN分享) https://blog.csdn.net/dc_sinor?typeblog 2 比赛日期和时间 报名截止时间&#xff1a;2024…

光学雨量计雨量传感器的工作原理与实时数据采集

光学雨量计雨量传感器的工作原理与实时数据采集 河北稳控科技光学雨量计是一种常用的雨量传感器&#xff0c;其工作原理基于光学原理和实时数据采集技术。它的主要作用是测量雨水的大小和强度&#xff0c;为气象、农业、水文等领域提供重要的数据支持。 光学雨量计的工作原理是…

单片机之ESP8266模块

目录 ESP8266简介 前言 ESP8266的工作模式 ESP8266引脚说明 ESP8266测试 步骤 单片机与esp8266交互 前言 收到数据的格式 AP模式 服务器模式 外部执行命令 代码内执行命令 代码部分 客户端模式 外部执行命令 内部执行命令 代码部分 STA模式 服务器模式 外…

10个常用的损失函数及Python代码实现

本文深入理解并详细介绍了10个常用的损失函数及Python代码实现。 什么是损失函数&#xff1f; 损失函数是一种衡量模型与数据吻合程度的算法。损失函数测量实际测量值和预测值之间差距的一种方式。损失函数的值越高预测就越错误&#xff0c;损失函数值越低则预测越接近真实值…

LUCF-Net:轻量级U形级联 用于医学图像分割的融合网络

LUCF-Net&#xff1a;轻量级U形级联 用于医学图像分割的融合网络 摘要IntroductionRelated WorkProposed MethodLocal-Global Feature ExtractionEncoder and DecoderFeature FusionLoss Function LUCF-Net: Lightweight U-shaped Cascade Fusion Network for Medical Image Se…

Android zxing库实现扫码识别

第一步 加库zxing库 //导入二维码识别库ZXingimplementation("com.journeyapps:zxing-android-embedded:4.2.0") 第二部获取摄像机权限 <uses-permission android:name="android.permission.CAMERA" /><uses-permission android:name="andro…

【日常记录】【CSS】利用动画延迟实现复杂动画

文章目录 1、介绍2、原理3、代码4、参考链接 1、介绍 对于这个效果而言&#xff0c;最先想到的就是 监听滑块的input事件来做一些操作 ,但是会发现&#xff0c;对于某一个节点的时候&#xff0c;这个样式操作起来比较麻烦 只看这个代码的话&#xff0c;发现他用的是动画&#x…

超详细!Python中 pip 常用命令

相信对于大多数熟悉Python的人来说&#xff0c;一定都听说并且使用过pip这个工具&#xff0c;但是对它的了解可能还不一定是非常的透彻&#xff0c;今天小编就来为大家介绍10个使用pip的小技巧&#xff0c;相信对大家以后管理和使用Python当中的标准库会有帮助。 安装 当然在…

【算法一则】编辑距离 【动态规划】

题目 给你两个单词 word1 和 word2&#xff0c; 请返回将 word1 转换成 word2 所使用的最少操作数 。 你可以对一个单词进行如下三种操作&#xff1a; 插入一个字符 删除一个字符 替换一个字符 示例 1&#xff1a;输入&#xff1a;word1 "horse", word2 "…