【PyTorch】(三)模型的创建、参数初始化、保存和加载

文章目录

  • 1. 模型的创建
    • 1.1. 模型组件
      • 1.1.1. 网络层
      • 1.1.2. 激活函数
      • 1.1.3. 函数包
      • 1.1.4. 容器
    • 1.2. 创建方法
      • 1.1.1. 通过使用模型组件
      • 1.1.2. 通过继承nn.Module类
    • 1.3. 将模型转移到GPU
  • 2. 模型参数初始化
  • 3. 模型的保存与加载
    • 3.1. 只保存参数
    • 3.2. 保存模型和参数

1. 模型的创建

1.1. 模型组件

1.1.1. 网络层

1.1.2. 激活函数

1.1.3. 函数包

1.1.4. 容器

1.2. 创建方法

1.1.1. 通过使用模型组件

可以直接使用模型组件快速创建模型。

import torch.nn as nnmodel =	nn.Linear(10, 10)
print(model)

输出结果:

Linear(in_features=10, out_features=10, bias=True)

1.1.2. 通过继承nn.Module类

在__init__方法中使用模型组件定义模型各层。必须重写forward方法实现前向传播。

import torch.nn as nnclass Model(nn.Module):def __init__(self):super().__init__()self.layer1 = nn.Linear(10, 10)self.layer2 = nn.Linear(10, 10)self.layer3 = nn.Sequential(nn.Linear(10, 10),nn.ReLU(),nn.Linear(10, 10))def forward(self, x):x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)return xmodel = Model()
print(model)

输出结果:

Model((layer1): Linear(in_features=10, out_features=10, bias=True)(layer2): Linear(in_features=10, out_features=10, bias=True)(layer3): Sequential((0): Linear(in_features=10, out_features=10, bias=True)(1): ReLU()(2): Linear(in_features=10, out_features=10, bias=True))
)

1.3. 将模型转移到GPU

方法与将数据转移到GPU类似,都有两种方法:

  1. model.to(device)
  2. mode.cuda()
import torch
import torch.nn as nn# 创建模型实例
model = nn.Sequential(nn.Linear(10, 10),nn.ReLU(),nn.Linear(10, 10)
)# 将模型移动到GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# 也可以
model = model.cuda()

2. 模型参数初始化

torch.nn.init提供了许多初始化参数的函数:

函数名作用参数
uniform_从均匀分布 U ( a , b ) U(a,b) U(a,b)中生成值,填充输入的张量tensor, a = 0, b = 1
normal_从正态分布 N ( m e a n , s t d 2 ) N(mean, std^2) N(mean,std2)中生成值,填充输入的张量tensor, mean = 0, std = 1
constant_用常数 v a l val val,填充输入的张量tensor, val
eye_用单位矩阵,填充二维输入张量tensor(二维)
dirac_用狄拉克函数,填充{3, 4, 5}维输入张量tensor({3, 4, 5}维), groups = 1
xavier_uniform_从xavier均匀分布中生成值,填充输入张量tensor, gain = 1
xavier_normal_从xavier正态分布中生成值,填充输入张量tensor, gain = 1
kaiming_uniform_从kaiming均匀分布中生成值,填充输入张量tensor, a = 0, mode = ‘fan_in’, nonlinearity = ‘leaky_relu’
kaiming_normal_从kaiming正态分布中生成值,填充输入张量tensor, a = 0, mode = ‘fan_in’, nonlinearity = ‘leaky_relu’
orthogonal_用一个(半)正交矩阵,填充输入张量tensor, gain = 1
sparse_用非零元素服从 N ( 0 , s t d 2 ) N(0, std^2) N(0,std2)的稀疏矩阵,填充二维输入张量tensor, sparsity, std = 0.01

3. 模型的保存与加载

模型保存和加载使用的python内置的pickle模块。

3.1. 只保存参数

import torch
import torch.nn as nn# 创建模型实例
model1 = nn.Sequential(nn.Linear(10, 10),nn.ReLU(),nn.Linear(10, 10)
)# 保存和加载参数
torch.save(model1.state_dict(), '../model/model_params.pkl')
model1.load_state_dict(torch.load('../model/model_params.pkl'))

3.2. 保存模型和参数

import torch
import torch.nn as nn# 创建模型实例
model1 = nn.Sequential(nn.Linear(10, 10),nn.ReLU(),nn.Linear(10, 10)
)# 保存和加载模型和参数
torch.save(model1, '../model/model.pt')
model2 = torch.load('../model/model.pt')
print(model2)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/186762.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Ruby和HTTParty库下载代码示例

ruby require httparty require nokogiri # 设置服务器 proxy_host "" proxy_port "" # 定义URL url "" # 创建HTTParty对象,并设置服务器 httparty HTTParty.new( :proxy > "#{proxy_host}:#{proxy_port}" ) …

MySQL之binlog日志

聊聊BINLOG binlog记录什么? MySQL server中所有的搜索引擎发生了更新(DDL和DML)都会产生binlog日志,记录的是语句的原始逻辑 为什么需要binlog? binlog主要有两个应用场景,一是数据复制,在…

训练自己的个性化Stable diffusion模型,LORA

一、背景 需要训练自己的LORA模型 二、分析 1、有sd-webui有训练插件功能 2、有单独的LORA训练开源web界面 两个开源训练界面 1、秋叶写的SD-Trainer https://github.com/Akegarasu/lora-scripts/ 没成功,主要也是cudnn和nvidia-smi中的CUDA版本不一致退出 2…

Netty Review - 探索Channel和Pipeline的内部机制

文章目录 概念Channel Pipeline实现原理分析详解 Inbound事件和Outbound事件演示Code 概念 Netty中的Channel和Pipeline是其核心概念,它们在构建高性能网络应用程序时起着重要作用。 Channel: 在Netty中,Channel表示一个开放的连接&#xff…

由于找不到msvcp120.dll的解决方法,msvcp120.dll修复指南

当你尝试运行某些程序或游戏时,可能会遇到系统弹出的错误消息,提示"找不到msvcp120.dll"或"msvcp120.dll丢失"。这种情况通常会妨碍程序的正常启动。为了帮助解决这一问题,本文将深入讨论msvcp120.dll是什么,…

C语言中的预处理指令

预处理指令是在编译之前由预处理器处理的命令。这些指令不是C语言的一部分,而是指导预处理器如何准备代码进行编译。预处理指令以井号(#)开头,主要可以分为以下几组: 一、 宏定义指令 #define: 定义宏。 #undef: 取消已定义的宏。宏可以定义常量,如 #define PI 3.14159。…

YOLOv8优化策略:检测头结构全新创新篇 | RT-DETR检测头助力,即插即用

🚀🚀🚀本文改进:RT-DETR检测头助力YOLOv8检测,保持v8轻量级的同时提升检测精度 🚀🚀🚀YOLOv8改进专栏:http://t.csdnimg.cn/hGhVK 学姐带你学习YOLOv8,从入门到创新,轻轻松松搞定科研; 1.RT-DETR介绍 论文: https://arxiv.org/pdf/2304.08069.pdf 摘要:…

再探Java集合系列—LinkedHashMap

LinkedHashMap 继承了 HashMap 所以LinkedHashMap也是一种k-v的键值对,并且内部是双链表的形式维护了插入的顺序 LinkedHashMap如何保证顺序插入的? 在HashMap中时候说到过HashMap插入无序的 LinkedHashMap使用了双向链表,内部的node节点包含…

【Linux】 服务器优化之定时任务:自动清理日志,重启服务

文章目录 ⭐️背景🏆处理流程查看进程清理日志文件重启服务 💖问题总结👍完整处理方案清理日志脚本自动重启服务计划任务定时清理日志文件定时重启服务 开机启动定时任务 ⭐️背景 部署在客户服务器项目无法访问,最后发现服务器上…

Docker 的基本概念和常用命令,应用程序开发中的实际应用。

Docker 是一种开源的容器化平台,能够帮助开发人员更加轻松地打包、部署和运行应用程序。以下是 Docker 的基本概念和优势: 基本概念: 镜像(image):类似于虚拟机镜像,包含了应用程序运行所需的所…

CityEngine2023安装与快速入门

目录 0 引言1 安装2 CityEngine官方示例2.1 官方地址2.2 导入示例工程 3 结尾 🙋‍♂️ 作者:海码007📜 专栏:CityEngine专栏💥 标题:CityEngine2023安装与快速入门❣️ 寄语:书到用时方恨少&am…

Linux基础命令之网络配置管理常用命令

在Linux中,有许多命令可以用于网络管理。以下是一些常用的Linux网络管理相关的命令 # 1、ifconfig 这是一个常用的网络配置工具,可以用来查看和配置网络接口。这个命令在大多数Linux发行版中都可以使用,包括Ubuntu、Debian、CentOS、Fedora…

解读拼多多Q3财报:Temu崭露头角,跨境故事刚刚开场

11月28日,拼多多发布了2023年第三季度的业绩报告,季度营收688.4亿元,较去年同期大涨94%,比市场预期高出100多亿元。 截止到11月28日美股收盘,拼多多股价上涨18.8%,总市值达到1834.23亿美元。11月29日美股开…

P1025 [NOIP2001 提高组] 数的划分

暴搜 剪枝 枚举固定的位置 #include<bits/stdc.h> using namespace std; using ll long long; const int N 1e310; int n,k; int res; void dfs(int last,int sum,int cur){if(curk){if(sumn)res;return;}for(int ilast;isum<n;i)dfs(i,sumi,cur1); } int main() {c…

倒计时(JS计时器)

<script>function countDown() {document.body.innerHTML ;//清空页面内容var nowTimer new Date(); //现在时间的毫秒数var valueTimer new Date("2024-1-1 12:00"); //用户输入年份倒计时时间毫秒数var timer (valueTimer - nowTimer) / 1000; //倒计时秒…

有什么值得推荐的node. js练手项目吗?

前言 可以参考一下下面的nodejs相关的项目&#xff0c;希望对你的学习有所帮助&#xff0c;废话少说&#xff0c;让我们直接进入正题 1、 NodeBB Star: 13.3k 一个基于Node.js的现代化社区论坛软件&#xff0c;具有快速、可扩展、易于使用和灵活的特点。它支持多种数据库&…

解决:ValueError: the first two maketrans arguments must have equal length

解决&#xff1a;ValueError: the first two maketrans arguments must have equal length 文章目录 解决&#xff1a;ValueError: the first two maketrans arguments must have equal length背景报错问题报错翻译报错位置代码报错原因解决方法今天的分享就到此结束了 背景 在…

大数据-之LibrA数据库系统告警处理(ALM-37018 数据库用户连接数超限)

告警解释 当集群中单个CN实例上某个用户的连接数超过限制时&#xff0c;产生该告警。 告警属性 告警ID 告警级别 可自动清除 37018 严重 是 告警参数 参数名称 参数含义 ServiceName 产生告警的服务名称 RoleName 产生告警的角色名称 HostName 产生告警的主机名…

如何在Ubuntu系统上安装Git

简单介绍 Git是一个开源的分布式版本控制系统&#xff0c;用于敏捷高效地处理任何或小或大的项目。Git是Linus Torvalds为了帮助管理Linux内核开发而开发的一个开放源码的版本控制软件。Git 与常用的版本控制工具CVS&#xff0c;Subversion 等不同&#xff0c;它采用了分布式版…

四、shell - 字符串

目录 1、单引号 2、双引号 3、拼接字符串 3.1 使用双引号拼接 3.2 使用单引号拼接 4、获取字符串长度 ​​​​​​​5、提取子字符串 ​​​​​​​6、查找子字符串 ​​​​​​​字符串是shell编程中最常用最有用的数据类型&#xff08;除了数字和字符串&#xff0…