逐行分析Transformer的程序代码,最后免费附上该代码!!

在这里插入图片描述

1. 代码详细解释

1. 第一段代码

这段代码首先定义了一些参数,包括编码器个数、输入维度、句子长度、词嵌入维度等。然后它保存了这些超参数到指定路径。接着,它加载训练和验证数据集,并创建了对应的数据加载器。之后,它定义了一个模型,使用了一个叫做DSCTransformer的模型,以及交叉熵损失函数和Adam 优化器。最后,它将模型移动到可用的设备(如果有 GPU 则移动到 GPU,否则移动到 CPU)。

def train(model_save_path, train_result_path, val_result_path, hp_save_path, epochs=100):#定义参数N = 4  # 编码器个数input_dim = 1024  # 输入维度seq_len = 16  # 句子长度d_model = 64  # 词嵌入维度d_ff = 256  # 全连接层维度head = 4  # 注意力头数dropout = 0.1  # Dropout 比率lr = 3E-5  # 学习率batch_size = 64  # 批大小# 保存超参数hyper_parameters = {'任务编码器堆叠数: ': '{}'.format(N),'全连接层维度: ': '{}'.format(d_ff),'任务注意力头数: ': '{}'.format(head),'dropout: ': '{}'.format(dropout),'学习率: ': '{}'.format(lr),'batch_size: ': '{}'.format(batch_size)}fs = open(hp_save_path, 'w')  # 打开文件以保存超参数fs.write(str(hyper_parameters))  # 将超参数写入文件fs.close()  # 关闭文件# 加载数据train_path = r'.\data\train\train.csv'  # 训练数据路径val_path = r'.\data\val\val.csv'  # 验证数据路径train_dataset = MyDataset(train_path, 'fd')  # 加载训练数据集val_dataset = MyDataset(val_path, 'fd')  # 加载验证数据集train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)  # 创建训练数据加载器val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=True, drop_last=True)  # 创建验证数据加载器# 定义模型model = DSCTransformer(input_dim=input_dim, num_classes=10, dim=d_model, depth=N,heads=head, mlp_dim=d_ff, dim_head=d_model, emb_dropout=dropout, dropout=dropout)  # 初始化模型criterion = nn.CrossEntropyLoss()  # 定义损失函数params = [p for p in model.parameters() if p.requires_grad]  # 获取模型参数optimizer = optim.Adam(params, lr=lr)  # 定义优化器device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')  # 判断是否有可用的 GPUprint("using {} device.".format(device))  # 打印使用的设备model.to(device)  # 将模型移动到对应的设备(GPU 或 CPU)

2. 第二段代码

这段代码是一个训练循环,它用于在每个训练周期(epoch)中训练模型,并在每个周期结束后评估模型的性能。在每个训练周期中,代码首先使用模型在训练数据上进行训练,然后使用模型在验证数据上进行验证,并打印出每个周期的训练损失、训练准确率、验证损失和验证准确率。

best_acc_fd = 0.0  # 初始化最佳准确率为0
train_result = []  # 记录训练结果
result_train_loss = []  # 记录训练损失
result_train_acc = []  # 记录训练准确率
val_result = []  # 记录验证结果
result_val_loss = []  # 记录验证损失
result_val_acc = []  # 记录验证准确率# 训练循环
for epoch in range(epochs):  # 遍历每个训练周期# traintrain_loss = []  # 用于记录每个批次的训练损失train_acc = []  # 用于记录每个批次的训练准确率model.train()  # 将模型设置为训练模式train_bar = tqdm(train_loader)  # 创建一个进度条,用于显示训练进度for datas, labels in train_bar:  # 遍历训练数据加载器的每个批次optimizer.zero_grad()  # 梯度清零outputs = model(datas.float().to(device))  # 前向传播loss = criterion(outputs, labels.type(torch.LongTensor).to(device))  # 计算损失loss.backward()  # 反向传播optimizer.step()  # 更新模型参数# 计算准确率acc = (outputs.argmax(dim=-1) == labels.to(device)).float().mean()# 记录损失和准确率train_loss.append(loss.item())train_acc.append(acc)# 更新进度条的显示信息train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1, epochs, loss.item())# valmodel.eval()  # 将模型设置为评估模式valid_loss = []  # 用于记录每个批次的验证损失valid_acc = []  # 用于记录每个批次的验证准确率val_bar = tqdm(val_loader)  # 创建一个进度条,用于显示验证进度for datas, labels in val_bar:  # 遍历验证数据加载器的每个批次with torch.no_grad():  # 禁止梯度计算outputs = model(datas.float().to(device))  # 前向传播loss = criterion(outputs, labels.type(torch.LongTensor).to(device))  # 计算损失# 计算准确率acc = (outputs.argmax(dim=-1) == labels.to(device)).float().mean()# 记录损失和准确率valid_loss.append(loss.item())valid_acc.append(acc)# 更新进度条的显示信息val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1, epochs)# 打印每个周期的训练和验证损失以及准确率print(f"[{epoch + 1:02d}/{epochs:02d}] train loss = "f"{sum(train_loss) / len(train_loss):.5f}, train acc = {sum(train_acc) / len(train_acc):.5f}", end="  ")print(f"valid loss = {sum(valid_loss) / len(valid_loss):.5f}, valid acc = {sum(valid_acc) / len(valid_acc):.5f}")

3. 第三段代码

这段代码是用于记录训练和验证结果,并保存这些结果到文件中。

result_train_loss.append(sum(train_loss) / len(train_loss))
result_train_acc.append((sum(train_acc) / len(train_acc)).item())
result_val_loss.append(sum(valid_loss) / len(valid_loss))
result_val_acc.append((sum(valid_acc) / len(valid_acc)).item())

这几行代码分别计算并记录了每个训练周期中的平均训练损失、平均训练准确率、平均验证损失和平均验证准确率。


if best_acc_fd <= sum(valid_acc) / len(valid_acc):best_acc_fd = sum(valid_acc) / len(valid_acc)torch.save(model.state_dict(), model_save_path)

这段代码用于更新最佳验证准确率并保存最佳模型参数。如果当前的验证准确率大于之前记录的最佳准确率,则更新最佳准确率为当前准确率,并保存当前模型参数到指定路径。


train_result.append(result_train_loss)
train_result.append(result_train_acc)
val_result.append(result_val_loss)
val_result.append(result_val_acc)

这里将每个训练和验证结果存储到对应的列表中。


np.savetxt(train_result_path, np.array(train_result), fmt='%.5f', delimiter=',')
np.savetxt(val_result_path, np.array(val_result), fmt='%.5f', delimiter=',')

最后,将训练和验证结果保存到文件中。它使用了 NumPy 库的 savetxt 函数将列表转换为数组,并将数组保存到指定路径的文件中。


2.附上所有代码

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from model import DSCTransformer
from tqdm import tqdm
from data_set import MyDataset
from torch.utils.data import DataLoaderdef train(model_save_path, train_result_path, val_result_path, hp_save_path, epochs=100):#定义参数N = 4 #编码器个数input_dim = 1024seq_len = 16 #句子长度d_model = 64 #词嵌入维度d_ff = 256 #全连接层维度head = 4 #注意力头数dropout = 0.1lr = 3E-5 #学习率batch_size = 64#保存超参数hyper_parameters = {'任务编码器堆叠数: ': '{}'.format(N),'全连接层维度: ': '{}'.format(d_ff),'任务注意力头数: ': '{}'.format(head),'dropout: ': '{}'.format(dropout),'学习率: ': '{}'.format(lr),'batch_size: ': '{}'.format(batch_size)}fs = open(hp_save_path, 'w')fs.write(str(hyper_parameters))fs.close()#加载数据train_path = r'.\data\train\train.csv'val_path = r'.\data\val\val.csv'train_dataset = MyDataset(train_path, 'fd')val_dataset = MyDataset(val_path, 'fd')train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=True, drop_last=True)#定义模型model = DSCTransformer(input_dim=input_dim, num_classes=10, dim=d_model, depth=N,heads=head, mlp_dim=d_ff, dim_head=d_model, emb_dropout=dropout, dropout=dropout)criterion = nn.CrossEntropyLoss()params = [p for p in model.parameters() if p.requires_grad]optimizer = optim.Adam(params, lr=lr)device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')print("using {} device.".format(device))model.to(device)best_acc_fd = 0.0train_result = []result_train_loss = []result_train_acc = []val_result = []result_val_loss = []result_val_acc = []#训练for epoch in range(epochs):#traintrain_loss = []train_acc = []model.train()train_bar = tqdm(train_loader)for datas, labels in train_bar:optimizer.zero_grad()outputs = model(datas.float().to(device))loss = criterion(outputs, labels.type(torch.LongTensor).to(device))loss.backward()optimizer.step()# torch.argmax(dim=-1), 求每一行最大的列序号acc = (outputs.argmax(dim=-1) == labels.to(device)).float().mean()# Record the loss and accuracytrain_loss.append(loss.item())train_acc.append(acc)train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1, epochs, loss.item())#valmodel.eval()valid_loss = []valid_acc = []val_bar = tqdm(val_loader)for datas, labels in val_bar:with torch.no_grad():outputs = model(datas.float().to(device))loss = criterion(outputs, labels.type(torch.LongTensor).to(device))acc = (outputs.argmax(dim=-1) == labels.to(device)).float().mean()# Record the loss and accuracyvalid_loss.append(loss.item())valid_acc.append(acc)val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1, epochs)print(f"[{epoch + 1:02d}/{epochs:02d}] train loss = "f"{sum(train_loss) / len(train_loss):.5f}, train acc = {sum(train_acc) / len(train_acc):.5f}", end="  ")print(f"valid loss = {sum(valid_loss) / len(valid_loss):.5f}, valid acc = {sum(valid_acc) / len(valid_acc):.5f}")result_train_loss.append(sum(train_loss) / len(train_loss))result_train_acc.append((sum(train_acc) / len(train_acc)).item())result_val_loss.append(sum(valid_loss) / len(valid_loss))result_val_acc.append((sum(valid_acc) / len(valid_acc)).item())if best_acc_fd <= sum(valid_acc) / len(valid_acc):best_acc_fd = sum(valid_acc) / len(valid_acc)torch.save(model.state_dict(), model_save_path)train_result.append(result_train_loss)train_result.append(result_train_acc)val_result.append(result_val_loss)val_result.append(result_val_acc)np.savetxt(train_result_path, np.array(train_result), fmt='%.5f', delimiter=',')np.savetxt(val_result_path, np.array(val_result), fmt='%.5f', delimiter=',')if __name__ == '__main__':group_index = 4for i in range(5):model_save_path = "result/result_own_noisy/group{}/exp0{}/model.pth".format(group_index, i + 1)hp_save_path = "result/result_own_noisy/group{}/parameters.txt".format(group_index)train_result_path = "result/result_own_noisy/group{}/exp0{}/train_result.txt".format(group_index, i + 1)val_result_path = "result/result_own_noisy/group{}/exp0{}/val_result.txt".format(group_index, i + 1)train(model_save_path, train_result_path, val_result_path, hp_save_path)

3. 所有文件链接如下:

在这里插入图片描述
链接:https://pan.baidu.com/s/12SEwGc36TN-jAbfx5fmHrw
提取码:rmue

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/2704.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

PHP 爬虫如何配置代理 IP(CURL 函数)

在 PHP中 配置代理IP&#xff0c;可以通过设置 CURL 库的选项来实现&#xff0c;代码如下&#xff1a; 当然你要有代理ip来源&#xff0c;比如我用的这个 代理商 &#xff0c;如果想服务稳定不建议找开源代理池&#xff0c;避免被劫持。 <?php // 初始化cURL会话 $ch cu…

xgp会员一年多少钱?xgp一个月多少钱?微软商店xgp会员价格指南

xgp是xbox游戏平台。xgp是类似于steam、epic等&#xff0c;拥有丰富游戏资源的平台。该平 台的全称为“XBox Game Pass”&#xff0c;俗称为“西瓜皮”。xgp是会员订阅模式&#xff0c;开启会员后&#xff0c;所有游戏资源都为你开放。pc版的&#xff0c;第一个月10港币&#x…

基于springboot+vue+Mysql的漫画网站

开发语言&#xff1a;Java框架&#xff1a;springbootJDK版本&#xff1a;JDK1.8服务器&#xff1a;tomcat7数据库&#xff1a;mysql 5.7&#xff08;一定要5.7版本&#xff09;数据库工具&#xff1a;Navicat11开发软件&#xff1a;eclipse/myeclipse/ideaMaven包&#xff1a;…

L1-099 帮助色盲 - java

L1-099 帮助色盲 代码长度限制 16 KB 时间限制 400 ms 内存限制 64 MB 栈限制 8192 KB 题目描述&#xff1a; 在古老的红绿灯面前&#xff0c;红绿色盲患者无法分辨当前亮起的灯是红色还是绿色&#xff0c;有些聪明人通过路口的策略是这样的&#xff1a;当红灯或绿灯亮起时&am…

WebServer项目介绍文章【四叶专属】

Linux项目实战C轻量级Web服务器源码分析TinyWebServer 书接上文&#xff0c;学习开源项目的笔记没想到居然有不少阅读量&#xff0c;后面结合另一个前端开源项目简单做了点修改&#xff0c;没想到居然有需要的同学&#xff0c;那么我就专门为四叶开一篇文章吧&#xff0c;【源码…

【c++】探究C++中的list:精彩的接口与仿真实现解密

&#x1f525;个人主页&#xff1a;Quitecoder &#x1f525;专栏&#xff1a;c笔记仓 朋友们大家好&#xff0c;本篇文章来到list有关部分&#xff0c;这一部分函数与前面的类似&#xff0c;我们简单讲解&#xff0c;重难点在模拟实现时的迭代器有关实现 目录 1.List介绍2.接…

【网络安全】跨站脚本攻击(XSS)

专栏文章索引&#xff1a;网络安全 有问题可私聊&#xff1a;QQ&#xff1a;3375119339 目录 一、XSS简介 二、XSS漏洞危害 三、XSS漏洞类型 1.反射型XSS 2.存储型XSS 3.DOM型XSS 四、XSS漏洞防御 一、XSS简介 XSS&#xff08;Cross-Site Scripting&#xff09; XSS 被…

Git merge的版本冲突实验

实验目的 发现 两个分支的 相同文件 怎样被修改 才会发生冲突&#xff1f; 实验过程 1.初始状态 现在目前有1.py、2.py两个文件&#xff0c;已经被git管理。现在我想制造冲突&#xff0c;看怎样的修改会发生冲突&#xff0c;先看怎么不会发生冲突。 目前仓库里的版本是这样…

C语言实现简单CRC校验

目录 一、实现题目 二、send模块 三、receive模块 四、运行截图 一、实现题目 二、send模块 #include <stdio.h> #include <string.h>// 执行模2除法&#xff0c;并计算出余数&#xff08;CRC校验码&#xff09; //dividend被除, divisor除数 void divide…

免费SSL证书和付费SSL证书区别在哪

免费SSL证书与付费SSL证书在多个方面存在差异&#xff0c;这些差异主要体现在认证级别、保障金额以及服务范围上。在以下几个方面存在显著区别&#xff1a; 1、验证类型和信任级别&#xff1a; 免费SSL证书&#xff1a;通常只提供域名验证&#xff08;DV&#xff09;级别的证…

实验:使用apache + yum实现自制yum仓库

实验准备 Web服务器端&#xff1a;cenos-1&#xff08;IP&#xff1a;10.9.25.33&#xff09; 客户端&#xff1a;centos-2 保证两台机器网络畅通&#xff0c;原yum仓库可用&#xff0c;关闭防火墙和selinux Web服务器端 ①安装httpd并运行&#xff0c;设置开机自启动 安装…

多模态模型

转换器成功作为构建语言模型的一种方法&#xff0c;促使 AI 研究人员考虑同样的方法是否对图像数据也有效。 研究结果是开发多模态模型&#xff0c;其中模型使用大量带有描述文字的图像进行训练&#xff0c;没有固定的标签。 图像编码器基于像素值从图像中提取特征&#xff0c;…

力扣数据库题库学习(4.23日)

610. 判断三角形 问题链接 解题思路 题目要求&#xff1a;对每三个线段报告它们是否可以形成一个三角形。以 任意顺序 返回结果表。 对于三个线段能否组成三角形的判定&#xff1a;任意两边之和大于第三边&#xff0c;对于这个表内的记录&#xff0c;要求就是&#xff08;x…

Maven基础篇7

私服-idea访问私服与组件上传 公司团队开发流程 本地上传–>repository–>私服 其他成员从私服拿 1.项目完成后发布到私服 在pom文件最后写上发布的配置管理 ​ //写发布的url也就是你发布到哪一个版本&#xff0c;以及写入id ​ ​ 发布的时候&#xff0c;将项…

安装Selenium

安装Selenium 【0】引言 ​ 由于sleenium4.1.0需要python3.7以上方可支持&#xff0c;请注意自己的python版本。 【1】使用Pycharm安装 使用 快捷键 Ctrl Alt S 【2】使用 pip 安装 Python3.x安装后就默认就会有pip&#xff08;pip.exe默认在python的Scripts路径下&…

VUE2版本的仿微信通讯录侧滑列表

<template><!-- Vue模板部分 --><div><div v-for"(group, index) in groupedArray" :key"index" ref"indexcatch"><h2>{{ letter[index] }}</h2><ul><li v-for"item in group" :key&quo…

Notepad++使用SFTP连接虚拟机编辑文档

一.前言 当我们在虚拟机中使用vim编辑有时候不太方便&#xff0c;可以使用远程工具连接进行编辑。 常用的远程连接编辑方式有 vscode下载remote-ssh插件notepad下载nppftp插件finallshell中可以直接打开文件编辑xftp软件 根据个人习惯去选择使用即可。 这里分享一下notepa…

模型训练时报错Failed to allocate 12192768 bytes in function ‘cv::OutOfMemoryError‘

目录 报错信息&#xff1a; 查找网上解决方法&#xff1a; 改进思路&#xff1a; 改进方法&#xff1a; 报错信息&#xff1a; D:\Programs\miniconda3\envs\python311\python.exe D:\python\project\VisDrone2019-DET-MOT\train.py Ultralytics YOLOv8.1.9 &#x1f680…

【Linux】gdb的简单使用

文章目录 一、gdb是什么&#xff1f;二、使用说明1. 安装2. 注意事项3. 常用调试指令3.1 gdb3.2 l3.3 r3.4 n3.5 s3.6 b3.7 info b3.8 finish3.9 p3.10 set var3.11 c3.12 d breakpoints3.13 d n3.14 disable/enable breakpoints3.15 disable/enable n3.16 info b3.17 display …

复习python函数

复习python函数 1.对函数的理解函数的传递方式返回值 return可通过help()函数查看函数说明作用域 2.不定长参数3.递归4.高阶函数将函数作为参数传递将函数作为返回值返回 5.匿名函数6.装饰器 1.对函数的理解 函数可以用来保存一些可执行的代码&#xff0c;并且可以在需要时&am…