[开源] 基于transformer的时间序列预测模型python代码

分享一下基于transformer的时间序列预测模型python代码,给大家,记得点赞哦

#!/usr/bin/env python
# coding: 帅帅的笔者import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import time
import math
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler# Set random seeds for reproducibility
torch.manual_seed(0)
np.random.seed(0)# Hyperparameters
input_window = 10
output_window = 1
batch_size = 250
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
epochs = 100
lr = 0.00005# Load data
df = pd.read_csv("data1.csv", parse_dates=["value"], index_col=[0], encoding='gbk')
data = np.array(df['value']).reshape(-1, 1)# Normalize data
scaler = MinMaxScaler(feature_range=(-1, 1))
data_normalized = scaler.fit_transform(data)# Split the data into train and validation sets
train_ratio = 0.828
train_size = int(len(data) * train_ratio)
val_size = len(data) - train_size
train_data_normalized = data_normalized[:train_size]
val_data_normalized = data_normalized[train_size:]# Define the Transformer model
class PositionalEncoding(nn.Module):def __init__(self, d_model, max_len=5000):super(PositionalEncoding, self).__init__()pe = torch.zeros(max_len, d_model)position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)pe = pe.unsqueeze(0).transpose(0, 1)self.register_buffer('pe', pe)def forward(self, x):return x + self.pe[:x.size(0), :]class TransAm(nn.Module):def __init__(self, feature_size=250, num_layers=1, dropout=0.1):super(TransAm, self).__init__()self.model_type = 'Transformer'self.src_mask = Noneself.pos_encoder = PositionalEncoding(feature_size)self.encoder_layer = nn.TransformerEncoderLayer(d_model=feature_size, nhead=10, dropout=dropout)self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)self.decoder = nn.Linear(feature_size, 1)self.init_weights()def init_weights(self):initrange = 0.1self.decoder.bias.data.zero_()self.decoder.weight.data.uniform_(-initrange, initrange)def forward(self, src):if self.src_mask is None or self.src_mask.size(0) != len(src):device = src.devicemask = self._generate_square_subsequent_mask(len(src)).to(device)self.src_mask = masksrc = self.pos_encoder(src)output = self.transformer_encoder(src, self.src_mask)output = self.decoder(output)return outputdef _generate_square_subsequent_mask(self, sz):mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))return mask# Create the dataset for the model
def create_inout_sequences(data, input_window, output_window):inout_seq = []length = len(data)for i in range(length - input_window - output_window):train_seq = data[i:i+input_window]train_label = data[i+input_window:i+input_window+output_window]inout_seq.append((train_seq, train_label))return inout_seqtrain_data = create_inout_sequences(train_data_normalized, input_window, output_window)
val_data = create_inout_sequences(val_data_normalized, input_window, output_window)# Train the model
model = TransAm().to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)def train(train_data):model.train()total_loss = 0.for i in range(0, len(train_data) - 1, batch_size):data, targets = torch.stack([torch.tensor(item[0], dtype=torch.float32) for item in train_data[i:i+batch_size]]).to(device), torch.stack([torch.tensor(item[1], dtype=torch.float32) for item in train_data[i:i+batch_size]]).to(device)optimizer.zero_grad()output = model(data)loss = criterion(output, targets)loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)optimizer.step()total_loss += loss.item()return total_loss / len(train_data)def validate(val_data):model.eval()total_loss = 0.with torch.no_grad():for i in range(0, len(val_data) - 1, batch_size):data, targets = torch.stack([torch.tensor(item[0], dtype=torch.float32) for item in val_data[i:i+batch_size]]).to(device), torch.stack([torch.tensor(item[1], dtype=torch.float32) for item in val_data[i:i+batch_size]]).to(device)output = model(data)loss = criterion(output, targets)total_loss += loss.item()return total_loss / len(val_data)best_val_loss = float("inf")
best_model = Nonefor epoch in range(1, epochs + 1):epoch_start_time = time.time()train_loss = train(train_data)val_loss = validate(val_data)scheduler.step()if val_loss < best_val_loss:best_val_loss = val_lossbest_model = model# Predict and denormalize the data
def predict(model, dataset):model.eval()predictions = []actuals = []with torch.no_grad():for i in range(len(dataset)):data, target = dataset[i]data = torch.tensor(data, dtype=torch.float32).to(device)output = model(data.unsqueeze(0))prediction = output.squeeze().cpu().numpy()predictions.append(prediction)actuals.append(target)return np.array(predictions), np.array(actuals)predictions, actuals = predict(best_model, val_data)
print("Predictions shape:", predictions.shape)
print("Actuals shape:", actuals.shape)predictions_denorm = scaler.inverse_transform(predictions)
actuals_denorm = scaler.inverse_transform(actuals.flatten().reshape(-1, 1))# Plot the results
plt.plot(predictions_denorm, label='Predictions')
plt.plot(actuals_denorm, label='Actuals')
plt.legend(['Predictions', 'Actuals'])
plt.xlabel('Timestep')
plt.ylabel('High')
plt.legend()
plt.show()

更多时间序列预测代码:时间序列预测算法全集合--深度学习

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/803771.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Java8新特性】二、函数式接口

这里写自定义目录标题 一、什么是函数式接口二、自定义函数式接口三、作为参数传递 Lambda 表达式四、四大内置核心函数式接口1、消费形接口2、供给形接口3、函数型接口4、断言形接口 一、什么是函数式接口 只包含一个抽象方法的接口&#xff0c;称为函数式接口。你可以通过 L…

【MATLAB高级编程】第二篇 | 元胞数组(cell)操作

【第二篇】元胞数组&#xff08;cell&#xff09;操作 1. 创建元胞数组cell2. 查看和修改cell内的元素值3. 高级操作: 可视化作图显示cell内的内容4. 把矩阵转换成单元数组5. 把单元数组转换成结构体变量 你好&#xff01; 欢迎进入 《MATLAB高级编程》 文章系列 &#xff0c;每…

postgresql uuid

示例数据库版本PG16&#xff0c;对于参照官方文档截图&#xff0c;可以在最上方切换到对应版本查看&#xff0c;相差不大。 方法一&#xff1a;自带函数 select gen_random_uuid(); 去掉四个斜杠&#xff0c;简化成32位 select replace(gen_random_uuid()::text, -, ); 官网介绍…

《前端面试题》- CSS - CSS选择器的优先级

行内样式1000 d选择器100 属性选择器、class或者伪类10 元素选择器&#xff0c;或者伪元素1 通配符0 参考网址&#xff1a;https://blog.csdn.net/jbj6568839z/article/details/113888600https://www.cnblogs.com/RenshuozZ/p/10327285.htmlhttps://www.cnblogs.com/zxjwlh/p/6…

搭建Grafana+Prometheus监控Spring Boot应用

Spring项目改造 maven依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-actuator</artifactId> </dependency><dependency><groupId>io.micrometer</groupId><artif…

​如何使用 ArcGIS Pro 制作带贴图建筑

对于用GIS软件制作三维建筑&#xff0c;很多时候都是制作的建筑体块&#xff0c;这里为大家介绍一下怎么使用 ArcGIS Pro 制作带贴图的建筑&#xff0c;希望能对你有所帮助。 数据来源 教程所使用的数据是从水经微图中下载的建筑数据&#xff0c;除了建筑数据&#xff0c;常见…

最简洁的Docker环境配置

Docker环境配置 Docker 是一个开源的应用容器引擎&#xff0c;让开发者可以打包他们的应用以及依赖包到一个可移植的镜像中&#xff0c;然后发布到任何流行的 Mac、Linux或Windows操作系统的机器上&#xff0c;也可以实现虚拟化。容器是完全使用沙箱机制&#xff0c;相互之间不…

AI大模型探索之路-应用篇2:Langchain框架ModelIO模块—数据交互的秘密武器

目录 前言 一、概述​​​​​​​ 二、Model 三、Prompt 五、Output Parsers 总结 前言 随着人工智能技术的不断进步&#xff0c;大模型的应用场景越来越广泛。LangChain框架作为一个创新的解决方案&#xff0c;专为处理大型语言模型的输入输出而设计。其中&#xff0c;…

redis主从复制详解

redis主从复制(replica) 1、是什么&#xff1f; 目录 redis主从复制(replica) 1、是什么&#xff1f; 2、能干嘛&#xff1f; 3、怎么玩&#xff1f; 4、案例演示 前置操作 &#x1f357;一主二仆 &#x1f355;薪火相传 &#x1f32d;反客为主 5、复制的原理和工作…

Flutter仿Boss-6.底部tab切换

效果 实现 图片资源采用boss包中的动画webp资源。Flutter采用Image加载webp动画。 遇到的问题 问题&#xff1a;Flutter加载webp再次加载无法再次播放动画问题 看如下代码&#xff1a; Image.asset(assets/images/xxx.webp,width: 40.w,height: 30.w, )运行的效果&#xf…

Vue3 + Vite 构建组件库发布到 npm

你有构建完组件库后&#xff0c;因为不知道如何发布到 npm 的烦恼吗&#xff1f;本教程手把手教你用 Vite 构建组件库发布到 npm 搭建项目 这里我们使用 Vite 初始化项目&#xff0c;执行命令&#xff1a; pnpm create vite my-vue-app --template vue这里以我的项目 vue3-xm…

GPT提示词分享 —— 中医

&#x1f449; 中医诊断涉及因素较多&#xff0c;治疗方案仅供参考&#xff0c;具体的方子需由医生提供。AI建议不能替代专业医疗意见&#xff0c;如果症状严重或持续&#xff0c;建议咨询专业医生。 我希望你能扮演一位既是老中医同时又是一个营养学专家&#xff0c;我讲描述…

Linux部署FTP服务器

文章目录 什么是FTP协议&#xff1f;Linux上部署FTP服务器安装FTP服务启动FTP服务编辑/etc/vsftpd.conf重新启动服务测试FTP服务 什么是FTP协议&#xff1f; FTP协议是一种基于TCP的文件传输协议&#xff0c;能够实现高效的文件上传和下载功能&#xff0c;最重要的是它能够使用…

LeetCode-322. 零钱兑换【广度优先搜索 数组 动态规划】

LeetCode-322. 零钱兑换【广度优先搜索 数组 动态规划】 题目描述&#xff1a;解题思路一&#xff1a;Python动态规划五部曲&#xff1a;定推初遍举【先遍历物品 后遍历背包】解题思路二&#xff1a;Python动态规划五部曲&#xff1a;定推初遍举【先遍历背包 后遍历物品】解题思…

组装机械狗电子玩具方案

这款机械狗玩具电子方案结合了现代电子技术和人工智能元素&#xff0c;旨在为用户提供一个高科技、互动性强的娱乐体验。通过不断的软件更新和硬件迭代&#xff0c;机械狗的功能将持续扩展。 一、功能特点&#xff1a; 1、自动巡游&#xff1a;机械狗能够自主在房间内巡游&am…

一文详解手机IP地址如何改变

在互联网时代&#xff0c;手机的IP地址扮演着至关重要的角色。它不仅是手机在网络中的标识&#xff0c;还关系到手机的网络连接、隐私保护以及访问权限等方面。然而&#xff0c;在某些情况下&#xff0c;我们可能需要改变手机的IP地址&#xff0c;以满足特定的需求或解决网络问…

OLAP在线实时 数据分析平台

随着业务的增长&#xff0c;精细化运营的提出&#xff0c;产品对数据部门提出了更高的要求&#xff0c;包括需要对实时数据进行查询分析&#xff0c;快速调整运营策略&#xff1b;对小部分人群做 AB 实验&#xff0c;验证新功能的有效性&#xff1b;减少数据查询时间&#xff0…

逆向案例十七(1)——webpack加如果之前发送公钥如何定位参数,基于中国五矿

网址链接&#xff1a;中国五矿集团有限公司采购电子商务平台 定位到数据包&#xff0c;载荷中param是一个加密参数。 每一个数据包前都有一个public返回公钥。 点击查看返回的数据 如何定位参数加密位置&#xff1f; 复制公钥包url的后面&#xff0c;进行搜索 &#xff0c;查…

nodejs fs http express express-session jwt mysql mongoose

文件fs模块 读取文件内容 fs.readFile(./file/fs-01.txt, utf8, (err, data) > {if (err) {console.error(err)return}console.log(data) })写入内容到文件 const fs require(fs);const filePath "./file/output.txt";fs.writeFile(filePath, "Hello Wor…

[C++][算法基础]字符串统计(Trie树)

维护一个字符串集合&#xff0c;支持两种操作&#xff1a; I x 向集合中插入一个字符串 x&#xff1b;Q x 询问一个字符串在集合中出现了多少次。 共有 N 个操作&#xff0c;所有输入的字符串总长度不超过 &#xff0c;字符串仅包含小写英文字母。 输入格式 第一行包含整数…