PyTorch使用细节

model.eval() :让BatchNorm、Dropout等失效;

with torch.no_grad() : 不再缓存activation,节省显存;

这是矩阵乘法:

y1 = tensor @ tensor.T
y2 = tensor.matmul(tensor.T)y3 = torch.rand_like(y1)
torch.matmul(tensor, tensor.T, out=y3)

这是点乘:

z1 = tensor * tensor
z2 = tensor.mul(tensor)z3 = torch.rand_like(tensor)
torch.mul(tensor, tensor, out=z3)

Tensor如果是1*1大小的,可以转为普通Python变量

agg = tensor.sum()
agg_item = agg.item()

Tensor和numpy之间,是share内存的,改一个另一个也被改动

n = torch.ones(5).numpy()n = np.ones(5)
t = torch.from_numpy(n)

root本地文件夹里有,则从本地读;没有的话,如指定了ownload=True,则从远程下载;

import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambdatraining_data = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor(),target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)

Dataset类:通过index,拿到1条数据;

        数据可以都在磁盘上,用到哪条,就加载哪条;

        自定义一个类,需要继承Dataset类,并重写__init__、__len__、__getitem__

DataLoader类:batching, shuffle(sampling策略), multiprocess加载,pin memory,...

ToTensor(): 把PIL格式的Image,转成Tensor;

Lambda: 把int的y,转成10维度的1-hot向量;

一切模型层,皆继承自torch.nn.Module

class NeuralNetwork(nn.Module):

Module必须copy到device上

model = NeuralNetwork().to(device)

input data也必须copy到device上

X = torch.rand(1, 28, 28, device=device)

不能直接使用Module.forward,使用Module(input)语法可以使前后的hook起作用

logits = model(X)

model.parameters(): 可训练的参数;

model.named_parameters(): 可训练的参数;包含名称;

state_dict: 可训练的参数、不可训练的参数,都有;

继承自Function类,可以写自定义的forward和backward,input或output可以放在ctx里:

>>> class Exp(Function):
>>>     @staticmethod
>>>     def forward(ctx, i):
>>>         result = i.exp()
>>>         ctx.save_for_backward(result)
>>>         return result
>>>
>>>     @staticmethod
>>>     def backward(ctx, grad_output):
>>>         result, = ctx.saved_tensors
>>>         return grad_output * result
>>>
>>> # Use it by calling the apply method:
>>> output = Exp.apply(input)

 构造计算图:

Tensor的几大成员:grad, grad_fn, is_leaf, requires_grad

Tensor.grad_fn,就是用于backward梯度计算的Function:

print(f"Gradient function for z = {z.grad_fn}")
print(f"Gradient function for loss = {loss.grad_fn}")# Output:
Gradient function for z = <AddBackward0 object at 0x7f5e9fb64e20>
Gradient function for loss = <BinaryCrossEntropyWithLogitsBackward0 object at 0x7f5e99b11b40>

backward时,注意,是累积加和到Tensor.grad上;这样,链式法则有些地方就是要加和的,accumulate step也可以实现;

只有满足这个条件的才会累积其grad: is_leaf==True && requires_grad==True

只有requires_grad==True,但is_leaf==False,则会将梯度传播给上游,自己的grad成员无值;

只用来inference时,可用"with torch.no_grad()"控制其不生成计算图:(好处:forward速度变快一点儿,不保存activation至ctx节省显存)

with torch.no_grad():z = torch.matmul(x, w)+b
print(z.requires_grad)Output: False

某些模型训练,有些parameter要设成frozen不参与权重更新,则手工设其requires_grad=False即可。

用detach()来创造数据引用,脱离了原计算图,原计算图可以被垃圾回收了:

z = torch.matmul(x, w)+b
z_det = z.detach()
print(z_det.requires_grad)Output: False

backward DAG,在每次forward阶段,都会被重新搭建;所以每个step,计算图可以任意变化(例如根据Tensor的值来走不同的control flow)

向量对向量求偏导,得到的是雅克比矩阵:

以下例子演示:雅克比矩阵、梯度累积、zero_grad

inp = torch.eye(4, 5, requires_grad=True)
out = (inp+1).pow(2).t()
out.backward(torch.ones_like(out), retain_graph=True)
print(f"First call\n{inp.grad}")
out.backward(torch.ones_like(out), retain_graph=True)
print(f"\nSecond call\n{inp.grad}")
inp.grad.zero_()
out.backward(torch.ones_like(out), retain_graph=True)
print(f"\nCall after zeroing gradients\n{inp.grad}")Output:First call
tensor([[4., 2., 2., 2., 2.],[2., 4., 2., 2., 2.],[2., 2., 4., 2., 2.],[2., 2., 2., 4., 2.]])Second call
tensor([[8., 4., 4., 4., 4.],[4., 8., 4., 4., 4.],[4., 4., 8., 4., 4.],[4., 4., 4., 8., 4.]])Call after zeroing gradients
tensor([[4., 2., 2., 2., 2.],[2., 4., 2., 2., 2.],[2., 2., 4., 2., 2.],[2., 2., 2., 4., 2.]])

optimizer使用例子

optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)def train(...):model.train()for batch, (X, y) in enumerate(dataloader):# Compute prediction and losspred = model(X)loss = loss_fn(pred, y)# Backpropagationloss.backward()optimizer.step()optimizer.zero_grad()  # 将所有Tensor.grad清0

torch.save: 使用Python的pickle,将一个dict进行序列化,并存至文件;

torch.load: 读取文件,使用Python的pickle,将字节数组进行反序列化,至一个dict;

torch.nn.Module.state_dict: 一个Python的dict,key是字符串,value是Tensor;包含可学习的parameters,不可学习的buffers(例如batch normalization需要的running mean);

optimizer也有state_dic(learning rate,冲量等)

save下来仅仅用于推理:(注意:必须model.eval(),否则dropout、BN,会出毛病)

# save:
torch.save(model.state_dict(), PATH)# load:
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

save下来可用于继续训练:

# save:
torch.save({'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': loss,...}, PATH)# load:
model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']model.train()

使用state_dict方式,load之前,model必须初始化好(内存已经被parameters占住了,只是权重是随机的)

map_location、model.to(device)等:Saving and Loading Models — PyTorch Tutorials 2.3.0+cu121 documentation

小众用法:(model不用初始化)

# save:
torch.save(model, PATH)# load:
# Model class must be defined somewhere
model = torch.load(PATH)
model.eval()

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/873390.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

初步认识HTML

目录 一. HTML概述 二. HTML基本语法 1. HTML的基本框架 2. 标签 2.1 标签分类 2.2 标签属性 三. 基本常用标签 3.1 标题标签 3.2 段落标签 3.3 换行标签 3.4 列表 3.5 超链接 四. 特殊符号转义 五. 表格 5.1 表格的基本构成标签 5.2 表格的基本结构 5.3 表格属…

js reduce 的别样用法

let mergedItems list.reduce((accumulator, currentItem) > {let existingItem accumulator.find((item) > item.manObject_name currentItem.manObject_name);if (existingItem) {existingItem.laborCostHand currentItem.laborCostHand; //劳务费existingItem.wor…

增量预训练和微调的区别

文章目录 前言一、增量预训练和微调的区别二、代码示例1. 增量预训练示例2. 微调示例3. 代码的区别 三、数据格式1. 增量预训练2. 微调3. 示例4. 小结 四、数据量要求1. 指导原则2. 示例3. 实际操作中的考虑4. 小结 前言 增量预训练是一种在现有预训练模型的基础上&#xff0c…

有了这5个高效视频剪辑工具,你一定会爱上剪辑

如果你是个剪辑新手&#xff0c;不知道如何挑选剪辑视频的工具&#xff0c;又或者是自己目前使用的剪辑工具不理想&#xff0c;想寻找新的剪辑软件&#xff1b;那就请你看看这篇文章&#xff0c;这里介绍的5款剪辑软件都是专业&#xff0c;简单&#xff0c;又高效的剪辑工具。 …

顺序表<数据结构 C版>

目录 线性表 顺序表 动态顺序表类型 初始化 销毁 打印 检查空间是否充足&#xff08;扩容&#xff09; 尾部插入 头部插入 尾部删除 头部删除 指定位置插入 指定位置删除 查找数据 线性表 线性表是n个相同特性的数据元素组成的有限序列&#xff0c;其是一种广泛运…

解决警告Creating a tensor from a list of numpy.ndarrays is extremely slow.

我的问题是创建一个列表x[]&#xff0c;然后不断读入数据使用x.append(sample)&#xff0c;chatgpt说这样转化比较低效&#xff0c;如果预先知道样本个数&#xff0c;可以用numpy来创建数组&#xff0c;再用索引x[i]sample赋值第二种方法更快&#xff0c;直接用numpy转化一下np…

04 Git与远程仓库

第4章&#xff1a;Git与远程仓库 一、Gitee介绍及创建仓库 一&#xff09;获取远程仓库 ​ 使用在线的代码托管平台&#xff0c;如Gitee&#xff08;码云&#xff09;、GitHub等 ​ 自行搭建Git代码托管平台&#xff0c;如GitLab 二&#xff09;Gitee创建仓库 ​ gitee官…

Gitee使用教程2-克隆仓库(下载项目)并推送更新项目

一、下载 Gitee 仓库 1、点击克隆-复制代码 2、打开Git Bash 并输入复制的代码 下载好后&#xff0c;找不到文件在哪的可以输入 pwd 找到仓库路径 二、推送更新 Gitee 项目 1、打开 Git Bash 用 cd 命令进入你的仓库&#xff08;我的仓库名为book&#xff09; 2、添加文件到 …

Spring-Boot基础--yaml

目录 Spring-Boot配置文件 注意&#xff1a; YAML简介 YAML基础语法 YAML:数据格式 YAML文件读取配置内容 逐个注入 批量注入 ConfigurationProperties 和value的区别 Spring-Boot配置文件 Spring-Boot中不用编写.xml文件&#xff0c;但是spring-Boot中还是存在.prope…

【Qt+opencv】基础的图像绘制

文章目录 前言line函数ellipse函数rectangle函数circle函数fillPoly函数putText函数总结 前言 在计算机视觉和图像处理领域&#xff0c;OpenCV是一个强大的库&#xff0c;提供了丰富的功能和算法。而Qt是一个跨平台的C图形用户界面应用程序开发框架&#xff0c;它为开发者提供…

参与开源项目 MySQL 的心得体会

前言 开源项目的数量和种类都在急剧增长&#xff0c;涵盖了从操作系统、数据库到人工智能、区块链等几乎所有的技术领域。这为技术的快速创新和迭代提供了强大的动力&#xff0c;使得新技术能够更快地普及和应用. 目录 经历 提升 挑战 良好的编程习惯 总结 经历 参与开源…

微信小程序-实现跳转链接并拼接参数(URL拼接路径参数)

第一种常用拼接方法&#xff1a;普通传值的拼接 //普通传值的拼接checkRouteBinttap: function (e) {wx.navigateTo({url: ../checkRoute/checkRoute?classId this.data.classInfo.classId "&taskId" this.data.classInfo.taskId,})}第二种&#xff1a;拼接…

Linux Namespace

Linux namespaces 介绍 namespaces是Linux内核用来隔离内核资源的方式。通过namespaces可以让一些进程只能看到与自己相关的那部分资源。而其它的进程也只能看到与他们自己相关的资源。这两拨进程根本感知不到对方的存在。而它具体的实现细节是通过Linux namespaces来实现的。 …

(三)C++之运算符重载

一.概念 C准许以运算符命名函数&#xff01;&#xff01;&#xff01; string a “hello”; a “ world”;// (a, “world”); cout<<“hello”; // <<(cout, “hello”); 可重载的运算符 不可重载的运算符 二.成员函数式(第一个行参是对象的引用) class T…

orcad导出pdf 缺少title block

在OrCAD中导出PDF时没有Title Block 最后确认问题在这里&#xff1a; 要勾选上Title Block Visible下面的print

k8s学习笔记——dashboard安装

重装了k8s集群后&#xff0c;重新安装k8s的仪表板&#xff0c;发现与以前安装不一样的地方。主要是镜像下载的问题&#xff0c;由于网络安全以及国外网站封锁的原因&#xff0c;现在很多镜像按照官方提供的仓库地址都下拉不下来&#xff0c;导致安装失败。我查了好几天&#xf…

Nginx详解(超级详细)

目录 Nginx简介 1. 为什么使用Nginx 2. 安装Nginx Nginx的核心功能 1. Nginx反向代理功能 2. Nginx的负载均衡 3 Nginx动静分离 Nginx简介 Nginx是一款轻量级的Web 服务器/反向代理服务器及电子邮件&#xff08;IMAP/POP3&#xff09;代理服务器&#xff0c;在BSD-like 协…

你能分清工业领域这些常见的技术文档吗?

在制造业领域中&#xff0c;技术文档是不可或缺的宝贵资源。它们不仅是产品设计理念的载体&#xff0c;更是指导生产、保证质量、降低错误的关键。技术文档详尽描述了产品的每一个细节&#xff0c;从设计原理到零部件规格&#xff0c;从装配步骤到操作指南&#xff0c;无所不包…

RabbitMQ 如何保证消息的可靠性

在分布式系统中&#xff0c;消息队列&#xff08;如 RabbitMQ&#xff09;扮演着至关重要的角色&#xff0c;它们作为中间件&#xff0c;帮助系统解耦、异步处理任务、提升系统性能和可靠性。然而&#xff0c;在使用消息队列时&#xff0c;确保消息的可靠性是一个不可忽视的问题…

Java 中快速生成唯一id

&#x1f446;&#x1f3fb;&#x1f446;&#x1f3fb;&#x1f446;&#x1f3fb;关注博主&#xff0c;让你的代码变得更加优雅。 前言 Hutool 是一个小而全的Java工具类库&#xff0c;通过静态方法封装&#xff0c;降低相关API的学习成本&#xff0c;提高工作效率&#xf…