pytorch再次学习

目录

  • 基础
    • 数据可视化
    • 切换设备device
    • 定义神经网络类
    • 打印每层的参数大小
    • 自动微分
      • 计算梯度
      • 禁用梯度追踪
      • 优化模型参数
    • 模型保存
    • 模型加载
  • 进阶
    • padding更准确的补法
    • ReLU增加计算量但是减少内存消耗的办法
    • 输出合并
    • 自适应平均池化(将输入shape变成指定的输出shape)
  • python方法
    • 读取图片
    • 获取绝对路径方法
    • 数据加载
    • 获取数据的分类索引,并进行翻转
    • 将类别写入json文件
    • 读取json文件
    • 打印训练时间

基础

数据可视化

import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plttraining_data = datasets.FashionMNIST(root="data1",train=True,download=True,transform=ToTensor()
)test_data = datasets.FashionMNIST(root="data1",train=False,download=True,transform=ToTensor()
)
labels_map = {0: "T-Shirt",1: "Trouser",2: "Pullover",3: "Dress",4: "Coat",5: "Sandal",6: "Shirt",7: "Sneaker",8: "Bag",9: "Ankle Boot",
}
cols, rows = 3, 3
for i in range(1, cols * rows + 1):sample_idx = torch.randint(len(training_data), size=(1,)).item() #用于随机取出一个training_dataimg, label = training_data[sample_idx]plt.subplot(3,3,i) #此处i必须是1开始plt.title(labels_map[label])plt.axis("off")plt.imshow(img.squeeze(), cmap="gray")
plt.show()

在这里插入图片描述

切换设备device

device = ("cuda"if torch.cuda.is_available()else "mps"if torch.backends.mps.is_available()else "cpu"
)
print(f"Using {device} device")

定义神经网络类

class NeuralNetwork(nn.Module):def __init__(self):super().__init__()self.flatten = nn.Flatten()self.linear_relu_stack = nn.Sequential(nn.Linear(28*28, 512),nn.ReLU(),nn.Linear(512, 512),nn.ReLU(),nn.Linear(512, 10),)def forward(self, x):x = self.flatten(x)logits = self.linear_relu_stack(x)return logits

打印每层的参数大小

print(f"Model structure: {model}\n\n")for name, param in model.named_parameters():print(f"Layer: {name} | Size: {param.size()} | Values : {param[:2]} \n")

自动微分

详见文章Variable
需要优化的参数需要加requires_grad=True,会计算这些参数对于loss的梯度

import torchx = torch.ones(5)  # input tensor
y = torch.zeros(3)  # expected output
w = torch.randn(5, 3, requires_grad=True)
b = torch.randn(3, requires_grad=True)
z = torch.matmul(x, w)+b
loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)

计算梯度

计算导数

loss.backward()
print(w.grad)
print(b.grad)

禁用梯度追踪

训练好后进行测试,也就是不要更新参数时使用

z = torch.matmul(x, w)+b
print(z.requires_grad)with torch.no_grad():z = torch.matmul(x, w)+b
print(z.requires_grad)

优化模型参数

  1. 调用optimizer.zero_grad()来重置模型参数的梯度。梯度会默认累加,为了防止重复计算(梯度),我们在每次遍历中显式的清空(梯度累加值)。
  2. 调用loss.backward()来反向传播预测误差。PyTorch对每个参数分别存储损失梯度。
  3. 我们获取到梯度后,调用optimizer.step()来根据反向传播中收集的梯度来调整参数。
optmizer.zero_grad()
loss.backward()
optmizer.step()

模型保存

import torch
import torchvision.models as modelsmodel = models.vgg16(weights='IMAGENET1K_V1')
torch.save(model.state_dict(), 'model_weights.pth')

模型加载

model = models.vgg16() # we do not specify ``weights``, i.e. create untrained model
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()

进阶

padding更准确的补法

input=torch.randn(1,1,3,3)
m=torch.nn.ZeroPad2d((1,1,2,0)) #左,右,上,下
m(input)

在这里插入图片描述

ReLU增加计算量但是减少内存消耗的办法

torch.nn.ReLU(inplace=True)

输出合并

output=[branch1,branch2,branch3,branch4]
torch.cat(output,1) #在1维度上合并

自适应平均池化(将输入shape变成指定的输出shape)

avgpool=torch.nn.AdaptiveAvgPool2d((1,1)) #输出shape为(1,1)

python方法

读取图片

from PIL import Image
im=Image.open('1.jpg')

获取绝对路径方法

import os
data_root=os.path.abspath(os.getcwd())
data_root

在这里插入图片描述

数据加载

data_transform={'train':transforms.Compose([transforms.RandomResizedCrop(224),transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]),'val':transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
}from torchvision import transforms,datasets,utils
train_dataset=datasets.ImageFolder(root=data_root+'train',transform=data_transform['train'])

获取数据的分类索引,并进行翻转

train_dataset.class_to_idx
cla_dict=dict((val,key) for key,val in flower_list.items())

将类别写入json文件

import json
json_str=json.dumps(cla_dict,indent=4) #转成json文件
with open('class_indices.json','w') as json_file: #写入json文件json_file.write(json_str)

读取json文件

try:json_file=open('./class_indices.json','r')class_indict=json.load(json_file)
except Exception as e:print(e)exit(-1)

打印训练时间

import time
t1=time.perf_counter()
'.....'
time.perf_counter-t1

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/74379.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

一个产品级MCU菜单框架设计

分享一个用单色屏做的菜单框架。代码托管在github: https://github.com/wujique/stm32f407/tree/sw_arch 1、概述 本处所说的菜单是用在128*64这种小屏幕的菜单,例如下面这种,不是彩屏上的GUI。 2、菜单框架设计 作为一个底层驱动工程师&a…

单片机-LED介绍

简介 LED 即发光二极管。它具有单向导电性,通过 5mA 左右电流即可发光 电流 越大,其亮度越强,但若电流过大,会烧毁二极管,一般我们控制在 3 mA-20mA 之间,通常我们会在 LED 管脚上串联一个电阻&#xff0c…

汇编语言Nasmide编辑软件

用来编写汇编语言源程序,Windows 记事本并不是一个好工具。同时,在命令行编译源程序也令很多人迷糊。毕竟,很多年轻的朋友都是用着 Windows 成长起来的,他们缺少在 DOS和 UNIX 下工作的经历。 我一直想找一个自己中意的汇编语言编…

文件导入之Validation校验List对象数组

背景: 我们的接口是一个List对象,对象里面的数据基本都有一些基础数据校验的注解,我们怎么样才能校验这些基础规则呢? 我们在导入excel文件进行数据录入的时候,数据录入也有基础的校验规则,这个时候我们又…

PaddleOCR学习笔记1-初步尝试

尝试使用PaddleOCR方法,如何使用自定义的模型方法,参数怎么配置,图片识别尝试简单提高识别率方法。 目前仅仅只是初步学习下如何使用PaddleOCR的方法。 一,测试识别图片: 1.png : 正确文本内容为“哲学可以帮助辩别现…

在很多公司里面会使用打tag的方式保留版本

:git tag|grep "xxx-dev“等分支来查看 2:git cherry-pick XXXXX 然后就是查看有冲突这些 git status 会出现相关的异常 然后解决相关的冲突 git add . git cherry-pick --continue git push XXX HEAD:refs/for/XXX 第一:git ta…

JTAG 简介

文章目录 1、JTAG 基本原理1.1、JTAG接口包括以下几个信号:1.2、The Debug TAP State Machine (DBGTAPSM) 2、JTAG 的应用 1、JTAG 基本原理 JTAG是Joint Test Action Group的缩写,它是一种国际标准测试协议,主要用于芯片或印制电路板的边界…

探索在云原生环境中构建的大数据驱动的智能应用程序的成功案例,并分析它们的关键要素。

文章目录 1. Netflix - 个性化推荐引擎2. Uber - 实时数据分析和决策支持3. Airbnb - 价格预测和优化5. Google - 自然语言处理和搜索优化 🎈个人主页:程序员 小侯 🎐CSDN新晋作者 🎉欢迎 👍点赞✍评论⭐收藏 ✨收录专…

Zoom正式发布类ChatGPT产品—AI Companion

9月6日,全球视频会议领导者Zoom在官网宣布,正式发布生成式AI助手——AI Companion。 AI Companion提供了与ChatGPT类似的功能,包括根据文本对话起草各种内容,自动生成会议摘要,自动回答会议相关问题等,以帮…

Unity VideoPlayer 指定位置开始播放

如果 source是 videoclip(以下两种方式都可以): _videoPlayer.Play();Debug.Log("time: " _videoPlayer.clip.length);_videoPlayer.time 10; [SerializeField] VideoPlayer videoPlayer;public void SetClipWithTime(VideoClip…

微服务·架构组件之网关- Spring Cloud Gateway

微服务架构组件之网关- Spring Cloud Gateway 引言 微服务架构已成为构建现代化应用程序的关键范式之一,它将应用程序拆分成多个小型、可独立部署的服务。Spring Cloud Gateway是Spring Cloud生态系统中的一个关键组件,用于构建和管理微服务架构中的网…

yml配置动态数据源(数据库@DS)与引起(If you want an embedded database (H2, HSQL or Derby))类问题

1:yml 配置 spring:datasource:dynamic:datasource:master:url: jdbc:mysql://192.168.11.50:3306/dsdd?characterEncodingUTF-8&useUnicodetrue&useSSLfalse&tinyInt1isBitfalse&allowPublicKeyRetrievaltrue&serverTimezoneUTCusername: ro…

0401hive入门-hadoop-大数据学习.md

文章目录 1 Hive概述2 Hive部署2.1 规划2.2 安装软件 3 Hive体验4 Hive客户端4.1 HiveServer2 服务4.2 DataGrip 5 问题集5.1 Could not open client transport with JDBC Uri 结语 1 Hive概述 Apache Hive是一个开源的数据仓库查询和分析工具,最初由Facebook开发&…

Direct3D绘制旋转立方体例程

初始化文件见Direct3D的初始化_direct3dcreate9_寂寂寂寂寂蝶丶的博客-CSDN博客 D3DPractice.cpp #include <windows.h> #include "d3dUtility.h" #include <d3dx9math.h>IDirect3DDevice9* Device NULL; IDirect3DVertexBuffer9* VB NULL; IDirect3…

华为云云耀云服务器L实例评测 | 分分钟完成打地鼠小游戏部署

前言 在上篇文章【华为云云耀云服务器L实例评测 | 快速部署MySQL使用指南】中&#xff0c;我们已经用【华为云云耀云服务器L实例】在命令行窗口内完成了MySQL的部署并简单使用。但是后台有小伙伴跟我留言说&#xff0c;能不能用【华为云云耀云服务器L实例】来实现个简单的小游…

Jmeter系列-阶梯加压线程组Stepping Thread Group详解(6)

前言 tepping Thread Group是第一个自定义线程组但&#xff0c;随着版本的迭代&#xff0c;已经有更好的线程组代替Stepping Thread Group了【Concurrency Thread Group】&#xff0c;所以说Stepping Thread Group已经是过去式了&#xff0c;但还是介绍一下 Stepping Thread …

详解Redis之Lettuce实战

摘要 是 Redis 的一款高级 Java 客户端&#xff0c;已成为 SpringBoot 2.0 版本默认的 redis 客户端。Lettuce 后起之秀&#xff0c;不仅功能丰富&#xff0c;提供了很多新的功能特性&#xff0c;比如异步操作、响应式编程等&#xff0c;还解决了 Jedis 中线程不安全的问题。 …

Python 内置函数速查手册(函数大全,带示例)

1. abs() abs() 返回数字的绝对值。 >>> abs(-7) **输出&#xff1a;**7 >>> abs(7) 输出&#xff1a; 7 2. all() all() 将容器作为参数。如果 python 可迭代对象中的所有值都是 True &#xff0c;则此函数返回 True。空值为 False。 >>>…

使用Spring WebSocket实现实时通信功能

&#x1f337;&#x1f341; 博主猫头虎&#xff08;&#x1f405;&#x1f43e;&#xff09;带您 Go to New World✨&#x1f341; &#x1f984; 博客首页——&#x1f405;&#x1f43e;猫头虎的博客&#x1f390; &#x1f433; 《面试题大全专栏》 &#x1f995; 文章图文…

ElasticSearch详解

目录 一、Elasticsearch是什么&#xff1f; 二、为什么要使用ElasticSearch 2.1 关系型数据库有什么问题&#xff1f; 2.2 ElasticSearch有什么优势&#xff1f; 2.3 ES使用场景 三、ElasticSearch概念、原理与实现 3.1 搜索引擎原理 3.2 Lucene 倒排索引核心原理 倒排…