从0开始基于transformer进行股价预测(pytorch版本)

目录

  • 数据阶段
    • 两个问题
    • 开始利用我们的代码进行切分
  • backbone网络
  • 训练
  • 效果 感觉还行,没有调参数。
  • 源码比较长,如果需要我后续会发(因为太长了!!)

数据阶段

!!!注意!!! , 本文不会讲原理,因为之前两篇文章已经讲过了,只会解释一些结构性问题,和思路问题。

所谓工欲善其事,必先利其器做量化分析的股价预测,完美必须要先把数据处理好。
那么本人的数据下载是在聚宽平台股票代码为601398的数据2014-3 到 2024-3年的默认数据。如何下载可以按照我的方式

在这里插入图片描述
进入研究环境后随便创建一个ipynb文件进行数据下载 ,运行以下代码就行

# 1.获取数据
data = get_price('601398.XSHG', start_date='2014-01-01', end_date='2024-01-01', frequency='daily', fields=None, skip_paused=False, fq='pre', panel=True)
# 2.保存数据
data.to_csv('data_沪深300/601398.XSHG(工商银行14-24).csv')

两个问题

1.为什么我们只需要用encoder部分去预测就行而不需要decoder部分?
答: 编码器用于将输入序列编码成一个上下文表示(contextual representation),然后解码器根据该上下文表示生成目标序列在时间序列预测任务中,我们不需要生成一个序列,而是预测单个或少量几个未来数据点。因此,编码器的上下文表示已经包含了足够的信息来进行预测,无需使用解码器。还有我觉得使用解码器的意思是,你用上一天的数据去预测下一天的数据,我感觉这样就没意思了,这和我们个人看有什么区别。而且对最后的结果也会造成不精准的效果。为什么这么说呢,你看解码器的mask编码部分应该可以理解了。
2.我们的维度为什么不是[batch, len, feature]? 因为这是pytorch要求,自己能实现的话,自己改吧。

开始利用我们的代码进行切分

我的思路用的是用五天的数据去预测下一天,数据集和测试及8/2分
但是我们要记住一点,就是我们必须要理解我们这么做的思路,就比如我们的特征有6列分别是,open,close,high,low,volume,money,我们可以通过训练得到我们想预测的某一特征。OK,我们这就开始。

说起数据分割里面的代码不难,最难的是
for i in range(len(X_CONVERT) - seq_length):
X_data.append(X_CONVERT[i:i+seq_length, :])
y_data.append(X_CONVERT[i+seq_length, 1])
你要知道我在干什么,就是用8成的数据集去预测得到我们所需要的train数据集和我们对应train数据集的label,举个例子就是,我们要炒菜,我们拿上原料后我们要知道炒的什么菜,那么菜单必须要知道。是吧,不然你炒完菜后说是红烧肉,但是没有菜单图片对比你怎么知道这是红烧肉?这也就是这一步的意义。

def split_data(batch_size,seq_length, pred_length, train_ratio):data_all = pd.read_csv(data_path)data_ha = []length = len(data_all)# 将数据转换为numpy数组,并添加到列表中for element in elements:data_element = data_all[element].values.astype(np.float32)data_element = data_element.reshape(length, 1)data_ha.append(data_element)X_hat = np.concatenate(data_ha, axis=1)X_CONVERT = torch.from_numpy(X_hat).float()X_CONVERT = X_CONVERT.flip(dims=[0])# 进行归一化min_val = np.min(X_hat, axis=0)max_val = np.max(X_hat, axis=0)X_normalized = (X_hat - min_val) / (max_val - min_val)X_CONVERT = torch.from_numpy(X_normalized).float()X_CONVERT = X_CONVERT.flip(dims=[0])#数据翻转# 划分训练集和验证集X_data = []y_data = []for i in range(len(X_CONVERT) - seq_length):#划分的时候是用8成的训练集去训练然后label是某##一列X_data.append(X_CONVERT[i:i+seq_length, :])y_data.append(X_CONVERT[i+seq_length, 1])X_data = torch.stack(X_data)y_data = torch.stack(y_data).squeeze(-1)print(X_data.shape, y_data.shape)dataset = TensorDataset(X_data, y_data)train_size = int(len(dataset) * train_ratio)val_size = len(dataset) - train_sizetrain_dataset, val_dataset = random_split(dataset, [train_size, val_size])train_loader = DataLoader(train_dataset, batch_size, shuffle=False)val_loader = DataLoader(val_dataset, batch_size, shuffle=False)return train_loader, val_loader,min_val, max_val

backbone网络

如其名,我们都知道这是这是transformer当然是用的transformer的结构。但是我们用,但是只用一部分,具体用什么部分开头说了,只用encoder

**但是具体操作起来的时候encoder里面的embadding部分我们需要修改,因为我们不是机器翻译,所以我们不需要把他变成词向量,我们时间序列数据,输入通常是连续的数值特征,使用线性层更直接地将这些数值特征映射到高维空间。并且我们的embadding嵌入层,适用于离散的输入,输出是固定维度的嵌入向量。而线性层,适用于连续的输入,可以灵活处理不同维度的输入特征,将其映射到高维表示。**具体看下面代码

class Encoder(nn.Module):def __init__(self):super(Encoder, self).__init__()self.src_emb = nn.Linear(feature, d_model)#这里替换了self.pos_emb = PositionalEncoding(d_model)self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])def forward(self, enc_inputs):enc_outputs = self.src_emb(enc_inputs)  # [batch_size, src_len, d_model]enc_outputs = self.pos_emb(enc_outputs)  # [batch_size, src_len, d_model]enc_self_attn_mask = get_attn_pad_mask(enc_inputs, enc_inputs)  # [batch_size, src_len, src_len]enc_self_attns = []for layer in self.layers:enc_outputs, enc_self_attn = layer(enc_outputs, enc_self_attn_mask)enc_self_attns.append(enc_self_attn)return enc_outputs, enc_self_attns

当然其他的部分和我上一篇的一样,但是就是decode不要了,当然也可以换成其他结果,或者加个注意力机制

讲下各个参数


d_model = 512   # linnerer的输入维度 也就是字embedding的维度
d_ff = 2048     # 前向传播隐藏层维度
d_k = d_v = 64  # K(=Q), V的维度
n_layers = 6    # 有多少个encoder和decoder
n_heads = 8     # Multi-Head Attention设置为8
feature=6       # 输入特征维度

当然主体还是要看一下的最重要的是通过encoder后的维度转换比较繁琐,要和我们之前split的数据集得到的y_train一致这样才能计算损失


class Transformer(nn.Module):def __init__(self):super(Transformer, self).__init__()self.Encoder = Encoder()self.projection = nn.Linear(d_model, 1, bias=False)def forward(self, enc_inputs):  # enc_inputs: [batch_size, src_len, feature]enc_outputs, enc_self_attns = self.Encoder(enc_inputs)  # enc_outputs: [batch_size, src_len, d_model]dec_logits = self.projection(enc_outputs)  # dec_logits: [batch_size, src_len, 1]dec_logits = dec_logits.mean(dim=1)  # 将每个时间步的预测结果取平均,得到 [batch_size, 1]return dec_logits.squeeze(-1), enc_self_attns  # 输出 [batch_size]

训练

先解释参数

batch_size=64#批处理大小
seq_length=7#时间序列长度 也就是通过seq_length天预测后面pred_length天
pred_length=1#预测长度
train_ratio=0.8#训练集比例
epochs = 50 # 训练轮数
lr= 0.001 # 学习率
png_save_path="diytransformers/12.24transformer/picture"#所有的图片保存的地方
loss_history = []# 存储每个 epoch 的损失

训练代码很长,挺简单的


# 训练模型
for epoch in range(epochs):epoch_loss = 0y_pre = []y_true = []# 训练阶段for X, y in train_loader:X = X.float()  # 确保输入数据类型为float32y = y.float()  # 确保目标数据类型为float32outputs, enc_self_attns = model(X)# 计算损失,确保形状一致loss = criterion(outputs, y)epoch_loss += loss.item()optimizer.zero_grad()loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)optimizer.step()#转换我们的label和训练后得到的训练集的预测值 y_pre.append(outputs.detach())y_true.append(y.detach())avg_loss = epoch_loss / len(train_loader)loss_history.append(avg_loss)#获得最好的lossif avg_loss < best_loss:best_loss = avg_lossbest_epoch = epochbest_model_wts = copy.deepcopy(model.state_dict())torch.save(best_model_wts, path_train)y_pre_concat = torch.cat(y_pre, dim=0)y_true_concat = torch.cat(y_true, dim=0)# 计算并打印评估指标metrics = evaluate(y_pre_concat, y_true_concat, min_val, max_val)print(f'Epoch {epoch + 1}, Loss: {avg_loss:.6f}')# 可视化结果ht(y_true_concat.detach().cpu().numpy(), y_pre_concat.detach().cpu().numpy(), min_val, max_val,png_save_path)

最后是看我们的一些指标效果如何 比如这里我计算的mae,rmse,pcc等

# 加载最佳模型权重
model.load_state_dict(torch.load(train_over_path))# 测试模型并计算评估指标
test_metrics = test_model(model, val_loader, min_val, max_val)print(f'Test Metrics: {test_metrics}')

效果 感觉还行,没有调参数。

在这里插入图片描述

源码比较长,如果需要我后续会发(因为太长了!!)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/45070.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

多个uilabel添加同一个UITapGestureRecognizer对象,只有最后那个生效么

如果多个 UILabel 添加同一个 UITapGestureRecognizer 对象&#xff0c;确实只有最后一个 UILabel 会响应手势。这是因为一个手势识别器只能被添加到一个视图上&#xff0c;多次添加实际上是重新指定该识别器的视图目标。 要实现多个 UILabel 响应相同的手势&#xff0c;可以为…

还不懂 OOM ?详解内存溢出与内存泄漏区别!

内存溢出与内存泄漏 1. 内存溢出&#xff08;Out Of Memory&#xff0c;OOM&#xff09; 概念&#xff1a; 内存溢出是指程序在运行过程中&#xff0c;尝试申请的内存超过了系统所能提供的最大内存限制&#xff0c;并且垃圾收集器也无法提供更多的内存&#xff0c;导致程序无…

# Redis 入门到精通(一)数据类型(3)

Redis 入门到精通&#xff08;一&#xff09;数据类型&#xff08;3&#xff09; 一、redis 数据类型–set 类型介绍与基本操作 1、set 类型 新的存储需求: 存储大量的数据&#xff0c;在查询方面提供更高的效率。需要的存储结构: 能够保存大量的数据&#xff0c;高效的内部…

【爬虫】解析爬取的数据

目录 一、正则表达式1、常用元字符2、量词3、Re模块4、爬取豆瓣电影 二、Xpath1、Xpath解析Ⅰ、节点选择Ⅱ、路径表达式Ⅲ、常用函数 2、爬取豆瓣电影 解析数据&#xff0c;除了前面的BeautifulSoup库&#xff0c;还有正则表达式和Xpath两种方法。 一、正则表达式 正则表达式…

C++|智能指针

目录 引入 一、智能指针的使用及原理 1.1RAII 1.2智能指针原理 1.3智能指针发展 1.3.1std::auto_ptr 1.3.2std::unique_ptr 1.3.3std::shared_ptr 二、循环引用问题及解决方法 2.1循环引用 2.2解决方法 三、删除器 四、C11和boost中智能指针的关系 引入 回顾上…

谷粒商城学习笔记-19-快速开发-逆向生成所有微服务基本CRUD代码

文章目录 一&#xff0c;使用逆向工程步骤梳理1&#xff0c;修改逆向工程的application.yml配置2&#xff0c;修改逆向工程的generator.properties配置3&#xff0c;以Debug模式启动逆向工程4&#xff0c;使用逆向工程生成代码5&#xff0c;整合生成的代码到对应的模块中 二&am…

VPS拨号服务器:独享的高效与安全

在当今互联网高速发展的时代&#xff0c;虚拟私人服务器&#xff08;VPS&#xff09;已成为许多企业和个人用户托管网站、应用程序的首选。特别是带有拨号功能的VPS服务器&#xff0c;以其独特的优势受到广泛关注。本文将深入探讨VPS拨号服务器的独享特性&#xff0c;以及它如何…

Vue 使用Audio或AudioContext播放本地音频

使用Audio 第一种 使用标签方式 <audio src"./tests.mp3" ref"audio"></audio><el-button click"audioPlay()">播放Audio</el-button>audioPlay() {this.$refs.audio.currentTime 0;this.$refs.audio.play();// this.$…

c++方法

std::transform方法 std::transform 是 C 标准库算法中的一个非常有用的函数&#xff0c;它定义在头文件 中。这个函数用于将给定范围内的每个元素按照指定的操作进行转换&#xff0c;并将转换结果存储在另一个位置&#xff08;可以是原始范围的另一个容器&#xff0c;或者完全…

HarmonyOS应用开发前景及使用工具

HarmonyOS应用开发001 文章目录 前言前景一、技术特性二、使用工具1.项目目录结构 前言 学习之前&#xff0c;需要有一定的开发基础&#xff08;如&#xff1a;java、c#、c、WEB前端的一些了解)。 HarmonyOS开发使用的ArkTS&#xff0c;ArkTS是在TS的基础之上进行封装的&#…

外科休克病人的护理

一、引言 休克是外科常见的危急重症之一,它是由于机体遭受强烈的致病因素侵袭后,有效循环血量锐减、组织灌注不足所引起的以微循环障碍、细胞代谢紊乱和器官功能受损为特征的综合征。对于外科休克病人的护理,至关重要。 二、休克的分类 外科休克主要分为低血容量性休克(包括…

VMware Workstation 虚拟机网络配置为与主机使用同一网络

要将 VMware Workstation 虚拟机网络配置为与主机使用同一网络&#xff0c;我们需要将虚拟机的网络适配器设置为桥接模式。具体步骤如下&#xff1a; 配置 VMware Workstation 虚拟机网络为桥接模式 打开 VMware Workstation&#xff1a; 启动 VMware Workstation。 选择虚拟机…

博客网站目录网址导航自适应主题php源码

开源免费 博客屋网址导航自适应主题php源码v1.0是一款免费开源的PHP分类导航建站程序&#xff0c;源代码公开且无任何加密代码、安全有保障、无后门隐患。 系统稳定 内核安全稳定、PHPMYSQL/Sqlite架构、跨平台运行;版本自带ico接口集成&#xff0c;添加网站时&#xff0c;可自…

PostGIS2.4服务器编译安装

PostGIS的最新版本已经到3.5&#xff0c;但是还有一些国产数据库内核使用的旧版本的PostgreSQL&#xff0c;支持PostGIS2.4。但PostGIS2.4的版本已经在yum中找不到了&#xff0c;安装只能通过本地编译的方式。这里介绍一下如何在Centos7的系统上&#xff0c;编译部署PostGIS2.4…

实验场:在几分钟内使用 Bedrock Anthropic Models 和 Elasticsearch 进行 RAG 实验

作者&#xff1a;来自 Elastic Joe McElroy, Aditya Tripathi 我们最近发布了 Elasticsearch Playground&#xff0c;这是一个新的低代码界面&#xff0c;开发人员可以通过 A/B 测试 LLM、调整提示&#xff08;prompt&#xff09;和分块数据来迭代和构建生产 RAG 应用程序。今天…

Web3学习路线图,从入门到精通

前面我们聊了Web3的知识图谱&#xff0c;内容是相当的翔实&#xff0c;要从哪里入手可以快速的入门Web3&#xff0c;本篇就带你看看Web3的学习路线图&#xff0c;一步一步深入学习Web3。 这张图展示了Web3学习路线图&#xff0c;涵盖了区块链基础知识、开发方向、应用开发等内…

桥接模式案例

桥接模式&#xff08;Bridge Pattern&#xff09;是一种结构型设计模式&#xff0c;它将抽象部分与实现部分分离&#xff0c;使它们可以独立变化。桥接模式通过创 建一个桥接接口&#xff0c;将抽象部分和实现部分连接起来&#xff0c;从而实现两者的解耦。下面是一个详细的桥接…

接上一回C++:补继承漏洞+多态原理(带图详解)

引子&#xff1a;接上一回我们讲了继承的分类与六大默认函数&#xff0c;其实继承中的菱形继承是有一个大坑的&#xff0c;我们也要进入多态的学习了。 注意&#xff1a;我学会了&#xff0c;但是讲述上可能有一些不足&#xff0c;希望大家多多包涵 继承复习&#xff1a; 1&…

windows环境下基于3DSlicer 源代码编译搭建工程开发环境详细操作过程和中间关键错误解决方法说明

说明: 该文档适用于  首次/重新 搭建3D-Slicer工程环境  Clean up(非增量) 编译生成 1. 3D-slicer 软件介绍 (1)3D Slicer为处理MRI\CT等图像数据软件,可以实行基于MRI图像数据的目标分割、标记测量、坐标变换及三维重建等功能,其源于3D slicer 4.13.0-2022-01-19开…

duplicate key value violates unique constraint

duplicate key value violates unique constraint 遇到的问题 你在尝试向数据库表 goods 插入新记录时&#xff0c;收到了 duplicate key value violates unique constraint 的错误。尽管你确认数据库中没有与尝试插入的 id 相同的记录&#xff0c;但错误依旧存在。进一步的调…