从0开始基于transformer进行股价预测(pytorch版本)

目录

  • 数据阶段
    • 两个问题
    • 开始利用我们的代码进行切分
  • backbone网络
  • 训练
  • 效果 感觉还行,没有调参数。
  • 源码比较长,如果需要我后续会发(因为太长了!!)

数据阶段

!!!注意!!! , 本文不会讲原理,因为之前两篇文章已经讲过了,只会解释一些结构性问题,和思路问题。

所谓工欲善其事,必先利其器做量化分析的股价预测,完美必须要先把数据处理好。
那么本人的数据下载是在聚宽平台股票代码为601398的数据2014-3 到 2024-3年的默认数据。如何下载可以按照我的方式

在这里插入图片描述
进入研究环境后随便创建一个ipynb文件进行数据下载 ,运行以下代码就行

# 1.获取数据
data = get_price('601398.XSHG', start_date='2014-01-01', end_date='2024-01-01', frequency='daily', fields=None, skip_paused=False, fq='pre', panel=True)
# 2.保存数据
data.to_csv('data_沪深300/601398.XSHG(工商银行14-24).csv')

两个问题

1.为什么我们只需要用encoder部分去预测就行而不需要decoder部分?
答: 编码器用于将输入序列编码成一个上下文表示(contextual representation),然后解码器根据该上下文表示生成目标序列在时间序列预测任务中,我们不需要生成一个序列,而是预测单个或少量几个未来数据点。因此,编码器的上下文表示已经包含了足够的信息来进行预测,无需使用解码器。还有我觉得使用解码器的意思是,你用上一天的数据去预测下一天的数据,我感觉这样就没意思了,这和我们个人看有什么区别。而且对最后的结果也会造成不精准的效果。为什么这么说呢,你看解码器的mask编码部分应该可以理解了。
2.我们的维度为什么不是[batch, len, feature]? 因为这是pytorch要求,自己能实现的话,自己改吧。

开始利用我们的代码进行切分

我的思路用的是用五天的数据去预测下一天,数据集和测试及8/2分
但是我们要记住一点,就是我们必须要理解我们这么做的思路,就比如我们的特征有6列分别是,open,close,high,low,volume,money,我们可以通过训练得到我们想预测的某一特征。OK,我们这就开始。

说起数据分割里面的代码不难,最难的是
for i in range(len(X_CONVERT) - seq_length):
X_data.append(X_CONVERT[i:i+seq_length, :])
y_data.append(X_CONVERT[i+seq_length, 1])
你要知道我在干什么,就是用8成的数据集去预测得到我们所需要的train数据集和我们对应train数据集的label,举个例子就是,我们要炒菜,我们拿上原料后我们要知道炒的什么菜,那么菜单必须要知道。是吧,不然你炒完菜后说是红烧肉,但是没有菜单图片对比你怎么知道这是红烧肉?这也就是这一步的意义。

def split_data(batch_size,seq_length, pred_length, train_ratio):data_all = pd.read_csv(data_path)data_ha = []length = len(data_all)# 将数据转换为numpy数组,并添加到列表中for element in elements:data_element = data_all[element].values.astype(np.float32)data_element = data_element.reshape(length, 1)data_ha.append(data_element)X_hat = np.concatenate(data_ha, axis=1)X_CONVERT = torch.from_numpy(X_hat).float()X_CONVERT = X_CONVERT.flip(dims=[0])# 进行归一化min_val = np.min(X_hat, axis=0)max_val = np.max(X_hat, axis=0)X_normalized = (X_hat - min_val) / (max_val - min_val)X_CONVERT = torch.from_numpy(X_normalized).float()X_CONVERT = X_CONVERT.flip(dims=[0])#数据翻转# 划分训练集和验证集X_data = []y_data = []for i in range(len(X_CONVERT) - seq_length):#划分的时候是用8成的训练集去训练然后label是某##一列X_data.append(X_CONVERT[i:i+seq_length, :])y_data.append(X_CONVERT[i+seq_length, 1])X_data = torch.stack(X_data)y_data = torch.stack(y_data).squeeze(-1)print(X_data.shape, y_data.shape)dataset = TensorDataset(X_data, y_data)train_size = int(len(dataset) * train_ratio)val_size = len(dataset) - train_sizetrain_dataset, val_dataset = random_split(dataset, [train_size, val_size])train_loader = DataLoader(train_dataset, batch_size, shuffle=False)val_loader = DataLoader(val_dataset, batch_size, shuffle=False)return train_loader, val_loader,min_val, max_val

backbone网络

如其名,我们都知道这是这是transformer当然是用的transformer的结构。但是我们用,但是只用一部分,具体用什么部分开头说了,只用encoder

**但是具体操作起来的时候encoder里面的embadding部分我们需要修改,因为我们不是机器翻译,所以我们不需要把他变成词向量,我们时间序列数据,输入通常是连续的数值特征,使用线性层更直接地将这些数值特征映射到高维空间。并且我们的embadding嵌入层,适用于离散的输入,输出是固定维度的嵌入向量。而线性层,适用于连续的输入,可以灵活处理不同维度的输入特征,将其映射到高维表示。**具体看下面代码

class Encoder(nn.Module):def __init__(self):super(Encoder, self).__init__()self.src_emb = nn.Linear(feature, d_model)#这里替换了self.pos_emb = PositionalEncoding(d_model)self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])def forward(self, enc_inputs):enc_outputs = self.src_emb(enc_inputs)  # [batch_size, src_len, d_model]enc_outputs = self.pos_emb(enc_outputs)  # [batch_size, src_len, d_model]enc_self_attn_mask = get_attn_pad_mask(enc_inputs, enc_inputs)  # [batch_size, src_len, src_len]enc_self_attns = []for layer in self.layers:enc_outputs, enc_self_attn = layer(enc_outputs, enc_self_attn_mask)enc_self_attns.append(enc_self_attn)return enc_outputs, enc_self_attns

当然其他的部分和我上一篇的一样,但是就是decode不要了,当然也可以换成其他结果,或者加个注意力机制

讲下各个参数


d_model = 512   # linnerer的输入维度 也就是字embedding的维度
d_ff = 2048     # 前向传播隐藏层维度
d_k = d_v = 64  # K(=Q), V的维度
n_layers = 6    # 有多少个encoder和decoder
n_heads = 8     # Multi-Head Attention设置为8
feature=6       # 输入特征维度

当然主体还是要看一下的最重要的是通过encoder后的维度转换比较繁琐,要和我们之前split的数据集得到的y_train一致这样才能计算损失


class Transformer(nn.Module):def __init__(self):super(Transformer, self).__init__()self.Encoder = Encoder()self.projection = nn.Linear(d_model, 1, bias=False)def forward(self, enc_inputs):  # enc_inputs: [batch_size, src_len, feature]enc_outputs, enc_self_attns = self.Encoder(enc_inputs)  # enc_outputs: [batch_size, src_len, d_model]dec_logits = self.projection(enc_outputs)  # dec_logits: [batch_size, src_len, 1]dec_logits = dec_logits.mean(dim=1)  # 将每个时间步的预测结果取平均,得到 [batch_size, 1]return dec_logits.squeeze(-1), enc_self_attns  # 输出 [batch_size]

训练

先解释参数

batch_size=64#批处理大小
seq_length=7#时间序列长度 也就是通过seq_length天预测后面pred_length天
pred_length=1#预测长度
train_ratio=0.8#训练集比例
epochs = 50 # 训练轮数
lr= 0.001 # 学习率
png_save_path="diytransformers/12.24transformer/picture"#所有的图片保存的地方
loss_history = []# 存储每个 epoch 的损失

训练代码很长,挺简单的


# 训练模型
for epoch in range(epochs):epoch_loss = 0y_pre = []y_true = []# 训练阶段for X, y in train_loader:X = X.float()  # 确保输入数据类型为float32y = y.float()  # 确保目标数据类型为float32outputs, enc_self_attns = model(X)# 计算损失,确保形状一致loss = criterion(outputs, y)epoch_loss += loss.item()optimizer.zero_grad()loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)optimizer.step()#转换我们的label和训练后得到的训练集的预测值 y_pre.append(outputs.detach())y_true.append(y.detach())avg_loss = epoch_loss / len(train_loader)loss_history.append(avg_loss)#获得最好的lossif avg_loss < best_loss:best_loss = avg_lossbest_epoch = epochbest_model_wts = copy.deepcopy(model.state_dict())torch.save(best_model_wts, path_train)y_pre_concat = torch.cat(y_pre, dim=0)y_true_concat = torch.cat(y_true, dim=0)# 计算并打印评估指标metrics = evaluate(y_pre_concat, y_true_concat, min_val, max_val)print(f'Epoch {epoch + 1}, Loss: {avg_loss:.6f}')# 可视化结果ht(y_true_concat.detach().cpu().numpy(), y_pre_concat.detach().cpu().numpy(), min_val, max_val,png_save_path)

最后是看我们的一些指标效果如何 比如这里我计算的mae,rmse,pcc等

# 加载最佳模型权重
model.load_state_dict(torch.load(train_over_path))# 测试模型并计算评估指标
test_metrics = test_model(model, val_loader, min_val, max_val)print(f'Test Metrics: {test_metrics}')

效果 感觉还行,没有调参数。

在这里插入图片描述

源码比较长,如果需要我后续会发(因为太长了!!)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/45070.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

还不懂 OOM ?详解内存溢出与内存泄漏区别!

内存溢出与内存泄漏 1. 内存溢出&#xff08;Out Of Memory&#xff0c;OOM&#xff09; 概念&#xff1a; 内存溢出是指程序在运行过程中&#xff0c;尝试申请的内存超过了系统所能提供的最大内存限制&#xff0c;并且垃圾收集器也无法提供更多的内存&#xff0c;导致程序无…

# Redis 入门到精通(一)数据类型(3)

Redis 入门到精通&#xff08;一&#xff09;数据类型&#xff08;3&#xff09; 一、redis 数据类型–set 类型介绍与基本操作 1、set 类型 新的存储需求: 存储大量的数据&#xff0c;在查询方面提供更高的效率。需要的存储结构: 能够保存大量的数据&#xff0c;高效的内部…

【爬虫】解析爬取的数据

目录 一、正则表达式1、常用元字符2、量词3、Re模块4、爬取豆瓣电影 二、Xpath1、Xpath解析Ⅰ、节点选择Ⅱ、路径表达式Ⅲ、常用函数 2、爬取豆瓣电影 解析数据&#xff0c;除了前面的BeautifulSoup库&#xff0c;还有正则表达式和Xpath两种方法。 一、正则表达式 正则表达式…

C++|智能指针

目录 引入 一、智能指针的使用及原理 1.1RAII 1.2智能指针原理 1.3智能指针发展 1.3.1std::auto_ptr 1.3.2std::unique_ptr 1.3.3std::shared_ptr 二、循环引用问题及解决方法 2.1循环引用 2.2解决方法 三、删除器 四、C11和boost中智能指针的关系 引入 回顾上…

谷粒商城学习笔记-19-快速开发-逆向生成所有微服务基本CRUD代码

文章目录 一&#xff0c;使用逆向工程步骤梳理1&#xff0c;修改逆向工程的application.yml配置2&#xff0c;修改逆向工程的generator.properties配置3&#xff0c;以Debug模式启动逆向工程4&#xff0c;使用逆向工程生成代码5&#xff0c;整合生成的代码到对应的模块中 二&am…

VMware Workstation 虚拟机网络配置为与主机使用同一网络

要将 VMware Workstation 虚拟机网络配置为与主机使用同一网络&#xff0c;我们需要将虚拟机的网络适配器设置为桥接模式。具体步骤如下&#xff1a; 配置 VMware Workstation 虚拟机网络为桥接模式 打开 VMware Workstation&#xff1a; 启动 VMware Workstation。 选择虚拟机…

实验场:在几分钟内使用 Bedrock Anthropic Models 和 Elasticsearch 进行 RAG 实验

作者&#xff1a;来自 Elastic Joe McElroy, Aditya Tripathi 我们最近发布了 Elasticsearch Playground&#xff0c;这是一个新的低代码界面&#xff0c;开发人员可以通过 A/B 测试 LLM、调整提示&#xff08;prompt&#xff09;和分块数据来迭代和构建生产 RAG 应用程序。今天…

Web3学习路线图,从入门到精通

前面我们聊了Web3的知识图谱&#xff0c;内容是相当的翔实&#xff0c;要从哪里入手可以快速的入门Web3&#xff0c;本篇就带你看看Web3的学习路线图&#xff0c;一步一步深入学习Web3。 这张图展示了Web3学习路线图&#xff0c;涵盖了区块链基础知识、开发方向、应用开发等内…

接上一回C++:补继承漏洞+多态原理(带图详解)

引子&#xff1a;接上一回我们讲了继承的分类与六大默认函数&#xff0c;其实继承中的菱形继承是有一个大坑的&#xff0c;我们也要进入多态的学习了。 注意&#xff1a;我学会了&#xff0c;但是讲述上可能有一些不足&#xff0c;希望大家多多包涵 继承复习&#xff1a; 1&…

windows环境下基于3DSlicer 源代码编译搭建工程开发环境详细操作过程和中间关键错误解决方法说明

说明: 该文档适用于  首次/重新 搭建3D-Slicer工程环境  Clean up(非增量) 编译生成 1. 3D-slicer 软件介绍 (1)3D Slicer为处理MRI\CT等图像数据软件,可以实行基于MRI图像数据的目标分割、标记测量、坐标变换及三维重建等功能,其源于3D slicer 4.13.0-2022-01-19开…

OS Copilot测评

1.按照第一步管理重置密码时报错了,搞不懂为啥?本来应该跳转到给的那个实例的,我的没跳过去 2.下一步重置密码的很丝滑没问题 3安全组新增入库22没问题 很方便清晰 4.AccessKey 还能进行预警提示 5.远程连接,网速还是很快,一点没卡,下载很棒 6.替换的时候我没有替换<>括…

【JavaEE】网络编程——UDP

&#x1f921;&#x1f921;&#x1f921;个人主页&#x1f921;&#x1f921;&#x1f921; &#x1f921;&#x1f921;&#x1f921;JavaEE专栏&#x1f921;&#x1f921;&#x1f921; 文章目录 1.数据报套接字(UDP)1.1特点1.2编码1.2.1DatagramSocket1.2.2DatagramPacket…

Spring Cloud Alibaba AI 介绍及使用

一、Spring Cloud Alibaba AI 介绍 Spring AI 是 Spring 官方社区项目&#xff0c;旨在简化 Java AI 应用程序开发&#xff0c;让 Java 开发者像使用 Spring 开发普通应用一样开发 AI 应用。而 Spring Cloud Alibaba AI 是阿里以 Spring AI 为基础&#xff0c;并在此基础上提供…

dive deeper into tensor:从底层开始学习tensor

inspired by karpathy/micrograd: A tiny scalar-valued autograd engine and a neural net library on top of it with PyTorch-like API (github.com)and Taking PyTorch for Granted | wh (nrehiew.github.io). 这属于karpathy的karpathy/nn-zero-to-hero: Neural Networks…

阐述 C 语言中的参数传递机制

&#x1f345;关注博主&#x1f397;️ 带你畅游技术世界&#xff0c;不错过每一次成长机会&#xff01; &#x1f4d9;C 语言百万年薪修炼课程 通俗易懂&#xff0c;深入浅出&#xff0c;匠心打磨&#xff0c;死磕细节&#xff0c;6年迭代&#xff0c;看过的人都说好。 文章目…

多表查询sql

概述&#xff1a;项目开发中,在进行数据库表结构设计时,会根据业务需求及业务模块之间的关系,分析并设计表结构,由于业务之间相互关联,所以各个表结构之间也存在着各种联系&#xff0c;分为三种&#xff1a; 一对多多对多一对一 一、多表关系 一对多 案例&#xff1a;部门与…

【PowerShell】-1-快速熟悉并使用PowerShell

目录 PowerShell是什么&#xff1f;和CMD的区别&#xff1f; PowerShell的演变 自动化IT管理任务 一些名词 详尽的PowerShell开始之路 1.打开PowerShell&#xff1a; 2.基本命令&#xff1a; &#xff08;1&#xff09;Get-Process &#xff08;2&#xff09;变量赋值…

【核心笔记】Java入门到起飞,小白都能看懂的Java教程 (五)——数组

一 数组的定义和初始化 定义数组 数据类型[] 数组名&#xff1b;例 int[] arr; 数据类型 数组名[]&#xff1b;例 int arr[]; 数组初始化 数据类型[] 数组名 new 数据类型[] {值}&#xff1b;例 int[] arr new int[] {1,2,3}; &#xff08;简化形式&#xff09;数据类型[] 数…

超赞!只需粘贴复制超赞,视频快速转换成文章

大家好&#xff01;我是闷声轻创&#xff01;是否还在为撰写高质量的文章而熬夜奋战&#xff1f;今天&#xff0c;我要给你们带来一个超级棒的消息——视频变文章的神奇工具&#xff0c;让你的创作之路从此不再艰辛&#xff01; 视频素材的宝藏——油管&#xff08;YTB&#xf…

2024年了还在学pytestday1

1、按照博主的说法&#xff0c;提出疑问&#xff1a;应该在电脑本地终端安装还是在pythoncharm终端安装&#xff1f; ------在pythoncharm终端安装就行 避免老是忘记&#xff0c;还是记下来比较好。 2、在公司安装不成功&#xff0c;换豆瓣源也不行&#xff0c;连接手机热点尝…