自然语言处理实验2 字符级RNN分类实验

实验2 字符级RNN分类实验

必做题:

(1)数据准备:academy_titles.txt为“考硕考博”板块的帖子标题,job_titles.txt为“招聘信息”板块的帖子标题,将上述两个txt进行划分,其中训练集为70%,测试集为30%。二分类标签:考硕考博为0,招聘信息为1。字符使用One-hot方法表示。

(2)设计模型:在训练集上训练字符级RNN模型。注意,字符级不用分词,是将文本的每个字依次送入模型。

(3)将训练好的模型在测试数据集上进行验证,计算准确率,并分析实验结果。要给出每一部分的代码。

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from sklearn.model_selection import train_test_split# 读取academy_titles文件内容
with open('C:\\Users\\hp\\Desktop\\academy_titles.txt', 'r', encoding='utf-8') as file:academy_titles = file.readlines()# 读取job_titles文件内容
with open('C:\\Users\\hp\\Desktop\\job_titles.txt', 'r', encoding='utf-8') as file:job_titles = file.readlines()# 将招聘信息与学术信息分开
academy_titles = [title.strip() for title in academy_titles]
job_titles = [title.strip() for title in job_titles]# 构建标签和数据
X = academy_titles + job_titles
y = [0] * len(academy_titles) + [1] * len(job_titles)# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)# 构建字符到索引的映射
all_chars = set(''.join(academy_titles + job_titles))
char_to_index = {char: i for i, char in enumerate(all_chars)}# 将文本转换为模型可接受的输入形式
def text_to_input(text, max_len, char_to_index):X_indices = np.zeros((len(text), max_len, len(char_to_index)), dtype=np.float32)for i, title in enumerate(text):for t, char in enumerate(title):X_indices[i, t, char_to_index[char]] = 1return torch.tensor(X_indices)max_len = max([len(title) for title in X])
X_train_indices = text_to_input(X_train, max_len, char_to_index)
X_test_indices = text_to_input(X_test, max_len, char_to_index)# 构建字符级RNN模型
class CharRNN(nn.Module):def __init__(self, input_size, hidden_size):super(CharRNN, self).__init__()self.hidden_size = hidden_sizeself.i2h = nn.LSTM(input_size, hidden_size)self.fc = nn.Linear(hidden_size, 1)self.sigmoid = nn.Sigmoid()def forward(self, input):hidden, _ = self.i2h(input)output = self.fc(hidden[-1])output = self.sigmoid(output)return outputmodel = CharRNN(input_size=len(char_to_index), hidden_size=128)# 定义损失函数和优化器
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 转换数据为PyTorch张量
y_train_tensor = torch.tensor(y_train, dtype=torch.float32).view(-1, 1)
y_test_tensor = torch.tensor(y_test, dtype=torch.float32).view(-1, 1)# 定义新的训练周期数和学习率
num_epochs = 30
learning_rate = 0.01# 定义新的优化器
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
best_accuracy = 0.0
best_model = None# 训练模型并输出每一轮的准确率
for epoch in range(num_epochs):optimizer.zero_grad()output = model(X_train_indices)output = output.view(-1, 1)loss = criterion(output, y_train_tensor[:output.size(0)])loss.backward()optimizer.step()# 计算训练集准确率predictions = (output > 0.5).float()correct = (predictions == y_train_tensor[:output.size(0)]).float()accuracy = correct.sum() / len(correct)print(f'Epoch {epoch+1}, 训练集准确率: {accuracy.item()}')# 保存准确率最高的模型if accuracy > best_accuracy:best_accuracy = accuracybest_model = model.state_dict().copy()# 加载最佳模型参数
model.load_state_dict(best_model)# 使用测试集上准确率最高的模型进行测试
test_output = model(X_test_indices)
test_output = test_output.view(-1, 1)
test_loss = criterion(test_output, y_test_tensor[:test_output.size(0)])
predictions = (test_output > 0.5).float()
correct = (predictions == y_test_tensor[:test_output.size(0)]).float()
accuracy = correct.sum() / len(correct)print(f'使用测试集上准确率最高的模型进行测试,准确率: {accuracy.item()}')

 这个实验准确率目前是偏低的,但是我没有很多时间去一直调整参数

希望后面有需要的同学,可以去调整参数!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/745368.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

服务器Debian 12.x中安装Jupyer并配置远程访问

服务器系统:Debian 12.x;IP地址:10.100.2.138 客户端:Windows 10;IP地址:10.100.2.38 利用ssh登录服务器: 1.安装python3 #apt install python3 2.安装pip #apt install python3-pip … 3.安装virtualen…

Unity Timeline学习笔记(3) - SignalTrack信号轨道和自定义带参数的Marker信号和轨道

信号轨道,顾名思义就是运行到某处发送一个信号。 普通用法 普通用法就是没有任何封装的,个人感觉特别难用,但是有必要理解一下工作原理。 添加信号 我们添加一个信号资源 生成后可以看到资源文件,这个是可以拖到SignalTrack上…

【Python数据结构与判断7/7】数据结构小结

目录 序言 整体回忆 定义方式 访问元素 访问单个元素 访问多个与元素 修改元素 添加元素 列表里添加元素 字典里添加元素 删除元素 in运算符 实战案例 总结 序言 今天将对前面学过的三种数据结构:元组(tuple)、列表(…

微前端框架 qiankun 配置使用【基于 vue/react脚手架创建项目 】

qiankun官方文档:qiankun - qiankun 一、创建主应用: 这里以 vue 为主应用,vue版本:2.x // 全局安装vue脚手架 npm install -g vue/clivue create main-app 省略 vue 创建项目过程,若不会可以自行百度查阅教程 …

java垃圾回收-三色标记法

三色标记法 引言什么是三色标记法白色灰色黑色 三色标记过程三色标记带来的问题多标问题漏标问题 如何弥补漏标问题增量更新原始快照总结 引言 在CMS,G1这种并发的垃圾收集器收集对象时,假如一个对象A被GC线程标记为不可达对象,但是用户线程又把A对象做…

数字化经济的前沿:深入了解 Web3 的商业模式

随着区块链技术的迅速发展,Web3作为一种新型的互联网范式,正逐渐引起人们的关注。它不仅仅是一种技术革新,更是一种商业模式和价值观的转变。本文将深入探讨Web3的商业模式,以及它对数字化经济的影响。 1. 理解Web3的商业模式 We…

算法---滑动窗口练习-4(无重复字符的最长子串)

无重复字符的最长子串 1. 题目解析2. 讲解算法原理3. 编写代码 1. 题目解析 题目地址:点这里 2. 讲解算法原理 算法的主要思想是使用滑动窗口来维护一个不含重复字符的子串。定义两个指针 left 和 right 分别表示窗口的左边界和右边界。还定义了一个数组 hash 来记…

Apache Paimon 的 CDC Ingestion 概述

CDC Ingestion 1)概述 Paimon支持schema evolution将数据插入到Paimon表中,添加的列将实时同步到Paimon表,并且无需重启同步作业。 目前支持的同步方式如下: MySQL Synchronizing Table: 将MySQL中的一个或多个表同步到一个Pa…

【算法与数据结构】深入解析二叉树(一)

文章目录 📝数概念及结构🌠 树的概念🌉树的表示🌠 树在实际中的运用(表示文件系统的目录树结构) 🌉二叉树概念及结构🌠概念🌉数据结构中的二叉树🌠特殊的二叉…

Spring web MVC(2)

1、RequestMapping称为路由映射(既是类注解也是方法注解提供访问路径) 2、RequestParam起到重命名的作用,也起到绑定的作用,传递集合list时会用到,多个值绑定给list,默认是必传参数如果不传参数需要设置re…

如何在Windows 10上打开和关闭平板模式?这里提供详细步骤

前言 默认情况下,当你将可翻转PC重新配置为平板模式时,Windows 10会自动切换到平板模式。如果你希望手动打开或关闭平板模式,有几种方法可以实现。​ 自动平板模式在Windows 10上如何工作 如果你使用的是二合一可翻转笔记本电脑&#xff0…

Spring, SpringBoot, SpringCloud,微服务

1,SSM (Spring+SpringMVC+MyBatis) SSM框架集由Spring、MyBatis两个开源框架整合而成(SpringMVC是Spring中的部分内容),常作为数据源较简单的web项目的框架。 Spring MVC 是 Spring 提供的一个基于 MVC 设计模式的轻量级 Web 开发框架,本质上相当于 Servlet,Controlle…

vue 基于elementUI/antd-vue, h函数实现message中嵌套链接跳转到指定路由 (h函数点击事件的写法)

效果如图: 点击message 组件中的 工单管理, 跳转到工单管理页面。 以下是基于vue3 antd-vue 代码如下: import { message } from ant-design-vue; import { h, reactive, ref, watch } from vue; import { useRouter } from vue-router; c…

PY32离线烧录器功能介绍,可批量烧录,支持PY32系列多款单片机

PY32离线烧录器可以对PY系列单片机进行批量烧录,现支持PY32F002A/002B/002/003/030/071/072/040/403/303芯片各封装和XL2409,XL32F001/003等芯片。PY32离线烧录器需要搭配上位机软件才能使用,上位机软件在我们官网(www.xinlinggo.…

【软考】UML中的图之对象图

目录 1. 说明2. 图示3. 特性 1. 说明 1.对象图即object diagram2.展现了某一时刻一组对象以及它们之间的关系3.描述了在类图中所建立的事物的实例的静态快照4.对象图一般包括对象和链5.对象图展示的是对象之间关系,不存在交互,所以不是交互图 2. 图示 …

#微信小程序(一个emo文案界面)

1.IDE:微信开发者工具 2.实验:一个emo文案界面 (1)最好使用rpx (2)图片宽度占不满,在CSS中设置width为100% (3)imag图片全部为网页链接图片 3.记录 4.代码 index.htm…

Jmeter+ant,ant安装与配置

1.ant含义 ant:Ant翻译过来是蚂蚁的意思,在我们做接口测试的时候,是可以用来做JMeter接口测试生成测试报告的工具 2.ant下载 下载地址:Apache Ant - Ant Manual Distributions download中选择ant 下载安装最新版zip文件 3.…

阿里云国际放行DDoS高防回源IP

如果源站服务器上设置了IP白名单访问控制(如安全软件、安全组),由于设置了DDoS高防后,回源IP是高防回源IP段,您需要将DDoS高防的回源IP段的地址加入安全软件和安全组的白名单中,避免DDoS高防的回源流量被误…

导入fetch_california_housing 加州房价数据集报错解决(HTTPError: HTTP Error 403: Forbidden)

报错 HTTPError Traceback (most recent call last) Cell In[3], line 52 from sklearn.datasets import fetch_california_housing3 from sklearn.model_selection import train_test_split ----> 5 X, Y fetch_california_housing(retu…

发布组件到npm

1.环境准备&#xff0c;需要装好node&#xff0c;注册号npm账号,这里不做详解 2.创建编写组件和方法的文件夹package 3.在文件夹中创建需要定义的组件&#xff0c;并且加上name属性 //组件 <template><div><button>按钮组件</button></div> &…