Day09【基于jieba分词和RNN实现的简单中文分词】

基于jieba分词和RNN实现的中文分词

      • 目标
      • 数据准备
      • 主程序
      • 预测效果

在这里插入图片描述
在这里插入图片描述

目标

本文基于给定的中文词表,将输入的文本基于jieba分词分割为若干个词,词的末尾对应的标签为1,中间部分对应的标签为0,同时将分词后的单词基于中文词表做初步序列化,之后经过embeddingRNN循环神经网络等网络结构层,最后输出在两类别(词内部和词边界)标签上的概率分布,从而实现一个简单中文分词任务。

数据准备

词表文件chars.txt

中文语料文件corpus.txt中文语料文件

主程序

#coding:utf8import torch
import torch.nn as nn
import jieba
import numpy as np
import random
import json
from torch.utils.data import DataLoader"""
基于pytorch的网络编写一个分词模型
我们使用jieba分词的结果作为训练数据
看看是否可以得到一个效果接近的神经网络模型
"""class TorchModel(nn.Module):def __init__(self, input_dim, hidden_size, num_rnn_layers, vocab):super(TorchModel, self).__init__()self.embedding = nn.Embedding(len(vocab) + 1, input_dim) #shape=(vocab_size, dim)self.rnn_layer = nn.RNN(input_size=input_dim,hidden_size=hidden_size,batch_first=True,bidirectional=False,num_layers=num_rnn_layers,nonlinearity="relu",dropout=0.1)self.classify = nn.Linear(hidden_size, 2)self.loss_func = nn.CrossEntropyLoss(ignore_index=-100)#当输入真实标签,返回loss值;无真实标签,返回预测值def forward(self, x, y=None):x = self.embedding(x)  #output shape:(batch_size, sen_len, input_dim)x, _ = self.rnn_layer(x)  #output shape:(batch_size, sen_len, hidden_size)y_pred = self.classify(x)   #input shape:(batch_size, sen_len, class_num)if y is not None:#(batch_size * sen_len, class_num),   (batch_size * sen_len, 1)return self.loss_func(y_pred.view(-1, 2), y.view(-1))else:return y_predclass Dataset:def __init__(self, corpus_path, vocab, max_length):self.vocab = vocabself.corpus_path = corpus_pathself.max_length = max_lengthself.load()def load(self):self.data = []with open(self.corpus_path, encoding="utf8") as f:for line in f:sequence = sentence_to_sequence(line, self.vocab)label = sequence_to_label(line)sequence, label = self.padding(sequence, label)sequence = torch.LongTensor(sequence)label = torch.LongTensor(label)self.data.append([sequence, label])if len(self.data) > 10000:breakdef padding(self, sequence, label):sequence = sequence[:self.max_length]sequence += [0] * (self.max_length - len(sequence))label = label[:self.max_length]label += [-100] * (self.max_length - len(label))return sequence, labeldef __len__(self):return len(self.data)def __getitem__(self, item):return self.data[item]#文本转化为数字序列,为embedding做准备
def sentence_to_sequence(sentence, vocab):sequence = [vocab.get(char, vocab['unk']) for char in sentence]return sequence#基于结巴生成分级结果的标注
def sequence_to_label(sentence):words = jieba.lcut(sentence)label = [0] * len(sentence)pointer = 0for word in words:pointer += len(word)label[pointer - 1] = 1return label#加载字表
def build_vocab(vocab_path):vocab = {}with open(vocab_path, "r", encoding="utf8") as f:for index, line in enumerate(f):char = line.strip()vocab[char] = index + 1   #每个字对应一个序号vocab['unk'] = len(vocab) + 1return vocab#建立数据集
def build_dataset(corpus_path, vocab, max_length, batch_size):dataset = Dataset(corpus_path, vocab, max_length) #diy __len__ __getitem__data_loader = DataLoader(dataset, shuffle=True, batch_size=batch_size) #torchreturn data_loaderdef main():epoch_num = 10        #训练轮数batch_size = 20       #每次训练样本个数char_dim = 50         #每个字的维度hidden_size = 100     #隐含层维度num_rnn_layers = 3    #rnn层数max_length = 20       #样本最大长度learning_rate = 1e-3  #学习率vocab_path = "chars.txt"  #字表文件路径corpus_path = "corpus.txt"  #语料文件路径vocab = build_vocab(vocab_path)       #建立字表data_loader = build_dataset(corpus_path, vocab, max_length, batch_size)  #建立数据集model = TorchModel(char_dim, hidden_size, num_rnn_layers, vocab)   #建立模型optim = torch.optim.Adam(model.parameters(), lr=learning_rate)     #建立优化器#训练开始for epoch in range(epoch_num):model.train()watch_loss = []for x, y in data_loader:optim.zero_grad()    #梯度归零loss = model(x, y)   #计算lossloss.backward()      #计算梯度optim.step()         #更新权重watch_loss.append(loss.item())print("=========\n第%d轮平均loss:%f" % (epoch + 1, np.mean(watch_loss)))#保存模型torch.save(model.state_dict(), "model.pth")#保存词表writer = open("vocab.json", "w", encoding="utf8")writer.write(json.dumps(vocab, ensure_ascii=False, indent=2))writer.close()return#最终预测
def predict(model_path, vocab_path, input_strings):#配置保持和训练时一致char_dim = 50  # 每个字的维度hidden_size = 100  # 隐含层维度num_rnn_layers = 3  # rnn层数vocab = build_vocab(vocab_path)       #建立字表model = TorchModel(char_dim, hidden_size, num_rnn_layers, vocab)   #建立模型model.load_state_dict(torch.load(model_path))   #加载训练好的模型权重model.eval()for input_string in input_strings:#逐条预测x = sentence_to_sequence(input_string, vocab)# print(x)with torch.no_grad():result = model.forward(torch.LongTensor([x]))[0]result = torch.argmax(result, dim=-1)  #预测出的01序列print(result)#在预测为1的地方切分,将切分后文本打印出来for index, p in enumerate(result):if p == 1:print(input_string[index], end=" ")else:print(input_string[index], end="")print()if __name__ == "__main__":print(torch.backends.mps.is_available())main()input_strings = ["同时国内有望出台新汽车刺激方案","沪胶后市有望延续强势","经过两个交易日的强势调整后","昨日上海天然橡胶期货价格再度大幅上扬"]predict("model.pth", "chars.txt", input_strings)

主要实现了一个基于jieba分词的中文分词模型,模型采用 RNN(循环神经网络)来处理中文文本,通过对句子进行分词,预测每个字是否为词的结尾。具体内容如下:

  1. 模型结构(TorchModel

    • 使用 nn.Embedding 层将每个字符映射到一个高维空间。
    • 通过 nn.RNN 层处理字符序列,提取上下文信息,使用单向 RNNbidirectional=False)。
    • 最后通过 nn.Linear 层将 RNN 输出转化为每个字符的分类结果,分类为 0(非词结尾)或 1(词结尾)。
    • 损失函数为 CrossEntropyLoss,计算预测与真实标签的差异。
  2. 数据处理(Dataset

    • 使用 jieba 分词工具将文本切分为词,并为每个字符标注一个标签。标签为 1 表示该字符是词的结尾,0 表示不是词结尾。
    • 将文本转换为数字序列,并根据最大句子长度进行填充,使得输入数据的形状一致。
  3. 训练过程

    • 数据通过 DataLoader 按批加载,使用 Adam 优化器进行训练。
    • 在每一轮训练中,计算损失并通过反向传播优化模型权重,训练 10 轮。
  4. 预测功能(predict

    • 加载训练好的模型,使用 torch.no_grad() 禁用梯度计算,提高推理速度。
    • 对每个输入字符串进行分词预测,输出每个字是否为词的结尾。若为词结尾,则切分该词并打印。
  5. 核心流程

    • sentence_to_sequence 将文本转换为字符序列,sequence_to_label 生成对应的标签序列。
    • 训练完成后,保存模型和词表,以便后续加载和预测。

代码实现了一个简单的中文分词模型,通过标注每个字符是否为词的结尾,结合 RNN 提取上下文信息,从而实现文本分词功能。

预测效果

输入语句:
“同时国内有望出台新汽车刺激方案”,
“沪胶后市有望延续强势”,
“经过两个交易日的强势调整后”,
“昨日上海天然橡胶期货价格再度大幅上扬”

中文分词后结果:

tensor([0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1])
同时 国内 有 望 出台 新 汽车 刺激 方案 
tensor([1, 1, 0, 1, 1, 1, 0, 1, 0, 1])
沪 胶 后市 有 望 延续 强势 
tensor([0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1])
经过 两个 交易 日 的 强势 调整 后 
tensor([0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1])
昨日 上海 天然 橡胶 期货 价格 再度 大幅 上扬 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/901619.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux-服务器添加审计日志功能

#查看audit软件是否在运行(状态为active而且为绿色表示已经在运行) systemctl start auditd #如果没有在运行的话,查看是否被系统禁用 (audit为0表示被禁用) cat /proc/cmdline | grep -w "audit=0" #修改/etc/default/grub里面audit=0 改为audit=1 #更新GRUB…

uniappx项目上架各手机平台

前段时间用uniappx开发的App,领导要求要在各个主要手机平台上上架了,本来不是我的任务,后来其他人没有空交给我了,上架小白一枚,哭唧唧的自己研究吧,根据领导发的账号密码登录各个平台上架,花费…

第4次课 前缀和与差分 A

课堂学习 前缀和数组 前1个收购点&#xff1a;3箱 前2个收购点&#xff1a;325箱 前3个收购点&#xff1a;32510箱 以此类推… 数组a存储10个收购点的箱数。 收购点编号从1~10&#xff0c;数组下标也从1开始使用。 下标0位置直接赋值0 #include<bits/stdc.h> using nam…

MySQL部分总结

mysql学习笔记&#xff0c;如有不足还请指出&#xff0c;谢谢。 外连接&#xff0c;内连接&#xff0c;全连接 外连接&#xff1a;左外、右外 内连接&#xff1a;自己和自己连接 全连接&#xff1a;左外连接右外链接 mysql unique字段 unique可以在数据库层面避免插入相同…

Spring MVC 请求处理流程详解

步骤1&#xff1a;用户发起请求 所有请求首先被 DispatcherServlet&#xff08;前端控制器&#xff09;拦截&#xff0c;它是整个流程的入口。 DispatcherServlet 继承自 HttpServlet&#xff0c;通过 web.xml 或 WebApplicationInitializer 配置映射路径&#xff08;如 /&…

Vue 高级技巧深度解析

Vue 高级技巧深度解析 mindmaproot(Vue2高级技巧)组件通信EventBusprovide/inject$attrs/$listeners性能优化虚拟DOM优化函数式组件按需加载状态管理Vuex模块化持久化存储严格模式高级指令自定义指令动态组件异步组件渲染控制作用域插槽渲染函数JSX支持一、组件通信的进阶之道 …

2024年React最新高频面试题及核心考点解析,涵盖基础、进阶和新特性,助你高效备战

以下是2024年React最新高频面试题及核心考点解析&#xff0c;涵盖基础、进阶和新特性&#xff0c;助你高效备战&#xff1a; 一、基础篇 React虚拟DOM原理及Diff算法优化策略 • 必考点&#xff1a;虚拟DOM树对比&#xff08;同级比较、Key的作用、组件类型判断&#xff09; •…

Zookeeper单机三节点集群部署(docker-compose方式)

前提: 服务器需要有docker镜像zookeeper:3.9.3 或能连网拉取镜像 服务器上面新建文件夹: mkdir -p /data/zk-cluster/{data,zoo-cfg} 创建三个zookeeper配置文件zoo1.cfg、zoo2.cfg、zoo3.cfg,配置文件里面内容如下(三个文件内容一样): tickTime=2000 initLimit=10 …

面试题之数据库-mysql高阶及业务场景设计

最近开始面试了&#xff0c;410面试了一家公司 针对自己薄弱的面试题库&#xff0c;深入了解下&#xff0c;也应付下面试。在这里先祝愿大家在现有公司好好沉淀&#xff0c;定位好自己的目标&#xff0c;在自己的领域上发光发热&#xff0c;在自己想要的领域上&#xff08;技术…

数字内容体验案例解析与行业应用

数字内容案例深度解析 在零售行业头部品牌的实践中&#xff0c;数字内容体验的革新直接推动了用户行为模式的转变。某国际美妆集团通过搭建智能内容中台&#xff0c;将产品信息库与消费者行为数据实时对接&#xff0c;实现不同渠道的动态内容生成。其电商平台首页的交互式AR试…

4.15 代码随想录第四十四天打卡

99. 岛屿数量(深搜) (1)题目描述: (2)解题思路: #include <iostream> #include <vector> using namespace std;int dir[4][2] {0, 1, 1, 0, -1, 0, 0, -1}; // 四个方向 void dfs(const vector<vector<int>>& grid, vector<vector<bool&g…

【三维重建与生成】GenFusion:SVD统一重建和生成

标题:《GenFusion: Closing the Loop between Reconstruction and Generation via Videos》 来源&#xff1a;西湖大学&#xff1b;慕尼黑工业大学&#xff1b;上海科技大学&#xff1b;香港大学&#xff1b;图宾根大学 项目主页&#xff1a;https://genfusion.sibowu.com 文章…

Quipus,LightRag的Go版本的实现

1 项目简介 奇谱系统当前版本以知识库为核心&#xff0c;基于知识库可以快构建自己的问答系统。知识库的Rag模块的构建算法是参考了LightRag的算法流程的Go版本优化实现&#xff0c;它可以帮助你快速、准确地构建自己的知识库&#xff0c;搭建属于自己的AI智能助手。与当前LLM…

mysql 8 支持直方图

mysql 8 可以通过语句 ANALYZE TABLE table_name UPDATE HISTOGRAM ON column_name WITH 10 BUCKETS; 生产直方图&#xff0c;解决索引数据倾斜的问题 在之前的mysql5.7的版本上是没有的 参考&#xff1a; MySQL :: MySQL 8.0 Reference Manual :: 15.7.3.1 ANALYZE TABL…

力扣-hot100(最长连续序列 - Hash)

128. 最长连续序列 中等 给定一个未排序的整数数组 nums &#xff0c;找出数字连续的最长序列&#xff08;不要求序列元素在原数组中连续&#xff09;的长度。 请你设计并实现时间复杂度为 O(n) 的算法解决此问题。 示例 1&#xff1a; 输入&#xff1a;nums [100,4,200,…

RCEP框架下eBay日本站选品战略重构:五维解析关税红利机遇

2024年RCEP深化实施背景下&#xff0c;亚太跨境电商生态迎来结构性变革。作为协定核心成员的日本市场&#xff0c;其跨境电商平台正经历新一轮价值重构。本文将聚焦eBay日本站&#xff0c;从政策解读到实操路径&#xff0c;系统拆解跨境卖家的战略机遇。 一、关税递减机制下的…

Unity开发框架:输入事件管理类

开发程序的时候经常会出现更改操作方式的情况&#xff0c;这种时候就需要将操作模式以事件的方式注册到管理输入事件的类中&#xff0c;方便可以随时切换和调用 using System; using System.Collections.Generic; using UnityEngine;/// <summary> /// 记录鼠标事件的的…

【kind管理脚本-2】脚本使用说明文档 —— 便捷使用 kind 创建、删除、管理集群脚本

当然可以&#xff0c;以下是为你这份 Kind 管理脚本写的一份使用说明文档&#xff0c;可作为 README.md 或内部文档使用&#xff1a; &#x1f680; Kind 管理脚本说明文档 本脚本是一个便捷的工具&#xff0c;帮助你快速创建、管理和诊断基于 Kind (Kubernetes IN Docker) 的…

opencv常用边缘检测算子示例

opencv常用边缘检测算子示例 1. Canny算子2. Sobel算子3. Scharr算子4. Laplacian算子5. 对比 1. Canny算子 从不同视觉对象中提取有用的结构信息并大大减少要处理的数据量的一种技术&#xff0c;检测算法可以分为以下5个步骤&#xff1a; 噪声过滤&#xff08;高斯滤波&…

Token安全存储的几种方式

文章目录 1. EncryptedSharedPreferences示例代码 2. SQLCipher示例代码 3.使用 Android Keystore加密后存储示例代码1. 生成密钥对2. 使用 KeystoreManager 代码说明安全性建议加密后的几种存储方式1. 加密后采用 SharedPreferences存储2. 加密后采用SQLite数据库存储1. Token…