Pytorch深度学习实践笔记4(b站刘二大人)

🎬个人简介:一个全栈工程师的升级之路!
📋个人专栏:pytorch深度学习
🎀CSDN主页 发狂的小花
🌄人生秘诀:学习的本质就是极致重复!

视频来自【b站刘二大人】

1 反向传播


Back propagation (BP),训练神经网络的目标是优化代价函数cost,使得cost找到以一个全局或者局部最优值。让cost尽可能的接近0,这样得到的weights和bias是最好的,由于需要不断的调整参数让cost收敛,cost在梯度的相反反向下降最快,所以提出了BP算法,就是来计算weights和bias的梯度(偏导数的,加速训练时的收敛速度,避免无效的训练
反向传播求梯度用到了链式求导,很好理解,高中就学习过了。

  • 反向传播的优点:尽力用一次前向传播和一次反向传播,就同时计算出所有参数的偏导数。 反向传播计算量和前向传播差不多,并且有效利用前向传播过程中的计算结果,前向传播的主要计算量 在 权重矩阵和input vector的乘法计算, 反向传播则主要是 矩阵和input vector 的转置的乘法计

2 链式求导

 

神经网络反向传播理解_反向传播的作用-CSDN博客​


3 计算图


计算图可以减轻网络构建的难度,以前需要为每一个神经网络写反向传播算法。
(1)计算图为有向无环图
(2)Pytorch为动态计算图,Tensorflow为静态计算图,后来也改进支持动态计算图
(3)Pytorch的动态计算图,为了节约内存,一轮迭代完后计算图就被在内存释放,因此每次都需要构建新的计算图,计算图代表程序中变量之间的关系
(4)pytorch计算图中,只有两种元素:数据(Tensor)和运算。tensor可以分为两种:叶子节点和非叶子节点。使用backward()函数反向传播计算tensor的梯度时,并不计算所有tensor的梯度,而是只计算满足这几个条件的tensor的梯度:1.类型为叶子节点、2.requires_grad=True、3.依赖该tensor的所有tensor的requires_grad=True。
自己定义的tensor中,requires_grad属性默认是False,而神经网络中的权重w的tensor中requires_grad属性默认为True。
(5)autograd包提供Tensor所有操作的自动求导方法。
torch.Tensor是这个包里面最重要的类。如果设置了requires_grad为True,那么它开始追踪所有在它上面的操作。当你完成了计算,可以使用调用backward(),回自动计算所有的梯度。然后这个tensor的梯度会被自动累积到grad属性上。

pytorch计算图_pytorch 计算图-CSDN博客​

Pytorch快速入门系列---(二)动态计算图、自动微分、torch.nn模块_pytorch计算图训练-CSDN博客​

blog.csdn.net/qq_42681787/article/details/129394170​编辑


4 tensor




Tensor 中指定需要计算梯度,requires_grad = True




w是Tensor(张量类型),Tensor中包含data和grad,data和grad也是Tensor。grad初始为None,调用l.backward()方法后w.grad为Tensor,故更新w.data时需使用w.grad.data。如果w需要计算梯度,那构建的计算图中,跟w相关的tensor都默认需要计算梯度。
调用backward()会将所有的需要计算梯度的都求出来,存储待对应的w.grad.data中。
 

  • torch.tensor() 和 torch.Tensor():

【PyTorch】Tensor和tensor的区别_pytorch tensor tensor-CSDN博客​

torch.FloatTensor和torch.Tensor、torch.tensor-CSDN博客​

  • torch.FloatTensor()


5 代码
 

import matplotlib.pyplot as plt
import torch
import numpy as np# SGD随机梯度下降x_data = np.arange(1.0,200.0,1.0)
y_data = np.arange(2.0,400.0,2.0)def forward(x,w):return x * wdef loss(x,y_true,w):y_pred = forward(x,w)return (y_pred-y_true)**2w = torch.Tensor([1.0])
w.requires_grad = Truelr = 0.00001epoch_list = []
loss_list = []print("Before train 4: ",forward(torch.Tensor([400.]),w).data.item())
for epoch in range(100):seed = np.random.choice(range(len(x_data)))loss_val = loss(x_data[seed],y_data[seed],w)loss_val.backward()w.data -= lr*w.grad.dataw.grad.data.zero_()print("epoch: ",epoch," loss: ",loss_val.data.item()," w: ",w.data.item())epoch_list.append(epoch)loss_list.append(loss_val.data.item())if (loss_val < 1e-7):break
print("After train 4: ",forward(torch.Tensor([400.]),w).data.item())plt.plot(epoch_list,loss_list)
plt.xlabel("epoch")
plt.ylabel("loss")
plt.savefig("./data/pytorch3.png")

import numpy as np
import matplotlib.pyplot as plt
import torch# 假设 3 * x^2 + 2 * x + 2 
x_data = [1.0,2.0,3.0]
y_data = [7.0,18.0,35.0]def forward(x,w1,w2,b):return (w1 * x **2 + w2 *x +b)def loss(x,y_true,w1,w2,b):y_pred = forward(x,w1,w2,b)return (y_pred-y_true)**2w1 = torch.Tensor([1.0])#初始权值
w1.requires_grad = True#计算梯度,默认是不计算的
w2 = torch.Tensor([1.0])
w2.requires_grad = True
b = torch.Tensor([1.0])
b.requires_grad = Truelr = 0.001epoch_list = []
loss_list = []print("Before train 4: ",forward(torch.Tensor([4.]),w1,w2,b).data.item())
for epoch in range(10000):seed = np.random.choice(range(len(x_data)))loss_val = loss(x_data[seed],y_data[seed],w1,w2,b)loss_val.backward()w1.data -= lr*w1.grad.dataw2.data -= lr*w2.grad.datab.data -= lr*b.grad.dataw1.grad.data.zero_()w2.grad.data.zero_()b.grad.data.zero_()print("epoch: ",epoch," loss: ",loss_val.data.item()," w1: ",w1.data.item()," w2: ",w2.data.item()," b: ",b.data.item())epoch_list.append(epoch)loss_list.append(loss_val.data.item())if (loss_val < 1e-7):break
print("After train 4: ",forward(torch.Tensor([4.]),w1,w2,b).data.item())plt.plot(epoch_list,loss_list)
plt.xlabel("epoch")
plt.ylabel("loss")
plt.savefig("./data/pytorch3_1.png")

🌈我的分享也就到此结束啦🌈
如果我的分享也能对你有帮助,那就太好了!
若有不足,还请大家多多指正,我们一起学习交流!
📢未来的富豪们:点赞👍→收藏⭐→关注🔍,如果能评论下就太惊喜了!
感谢大家的观看和支持!最后,☺祝愿大家每天有钱赚!!!欢迎关注、关注!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/19269.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

FFMPEG+ANativeWinodow渲染播放视频

前言 学习音视频开发&#xff0c;入门基本都得学FFMPEG&#xff0c;按照目前互联网上流传的学习路线&#xff0c;FFMPEGANativeWinodow渲染播放视频属于是第一关卡的Boss&#xff0c;简单但是关键。这几天写了个简单的demo&#xff0c;可以比较稳定进行渲染播放&#xff0c;便…

【运维】Linux 端口管理实用指南,扫描端口占用

在 Linux 系统中&#xff0c;你可以使用以下几种方法来查看当前被占用的端口&#xff0c;并检查 7860 到 7870 之间的端口&#xff1a; 推荐命令&#xff1a; sudo lsof -i :7860-7870方法一&#xff1a;使用 netstat 命令 sudo netstat -tuln | grep :78[6-7][0-9]这个命令…

全球痛风年轻化趋势明显 别嘌醇制剂需求增多

全球痛风年轻化趋势明显 别嘌醇制剂需求增多 别嘌醇制剂包括片剂和缓释胶囊两种剂型&#xff0c;别嘌醇片剂吸收快&#xff0c;可能会出现胃肠道反应&#xff1b;别嘌醇缓释胶囊释放比较缓慢&#xff0c;作用更持久&#xff0c;对胃肠道损害比较小。别嘌醇制剂是抑制尿酸合成的…

【CSDN唯一】Python解析.dwg格式文件信息提取

目录 一、装环境1、下载 ODAFileConverter2、安装 ODAFileConverter(1)、安装 gdebi 来处理依赖关系(2)、使用 gdebi 安装 DEB 包 3、解决 libxcb库问题(1)、安装依赖(2)、确认 libxcb-util.so.1 是否存在(3)、创建符号链接 4、安装python包 二、上代码 一、装环境 这里搞的人…

Java内存空间

Java内存空间划分 Java虚拟机在执行Java程序的过程中会把他管理的内存划分为若干个不同的数据区域&#xff0c;如图所示1.7和1.8两个版本的Java内存空间划分。 JDK1.7: JDK1.8: 线程私有&#xff1a; 程序计数器虚拟机栈本地方法栈 线程共享 &#xff1a; 堆方法区直接内…

股价飙升:AI PC大变革,联想的“联想时刻”正在缔造?

按照产业的传导逻辑&#xff0c;在颠覆式技术到来之时&#xff0c;当引发这场变革的最核心技术及产品真正进入了产品化、商业化阶段&#xff0c;此时直触需求端的终端厂商&#xff0c;其成长性估算将得到市场的重新预估。 眼下AI PC之于联想就是如此。 5月27日&#xff0c;联…

SPI协议的基本介绍

1. 基本介绍 SPI&#xff08;Serial Peripheral Interface&#xff0c;串行外设接口&#xff09;是一种高速、全双工、同步的通信协议&#xff0c;主要用于微控制器和各种外部硬件或外设之间的通信&#xff0c;例如传感器、SD卡、液晶显示屏等。 SPI协议由四根线组成&#xff1…

mysql中InnoDB的统计数据

大家好。我们知道&#xff0c;mysql中存在许多的统计数据&#xff0c;比如通过SHOW TABLE STATUS 可以看到关于表的统计数据&#xff0c;通过SHOW INDEX可以看到关于索引的统计数据&#xff0c;那么这些统计数据是怎么来的呢&#xff1f;它们是以什么方式收集的呢&#xff1f;今…

正方形 II

描述 正方形是特殊的平行四边形之一。即有一组邻边相等&#xff0c;并且有一个角是直角的平行四边形称为正方形。设a为正方形的边长&#xff0c;s为正方形的面积&#xff0c;c为正方形的周长。 输入 一行&#xff0c;包含一个正整数a&#xff0c;表示正方形的边长。 输出 …

vscode:如何解决”检测到include错误,请更新includePath“

vscode:如何解决”检测到include错误&#xff0c;请更新includePath“ 前言解决办法1 获取includePath路径2 将includePath路径添加到指定文件3 保存 前言 配置vscode是出现如下错误&#xff1a; 解决办法 1 获取includePath路径 通过cmd打开终端&#xff0c;输入如下指令&a…

【第8章】SpringBoot之单元测试

文章目录 前言一、准备1. 引入库2. 目录结构 二、测试代码1. SpringBoot3ApplicationTests2.测试结果 总结 前言 单元测试是SpringBoot项目的一大利器&#xff0c;在SpringBoot我们可以很轻松地测试我们的接口。 一、准备 1. 引入库 <dependency><groupId>org.s…

Java基于saas模式云MES制造执行系统源码Spring Boot + Hibernate Validation什么是MES系统?

Java基于saas模式云MES制造执行系统源码Spring Boot Hibernate Validation 什么是MES系统&#xff1f; MES制造执行系统&#xff0c;通过互联网技术实现从订单下达到产品完成的整个生产过程进行优化管理。能有效地对生产现场的流程进行智能控制&#xff0c;防错防呆防漏&…

大模型时代的具身智能系列专题(五)

stanford宋舒然团队 宋舒然是斯坦福大学的助理教授。在此之前&#xff0c;他曾是哥伦比亚大学的助理教授&#xff0c;是Columbia Artificial Intelligence and Robotics Lab的负责人。他的研究聚焦于计算机视觉和机器人技术。本科毕业于香港科技大学。 主题相关作品 diffusio…

用Python编写一个开放端口扫描脚本

现在的Ai是真的好用&#xff0c;下面是我编写的开放端口扫描脚本&#xff1a; # coding&#xff1a;utf-8 # 时间&#xff1a;2024/5/27 上午12:15 # 红客技术网&#xff1a;blog.hongkewang.cnimport socket# 设置目标IP地址 ip input("请输入需要扫描端口的IP&#xf…

生成 SSH 证书和私钥

生成 SSH 证书和私钥的过程通常涉及使用 ssh-keygen 命令。以下是生成 SSH 证书和私钥的步骤&#xff1a; 打开终端。 输入 ssh-keygen 命令并按回车。 根据提示设置文件保存位置和对证书的加密密码&#xff08;可选&#xff09;。 示例代码&#xff1a; ssh-keygen -t rs…

hashmap 插入1万条数据会有什么影响

在 Java 中&#xff0c;HashMap 是基于哈希表的 Map 接口的非同步实现。当你向 HashMap 中插入大量数据&#xff0c;如 1 万条数据时&#xff0c;会涉及到以下几个方面的影响&#xff1a; 1. 性能 初始插入速度&#xff1a;通常&#xff0c;HashMap 的插入操作非常快&#xf…

RestTemplate使用详解

文章目录 简介基本操作uri参数传递json参数与header参数设置form-dataexchange复杂类型处理上传文件下载文件 简介 对于http请求之前一直用apache的httpclient&#xff0c;已经习惯了&#xff0c;特别是使用fluent之后&#xff0c;更加方便了。 所以一直没有怎么太过关注Rest…

C 语言实例 - 表格形式输出数据

将 1~100 的数据以 10x10 矩阵格式输出。 #include <stdio.h>int main() {int i, j, count;for(i 1; i < 10; i) {for(j i; j <100; j 10 )printf(" %3d", j);printf("\n");}return 0; }运行结果&#xff1a; 1 11 21 31 41 51 61 …

数据库内核-基础知识

常用索引&#xff1a; 介绍&#xff1a; 哈希表&#xff1a;数组加链表&#xff0c;取字段Hash值做Key,B树&#xff1a; 树形结构&#xff0c;排序后N分查找B树&#xff1a; 树形结构&#xff0c;仅叶子结点存放数据跳表索引&#xff1a;链表链表&#xff0c;相当于一级链…

【YashanDB知识库】kettle从DM8的number类型同步到YashanDB的varchar类型,存入是科学计数法形式的数据

【标题】kettle从DM8的number类型同步到YashanDB的varchar类型&#xff0c;存入是科学计数法形式的数据 【问题分类】数据导入导出 【关键字】数据同步&#xff0c;number类型&#xff0c;科学计数法 【问题描述】客户查询不到准确数据&#xff0c;只看到科学计数法展示的字…