Pytorch深度学习实践笔记4

🎬个人简介:一个全栈工程师的升级之路!
📋个人专栏:pytorch深度学习
🎀CSDN主页 发狂的小花
🌄人生秘诀:学习的本质就是极致重复!

视频来自【b站刘二大人】

1 反向传播


Back propagation (BP),训练神经网络的目标是优化代价函数cost,使得cost找到以一个全局或者局部最优值。让cost尽可能的接近0,这样得到的weights和bias是最好的,由于需要不断的调整参数让cost收敛,cost在梯度的相反反向下降最快,所以提出了BP算法,就是来计算weights和bias的梯度(偏导数的,加速训练时的收敛速度,避免无效的训练
反向传播求梯度用到了链式求导,很好理解,高中就学习过了。

  • 反向传播的优点:尽力用一次前向传播和一次反向传播,就同时计算出所有参数的偏导数。 反向传播计算量和前向传播差不多,并且有效利用前向传播过程中的计算结果,前向传播的主要计算量 在 权重矩阵和input vector的乘法计算, 反向传播则主要是 矩阵和input vector 的转置的乘法计

2 链式求导

 

神经网络反向传播理解_反向传播的作用-CSDN博客​


3 计算图


计算图可以减轻网络构建的难度,以前需要为每一个神经网络写反向传播算法。
(1)计算图为有向无环图
(2)Pytorch为动态计算图,Tensorflow为静态计算图,后来也改进支持动态计算图
(3)Pytorch的动态计算图,为了节约内存,一轮迭代完后计算图就被在内存释放,因此每次都需要构建新的计算图,计算图代表程序中变量之间的关系
(4)pytorch计算图中,只有两种元素:数据(Tensor)和运算。tensor可以分为两种:叶子节点和非叶子节点。使用backward()函数反向传播计算tensor的梯度时,并不计算所有tensor的梯度,而是只计算满足这几个条件的tensor的梯度:1.类型为叶子节点、2.requires_grad=True、3.依赖该tensor的所有tensor的requires_grad=True。
自己定义的tensor中,requires_grad属性默认是False,而神经网络中的权重w的tensor中requires_grad属性默认为True。
(5)autograd包提供Tensor所有操作的自动求导方法。
torch.Tensor是这个包里面最重要的类。如果设置了requires_grad为True,那么它开始追踪所有在它上面的操作。当你完成了计算,可以使用调用backward(),回自动计算所有的梯度。然后这个tensor的梯度会被自动累积到grad属性上。

pytorch计算图_pytorch 计算图-CSDN博客​

Pytorch快速入门系列---(二)动态计算图、自动微分、torch.nn模块_pytorch计算图训练-CSDN博客​

blog.csdn.net/qq_42681787/article/details/129394170​编辑


4 tensor




Tensor 中指定需要计算梯度,requires_grad = True




w是Tensor(张量类型),Tensor中包含data和grad,data和grad也是Tensor。grad初始为None,调用l.backward()方法后w.grad为Tensor,故更新w.data时需使用w.grad.data。如果w需要计算梯度,那构建的计算图中,跟w相关的tensor都默认需要计算梯度。
调用backward()会将所有的需要计算梯度的都求出来,存储待对应的w.grad.data中。
 

  • torch.tensor() 和 torch.Tensor():

【PyTorch】Tensor和tensor的区别_pytorch tensor tensor-CSDN博客​

torch.FloatTensor和torch.Tensor、torch.tensor-CSDN博客​

  • torch.FloatTensor()


5 代码
 

import matplotlib.pyplot as plt
import torch
import numpy as np# SGD随机梯度下降x_data = np.arange(1.0,200.0,1.0)
y_data = np.arange(2.0,400.0,2.0)def forward(x,w):return x * wdef loss(x,y_true,w):y_pred = forward(x,w)return (y_pred-y_true)**2w = torch.Tensor([1.0])
w.requires_grad = Truelr = 0.00001epoch_list = []
loss_list = []print("Before train 4: ",forward(torch.Tensor([400.]),w).data.item())
for epoch in range(100):seed = np.random.choice(range(len(x_data)))loss_val = loss(x_data[seed],y_data[seed],w)loss_val.backward()w.data -= lr*w.grad.dataw.grad.data.zero_()print("epoch: ",epoch," loss: ",loss_val.data.item()," w: ",w.data.item())epoch_list.append(epoch)loss_list.append(loss_val.data.item())if (loss_val < 1e-7):break
print("After train 4: ",forward(torch.Tensor([400.]),w).data.item())plt.plot(epoch_list,loss_list)
plt.xlabel("epoch")
plt.ylabel("loss")
plt.savefig("./data/pytorch3.png")

import numpy as np
import matplotlib.pyplot as plt
import torch# 假设 3 * x^2 + 2 * x + 2 
x_data = [1.0,2.0,3.0]
y_data = [7.0,18.0,35.0]def forward(x,w1,w2,b):return (w1 * x **2 + w2 *x +b)def loss(x,y_true,w1,w2,b):y_pred = forward(x,w1,w2,b)return (y_pred-y_true)**2w1 = torch.Tensor([1.0])#初始权值
w1.requires_grad = True#计算梯度,默认是不计算的
w2 = torch.Tensor([1.0])
w2.requires_grad = True
b = torch.Tensor([1.0])
b.requires_grad = Truelr = 0.001epoch_list = []
loss_list = []print("Before train 4: ",forward(torch.Tensor([4.]),w1,w2,b).data.item())
for epoch in range(10000):seed = np.random.choice(range(len(x_data)))loss_val = loss(x_data[seed],y_data[seed],w1,w2,b)loss_val.backward()w1.data -= lr*w1.grad.dataw2.data -= lr*w2.grad.datab.data -= lr*b.grad.dataw1.grad.data.zero_()w2.grad.data.zero_()b.grad.data.zero_()print("epoch: ",epoch," loss: ",loss_val.data.item()," w1: ",w1.data.item()," w2: ",w2.data.item()," b: ",b.data.item())epoch_list.append(epoch)loss_list.append(loss_val.data.item())if (loss_val < 1e-7):break
print("After train 4: ",forward(torch.Tensor([4.]),w1,w2,b).data.item())plt.plot(epoch_list,loss_list)
plt.xlabel("epoch")
plt.ylabel("loss")
plt.savefig("./data/pytorch3_1.png")

🌈我的分享也就到此结束啦🌈
如果我的分享也能对你有帮助,那就太好了!
若有不足,还请大家多多指正,我们一起学习交流!
📢未来的富豪们:点赞👍→收藏⭐→关注🔍,如果能评论下就太惊喜了!
感谢大家的观看和支持!最后,☺祝愿大家每天有钱赚!!!欢迎关注、关注!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/14347.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

信息化项目交付验收流程管理办法

项目交付验收流程制度 管理办法 (执行版) (文件编号: ) 编制: 审核: 批准: 版本: 生效日期: 管理办法概述 制定目的为了保证公司在建项目交付验收工作事项的顺利开展,保证交付验收进度及…

创新力作 焕新首发丨捷顺科技·捷曜系列智慧停车新品全新上市

2024捷顺科技智慧停车全家族新品全面上市 全新外观、全新特性、全新体验 新控制机、新道闸、新超眸相机... 每款新品都有哪些功能亮点 带您一探究竟

解决vue3 vite打包报Root file specified for compilation问题

解决方法&#xff1a; 修改package.json打包命令 把 "build": "vue-tsc --noEmit && vite build" 修改为 "build": "vite build" 就可以了 另外关于allowJs这个问题&#xff0c;在tsconfig.json文件中配置"allowJs&qu…

C++入门:从C语言到C++的过渡(1)

目录 1.什么是C 2.C的标准库 3.命名空间 3.1为什么要存在命名空间 3.2命名空间的定义 3.3命名空间的使用 3.3.1域作用限定符 3.3.2using关键字引入某个成员 3.3.3using关键字引入命名空间名称 3.4命名空间的嵌套 3.5命名空间的合并 4.C中的输入与输出 1.什么是C C&am…

mysql binlog统一恢复误删数据库、表、数据(没有任何备份)

先将mysql文件夹中的my.ini进行设置 在 [mysqld]下边加上 # mysql-bin 是日志的基本名或前缀名&#xff0c;最后生成的日志文件是mysql-bin.000001类似&#xff0c;重启mysql数字会递增 log_binmysql-bin #binlog格式&#xff0c;statement&#xff0c;row&#xff0c;mixed可…

Reactor设计模式

Reactor设计模式 Reactor模式称为反应器模式或应答者模式&#xff0c;是基于事件驱动的设计模式&#xff0c;拥有一个或多个并发输入源&#xff0c;有一个服务处理器和多个请求处理器&#xff0c;服务处理器会同步的将输入的请求事件以多路复用的方式分发给相应的请求处理器。…

前端自动将 HTTP 请求升级为 HTTPS 请求

前端将HTTP请求升级为HTTPS请求有两种方式&#xff1a; 一、index.html 中插入meta 直接在首页 index.html 的 head 中加入一条 meta 即可&#xff0c;如下所示&#xff1a; <meta http-equiv"Content-Security-Policy" content"upgrade-insecure-requests&…

树洞陪聊系统源码/陪聊/陪玩/树洞/陪陪/公众号开发/源码交付/树洞系统源码

独立版本源码交付&#xff0c;自研UI和前后端代码 平台自带店员&#xff0c;无需自主招募&#xff0c;搭建直接运营 支持三方登录&#xff0c;官方支付、虎皮椒、易支付/码支付 支持首单体验、盲盒订单、指定下单等多个模式 支持钱包预充值、店员收藏、订单评价等功能 支持…

AI日报:讯飞星火Lite API永久免费;李开复称大模型疯狂降价是双输;AI特效末日滤镜抖音爆火;AI音乐Suno 融资1.25亿美元

欢迎来到【AI日报】栏目!这里是你每天探索人工智能世界的指南&#xff0c;每天我们为你呈现AI领域的热点内容&#xff0c;聚焦开发者&#xff0c;助你洞悉技术趋势、了解创新AI产品应用。 新鲜AI产品点击了解&#xff1a;AIbase - 智能匹配最适合您的AI产品和网站 1、科大讯飞…

can设备调试 - linux driver

这篇文章主要介绍can设备的调试相关信息&#xff0c;不具体介绍驱动的实现。 如果驱动写完&#xff0c;对can设备进行验证&#xff0c;可能会出现很多不可预见的问题。下面说说验证步骤 验证can设备可以使用工具can-utils。这个工具包中会有cansend candump等程序。可以直接通…

系统架构师考试(十)

SaaS为在线客服 PaaS为二次开发&#xff0c;比如低代码平台 IaaS 硬件开发 B 是基础设施作为服务 软件架构的概念 架构风格 数据流风格 网络报文是在计算机网络中通过网络传输的数据单元&#xff0c;它是网络通信的基本单位。网络报文包含了发送方和接收方之间传输的数据&…

『网络攻防和AI安全之家』星球正式运营及CSDN安全知识汇总,欢迎广大博友加入

“今天是Eastmount的安全星球 —— 『网络攻防和AI安全之家』正式创建和运营的日子&#xff0c;该星球目前主营业务为 安全零基础答疑、安全技术分享、AI安全技术分享、AI安全论文交流、威胁情报每日推送、网络攻防技术总结、系统安全技术实战、面试求职、安全考研考博、简历修…

计算机操作系统核心组件

我是荔园微风&#xff0c;作为一名在IT界整整25年的老兵&#xff0c;今天给大家讲讲操作系统。 操作系统核心组件 用户借助于一个或多个应用程序与操作系统进行交互&#xff0c;常常是通过一个称为shell的特殊应用程序进行的&#xff0c;shell也叫作命令解释器。105今天的大多…

Postgresql源码(130)ExecInterpExpr转换为IR的流程

相关 《Postgresql源码&#xff08;127&#xff09;投影ExecProject的表达式执行分析》 《Postgresql源码&#xff08;128&#xff09;深入分析JIT中的函数内联llvm_inline》 《Postgresql源码&#xff08;129&#xff09;JIT函数中如何使用PG的类型llvmjit_types》 表达式计算…

Java设计模式 _行为型模式_迭代器模式

一、迭代器模式 1、迭代器模式 迭代器模式&#xff08;Iterator Pattern&#xff09;是一种行为型设计模式&#xff0c;用于顺序访问集合对象的元素&#xff0c;不需要关心集合对象的底层表示。如&#xff1a;java中的Iterator接口就是这个工作原理。 2、实现思路 &#xff0…

tomcat jdbc连接池的默认配置配置方案

MySQL 5.0 以后针对超长时间数据库连接做了一个处理&#xff0c;即一个数据库连接在无任何操作情况下过了 8 个小时后(MySQL 服务器默认的超时时间是 8 小时)&#xff0c;MySQL 会自动把这个连接关闭。在数据库连接池中的 connections 如果空闲超过 8 小时&#xff0c;MySQL 将…

国家自然博物馆“云端自然”线上虚拟展厅是如何搭建的?

国家级综合性自然博物馆国家自然博物馆&#xff0c;联手积木易搭打造“云端自然”线上虚拟展览&#xff0c;形成一个集参观游览、科普教育为一体的线上虚拟数字博物馆平台&#xff0c;让数千以至数万年的古生物&#xff0c;栩栩如生地呈现在我们面前。 通过数字化的展示手段&am…

在做题中学习(61):连续数组

525. 连续数组 - 力扣&#xff08;LeetCode&#xff09; 思路&#xff1a;前缀和 哈希表 转化&#xff1a;将 0 ——> -1 转变为&#xff1a;找到和为0的最长子数组 细节&#xff1a; 1.哈希表存什么 前缀和 &#xff0c; 长度 2.什么时候存入哈希表 先处理前一个&…

怎么用二维码看excel表格?生成文件二维码的制作技巧

Excel表格怎么放到二维码中&#xff0c;让其他人通过扫码查看数据呢&#xff1f;现在文件放入二维码中展示在很多的场景中都有应用&#xff0c;比如通知、数据、作品、报告等类型的内容都可以通过扫码的方式在手机上展现&#xff0c;那么如何将文件生成二维码呢&#xff1f; 文…