一个有意思pytorch的简单应用小实验

        通过一个简单的脚本,来学习pytorch的基本应用,比如:前向传播、反向传播、学习率以及预测、模型的基本原理和套路。

        得到结果。。。保存模型。。。输入参数。。。预测。。。像不像?。。。像多少?。。。

        设计目标:一个包含了两个元素的输入张量,经过一个线性模型的运算后输出预测结果,经过前向传播、反向传播、学习调整后,使预测的结果尽量接近目标结果。

        输入张量:in_tensor=[2.0, 9.0]

        线性模型:model(k0 * in_tensor[0] + k1 * in_tensor[1])

        目标结果:100。

        总结来说就是:设计目标:2.0*k0 + 9.0*k1 = 100,通过Pytorch的惯用框架和套路,经过多次学习和迭代优化之后,求出k0和k1的最优值。

基本代码

import torch
import random# 定义常量
TARGET_VALUE = 100
LR = 0.01  # 学习率# 初始化张量和权重
in_tensor = torch.tensor([2.0, 4.0])
k0 = torch.tensor(random.random(), requires_grad=True)  # 权重需要计算梯度
k1 = torch.tensor(random.random(), requires_grad=True)  # 权重需要计算梯度# 定义模型
def model(in_tensor, k0, k1):return k0 * in_tensor[0] + k1 * in_tensor[1]   # 定义了一个简单的线性模型# 定义损失函数
def loss_fn(y_pred, y_true):return (y_pred - y_true) ** 2   # 均方误差(MSE),计算预测值与真实值之间的平方差。# 训练过程
def train(iterations, in_tensor, k0, k1):for i in range(iterations):# 前向传播y_pred = model(in_tensor, k0, k1)   # 预测结果loss = loss_fn(y_pred, TARGET_VALUE)   # 损失值(可以理解为误差)# 反向传播loss.backward()# 更新权重with torch.no_grad():   # 停止梯度跟踪k0 -= LR * k0.grad  # k0减去它的梯度*学习率,完成一次权重的调整k1 -= LR * k1.grad  # k1减去它的梯度*学习率,完成一次权重的调整# 清零权重的梯度k0.grad.zero_()k1.grad.zero_()print(f"Iteration {i}: y_pred = {y_pred}, loss = {loss.item()}")# 开始训练
train(10, in_tensor, k0, k1)

运行结果:

 可以看到,由于模型很简单,收敛很快,经过10次训练,loss已经降到了0.98。

学习率的实验

上面是学习率LR = 0.01得到的训练结果,现改为LR = 0.015:

同样的10次训练,当学习率增加之后,loss已经降到了0.000625,模型的收敛速度加快。 

继续加大学习率,改为LR = 0.03: 

 loss已经降到了2.85e-9,模型的收敛速度更快了。

继续加大,LR = 0.06:

预测值剧烈震荡,模型无法收敛。

知识点:

        加大学习率,可以加快模型收敛速度,但是也不能过大,学习率过大的后果:

        1. 无法收敛

        • 跳过最优解: 学习率过大时,每次参数更新的步长也会很大,这可能导致模型在优化过程中跳过最优解。

        • 震荡: 模型参数可能会在最佳值附近来回震荡,无法稳定地达到收敛。

        • 梯度爆炸: 在极端情况下,学习率过大可能导致梯度值变得非常大,进而使得参数更新步长过大,甚至导致数值溢出(如NaN)。  

        2. 训练不稳定

        • 损失函数波动: 损失函数的值可能会在每次迭代中剧烈波动,而不是逐渐减小。

        • 泛化能力差: 由于模型参数未能稳定收敛,可能导致模型在测试集上的表现不稳定,泛化能力差。  

        3. 过拟合风险增加

        • 在某些情况下,即使模型最终收敛,也可能因为学习率过大而错过最优解,导致过拟合。

再来,将学习率变小,LR = 0.006:

模型也在持续收敛,但是比起LR = 0.01,收敛变慢了。 

LR = 0.004:

收敛更慢了。

知识点:

        学习率过小的后果:

        1. 收敛速度慢

        • 训练时间长: 由于每次参数更新的步长很小,模型需要更多的迭代次数才能达到最优解,导致训练时间显著增加。

        • 陷入局部最优: 在某些情况下,学习率过小可能导致模型陷入局部最优解,而不是全局最优解。  

        2. 过拟合风险增加

        • 过度训练: 由于训练时间过长,模型可能在训练集上过度拟合,导致在测试集上的表现下降。  

        3. 梯度消失

        • 接近零的梯度: 学习率过小,尤其是在深度神经网络中,可能导致梯度值变得非常小,进而使得参数更新几乎停滞,这种现象称为梯度消失。

早停机制

将局部代码改为:

LR = 0.016
train(100, in_tensor, k0, k1)

在训练了16次之后,loss已经为0。所以,就可以停止训练了。

局部代码修改为:

# 训练过程
def train(iterations, in_tensor, k0, k1):for i in range(iterations):# 前向传播y_pred = model(in_tensor, k0, k1)   # 预测结果loss = loss_fn(y_pred, TARGET_VALUE)   # 损失值(可以理解为误差)if loss <= 0.00001:   # 早停机制print(f"Iteration {i}: y_pred = {y_pred}, loss = {loss.item()}, ki = {k0}, k1 = {k1}")break# 反向传播
。。。。。。。。。。。

 当偏差足够小时,停止训练,并输出训练结果。

保存模型和使用模型预测

import torch
import random# 定义常量
TARGET_VALUE = 100
LR = 0.016  # 学习率# 初始化张量和权重
in_tensor = torch.tensor([2.0, 4.0])
pre_tensor = torch.tensor([2.2, 4.0])
k0 = torch.tensor(random.random(), requires_grad=True)  # 权重需要计算梯度
k1 = torch.tensor(random.random(), requires_grad=True)  # 权重需要计算梯度
model_state = []   # 模型参数# 定义模型
def model(in_tensor, k0, k1):return k0 * in_tensor[0] + k1 * in_tensor[1]# 定义损失函数
def loss_fn(y_pred, y_true):return (y_pred - y_true) ** 2# 训练过程
def train(iterations, in_tensor, k0, k1):for i in range(iterations):# 前向传播y_pred = model(in_tensor, k0, k1)loss = loss_fn(y_pred, TARGET_VALUE)if loss <= 0.00001:print(f"Iteration {i}: y_pred = {y_pred}, loss = {loss.item()}, ki = {k0}, k1 = {k1}")return [k0, k1]   # 返回训练后的模型参数break# 反向传播loss.backward()# 更新权重with torch.no_grad():   # 停止梯度跟踪k0 -= LR * k0.gradk1 -= LR * k1.grad# 清零权重的梯度k0.grad.zero_()k1.grad.zero_()# print(f"Iteration {i}: y_pred = {y_pred}, loss = {loss.item()}")# 开始训练
model_state = train(100, in_tensor, k0, k1) # 训练100次# 保存模型
torch.save(model_state, "model_state.pt")# 加载模型
model_state = torch.load("model_state.pt")
print(model_state)# 预测
y_pred = model(pre_tensor, model_state[0], model_state[1])
print(f"y_pred = {y_pred}, loss = {loss_fn(y_pred, TARGET_VALUE)}")

        在上面的代码中,我们保存了一个模型,并且用它预测了一个张量[2.2, 4.0],与我们训练用的数据[2.0, 4.0]相差不多,所以预测结果也相差不多。如果换成不同的数据,那么预测的结果也将会不同。

        推而广之,如果把输入的张量换成一个图像的像素阵列,预测结果换为判断类别,模型换为多层的卷积神经网络,再加上一些层间池化、输出激活函数,那么就是pytorch最常见的图像识别套路了。所以,无论模型和应用框架多么复杂,也是由最简单的结构迭加、衍生而成,将一个复杂的任务分解成一个个简单任务,它就不再复杂。

        以上为一点点初学者的肤浅心得,与大家交流共勉,望多指教!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/63177.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

使用lumerical脚本语言创建定向耦合器并进行数据分析(纯代码实现)

本文使用lumerical脚本语言创建定向耦合器波导、计算定向耦合器的偶数和奇数模式、分析定向耦合器的波长依赖性、分析定向耦合器的间隙依赖性(代码均有注释详解)。 一、绘制定向耦合器波导 1.1 代码实现 # 这段代码主要实现了绘制定向耦合器波导几何结构的功能。通过定义各种…

Linux 35.6 + JetPack v5.1.4之RTP实时视频Python框架

Linux 35.6 JetPack v5.1.4之RTP实时视频Python框架 1. 源由2. 思路3. 方法论3.1 扩展思考 - 慎谋而后定3.2 扩展思考 - 拒绝拖延或犹豫3.3 扩展思考 - 哲学思考3.4 逻辑实操 - 方法论 4 准备5. 分析5.1 gst-launch-1.05.1.1 xvimagesink5.1.2 nv3dsink5.1.3 nv3dsink sync05…

企业风险投资、融资事件数据(1921-2024)

数据包括历年上市与非上市企业的风险投资融资数据等数据&#xff0c;包括融资时间、被投企业、投资方、退出方等数据&#xff0c;希望对大家的研究有所帮助 一、数据介绍 数据名称&#xff1a;企业风险投资、融资事件 数据范围&#xff1a;上市与非上市企业 数据年份&#x…

移远5G模块移植

移远5G模块移植 1.NCM网卡配置2.拨号工具编译3.程序运行 1.NCM网卡配置 1.1、内核配置 打开内核配置界面&#xff0c;并找到USB Network Adapters进行NCM网卡配置 > Device Drivers > Network device support > USB Network Adapters 1.2、驱动修改 打开内核源码钟的…

煤矿 35kV 变电站 3 套巡检机器人 “上岗”,力破供电瓶颈

近日&#xff0c;杭州旗晟智能科技与甘肃某变电站配电室的三套智能巡检机器人线下测试顺利完成&#xff0c;并成功交付使用&#xff0c;这为电力运维工作注入了全新的活力与强大的技术支撑。 一、项目背景 甘肃某变电站总建筑面积1098平方米的变电站集变电、配电、监控等多功能…

docker 相关问题记录

docker mysql 一直重启解决办法&#xff08;断电或者重启&#xff09; 一直重启。。因为是内部开发&#xff0c;也没有备份最新的。所以不能删了重来。 方法&#xff1a; docker logs mysql5.7 看到错误跟innodb有关。 具体原因可以参考 http://acuilab.com/articles/2019/1…

Linux中Crontab(定时任务)命令详解

文章目录 Linux中Crontab&#xff08;定时任务&#xff09;命令详解一、引言二、Crontab的基本使用1、Crontab命令格式2、Crontab常用操作 三、Crontab的配置与服务管理1、配置Crontab2、服务管理 四、使用示例1、每天凌晨2点备份网站数据2、每周一凌晨3点清理临时文件3、每月的…

记录学习《手动学习深度学习》这本书的笔记(三)

这两天看完了第六章&#xff1a;卷积神经网络&#xff0c;巧的是最近上的专业选修课刚讲完卷积神经网络&#xff0c;什么卷积层池化层听得云里雾里的&#xff0c;这一章正好帮我讲解了基础的知识。 第六章&#xff1a;卷积神经网络 6.1 从全连接层到卷积 在之前的学习中&…

测试知识-高阻示波器的探头补偿

目录 探头补偿 探头补偿 调节补偿电容 调节补偿电容 探头补偿 设计到一个知识盲点&#xff0c;刚好复习补充下 探头补偿 理论知识 示波器和 10:1 探头的简化模型如上图所示&#xff0c;其中示波器的输入阻抗为 RscopeRscope​&#xff0c;探头的补偿电容为 CcompCcomp​。…

低空经济的第一助推力,基于鸿道Intewell操作系统的无人机控制系统

低空经济背景 低空经济是指利用低空空域资源进行经济活动的总和&#xff0c;包括无人机、通用航空、低空物流等新兴产业。近年来&#xff0c;随着技术的不断突破和政策的支持&#xff0c;低空经济正逐渐成为全球瞩目的新经济增长点。在中国&#xff0c;低空经济的发展受到了政…

软件测试基础详解(自动化测试/安全测试/性能测试)

&#x1f345; 点击文末小卡片&#xff0c;免费获取软件测试全套资料&#xff0c;资料在手&#xff0c;涨薪更快 自动化测试的意义 缩短软件开发测试周期&#xff0c;可以让产品更快投放市场 测试效率高&#xff0c;充分利用硬件资源 节省人力资源&#xff0c;降低测试成本 …

最小二乘法实际应用

最小二乘法 使用最小二乘法拟合大气二氧化碳浓度数据 数据保存在monthly_co2.xls文件中(只截取部分) python需要安装的库 xlrdnumpypandasmatplotlib 绘制图像代码(绘制整体数据趋势图) # -*- coding: utf-8 -*- """ File : 绘制趋势图.py Time : …

原生html+css+ajax实现二级下拉选择的增删改及树形结构列出

<?php $db_host localhost; $db_user info_chalide; $db_pass j8c2rRr2RnA; $db_name info_chalide; /* 数据库结构SQL CREATE TABLE categories ( id INT AUTO_INCREMENT PRIMARY KEY, name VARCHAR(255) NOT NULL, parent_id INT DEFAULT 0 ); */ try { $pdo new PD…

Linux操作系统--文件的重定向以及文件缓冲区

目录 前言 一、文件描述符的分配规则 二、重定向 三、系统中的重定向接口 1、dup2()介绍 2、dup2()使用 1&#xff09;输出重定向和追加重定向 2&#xff09;输入重定向 四、文件缓冲区 1、定义 2、缓冲区刷新的条件 1&#xff09;文件缓冲区存在的意义 2&…

5G CPE核心器件-基带处理器(三)

5G CPE 核心器件 -5G基带芯片 基带芯片简介基带芯片组成与结构技术特点与发展趋势5G基带芯片是5G CPE中最核心的组件,负责接入5G网络,并进行上下行数据业务传输。移动通信从1G发展到5G,终端形态产生了极大的变化,在集成度、功耗、性能等方面都取得巨大的提升。 基带芯片简…

mmdection配置-yolo转coco

基础配置看我的mmsegmentation。 也可以参考b站 &#xff1a;https://www.bilibili.com/video/BV1xA4m1c7H8/?vd_source701421543dabde010814d3f9ea6917f6#reply248829735200 这里面最大的坑就是配置coco数据集。我一般是用yolo&#xff0c;这个yolo转coco格式很难搞定&#…

Java 单元测试模拟框架-Mockito 的介绍

Mockito 是什么 Mockito 是一个用于单元测试的模拟框架&#xff0c;基于它可以使用简洁易用的API编写出色的测试。 Mockito 允许开发人员创建和管理模拟对象&#xff08;mock objects&#xff09;&#xff0c;以便在测试过程中替换那些不容易构造或获取的对象。 Mockito的基本…

NiFi-从部署到开发(图文详解)

NiFi简介 Apache NiFi 是一款强大的开源数据集成工具&#xff0c;旨在简化数据流的管理、传输和自动化。它提供了直观的用户界面和可视化工具&#xff0c;使用户能够轻松设计、控制和监控复杂的数据流程&#xff0c;NiFi 具备强大的扩展性和可靠性&#xff0c;可用于处理海量数…

draggable插件——实现元素的拖动排序——拖动和不可拖动的两种情况处理

最近在写后台管理系统的时候&#xff0c;遇到一个需求&#xff0c;就是关于拖动排序的功能。 我之前是写过一个关于拖动表格的功能&#xff0c;此功能可以实现表格中的每一行数据上下拖动实现排序的效果。 vue——实现表格的拖拽排序功能——技能提升 但是目前我这边的需求是…

Delphi Web前端开发教程(9):基于TMS WEB Core框架

3、REST Servers服务端(后端)框架 REST服务端特点&#xff1a; – 为远程资源提供一个REST API接口。也可以为其他网络内容提供服务&#xff1b; – 包括在Delphi Enterprise & Architect企业版和架构师版中的RAD服务器、DataSnap、WebBroker&#xff1b; – 开源框架&a…