nlp(6)--构建找规律模型任务

前言

仅记录学习过程,有问题欢迎讨论

包含了两个例子
第一个为5分类任务
第二个为2分类任务
Demo1比Demo2难一点,放上边方便以后看。
练习顺序为 Demo2—>Demo1

代码

DEMO1:
"""
自定义一个模型
解决 5分类问题
问题如下:
给定5维向量,0-4下标哪个值对应最大,为对应分类
如 [1,3,4,1,7] 为 5 分类
如 [9,3,1,6,2] 为 1 分类"""
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data as Data
import matplotlib.pyplot as pltclass TorchModel(nn.Module):def __init__(self, input_size):super(TorchModel, self).__init__()# 5分类任务 y =0~4self.linear = nn.Linear(input_size, 5)# 激活函数self.activation = torch.sigmoid  # sigmoid做激活函数# loss 交叉熵self.loss = nn.functional.cross_entropy# 传数据进来def forward(self, x, y=None):x = self.linear(x)y_pred = self.activation(x)if y is None:return y_predelse:return self.loss(y_pred, y.long())# def test():
#     x = torch.tensor(np.random.random(5), dtype=torch.float32)
#     y = torch.tensor(np.array(1), dtype=torch.long)
#     print(x.dtype)
#     print(y.dtype)
#     ce_loss = nn.CrossEntropyLoss()
#     loss = ce_loss(x, y)
#     print(loss)
#
# test()def build_dataset(num):X = []Y = []for i in range(num):x = np.random.random(5)X.append(x)# 获取最大的值的indexmax_val, max_index = torch.max(torch.tensor(x), 0)Y.append(max_index)return torch.FloatTensor(np.array(X)), torch.FloatTensor(np.array(Y))# evaluate accuracy
def evaluate(model):# testmodel.eval()test_simple_num = 100y_sum = np.zeros(5)x, y_true = build_dataset(test_simple_num)for i in range(test_simple_num):if int(y_true.data[i]) == 0:y_sum[0] += 1elif int(y_true.data[i]) == 1:y_sum[1] += 1elif int(y_true.data[i]) == 2:y_sum[2] += 1elif int(y_true.data[i]) == 3:y_sum[3] += 1else:y_sum[4] += 1print("本轮中y_sum的值为%s", y_sum)correct, wrong = 0, 0# 调用模型with torch.no_grad():y_pred = model(x)for y_p, y_t in zip(y_pred, y_true):# 通过获取最大值的下标来预测结果if int(torch.argmax(y_p)) == int(y_t):correct += 1else:wrong = 1print("正确预测个数:%d / %d, 正确率:%f" % (correct, test_simple_num, correct / (correct + wrong)))return correct / (correct + wrong)def main():batch_size = 10lr = 0.002input_size = 5train_simple = 5000epoch_size = 40# build modelmodel = TorchModel(input_size)# 優化器optim = torch.optim.Adam(model.parameters(), lr=lr)# 訓練的數據X, Y = build_dataset(train_simple)# 分割數據dataset = Data.TensorDataset(X, Y)log = []data_item = Data.DataLoader(dataset, batch_size, shuffle=True)for epoch in range(epoch_size):# start trainingmodel.train()epoch_loss = []# x.shape == 20*5 y_true.shape == 20for x, y_true in data_item:# print(x, y_true)# 交叉熵需要传递整个x,y过去,而非单个的loss = model(x, y_true)# print(loss)# 反向传播过程,在反向传播过程中会计算每个参数的梯度值loss.backward()# 改變權重;所有的 optimizer 都实现了 step() 方法,该方法会更新所有的参数。optim.step()# 将上一轮计算的梯度清零,避免上一轮的梯度值会影响下一轮的梯度值计算optim.zero_grad()epoch_loss.append(loss.data)print("========\n第%d轮平均loss:%f" % (epoch + 1, np.mean(epoch_loss)))# 测试准确率acc = evaluate(model)log.append([acc, float(np.mean(epoch_loss))])# save modeltorch.save(model.state_dict(), "model_work.pt")# 画图# print(log)plt.plot(range(len(log)), [l[0] for l in log], label="acc")  # 画acc曲线plt.plot(range(len(log)), [l[1] for l in log], label="loss")  # 画loss曲线plt.legend()plt.show()return# 测试
def predict(model_path, test_vec_x):# 数据维度input_size = 5model = TorchModel(input_size)# 读取路径model.load_state_dict(torch.load(model_path))# 测试模式model.eval()with torch.no_grad():  # 不计算梯度# 模型预测的结果result = model.forward(torch.FloatTensor(test_vec_x))print(result[1])for i in range(len(test_vec_x)):print(torch.argmax(result[i]), test_vec_x[i])if __name__ == '__main__':# main()test_vec_x = [[0.27889086, 0.15229675, 0.41082123, 0.03504317, 0.18920843],[0.04963533, 0.5524256, 0.95758807, 0.95520434, 0.84890681],[0.98797868, 0.67482528, 0.13625847, 0.34675372, 0.19871392],[0.99349776, 0.59416669, 0.12579291, 0.41567412, 0.7358894]]predict("model_work.pt", test_vec_x)
DEMO2:
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt"""基于pytorch框架编写模型训练
实现一个自行构造的找规律(机器学习)任务
规律:x是一个5维向量,如果第1个数>第5个数,则为正样本,反之为负样本"""# 自定义模型
class TorchModel(nn.Module):def __init__(self, input_size):super(TorchModel, self).__init__()# 1*5 的线性层self.linear = nn.Linear(input_size, 1)# sigmoid归一化函数 激活层self.activation = torch.sigmoid# 均方差的损失函数()self.loss = nn.functional.mse_loss# 当输入真实标签,返回loss值;无真实标签,返回预测值 默认y=nonedef forward(self, x, y=None):x = self.linear(x)  # (batch_size, input_size) -> (batch_size, 1)y_pred = self.activation(x)if y is not None:return self.loss(y_pred, y)else:return y_pred# 构建数据
def build_dataset(size):X = []Y = []for i in range(size):x, y = build_sample()X.append(x)Y.append(y)return torch.FloatTensor(X), torch.FloatTensor(Y)# 生成一个样本, 样本的生成方法,代表了我们要学习的规律
# 随机生成一个5维向量,如果第一个值大于第五个值,认为是正样本,反之为负样本
def build_sample():x = np.random.random(5)if x[0] > x[4]:return x, 1else:return x, 0# 评估目前模型效果
def evaluate(model):# 切换模型到测试模式!!!!model.eval()test_sample_num = 100x, y = build_dataset(test_sample_num)print("本次预测集中共有%d个正样本,%d个负样本" % (sum(y), test_sample_num - sum(y)))correct, wrong = 0, 0# 无需计算梯度with torch.no_grad():y_pred = model(x)  # 模型预测for y_p, y_t in zip(y_pred, y):  # 与真实标签进行对比if float(y_p) < 0.5 and int(y_t) == 0:correct += 1  # 负样本判断正确elif float(y_p) >= 0.5 and int(y_t) == 1:correct += 1  # 正样本判断正确else:wrong += 1print("正确预测个数:%d, 正确率:%f" % (correct, correct / (correct + wrong)))return correct / (correct + wrong)def main():# 配置参数# 训练轮数epoch_num = 30# 小样本个数batch_size = 20# 总样本个数train_simple = 5000# 数据样本维度input_size = 5# 学习率lr = 0.002# 建立模型model = TorchModel(input_size)# 选择优化器optim = torch.optim.Adam(model.parameters(), lr=lr)log = []# 创建训练集train_x, train_y = build_dataset(train_simple)# 训练过程for epoch in range(epoch_num):model.train()# 本轮次损失函数 主要为了检查损失是否下降epoch_loss = []# python中“//”是一个算术运算符,表示整数除法,它可以返回商的整数部分(向下取整)for batch_index in range(train_simple // batch_size):# 代表取出来的具体的x,y_turex = train_x[batch_index * batch_size: (batch_index + 1) * batch_size]y = train_y[batch_index * batch_size: (batch_index + 1) * batch_size]loss = model(x, y)loss.backward()  # 计算梯度(对 loss求导)optim.step()  # 更新权重(学习)optim.zero_grad()  # 梯度归0(不要影响到下一批次)epoch_loss.append(loss.item())# np.mean 表示计算数组元素的平均值print("=========\n第%d轮平均loss:%f" % (epoch + 1, np.mean(epoch_loss)))# 测试本轮模型结果 准确率acc = evaluate(model)log.append([acc, float(np.mean(epoch_loss))])# 保存模型 保存的是模型的权重!torch.save(model.state_dict(), "model.pt")# 画图print(log)plt.plot(range(len(log)), [l[0] for l in log], label="acc")  # 画acc曲线plt.plot(range(len(log)), [l[1] for l in log], label="loss")  # 画loss曲线plt.legend()plt.show()return# 使用训练好的模型做预测
def predict(model_path, input_vec):input_size = 5model = TorchModel(input_size)model.load_state_dict(torch.load(model_path))  # 加载训练好的权重# print(model.state_dict())model.eval()  # 测试模式with torch.no_grad():  # 不计算梯度result = model.forward(torch.FloatTensor(input_vec))  # 模型预测for vec, res in zip(input_vec, result):print("输入:%s, 预测类别:%d, 概率值:%f" % (vec, round(float(res)), res))  # 打印结果if __name__ == "__main__":main()# test_vec = [[0.27889086,0.15229675,0.31082123,0.03504317,0.18920843],#             [0.04963533,0.5524256,0.95758807,0.95520434,0.84890681],#             [0.08797868,0.67482528,0.13625847,0.34675372,0.19871392],#             [0.99349776,0.59416669,0.92579291,0.41567412,0.7358894]]# predict("model.pt", test_vec)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/1479.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SQL概述

1. SQL的分类 SQL语言在功能上主要分为如下3大类&#xff1a; DDL&#xff08;Data Definition Languages、数据定义语言&#xff09;&#xff0c;这些语句定义了不同的数据库、表、视图、索引等数据库对象&#xff0c;还可以用来创建、删除、修改数据库和数据表的结构。主要…

2.1K Star微软开源的高质量 iot库

功能描述 该项目是一个开源的 .NET Core 实现&#xff0c;旨在帮助开发者构建适用于物联网(IoT)设备和场景的应用程序。它提供了与传感器、显示器和输入设备等相互作用所需的 GPIO 引脚、串口等硬件的接口。该仓库包含 System.Device.Gpio 库以及针对各种板卡&#xff08;如 Ra…

redis底层数据结构之ziplist

目录 一、概述二、ziplist结构三、Entry结构四、为什么ZipList特别省内存五、ziplist的缺点 上一篇 redis底层数据结构之SDS 下一篇 明天更新 一、概述 一种连续内存空间存储的顺序数据结构&#xff0c;每个元素可以是字符串或整数。优点:节省内存空间。适用于存储小规模的列表…

STM32 | USART实战案例

STM32 | 通用同步/异步串行接收/发送器USART带蓝牙(第六天)随着扩展的内容越来越多,很多小伙伴已经忘记了之前的学习内容,然后后面这些都很难理解。STM32合集已在专栏创建,方面大家学习。1、通过电脑串口助手发送数据,控制开发板LED灯 从题目中可以挖掘出,本次使用led、延…

【Linux文件系统开发】认知篇

【Linux文件系统开发】认知篇 文章目录 【Linux文件系统开发】认知篇一、文件系统的概念二、文件系统的种类&#xff08;文件管理系统的方法&#xff09;三、分区四、文件系统目录结构五、虚拟文件系统&#xff08;Virtual File System&#xff09;1.概念2.原因3.作用4.总结 一…

[ LeetCode ] 题刷刷(Python)-第35题:搜索插入位置

题目描述 给定一个排序数组和一个目标值&#xff0c;在数组中找到目标值&#xff0c;并返回其索引。如果目标值不存在于数组中&#xff0c;返回它将会被按顺序插入的位置。 nums 为 无重复元素 的 升序 排列数组 请必须使用时间复杂度为 O(log n) 的算法。 示例 示例 1: 输入: …

减肥变成一种趋势!足球直播是一种刺激——早读(逆天打工人爬取热门微信文章解读)

看直播是打发时间的好方式 引言Python 代码第一篇 洞见 跌入粪坑的钟美美&#xff0c;才是真正的“人间清醒”第二篇 人民日报 来了&#xff01;新闻早班车要闻社会政策 结尾 变化是生活的法则 而直播的比赛则是这一法则的缩影 每一秒都可能带来转折和惊喜 充满了未知和奇迹 引…

磁盘损坏无法读取:原因、恢复方案与防范之道

在数字化信息爆炸的时代&#xff0c;磁盘作为数据存储的重要载体&#xff0c;承载着无数重要的文件和资料。然而&#xff0c;当磁盘突然损坏&#xff0c;无法读取数据时&#xff0c;我们往往会陷入困境&#xff0c;焦虑不已。面对这种情况&#xff0c;我们该如何应对&#xff1…

插入排序(insertionSort)

插入排序是一种简单直观的排序算法&#xff0c; 基本思想 将待排序的元素逐个插入到已经排序好的部分中的适当位置&#xff0c;从而得到新的有序序列。核心思想是不断地比较和移动元素&#xff0c;直到找到合适的插入位置。 插入排序的特点&#xff1a; 稳定性&#xff1a;插…

Yolo-world使用

1、安装 python pip install ultralytics 前往官网下载模型&#xff1a;https://docs.ultralytics.com/models/yolo-world/#key-features 我这里使用yolov8s-world.pt举例 最简单的使用示例 if __name__ __main__:model YOLO(model/yolov8s-world.pt)results model.pre…

中仕公考:考公还是考编?区别是什么?

公务员和事业编应该如何选择?区别在哪里?中仕为大家总结以下几点&#xff0c;看完就明白了! 事业编制&#xff1a;主要指从事事业单位工作人员所获得的稳定的事业单位编制。 公务员&#xff1a;是指在各级政府机关中&#xff0c;行使国家行政职权&#xff0c;执行国家公务的…

Ubuntu的apt命令用法汇总

在Ubuntu系统中&#xff0c;apt 是一个十分常用的包管理工具&#xff0c;用于安装、更新、卸载和管理软件包。 本文将汇总apt 命令的用法&#xff0c;以便你更好地利用Ubuntu系统进行软件管理。 一. 安装软件包要安装一个软件包&#xff0c;使用以下命令&#xff1a; sudo a…

pyhton学习之-分支结构-运费计算模板-第二练

第1关:运费计算模板第二练-地区运费计算模板 任务描述 现在有一个淘宝店铺,发货地在天津,店主设计了一个运费规则如下图所示: 可以选择区域来指定每件商品的运费,达到一定的购买金额以后可以包邮 测试说明 计算运费 根据地区和购买数量计算运费。 输入:北京,1件 输出:…

2024HW ---->内网横向移动

在蓝队的面试过程中&#xff0c;如果你会内网渗透的话&#xff0c;那是肯定的一个加分选项&#xff01;&#xff01;&#xff01; 那么从今天开始&#xff0c;我们就来讲一下内网的横向移动&#xff01;&#xff01;&#xff01; 目录 1.域内任意用户枚举 2.Password-Sprayi…

ffmpeg入门

ffmpeg入——安装 Fmpeg地址 FFmpeg源码地址&#xff1a;GitHub - FFmpeg/FFmpeg: Mirror of https://git.ffmpeg.org/ffmpeg.git FFmpeg可执行文件地址&#xff1a;Download FFmpeg Windows平台 ​ ​ Windows平台下载解压后如图所示&#xff08;文件名称以-share结尾的…

深入剖析Spring框架:循环依赖的解决机制

你好&#xff0c;我是柳岸花开。 什么是循环依赖&#xff1f; 很简单&#xff0c;就是A对象依赖了B对象&#xff0c;B对象依赖了A对象。 在Spring中&#xff0c;一个对象并不是简单new出来了&#xff0c;而是会经过一系列的Bean的生命周期&#xff0c;就是因为Bean的生命周期所…

如何添加所有未跟踪文件到暂存区?

文章目录 如何将所有未跟踪文件添加到Git暂存区&#xff1f;步骤与示例代码1. 打开命令行或终端2. 列出所有未跟踪的文件3. 添加所有未跟踪文件到暂存区4. 验证暂存区状态 如何将所有未跟踪文件添加到Git暂存区&#xff1f; 在版本控制系统Git中&#xff0c;当我们首次创建新文…

Java每日面试题

Java 高级面试问题及答案 问题1: Java中的垃圾回收机制是如何工作的&#xff1f;请描述一下垃圾收集器的工作原理。 答案: Java的垃圾回收机制主要依赖于垃圾收集器&#xff08;Garbage Collector&#xff0c;GC&#xff09;&#xff0c;它负责自动回收不再使用的对象&#x…

最全!2024腾讯春招Spring Circuit Breaker面试题大全,附详解和技巧,必备收藏!

面对2024年腾讯春季招聘&#xff0c;准备充分的技术面试答案至关重要&#xff0c;尤其是在微服务架构和高可用性设计方面。Spring Circuit Breaker作为维持微服务稳定性和可靠性的关键技术&#xff0c;了解其工作原理和实际应用对于任何希望在当今技术驱动的环境中取得成功的软…

安全狗云眼的主要功能有哪些?

"安全狗云眼"是一款综合性的网络安全产品&#xff0c;主要用于实时监控和保护企业的网络安全。其核心功能包括威胁检测、漏洞扫描、日志管理和合规性检查等。 以下是安全狗云眼的主要功能详细介绍&#xff1a; 1、资产管理 定期获取并记录主机上的Web站点、Web容器、…