基于Pytorch的NLP入门任务思想及代码实现:判断文本中是否出现指定字

今天学了第一个基于Pytorch框架的NLP任务:

判断文本中是否出现指定字

思路:(注意:这是基于字的算法)

任务:判断文本中是否出现“xyz”,出现其中之一即可

训练部分:

一,我们要先设计一个模型去训练数据。
这个Pytorch的模型:
首先通过embedding层:将字符转化为离散数值(矩阵)
通过线性层:设置网络的连接层,进行映射
通过dropout层:将一部分输入设为0(可去掉)
通过激活层:sigmoid激活
通过一个pooling层:降维,将矩阵->向量
通过另一个输出线性层:使输出是一维(1或0)
通过一个激活层*:sigmoid激活。

二,设置一个函数:这个函数能将设定的字符变成字符集,将每一个字符设定一个代号,比如说:“我爱你”-> 我:1,爱:2,你:3。当出现"你爱我"时,计算机接受的是:3,2,1。这样方便计算机处理字符。

三,因为我们没有训练样本和测试样本,所以我们要自己生成一些随机样本。通过random.choice在字符集中随机顺序输出字符作为输入,并将输入中含有"xyz"的样本的输出值为“1”,反之为“0”

四,设置一个函数,将随机得到的样本,放入数据集中(列表),便于运算。

五,设置测试函数:随机建立一些样本,根据样本的输出来设定有多少个正样本,多少个负样本,再将预测的样本输出来与样本输出对比,得到正确率。

六,最后的main函数:按照训练轮数和训练组数,通过BP反向传播更新权重进行训练。然后调取测试函数得到acc等数据。将loss和acc的数值绘制下来,保存模型和词表。

预测部分

将保存的词表和模型加载进来,将输入的字符转化为列表,然后进入模型forward函数进行预测,最后打印出结果。

代码实现:

import torch
import torch.nn as nn
import numpy as np
import random
import json
import matplotlib.pyplot as plt"""
基于pytorch的网络编写
实现一个网络完成一个简单nlp任务
判断文本中是否有某些特定字符出现
"""class TorchModel(nn.Module):def __init__(self, input_dim, sentence_length, vocab):super(TorchModel, self).__init__()self.embedding = nn.Embedding(len(vocab) + 1, input_dim)    #embedding层:将字符转化为离散数值self.layer = nn.Linear(input_dim, input_dim)    #对输入数据做线性变换self.classify = nn.Linear(input_dim, 1)     #映射到一维self.pool = nn.AvgPool1d(sentence_length)   #pooling层:降维self.activation = torch.sigmoid     #sigmoid做激活函数self.dropout = nn.Dropout(0.1)  #一部分输入为0self.loss = nn.functional.mse_loss  #loss采用均方差损失#当输入真实标签,返回loss值;无真实标签,返回预测值def forward(self, x, y=None):x = self.embedding(x)#输入维度:(batch_size, sen_len)输出维度:(batch_size, sen_len, input_dim)将文本->矩阵x = self.layer(x)#输入维度:(batch_size, sen_len, input_dim)输出维度:(batch_size, sen_len, input_dim)x = self.dropout(x)#输入维度:(batch_size, sen_len, input_dim)输出维度:(batch_size, sen_len, input_dim)x = self.activation(x)#输入维度:(batch_size, sen_len, input_dim)输出维度:(batch_size, sen_len, input_dim)x = self.pool(x.transpose(1,2)).squeeze()#输入维度:(batch_size, sen_len, input_dim)输出维度:(batch_size, input_dim)将矩阵->向量x = self.classify(x)#输入维度:(batch_size, input_dim)输出维度:(batch_size, 1)y_pred = self.activation(x)#输入维度:(batch_size, 1)输出维度:(batch_size, 1)if y is not None:return self.loss(y_pred, y)else:return y_pred#字符集随便挑了一些汉字,实际上还可以扩充
#为每个汉字生成一个标号
#{"不":1, "东":2, "个":3...}
#不东个->[1,2,3]
def build_vocab():chars = "不东个么买五你儿几发可同名呢方人上额旅法xyz"  #随便设置一个字符集vocab = {}for index, char in enumerate(chars):vocab[char] = index + 1   #每个字对应一个序号,+1是序号从1开始vocab['unk'] = len(vocab)+1   #不在表中的值设为前一个+1return vocab#随机生成一个样本
#从所有字中选取sentence_length个字
#如果vocab中的xyz出现在样本中,则为正样本
#反之为负样本
def build_sample(vocab, sentence_length):#将vacab转化为字表,随机从字表选取sentence_length个字,可能重复x = [random.choice(list(vocab.keys())) for _ in range(sentence_length)]#指定哪些字必须在正样本出现if set("xyz") & set(x):     #若xyz与x中的字符相匹配,则为1,为正样本y = 1else:y = 0x = [vocab.get(word,vocab['unk']) for word in x]   #将字转换成序号return x, y#建立数据集
#输入需要的样本数量。需要多少生成多少
def build_dataset(sample_length,vocab, sentence_length):dataset_x = []dataset_y = []for i in range(sample_length):x, y = build_sample(vocab, sentence_length)dataset_x.append(x)dataset_y.append([y])return torch.LongTensor(dataset_x), torch.FloatTensor(dataset_y)#建立模型
def build_model(vocab, char_dim, sentence_length):model = TorchModel(char_dim, sentence_length, vocab)return model#测试代码
#用来测试每轮模型的准确率
def evaluate(model, vocab, sample_length):model.eval()x, y = build_dataset(200, vocab, sample_length)#建立200个用于测试的样本(因为测试样本是随机生成的,所以不存在过拟合)print("本次预测集中共有%d个正样本,%d个负样本"%(sum(y), 200 - sum(y)))correct, wrong = 0, 0with torch.no_grad():y_pred = model(x)      #调用Pytorch模型预测for y_p, y_t in zip(y_pred, y):  #与真实标签进行对比if float(y_p) < 0.5 and int(y_t) == 0:correct += 1   #负样本判断正确elif float(y_p) >= 0.5 and int(y_t) == 1:correct += 1   #正样本判断正确else:wrong += 1print("正确预测个数:%d, 正确率:%f"%(correct, correct/(correct+wrong)))return correct/(correct+wrong)def main():epoch_num = 10        #训练轮数batch_size = 20       #每次训练样本个数train_sample = 1000   #每轮训练总共训练的样本总数char_dim = 20         #每个字的维度sentence_length = 10   #样本文本长度vocab = build_vocab()       #建立字表model = build_model(vocab, char_dim, sentence_length)    #建立模型optim = torch.optim.Adam(model.parameters(), lr=0.005)   #建立优化器log = []for epoch in range(epoch_num):model.train()watch_loss = []for batch in range(int(train_sample / batch_size)):x, y = build_dataset(batch_size, vocab, sentence_length) #每次训练构建一组训练样本optim.zero_grad()    #梯度归零loss = model(x, y)   #计算losswatch_loss.append(loss.item())  #将loss存下来,方便画图loss.backward()      #计算梯度optim.step()         #更新权重print("=========\n第%d轮平均loss:%f" % (epoch + 1, np.mean(watch_loss)))acc = evaluate(model, vocab, sentence_length)   #测试本轮模型结果log.append([acc, np.mean(watch_loss)])plt.plot(range(len(log)), [l[0] for l in log])  #画acc曲线:蓝色的plt.plot(range(len(log)), [l[1] for l in log])  #画loss曲线:黄色的plt.show()#保存模型torch.save(model.state_dict(), "model.pth")writer = open("vocab.json", "w", encoding="utf8")#保存词表writer.write(json.dumps(vocab, ensure_ascii=False, indent=2))writer.close()return#最终预测
def predict(model_path, vocab_path, input_strings):char_dim = 20  # 每个字的维度sentence_length = 10  # 样本文本长度vocab = json.load(open(vocab_path, "r", encoding="utf8"))model = build_model(vocab, char_dim, sentence_length)    #建立模型model.load_state_dict(torch.load(model_path))   #将模型文件加载进来x = []for input_string in input_strings:  #转化输入x.append([vocab[char] for char in input_string])model.eval()    #在torch中预测注意这个:停止dropoutwith torch.no_grad():   #在torch中预测注意这个:停止梯度result = model.forward(torch.LongTensor(x)) #根据自己设计的函数定义,只输入x就会输出预测值for i, input_string in enumerate(input_strings):print(round(float(result[i])), input_string, result[i])#round(float(result))是将预测结果四舍五入得到0或1的预测值if __name__ == "__main__":main()#如果是进行预测,将下面两行解除注释,将main()注释掉,即可调用最终预测函数进行预测# test_strings = ["个么买不东五你儿x发", "不东东么儿几买五你发", "不东个么买五你个么买", "不z个五么买你儿几发"]# predict("model.pth", "vocab.json", test_strings)

运行结果展示:

训练部分:

在这里插入图片描述
(蓝色是acc曲线,黄色是loss曲线)

预测部分:

在这里插入图片描述

一些补充:

一:model.eval()或者model.train()的作用

如果模型中有BN层(Batch Normalization)和Dropout,需要在训练时添加model.train(),在测试时添加model.eval()。其中model.train()是保证BN层用每一批数据的均值和方差,而model.eval()是保证BN用全部训练数据的均值和方差;而对于Dropout,model.train()是随机取一部分网络连接来训练更新参数,而model.eval()是利用到了所有网络连接。

二:Pytorch模型中用了两次激活函数

在每一个网络层后使用一个激活层是一种比较常见的模型搭建方式,但不是必要的。这个只是举例,去掉也是可行的。在具体任务中,带着好还是不带好也跟数据和任务本身有关,没有确定答案(如果在代码 中把第一个激活层注释掉 反而性能更好)

三:对x = self.pool(x.transpose(1,2)).squeeze()代码的解读

通过shape方法我们能知道,在pool前,x的维度输出是[20,10,20],代表20个10×20的矩阵,代表着[这一批的个数,样本文本长度,输入维度],transpose(1,2)是将x中行和列调换(转置),然后通过pooling层将[20,20,10]->[20,20,1],最后通过squeeze()进行降维变成[20,20]。
(池化层的作用及理解)

四:embedding层的理解

embedding层并不是单纯的单词映射,而是将单词表中每个单词的数值与权重相乘。在第一次时有默认权重,然后在接下来的训练中,embedding层的权重与分类权重一起经过训练。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/389571.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

erlang下lists模块sort(排序)方法源码解析(二)

上接erlang下lists模块sort&#xff08;排序&#xff09;方法源码解析(一)&#xff0c;到目前为止&#xff0c;list列表已经被分割成N个列表&#xff0c;而且每个列表的元素是有序的&#xff08;从大到小&#xff09; 下面我们重点来看看mergel和rmergel模块&#xff0c;因为我…

洛谷P4841 城市规划(多项式求逆)

传送门 这题太珂怕了……如果是我的话完全想不出来…… 题解 1 //minamoto2 #include<iostream>3 #include<cstdio>4 #include<algorithm>5 #define ll long long6 #define swap(x,y) (x^y,y^x,x^y)7 #define mul(x,y) (1ll*(x)*(y)%P)8 #define add(x,y) (x…

支撑阻力指标_使用k表示聚类以创建支撑和阻力

支撑阻力指标Note from Towards Data Science’s editors: While we allow independent authors to publish articles in accordance with our rules and guidelines, we do not endorse each author’s contribution. You should not rely on an author’s works without seek…

高版本(3.9版本)python在anaconda安装opencv库及skimage库(scikit_image库)诸多问题解决办法

今天开始CV方向的学习&#xff0c;然而刚拿到基础代码的时候发现 from skimage.color import rgb2gray 和 import cv2标红&#xff08;这里是因为我已经配置成功了&#xff0c;所以没有红标&#xff09;&#xff0c;我以为是单纯两个库没有下载&#xff0c;去pycharm中下载ski…

python 实现斐波那契数列

# coding:utf8 __author__ blueslidef fun(arg1,arg2,stop):if arg10:print(arg1,arg2)arg3 arg1arg2print(arg3)if arg3<stop:arg3 fun(arg2,arg3,stop)fun(0,1,100)转载于:https://www.cnblogs.com/bluesl/p/9079705.html

单机安装ZooKeeper

2019独角兽企业重金招聘Python工程师标准>>> zookeeper下载、安装以及配置环境变量 本节介绍单机的zookeeper安装&#xff0c;官方下载地址如下&#xff1a; https://archive.apache.org/dist/zookeeper/ 我这里使用的是3.4.11版本&#xff0c;所以找到相应的版本点…

均线交易策略的回测 r_使用r创建交易策略并进行回测

均线交易策略的回测 rR Programming language is an open-source software developed by statisticians and it is widely used among Data Miners for developing Data Analysis. R can be best programmed and developed in RStudio which is an IDE (Integrated Development…

opencv入门课程:彩色图像灰度化和二值化(采用skimage库和opencv库两种方法)

用最简单的办法实现彩色图像灰度化和二值化&#xff1a; 首先采用skimage库&#xff08;skimage库现在在scikit_image库中&#xff09;实现&#xff1a; from skimage.color import rgb2gray import numpy as np import matplotlib.pyplot as plt""" skimage库…

SVN中Revert changes from this revision 跟Revert to this revision

譬如有个文件&#xff0c;有十个版本&#xff0c;假定版本号是1&#xff0c;2&#xff0c;3&#xff0c;4&#xff0c;5&#xff0c;6&#xff0c;7&#xff0c;8&#xff0c;9&#xff0c;10。Revert to this revision&#xff1a; 如果是在版本6这里点击“Revert to this rev…

归 [拾叶集]

归 心归故乡 想象行走在 乡间恬静小路上 让那些疲惫的梦 都随风飞散吧&#xff01; 不去想那些世俗 人来人往 熙熙攘攘 秋日午后 阳光下 细数落叶 来日方长 世上的路 有诗人、浪子 歌咏吟唱 世上的人 在欲望、信仰中 彷徨 彷徨又迷茫 亲爱的人儿 快结束那 无休止的独自流浪 莫要…

instagram分析以预测与安的限量版运动鞋转售价格

Being a sneakerhead is a culture on its own and has its own industry. Every month Biggest brands introduce few select Limited Edition Sneakers which are sold in the markets according to Lottery System called ‘Raffle’. Which have created a new market of i…

opencv:用最邻近插值和双线性插值法实现上采样(放大图像)与下采样(缩小图像)

上采样与下采样 概念&#xff1a; 上采样&#xff1a; 放大图像&#xff08;或称为上采样&#xff08;upsampling&#xff09;或图像插值&#xff08;interpolating&#xff09;&#xff09;的主要目的 是放大原图像,从而可以显示在更高分辨率的显示设备上。 下采样&#xff…

CSS魔法堂:那个被我们忽略的outline

前言 在CSS魔法堂&#xff1a;改变单选框颜色就这么吹毛求疵&#xff01;中我们要模拟原生单选框通过Tab键获得焦点的效果&#xff0c;这里涉及到一个常常被忽略的属性——outline&#xff0c;由于之前对其印象确实有些模糊&#xff0c;于是本文打算对其进行稍微深入的研究^_^ …

初创公司怎么做销售数据分析_初创公司与Faang公司的数据科学

初创公司怎么做销售数据分析介绍 (Introduction) In an increasingly technological world, data scientist and analyst roles have emerged, with responsibilities ranging from optimizing Yelp ratings to filtering Amazon recommendations and designing Facebook featu…

opencv:灰色和彩色图像的像素直方图及直方图均值化的实现与展示

直方图及直方图均值化的理论&#xff0c;实现及展示 直方图&#xff1a; 首先&#xff0c;我们来看看什么是直方图&#xff1a; 理论概念&#xff1a; 在图像处理中&#xff0c;经常用到直方图&#xff0c;如颜色直方图、灰度直方图等。 图像的灰度直方图就描述了图像中灰度分…

mysql.sock问题

Cant connect to local MySQL server through socket /tmp/mysql.sock 上述提示可能在启动mysql时遇到&#xff0c;即在/tmp/mysql.sock位置找不到所需要的mysql.sock文件&#xff0c;主要是由于my.cnf文件里对mysql.sock的位置设定导致。 mysql.sock默认的是在/var/lib/mysql,…

交换机的基本原理配置(一)

1、配置主机名 在全局模式下输入hostname 名字 然后回车即可立马生效&#xff08;在生产环境交换机必须有自己唯一的名字&#xff09; Switch(config)#hostname jsh-sw1jsh-sw1(config)#2、显示系统OS名称及版本信息 特权模式下&#xff0c;输入命令 show version Switch#show …

opencv:卷积涉及的基础概念,Sobel边缘检测代码实现及Same(相同)填充与Vaild(有效)填充

滤波 线性滤波可以说是图像处理最基本的方法&#xff0c;它可以允许我们对图像进行处理&#xff0c;产生很多不同的效果。 卷积 卷积的概念&#xff1a; 卷积的原理与滤波类似。但是卷积却有着细小的差别。 卷积操作也是卷积核与图像对应位置的乘积和。但是卷积操作在做乘…

机器学习股票_使用概率机器学习来改善您的股票交易

机器学习股票Note from Towards Data Science’s editors: While we allow independent authors to publish articles in accordance with our rules and guidelines, we do not endorse each author’s contribution. You should not rely on an author’s works without seek…

BZOJ 2818 Gcd

传送门 题解&#xff1a;设p为素数 &#xff0c;则gcd(x/p,y/p)1也就是说求 x&#xff0f;p以及 y&#xff0f;p的欧拉函数。欧拉筛前缀和就可以解决 #include <iostream> #include <cstdio> #include <cmath> #include <algorithm> #include <map&…