深入解析 PyTorch 的 torch.load() 函数:用法、参数与实际应用示例

深入解析 PyTorch 的 torch.load() 函数:用法、参数与实际应用示例

函数 torch.load() 是一个在PyTorch中用于加载通过 torch.save() 保存的序列化对象的核心功能。这个函数广泛应用于加载预训练模型、模型的状态字典(state dictionaries)、优化器状态以及其他PyTorch对象。它利用Python的反序列化能力,特别地对张量的底层存储(storages)进行了特殊处理,以支持跨设备加载和内存效率。

基本语法和参数详解

torch.load(f, map_location=None, pickle_module=pickle, *, weights_only=False, mmap=None, **pickle_load_args)
参数详细说明
  • f (Union[str, PathLike, BinaryIO, IO[bytes]])

    • 类型:可以是字符串、路径对象或文件对象。
    • 含义:指定要加载的文件的路径或文件对象。如果是文件对象,它必须实现基本的文件读取方法,如 read()seek()
  • map_location (Optional[Union[Callable[[Storage, str], Storage], torch.device, str, Dict[str, str]])

    • 类型:可选,可以是函数、设备对象、字符串或字典。
    • 含义:用于指定存储设备的重新映射策略。
      • 函数:如果提供了函数,它应该接受存储和位置标签作为参数,并返回新的存储位置。
      • 设备或字符串:可以直接指定所有张量应该被加载到的设备,如 'cpu''cuda:0'
      • 字典:将文件中的位置标签映射到新的存储位置。
  • pickle_module (Optional[Any])

    • 类型:模块。
    • 含义:用于反序列化的模块,默认为Python的 pickle 模块。如果序列化时使用了特定的模块,则加载时也必须使用相同的模块。
  • weights_only (Optional[bool])

    • 类型:布尔值。
    • 含义:如果设置为 True,则加载过程将限制为仅加载张量、基本数据类型、字典和通过 torch.serialization.add_safe_globals() 添加的安全类型。
  • mmap (Optional[bool])

    • 类型:布尔值。
    • 含义:如果设置为 True,则文件将通过内存映射的方式访问,而不是完全加载到内存中。这对处理大型数据文件特别有用,因为它减少了内存使用并可能提高访问速度。
  • pickle_load_args (Any)

    • 类型:关键字参数。
    • 含义:传递给 pickle_module.load()pickle_module.Unpickler() 的附加参数,例如 encoding

实际使用示例

示例 1: 基础加载模型

加载一个在GPU上训练并保存的模型到CPU上进行推理:

import torch# 设置加载路径
model_path = 'gpu_trained_model.pth'# 加载模型到CPU
model = torch.load(model_path, map_location='cpu')# 打印模型结构确认加载无误
print(model)
示例 2: 使用内存映射和仅加载权重

对于大型模型文件,使用内存映射加载权重,减少内存占用:

import torch# 模型文件路径
large_model_path = 'large_model_weights.pth'# 使用内存映射方式加载模型权重到CPU,限制为仅加载权重
model_weights = torch.load(large_model_path, map_location='cpu', mmap=True, weights_only=True)# 假设 MyModel 是模型的架构类
model = MyModel()
model.load_state_dict(model_weights)# 输出模型确保权重被正确加载
print(model)

这些示例清楚地展示了如何灵活使用 torch.load() 的不同参数来优化模型的加载策略,适应不同的硬件环境和内存限制,从而实现高效的模型部署。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/62353.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Web开发基础学习——axios的理解

Web开发基础学习系列文章目录 第一章 基础知识学习之axios的理解 文章目录 Web开发基础学习系列文章目录前言一、使用方法1.1 安装 axios:1.2 在前端代码中使用 axios: 总结 前言 Axios 是一个基于 Promise 的 HTTP 客户端,用于在浏览器和 …

FileReader和 FileWriter

FileReader和FileWriter是用于操作文件的类,它们分别用于读取和写入数据。下面是它们的一些基本用法: FileReader: 创建一个FileReader对象,指定要读取的文件路径。使用read()方法读取文件的内容,返回一个整数字符表…

FreeRTOS posix 实现低功耗tickless

文章目录 打印重定向FreeRTOSConfig.h 配置portmacro.h 实现低功耗流程vPortSuppressTicksAndSleep 实现测试效果注意事项 打印重定向 为了观察睡眠时间,重定向打印函数,打印的时候将时间戳打印出来,实现如下 #define printf(fmt, ...) …

解析客服知识库搭建的五个必要性

在当今竞争激烈的商业环境中,客服知识库的搭建已成为企业提升服务质量、优化客户体验的重要手段。一个完善的客服知识库不仅能帮助企业高效管理客户服务流程,还能显著提升客户满意度和忠诚度。以下是搭建客服知识库的五个必要性: 1. 提升服务…

淘宝Vision Pro:革新购物体验的沉浸式未来

引言 简要介绍淘宝Vision Pro版的背景,包括它在美区AppStore的发布及WWDC上的展示。阐述本文的目的:为读者提供一个全面的功能概览与设计背后的思考。设计原则 列出并简要解释5条设计原则(熟悉、直观、真实、实用、易用)。说明这些原则如何指导整个产品设计过程。核心功能详…

网站怎么防御https攻击

HTTPS攻击,它不仅威胁到网站的数据安全,还可能影响用户隐私和业务稳定运行。 HTTPS攻击主要分为以下几种类型: 1.SSL劫持:攻击者通过中间人攻击手段,篡改HTTPS流量,从而实现对数据的窃取或伪造。 2.中间人攻…

【从0学英语】 04.句型 - 英语句子的骨架

在学习英语的过程中,句型就像建筑的骨架一样,是构建完整句子的基础。俗话说,万变不离其宗,即使英语句子千变万化,也离不开几种基本的句型结构。本节内容将从零开始,带您逐步了解英语句子的五种核心骨架&…

【CSS in Depth 2 精译_062】第 10 章 CSS 中的容器查询(@container)概述 + 10.1 容器查询的一个简单示例

当前内容所在位置(可进入专栏查看其他译好的章节内容) 【第十章 CSS 容器查询】 ✔️ 10.1 容器查询的一个简单示例 ✔️ 10.1.1 容器尺寸查询的用法 ✔️ 10.2 深入理解容器10.3 与容器相关的单位10.4 容器样式查询的用法10.5 本章小结 文章目录 第 10…

openjdk17 jvm 对象 内存溢出 在C++源码体现

##java大对象类 public class MiBigObject {private String f1;private String f2;private String f3;private String f4;private String f5;private String f6;private String f7;private String f8;private String f9;private String f10;private String f11;private String…

HCIE:详解OSPF,从基础到高级特性再到深入研究

目录 前言 一、OSPF协议基本原理 简介 基本原理 OSPF路由器类型 OSPF网络类型 OSPF报文类型和封装 OSPF邻居的建立的维护 DR和BDR的选举 伪节点 LSDB的更新 OSPF的配置 二、OSPF的高级特性 虚连接(Virtual-Link) OSPF的LSA和路由选择 OSPF…

C++算法练习-day45——236.二叉树的最近公共祖先

题目来源:. - 力扣(LeetCode) 题目思路分析 题目要求在一个二叉树中找到两个给定节点的最低公共祖先(Lowest Common Ancestor, LCA)。最低公共祖先是指在树中同时包含两个给定节点的所有节点中,深度最大的…

think php处理 异步 url 请求 记录

1、需求 某网站 需要 AI生成音乐,生成mp3文件的时候需要等待,需要程序中实时监听mp3文件是否生成 2、用的开发框架 为php 3、文件结构 配置路由设置 Route::group(/music, function () {Route::post(/musicLyrics, AiMusic/musicLyrics);//Ai生成歌词流式…

【VRChat 改模】开发环境搭建:VCC、VRChat SDK、Unity 等环境配置

一、配置 Unity 相关 1.下载 UnityHub 下载地址:https://unity.com/download 安装打开后如图所示: 2.下载 VRChat 官方推荐版本的 Unity 跳转界面(VRChat 官方推荐页面):https://creators.vrchat.com/sdk/upgrade/…

AJAX 实时搜索

AJAX 实时搜索 AJAX(Asynchronous JavaScript and XML)实时搜索是一种无需刷新整个网页就能从服务器获取数据并在网页上展示的技术。这种技术极大地提升了用户体验,尤其是在搜索引擎、在线购物网站、社交媒体平台等应用中。本文将详细介绍AJ…

ollama部署bge-m3,并实现与dify平台对接

概述 这几天为了写技术博客,各种组件可谓是装了卸,卸了装,只想复现一些东西,确保你们看到的东西都是可以复现的。 (看在我这么认真的份上,求个关注啊,拜托各位观众老爷了。) 这不,为了实验在windows上docker里运行pytorch,把docker重装了。 dify也得重装: Dify基…

详细介绍HTTP与RPC:为什么有了HTTP,还需要RPC?

目录 一、HTTP 二、RPC 介绍 工作原理 核心功能 如何服务寻址 如何进行序列化和反序列化 如何网络传输 基于 TCP 协议的 RPC 调用 基于 HTTP 协议的 RPC 调用 实现方式 优点和缺点 使用场景 常见框架 示例 三、问题 问题一:是先有HTTP还是先有RPC&…

Paddle Inference部署推理(十五)

十五:Paddle Inference推理 (python)API详解 枚举类型 DataType DataType 定义了 Tensor 的数据类型,由传入 Tensor 的 numpy 数组类型确定。 # DataType 枚举定义 class paddle.inference.DataType:# 获取各个 DataType 对应…

blender 视频背景

准备视频文件 首先,确保你有想要用作背景的视频文件。视频格式最好是 Blender 能够很好兼容的,如 MP4 等常见格式。 创建一个新的 Blender 场景或打开现有场景 打开 Blender 软件后,你可以新建一个场景(通过点击 “文件” - “新建…

Elasticsearch与NLP的深度融合:文本嵌入与向量搜索实战指南

Elasticsearch与NLP的深度融合:文本嵌入与向量搜索实战指南 引言 在当今信息爆炸的时代,如何从海量文本数据中快速准确地检索出相关信息,成为了一个迫切需要解决的问题。自然语言处理(NLP)技术的发展为这一挑战提供了新的解决方案。Elasticsearch,作为一个强大的搜索引…

Lesson 10 GNN

听课(李宏毅老师的)笔记,方便梳理框架,以作复习之用。本节课主要讲了生成式对抗网络(GNN)。 目录 Generation Network as Generator 到目前为止,我们学习到的是类似于函数的network&#xf…