深度学习_18_模型的下载与读取

在深度学习的过程中,需要将训练好的模型运用到我们要使用的另一个程序中,这就需要模型的下载与转移操作

代码:

import math
import torch
from torch import nn
from d2l import torch as d2l
import matplotlib.pyplot as plt# 生成随机的数据集
max_degree = 20  # 多项式的最大阶数
n_train, n_test = 100, 100  # 训练和测试数据集大小
true_w = torch.zeros(max_degree)
true_w[0:4] = torch.Tensor([5, 1.2, -3.4, 5.6])# 生成特征
features = torch.randn((n_train + n_test, 1))
permutation_indices = torch.randperm(features.size(0))
# 使用随机排列的索引来打乱features张量(原地修改)
features = features[permutation_indices]
poly_features = torch.pow(features, torch.arange(max_degree).reshape(1, -1))
for i in range(max_degree):poly_features[:, i] /= math.gamma(i + 1)# 生成标签
labels = torch.matmul(poly_features, true_w)
labels += torch.normal(0, 0.1, size=labels.shape)# 以下是你原来的训练函数,没有修改
def evaluate_loss(net, data_iter, loss):metric = d2l.Accumulator(2)for X, y in data_iter:out = net(X)y = y.reshape(out.shape)l = loss(out, y)metric.add(l.sum(), l.numel())return metric[0] / metric[1]def l2_penalty(w):w = w[0].weightreturn torch.sum(w.pow(2)) / 2def train(train_features, test_features, train_labels, test_labels, lambd,num_epochs=100):loss = d2l.squared_lossinput_shape = train_features.shape[-1]net = nn.Sequential(nn.Linear(input_shape, 1, bias=False))  # 模型batch_size = min(10, train_labels.shape[0])train_iter = d2l.load_array((train_features, train_labels.reshape(-1, 1)),batch_size)test_iter = d2l.load_array((test_features, test_labels.reshape(-1, 1)),batch_size, is_train=False)# 用于存储训练和测试损失的列表train_losses = []test_losses = []total_loss = 0total_samples = 0for epoch in range(num_epochs):for X, y in train_iter:out = net(X)y = y.reshape(-1, 1)  # 确保y是二维的l = loss(out, y) + lambd * l2_penalty(net)# 反向传播和优化器更新l.sum().backward()d2l.sgd(net.parameters(), lr=0.01, batch_size= batch_size)total_loss += l.sum().item()  # 统计所有元素损失total_samples += y.numel()  # 统计个数a = total_loss / total_samples  # 本次训练的平均损失train_losses.append(a)test_loss = evaluate_loss(net, test_iter, loss)test_losses.append(test_loss)total_loss = 0total_samples = 0print(f"Epoch {epoch + 1}/{num_epochs}:")print(f"训练损失: {a:.4f}   测试损失: {test_loss:.4f} ")print(net[0].weight)torch.save(net.state_dict(), "NetSave")  # 存模型net_try = nn.Sequential(nn.Linear(input_shape, 1, bias=False))print("net_try")print(net_try[0].weight)net_try.load_state_dict(torch.load("NetSave"))net_try.eval()  # 评估模式print("net_try_load")print(net_try[0].weight)# 绘制损失曲线plt.figure(figsize=(10, 6))plt.plot(train_losses, label='train', color='blue', linestyle='-', marker='.')plt.plot(test_losses, label='test', color='purple', linestyle='--', marker='.')plt.xlabel('epoch')plt.ylabel('loss')plt.title('Loss over Epochs')plt.legend()plt.grid(True)plt.ylim(0, 1)  # 设置y轴的范围从0.01到100plt.show()# 选择多项式特征中的前4个维度
train(poly_features[:n_train, :4], poly_features[n_train:, :4],labels[:n_train], labels[n_train:], 0)##  net.parameters() 是一个 PyTorch 模型的方法,用于返回模型所有参数的迭代器。这个迭代器产生模型中所有可学习的参数(例如权重和偏置)。

上述代码的模型是简单线性模型

net = nn.Sequential(nn.Linear(input_shape, 1, bias=False))  # 模型

此模型的下载与储存如下

    torch.save(net.state_dict(), "NetSave")  # 存模型net_try = nn.Sequential(nn.Linear(input_shape, 1, bias=False))  # 搭建模型框架print("net_try")print(net_try[0].weight)net_try.load_state_dict(torch.load("NetSave"))  # 下载模型net_try.eval()  # 评估模式print("net_try_load")print(net_try[0].weight)

效果
在这里插入图片描述

所以说要想在另一个程序中将训练好的模型加载到上面去,首先要保存训练好的模型,另一个程序必须有和本模型一样的框架,再将训练好的模型权重加载到另一个程序框架内即可

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/724466.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

NFTScan :什么是 ERC-404?深入解读 NFT 协议的未来

上月初,ERC-404 成为最首要热门的话题,ERC-404 是由 Pandora 团队在 2 月初为创作者和开发者等开源的实验性代币标准,其混合 ERC-20 / ERC-721 实现,具有原生流动性和碎片化等特点。伴随着早期的发展,越来越多参与者开…

Java学习笔记006——子类与父类的类型转换

在Java中,类型转换主要涉及到两种类型:向上类型转换(Upcasting)和向下类型转换(Downcasting)。 1. 向上类型转换(Upcasting): 向上类型转换是将子类的对象转换为父类类…

leetcode 227. 基本计算器 II

题目描述 给你一个字符串表达式 s ,请你实现一个基本计算器来计算并返回它的值。 整数除法仅保留整数部分。 你可以假设给定的表达式总是有效的。所有中间结果将在 [-231, 231 - 1] 的范围内。 注意 不允许使用任何将字符串作为数学表达式计算的内置函数&#x…

win10安全中心误删文件怎么办?解析恢复与预防策略

在使用Windows 10的过程中,许多用户依赖于其内置的安全中心来保护电脑免受恶意软件的侵害。然而,有时安全中心的误判可能导致重要文件被错误地删除。当面对这种情况时,了解如何恢复误删的文件并掌握预防措施显得尤为重要。本文将为您详细解析…

java常用技术栈,java面试带答案

前言 我们从一个问题引入今天的主题。 在日常业务开发中,我们可能经常听到 DBA 对我们说“不要”(注意:不是禁止)使用 join,那么为什么 DBA 对 join 这么抵触呢?是 join 本身有问题,还是我们使…

【weblogic 报错 application webapp does not have any Components in it.】

当你启动你的web时,报错weblogic 报错 application webapp does not have any Components in it. 检查你的startWeblogic.sh 看一下你的项目路径是否正确。

G1018选择排序

题目描述 完善程序&#xff1a; 输入N个整数&#xff0c;使用选择排序法从小到大输出。 #include<bits/stdc.h> using namespace std; int N; int a[100010]; int main() {freopen("1455.in","r",stdin);freopen("1455.out","w&quo…

私募证券基金动态-24年2月报

成交量&#xff1a;2月日均9492.60亿元 2024年2月A股两市日均成交9492.60亿元&#xff0c;环比增加30.38%、同比增加5.77%。2月整体15个交易日&#xff0c;有4个单日交易日成交金额过万亿&#xff0c;单日交易日最高成交金额为13576.43亿元&#xff08;2月28日&#xff09;&am…

MySQL 学习笔记(基础篇 Day1)

「写在前面」 本文为黑马程序员 MySQL 教程的学习笔记。本着自己学习、分享他人的态度&#xff0c;分享学习笔记&#xff0c;希望能对大家有所帮助。 目录 0 课程介绍 1 MySQL 概述 1.1 数据库相关概念 1.2 MySQL 数据库 2 SQL 2.1 SQL 通用语法 2.2 SQL 分类 2.3 DDL 2.4 图形…

【leetcode C++】电话号码的字母组合

17. 电话号码的字母组合 题目 给定一个仅包含数字 2-9 的字符串&#xff0c;返回所有它能表示的字母组合。答案可以按 任意顺序 返回。 给出数字到字母的映射如下&#xff08;与电话按键相同&#xff09;。注意 1 不对应任何字母。 题目链接 . - 力扣&#xff08;LeetCode&…

1.类和对象-友元

文章目录 1.全局函数做友元代码运行结果 2.类做友元代码运行结果 1.全局函数做友元 思路分析&#xff1a; 正常情况下&#xff0c;全局函数visit()中的ROOM 类变量r是访问不到Building类中的私有成员的。但是通过在Building类中添加使用全局函数做友元&#xff0c;即可访问私有…

什么是ElasticSearch的深度分页问题?如何解决?

在ElasticSearch中进行分页查询通常使用from和size参数。当我们对ElasticSearch发起一个带有分页参数的查询(如使用from和size参数)时,ElasticSearch需要遍历所以匹配的文档直到达到指定的起始点(from),然后返回从这一点开始的size个文档 在这个例子中: 1.from 参数定义…

ThreadLocal通俗解读,举个例子?

ThreadLocal 解读 ThreadLocal 是 Java 中一个用来创建线程局部变量的类。它为每个使用该变量的线程提供独立的变量副本 线程局部变量意味着对于同一个 ThreadLocal 实例&#xff0c;在不同的线程中访问到的值是不同的&#xff0c;每个线程都有自己的变量副本。这样可以在多线…

代码学习记录13

随想录日记part13 t i m e &#xff1a; time&#xff1a; time&#xff1a; 2024.03.06 主要内容&#xff1a;今天的主要内容是二叉树的第二部分哦&#xff0c;主要有层序遍历&#xff1b;翻转二叉树&#xff1b;对称二叉树。 102.二叉树的层序遍历226.翻转二叉树101. 对称二叉…

MySQL用户创建和权限分配

MySQL用户创建和权限分配 用户创建 查看用户 select user,host from user; 创建用户 create user ‘用户名’ ‘host’ identified by 密码’; 删除用户 drop user ‘用户名’ ‘host’; 更新host update user set host ‘%’ where user 用户名‘&#xff1b; 权限分配 查…

逢7过,从任意一个数字开始报数,当你要报的数字包含7或者是7的倍数时都要说:过(1~100之间满足逢7必过规则的数据)

分析&#xff1a;这题就是碰到 7是个位&#xff0c;7是十位&#xff0c;7的倍数 就要过 // 1 2 3 4 5 6 过 8 9 10 11 12 13 过 14 15 16 过 18 19 20 过。。 69 过 过 过 过 过 。。80.。。 判断每个数字&#xff0c;如果符合条件&#xff0c;就打印过&#xff0c;如果不符…

2024中国重庆沐浴博览会5月29日-31日

2024中国沐浴展——世界级温泉胜地&#xff0c;引领健康产业新风向 2024第五届中国沐浴健康产业&#xff08;重庆&#xff09;博览会暨第十一届中国沐浴温泉文化节 ——世界级温泉胜地&#xff0c;引领健康产业新风向 随着人们生活水平的提高和健康意识的增强&#xff0c;沐…

LeetCode-第67题-二进制求和

1.题目描述 给你两个二进制字符串 a 和 b &#xff0c;以二进制字符串的形式返回它们的和。 2.样例描述 3.思路描述 将两个二进制字符串转换成整型&#xff0c;然后相加后的整型转为二进制字符串 4.代码展示 class Solution(object):def addBinary(self, a, b):# 将字符串…

AI新工具(20240306) mlx-swift-chat Mac运行本地模型;Comflowyspace开源AI图像和视频生成工具

1: mlx-swift-chat 专为 Apple 硅片设计的高效机器学习框架&#xff0c;支持在本地实时运行 LLM 模型&#xff08;如 Llama、Mistral&#xff09; mlx-swift-chat 是一个为苹果系统&#xff08;例如你的笔记本电脑上的Apple Silicon&#xff09;特别设计的机器学习框架 MLX 的…

计划任务和日志

一、计划任务 计划任务概念解析 在Linux操作系统中&#xff0c;除了用户即时执行的命令操作以外&#xff0c;还可以配置在指定的时间、指定的日期执行预先计划好的系统管理任务&#xff08;如定期备份、定期采集监测数据&#xff09;。RHEL6系统中默认已安装了at、crontab软件…