PyTorch 基础篇(1):Pytorch 基础

Pytorch 学习开始
入门的材料来自两个地方:

第一个是官网教程:WELCOME TO PYTORCH TUTORIALS,特别是官网的六十分钟入门教程 DEEP LEARNING WITH PYTORCH: A 60 MINUTE BLITZ。

第二个是韩国大神 Yunjey Choi 的 Repo:pytorch-tutorial,代码写得干净整洁。

目的:我是直接把 Yunjey 的教程的 python 代码挪到 Jupyter Notebook 上来,一方面可以看到运行结果,另一方面可以添加注释和相关资料链接。方便后面查阅。

顺便一题,我的 Pytorch 的版本是 0.4.1

  
  1. import torch
  2. print(torch.version)
  
  1. 0.4.1
  
  1. # 包
  2. import torch
  3. import torchvision
  4. import torch.nn as nn
  5. import numpy as np
  6. import torchvision.transforms as transforms

autograd(自动求导 / 求梯度) 基础案例 1

  
  1. # 创建张量(tensors)
  2. x = torch.tensor(1., requires_grad=True)
  3. w = torch.tensor(2., requires_grad=True)
  4. b = torch.tensor(3., requires_grad=True)
  5.  
  6. # 构建计算图( computational graph):前向计算
  7. y = w * x + b # y = 2 * x + 3
  8.  
  9. # 反向传播,计算梯度(gradients)
  10. y.backward()
  11.  
  12. # 输出梯度
  13. print(x.grad) # x.grad = 2
  14. print(w.grad) # w.grad = 1
  15. print(b.grad) # b.grad = 1
  
  1. tensor(2.)
  2. tensor(1.)
  3. tensor(1.)

autograd(自动求导 / 求梯度) 基础案例 2

  
  1. # 创建大小为 (10, 3) 和 (10, 2)的张量.
  2. x = torch.randn(10, 3)
  3. y = torch.randn(10, 2)
  4.  
  5. # 构建全连接层(fully connected layer)
  6. linear = nn.Linear(3, 2)
  7. print ('w: ', linear.weight)
  8. print ('b: ', linear.bias)
  9.  
  10. # 构建损失函数和优化器(loss function and optimizer)
  11. # 损失函数使用均方差
  12. # 优化器使用随机梯度下降,lr是learning rate
  13. criterion = nn.MSELoss()
  14. optimizer = torch.optim.SGD(linear.parameters(), lr=0.01)
  15.  
  16. # 前向传播
  17. pred = linear(x)
  18.  
  19. # 计算损失
  20. loss = criterion(pred, y)
  21. print('loss: ', loss.item())
  22.  
  23. # 反向传播
  24. loss.backward()
  25.  
  26. # 输出梯度
  27. print ('dL/dw: ', linear.weight.grad)
  28. print ('dL/db: ', linear.bias.grad)
  29.  
  30. # 执行一步-梯度下降(1-step gradient descent)
  31. optimizer.step()
  32.  
  33. # 更底层的实现方式是这样子的
  34. # linear.weight.data.sub_(0.01 * linear.weight.grad.data)
  35. # linear.bias.data.sub_(0.01 * linear.bias.grad.data)
  36.  
  37. # 进行一次梯度下降之后,输出新的预测损失
  38. # loss的确变少了
  39. pred = linear(x)
  40. loss = criterion(pred, y)
  41. print(‘loss after 1 step optimization: ‘, loss.item())
  
  1. w: Parameter containing:
  2. tensor([[ 0.5180, 0.2238, -0.5470],
  3. [ 0.1531, 0.2152, -0.4022]], requires_grad=True)
  4. b: Parameter containing:
  5. tensor([-0.2110, -0.2629], requires_grad=True)
  6. loss: 0.8057981729507446
  7. dL/dw: tensor([[-0.0315, 0.1169, -0.8623],
  8. [ 0.4858, 0.5005, -0.0223]])
  9. dL/db: tensor([0.1065, 0.0955])
  10. loss after 1 step optimization: 0.7932316660881042

从 Numpy 装载数据

  
  1. # 创建Numpy数组
  2. x = np.array([[1, 2], [3, 4]])
  3. print(x)
  4.  
  5. # 将numpy数组转换为torch的张量
  6. y = torch.from_numpy(x)
  7. print(y)
  8.  
  9. # 将torch的张量转换为numpy数组
  10. z = y.numpy()
  11. print(z)
  
  1. [[1 2]
  2. [3 4]]
  3. tensor([[1, 2],
  4. [3, 4]])
  5. [[1 2]
  6. [3 4]]

输入工作流(Input pipeline)

  
  1. # 下载和构造CIFAR-10 数据集
  2. # Cifar-10数据集介绍:https://www.cs.toronto.edu/~kriz/cifar.html
  3. train_dataset = torchvision.datasets.CIFAR10(root=’…/…/…/data/’,
  4. train=True,
  5. transform=transforms.ToTensor(),
  6. download=True)
  7.  
  8. # 获取一组数据对(从磁盘中读取)
  9. image, label = train_dataset[0]
  10. print (image.size())
  11. print (label)
  12.  
  13. # 数据加载器(提供了队列和线程的简单实现)
  14. train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
  15. batch_size=64,
  16. shuffle=True)
  17.  
  18. # 迭代的使用
  19. # 当迭代开始时,队列和线程开始从文件中加载数据
  20. data_iter = iter(train_loader)
  21.  
  22. # 获取一组mini-batch
  23. images, labels = data_iter.next()
  24.  
  25.  
  26. # 正常的使用方式如下:
  27. for images, labels in train_loader:
  28. # 在此处添加训练用的代码
  29. pass
  
  1. Files already downloaded and verified
  2. torch.Size([3, 32, 32])
  3. 6

自定义数据集的 Input pipeline

  
  1. # 构建自定义数据集的方式如下:
  2. class CustomDataset(torch.utils.data.Dataset):
  3. def init(self):
  4. # TODO
  5. # 1. 初始化文件路径或者文件名
  6. pass
  7. def getitem(self, index):
  8. # TODO
  9. # 1. 从文件中读取一份数据(比如使用nump.fromfile,PIL.Image.open)
  10. # 2. 预处理数据(比如使用 torchvision.Transform)
  11. # 3. 返回数据对(比如 image和label)
  12. pass
  13. def len(self):
  14. # 将0替换成数据集的总长度
  15. return 0
  16. # 然后就可以使用预置的数据加载器(data loader)了
  17. custom_dataset = CustomDataset()
  18. train_loader = torch.utils.data.DataLoader(dataset=custom_dataset,
  19. batch_size=64,
  20. shuffle=True)
  21.  
  22. 预训练模型
  
  1. # 下载并加载预训练好的模型 ResNet-18
  2. resnet = torchvision.models.resnet18(pretrained=True)
  3.  
  4.  
  5. # 如果想要在模型仅对Top Layer进行微调的话,可以设置如下:
  6. # requieres_grad设置为False的话,就不会进行梯度更新,就能保持原有的参数
  7. for param in resnet.parameters():
  8. param.requires_grad = False
  9. # 替换TopLayer,只对这一层做微调
  10. resnet.fc = nn.Linear(resnet.fc.in_features, 100) # 100 is an example.
  11.  
  12. # 前向传播
  13. images = torch.randn(64, 3, 224, 224)
  14. outputs = resnet(images)
  15. print (outputs.size()) # (64, 100)
  
  1. torch.Size([64, 100])

保存和加载模型

  
  1. # 保存和加载整个模型
  2. torch.save(resnet, ‘model.ckpt’)
  3. model = torch.load(‘model.ckpt’)
  4.  
  5. # 仅保存和加载模型的参数(推荐这个方式)
  6. torch.save(resnet.state_dict(), ‘params.ckpt’)
  7. resnet.load_state_dict(torch.load(‘params.ckpt’))

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/204053.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Transformer中的layer norm(包含代码解释)

在transformer中存在add&norm操作,add操作很简单,就是把注意力矩阵和原来的矩阵相加,也就是残差链接,可以有效减少梯度消失。如下图所示,为layer norm的解释图,可以看出layer norm是针对一个token来做的…

接口自动化测试总结,接口鉴权+加密与解密+数据库操作/断言...

前言 1、接口鉴权的多种方式 1)后端接口鉴权常用方法 cookie: 携带身份信息请求认证 之后的每次请求都携带cookie信息,cookie记录在请求头中 token: 携带身份信息请求认证 之后的每次请求都携带token认证信息 可能记录在请求头…

Java随记

Java java保留两位小数 1、使用String.format()方法: public static void stringFormatdecimalFormatKeepTwoDecimalPlaces(){double number 3.1415926;String result String.format("%.2f", number);System.out.println(result);}输出:3…

Large Language Models areVisual Reasoning Coordinators

目录 一、论文速读 1.1 摘要 1.2 论文概要总结 二、论文精度 2.1 论文试图解决什么问题? 2.2 论文中提到的解决方案之关键是什么? 2.3 用于定量评估的数据集是什么?代码有没有开源? 2.4 这篇论文到底有什么贡献&#xff1…

第十五章 : Spring Boot 集成MyBatis 多种方式

第十五章 : Spring Boot 集成MyBatis 方式 前言 本章知识重点:Spring Boot集成MyBatis的两种方式:注解方式和配置文件集成方式,重点推荐一款脚手架工具-mybatis-plus3以及如何在Idea中集成与应用;大大提高了开发效率,代码更加规范和简洁。 Spring Boot数据访问概述 在…

振弦采集仪助力岩土工程质量控制

振弦采集仪助力岩土工程质量控制 随着工程建设规模越来越大,建筑结构的安全性和稳定性越来越成为人们所关注的焦点。岩土工程在工程建设中占据着非常重要的地位,岩土工程质量控制更是至关重要。而振弦采集仪作为一种先进的检测设备,正得到越…

[PyTorch][chapter 5][李宏毅深度学习][Classification]

前言: 这章节主要讲解常用的分类器原理.分类主要是要找到一个映射函数 比如垃圾邮件分类 : c0, 垃圾邮件 c1 正常邮件 主要应用场景: 垃圾邮件分类,手写数字识别,金融信用评估. 这里面简单了解一下,很少用 目录: 1: …

还记得当初自己为什么选择计算机?

还记得当初自己为什么选择计算机? 当初你问我为什么选择计算机,我笑着回答:“因为我梦想成为神奇的码农!我想像编织魔法一样编写程序,创造出炫酷的虚拟世界!”谁知道,我刚入门的那天&#xff0…

离线数仓构建案例一

数据采集 日志数据(文件)到Kafka 自己写个程序模拟一些用户的行为数据,这些数据存在一个文件夹中。 接着使用flume监控采集这些文件,然后发送给kafka中待消费。 1、flume采集配置文件 监控文件将数据发给kafka的flume配置文件…

STM32——定时器Timer

定时器工作原理 软件定时 缺点:不精确、占用 CPU 资源 void Delay500ms() //11.0592MHz {unsigned char i, j, k;_nop_();i 4;j 129;k 119;do{do{while (--k);} while (--j);} while (--i); } 使用精准的时基,通过硬件的方式,实现定时功…

Linux---访问NFS存储及自动挂载

本章主要介绍NFS客户端的使用 创建NFS服务器并通过NFS共享一个目录在客户端上访问NFS共享的目录自动挂载的配置和使用 访问NFS存储 前面介绍了本地存储,本章就来介绍如何使用网络上的存储设备。NFS即网络文件系统, 所实现的是 Linux 和 Linux 之间的共…

TypeScript中泛型函数

一.概览 此前,对泛型有了整体的概览,详见TypeScript中的泛型,后面的系列会详细地介绍TypeScript的泛型。此篇文章主要介绍泛型函数 二. 泛型函数 泛型是类型不明确的数据类型,在定义时,接收泛指的数据类型&#xff…

易点易动:颠覆固定资产用量管理,实现高效精准的企业固定资产管理

固定资产用量管理是企业日常运营中不可或缺的一环。然而,传统的人工管理方式面临着时间成本高、数据不准确、难以监控等问题。为了解决这些挑战,易点易动应运而生,它是一款先进的资产管理系统,能够帮助企业实现高效精准的固定资产…

【Java项目管理工具】Maven

Maven 文章目录 Maven一、简介二、安装和配置三、GAVP四、IDEA Maven Java Web工程五、插件、命令、生命周期六、依赖配置七、构建配置八、依赖传递与依赖冲突九、Maven工程继承和聚合关系9.1 工程继承关系9.2 工程聚合关系 十、Maven私服10.1 Nexus下载安装10.2 Nexus上的各种…

案例054:基于微信的追星小程序

文末获取源码 开发语言:Java 框架:SSM JDK版本:JDK1.8 数据库:mysql 5.7 开发软件:eclipse/myeclipse/idea Maven包:Maven3.5.4 小程序框架:uniapp 小程序开发软件:HBuilder X 小程序…

linux的权限741

741权限 在 Linux 中,文件和目录的权限由三组权限来定义,分别是所有者(Owner)、所属组(Group)和其他用户(Others)。每一组权限又分为读(Read)、写&#xff0…

c++函数模板STL详解

函数模板 函数模板语法 所谓函数模板,实际上是建立一个通用函数,其函数类型和形参类型不具体指定,用一个虚拟的类型来代表。这个通用函数就称为函数模板。 凡是函数体相同的函数都可以用这个模板来代替,不必定义多个函数&#xf…

Java安全之Commons Collections5

CC5分析 import org.apache.commons.collections.Transformer; import org.apache.commons.collections.functors.ChainedTransformer; import org.apache.commons.collections.functors.ConstantTransformer; import org.apache.commons.collections.functors.InvokerTransfo…

基于ssm绿色农产品推广应用网站论文

摘 要 21世纪的今天,随着社会的不断发展与进步,人们对于信息科学化的认识,已由低层次向高层次发展,由原来的感性认识向理性认识提高,管理工作的重要性已逐渐被人们所认识,科学化的管理,使信息存…

Cloudways和SiteGround哪个更好?

当提及WordPress托管服务提供商时,人们常常会拿Cloudways和SiteGround做比较。Cloudways作为备受欢迎的品牌,而SiteGround则是业界的老牌巨头。它们之间主要的区别在于服务范围。SiteGround提供广泛的托管服务,包括Web托管、WordPress托管、W…