AI学习指南深度学习篇-Python实践

AI学习指南深度学习篇 - Python实践

引言

在现代深度学习中,学习率是一个至关重要的超参数,它直接影响模型的收敛速度和最终效果。适当的学习率能够加速训练,但过大会导致模型不收敛,过小则可能导致训练过程过慢。因此,学习率衰减成为了深度学习中的一种常见策略,可以帮助我们在训练过程中逐步减小学习率。

在本篇文章中,我们将通过使用Python中的深度学习库(如TensorFlow和PyTorch)来演示学习率衰减的实现。同时,我们还将讨论在模型训练过程中如何进行有效的调参。

1. 学习率衰减的概念

学习率衰减指的是在训练过程中逐渐减小学习率的策略,目的是为了在训练初期快速收敛,并在后期细致优化。常用的学习率衰减策略包括:

  • 固定步长衰减:每隔固定步数就减小学习率。
  • 指数衰减:学习率按一定的指数基数衰减。
  • 余弦退火:学习率在一个固定范围内周期性变化。

2. 使用TensorFlow实现学习率衰减

2.1 环境准备

在开始之前,请确保您已经安装了TensorFlow库。如果未安装,可以通过以下命令进行安装:

pip install tensorflow

2.2 示例代码

在此示例中,我们创建一个简单的全连接神经网络,使用TensorFlow实现学习率衰减。

2.2.1 导入必要的库
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers
import numpy as np
import matplotlib.pyplot as plt
2.2.2 生成数据集

我们将生成一个简单的合成数据集,用于训练模型。

# 生成合成数据集
x_train = np.random.rand(1000, 20)
y_train = (np.sum(x_train, axis=1) > 10).astype(int)
x_test = np.random.rand(200, 20)
y_test = (np.sum(x_test, axis=1) > 10).astype(int)
2.2.3 建立模型
def create_model():model = models.Sequential([layers.Dense(64, activation="relu", input_shape=(20,)),layers.Dense(32, activation="relu"),layers.Dense(1, activation="sigmoid")])return model
2.2.4 定义学习率衰减策略

这里我们使用ExponentialDecay来实现指数衰减。

initial_learning_rate = 0.1
decay_steps = 100
decay_rate = 0.96lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(initial_learning_rate,decay_steps=decay_steps,decay_rate=decay_rate,staircase=True)
2.2.5 编译和训练模型
model = create_model()
model.compile(optimizer=optimizers.Adam(learning_rate=lr_schedule),loss="binary_crossentropy",metrics=["accuracy"])history = model.fit(x_train, y_train, epochs=100, validation_split=0.2, verbose=0)
2.2.6 可视化训练过程
plt.plot(history.history["accuracy"], label="accuracy")
plt.plot(history.history["val_accuracy"], label="val_accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()
plt.title("Training and Validation Accuracy with Learning Rate Decay")
plt.show()

2.3 结果分析

通过实际运行上述代码,我们可以观察到学习率的变化以及模型性能的提升。我们可以在训练过程中看到训练和验证准确率的折线图,更容易监控模型的学习效果。

3. 使用PyTorch实现学习率衰减

3.1 环境准备

确保您已经安装了PyTorch。如果未安装,可以通过以下命令进行安装:

pip install torch torchvision

3.2 示例代码

同样的,我们将使用PyTorch创建一个简单的神经网络并实现学习率衰减。

3.2.1 导入必要的库
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
3.2.2 生成数据集

与TensorFlow示例相同,生成合成数据集。

# 生成合成数据集
x_train = np.random.rand(1000, 20).astype(np.float32)
y_train = (np.sum(x_train, axis=1) > 10).astype(np.float32)
x_test = np.random.rand(200, 20).astype(np.float32)
y_test = (np.sum(x_test, axis=1) > 10).astype(np.float32)# 转换为PyTorch张量
x_train_tensor = torch.tensor(x_train)
y_train_tensor = torch.tensor(y_train).view(-1, 1)
3.2.3 建立模型
class SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.fc1 = nn.Linear(20, 64)self.fc2 = nn.Linear(64, 32)self.fc3 = nn.Linear(32, 1)def forward(self, x):x = torch.relu(self.fc1(x))x = torch.relu(self.fc2(x))x = torch.sigmoid(self.fc3(x))return xmodel = SimpleNN()
3.2.4 定义学习率衰减策略

使用torch.optim.lr_scheduler来实现学习率衰减。

initial_lr = 0.1
optimizer = optim.Adam(model.parameters(), lr=initial_lr)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.96)
3.2.5 训练模型
epochs = 100
train_losses = []for epoch in range(epochs):model.train()optimizer.zero_grad()output = model(x_train_tensor)loss = nn.BCELoss()(output, y_train_tensor)loss.backward()optimizer.step()scheduler.step()  # 更新学习率train_losses.append(loss.item())if epoch % 10 == 0:print(f"Epoch {epoch}, Loss: {loss.item()}, Learning Rate: {scheduler.get_last_lr()}")
3.2.6 可视化训练过程
plt.plot(train_losses)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training Loss with Learning Rate Decay")
plt.show()

3.3 结果分析

通过观察训练损失的变化,可以记住在学习率衰减策略下模型的学习过程。降低学习率使得模型在训练后期能够更加细致地优化,避免错过局部最优。

4. 调参技巧

学习率衰减是深度学习模型训练中的重要一环,但选择合适的衰减参数(例如:初始学习率、衰减步长和衰减率)对于训练效果有显著影响。以下是一些调参技巧:

  1. 网格搜索(Grid Search):系统性地尝试不同的学习率、衰减率和衰减步长的组合,以找到最佳设置。

  2. 学习率范围测试:以线性或对数方式增加学习率,观察损失变化,从而找到一个合理的初始化学习率。

  3. 早停法和检查点:结合其他技术(如早停法),记住保存最佳模型,以防止过拟合。

  4. 微调策略:对大规模预训练模型进行微调时,使用较小的学习率衰减策略。

5. 小结

本文介绍了在深度学习中如何使用TensorFlow和PyTorch实现学习率衰减策略。我们从基本概念入手,展示了具体的代码示例,并探讨了调参技巧。学习率衰减不仅能够帮助模型更好地收敛,也为我们在深度学习中的其他调参策略提供了启示。

希望这些实践能够帮助到您在深度学习的研究与应用中更进一步!如有任何问题或建议,欢迎交流讨论!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/881666.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python异步网络编程详解:多种实现方法与案例分析

目录 Python异步网络编程详解:多种实现方法与案例分析一、异步网络编程的基本概念1.1 同步与异步1.2 并发与并行1.3 Python中的异步编程模型 二、使用 asyncio 实现异步网络编程2.1 基本概念2.2 面向对象的异步HTTP请求实现代码实现详细解释: 三、使用 a…

【汇编语言】寄存器(CPU工作原理)(七)—— 查看CPU和内存,用机器指令和汇编指令编程

文章目录 前言1. 预备知识:Debug的使用1.1 什么是Debug?1.2 我们用到的Debug功能1.3 进入Debug1.3.1 对于16位或者32位机器的进入方式1.3.2 对于64位机器的进入方式 1.4 R命令1.5 D命令1.6 E命令1.7 U命令1.8 T命令1.9 A命令 2. 总结3. 实操练习结语 前言…

grpc的python使用

RPC 什么是 RPC ? RPC(Remote Procedure Call)远程过程调用,是一种计算机通信协议,允许一个程序(客户端)通过网络向另一个程序(服务器)请求服务,而无需了解…

笔试算法总结

文章目录 题目1题目2题目3题目4 题目1 使用 StringBuilder 模拟栈的行为&#xff0c;通过判断相邻2个字符是否相同&#xff0c;如果相同就进行删除 public class Main {public static String fun(String s) {if (s null || s.length() < 1) return s;StringBuilder builde…

前端开发基础NodeJS+NPM基本使用(零基础入门)

文章目录 1、Nodejs基础1.1、NodeJs简介1.2、下载安装文件1.3、安装NodeJS1.4、验证安装2、Node.js 创建第一个应用2.1、说明2.2、创建服务脚本2.3、执行运行代码2.4、测试访问3、npm 基本使用3.1、测试安装3.2、配置淘宝npm镜像3.3.1、本地安装3.3.2、全局安装3.4、查看安装信…

【网络】详解TCP协议的流量控制和拥塞控制

【网络】详解TCP协议的流量控制和拥塞控制 一. 流量控制模型窗口探测 二. 拥塞控制模型 总结 一. 流量控制 流量控制主要考虑的是接收方的处理速度。 接收端处理数据的速度是有限的.。如果发送端发的太快, 导致接收端的缓冲区被打满, 这个时候如果发送端继续发送, 就会造成丢包…

Composer入门详解

文章目录 Composer入门详解一、引言二、Composer的安装1、Windows平台1.1、步骤 2、Linux平台2.1、步骤 3、Mac OS系统3.1、步骤 三、Composer的使用1、创建composer.json文件1.1、示例 2、安装依赖3、更新依赖4、移除依赖 四、总结 Composer入门详解 一、引言 Composer 是 P…

IP地址如何支持远程办公?

由于当今社会经济的飞速发展&#xff0c;各个方向的业务都不免接触到跨省、跨市以及跨国办公的需要&#xff0c;随之而来的远程操作的不方便&#xff0c;加载缓慢&#xff0c;传输文件时间过长等困难&#xff0c;如何在万里之外实现远程办公呢&#xff1f;我们以以下几点进行阐…

今日sql学习

目录 一,union all 二&#xff0c;GROUP_CONCAT函数 三&#xff0c;字符串函数&#xff1a;RIGHT()函数 四&#xff0c;字符串函数&#xff1a;LENGTH() 函数 / REPLACE&#xff08;&#xff09;函数 五&#xff0c;条件更新表内的值 六&#xff0c;创建外键 七&…

【GaussDB】产品简介

产品定位 GaussDB 200是一款具备分析及混合负载能力的分布式数据库&#xff0c;支持x86和Kunpeng硬件架构&#xff0c;支持行存储与列存储&#xff0c;提供PB(Petabyte)级数据分析能力、多模分析能力和实时处理能力&#xff0c;用于数据仓库、数据集市、实时分析、实时决策和混…

3DCAT实时云渲染赋能2024广东旅博会智慧文旅元宇宙体验馆上线!

广东国际旅游产业博览会&#xff08;以下简称“旅博会”&#xff09;是广东省倾力打造的省级展会品牌&#xff0c;自2009年独立成展至今已成功举办十五届。2024广东旅博会于9月13—15日在广州中国进出口商品交易会展馆A区举办&#xff0c;线上旅博会“智慧文旅元宇宙体验馆”于…

【C++刷题】力扣-#35-搜索插入位置

题目描述 给定一个已排序的数组 nums 和一个目标值 target&#xff0c;你需要找出 target 在数组中的插入位置&#xff0c;以保持数组的有序性。如果 target 已经存在于数组中&#xff0c;则返回它的索引。 示例 示例 1 输入: nums [1,3,5,6], target 5 输出: 2 解释: 值 5…

力扣21~30题

21题&#xff08;简单&#xff09;&#xff1a; 分析&#xff1a; 按要求照做就好了&#xff0c;这种链表基本操作适合用c写&#xff0c;python用起来真的很奇怪 python代码&#xff1a; # Definition for singly-linked list. # class ListNode: # def __init__(self, v…

awk脚本通过将多行拼接成一行并在 SQL 语句的末尾遇到分号(;)时打印出来

concact_sql.awk 脚本如下 {if ($0 ~ /^SELECT|UPDATE|DELETE|INSERT/) {flag 1buffer ""gsub(/\n/, "", $0)cli $0} else if (flag) {if ($0 ~ /;/) {flag 0print cli buffer $0} else {buffer buffer $0}} else if (flag 0) {print $0} } 这个 aw…

《镂空·幻象---基于新中式古国文化艺术文旅可延续VR个性化体验》| OPENAIGC开发者大赛高校组优秀作品

在第二届拯救者杯OPENAIGC开发者大赛中&#xff0c;涌现出一批技术突出、创意卓越的作品。为了让这些优秀项目被更多人看到&#xff0c;我们特意开设了优秀作品报道专栏&#xff0c;旨在展示其独特之处和开发者的精彩故事。 无论您是技术专家还是爱好者&#xff0c;希望能带给…

前端将JSON或者table直接导出为excel

一、引入Sheetjs或者npm直接下载 <script lang"javascript" src"https://cdn.sheetjs.com/xlsx-0.20.3/package/dist/xlsx.full.min.js"></script> 二、页面中使用 //json导出为excel <button onclick"exportExcel()">导出…

Python+whisper/vosk实现语音识别

目录 一、Whisper 1、Whisper介绍 2、安装Whisper 3、使用Whisper-base模型 4、使用Whisper-large-v3-turbo模型 二、vosk 1、Vosk介绍 2、vosk安装 3、使用vosk 三、总结 一、Whisper 1、Whisper介绍 Whisper 是一个由 OpenAI 开发的人工智能语音识别模型&#xf…

Parallels Desktop意外退出,Parallels Desktop安装软件很卡闪退怎么办?

Parallels Desktop是目前很优秀的虚拟机软件&#xff0c;操作简单&#xff0c;兼容性强而且安装也非常方便&#xff0c;备受苹果用户的喜爱和满意。然而&#xff0c;部分用户在使用Parallels Desktop的时候&#xff0c;会遇到意外退出或终端关机的情况&#xff0c;这不仅会影响…

利用 Llama 3.1模型 + Dify开源LLM应用开发平台,在你的Windows环境中搭建一套AI工作流

文章目录 1. 什么是Ollama&#xff1f;2. 什么是Dify&#xff1f;3. 下载Ollama4. 安装Ollama5. Ollama Model library模型库6. 本地部署Llama 3.1模型7. 安装Docker Desktop8. 使用Docker-Compose部署Dify9. 注册Dify账号10. 集成本地部署的 Llama 3.1模型11. 集成智谱AI大模型…