【PyTorch】模型

文章目录

  • 1. 模型的创建
    • 1.1. 创建方法
      • 1.1.1. 通过使用模型组件
      • 1.1.2. 通过继承nn.Module类
    • 1.2. 将模型转移到GPU
  • 2. 模型参数初始化
  • 3. 模型的保存与加载
    • 3.1. 只保存参数
    • 3.2. 保存模型和参数

1. 模型的创建

1.1. 创建方法

1.1.1. 通过使用模型组件

可以直接使用模型组件快速创建模型。

import torch.nn as nnmodel =	nn.Linear(10, 10)
print(model)

输出结果:

Linear(in_features=10, out_features=10, bias=True)

1.1.2. 通过继承nn.Module类

在__init__方法中使用模型组件定义模型各层。必须重写forward方法实现前向传播。

import torch.nn as nnclass Model(nn.Module):def __init__(self):super().__init__()self.layer1 = nn.Linear(10, 10)self.layer2 = nn.Linear(10, 10)self.layer3 = nn.Sequential(nn.Linear(10, 10),nn.ReLU(),nn.Linear(10, 10))def forward(self, x):x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)return xmodel = Model()
print(model)

输出结果:

Model((layer1): Linear(in_features=10, out_features=10, bias=True)(layer2): Linear(in_features=10, out_features=10, bias=True)(layer3): Sequential((0): Linear(in_features=10, out_features=10, bias=True)(1): ReLU()(2): Linear(in_features=10, out_features=10, bias=True))
)

1.2. 将模型转移到GPU

方法与将数据转移到GPU类似,都有两种方法:

  1. model.to(device)
  2. mode.cuda()
import torch
import torch.nn as nn# 创建模型实例
model = nn.Sequential(nn.Linear(10, 10),nn.ReLU(),nn.Linear(10, 10)
)# 将模型移动到GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# 也可以
model = model.cuda()

2. 模型参数初始化

torch.nn.init提供了许多初始化参数的函数:

函数名作用参数
uniform_从均匀分布 U ( a , b ) U(a,b) U(a,b)中生成值,填充输入的张量tensor, a = 0, b = 1
normal_从正态分布 N ( m e a n , s t d 2 ) N(mean, std^2) N(mean,std2)中生成值,填充输入的张量tensor, mean = 0, std = 1
constant_用常数 v a l val val,填充输入的张量tensor, val
eye_用单位矩阵,填充二维输入张量tensor(二维)
dirac_用狄拉克函数,填充{3, 4, 5}维输入张量tensor({3, 4, 5}维), groups = 1
xavier_uniform_从xavier均匀分布中生成值,填充输入张量tensor, gain = 1
xavier_normal_从xavier正态分布中生成值,填充输入张量tensor, gain = 1
kaiming_uniform_从kaiming均匀分布中生成值,填充输入张量tensor, a = 0, mode = ‘fan_in’, nonlinearity = ‘leaky_relu’
kaiming_normal_从kaiming正态分布中生成值,填充输入张量tensor, a = 0, mode = ‘fan_in’, nonlinearity = ‘leaky_relu’
orthogonal_用一个(半)正交矩阵,填充输入张量tensor, gain = 1
sparse_用非零元素服从 N ( 0 , s t d 2 ) N(0, std^2) N(0,std2)的稀疏矩阵,填充二维输入张量tensor, sparsity, std = 0.01

3. 模型的保存与加载

模型保存和加载使用的python内置的pickle模块。

3.1. 只保存参数

import torch
import torch.nn as nn# 创建模型实例
model1 = nn.Sequential(nn.Linear(10, 10),nn.ReLU(),nn.Linear(10, 10)
)# 保存和加载参数
torch.save(model1.state_dict(), '../model/model_params.pkl')
model1.load_state_dict(torch.load('../model/model_params.pkl'))

3.2. 保存模型和参数

import torch
import torch.nn as nn# 创建模型实例
model1 = nn.Sequential(nn.Linear(10, 10),nn.ReLU(),nn.Linear(10, 10)
)# 保存和加载模型和参数
torch.save(model1, '../model/model.pt')
model2 = torch.load('../model/model.pt')
print(model2)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/204923.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Unity动画】Unity 2D动画创建流程

本文以2D为案例,讲解Unity 播放动画的流程 准备和导入2D动画资源 外部导入序列帧生成的 Unity内部制作的 外部导入的3D动画 2.创建动画过程 打开时间轴Ctrl6 选中场景中的一个未来需要播放动画的物体 回到时间轴点击Create一个新动画片段 拖动2D动画资源放入…

什么是SPA(Single Page Application)?它的优点和缺点是什么?

聚沙成塔每天进步一点点 ⭐ 专栏简介 前端入门之旅:探索Web开发的奇妙世界 欢迎来到前端入门之旅!感兴趣的可以订阅本专栏哦!这个专栏是为那些对Web开发感兴趣、刚刚踏入前端领域的朋友们量身打造的。无论你是完全的新手还是有一些基础的开发…

IT外包对中小企业的独特优势

在竞争激烈的商业环境中,企业的发展稍有缓慢,就很有可能被竞争对手快速赶超、趁机抢占市场。一些中小企业为了更好地应对市场变化和提高自身竞争力,越来越多地转向了IT外包服务。相较于大型企业,中小企业在选择IT外包时能够获得一…

数据结构实验任务七:基于广度优先搜索的六度空间理论验证

问题描述 “六度空间”理论又称作“六度分隔(Six Degrees of Separation)”理论。这个理论 可以通俗地阐述为:“你和任何一个陌生人之间所间隔的人不会超过六个,也就是 说,最多通过五个人你就能够认识任何一个陌生人。”假如给你一个社交网络图&#xf…

java中用thumbnailator依赖写一个压缩图片的类,只要图片大小超过几兆就无限循环下去的详细代码实例?(经典)

下面是使用Thumbnailator依赖编写的一个压缩图片的类。该类会不断循环压缩图片,直到图片大小小于指定的阈值(以字节为单位)。 java Copy code import net.coobird.thumbnailator.Thumbnails; import java.io.File; import java.io.IOExcept…

Tap虚拟网卡

1 概述 Tap设备通常用于虚拟化场景下,其驱动代码位于drivers/net/tun.c,tap与tun复用大部分代码, 注:drivers/net/tap.c并不是tap设备的代码,而是macvtap和ipvtap; 下文中,我们统一称tap&#…

父子进程继承问题:OSError: [Errno 88] Socket operation on non-socket错误记录

目录 1 错误:self.server_address = self.socket.getsockname()OSError: [Errno 88] Socket operation on non-socket 2 错误排查过程 3 解决方法

java中用thumbnailator依赖写一个压缩图片的类,只要图片大小超过固定尺寸就无限循环下去的详细代码实例?

下面是使用thumbnailator依赖编写的一个压缩图片类的详细代码示例,该类会对大小超过固定尺寸的图片进行无限循环压缩。 java Copy code import net.coobird.thumbnailator.Thumbnails; import javax.imageio.ImageIO; import java.awt.image.BufferedImage; import…

四、分代垃圾回收机制及垃圾回收算法

学习垃圾回收的意义 Java 与 C等语言最大的技术区别:自动化的垃圾回收机制(GC) 为什么要了解 GC 和内存分配策略 1、面试需要 2、GC 对应用的性能是有影响的; 3、写代码有好处 栈:栈中的生命周期是跟随线程&…

重型堆垛机钢丝绳维护经验

钢丝绳是重型堆垛机一个非常重要的组成部分,平时我们给一些客户做堆垛机的维保,每次都会特地去检查堆垛机的钢丝绳,如果发现起毛刺,那必须得赶紧跟客户讲,让客户自己的维修人员不定期地观察,情况严重就要做…

CPU密集型和IO密集型对 CPU内核之间的关系

多线程如何合理的配置核心线程数? 对于 CPU 密集型任务,由于 CPU 密集型任务的性质,导致 CPU 的使用率很高,如果使用线程池中的核心线程数量过多,会增加上下文切换的次数,带来额外的开销。因此&#xff0c…

Python 日志(略讲)

日志操作 日志输出: # 输出日志信息 logging.debug("调试级别日志") logging.info("信息级别日志") logging.warning("警告级别日志") logging.error("错误级别日志") logging.critical("严重级别日志")级别设置…

Java程序员,你掌握了多线程吗?(文末送书)

目录 01、多线程对于Java的意义02、为什么Java工程师必须掌握多线程03、Java多线程使用方式04、如何学好Java多线程送书规则 摘要:互联网的每一个角落,无论是大型电商平台的秒杀活动,社交平台的实时消息推送,还是在线视频平台的流…

unity 2d 入门 飞翔小鸟 下坠功能且碰到地面要停止 刚体 胶囊碰撞器 (四)

1、实现对象要受重力 在对应的图层添加刚体 改成持续 2、设置胶囊碰撞器并设置水平方向 3、地面添加盒状碰撞器 运行则能看到小鸟下坠并落到地面上

Windows本地如何添加域名映射?(修改hosts文件)

1. DNS(域名系统) Domain Name System(域名系统):为了加快定位IP地址的速度, 将域名映射进行层层缓存的系统. 目的:互联网通过IP(10.223.146.45)定位浏览器建立连接,但是我们不易区别IP,为了方便用户辨识I…

柏睿网络分析:为什么微模块化机房越来越受欢迎?

与传统机房相比,微模块化机房的建设周期更短,扩展性更强,能耗更低,运维难度也相对较低。因此,微模块化机房是一种高效、灵活、节能的机房解决方案,适用于各种规模的数据中心。 一体化分布式部署&#xff1a…

idea利用SpringMVC框架整合ThymeLeaf

简洁一些:两个重要文件 1.controller指定html文件:我们访问http://localhost:8080/test package com.example.appledemo.controller;import org.springframework.stereotype.Controller; import org.springframework.web.bind.annotation.RequestMapping; import o…

甘草书店:#9 2023年11月23日 星期四 「麦田创业历程分享1——联合创始人的魔幻相遇」

既然甘草是一家创业主题的书店咖啡馆,那就从我,从麦田开始分享一下创业历程吧。 需要声明的是,我从不认为我有资格对别人的创业指指点点,每位创业者的性格、背景、基础、诉求各有不同,时代发展也日新月异,…

netty07-粘包半包以及解决方案

粘包指的是发送方在发送数据时,多个数据包被合并成一个大的数据包发送到接收方,接收方在接收时无法准确地区分各个数据包的边界,从而导致数据粘在一起。 半包指的是发送方发送的数据包被拆分成了多个小的数据包,在接收方接收时&a…

springboot中优雅实现异常拦截和返回统一结构数据

做前后端分离的项目,为了方便前端处理数据,都会将返回的数据封装到统一的结构下,这样前端拿到数据可以根据指定的字段做不同的业务逻辑处理。 1、异常信息统一拦截 项目开发中,难免会发生异常,如果不做拦截&#xff…