train_encoder_decoder.py

train_encoder_decoder.py

from __future__ import print_function #为了确保代码同时兼容Python 2和Python 3版本中的print函数# 导入标准库和第三方库
import os.path #导入了Python的os.path模块,用于处理文件和目录路径
from os import path #从os模块中导入了path子模块,可以直接使用path来调用os.path中的函数import sys #导入了sys模块,用于系统相关的参数和函数
import math #导入了math模块,提供了数学运算函数
import numpy as np #导入了NumPy库,并使用np作为别名,NumPy是用于科学计算的基础库
import pandas as pd #导入了Pandas库,并使用pd作为别名,Pandas是用于数据分析的强大库# 导入深度学习相关库
import tensorflow as tf #导入了TensorFlow深度学习框架from keras import backend as K #导入了Keras的backend模块,并使用K作为别名,用于访问后端引擎的函数
from keras.models import Model #从Keras导入了Model类,用于定义神经网络模型
from keras.layers import LSTM, GRU, TimeDistributed, Input, Dense, RepeatVector #从Keras导入了LSTM、Input和Dense等神经网络层
from keras.callbacks import CSVLogger, EarlyStopping, TerminateOnNaN #从Keras导入了CSVLogger、EarlyStopping和TerminateOnNaN等回调函数,用于模型训练时的控制和记录
from keras import regularizers #从Keras导入了regularizers模块,用于正则化
from keras.optimizers import Adam #从Keras导入了Adam优化器,用于编译模型时指定优化算法# 导入其他功能模块
from functools import partial, update_wrapper #从Python标准库functools中导入了partial和update_wrapper函数,用于函数式编程中的功能扩展和包装# 这个函数的作用是创建一个部分应用(partial application)的函数,并保留原始函数的文档字符串等信息。
def wrapped_partial(func, *args, **kwargs):partial_func = partial(func, *args, **kwargs)update_wrapper(partial_func, func)return partial_func# 这是一个自定义的损失函数,计算加权的均方误差(Mean Squared Error),其中y_true是真实值,y_pred是预测值,weights是权重。
def weighted_mse(y_true, y_pred, weights):return K.mean(K.square(y_true - y_pred) * weights, axis=-1)# 这部分代码用于选择使用的GPU设备。它从命令行参数中获取一个整数值gpu,如果gpu小于3,则设置CUDA环境变量以指定使用的GPU设备
import os
gpu = int(sys.argv[-13])
if gpu < 3:os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152os.environ["CUDA_VISIBLE_DEVICES"]= "{}".format(gpu)from tensorflow.python.client import device_libprint(device_lib.list_local_devices())# 这部分代码获取了一系列命令行参数,并将它们分别赋值给变量 
# 这些参数可能包括数据集名称、训练的批次数量、训练周期数、学习率、正则化惩罚、Dropout率、耐心(用于Early Stopping)等 
imp = sys.argv[-1]
T = sys.argv[-2]
t0 = sys.argv[-3]
dataname = sys.argv[-4] 
nb_batches = sys.argv[-5]
nb_epochs = sys.argv[-6]
lr = float(sys.argv[-7])
penalty = float(sys.argv[-8])
dr = float(sys.argv[-9])
patience = sys.argv[-10]
n_hidden = int(sys.argv[-11])
hidden_activation = sys.argv[-12]# results_directory 是一个字符串,表示将要创建的结果文件夹路径。dataname 是之前从命令行参数中获取的数据集名称
# 如果这个文件夹路径不存在,就使用 os.makedirs 函数创建它。这个路径通常用于存储训练模型的结果或者日志
results_directory = 'results/encoder-decoder/{}'.format(dataname)if not os.path.exists(results_directory):os.makedirs(results_directory)# 定义了一个函数 create_model,用于创建、编译和返回一个循环神经网络(RNN)模型
def create_model(n_pre, n_post, nb_features, output_dim, lr, penalty, dr, n_hidden, hidden_activation):""" creates, compiles and returns a RNN model @param nb_features: the number of features in the model"""# 这里定义了两个输入层:inputs 是一个形状为 (n_pre, nb_features) 的输入张量,用于模型的主输入;weights_tensor 是一个形状相同的张量,用于传递权重或其他需要的信息inputs = Input(shape=(n_pre, nb_features), name="Inputs")  weights_tensor = Input(shape=(n_pre, nb_features), name="Weights") # 这里使用了两个 LSTM 层:lstm_1 是一个具有 n_hidden 个单元的 LSTM 层,应用了 dropout 和 recurrent_dropout,并且返回整个时间序列的输出。lstm_2 是一个相同的 LSTM 层,但它只返回最后一个时间步的输出。lstm_1 = LSTM(n_hidden, dropout=dr, recurrent_dropout=dr, activation=hidden_activation, return_sequences=True, name='LSTM_1')(inputs) # Encoderlstm_2 = LSTM(n_hidden, activation=hidden_activation, return_sequences=False, name='LSTM_2')(lstm_1) # Encoderrepeat = RepeatVector(n_post, name='Repeat')(lstm_2) # get the last output of the LSTM and repeats itgru_1 = GRU(n_hidden, activation=hidden_activation, return_sequences=True, name='Decoder')(repeat)  # Decoderoutput= TimeDistributed(Dense(output_dim, activation='linear', kernel_regularizer=regularizers.l2(penalty), name='Dense'), name='Outputs')(gru_1)model = Model([inputs, weights_tensor], output)# Compilecl = wrapped_partial(weighted_mse, weights=weights_tensor)model.compile(optimizer=Adam(lr=lr), loss=cl)print(model.summary()) return modeldef train_model(model, dataX, dataY, weights, nb_epoches, nb_batches):# Prepare model checkpoints and callbacksstopping = EarlyStopping(monitor='val_loss', patience=int(patience), min_delta=0, verbose=1, mode='min', restore_best_weights=True)csv_logger = CSVLogger('results/encoder-decoder/{}/training_log_{}_{}_{}_{}_{}_{}_{}_{}.csv'.format(dataname,dataname,imp,hidden_activation,n_hidden,patience,dr,penalty,nb_batches), separator=',', append=False)terminate = TerminateOnNaN()# Model fithistory = model.fit(x=[dataX,weights], y=dataY, batch_size=nb_batches, verbose=1,epochs=nb_epoches, callbacks=[stopping,csv_logger,terminate],validation_split=0.2)def test_model():n_post = int(1)n_pre =int(t0)-1seq_len = int(T)wx = np.array(pd.read_csv("data/{}-wx-{}.csv".format(dataname,imp)))print('raw wx shape', wx.shape)  wXC = []for i in range(seq_len-n_pre-n_post):wXC.append(wx[i:i+n_pre]) wXC = np.array(wXC)print('wXC shape:', wXC.shape)x = np.array(pd.read_csv("data/{}-x-{}.csv".format(dataname,imp)))print('raw x shape', x.shape) dXC, dYC = [], []for i in range(seq_len-n_pre-n_post):dXC.append(x[i:i+n_pre])dYC.append(x[i+n_pre:i+n_pre+n_post])dataXC = np.array(dXC)dataYC = np.array(dYC)print('dataXC shape:', dataXC.shape)print('dataYC shape:', dataYC.shape)nb_features = dataXC.shape[2]output_dim = dataYC.shape[2]# create and fit the encoder-decoder networkprint('creating model...')model = create_model(n_pre, n_post, nb_features, output_dim, lr, penalty, dr, n_hidden, hidden_activation)train_model(model, dataXC, dataYC, wXC, int(nb_epochs), int(nb_batches))# now testprint('Generate predictions on full training set')preds_train = model.predict([dataXC,wXC], batch_size=int(nb_batches), verbose=1)print('predictions shape =', preds_train.shape)preds_train = np.squeeze(preds_train)print('predictions shape (squeezed)=', preds_train.shape)print('Saving to results/encoder-decoder/{}/encoder-decoder-{}-train-{}-{}-{}-{}-{}-{}.csv'.format(dataname,dataname,imp,hidden_activation,n_hidden,patience,dr,penalty,nb_batches))np.savetxt("results/encoder-decoder/{}/encoder-decoder-{}-train-{}-{}-{}-{}-{}-{}.csv".format(dataname,dataname,imp,hidden_activation,n_hidden,patience,dr,penalty,nb_batches), preds_train, delimiter=",")print('Generate predictions on test set')wy = np.array(pd.read_csv("data/{}-wy-{}.csv".format(dataname,imp)))print('raw wy shape', wy.shape)  wY = []for i in range(seq_len-n_pre-n_post):wY.append(wy[i:i+n_pre]) # weights for outputswXT = np.array(wY)print('wXT shape:', wXT.shape)y = np.array(pd.read_csv("data/{}-y-{}.csv".format(dataname,imp)))print('raw y shape', y.shape)  dXT = []for i in range(seq_len-n_pre-n_post):dXT.append(y[i:i+n_pre]) # treated is inputdataXT = np.array(dXT)print('dataXT shape:', dataXT.shape)preds_test = model.predict([dataXT, wXT], batch_size=int(nb_batches), verbose=1)print('predictions shape =', preds_test.shape)preds_test = np.squeeze(preds_test)print('predictions shape (squeezed)=', preds_test.shape)print('Saving to results/encoder-decoder/{}/encoder-decoder-{}-test-{}-{}-{}-{}-{}-{}.csv'.format(dataname,dataname,imp,hidden_activation,n_hidden,patience,dr,penalty,nb_batches))np.savetxt("results/encoder-decoder/{}/encoder-decoder-{}-test-{}-{}-{}-{}-{}-{}.csv".format(dataname,dataname,imp,hidden_activation,n_hidden,patience,dr,penalty,nb_batches), preds_test, delimiter=",")def main():test_model()return 1if __name__ == "__main__":main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/38592.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【场景题】数据库优化和接口优化——异步思想

理解 异步处理&#xff1a; 对于耗时的操作&#xff0c;可以考虑使用异步处理方式来提升接口的响应速度。用户可以在不阻塞当前操作的情况下&#xff0c;等待异步操作的结果。 异步处理在数据库优化中的应用 虽然数据库操作本身&#xff08;如查询、插入、更新等&#xff09…

Git 安装

目录 Git 安装 Git 安装 在使用 Git 前我们需要先安装 Git。Git 目前支持 Linux/Unix、Solaris、Mac 和 Windows 平台上运行。Git 各平台安装包下载地址为&#xff1a;http://git-scm.com/downloads 在 Linux 平台上安装&#xff08;包管理工具安装&#xff09; 首先&#xff0…

IIS在Windows上的搭建

&#x1f4d1;打牌 &#xff1a; da pai ge的个人主页 &#x1f324;️个人专栏 &#xff1a; da pai ge的博客专栏 ☁️宝剑锋从磨砺出&#xff0c;梅花香自苦寒来 目录 一 概念&#xff1a; 二网络…

深入理解C++中的锁

目录 1.基本互斥锁&#xff08;std::mutex&#xff09; 2.递归互斥锁&#xff08;std::recursive_mutex&#xff09; 3.带超时机制的互斥锁&#xff08;std::timed_mutex&#xff09; 4.带超时机制的递归互斥锁&#xff08;std::recursive_timed_mutex&#xff09; 5.共享…

【python脚本】批量检测sql延时注入

文章目录 前言批量检测sql延时注入工作原理脚本演示 前言 SQL延时注入是一种在Web应用程序中利用SQL注入漏洞的技术&#xff0c;当传统的基于错误信息或数据回显的注入方法不可行时&#xff0c;例如当Web应用进行了安全配置&#xff0c;不显示任何错误信息或敏感数据时&#x…

【TS】TypeScript 原始数据类型深度解析

&#x1f308;个人主页: 鑫宝Code &#x1f525;热门专栏: 闲话杂谈&#xff5c; 炫酷HTML | JavaScript基础 ​&#x1f4ab;个人格言: "如无必要&#xff0c;勿增实体" 文章目录 TypeScript 原始数据类型深度解析一、引言二、基础原始数据类型2.1 boolean2.2 …

苍穹外卖--sky-take-out(四)10-12

苍穹外卖--sky-take-out&#xff08;一&#xff09; 苍穹外卖--sky-take-out&#xff08;一&#xff09;-CSDN博客​编辑https://blog.csdn.net/kussm_/article/details/138614737?spm1001.2014.3001.5501https://blog.csdn.net/kussm_/article/details/138614737?spm1001.2…

Unity动画系统(2)

6.1 动画系统基础2-3_哔哩哔哩_bilibili p316 模型添加Animator组件 动画控制器 AnimatorController AnimatorController 可以通过代码控制动画速度 建立动画间的联系 bool值的设定 trigger p318 trigger点击的时候触发&#xff0c;如喊叫&#xff0c;开枪及换子弹等&#x…

在js中如何Json字符串格式不对,如何处理

如果 JSON 字符串格式不正确&#xff0c;解析它时会抛出异常&#xff0c;但我们可以尝试尽可能提取有效的信息。以下是一个方法&#xff0c;可以使用正则表达式和字符串操作来提取部分有效的 JSON 内容&#xff0c;即使整个字符串无法被 JSON.parse 完全解析。 示例代码如下&a…

错误 [WinError 10013] 以一种访问权限不允许的方式做了一个访问套接字的尝试 python ping

报错提示&#xff1a;错误 [WinError 10013] 以一种访问权限不允许的方式做了一个访问套接字的尝试 用python做了一个批量ping脚本&#xff0c;在windows专业版上没问题&#xff0c;但是到了windows服务器就出现这个报错 解决方法&#xff1a;右键 管理员身份运行 这个脚本 …

sql拉链表

1、定义&#xff1a;维护历史状态以及最新数据的一种表 2、使用场景 1、有一些表的数据量很大&#xff0c;比如一张用户表&#xff0c;大约1亿条记录&#xff0c;50个字段&#xff0c;这种表 2.表中的部分字段会被update更新操作&#xff0c;如用户联系方式&#xff0c;产品的…

compute和computeIfAbsent的区别和用法

compute和computeIfAbsent都是Map接口中的默认方法&#xff0c;用于在映射中进行键值对的计算和更新。它们的主要区别在于它们的行为和使用场景。 compute 方法 定义: V compute(K key, BiFunction<? super K, ? super V, ? extends V> remappingFunction);参数: k…

在 WebGPU 与 Vulkan 之间做出正确的选择(Making the Right Choice between WebGPU vs Vulkan)

在 WebGPU 与 Vulkan 之间做出正确的选择&#xff08;Making the Right Choice between WebGPU vs Vulkan&#xff09; WebGPU 和 Vulkan 之间的主要区别WebGPU 是什么&#xff1f;它适合谁使用&#xff1f;Vulkan 是什么&#xff1f;它适合谁使用&#xff1f;WebGPU 和 Vulkan…

修改CentOS7 yum源

修改CentOS默认yum源为阿里镜像源 备份系统自带yum源配置文件 mv /etc/yum.repos.d/CentOS-Base.repo /etc/yum.repos.d/CentOS-Base.repo.backup 下载ailiyun的yum源配置文件 CentOS7 yum源如下&#xff1a; wget -O /etc/yum.repos.d/CentOS-Base.repo http://mirrors.aliyun…

AI领域最需要掌握的技术是什么?

在AI领域&#xff0c;掌握一系列核心技术和相关知识是非常重要的&#xff0c;以下是AI专业人士最需要掌握的一些关键技术&#xff1a; 1. **数学基础** - 线性代数&#xff1a;用于处理向量和矩阵&#xff0c;是机器学习和深度学习的基石。 - 微积分&#xff1a;用于理解函数的…

SpringBoot项目使用WebSocket提示Error creating bean with name ‘serverEndpointExporter‘

问题描述&#xff1a;WebSocket在Controller中正常工作&#xff0c;但是在之后使用SpringBootTest进行单元测试的时候&#xff0c;突然提示WebSocket的相关错误。 错误提示&#xff1a; Exception encountered during context initialization - cancelling refresh attempt: …

项目中的代码记录日常

项目中的代码记录日常 /// <summary> /// 修改任务状态 /// </summary> private void StartProcess21() {Process21Task new Thread(() >{while (CommonUtility.IsWorking){try{if (tPAgvTasksList.Count > 0){Parallel.ForEach(tPAgvTasksList, new Paral…

gitlab push的时候需要密码,你忘记了密码

情景: 忘记密码,且登入网页端gitlab的密码并不能在push的时候使用,应该两者是两个不同的密码 解决方法: 直接设置ssh密钥登入,不使用密码gitlab添加SSH密钥——查看本地密钥 & 生成ssh密钥_gitlab生成ssh密钥-CSDN博客

[OC]萝卜圈Python手动机器人脚本

这是给机器人设置的端口&#xff0c;对照用 代码 # #作者:溥哥’ ##机器人驱动主程序 #请在main中编写您自己的机器人驱动代码 import msvcrt def main():a"none"while True:key_input msvcrt.getch()akey_inputif abw:print(a)robot_drv.set_motors(1,40,2,40,3,…

uniapp学习笔记

uniapp官网地址&#xff1a;https://uniapp.dcloud.net.cn/ 学习源码&#xff1a;https://gitee.com/qingnian8/uniapp-ling_project.git 颜色网址&#xff1a;https://colordrop.io/ uniapp中如何获取导航中的路由信息&#xff1f; onLoad(e){console.log(e)console.log(e.w…