1 NLP分类之:FastText

0 数据

https://download.csdn.net/download/qq_28611929/88580520?spm=1001.2014.3001.5503

数据集合:0 NLP: 数据获取与EDA-CSDN博客

词嵌入向量文件: embedding_SougouNews.npz

词典文件:vocab.pkl

1 模型

基于fastText做词向量嵌入然后引入2-gram, 3-gram扩充,最后接入一个MLP即可;

fastText 是一个由 Facebook AI Research 实现的开源库,用于进行文本分类和词向量学习。它结合了传统的词袋模型和神经网络的优点,能够快速训练大规模的文本数据。

fastText 的主要特点包括:

1. 快速训练:fastText 使用了层次化 Softmax 和负采样等技术,大大加快了训练速度。

2. 子词嵌入:fastText 将单词表示为字符级别的 n-gram,并将其视为单词的子词。这样可以更好地处理未登录词和稀有词。

3. 文本分类:fastText 提供了一个简单而高效的文本分类接口,可以用于训练和预测多类别文本分类任务。

4. 多语言支持:fastText 支持多种语言,并且可以通过学习共享词向量来提高跨语言任务的性能。

需要注意的是,fastText 主要适用于文本分类任务,对于其他类型的自然语言处理任务(如命名实体识别、机器翻译等),可能需要使用其他模型或方法。

 

 2 代码

nn.Embedding.from_pretrained(config.embedding_pretrained, freeze=False)
 

`nn.Embedding.from_pretrained` 是 PyTorch 中的一个函数,用于从预训练的词向量加载 Embedding 层的权重。

在使用 `nn.Embedding.from_pretrained` 时,你需要提供一个预训练的词向量矩阵作为参数,

freeze 参数: 指定是否冻结该层的权重。预训练的词向量可以是从其他模型(如 Word2Vec 或 GloVe)中得到的。

y = nn.Embedding.from_pretrained (x)

x输入:词的索引

y返回: 词向量

 

import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pickle as pkl
from tqdm import tqdm
import time
from torch.utils.data import Datasetfrom datetime import timedelta
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from collections import defaultdict
from torch.optim import AdamWUNK, PAD = '<UNK>', '<PAD>'  # 未知字,padding符号
RANDOM_SEED = 2023file_path = "./data/online_shopping_10_cats.csv"
vocab_file = "./data/vocab.pkl"
emdedding_file = "./data/embedding_SougouNews.npz"
vocab = pkl.load(open(vocab_file, 'rb'))class MyDataSet(Dataset):def __init__(self, df, vocab,pad_size=None):self.data_info = dfself.data_info['review'] = self.data_info['review'].apply(lambda x:str(x).strip())self.data_info = self.data_info[['review','label']].valuesself.vocab = vocab self.pad_size = pad_sizeself.buckets = 250499  def biGramHash(self,sequence, t):t1 = sequence[t - 1] if t - 1 >= 0 else 0return (t1 * 14918087) % self.bucketsdef triGramHash(self,sequence, t):t1 = sequence[t - 1] if t - 1 >= 0 else 0t2 = sequence[t - 2] if t - 2 >= 0 else 0return (t2 * 14918087 * 18408749 + t1 * 14918087) % self.bucketsdef __getitem__(self, item):result = {}view, label = self.data_info[item]result['view'] = view.strip()result['label'] = torch.tensor(label,dtype=torch.long)token = [i for i in view.strip()]seq_len = len(token)# 填充if self.pad_size:if len(token) < self.pad_size:token.extend([PAD] * (self.pad_size - len(token)))else:token = token[:self.pad_size]seq_len = self.pad_sizeresult['seq_len'] = seq_len# 词表的转换words_line = []for word in token:words_line.append(self.vocab.get(word, self.vocab.get(UNK)))result['input_ids'] = torch.tensor(words_line, dtype=torch.long) # bigram = []trigram = []for i in range(self.pad_size):bigram.append(self.biGramHash(words_line, i))trigram.append(self.triGramHash(words_line, i))result['bigram'] = torch.tensor(bigram, dtype=torch.long)result['trigram'] = torch.tensor(trigram, dtype=torch.long)return resultdef __len__(self):return len(self.data_info)
df = pd.read_csv("./data/online_shopping_10_cats.csv")#myDataset[0]
df_train, df_test = train_test_split(df, test_size=0.1, random_state=RANDOM_SEED)
df_val, df_test = train_test_split(df_test, test_size=0.5, random_state=RANDOM_SEED)
df_train.shape, df_val.shape, df_test.shapedef create_data_loader(df,vocab,pad_size,batch_size=4):ds = MyDataSet(df,vocab,pad_size=pad_size)return DataLoader(ds,batch_size=batch_size)MAX_LEN = 256
BATCH_SIZE = 4
train_data_loader = create_data_loader(df_train,vocab,pad_size=MAX_LEN, batch_size=BATCH_SIZE)
val_data_loader = create_data_loader(df_val,vocab,pad_size=MAX_LEN, batch_size=BATCH_SIZE)
test_data_loader = create_data_loader(df_test,vocab,pad_size=MAX_LEN, batch_size=BATCH_SIZE)class Config(object):"""配置参数"""def __init__(self):self.model_name = 'FastText'self.embedding_pretrained = torch.tensor(np.load("./data/embedding_SougouNews.npz")["embeddings"].astype('float32'))  # 预训练词向量self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')   # 设备self.dropout = 0.5                                              # 随机失活self.require_improvement = 1000                                 # 若超过1000batch效果还没提升,则提前结束训练self.num_classes = 2                                            # 类别数self.n_vocab = 0                                                # 词表大小,在运行时赋值self.num_epochs = 20                                            # epoch数self.batch_size = 128                                           # mini-batch大小self.learning_rate = 1e-4                                       # 学习率self.embed = self.embedding_pretrained.size(1)\if self.embedding_pretrained is not None else 300           # 字向量维度self.hidden_size = 256                                          # 隐藏层大小self.n_gram_vocab = 250499                                      # ngram 词表大小class Model(nn.Module):def __init__(self, config):super(Model, self).__init__()if config.embedding_pretrained is not None:self.embedding = nn.Embedding.from_pretrained(config.embedding_pretrained, freeze=False)else:self.embedding = nn.Embedding(config.n_vocab, config.embed, padding_idx=config.n_vocab - 1)self.embedding_ngram2 = nn.Embedding(config.n_gram_vocab, config.embed)self.embedding_ngram3 = nn.Embedding(config.n_gram_vocab, config.embed)self.dropout = nn.Dropout(config.dropout)self.fc1 = nn.Linear(config.embed * 3, config.hidden_size)# self.dropout2 = nn.Dropout(config.dropout)self.fc2 = nn.Linear(config.hidden_size, config.num_classes)def forward(self, x):out_word = self.embedding(x['input_ids'])out_bigram = self.embedding_ngram2(x['bigram'])out_trigram = self.embedding_ngram3(x['trigram'])out = torch.cat((out_word, out_bigram, out_trigram), -1)out = out.mean(dim=1)out = self.dropout(out)out = self.fc1(out)out = F.relu(out)out = self.fc2(out)return outconfig = Config()
model = Model(config)device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)EPOCHS = 5 # 训练轮数
optimizer = AdamW(model.parameters(),lr=2e-4)
total_steps = len(train_data_loader) * EPOCHS
# schedule = get_linear_schedule_with_warmup(optimizer,num_warmup_steps=0,
#                                num_training_steps=total_steps)
loss_fn = nn.CrossEntropyLoss().to(device)def train_epoch(model,data_loader,loss_fn,device,n_exmaples,schedule=None):model = model.train()losses = []correct_predictions = 0for d in tqdm(data_loader):# input_ids = d['input_ids'].to(device)# attention_mask = d['attention_mask'].to(device)targets = d['label']#.to(device)outputs = model(d)_,preds = torch.max(outputs, dim=1)loss = loss_fn(outputs,targets)losses.append(loss.item())correct_predictions += torch.sum(preds==targets)loss.backward()nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)optimizer.step()#scheduler.step()optimizer.zero_grad()return correct_predictions.double() / n_examples, np.mean(losses)def eval_model(model, data_loader, loss_fn, device, n_examples):model = model.eval() # 验证预测模式losses = []correct_predictions = 0with torch.no_grad():for d in data_loader:targets = d['label']#.to(device)outputs = model(d)_, preds = torch.max(outputs, dim=1)loss = loss_fn(outputs, targets)correct_predictions += torch.sum(preds == targets)losses.append(loss.item())return correct_predictions.double() / n_examples, np.mean(losses)# train model
EPOCHS = 5
history = defaultdict(list) # 记录10轮loss和acc
best_accuracy = 0for epoch in range(EPOCHS):print(f'Epoch {epoch + 1}/{EPOCHS}')print('-' * 10)train_acc, train_loss = train_epoch(model,train_data_loader,loss_fn,optimizer,device,len(df_train))print(f'Train loss {train_loss} accuracy {train_acc}')val_acc, val_loss = eval_model(model,val_data_loader,loss_fn,device,len(df_val))print(f'Val   loss {val_loss} accuracy {val_acc}')print()history['train_acc'].append(train_acc)history['train_loss'].append(train_loss)history['val_acc'].append(val_acc)history['val_loss'].append(val_loss)if val_acc > best_accuracy:torch.save(model.state_dict(), 'best_model_state.bin')best_accuracy = val_acc

备注: CPU训练模型很慢啊!!!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/182038.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

vue3+vite+ts项目打包时出错

项目中引入了element-plus国家化的配置&#xff0c;然后进行项目打包&#xff0c;报下面的错误 解决方法&#xff1a; 在main.ts中添加 // ts-ignore

【存储】blotdb的原理及实现(2)

【存储】etcd的存储是如何实现的(3)-blotdb 在etcd系列中&#xff0c;我们对作为etcd底层kv存储的boltdb进行了比较全面的介绍。但是还有两个点没有涉及。 第一点是boltdb如何和磁盘文件交互。 持久化存储和我们一般业务应用程序的最大区别就是其强依赖磁盘文件。一方面文件数…

Linux系统之一次性计划任务at命令的基本使用

Linux系统之一次性计划任务at命令的基本使用 一、at命令介绍二、at命令的使用帮助2.1 at命令的help帮助信息2.2 at命令的语法解释 三、at命令的日常使用3.1 立即执行一次性任务3.2 指定时间执行一次性任务3.3 查询计划任务3.4 其他指定时间用法3.5 删除已经设置的计划任务3.6 显…

深度学习毕设项目 基于生成对抗网络的照片上色动态算法设计与实现 - 深度学习 opencv python

文章目录 1 前言1 课题背景2 GAN(生成对抗网络)2.1 简介2.2 基本原理 3 DeOldify 框架4 First Order Motion Model 1 前言 &#x1f525; 这两年开始毕业设计和毕业答辩的要求和难度不断提升&#xff0c;传统的毕设题目缺少创新和亮点&#xff0c;往往达不到毕业答辩的要求&am…

echarts案例网站

一、ppchart 网站&#xff1a;https://ppchart.com/#/ 二、echarts官网示例 网站&#xff1a;https://echarts.apache.org/examples/zh/index.html

1992-2021年区县经过矫正的夜间灯光数据(GNLD、VIIRS)

1992-2021年区县经过矫正的夜间灯光数据&#xff08;GNLD、VIIRS&#xff09; 1、时间&#xff1a;1992-2021年3月&#xff0c;其中1992-2013年为年度数据&#xff0c;2013-2021年3月为月度数据 2、来源&#xff1a;DMSP、VIIRS 3、范围&#xff1a;区县数据 4、指标解释&a…

NeurIPS 2023|AI Agents先行者CAMEL:第一个基于大模型的多智能体框架

AI Agents是当下大模型领域备受关注的话题&#xff0c;用户可以引入多个扮演不同角色的LLM Agents参与到实际的任务中&#xff0c;Agents之间会进行竞争和协作等多种形式的动态交互&#xff0c;进而产生惊人的群体智能效果。本文介绍了来自KAUST研究团队的大模型心智交互CAMEL框…

双指针算法(题目与答案讲解)

文章目录 题目移动零复写零两数之和N数之和(>2个数) 答案讲解移动零复写零两数之和N数之和 题目 力扣 移动零 1、移动零:题目链接 复写零 2、复写零:题目链接 两数之和 3、两数之和题目链接 N数之和(>2个数) 4、N数之和(三个数、四个数) 三个数:题目链接 四个数题目链接…

Docker、Kubernetes、OCI、CRI-O、containerd、runc 之间的关系以及它们是如何一起工作的?

最近网上看到一张图片&#xff0c;能够很清晰地展现出 Docker、Kubernetes、OCI、CRI-O、containerd、runc 之间的关系以及它们是如何在一起工作的&#xff0c;如下&#xff1a; 本文可以作为之前一篇文章&#xff08;《K8s、Docker、CRI、OCI 之间的爱恨情仇》&#xff09;的…

依靠堡塔面板,飞速部署Java项目

依靠堡塔面板&#xff0c;飞速部署Java项目 环境介绍 环境介绍&#xff1a; 面板版本&#xff1a;8.0.26 操作系统版本&#xff1a;CentOS7.9.2009 Nginx版本&#xff1a;1.22 Java环境&#xff1a;Tomcat8&#xff0c;JDK&#xff1a;OpenJDK-1.8.0-internal MySQL版本&#…

CodeMeter软件保护及授权管理解决方案(二)

客户端管理工具 CodeMeter Runtime是CodeMeter解决方案中的重要组成部分&#xff0c;其为独立软件包&#xff0c;开发者需要把CodeMeter Runtime和加密后的软件一起发布。CodeMeter Runtim包括以下组件用于实现授权的使用&#xff1a; CodeMeter License Server授权服务器 Co…

7 种 JVM 垃圾收集器详解

一、概述 如果说收集算法是内存回收的方法论&#xff0c;那么垃圾收集器就是内存回收的具体实现。Java虚拟机规范中对垃圾收集器应该如何实现并没有任何规定&#xff0c;因此不同的厂商、版本的虚拟机所提供的垃圾收集器都可能会有很大差别&#xff0c;并且一般都会提供参数供用…

11月29日作业

作业&#xff1a; 自己封装一个矩形类(Rect)&#xff0c;拥有私有属性:宽度(width)、高度(height)&#xff0c; 定义公有成员函数: 初始化函数:void init(int w, int h) 更改宽度的函数:set_w(int w) 更改高度的函数:set_h(int h) 输出该矩形的周长和面积函数:void show(…

HarmonyOS4.0开发应用(一)【工具安装】

工具安装 地址:https://developer.harmonyos.com/cn/develop/deveco-studio#download 我是windows&#xff0c;所以安装的windows 解压后双击该文件进行安装 安装完成后可选择是否导入开发工具的设置 这里选择不导入进入到工具内 点击Next进行安装 都没问题最终就可以开…

linux 脚本之条件语句 if与case

1. 测试 条件测试&#xff1a;判断某需求是否满足&#xff0c;需要由测试机制来实现&#xff0c;专用的测试表达式需要由测试命令辅助完成。 测试过程&#xff0c;实现评估布尔声明&#xff0c;以便用在条件性环境下进行执行 若真&#xff0c;则状态码变量 $? 返回0若假&am…

接口性能测试 —— Jmeter并发与持续性压测

接口压测的方式&#xff1a; 1、同时并发&#xff1a;设置线程组、执行时间、循环次数&#xff0c;这种方式可以控制接口请求的次数 2、持续压测&#xff1a;设置线程组、循环次数&#xff0c;勾选“永远”&#xff0c;调度器&#xff08;持续时间&#xff09;&#xff0c;这种…

轻松整合Knife4j:快速搭建Swagger文档界面与接口调试

Knife4j 是一个为 Java 开发者提供的 Swagger 文档聚合工具&#xff0c;它是 Swagger-Bootstrap-UI 的升级版。它的主要功能是生成和展示 API 文档&#xff0c;让开发者能够更轻松地查看和测试接口。 整合 Knife4j&#xff08;Swagger-Bootstrap-UI 的升级版&#xff09;到 Spr…

从 Elasticsearch 到 SelectDB,观测云实现日志存储与分析的 10 倍性价比提升

作者&#xff1a;观测云 CEO 蒋烁淼 & 飞轮科技技术团队 在云计算逐渐成熟的当下&#xff0c;越来越多的企业开始将业务迁移到云端&#xff0c;传统的监控和故障排查方法已经无法满足企业的需求。在可观测理念逐渐深入人心的当下&#xff0c;人们越来越意识到通过多层次、…

第三方发起备份的ORA-00245问题

文章目录 前言一、信息确认共享目录位置控制文件快照位置节点1节点2 二、RAC修改snapshot controlfile 参数三、字典表确认以及测试 前言 在使用 AnyBackup 管理控制台发起 Oracle RAC 数据库备份后&#xff0c;在任务历史记录 > 执行输出中显示如下错误信息&#xff1a; c…

Unity版本使用情况统计(更新至2023年10月)

本期UWA发布的内容是第十三期Unity版本使用统计&#xff0c;统计周期为2023年5月至2023年10月&#xff0c;数据来源于UWA网站&#xff08;www.uwa4d.com&#xff09;性能诊断提测的项目。希望给Unity开发者提供相关的行业趋势&#xff0c;了解近半年来哪些Unity版本的使用概率更…