神经网络的工程基础(零)——PyTorch基础

相关说明

这篇文章的大部分内容参考自我的新书《解构大语言模型:从线性回归到通用人工智能》,欢迎有兴趣的读者多多支持。
本文涉及到的代码链接如下:regression2chatgpt/ch06_optimizer/gradient_descent.ipynb

本文将介绍PyTorch的基础。它是神经网络领域常用的建模工具。

关于大语言模型的内容,推荐参考这个专栏。

内容大纲

  • 相关说明
  • 一、PyTorch的数据基础:张量(Tensor)
  • 二、张量的基本计算

一、PyTorch的数据基础:张量(Tensor)

工欲善其事,必先利其器。在讨论如何实现梯度下降法之前,首先探讨一下PyTorch这个强大的工具。PyTorch是一种备受欢迎的开源机器学习框架,被广泛用于构建、训练和部署神经网络模型,因具有灵活性、动态计算图和卓越的GPU支持而成为神经网络领域的首选。

PyTorch的基础数据结构是张量。张量的创建方式如程序清单1所示(完整代码)。

程序清单1 张量的创建
 1 |  # 使用tensor封装的函数创建tensor2 |  zeros = torch.zeros(2, 3)3 |  tensor([[0., 0., 0.],4 |          [0., 0., 0.]])5 |  6 |  ones = torch.ones(2, 3)7 |  tensor([[1., 1., 1.],8 |          [1., 1., 1.]])9 |  
10 |  torch.manual_seed(1024)
11 |  random = torch.rand(3, 4)
12 |  tensor([[0.8090, 0.7935, 0.2099, 0.9279],
13 |          [0.8136, 0.7422, 0.4769, 0.4955],
14 |          [0.3602, 0.1178, 0.7852, 0.0228]])
15 |  
16 |  # 从Python对象创建
17 |  data = [[2, 3, 4], [1, 0, 1]]
18 |  t_data = torch.tensor(data)
19 |  tensor([[2, 3, 4],
20 |          [1, 0, 1]])
21 |  
22 |  ## 从numpy对象创建
23 |  import numpy as np
24 |  
25 |  n_data = np.array(data)
26 |  tn_data = torch.from_numpy(n_data)
27 |  tensor([[2, 3, 4],
28 |          [1, 0, 1]])
29 |  
30 |  ## Numpy bridge,也就是说对numpy对象的改变会传导到tensor
31 |  n_data += 1
32 |  torch.all(torch.from_numpy(n_data) == tn_data)
33 |  tensor(True)

张量的形状(Shape)是至关重要的概念,它定义了张量的维度以及每个维度的大小。在实际应用中,可以通过使用一系列函数来改变张量的形状,使其适应不同的运算需求,如程序清单2所示。

程序清单2 改变张量的形状
 1 |  # 增加或减少数据的维度2 |  a = torch.rand(3, 4)  # (3, 4)3 |  ## 增加维度4 |  b = a.unsqueeze(0)    # (1, 3, 4)5 |  ## 减少维度6 |  c = b.squeeze(0)      # (3, 4)7 |  ## 数据相同,但是维度不同8 |  print(torch.all(c.eq(b)))    # tensor(True)9 |  print(c.shape == b.shape)    # False
10 |  
11 |  # 变换tensor形状
12 |  data = torch.tensor(range(0, 10))   # tensor([1, 2, 3, 4, 5, 6, 7, 8, 9])
13 |  view1 = data.view(2, 5)
14 |  tensor([[0, 1, 2, 3, 4],
15 |          [5, 6, 7, 8, 9]])
16 |  transpose1 = view1.T
17 |  tensor([[0, 5],
18 |          [1, 6],
19 |          [2, 7],
20 |          [3, 8],
21 |          [4, 9]])
22 |  ## 非毗邻存储的对象不能进行view操作
23 |  print(view1.is_contiguous(), transpose1.is_contiguous()) 
24 |  True False
25 |  ## 下面的操作会报错
26 |  view2 = transpose1.view(1, 10)
  1. 程序清单2的第4—6行使用unsqueeze和squeeze函数来增加或减少张量的维度。需要注意的是,这些操作并不会改变张量实际存储的数据,也不会在实质上改变张量的形状。相反,它们只是在张量的形状中添加或删除一个空的维度。具体的变化可以在第8行和第9行中看到。
  2. 为了改变张量的形状,可以使用view函数,如第12—15行所示。但需要注意的是,view函数只能用在毗邻存储的张量1对象上。非毗邻存储的张量只能使用reshape函数来改变形状。尽管这两个函数在功能上相似,但在计算效率上存在显著差异:相较于 view 函数,reshape 的计算开销要大得多。因此,在实际应用中,最好优先选择使用 view 函数。

二、张量的基本计算

张量的运算分为两种:逐元素操作(Element-Wise Operations)和矩阵乘法,这些计算方法在处理数据和构建神经网络模型时都具有重要作用。程序清单6-3中讨论了这些操作,并介绍了PyTorch中的广播机制(Broadcasting Semantics),它在处理不同形状的张量时起到了重要的作用。

程序清单3 张量的常见运算
 1 |  # 逐元素操作2 |  twos = torch.ones(2, 2) * 23 |  tensor([[2., 2.],4 |          [2., 2.]])5 |  powers = twos ** torch.tensor([[1, 2], [3, 4]])6 |  tensor([[ 2.,  4.],7 |          [ 8., 16.]])8 |  9 |  ## tensor广播,tensor broadcasting
10 |  a = torch.tensor(range(1, 7)).view(2, 3)
11 |  tensor([[1, 2, 3],
12 |          [4, 5, 6]])
13 |  b = torch.tensor(range(1, 4)).view(   3)
14 |  tensor([1, 2, 3])
15 |  print(a * b)
16 |  tensor([[ 1,  4,  9],
17 |          [ 4, 10, 18]])    
18 |  ## 关于广播,更复杂的例子
19 |  a =     torch.ones(4, 1, 3, 2)
20 |  b = a * torch.rand(   5, 1, 2)
21 |  print(b.shape)
22 |  torch.Size([4, 5, 3, 2])
23 |  
24 |  # 矩阵运算
25 |  mat1 = torch.randn(3, 4)    # (3, 4)
26 |  mat2 = torch.randn(4, 5)    # (4, 5)
27 |  re = mat1 @ mat2            # (3, 5)
28 |  ## 矩阵运算的广播
29 |  mat1 = torch.randn(5, 1, 3, 4)   # (5, 1, 3, 4)
30 |  mat2 = torch.randn(   8, 4, 5)   # (   8, 4, 5)
31 |  re = mat1 @ mat2                 # (5, 8, 3, 5)
  1. 逐元素操作要求进行运算的两个张量的形状必须相同,如程序清单3中的第2—7行所示。然而,在实际应用中,常常需要对形状不同的张量进行操作。为此,PyTorch引入了广播机制,它允许在一定条件下对形状不同的张量进行逐元素操作,如第9—22行所示。
  2. 广播机制的流程相对复杂,如图1所示,需要注意几个关键步骤。首先,从后向前逐个比较两个张量的维度;接着,对缺失的维度进行扩充(类似于unsqueeze函数的操作);然后,检查广播规则,即两个张量的各分量要么相等,要么其中一个等于1;最后,复制数据,实现广播操作。
  3. 广播机制不仅适用于逐元素操作,它同样影响着张量的矩阵乘法。不同之处在于,当执行矩阵乘法时,广播机制只会作用于前面的维度,而不涉及最后两维,如第29—31行所示。

图1

图1


  1. 毗邻存储(C Contiguous)是一个与硬件相关的概念。简而言之,毗邻存储意味着数据在内存中是连续存储的,这种存储方式能够显著提升数据的读取和计算速度。张量在内存中的存储细节超出了本书的范围,对此感兴趣的读者可以在PyTorch的官方文档中找到更详细的信息。 ↩︎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/17942.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

CUDA学习备份

CUDA项目配置 1.项目属性->配置属性->常规->Windows SDK版本->选实际的版本 2.项目属性->CUDA C/C>Device->修改为对应CUDA型号的算力&#xff0c;例如算力3.5&#xff0c;就设置为compute_35 sm_35 概念&#xff1a; gpuAdd << <1, 1 >>…

EasyExcel实现导入导出

EasyExcel实现导入导出 目录 EasyExcel实现导入导出1、使用场景2、特点3、使用1、使用EasyExcel进行写操作&#xff08;下载Excel&#xff09;1. 在pom文件中添加对应的依赖2. 创建实体类&#xff0c;和excel数据对应3. converter自定义转换器4、性别枚举类 5.普通导出6.多shee…

Linux防火墙(以iptables为例)

目录 Linux配置防火墙1. 引言2. 什么是防火墙3. Linux中的防火墙3.1 iptablesiptables命令参数常用方式&#xff1a;3.1.1 安装iptables3.1.2 配置iptables规则3.1.3 示例一&#xff1a;使用iptables配置防火墙规则4. iptables执行过程 Linux配置防火墙 1. 引言 在互联网时代&…

【从零开始学习RabbitMQ | 第三篇】什么是延迟消息

目录 前言&#xff1a; 延迟消息&#xff1a; 延迟消息实现方式&#xff1a; 死信交换机&#xff1a; 延迟消息插件&#xff1a; 1.基于注解的方式 2.基于Bean的方式 总结&#xff1a; 前言&#xff1a; 在现代软件开发中&#xff0c;异步消息处理已成为构建可扩展、高可…

php爬虫之获取淘宝商品数据

爬取淘宝信息数据 首先需要先导入webdriver 1.from selenium import webdriver webdriver支持主流的浏览器&#xff0c;比如说&#xff1a;谷歌浏览器、火狐浏览器、IE浏览器等等 然后可以创建一个webdriver对象&#xff0c;通过这个对象就可以通过get方法请求网站 1.driver…

学习前端第四十五天(冒泡和捕获、事件委托)

一、冒泡和捕捉 1、冒泡 当一个事件发生在一个元素上&#xff0c;它会首先运行在该元素上的处理程序&#xff0c;然后运行其父元素上的处理程序&#xff0c;然后一直向上到其他祖先上的处理程序 <div class"box" onclick"console.log(1)">box<d…

全身关节活动评估训练系统:提升健康与康复的新科技

随着科技的不断进步&#xff0c;医疗和健身领域也迎来了巨大的变革。其中&#xff0c;全身关节活动评估训练系统作为一种创新的科技产品&#xff0c;正在逐渐改变我们对健康、康复以及健身的认知。本文将深入探讨这一系统的原理、功能、应用以及其对个人健康和社会福祉的潜在影…

闲鱼详情API接口探析

随着互联网的快速发展&#xff0c;我国闲置交易市场逐渐繁荣&#xff0c;闲鱼作为阿里巴巴旗下闲置交易平台&#xff0c;已经成为众多用户的选择。为了方便开发者构建第三方应用&#xff0c;闲鱼提供了详细的API接口&#xff0c;联讯数据将对闲鱼详情API接口进行深入分析&#…

时序数据库InfluxDB面试题和参考答案

目录 InfluxDB如何处理大规模数据集? 如何使用InfluxDB进行实时分析?

你真的懂firewalld吗?不妨看看我的这篇文章

一、firewalld简介 firewalld防火墙是Linux系统上的一种动态防火墙管理工具&#xff0c;它是Red Hat公司开发的&#xff0c;并在许多Linux发行版中被采用。相对于传统的静态防火墙规则&#xff0c;firewalld使用动态的方式来管理防火墙规则&#xff0c;可以更加灵活地适应不同…

TypeScript中的`let`、`const`、`var`区别:变量声明的规范与实践

TypeScript中的let、const、var区别&#xff1a;变量声明的规范与实践 引言 在TypeScript中&#xff0c;变量声明是代码编写的基础部分。let、const、var 是三种用于变量声明的关键字&#xff0c;它们各自有不同的作用域规则和可变性特点。 基础知识 作用域&#xff1a;变量…

ctfhub中的SSRF相关例题(中)

目录 上传文件 gopher协议的工作原理&#xff1a; gopher协议的使用方法&#xff1a; 相关例题: FastCGI协议 FastCGI协议知识点 相关例题&#xff1a; Redis协议 知识点&#xff1a; 相关例题 第一种方法 第二种方法 上传文件 gopher协议的工作原理&#xff1a; …

开箱元宇宙| 探索家乐福如何在The Sandbox 中重新定义零售和可持续发展

有没有想过 The Sandbox 如何与世界上最具代表性的品牌和名人的战略保持一致&#xff1f;在本期的 "开箱元宇宙 "系列中&#xff0c;我们与家乐福团队进行了对话&#xff0c;这家法国巨头率先采用web3技术重新定义零售和可持续发展。 家乐福的用户平均游玩时间为 57 …

QWidget For Android之QDialog中QLineEdit无法编辑问题

项目场景&#xff1a; QWidget For Android 问题描述 QDialog打开对话框时&#xff0c;QLineEdit输入框无法输入 this->setWindowFlags(Qt::FramelessWindowHint | Qt::Tool | Qt::WindowStaysOnTopHint);this->setAttribute(Qt::WA_TranslucentBackground);原因分析&a…

maven部署到私服

方法一:网页上传 1、账号登录 用户名/密码 2、地址 http://自己的ip:自己的端口/nexus 3、查看Repositories列表&#xff0c;选择Public Repositories&#xff0c;确定待上传jar包不在私服中 4、选择3rd party仓库&#xff0c;点击Artifact Upload页签 5、GAV Definition选…

2024上半年软考 考试心得

考试的时候感觉选择题有点偏&#xff0c;很多概念题都不知道是什么&#xff0c;好像没怎么见过&#xff0c;什么拖库洗库&#xff0c;linux权限号不会&#xff0c;python也不确定&#xff0c;但也算顺利&#xff1b;下午题的数据库竟然没考主键外键&#xff0c;我的天哪&#x…

蓝桥杯嵌入式国赛笔记(3):其他拓展板程序设计(温、湿度传感器、光敏电阻等)

目录 1、DS18B20读取 2、DHT11 2.1 宏定义 2.2 延时 2.3 设置引脚输出 2.4 设置引脚输入 2.5 复位 2.6 检测函数 2.7 读取DHT11一个位 2.7.1 数据位为0的电平信号显示 2.7.2 数据位为1的电平信号显示 2.8 读取DHT11一个字节 2.9 DHT11初始化 2.10 读取D…

exe4j --实现把jar包打成exe可执行文件

工具准备 1.Java编辑器&#xff0c;如&#xff1a;idea、eclipse等&#xff0c;下载地址&#xff1a; IntelliJ IDEA: The Capable & Ergonomic Java IDE by JetBrains https://www.jetbrains.com/idea/ 2.exe4j&#xff0c;下载地址&#xff1a; ej-technologies - Java A…

SQL试题使得每个学生 按照姓名的字⺟顺序依次排列 在对应的⼤洲下⾯

学⽣地理信息报告 学校有来⾃亚洲、欧洲和美洲的学⽣。 表countries 数据如下&#xff1a; namecontinentJaneAmericaPascalEuropeXiAsiaJackAmerica 1、编写解决⽅案实现对⼤洲&#xff08;continent&#xff09;列的 透视表 操作&#xff0c;使得每个学生 按照姓名的字⺟顺…

常用批处理命令及批处理文件编写技巧

一常用批处理命令 1.查看命令用法&#xff1a;命令 /? //如&#xff1a;cd /? 2.切换盘符目录&#xff1a;cd /d D:\test 或直接输入 d: //进入上次d盘所在的目录 3.切换目录&#xff1a;cd test 4.清屏:cls 5.“arp -a” //它会列出当前设备缓存中的所有…