解读简单的一段深度学习代码(已跑通)

 

最近一直想要学习深度学习的内容,想要复现大佬的代码,试了好多有的是不给数据,有的总是跑不通,这一个是已经跑通的一个代码,以上为证,学习最快的方式就是直接实战,大部分的内容都是比较偏向于理论,个人还是偏向于直接跑通代码,这部分代码相对比较容易理解,从头到尾逐行剖析它的每一行代码。

训练文件train.py

import numpy as np
import tensorflow as tf
from utils import get_data, model_LSTM
import matplotlib.pyplot as plt
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
import os

 

刚开始接触的时候对于第三行代码确实没太理解,utils.py相当于是你自己编写的一个程序包,可以把它想象成类似于机器学习中sklearn集成包那种,可以从里边调用许多自己需要的包,本文中的utils.py是由自己进行编写的符合本次项目的一个集成包,但是注意的内容是要放在同一个目录下进行调用。此时从utils.py包中调用的两个包一个是获取数据的方式,一个是自己定义的lstm模型,

此时打开utils.py文件后可以看到已经定义好的两个模型都存在。此时我们的模型的好坏是使用验证损失来衡量的,所以以下这三个参数都是以损失值来衡量的。 

ModelCheckpoint: 每个epoch结束后,如果验证损失下降,就自动保存模型权重。这样可以保存训练过程中损失最小的最优模型,在模型训练过程中我们会选择让模型迭代多少次,后一次循环结束之后的损失值小于前一次时便会自动进行保存最好的模型。

EarlyStopping: 当验证损失停止下降时,自动中止训练,防止过拟合,当在一次次的训练过程中损失值一直增长反而不下降时模型便会终止训练,假如损失值没有下降再训练下去也是毫无意义。说到这里可能大家会有一个小的疑问,难道时一次增长之后便停止吗,当然不是了,这个参数也是可以自己进行设置,设置的是patiences参数耐心值,在这个范围之内损失函数没有变化的话模型便会自动停止训练,根据下边的代码可以看到本次设置的是15。

ReduceLROnPlateau: 当验证损失平台期时,自动减小学习率。这样可以帮助训练更快收敛到最优点。这里可能会有疑问既然是减少学习率的话,具体是减少多少呢,本文中的factor为0.5,那样的话是不是可以无限减少呢,当然不是了,下文中的min_lr规定了减小到最低什么程度。

    checkpoint = ModelCheckpoint(filepath=checkpoint_save_path,monitor='val_loss',verbose=1,save_weights_only=True,save_best_only=True)earlystopping = EarlyStopping(monitor='val_loss',patience=15,verbose=1,mode='auto',restore_best_weights=True)rlrop = ReduceLROnPlateau(monitor='val_loss',patience=5,factor=0.5,min_lr=1e-7,verbose=1)

verbose参数在ReduceLROnPlateau回调函数中控制打印日志的详细程度。

它有以下几个取值:

0: 不打印日志,静默模式。

1: 打印一个日志消息来表示学习率下降。

2: 打印日志并显示学习率下降的老值和新值。

save_weights_only参数在ModelCheckpoint回调函数中是这个意思:

设置为True时,只保存模型的权重参数(权重文件),不保存模型的结构信息和编译信息。设置为False时(默认值),会保存整个模型结构、编译信息以及权重等所有信息。使用save_weights_only=True的好处是:保存的文件尺寸更小,重量更轻。常用在云训练场景,可以节约存储空间。只保存权重参数就可以,加载时可使用原模型结构。便于版本控制和权重文件分享。在功能上等同于保存整个模型,但是文件更小。

缺点是:

需要确保加载权重时模型结构保持一致。不保存编译信息,加载时需要重新编译模型。所以,如果只关心模型训练结果,可以设置save_weights_only=True以节省空间。如果需要保存完整模型结构信息,可以保持默认False。一般情况下,ModelCheckpoint常用这个设置来自动收集训练过程中的最佳权重参数。

 

filepath参数指定ModelCheckpoint回调函数将检查点文件保存的路径。

一些关键点:

filepath可以是目录路径或包含文件名的全路径。如果只指定目录但不指定文件名,它将自动生成文件名规则为'epoch{epoch:02d}-loss{loss:.2f}.h5'。每次触发保存都会覆盖同名文件,所以只会得到最后一个检查点。可以设置save_best_only=True仅保存损失最小的那个检查点。默认情况下,检查点文件将包含模型的结构、权重和优化器状态等完整信息。

min_lr参数是在ReduceLROnPlateau回调函数中使用的,它用于设置学习率的最小值。

具体说:ReduceLROnPlateau会在验证集精度停滞时自动减少学习率。每次减少的比例由factor参数控制,默认是0.1。但减少后学习率不能低于min_lr定义的最小值。所以:min_lr的默认值是0,表示学习率可以降至0。如果设置为1e-7,那么学习率每次减小后,将保证大于或等于1e-7。

monitor参数指定ReduceLROnPlateau和EarlyStopping这两个回调函数要监测的指标。

对于ReduceLROnPlateau,默认监测指标是:monitor='val_loss',即监测验证集上的损失函数(loss)值。当验证集loss值在一定数量的epoch内没有下降,就会触发学习率下调。

而EarlyStopping的默认也是:monitor='val_loss'它会监测验证集loss的变化情况,如果一定 epoch内loss没有下降,就会提前结束训练防止过拟合。

所以:monitor参数允许用户更改监测的指标,如:monitor='val_accuracy',那么二者就会监测验证集上准确率是否在提升等。此外一些常用的指标如:

'accuracy'
'f1_score'
'categorical_accuracy'
'mae'
'mse'
都可以作为替代monitor的选项。

通过修改该参数,可以让回调函数基于其他 mehr指标自动调整学习率和停止训练,更符合实际需要,看你自己最注重的是什么以此来确定衡量函数来判断。

根据以上对各个参数变量的解读,训练模型所需要的整体框架的各个参数便基本确定了,接下来便是开始进行具体的数据输入和训练

if __name__ == "__main__":time = 144X_train, y_train, X_test, y_test = get_data(test_num=time)model = model_LSTM()tf.keras.utils.plot_model(model, show_shapes=True)model.compile(loss=tf.keras.losses.mse,optimizer=tf.keras.optimizers.Adam(),)

这段代码一般放在if name == "main"下,保证主流程只在直接运行文件时执行,作为单元测试模型的主程序部分,加入if name == "main"这一行代码之后,它后边的代码类似于上文utils.py中定义的函数便不会被调用。

def get_data(step=200, test_num=144):df = pd.read_csv("../datasets/Training/PV_data_all.csv")df.drop("t", axis=1, inplace=True)df.drop(['InclAngle', "Current_V", "Voltage", "Humidity"], axis=1, inplace=True)sc = MinMaxScaler(feature_range=(-1, 1))df = sc.fit_transform(df)df = pd.DataFrame(df)datasets_X, datasets_y = [], []for index in range(len(df) - step):datasets_X.append(np.array(df.iloc[index: index+step, :]))datasets_y.append(np.array(df.iloc[index + step]))datasets_X, datasets_y = np.array(datasets_X), np.array(datasets_y)X_train = datasets_X[:-test_num]y_train = datasets_y[:-test_num]X_test = datasets_X[-test_num:]y_test = datasets_y[-test_num:]return X_train, y_train, X_test, y_test

那具体的获取数据的流程是怎么样的呢,定义的获取数据的函数是如何定义的呢,看上图所示。本文所使用的数据如下图所示。

数据处理过程中首先将不需要的列删掉,以及进行数据标准化转换格式等常规操作。这里设置了测试数据的数量为144,同时设置了步长为200 ,根据索引确定数据X和y.

索引从0到len(df)-step设置范围的原因是:df中的数据点数量为len(df)

每个被提取出来的序列样本的长度为step个数据点

所以序列样本的第一个数据点的索引必须小于len(df)-step

举个简单例子:

如果df总长度是10个点:

len(df) = 10

取每序列3个点作为一个样本(step=3)

那么:

第一个样本可以是点0-2,索引从0提取
第二个样本可以是点1-3,索引从1提取
第三个样本可以是点2-4,索引从2提取
但是第四个样本点3-5就超出了df范围。

所以最大索引只能是len(df)-step,也就是10-3=7

这就是为什么设置索引范围从0到len(df)-step的原因:

保证每次提取的序列范围始终在原始数据df范围内
提取完所有可能的序列样本不会溢出原始数据
这个设置很重要,可以正确和完整地从时序数据中分段提取样本

def model_LSTM():seq_length, data_dim, output_dim = 200, 4, 4input_layer = Input(shape=(seq_length, data_dim))x = Bidirectional(LSTM(256, return_sequences=True))(input_layer)x1 = Bidirectional(LSTM(256, return_sequences=True))(x)x2 = Bidirectional(GRU(256, return_sequences=True))(x)x = Concatenate(axis=2)([x1, x2])x = Bidirectional(LSTM(128, return_sequences=True))(x)dense = tf.keras.layers.Flatten()(x)dense = Dense(512, activation="relu")(dense)dense = Dropout(0.2)(dense)dense = Dropout(0.2)(dense)output = Dense(output_dim)(dense)model = tf.keras.Model(inputs=input_layer, outputs=output)return model

此时定义的LSTM模型结构如上图所示

这段代码定义了一个基于LSTM和Bidirectional的时序模型:

Input定义输入序列的长度和维度,第一个Bidirectional-LSTM层返回序列,加深特征提取,第二个Bidirectional-LSTM层再进行一次提取,一个Bidirectional-GRU层提取另一种特征,Concatenate层把LSTM和GRU提取的特征拼接在一起,第三个Bidirectional-LSTM层进行整合,Flatten层展平多维输入为一维,后面是标准Dense层结构进行分类任务

主要特点:

使用Bidirectional可以从前向和后向流提取特征,拼接LSTM和GRU提取不同特征联合,3层LSTM提取深层次时间特征,最后 Dense层作为分类器,该模型充分利用LSTM various能力进行序列建模并实现了前向后向拼接多路网络设计,是一种相对完整的时序分类模型定义方式。

Bidirectional意为双向的。在定义LSTM层时使用了Bidirectional参数,它可以让LSTM模型具有双向学习能力:

标准的RNN只能从输入序列的第一个时间步开始依次向后处理每个时间步的特征,但实际上后续时间步的信息也可能对当前时间步预测结果有帮助。Bidirectional LSTM可以同时用前向网络和后向网络来处理输入序列:前向网络从第一个时间步开始依次处理输入;后向网络从最后一个时间步开始依次处理输入;两网结果通过拼接相加,得到每个时间步的完整表示。这样实际上让模型学习输入序列的双向上下文信息:前向网络学习前向上下文;后向网络学习后向上下文;两者融合可以学习更完整的序列表示。所以使用Bidirectional LSTM可以充分利用序列中不同时间步的关联信息,借此提升模型学习和推理效果。它成为RNN在许多序列任务上的首选结构之一。

    checkpoint_save_path = "./checkpoint/LSTM.weights.h5"if os.path.exists(checkpoint_save_path + ".index"):print("load the model".center(50, "-"))model.load_weights(checkpoint_save_path)

 

这段代码是用于加载已有预训练模型权重的:

首先定义了权重保存文件的路径为checkpoint_save_path,判断这个路径下是否存在".index"文件(Keras自动生成的文件)。如果存在,说明这套权重文件之前已经保存过,打印一句提示语句,中心对齐长度为50的"-",起示意作用,然后直接调用model.load_weights(filepath),将权重文件加载到模型

作用:

检查指定路径是否有保存的模型权重,如果有,直接加载这套权重初始化模型

这样可以实现:

从上次训练得出的权重加载模型,继续训练,避免从零开始,加快收敛,对已有模型做微调等使用,这在模型训练过程中迭代和微调很有用。通过预训练权重的载入,可以在保留学习信息的基础上优化模型,而不是从零开始训练。

    callbacks = [checkpoint, earlystopping, rlrop]history = model.fit(X_train, y_train, epochs=100, batch_size=64, validation_split=0.2, callbacks=callbacks)

callbacks列表的作用是把各种回调函数传给model.fit().需要注意几点:回调函数都定义好了,如checkpoint, earlystopping, rlrop等。将它们收集到一个列表callbacks中。在调用model.fit()时,会通过callbacks参数传入这个列表。batch_size=64:设置每个训练批次(batch)的数据量大小为64个样本。区分批训练和全量训练,可以减小内存占用。

    plan = 1if plan == 1:pred = []X_pred = X_test[0]pred_0 = np.array(model.predict(X_pred.reshape(1, 200, 4)))pred.append(pred_0[0])for i in range(time-1):X_pred = np.vstack((X_pred[1:], pred_0))pred_0 = np.array(model.predict(X_pred.reshape(1, 200, 4)))[0]pred.append(pred_0)pred = np.array(pred)feature_list = ["AmbiTemp", "Irradiance", "ModuleTemp", "TruePower"]plt.figure(figsize=(20, 12))for index in range(4):plt.subplot(2, 2, index+1)y_pred = pred[:, index]y_test_show = y_test[:time, index]plt.plot(y_test_show, label="{}_true".format(feature_list[index]))plt.plot(y_pred, label="{}_pred".format(feature_list[index]))plt.legend()plt.savefig("./pred_1.png")if plan == 1:pred = np.array(model.predict(X_test))feature_list = ["AmbiTemp", "Irradiance", "ModuleTemp", "TruePower"]plt.figure(figsize=(20, 12))for index in range(4):plt.subplot(2, 2, index + 1)y_pred = pred[:, index]y_test_show = y_test[:, index]plt.plot(y_test_show, label="{}_true".format(feature_list[index]))plt.plot(y_pred, label="{}_pred".format(feature_list[index]))plt.legend()plt.savefig("./pred_2.png")

这段代码实现了两种LSTM模型在测试数据上的预测方式,并可视化对比结果:

第一个if块(plan=1):

使用时序预测的方式循环预测,每次使用前面步的预测值作为后面步的输入

图片命名为pred_1.png

第二个if块(plan=1):

直接使用模型 predict()接口批量预测整个测试集

图片命名为pred_2.png

主要区别:

第一种循环预测考虑时序关系,但可能会累积误差

第二种直接批量预测考虑全局信息,但忽略时序

两种方式结果画在一张图中可视化对比:

性能是否一致

预测曲线是否平滑或者是否存在震荡

这样可以评估两种预测策略在时序任务中的表现,选择更适合的方式。

是LSTM和其他RNN模型评估的常用方法。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/6706.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

机器学习笔记导航(吴恩达版)

01.机器学习笔记01:机器学习前置概念导入、线性回归、梯度下降算法 02.机器学习笔记02:多元线性回归、多元梯度下降算法、特征缩放、均值归一化、正规方程 03.机器学习笔记03:octave安装、创建矩阵 04.机器学习笔记04:octave中移动…

LeetCode题练习与总结:分隔链表--86

一、题目描述 给你一个链表的头节点 head 和一个特定值 x ,请你对链表进行分隔,使得所有 小于 x 的节点都出现在 大于或等于 x 的节点之前。 你应当 保留 两个分区中每个节点的初始相对位置。 示例 1: 输入:head [1,4,3,2,5,2]…

MyScaleDB:SQL+向量驱动大模型和大数据新范式

大模型和 AI 数据库双剑合璧,成为大模型降本增效,大数据真正智能的制胜法宝。 大模型(LLM)的浪潮已经涌动一年多了,尤其是以 GPT-4、Gemini-1.5、Claude-3 等为代表的模型你方唱罢我登场,成为当之无愧的风口…

第五十三节 Java设计模式 - 工厂模式

Java设计模式 - 工厂模式 工厂模式是一种创建模式,因为此模式提供了更好的方法来创建对象。 在工厂模式中,我们创建对象而不将创建逻辑暴露给客户端。 例子 在以下部分中,我们将展示如何使用工厂模式创建对象。 由工厂模式创建的对象将是…

因果推断 | 潜在结果框架的基础知识

文章目录 1 引言2 框架描述2.1 问题定义2.2 数学表达式 3 实现方案3.1 随机实验数据3.2 一般数据 4 方案评估5 总结6 相关阅读 1 引言 在上一篇文章(运筹从业者也需要的因果推断入门:基础概念解析和体系化方法理解)中,已经对因果…

Linux下Palabos源码编译安装及使用

目录 软件介绍 基本依赖 其它可选依赖 一、源码下载 二、解压缩(通过方式1下载源码.zip格式) 三、编译安装 3.1 自带算例 ​编辑3.2 自行开发算例 四、简单使用 4.1 串行运行 4.2 并行运行 4.3 查看结果 软件介绍 Palabos是一款基于LBM&…

EXCEL怎样把筛选后含有公式的数据,复制粘贴到同一行的其它列?

自excel2003版之后,常规情况下,复制筛选后的数据,会忽略隐藏行,仅复制其筛选后的数据,粘贴则是粘贴到连续单元格区域,不管行是在显示状态还是隐藏状态。 一、初始数据: 二、题主的复制粘贴问题…

windows驱动开发-内核调度(一)

驱动层面的调度和同步一向是内核中比较困难的部分,和应用层不一样,内核位于系统进程下,所以它的调度和同步一旦出现纰漏,那会影响所有的程序,而内核并不具备对于这种情况下的纠错能力,没有异常手段能够让挂…

植物生态化学计量主要理论和假说

1 功能关联假说 描述化学计量特征与植物生长功能的关联, 主要包括: (1) 生长速率假说(Growth Rate Hypothesis) (Sterner & Elser, 2002): 随生长速率增加, 植物N:P和C:P呈降低趋势, 而P 含量呈增加趋势。该假说有助于理解植物生长速率的调控机制, 但受其他因素调控…

EPAI手绘建模APP动画编辑器、信息、工程图

④ 动画:打开关闭动画编辑器。APP中动画包含两个部分,动画编辑器和动画控制器。动画编辑器用来编辑动画。具体来说,选中一个模型后,给模型添加移动、旋转、缩放三种关键帧,不同的模型添加不同的关键帧,实现…

40.乐理基础-拍号-什么是一拍

拍: 首先 以Y分音符的时长为一拍 这一句话,然后拍是音乐中的时长单位,但这个时长单位有点特殊,它并不是完全绝对的某一个时间,而正是因为如此,所以不能用 秒 之类的,已经很确定很绝对的时间单位…

matlab例题大全

1.第1章 MATLAB系统环境 1.1 注:plot函数为画图函数。例plot(x1,y1,:,x2,y2,*); 1.2 注:root为求根函数。p为方程变量前面系数矩阵。 1.3 注: 2*x3y-1*z 2; 8*x2*y3*z 4; 45*x3*y9*z 23 求:x,y,z的…

关于位操作符的实际应用<C语言>

前言 位操作符在C语言初学阶段相对其他操作符来说,是一种难度比较大的操作符,且运用较少的一类操作符,但是位操作符并不是“一无是处”,合理运用的位操作符,在某些场景下可以优化算法,提高代码的执行效率&a…

PyQt5:Qt Designer使用重载的自定义类提升控件

1,以QPushButton举例 2,右击需要提升的控件,选择【提升为...】 3,添加自定义类,不用管 .h 的后缀,不影响使用。 4,完成 5,说明:自定义类的:__init__()方法…

基于STC12C5A60S2系列1T 8051单片机的IIC通信的0.96寸4针OLED12864显示16行点x16列点字模的功能

基于STC12C5A60S2系列1T 8051单片机的IIC通信的0.96寸4针OLED12864显示16行点x16列点字模的功能 STC12C5A60S2系列1T 8051单片机管脚图STC12C5A60S2系列1T 8051单片机I/O口各种不同工作模式及配置STC12C5A60S2系列1T 8051单片机I/O口各种不同工作模式介绍液晶显示器OLED12864简…

抖音直播间小风车怎么挂?直播间小风车跳转微信怎么开通!

抖音直播已经成为了一个非常受欢迎的直播平台,而在直播间引流也是用户非常关注的一个话题。而针对这个问题,抖音也提供了一种非常好用的小工具——小风车,可以帮助用户在直播间进行引流。那么,抖音直播间小风车怎么挂?…

记录几种排序算法

十种常见排序算法可以分类两大类别:比较类排序和非比较类排序。 常见的快速排序、归并排序、堆排序以及冒泡排序等都属于比较类排序算法。比较类排序是通过比较来决定元素间的相对次序,其时间复杂度不能突破 O(nlogn)。在冒泡排序之类的排序中&…

扩展学习|本体研究进展

文献来源: 王向前,张宝隆,李慧宗.本体研究综述[J].情报杂志,2016,35(06):163-170. 一、本体的定义 本体概念被引入人工智能、知识工程等领域后被赋予了新的含义。然而不同的专家学者对本体的理解不同,所给出的定义也有所差异。 人工智能领域的学者Neches(1991)等人对…

Docker Compose 部署若依前后端分离版

准备一台服务器 本次使用虚拟机,虚拟机系统 Ubuntu20.04,内存 4G,4核。 确保虚拟机能连接互联网。 Ubuntu20.04 安装 Docker 添加 Docker 的官方 GPG key: sudo apt-get update sudo apt-get install ca-certificates curl su…

初始面相对象

初始面向对象 类和对象的关系 类:对对象向上抽取出像的部分、公共的部分以此形成类,类就相当于一个模版。 对象:在某个模版下的具体的产物可以理解为对象,对象就是一个一个具体的实例,就相当于这个模版下具体的产品&…