深度学习_18_模型的下载与读取

在深度学习的过程中,需要将训练好的模型运用到我们要使用的另一个程序中,这就需要模型的下载与转移操作

代码:

import math
import torch
from torch import nn
from d2l import torch as d2l
import matplotlib.pyplot as plt# 生成随机的数据集
max_degree = 20  # 多项式的最大阶数
n_train, n_test = 100, 100  # 训练和测试数据集大小
true_w = torch.zeros(max_degree)
true_w[0:4] = torch.Tensor([5, 1.2, -3.4, 5.6])# 生成特征
features = torch.randn((n_train + n_test, 1))
permutation_indices = torch.randperm(features.size(0))
# 使用随机排列的索引来打乱features张量(原地修改)
features = features[permutation_indices]
poly_features = torch.pow(features, torch.arange(max_degree).reshape(1, -1))
for i in range(max_degree):poly_features[:, i] /= math.gamma(i + 1)# 生成标签
labels = torch.matmul(poly_features, true_w)
labels += torch.normal(0, 0.1, size=labels.shape)# 以下是你原来的训练函数,没有修改
def evaluate_loss(net, data_iter, loss):metric = d2l.Accumulator(2)for X, y in data_iter:out = net(X)y = y.reshape(out.shape)l = loss(out, y)metric.add(l.sum(), l.numel())return metric[0] / metric[1]def l2_penalty(w):w = w[0].weightreturn torch.sum(w.pow(2)) / 2def train(train_features, test_features, train_labels, test_labels, lambd,num_epochs=100):loss = d2l.squared_lossinput_shape = train_features.shape[-1]net = nn.Sequential(nn.Linear(input_shape, 1, bias=False))  # 模型batch_size = min(10, train_labels.shape[0])train_iter = d2l.load_array((train_features, train_labels.reshape(-1, 1)),batch_size)test_iter = d2l.load_array((test_features, test_labels.reshape(-1, 1)),batch_size, is_train=False)# 用于存储训练和测试损失的列表train_losses = []test_losses = []total_loss = 0total_samples = 0for epoch in range(num_epochs):for X, y in train_iter:out = net(X)y = y.reshape(-1, 1)  # 确保y是二维的l = loss(out, y) + lambd * l2_penalty(net)# 反向传播和优化器更新l.sum().backward()d2l.sgd(net.parameters(), lr=0.01, batch_size= batch_size)total_loss += l.sum().item()  # 统计所有元素损失total_samples += y.numel()  # 统计个数a = total_loss / total_samples  # 本次训练的平均损失train_losses.append(a)test_loss = evaluate_loss(net, test_iter, loss)test_losses.append(test_loss)total_loss = 0total_samples = 0print(f"Epoch {epoch + 1}/{num_epochs}:")print(f"训练损失: {a:.4f}   测试损失: {test_loss:.4f} ")print(net[0].weight)torch.save(net.state_dict(), "NetSave")  # 存模型net_try = nn.Sequential(nn.Linear(input_shape, 1, bias=False))print("net_try")print(net_try[0].weight)net_try.load_state_dict(torch.load("NetSave"))net_try.eval()  # 评估模式print("net_try_load")print(net_try[0].weight)# 绘制损失曲线plt.figure(figsize=(10, 6))plt.plot(train_losses, label='train', color='blue', linestyle='-', marker='.')plt.plot(test_losses, label='test', color='purple', linestyle='--', marker='.')plt.xlabel('epoch')plt.ylabel('loss')plt.title('Loss over Epochs')plt.legend()plt.grid(True)plt.ylim(0, 1)  # 设置y轴的范围从0.01到100plt.show()# 选择多项式特征中的前4个维度
train(poly_features[:n_train, :4], poly_features[n_train:, :4],labels[:n_train], labels[n_train:], 0)##  net.parameters() 是一个 PyTorch 模型的方法,用于返回模型所有参数的迭代器。这个迭代器产生模型中所有可学习的参数(例如权重和偏置)。

上述代码的模型是简单线性模型

net = nn.Sequential(nn.Linear(input_shape, 1, bias=False))  # 模型

此模型的下载与储存如下

    torch.save(net.state_dict(), "NetSave")  # 存模型net_try = nn.Sequential(nn.Linear(input_shape, 1, bias=False))  # 搭建模型框架print("net_try")print(net_try[0].weight)net_try.load_state_dict(torch.load("NetSave"))  # 下载模型net_try.eval()  # 评估模式print("net_try_load")print(net_try[0].weight)

效果
在这里插入图片描述

所以说要想在另一个程序中将训练好的模型加载到上面去,首先要保存训练好的模型,另一个程序必须有和本模型一样的框架,再将训练好的模型权重加载到另一个程序框架内即可

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/724466.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

NFTScan :什么是 ERC-404?深入解读 NFT 协议的未来

上月初,ERC-404 成为最首要热门的话题,ERC-404 是由 Pandora 团队在 2 月初为创作者和开发者等开源的实验性代币标准,其混合 ERC-20 / ERC-721 实现,具有原生流动性和碎片化等特点。伴随着早期的发展,越来越多参与者开…

win10安全中心误删文件怎么办?解析恢复与预防策略

在使用Windows 10的过程中,许多用户依赖于其内置的安全中心来保护电脑免受恶意软件的侵害。然而,有时安全中心的误判可能导致重要文件被错误地删除。当面对这种情况时,了解如何恢复误删的文件并掌握预防措施显得尤为重要。本文将为您详细解析…

java常用技术栈,java面试带答案

前言 我们从一个问题引入今天的主题。 在日常业务开发中,我们可能经常听到 DBA 对我们说“不要”(注意:不是禁止)使用 join,那么为什么 DBA 对 join 这么抵触呢?是 join 本身有问题,还是我们使…

私募证券基金动态-24年2月报

成交量:2月日均9492.60亿元 2024年2月A股两市日均成交9492.60亿元,环比增加30.38%、同比增加5.77%。2月整体15个交易日,有4个单日交易日成交金额过万亿,单日交易日最高成交金额为13576.43亿元(2月28日)&am…

MySQL 学习笔记(基础篇 Day1)

「写在前面」 本文为黑马程序员 MySQL 教程的学习笔记。本着自己学习、分享他人的态度,分享学习笔记,希望能对大家有所帮助。 目录 0 课程介绍 1 MySQL 概述 1.1 数据库相关概念 1.2 MySQL 数据库 2 SQL 2.1 SQL 通用语法 2.2 SQL 分类 2.3 DDL 2.4 图形…

【leetcode C++】电话号码的字母组合

17. 电话号码的字母组合 题目 给定一个仅包含数字 2-9 的字符串,返回所有它能表示的字母组合。答案可以按 任意顺序 返回。 给出数字到字母的映射如下(与电话按键相同)。注意 1 不对应任何字母。 题目链接 . - 力扣(LeetCode&…

1.类和对象-友元

文章目录 1.全局函数做友元代码运行结果 2.类做友元代码运行结果 1.全局函数做友元 思路分析: 正常情况下,全局函数visit()中的ROOM 类变量r是访问不到Building类中的私有成员的。但是通过在Building类中添加使用全局函数做友元,即可访问私有…

什么是ElasticSearch的深度分页问题?如何解决?

在ElasticSearch中进行分页查询通常使用from和size参数。当我们对ElasticSearch发起一个带有分页参数的查询(如使用from和size参数)时,ElasticSearch需要遍历所以匹配的文档直到达到指定的起始点(from),然后返回从这一点开始的size个文档 在这个例子中: 1.from 参数定义…

代码学习记录13

随想录日记part13 t i m e : time: time: 2024.03.06 主要内容:今天的主要内容是二叉树的第二部分哦,主要有层序遍历;翻转二叉树;对称二叉树。 102.二叉树的层序遍历226.翻转二叉树101. 对称二叉…

LeetCode-第67题-二进制求和

1.题目描述 给你两个二进制字符串 a 和 b ,以二进制字符串的形式返回它们的和。 2.样例描述 3.思路描述 将两个二进制字符串转换成整型,然后相加后的整型转为二进制字符串 4.代码展示 class Solution(object):def addBinary(self, a, b):# 将字符串…

AI新工具(20240306) mlx-swift-chat Mac运行本地模型;Comflowyspace开源AI图像和视频生成工具

1: mlx-swift-chat 专为 Apple 硅片设计的高效机器学习框架,支持在本地实时运行 LLM 模型(如 Llama、Mistral) mlx-swift-chat 是一个为苹果系统(例如你的笔记本电脑上的Apple Silicon)特别设计的机器学习框架 MLX 的…

计划任务和日志

一、计划任务 计划任务概念解析 在Linux操作系统中,除了用户即时执行的命令操作以外,还可以配置在指定的时间、指定的日期执行预先计划好的系统管理任务(如定期备份、定期采集监测数据)。RHEL6系统中默认已安装了at、crontab软件…

扫码看图的预览效果怎么做?图片的二维码如何在线生成?

图片二维码是现在很常用的一种预览图片的方式,比如照片、海报、动态图、拍摄的图片等类型的内容都可以用二维码的方式在手机上预览。在制作图片二维码时候,现在大多会通过网上的图片二维码生成器来制作,直接用专业的功能,就可以快…

SoraAI优先体验资格注册教程

SoraA1视频工具优先体验资格申请 申请网址:https://openai.com/form/red-teaming-network 申请步骤: 填写基础信息 请使用英文根据内容填写以下内容,名、姓、电子邮件、居住国家、组织隶属关系(如果有)、教育水平 、学位(哪个领…

视频推拉流EasyDSS平台直播通道重连无法转推的原因排查与解决

视频推拉流EasyDSS视频直播点播平台,集视频直播、点播、转码、管理、录像、检索、时移回看等功能于一体,可提供音视频采集、视频推拉流、播放H.265编码视频、存储、分发等视频能力服务。 用户使用EasyDSS平台对直播通道进行转推,发现只要关闭…

大势智慧黄先锋:现实世界数字重建 拥抱AI 擘画自主可控的三维画卷

来源:中国地理信息产业协会 实景三维涉及到大面积、高精度的地理空间信息数据,然而早期国内99%以上的实景三维数据制作测绘单位都基于国外软件进行三维重建,如此重要的工作大量使用国外软件,如何确保国家地理空间信息的安全&#…

【MySQL】事务?隔离级别?锁?详解MySQL并发控制机制

目录 1.先理清一下概念 2.锁 2.1.分类 2.2.表锁 2.3.行锁(MVCC) 2.4.间隙锁 2.5.行锁变表锁 2.6.强制锁行 1.先理清一下概念 所谓并发控制指的是在对数据库进行并发操作时如何保证数据的一致性和正确性。在数据库中与并发控制相关的概念有如下几…

android开发基础有哪些,985研究生入职电网6个月

不好意思久等了 这篇文章让小伙伴们久等了。 一年多以来,关于嵌入式开发学习路线、规划、看什么书等问题,被问得没有一百,也有大几十次了。但是无奈自己对这方面了解有限,所以每次都没法交代,搞得实在不好意思。 但…

SanctuaryAI推出Phoenix: 专为工作而设计的人形通用机器人

文章目录 1. Company2. Main2.1 关于凤凰™ (Phoenix)2.2 关于碳™(Carbon)2.3 商业化部署2.4 关于 Sanctuary Corporation 3. My thoughtsReference彩蛋:将手机变为桌面小机器人 唯一入选《时代》杂志 2023 年最佳发明的通用机器人。 称机器人自主做家务的速度和灵…

7.使用os.Args或flag解析命令行参数

文章目录 一、os.Args二、flag包基本使用 Go语言内置的flag包实现了命令行参数的解析,flag包使得开发命令行工具更为简单。 一、os.Args 如果你只是简单的想要获取命令行参数,可以像下面的代码示例一样使用os.Args来获取命令行参数。 package mainimp…