【PyTorch】模型的基本操作

文章目录

  • 1. 模型的创建
    • 1.1. 创建方法
      • 1.1.1. 通过使用模型组件
      • 1.1.2. 通过继承nn.Module类
    • 1.2. 将模型转移到GPU
  • 2. 模型参数初始化
  • 3. 模型的保存与加载
    • 3.1. 只保存参数
    • 3.2. 保存模型和参数

1. 模型的创建

1.1. 创建方法

1.1.1. 通过使用模型组件

可以直接使用模型组件快速创建模型。

import torch.nn as nnmodel =	nn.Linear(10, 10)
print(model)

输出结果:

Linear(in_features=10, out_features=10, bias=True)

1.1.2. 通过继承nn.Module类

__init__方法中使用模型组件定义模型各层,还可以直接使用torch.nn.functional中的函数。必须重写forward方法实现前向传播。

import torch.nn as nnclass Model(nn.Module):def __init__(self):super().__init__()self.layer1 = nn.Linear(10, 10)self.layer2 = nn.Linear(10, 10)self.layer3 = nn.Sequential(nn.Linear(10, 10),nn.ReLU(),nn.Linear(10, 10))def forward(self, x):x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)return xmodel = Model()
print(model)

输出结果:

Model((layer1): Linear(in_features=10, out_features=10, bias=True)(layer2): Linear(in_features=10, out_features=10, bias=True)(layer3): Sequential((0): Linear(in_features=10, out_features=10, bias=True)(1): ReLU()(2): Linear(in_features=10, out_features=10, bias=True))
)

1.2. 将模型转移到GPU

方法与将数据转移到GPU类似,都有两种方法:

  1. model.to(device)
  2. mode.cuda()
import torch
import torch.nn as nn# 创建模型实例
model = nn.Sequential(nn.Linear(10, 10),nn.ReLU(),nn.Linear(10, 10)
)# 将模型移动到GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# 也可以
model = model.cuda()

2. 模型参数初始化

torch.nn.init提供了许多初始化参数的函数:

函数名作用参数
uniform_从均匀分布 U ( a , b ) U(a,b) U(a,b)中生成值,填充输入的张量tensor, a = 0, b = 1
normal_从正态分布 N ( m e a n , s t d 2 ) N(mean, std^2) N(mean,std2)中生成值,填充输入的张量tensor, mean = 0, std = 1
constant_用常数 v a l val val,填充输入的张量tensor, val
eye_用单位矩阵,填充二维输入张量tensor(二维)
dirac_用狄拉克函数,填充{3, 4, 5}维输入张量tensor({3, 4, 5}维), groups = 1
xavier_uniform_从xavier均匀分布中生成值,填充输入张量tensor, gain = 1
xavier_normal_从xavier正态分布中生成值,填充输入张量tensor, gain = 1
kaiming_uniform_从kaiming均匀分布中生成值,填充输入张量tensor, a = 0, mode = ‘fan_in’, nonlinearity = ‘leaky_relu’
kaiming_normal_从kaiming正态分布中生成值,填充输入张量tensor, a = 0, mode = ‘fan_in’, nonlinearity = ‘leaky_relu’
orthogonal_用一个(半)正交矩阵,填充输入张量tensor, gain = 1
sparse_用非零元素服从 N ( 0 , s t d 2 ) N(0, std^2) N(0,std2)的稀疏矩阵,填充二维输入张量tensor, sparsity, std = 0.01

3. 模型的保存与加载

模型保存和加载使用的python内置的pickle模块。

3.1. 只保存参数

import torch
import torch.nn as nn# 创建模型实例
model1 = nn.Sequential(nn.Linear(10, 10),nn.ReLU(),nn.Linear(10, 10)
)# 保存和加载参数
torch.save(model1.state_dict(), '../model/model_params.pkl')
model1.load_state_dict(torch.load('../model/model_params.pkl'))

3.2. 保存模型和参数

import torch
import torch.nn as nn# 创建模型实例
model1 = nn.Sequential(nn.Linear(10, 10),nn.ReLU(),nn.Linear(10, 10)
)# 保存和加载模型和参数
torch.save(model1, '../model/model.pt')
model2 = torch.load('../model/model.pt')
print(model2)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/205443.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

3.镜像加速器

目录 1 阿里云 2 网易云 从网络上拉取镜像的时候使用默认的源可能会慢,用国内的源会快一些 1 阿里云 访问 阿里云-计算,为了无法计算的价值 然后登录,登录后搜索 容器镜像服务 点击容器镜像服务 点击管理控制台 点击 镜像工具->镜像…

【web安全】文件包含漏洞详细整理

前言 菜某的笔记总结,如有错误请指正。 本文用的是PHP语言作为案例 文件包含漏洞的概念 开发者使用include()等函数,可以把别的文件中的代码引入当前文件中执行,而又没有对用户输入的内容进行充分的过滤&#xff0…

5G入门到精通 - 5G的十大关键技术

文章目录 一、网络切片二、自组织网络三、D2D技术四、低时延技术五、MIMO技术六、毫米波七、内容分发网络八、M2M技术九、频谱共享十、信息中心网络 一、网络切片 5G中的网络切片是一项关键技术,它允许将整个5G网络分割成多个独立的虚拟网络,每个虚拟网络…

CodeBlocks添加头文件,解决fatal error: ui.h No such file or directory

问题描述 在使用codeblocks工具进行LVGL仿真过程中报错,找不到头文件 原因分析: 没有将头文件加入编辑器搜索的目录中,编译时找不到头文件。 解决方案: 将要包含的头文件的目录加进去就可以了

BCI-Two-streams hypothesis(双流假说)

双流假说 双流假设(Two-stream hypothesis)是关于视觉和听觉神经处理的模型。该假设最初由大卫米尔纳(David Milner)和梅尔文古德尔(Melvyn A. Goodale)于1992年的一篇论文中进行了初步描述,认为人类拥有两个独立的视觉…

【爬取音乐,并将音乐信息储存到数据库中】

爬取音乐,并将音乐信息储存到数据库中 确定音乐网站的url并分析网站分析二级页面创建数据库使用Xpath解析,进行多层爬取保存信息完整代码结果 确定音乐网站的url并分析网站 分析二级页面 创建数据库 # 创建一个链接对象 conn pymysql.connect(hostmaster, userroo…

虚拟网络技术:bond技术

网卡bond也称为网卡捆绑,就是将两个或者更多的物理网卡绑定成一个虚拟网卡。 bond的作用: 1.提高网卡的吞吐量 2.增加网络的高可用,实现负载均衡。 一、bond简介 bond技术即bonding,能将多块物理网卡绑定到一块虚拟网卡上&…

六、C语言数组

1. 数组的概念 数组是⼀组相同类型元素的集合;从这个概念中我们就可以发现2个有价值的信息: 数组中存放的是1个或者多个数据,但是数组元素个数不能为0。数组中存放的多个数据,类型是相同的。 数组分为⼀维数组和多维数组&#xf…

Prometheus 发现机制和告警

1.服务发现 Prometheus Server的数据抓取工作于Pull模型,因而,它必需要事先知道各Target的位置,然后才能从相应的Exporter或Instrumentation中抓取数据。在不同的场景下,需要结合不同的机制来实现对应的数据抓取目的。 对于小型的…

企业级 接口自动化测试框架:Pytest+Allure+Excel

1. Allure 简介 简介 Allure 框架是一个灵活的、轻量级的、支持多语言的测试报告工具,它不仅以 Web 的方式展示了简介的测试结果,而且允许参与开发过程的每个人可以从日常执行的测试中,最大限度地提取有用信息。 Allure 是由 Java 语言开发…

基于selenium工具刷b站播放量(请谨慎使用)

基于selenium工具刷b站播放量(请谨慎使用) from selenium import webdriver import time import random# 打开B站视频 url input("url:") if url "":url https://www.bilibili.com/video/BV1K64y1574T for i in range(50):# 设置…

【学习记录】从0开始的Linux学习之旅——字符型设备驱动及应用

一、概述 Linux操作系统通常是基于Linux内核,并结合GNU项目中的工具和应用程序而成。Linux操作系统支持多用户、多任务和多线程,具有强大的网络功能和良好的兼容性。基于前面应用与驱动的开发学习,本文主要讲述如何在linux系统上把应用与驱动…

参考信号速度变化存在跳跃时容易发生不稳定的阻抗调节

问题描述 当参考信号速度存在跳跃变化时,阻抗调节系统容易发生不稳定。这是因为阻抗调节系统需要根据参考信号的速度来调整其输出阻抗,以匹配负载阻抗,从而保持系统的稳定性。 当参考信号速度突然变化时,阻抗调节系统可能无法及…

『TypeScript』深入理解变量声明、函数定义、类与接口及泛型

📣读完这篇文章里你能收获到 了解TypeScript变量声明与类型注解掌握TypeScript函数与方法的使用掌握TypeScript类与接口的使用掌握TypeScript泛型的应用 文章目录 一、变量声明与类型注解1. 变量声明2. 类型注解3. 类型推断 二、函数与方法定义1. 函数定义2. 方法定…

Kubernetes权威指南:从Docker到Kubernetes实践全接触(第5版)读书笔记 目录

完结状态:未完结 文章目录 前言第1章 Kubernetes入门 11.1 了解Kubernetes 2 附录A Kubernetes核心服务配置详解 915总结 前言 提示:这里可以添加本文要记录的大概内容: Kubernetes权威指南:从Docker到Kubernetes实践全接触&…

Jmeter 性能测试基础!

压力测试   压力测试分两种场景:一种是单场景,压一个接口的;第二种是混合场景,多个有关联的接口。压测时间,一般场景都运行10-15分钟。如果是疲劳测试,可以压一天或一周,根据实际情况来定。 压…

【编程技术】CUDA TencoreCore编程实例说明

概述 通过一个m16n8k16矩阵乘法的CUDA TencoreCore编程实例,展示load/store mma 的矩阵乘法运行过程 动画实例 CUDA TensoreCore 编程实例

springboot 在自定义注解中注入bean,解决注入bean为null的问题

问题: 在我们开发过程中总会遇到比如在某些场合中需要使用service或者mapper等读取数据库,或者某些自动注入bean失效的情况 解决方法: 1.在构造方法中通过工具类获取需要的bean 工具类代码: import org.springframework.beans…

Spring到底是如何解决循环依赖问题的?

Spring作为当前使用最广泛的框架之一,其重要性不言而喻。所以充分理解Spring的底层实现原理对于咱们Java程序员来说至关重要,那么今天笔者就详细说说Spring框架中一个核心技术点:如何解决循环依赖问题? 什么是循环依赖问题&#x…

深入理解Java中的逃逸分析

目录 1. 对象作用域分析2. 栈上分配3. 同步省略(锁消除)4. 标量替换 逃逸分析是一种编译器优化技术,用于确定对象的作用域和生命周期。其主要特点包括:对象作用域分析、栈上分配、同步省略和标量替换。现在将详细阐述这些特点&…