【NLP 9、实践 ① 五维随机向量交叉熵多分类】

目录

五维向量交叉熵多分类

规律:

实现:

1.设计模型

2.生成数据集

3.模型测试

4.模型训练

5.对训练的模型进行验证

调用模型


你的平静,是你最强的力量

                                —— 24.12.6

五维向量交叉熵多分类

规律:

x是一个五维(索引)向量,对x做五分类任务

改用交叉熵实现一个多分类任务,五维随机向量中最大的数字在哪维就属于哪一类


实现:

1.设计模型

Linear():模型函数中定义线性层

activation = nn.Softmax(dim=1):定义激活层为softmax激活函数

nn.CrossEntropyLoss() / nn.functional.cross_entropy:定义交叉熵损失函数

pyTorch中定义的交叉熵损失函数内部封装了softMax函数, 而使用交叉熵必须使用softMax函数,对数据进行归一化

经过 Softmax 归一化后,输出向量的每个元素可以被解释为样本属于相应类别的概率。这使得我们能够直接比较不同类别上的概率大小,并且与真实的类别概率分布(如one-hot编码)进行合理的对比。

例如,在一个三分类问题中,经过 Softmax 后的输出可能是[0.2,0.3,0.5],我们可以直观地说样本属于第三类的概率是 0.5,这是一个符合概率意义的解释

forward函数,前向计算,定义网络的使用方式,声明模型计算过程

# 1.设计模型
class TorchModel(nn.Module):def __init__(self, input_size):super(TorchModel, self).__init__()# 预测出一个五维的向量,五维向量代表五个类别上的概率分布self.linear = nn.Linear(input_size, 5)  # 线性层# 类交叉熵写法:CrossEntropyLoss()     函数交叉熵写法:cross_entropy# nn.CrossEntropyLoss() pycharm交叉的熵损失函数内部封装了softMax函数, 而使用交叉熵必须使用softMax函数self.loss = nn.functional.cross_entropy # loss函数采用交叉熵损失self.activation = nn.Softmax(dim=1)# 当输入真实标签,返回loss值;无真实标签,返回预测值def forward(self, x, y=None):# 输入过第一个网络层y_pred = self.linear(x)  # (batch_size, input_size) -> (batch_size, 1)if y is not None:return self.loss(y_pred, y)  # 预测值和真实值计算损失else:return self.activation(y_pred)  # 输出预测结果# return y_pred

2.生成数据集

由于题目要求,要在一个五维随机向量中查找标量最大的数所在维度,所以用np.random函数随机生成一个五维向量,然后通过np.argmax函数找出生成向量中最大标量所对应的维度,并将其作为数据 x标注 y 返回

当我们输出一串数字,要告诉模型输出的是一串单独的数而不是一串样本时,需要用到 "[ ]",换句话说当y是单独的一个数(标量)时,才需要加“[ ]”

而该模型输出的预测结果是一个向量,而不是一个数(标量的概率)时,不需要拼在一起

# 2.生成数据集标签label   数据构建
# 生成一个样本, 样本的生成方法,代表了我们要学习的规律,随机生成一个5维向量,如果第一个值大于第五个值,认为是正样本,反之为负样本
def build_sample():x = np.random.random(5)# 获取最大值对应的索引max_index = np.argmax(x)return x, max_index# 随机生成一批样本
# 正负样本均匀生成
def build_dataset(total_sample_num):X = []Y = []# 随机生成样本,total_sample_num 生成的随机样本数for i in range(total_sample_num):x, y = build_sample()X.append(x)# 当我们输出一串数字,要告诉模型输出的是一串单独的数而不是一串样本时,需要用到"[]",换句话说当y是单独得一个数(标量)时,才需要加“[]”# 而该模型输出的预测结果是一个向量,而不是一个数(标量的概率)时,不需要拼在一起Y.append(y)X_array = np.array(X)Y_array = np.array(Y)# 一般torch中的Long整形类型用来判定类型return torch.FloatTensor(X_array), torch.LongTensor(Y_array)

3.模型测试

用来测试每轮模型预测的精确度

model.eval():声明模型框架在这个函数中不做训练

with torch.no_grad():在模型测试的部分中,声明是测试函数,不计算梯度,增加模型训练效率

zip():zip 函数是一个内置函数,用于将多个可迭代对象(如列表、元组、字符串等)中对应的元素打包成一个个元组,然后返回由这些元组组成的可迭代对象(通常是一个 zip 对象)。如果各个可迭代对象的长度不一致,那么 zip 操作会以最短的可迭代对象长度为准。

# 3.模型测试
# 用来测试每轮模型的准确率
def evaluate(model):model.eval()test_sample_num = 100x, y = build_dataset(test_sample_num)print("本次预测集中共有%d个正样本,%d个负样本" % (sum(y), test_sample_num - sum(y)))correct, wrong = 0, 0with torch.no_grad():y_pred = model(x)  # 模型预测 model.forward(x)for y_p, y_t in zip(y_pred, y):  # 与真实标签进行对比# np.argmax是求最大数所在维,max求最大数,torch.argmax是求最大数所在维if torch.argmax(y_p) == int(y_t):correct += 1  # 正确预测加一else:wrong += 1  # 错误预测加一print("正确预测个数:%d, 正确率:%f" % (correct, correct / (correct + wrong)))return correct / (correct + wrong)

4.模型训练

① 配置参数        

② 建立模型

③ 选择优化器(Adam)

④ 读取训练集

⑤ 训练过程

        Ⅰ、model.train():设置训练模式

        Ⅱ、对训练集样本开始循环训练(循环取出训练数据)

        Ⅲ、根据模型函数和损失函数的定义计算模型损失

        Ⅳ、计算梯度

        Ⅴ、通过梯度用优化器更新权重

        Ⅵ、计算完一轮训练数据后梯度进行归零,下一轮重新计算

torch.save(model.state_dict(), "model.pt"):模型保存model.pt文件

一般任务不同只需更改数据读取(步骤③)模型构建(步骤①)内容,训练过程一般无需更改,evaluate测试代码可能也需更改,因为不同模型测试正确率的方式不同

# 4.模型训练
def main():# 配置参数epoch_num = 20  # 训练轮数batch_size = 20  # 每次训练样本个数train_sample = 5000  # 每轮训练总共训练的样本总数input_size = 5  # 输入向量维度learning_rate = 0.001  # 学习率# ① 建立模型model = TorchModel(input_size)# ② 选择优化器optim = torch.optim.Adam(model.parameters(), lr=learning_rate)log = []# ③ 创建训练集,正常任务是读取训练集train_x, train_y = build_dataset(train_sample)# 训练过程# 轮数进行自定义for epoch in range(epoch_num):model.train()watch_loss = []# ④ 读取数据集for batch_index in range(train_sample // batch_size):x = train_x[batch_index * batch_size : (batch_index + 1) * batch_size]y = train_y[batch_index * batch_size : (batch_index + 1) * batch_size]# ⑤ 计算lossloss = model(x, y)  # 计算loss  model.forward(x,y)# ⑥ 计算梯度loss.backward()  # 计算梯度# ⑦ 权重更新optim.step()  # 更新权重# ⑧ 梯度归零optim.zero_grad()  # 梯度归零watch_loss.append(loss.item())# 一般任务不同只需更改数据读取(步骤③)和模型构建(步骤①)内容,训练过程一般无需更改,evaluate测试代码可能也需更改,因为不同模型测试正确率的方式不同print("=========\n第%d轮平均loss:%f" % (epoch + 1, np.mean(watch_loss)))acc = evaluate(model)  # 测试本轮模型结果log.append([acc, float(np.mean(watch_loss))])# 保存模型torch.save(model.state_dict(), "model.pt")# 画图print(log)plt.plot(range(len(log)), [l[0] for l in log], label="acc")  # 画acc曲线plt.plot(range(len(log)), [l[1] for l in log], label="loss")  # 画loss曲线plt.legend()plt.show()return

5.对训练的模型进行验证

调用main函数

if __name__ == "__main__":main()


调用模型

model.eval():声明模型框架在这个函数中不做训练

predict("model.pt", test_vec):调用模型存储的文件model.pt,通过调用模型对数据进行预测

# 使用训练好的模型做预测
def predict(model_path, input_vec):input_size = 5model = TorchModel(input_size)# 加载训练好的权重model.load_state_dict(torch.load(model_path, weights_only=True))# print(model.state_dict())model.eval()  # 测试模式,不计算梯度with torch.no_grad():# 输入一个真实向量转成Tensor,让模型forward一下result = model.forward(torch.FloatTensor(input_vec))  # 模型预测for vec, res in zip(input_vec, result):# python中,round函数是对浮点数进行四舍五入print("输入:%s, 预测类别:%s, 概率值:%s" % (vec, torch.argmax(res), res))  # 打印结果if __name__ == "__main__":test_vec = [[0.97889086,0.15229675,0.31082123,0.03504317,0.88920843],[0.74963533,0.5524256,0.95758807,0.95520434,0.84890681],[0.00797868,0.67482528,0.13625847,0.34675372,0.19871392],[0.09349776,0.59416669,0.92579291,0.41567412,0.1358894]]predict("model.pt", test_vec)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/63259.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

windows文件下换行, linux上不换行 解决CR换行符替换为LF notepad++

html文件是用回车换行的,在windows电脑上,显示正常。 文件上传到linux服务器后,文件不换行了。只有一行。而且相关js插件也没法正常运行。 用notepad查看,显示尾部换行符,是CR,这就是原因。CR是不被识别的。…

ES6关于解构的详细探讨,以及可能会出现的错误

ES6关于解构的详细探讨,以及可能会出现的错误 1.解构赋值时,如果等号右边是数值和布尔值,则会先转为对象。2.字符串的解构赋值,字符串被转换成了一个类似数组的对象3.默认值生效的条件是,对象的属性值严格等于undefined。4.不能使用圆括号的情…

Unity 模拟百度地图,使用鼠标控制图片在固定区域内放大、缩小、鼠标左键拖拽移动图片

效果展示: 步骤流程: 1.使用的是UGUI,将下面的脚本拖拽到图片上即可。 using UnityEngine; using UnityEngine.UI; using UnityEngine.EventSystems;public class CheckImage : MonoBehaviour, IDragHandler, IBeginDragHandler, IEndDragH…

游戏引擎学习第30天

仓库: https://gitee.com/mrxiao_com/2d_game 回顾 在这段讨论中,重点是对开发过程中出现的游戏代码进行梳理和进一步优化的过程。 工作回顾:在第30天,回顾了前一天的工作,并提到今天的任务是继续从第29天的代码开始&#xff0c…

基于MFC绘制门电路

MFC绘制门电路 1. 设计内容、方法与难点 本课题设计的内容包括了基本门电路中与门和非门的绘制、选中以及它们之间的连接。具体采用的方法是在OnDraw函数里面进行绘制,并设计元器件基类,派生出与门和非门,并组合了一个引脚类,在…

【text2sql】低资源场景下Text2SQL方法

SFT使模型能够遵循输入指令并根据预定义模板进行思考和响应。如上图,、 和 是用于通知模型在推理过程中响应角色的角色标签。 后面的内容表示模型需要遵循的指令,而 后面的内容传达了当前用户对模型的需求。 后面的内容代表模型的预期输出,也…

学习threejs,实现配合使用WebWorker

👨‍⚕️ 主页: gis分享者 👨‍⚕️ 感谢各位大佬 点赞👍 收藏⭐ 留言📝 加关注✅! 👨‍⚕️ 收录于专栏:threejs gis工程师 文章目录 一、🍀前言1.1 ☘️WebWorker web端多线程 二、…

16-03、JVM系列之:内存与垃圾回收篇(三)

JVM系列之:内存与垃圾回收篇(三) ##本篇内容概述: 1、执行引擎 2、StringTable 3、垃圾回收一、执行引擎 ##一、执行引擎概述 如果想让一个java程序运行起来,执行引擎的任务就是将字节码指令解释/编译为对应平台上的本地机器指令才可以。 简…

正逆断层剪应力方向

正断层(Normal Fault): 在正断层中,上盘相对于下盘向下滑动。由于正断层是由垂直拉伸应力引起的,因此,剪应力的方向实际上是指向下盘的,也就是说,剪应力的作用是沿断层面从上盘向下盘…

Android11.0系统关闭App所有通知

通过广播接收方式&#xff0c;根据包名关闭App所有通知。 packages/apps/Settings$ git diff diff --git a/AndroidManifest.xml b/AndroidManifest.xml index d4c54c6ed8..1ce7d4136f 100644 --- a/AndroidManifest.xmlb/AndroidManifest.xml-106,6 106,7 <uses-permissio…

小程序 - 美食列表

小程序交互练习 - 美食列表小程序开发笔记 目录 美食列表 功能描述 准备工作 创建项目 配置页面 配置导航栏 启动本地服务器 页面初始数据 设置获取美食数据 设置onload函数 设置项目配置 页面渲染 页面样式 处理电话格式 创建处理电话格式脚本 页面引入脚本 …

Qt6.8 QGraphicsView鼠标坐标点偏差

ui文件拖放QGraphicsView&#xff0c;src文件定义QGraphicsScene赋值给图形视图。 this->scene new QGraphicsScene();ui.graph->setScene(this->scene);对graphicview过滤事件&#xff0c;只能在其viewport之后安装&#xff0c;否则不响应。 ui.graph->viewport…

springboot/ssm购物系统Java代码web项目在线购物商城电商源码

springboot/ssm购物系统Java代码web项目在线购物商城电商源码 基于springboot(可改ssm)vue项目 开发语言&#xff1a;Java 框架&#xff1a;springboot/可改ssm vue JDK版本&#xff1a;JDK1.8&#xff08;或11&#xff09; 服务器&#xff1a;tomcat 数据库&#xff1a;m…

若依 ruoyi VUE el-select 直接获取 选择option 的 label和value

1、最新在研究若依这个项目&#xff0c;我使用的是前后端分离的方案&#xff0c;RuoYi-Vue-fast(后端) RuoYi-Vue-->ruoyi-ui(前端)。RuoYi-Vue-fast是单应用版本没有区分那么多的modules 自己开发起来很方便&#xff0c;这个项目运行起来很方便&#xff0c;但是需要自定义的…

基于队列(Queue)的部分笔试题

1. 设计一个循环队列&#xff08;环形队列&#xff09; 问题描述&#xff1a; 设计一个支持以下操作的队列&#xff1a; enqueue(int x)&#xff1a;将元素 x 添加到队尾。 dequeue()&#xff1a;移除并返回队头元素。 peek()&#xff1a;返回队头元素&#xff0c;但不移除它…

springboot事务手动回滚报错

捕捉异常之后手动标记回滚事务 TransactionAspectSupport.currentTransactionStatus().setRollbackOnly(); 没有嵌套事务&#xff0c;还是报Transaction rolled back because it has been marked as rollback-only异常错误 查看错误堆栈&#xff0c;service调用的方法外层还套…

Pytorch使用手册- TorchVision目标检测微调Tutorial的使用指南(专题十二)

这篇教程的目标是对一个预训练的 Mask R-CNN 模型进行微调,应用于 Penn-Fudan 行人检测与分割数据集。该数据集包含 170 张图像,里面有 345 个行人实例,我们将通过这个教程来演示如何使用 torchvision 中的新特性,训练一个面向自定义数据集的目标检测和实例分割模型。 注意…

使用 LlamaFactory 结合开源大语言模型实现文本分类:从数据集构建到 LoRA 微调与推理评估

文章目录 背景介绍文本分类数据集Lora 微调模型部署与推理期待模型的输出结果 文本分类评估代码 背景介绍 本文将一步一步地&#xff0c;介绍如何使用llamafactory框架利用开源大语言模型完成文本分类的实验&#xff0c;以 LoRA微调 qwen/Qwen2.5-7B-Instruct 为例。 文本分类…

发论文参考文献部分怎么注明数据集出处gitee

见的参考文献标注格式&#xff08;如APA、MLA、Chicago等&#xff09;&#xff0c;电子文献或网络资源的标注通常包括作者&#xff08;或组织&#xff09;、标题、发布年份、获取路径&#xff08;URL&#xff09;等信息。 二、具体步骤 查找数据集信息&#xff1a; 在Gitee上找…

ARM内核与单片机

1.单片机硬件架构如下所示&#xff1a;各种硬件通过总线进行连接。 2.M4内核架构 3.单片机如何工作&#xff1a; 4.CPU是通过读写寄存器来控制GPIO的 5.GPIO的硬件框架&#xff1a;一共有8种模式 &#xff08;1&#xff09;推挽/推挽复用输出。下图先看图1&#xff0c;如果输入…