PyTorch使用细节

model.eval() :让BatchNorm、Dropout等失效;

with torch.no_grad() : 不再缓存activation,节省显存;

这是矩阵乘法:

y1 = tensor @ tensor.T
y2 = tensor.matmul(tensor.T)y3 = torch.rand_like(y1)
torch.matmul(tensor, tensor.T, out=y3)

这是点乘:

z1 = tensor * tensor
z2 = tensor.mul(tensor)z3 = torch.rand_like(tensor)
torch.mul(tensor, tensor, out=z3)

Tensor如果是1*1大小的,可以转为普通Python变量

agg = tensor.sum()
agg_item = agg.item()

Tensor和numpy之间,是share内存的,改一个另一个也被改动

n = torch.ones(5).numpy()n = np.ones(5)
t = torch.from_numpy(n)

root本地文件夹里有,则从本地读;没有的话,如指定了ownload=True,则从远程下载;

import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambdatraining_data = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor(),target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)

Dataset类:通过index,拿到1条数据;

        数据可以都在磁盘上,用到哪条,就加载哪条;

        自定义一个类,需要继承Dataset类,并重写__init__、__len__、__getitem__

DataLoader类:batching, shuffle(sampling策略), multiprocess加载,pin memory,...

ToTensor(): 把PIL格式的Image,转成Tensor;

Lambda: 把int的y,转成10维度的1-hot向量;

一切模型层,皆继承自torch.nn.Module

class NeuralNetwork(nn.Module):

Module必须copy到device上

model = NeuralNetwork().to(device)

input data也必须copy到device上

X = torch.rand(1, 28, 28, device=device)

不能直接使用Module.forward,使用Module(input)语法可以使前后的hook起作用

logits = model(X)

model.parameters(): 可训练的参数;

model.named_parameters(): 可训练的参数;包含名称;

state_dict: 可训练的参数、不可训练的参数,都有;

继承自Function类,可以写自定义的forward和backward,input或output可以放在ctx里:

>>> class Exp(Function):
>>>     @staticmethod
>>>     def forward(ctx, i):
>>>         result = i.exp()
>>>         ctx.save_for_backward(result)
>>>         return result
>>>
>>>     @staticmethod
>>>     def backward(ctx, grad_output):
>>>         result, = ctx.saved_tensors
>>>         return grad_output * result
>>>
>>> # Use it by calling the apply method:
>>> output = Exp.apply(input)

 构造计算图:

Tensor的几大成员:grad, grad_fn, is_leaf, requires_grad

Tensor.grad_fn,就是用于backward梯度计算的Function:

print(f"Gradient function for z = {z.grad_fn}")
print(f"Gradient function for loss = {loss.grad_fn}")# Output:
Gradient function for z = <AddBackward0 object at 0x7f5e9fb64e20>
Gradient function for loss = <BinaryCrossEntropyWithLogitsBackward0 object at 0x7f5e99b11b40>

backward时,注意,是累积加和到Tensor.grad上;这样,链式法则有些地方就是要加和的,accumulate step也可以实现;

只有满足这个条件的才会累积其grad: is_leaf==True && requires_grad==True

只有requires_grad==True,但is_leaf==False,则会将梯度传播给上游,自己的grad成员无值;

只用来inference时,可用"with torch.no_grad()"控制其不生成计算图:(好处:forward速度变快一点儿,不保存activation至ctx节省显存)

with torch.no_grad():z = torch.matmul(x, w)+b
print(z.requires_grad)Output: False

某些模型训练,有些parameter要设成frozen不参与权重更新,则手工设其requires_grad=False即可。

用detach()来创造数据引用,脱离了原计算图,原计算图可以被垃圾回收了:

z = torch.matmul(x, w)+b
z_det = z.detach()
print(z_det.requires_grad)Output: False

backward DAG,在每次forward阶段,都会被重新搭建;所以每个step,计算图可以任意变化(例如根据Tensor的值来走不同的control flow)

向量对向量求偏导,得到的是雅克比矩阵:

以下例子演示:雅克比矩阵、梯度累积、zero_grad

inp = torch.eye(4, 5, requires_grad=True)
out = (inp+1).pow(2).t()
out.backward(torch.ones_like(out), retain_graph=True)
print(f"First call\n{inp.grad}")
out.backward(torch.ones_like(out), retain_graph=True)
print(f"\nSecond call\n{inp.grad}")
inp.grad.zero_()
out.backward(torch.ones_like(out), retain_graph=True)
print(f"\nCall after zeroing gradients\n{inp.grad}")Output:First call
tensor([[4., 2., 2., 2., 2.],[2., 4., 2., 2., 2.],[2., 2., 4., 2., 2.],[2., 2., 2., 4., 2.]])Second call
tensor([[8., 4., 4., 4., 4.],[4., 8., 4., 4., 4.],[4., 4., 8., 4., 4.],[4., 4., 4., 8., 4.]])Call after zeroing gradients
tensor([[4., 2., 2., 2., 2.],[2., 4., 2., 2., 2.],[2., 2., 4., 2., 2.],[2., 2., 2., 4., 2.]])

optimizer使用例子

optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)def train(...):model.train()for batch, (X, y) in enumerate(dataloader):# Compute prediction and losspred = model(X)loss = loss_fn(pred, y)# Backpropagationloss.backward()optimizer.step()optimizer.zero_grad()  # 将所有Tensor.grad清0

torch.save: 使用Python的pickle,将一个dict进行序列化,并存至文件;

torch.load: 读取文件,使用Python的pickle,将字节数组进行反序列化,至一个dict;

torch.nn.Module.state_dict: 一个Python的dict,key是字符串,value是Tensor;包含可学习的parameters,不可学习的buffers(例如batch normalization需要的running mean);

optimizer也有state_dic(learning rate,冲量等)

save下来仅仅用于推理:(注意:必须model.eval(),否则dropout、BN,会出毛病)

# save:
torch.save(model.state_dict(), PATH)# load:
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

save下来可用于继续训练:

# save:
torch.save({'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': loss,...}, PATH)# load:
model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']model.train()

使用state_dict方式,load之前,model必须初始化好(内存已经被parameters占住了,只是权重是随机的)

map_location、model.to(device)等:Saving and Loading Models — PyTorch Tutorials 2.3.0+cu121 documentation

小众用法:(model不用初始化)

# save:
torch.save(model, PATH)# load:
# Model class must be defined somewhere
model = torch.load(PATH)
model.eval()

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/873390.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

js reduce 的别样用法

let mergedItems list.reduce((accumulator, currentItem) > {let existingItem accumulator.find((item) > item.manObject_name currentItem.manObject_name);if (existingItem) {existingItem.laborCostHand currentItem.laborCostHand; //劳务费existingItem.wor…

有了这5个高效视频剪辑工具,你一定会爱上剪辑

如果你是个剪辑新手&#xff0c;不知道如何挑选剪辑视频的工具&#xff0c;又或者是自己目前使用的剪辑工具不理想&#xff0c;想寻找新的剪辑软件&#xff1b;那就请你看看这篇文章&#xff0c;这里介绍的5款剪辑软件都是专业&#xff0c;简单&#xff0c;又高效的剪辑工具。 …

顺序表<数据结构 C版>

目录 线性表 顺序表 动态顺序表类型 初始化 销毁 打印 检查空间是否充足&#xff08;扩容&#xff09; 尾部插入 头部插入 尾部删除 头部删除 指定位置插入 指定位置删除 查找数据 线性表 线性表是n个相同特性的数据元素组成的有限序列&#xff0c;其是一种广泛运…

04 Git与远程仓库

第4章&#xff1a;Git与远程仓库 一、Gitee介绍及创建仓库 一&#xff09;获取远程仓库 ​ 使用在线的代码托管平台&#xff0c;如Gitee&#xff08;码云&#xff09;、GitHub等 ​ 自行搭建Git代码托管平台&#xff0c;如GitLab 二&#xff09;Gitee创建仓库 ​ gitee官…

Gitee使用教程2-克隆仓库(下载项目)并推送更新项目

一、下载 Gitee 仓库 1、点击克隆-复制代码 2、打开Git Bash 并输入复制的代码 下载好后&#xff0c;找不到文件在哪的可以输入 pwd 找到仓库路径 二、推送更新 Gitee 项目 1、打开 Git Bash 用 cd 命令进入你的仓库&#xff08;我的仓库名为book&#xff09; 2、添加文件到 …

Spring-Boot基础--yaml

目录 Spring-Boot配置文件 注意&#xff1a; YAML简介 YAML基础语法 YAML:数据格式 YAML文件读取配置内容 逐个注入 批量注入 ConfigurationProperties 和value的区别 Spring-Boot配置文件 Spring-Boot中不用编写.xml文件&#xff0c;但是spring-Boot中还是存在.prope…

参与开源项目 MySQL 的心得体会

前言 开源项目的数量和种类都在急剧增长&#xff0c;涵盖了从操作系统、数据库到人工智能、区块链等几乎所有的技术领域。这为技术的快速创新和迭代提供了强大的动力&#xff0c;使得新技术能够更快地普及和应用. 目录 经历 提升 挑战 良好的编程习惯 总结 经历 参与开源…

Linux Namespace

Linux namespaces 介绍 namespaces是Linux内核用来隔离内核资源的方式。通过namespaces可以让一些进程只能看到与自己相关的那部分资源。而其它的进程也只能看到与他们自己相关的资源。这两拨进程根本感知不到对方的存在。而它具体的实现细节是通过Linux namespaces来实现的。 …

(三)C++之运算符重载

一.概念 C准许以运算符命名函数&#xff01;&#xff01;&#xff01; string a “hello”; a “ world”;// (a, “world”); cout<<“hello”; // <<(cout, “hello”); 可重载的运算符 不可重载的运算符 二.成员函数式(第一个行参是对象的引用) class T…

orcad导出pdf 缺少title block

在OrCAD中导出PDF时没有Title Block 最后确认问题在这里&#xff1a; 要勾选上Title Block Visible下面的print

Nginx详解(超级详细)

目录 Nginx简介 1. 为什么使用Nginx 2. 安装Nginx Nginx的核心功能 1. Nginx反向代理功能 2. Nginx的负载均衡 3 Nginx动静分离 Nginx简介 Nginx是一款轻量级的Web 服务器/反向代理服务器及电子邮件&#xff08;IMAP/POP3&#xff09;代理服务器&#xff0c;在BSD-like 协…

你能分清工业领域这些常见的技术文档吗?

在制造业领域中&#xff0c;技术文档是不可或缺的宝贵资源。它们不仅是产品设计理念的载体&#xff0c;更是指导生产、保证质量、降低错误的关键。技术文档详尽描述了产品的每一个细节&#xff0c;从设计原理到零部件规格&#xff0c;从装配步骤到操作指南&#xff0c;无所不包…

关于dom4j主节点的xmlns无法写入的问题

由于最近需要做一个xml的文件&#xff0c;使用dom4j的时候发现了一个bug&#xff0c;就是我的xmlns根本无法写入到xml的头部标签中。 Element element document.addElement("test"); element.addAttribute("xmlns", "urn:Declaration:datamodel:sta…

Windows10 22H2专业工作站版:功能全新升级,工作更高效!

Windows10 22H2专业工作站版是一款专为具有高级数据需求的人士设计的操作系统&#xff0c;拥有强大的服务器级数据保护和性能&#xff0c;可以帮助用户不断突破高级工作负载的挑战。接下来系统之家小编给大家带来全新升级的Windows10 22H2专业工作站版系统&#xff0c;喜欢的用…

刚起步的家庭海外仓:涉及到的全部业务优化流程

对于家庭海外仓来说&#xff0c;最难的阶段应该就是刚起步的时候。对业务流程不熟悉&#xff0c;也没有客户积累&#xff0c;本身的预算又十分有限。 在这个情况下应该注意什么&#xff0c;怎样才能顺利的开展业务&#xff1f;今天我们就针对这个问题详细的梳理了一下家庭海外…

界面控件DevExpress Blazor UI v24.1 - 发布全新TreeList组件

DevExpress Blazor UI组件使用了C#为Blazor Server和Blazor WebAssembly创建高影响力的用户体验&#xff0c;这个UI自建库提供了一套全面的原生Blazor UI组件&#xff08;包括Pivot Grid、调度程序、图表、数据编辑器和报表等&#xff09;。 DevExpress Blazor控件目前已经升级…

电脑屏幕录制怎么弄?分享3个简单的电脑录屏方法

在信息爆炸的时代&#xff0c;屏幕上的每一个画面都可能成为我们生活中不可或缺的记忆。作为一名年轻男性&#xff0c;我对于录屏软件的需求可以说是既挑剔又实际。今天&#xff0c;我就为大家分享一下我近期体验的三款录屏软件&#xff1a;福昕录屏大师、转转大师录屏大师和OB…

高频面试题-CSS

BFC 介绍下BFC (块级格式化上下文) 1>什么是BFC BFC即块级格式化上下文&#xff0c;是CSS可视化渲染的一部分, 它是一块独立的渲染区域&#xff0c;只有属于同一个BFC的元素才会互相影响&#xff0c;且不会影响其它外部元素。 2>如何创建BFC 根元素&#xff0c;即HTM…

maven项目容器化运行之2-maven中使用docker插件调用远程docker构建服务并在1Panel中运行

一.背景 公司主机管理小组的同事期望我们开发的maven项目能够在1Panel管理的docker容器部署。上一篇写了先开放1Panel中docker镜像构建能力maven项目容器化运行之1-基于1Panel软件将docker镜像构建能力分享给局域网-CSDN博客。这一篇就是演示maven工程的镜像构建、容器运行、运…

昇思25天学习打卡营第14天|DCGAN 与漫画头像生成:原理剖析与训练实战

目录 数据集下载 数据处理 构建生成器 构建判别器 模型训练 结果展示 数据集下载 首先尝试卸载已安装的 mindspore 库&#xff0c;然后通过指定的镜像源安装特定版本&#xff08;2.2.14&#xff09;的 mindspore 库。从指定的 URL 下载一个 zip 文件到当前目录下的 ./faces…