【NLP 16、实践 ③ 找出特定字符在字符串中的位置】

看着父亲苍老的白发和渐渐老态的面容

希望时间再慢一些

                                                —— 24.12.19

一、定义模型

1.初始化模型

① 初始化父类

super(TorchModel, self).__init__(): 调用父类 nn.Module 初始化方法,确保模型能够正确初始化。

② 创建嵌入层

self.embedding = nn.Embedding(len(vocab), vector_dim): 创建一个嵌入层,将词汇表中的每个词映射到一个 vector_dim 维度的向量。

③ 创建RNN层

self.rnn = nn.RNN(vector_dim, vector_dim, batch_first=True): 创建一个 RNN 层输入和输出的特征维度均为 vector_dim,并且输入数据的第一维是批量大小。

④ 创建线性分类层

self.classify = nn.Linear(vector_dim, sentence_length + 1): 创建一个线性层,将 RNN 输出的特征向量映射到 sentence_length + 1 个分类标签。+1 是因为可能有某个词不存在的情况,此时的真实标签被设为 sentence_length。

⑤ 定义损失函数

self.loss = nn.functional.cross_entropy: 定义交叉熵损失函数,用于计算模型预测值与真实标签之间的差异。

class TorchModel(nn.Module):def __init__(self, vector_dim, sentence_length, vocab):super(TorchModel, self).__init__()self.embedding = nn.Embedding(len(vocab), vector_dim)  #embedding层# self.pool = nn.AvgPool1d(sentence_length)   #池化层#可以自行尝试切换使用rnnself.rnn = nn.RNN(vector_dim, vector_dim, batch_first=True)# +1的原因是可能出现a不存在的情况,那时的真实label在构造数据时设为了sentence_lengthself.classify = nn.Linear(vector_dim, sentence_length + 1)self.loss = nn.functional.cross_entropy

2、前向传播定义

① 输入嵌入

x = self.embedding(x):将输入 x 通过嵌入层转换为向量表示

② RNN处理

rnn_out, hidden = self.rnn(x):将嵌入后的向量输入到RNN层,得到每个时间步的输出 rnn_out 和最后一个时间步的隐藏状态 hidden。

③ 提取特征

x = rnn_out[:, -1, :]:从RNN的输出中提取最后一个时间步(最后一维)的特征向量。

④ 分类

y_pred = self.classify(x):将提取的特征向量通过线性层进行分类,得到预测值 y_pred。

⑤ 损失计算

如果提供了真实标签 y,则计算并返回损失值;否则,返回预测值。

    #当输入真实标签,返回loss值;无真实标签,返回预测值def forward(self, x, y=None):x = self.embedding(x)#使用pooling的情况,先使用pooling池化层会丢失模型语句的时序信息# x = x.transpose(1, 2)# x = self.pool(x)# x = x.squeeze()#使用rnn的情况# rnn_out:每个字对应的向量  hidden:最后一个输出的隐含层对应的向量rnn_out, hidden = self.rnn(x)# 中间维度改变,变成(batch_size数据样本数量, sentence_length文本长度, vector_dim向量维度)x = rnn_out[:, -1, :]  #或者写hidden.squeeze()也是可以的,因为rnn的hidden就是最后一个位置的输出#接线性层做分类y_pred = self.classify(x)if y is not None:return self.loss(y_pred, y)   #预测值和真实值计算损失else:return y_pred     

二、数据

1.建立词表

① 定义字符集

定义一个字符集 chars,包含字母 'a' 到 'k'。

② 定义字典

初始化一个字典 vocab,其中键为 'pad',值为 0。

③ 遍历字符集

使用 enumerate 遍历字符集 chars,为每个字符分配一个唯一的序号,从 1 开始。

④ 定义unk键

添加一个特殊的键 'unk',其值为当前字典的长度(即 26)。

⑤ 返回词汇表

将生成的词汇表返回

#字符集随便挑了一些字,实际上还可以扩充
#为每个字生成一个标号
#{"a":1, "b":2, "c":3...}
#abc -> [1,2,3]
def build_vocab():chars = "abcdefghijk"  #字符集vocab = {"pad":0}for index, char in enumerate(chars):vocab[char] = index+1   #每个字对应一个序号vocab['unk'] = len(vocab) #26return vocab

2.随机生成样本

① 采样

random.sample(list(vocab.keys()), sentence_length):从词汇表 vocab 的键中随机选择 sentence_length 个不同的字符,生成列表 x

② 标签生成

index('a'):检查列表 x 中是否包含字符 "a",如果包含,记录 "a" 在列表中的索引位置为 y,否则,设置 y 为 sentence_length。

③ 转换

将列表 x 中的每个字符转换为其在词汇表中的序号,如果字符不在词汇表中,则使用 unk 的序号

④ 返回结果

返回转换后的列表 x 和标签 y

#随机生成一个样本
def build_sample(vocab, sentence_length):#注意这里用sample,是不放回的采样,每个字母不会重复出现,但是要求字符串长度要小于词表长度x = random.sample(list(vocab.keys()), sentence_length)#指定哪些字出现时为正样本if "a" in x:y = x.index("a")else:y = sentence_lengthx = [vocab.get(word, vocab['unk']) for word in x]   #将字转换成序号,为了做embeddingreturn x, y

3.建立数据集

① 初始化数据集

创建两个空列表 dataset_x 和 dataset_y,用于存储生成的样本和对应的标签

② 生成样本

使用 for 循环,循环次数为 sample_length,即需要生成的样本数量。在每次循环中,调用 build_sample 函数生成一个样本 (x, y),其中 x 是输入数据,y 是标签

③ 存储样本

将生成的样本 x 添加到 dataset_x 列表中。将生成的标签 y 添加到 dataset_y 列表

④ 返回数据集

将 dataset_x 和 dataset_y 转换为 torch.LongTensor 类型,以便在 PyTorch 中使用。返回转换后的数据集。

#建立数据集
#输入需要的样本数量。需要多少生成多少
def build_dataset(sample_length, vocab, sentence_length):dataset_x = []dataset_y = []for i in range(sample_length):x, y = build_sample(vocab, sentence_length)dataset_x.append(x)dataset_y.append(y)return torch.LongTensor(dataset_x), torch.LongTensor(dataset_y)

三、模型测试、训练、评估

1.建立模型

① 参数:

vocab:词汇表,通常是一个包含所有字符或单词的列表或字典

char_dim:字符的维度,即每个字符在嵌入层中的向量长度

sentence_length:句子的最大长度

② 过程:

使用传入的参数 char_dim、sentence_length 和 vocab 实例化一个 TorchModel 对象并返回

#建立模型
def build_model(vocab, char_dim, sentence_length):model = TorchModel(char_dim, sentence_length, vocab)return model

2.测试模型

① 设置模型为评估模式

model.eval():将模型设置为评估模式禁用 dropout 训练时的行为

② 生成测试数据集

调用 build_dataset 函数生成 200 个用于测试的样本

③ 打印样本数量

输出当前测试集中样本的数量

④ 模型预测

使用 torch.no_grad() 禁用梯度计算,提高推理速度并减少内存消耗,然后对生成的测试数据进行预测

⑤ 计算准确率

遍历预测结果和真实标签,统计正确和错误的预测数量,并计算准确率

⑥ 输出结果

打印正确预测的数量和准确率,并返回准确率

#测试代码
#用来测试每轮模型的准确率
def evaluate(model, vocab, sample_length):model.eval()x, y = build_dataset(200, vocab, sample_length)   #建立200个用于测试的样本print("本次预测集中共有%d个样本"%(len(y)))correct, wrong = 0, 0with torch.no_grad():y_pred = model(x)      #模型预测for y_p, y_t in zip(y_pred, y):  #与真实标签进行对比if int(torch.argmax(y_p)) == int(y_t):correct += 1else:wrong += 1print("正确预测个数:%d, 正确率:%f"%(correct, correct/(correct+wrong)))return correct/(correct+wrong)

3.模型训练

① 配置参数

设置训练轮数epoch_num批量大小batch_size训练样本数train_sample字符维度char_dim句子长度sentence_length学习率learning_rate

② 建立字表

调用 build_vocab 函数生成字符到索引的映射。

③ 建立模型

调用 build_model 函数创建模型。

④ 选择优化器

torch.optim.Adam(model.parameters(), lr=learning_rate):使用 Adam 优化器

⑤ 训练过程

model.train():模型进入训练模式。每个 epoch 中,按批量生成训练数据,计算损失,反向传播并更新权重。记录每个 epoch 的平均损失

⑥ 评估模型

每个 epoch 结束后,调用 evaluate 函数评估模型性能。

⑦ 记录日志

记录每个 epoch 的准确率和平均损失

⑧ 绘制图表

绘制准确率和损失的变化曲线

⑨ 保存模型和词表

保存模型参数和词表

def main():#配置参数epoch_num = 20        #训练轮数batch_size = 40       #每次训练样本个数train_sample = 1000    #每轮训练总共训练的样本总数char_dim = 30         #每个字的维度sentence_length = 10   #样本文本长度learning_rate = 0.001 #学习率# 建立字表vocab = build_vocab()# 建立模型model = build_model(vocab, char_dim, sentence_length)# 选择优化器optim = torch.optim.Adam(model.parameters(), lr=learning_rate)log = []# 训练过程for epoch in range(epoch_num):model.train()watch_loss = []for batch in range(int(train_sample / batch_size)):x, y = build_dataset(batch_size, vocab, sentence_length) #构造一组训练样本optim.zero_grad()    #梯度归零loss = model(x, y)   #计算lossloss.backward()      #计算梯度optim.step()         #更新权重watch_loss.append(loss.item())print("=========\n第%d轮平均loss:%f" % (epoch + 1, np.mean(watch_loss)))acc = evaluate(model, vocab, sentence_length)   #测试本轮模型结果log.append([acc, np.mean(watch_loss)])#画图plt.plot(range(len(log)), [l[0] for l in log], label="acc")  #画acc曲线plt.plot(range(len(log)), [l[1] for l in log], label="loss")  #画loss曲线plt.legend()plt.show()#保存模型torch.save(model.state_dict(), "model.pth")# 保存词表writer = open("vocab.json", "w", encoding="utf8")writer.write(json.dumps(vocab, ensure_ascii=False, indent=2))writer.close()return

四、模型预测

1.训练并保存模型

if __name__ == "__main__":main()


2.预测数据

用保存的训练好的模型进行预测

① 初始化参数

设置每个字的维度 char_dim 和样本文本长度 sentence_length

② 加载字符表

从指定路径加载字符表 vocab

③ 建立模型

调用 build_model 函数构建模型

④ 加载模型权重

从指定路径加载预训练的模型权重

⑤ 序列化输入

将输入字符串转换为模型所需的输入格式

⑥ 模型预测

将输入数据传递给模型进行预测

⑦ 输出结果

打印每个输入字符串的预测类别和概率值

#使用训练好的模型做预测
def predict(model_path, vocab_path, input_strings):char_dim = 30  # 每个字的维度sentence_length = 10  # 样本文本长度vocab = json.load(open(vocab_path, "r", encoding="utf8")) #加载字符表model = build_model(vocab, char_dim, sentence_length)     #建立模型model.load_state_dict(torch.load(model_path,weights_only=True))             #加载训练好的权重x = []for input_string in input_strings:x.append([vocab[char] for char in input_string])  #将输入序列化model.eval()   #测试模式with torch.no_grad():  #不计算梯度result = model.forward(torch.LongTensor(x))  #模型预测for i, input_string in enumerate(input_strings):print("输入:%s, 预测类别:%s, 概率值:%s" % (input_string, torch.argmax(result[i]), result[i])) #打印结果

3.调用函数进行预测

if __name__ == "__main__":# main()test_strings = ["kijabcdefh", "gijkbcdeaf", "gkijadfbec", "kijhdefacb"]predict("model.pth", "vocab.json", test_strings)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/64950.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

javaEE-多线程编程-3

目录 java 常见的包 : 回调函数: 什么是线程: 第一个线程: 验证多线程执行: 内核: 调用sleep()方法: 执行结果分析: 线程创建的几种方式: 1.继承Thread类,重写run()方法. 2.实现Runnable接口,重写run()方法. 3.继承Thread类,重写run()方法.但使用匿名内部类 4.实现…

怎么在idea中创建springboot项目

最近想系统学习下springboot,尝试一下全栈路线 从零开始,下面将叙述下如何创建项目 环境 首先确保自己环境没问题 jdkMavenidea 创建springboot项目 1.打开idea,选择file->New->Project 2.选择Spring Initializr->设置JDK->…

设计模式期末复习

一、设计模式的概念以及分类 是一套被反复使用,多数人知晓,经过分类编目,代码设计经验的总结,描述了在软件设计的过程中不断重复发生的问题,以及该问题的解决方案,他是解决特定问题的一系列套路&#xff0c…

哔哩哔哩视频能保存到本地吗

哔哩哔哩(B站)视频可以保存到本地,但需要根据具体情况选择方法。以下是一些常见的方式: 使用B站客户端的离线功能 B站官方客户端(移动端或PC端)提供了离线下载功能,适用于已开通权限的视频。 离…

FreeMarker语法

1. 查找转移 <#function getSubSlot x > <#return (x) ? switch( "1", "L", "2", "R", "" )> </#function> 2. 转换数字 ?number ${mergedMap[placement.sequence].material.subs…

OCR(五)linux 环境 基于c++的 paddle ocr 编译【CPU版本 】

1. 下载 下载opencv4.10 2. 编译opencv 2.1 安装依赖库 sudo apt install -y g ++ sudo apt install -y cmake sudo apt install -y make sudo apt install -y wget sudo apt install -y unzip sudo apt-get install build-essential libgtk2.0-dev libgtk-3-devlibavcodec-…

SQL Server 批量插入数据的方式汇总及优缺点分析

在 SQL Server 中,批量插入数据是非常常见的操作,尤其是在需要导入大量数据时。以下是几种常用的批量插入数据的方式: 1. 使用 INSERT INTO ... VALUES • 特点:适用于少量数据插入。 • 优点:简单易用。 • 缺点:不适合大量数据插入,性能较差。 • 示例:…

对于其他管理的理解(下)

信息系统项目管理师的论文还有可行性分析、安全管理、测试管理、招投标管理等 可行性分析(医生的手术前检查) 1. 明确目标&#xff1a;确定手术的必要性 知识点&#xff1a; 在项目启动阶段&#xff0c;需要明确项目的目的、目标和需求&#xff0c;确定是否有必要开展项目。这…

Linux增加回收站功能

功能简介 rm命令是非常危险的命令&#xff0c;为了防止用户误删文件&#xff0c;所以我们在执行rm命令时将文件添加到回收站&#xff0c;防止误删文件。 相关环境变量 名称描述TRASH_DIR 回收站目录&#xff0c;默认为/Recycle_Bin 文件命名规则 文件名生成格式为 原始文件名…

Github——网页版上传文件夹

第一步&#xff1a;创建一个新的仓库或进入已存在的仓库页面 第二步&#xff1a;点进对应的文件夹下&#xff0c;然后 点击 “Upload files” 第三步&#xff1a;将文件夹拖拽到上传区域 打开资源管理器&#xff0c;将要上传的文件夹从计算机中拖拽到上传区域。 注意&#xf…

高级的SQL查询技巧有哪些?

成长路上不孤单&#x1f60a;&#x1f60a;&#x1f60a;&#x1f60a;&#x1f60a;&#x1f60a; 【14后&#x1f60a;///C爱好者&#x1f60a;///持续分享所学&#x1f60a;///如有需要欢迎收藏转发///&#x1f60a;】 今日分享关于高级SQL查询技巧方面的相关内容&#xf…

FastStone 10.x 注册码

简介 FastStone Capture是一款经典好用的屏幕截图软件&#xff0c;在屏幕截图领域具有广泛的应用和众多优势。 软件基本信息 FastStone Capture体积小巧&#xff0c;占用内存少&#xff0c;这使得它在运行时不会给计算机系统带来过多的负担&#xff0c;即使在配置较低的电脑…

数据库 SQL 常用语句全解析

数据库 SQL 常用语句全解析 在数据库领域&#xff0c;SQL&#xff08;Structured Query Language&#xff09;作为标准语言&#xff0c;掌控着数据的查询、插入、更新与删除等关键操作。无论是新手入门数据库&#xff0c;还是经验丰富的开发者日常工作&#xff0c;熟练掌握 SQ…

ADB在浏览器中的革命:ya-webadb项目解析及新手指南

ADB在浏览器中的革命&#xff1a;ya-webadb项目解析及新手指南 ya-webadb ADB in your browser [这里是图片001] 项目地址: https://gitcode.com/gh_mirrors/ya/ya-webadb ya-webadb是一个创新的开源项目&#xff0c;它将Android调试桥(ADB)的功能带入了基于Chromium的浏览器…

K8S详解(5万字详细教程)

目录 ​编辑 一、集群管理命令 二、命名空间 1. 获取命名空间列表 2. 创建命名空间 3. 删除命名空间 4. 查看命名空间详情 三、Pod 1. Pod概述 2. Pod相位状态 3. 管理命令 3.1 获取命名空间下容器(pod)列表 3.2 查看pod的详细信息 3.3 创建 && 运行 3.4 …

费舍尔信息矩阵全面讲述

费舍尔信息矩阵&#xff08;Fisher Information Matrix&#xff09; 费舍尔信息矩阵是统计学中一个非常重要的概念&#xff0c;尤其在参数估计、最大似然估计&#xff08;MLE&#xff09;和贝叶斯推断中具有广泛的应用。它反映了参数估计的不确定性程度&#xff0c;也可以用来…

PTA 时间几何

作者 Happyer 单位 湖北文理学院 乘火车或飞机常有由始发时间历经时间计算终到时间的事儿。我们通过三个 函数来完成&#xff0c;当然&#xff0c;为了存储几点几分这个时间&#xff0c;我们专门定义了一个结构体Time_gxx,你要完成的是写二个函数&#xff1a;1&#xff09;st…

Zookeeper的监听机制

Zookeeper的监听机制是其实现分布式协调服务的一个核心功能。 它允许客户端注册Watcher&#xff08;观察者&#xff09;来监听特定的Znode&#xff08;节点&#xff09;上的事件&#xff0c;当Znode的状态发生变化时&#xff0c;Zookeeper会向注册了Watcher的客户端发送通知。…

[原创](Modern C++)现代C++的第三方库的导入方式: 例如Visual Studio 2022导入GSL 4.1.0

[简介] 常用网名: 猪头三 出生日期: 1981.XX.XX 企鹅交流: 643439947 个人网站: 80x86汇编小站 编程生涯: 2001年~至今[共23年] 职业生涯: 21年 开发语言: C/C、80x86ASM、PHP、Perl、Objective-C、Object Pascal、C#、Python 开发工具: Visual Studio、Delphi、XCode、Eclipse…

2002 - Can‘t connect to server on ‘192.168.1.XX‘ (36)

参考:2002 - Can‘t connect to server on ‘192.168.1.XX‘ (36) ubantu20.04&#xff0c;mysql5.7.13 navicat 远程连接数据库报错 2002 - Can’t connect to server on ‘192.168.1.61’ (36) 一、查看数据库服务是否有启动&#xff0c;发现有启动 systemctl status mysql…