(Transfer Learning)迁移学习在IMDB上训练情感分析模型

1. 背景

有些场景下,开始的时候数据量很小,如果我们用一个几千条数据训练一个全新的深度机器学习的文本分类模型,效果不会很好。这个时候你有两种选择,1.用传统的机器学习训练,2.利用迁移学习在一个预训练的模型上训练。本博客教你怎么用tensorflow Hub和keras 在少量的数据上训练一个文本分类模型。

2. 实践

2.1. 下载IMDB 数据集,参考下面博客。

Imdb影评的数据集介绍与下载_imdb影评数据集-CSDN博客

2.2.  预处理数据

替换掉imdb目录 (imdb_raw_data_dir). 创建dataset目录。

import numpy as np
import os as osimport re
from sklearn.model_selection import train_test_splitvocab_size = 30000
maxlen = 200
imdb_raw_data_dir = "/Users/harry/Documents/apps/ml/aclImdb"
save_dir = "dataset"def get_data(datapath =r'D:\train_data\aclImdb\aclImdb\train' ):pos_files = os.listdir(datapath + '/pos')neg_files = os.listdir(datapath + '/neg')print(len(pos_files))print(len(neg_files))pos_all = []neg_all = []for pf, nf in zip(pos_files, neg_files):with open(datapath + '/pos' + '/' + pf, encoding='utf-8') as f:s = f.read()s = process(s)pos_all.append(s)with open(datapath + '/neg' + '/' + nf, encoding='utf-8') as f:s = f.read()s = process(s)neg_all.append(s)print(len(pos_all))# print(pos_all[0])print(len(neg_all))X_orig= np.array(pos_all + neg_all)# print(X_orig)Y_orig = np.array([1 for _ in range(len(pos_all))] + [0 for _ in range(len(neg_all))])print("X_orig:", X_orig.shape)print("Y_orig:", Y_orig.shape)return X_orig, Y_origdef generate_dataset():X_orig, Y_orig = get_data(imdb_raw_data_dir + r'/train')X_orig_test, Y_orig_test = get_data(imdb_raw_data_dir + r'/test')X_orig = np.concatenate([X_orig, X_orig_test])Y_orig = np.concatenate([Y_orig, Y_orig_test])X = X_origY = Y_orignp.random.seed = 1random_indexs = np.random.permutation(len(X))X = X[random_indexs]Y = Y[random_indexs]X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.3)print("X_train:", X_train.shape)print("y_train:", y_train.shape)print("X_test:", X_test.shape)print("y_test:", y_test.shape)np.savez(save_dir + '/train_test', X_train=X_train, y_train=y_train, X_test= X_test, y_test=y_test )def rm_tags(text):re_tag = re.compile(r'<[^>]+>')return re_tag.sub(' ', text)def clean_str(string):string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string)string = re.sub(r"\'s", " \'s", string)  # it's -> it 'sstring = re.sub(r"\'ve", " \'ve", string) # I've -> I 'vestring = re.sub(r"n\'t", " n\'t", string) # doesn't -> does n'tstring = re.sub(r"\'re", " \'re", string) # you're -> you arestring = re.sub(r"\'d", " \'d", string)  # you'd -> you 'dstring = re.sub(r"\'ll", " \'ll", string) # you'll -> you 'llstring = re.sub(r"\'m", " \'m", string) # I'm -> I 'mstring = re.sub(r",", " , ", string)string = re.sub(r"!", " ! ", string)string = re.sub(r"\(", " \( ", string)string = re.sub(r"\)", " \) ", string)string = re.sub(r"\?", " \? ", string)string = re.sub(r"\s{2,}", " ", string)return string.strip().lower()def process(text):text = clean_str(text)text = rm_tags(text)#text = text.lower()return  textif __name__ == '__main__':generate_dataset()

执行完后,产生train_test.npz 文件

2.3.  训练模型

1. 取数据集

def get_dataset_to_train():train_test = np.load('dataset/train_test.npz', allow_pickle=True)x_train =  train_test['X_train']y_train = train_test['y_train']x_test =  train_test['X_test']y_test = train_test['y_test']return x_train, y_train, x_test, y_test

2. 创建模型

基于nnlm-en-dim50/2 预训练的文本嵌入向量,在模型外面加了两层全连接。

def get_model():hub_layer = hub.KerasLayer(embedding_url, input_shape=[], dtype=tf.string, trainable=True)# Build the modelmodel = Sequential([hub_layer,Dense(16, activation='relu'),Dropout(0.5),Dense(2, activation='softmax')])print(model.summary())model.compile(optimizer=keras.optimizers.Adam(),loss=keras.losses.SparseCategoricalCrossentropy(),metrics=[keras.metrics.SparseCategoricalAccuracy()])return model

还可以使用来自 TFHub 的许多其他预训练文本嵌入向量:

  • google/nnlm-en-dim128/2 - 基于与 google/nnlm-en-dim50/2 相同的数据并使用相同的 NNLM 架构进行训练,但具有更大的嵌入向量维度。更大维度的嵌入向量可以改进您的任务,但可能需要更长的时间来训练您的模型。
  • google/nnlm-en-dim128-with-normalization/2 - 与 google/nnlm-en-dim128/2 相同,但具有额外的文本归一化,例如移除标点符号。如果您的任务中的文本包含附加字符或标点符号,这会有所帮助。
  • google/universal-sentence-encoder/4 - 一个可产生 512 维嵌入向量的更大模型,使用深度平均网络 (DAN) 编码器训练。

还有很多!在 TFHub 上查找更多文本嵌入向量模型。

3. 评估你的模型

def evaluate_model(test_data, test_labels):model = load_trained_model()# Evaluate the modelresults = model.evaluate(test_data, test_labels, verbose=2)print("Test accuracy:", results[1])def load_trained_model():# model = get_model()# model.load_weights('./models/model_new1.h5')model = tf.keras.models.load_model('models_pb')return model

4. 测试几个例子

def predict(real_data):model  = load_trained_model()probabilities = model.predict([real_data]);print("probabilities :",probabilities)result =  get_label(probabilities)return resultdef get_label(probabilities):index = np.argmax(probabilities[0])print("index :" + str(index))result_str =  index_dic.get(str(index))# result_str = list(index_dic.keys())[list(index_dic.values()).index(index)]return result_strdef predict_my_module():# review = "I don't like it"# review = "this is bad movie "# review = "This is good movie"review = " this is terrible movie"# review = "This isn‘t great movie"# review = "i think this is bad movie"# review = "I'm not very disappoint for this movie"# review = "I'm not very disappoint for this movie"# review = "I am very happy for this movie"#neg:0 postive:1s = predict(review)print(s)if __name__ == '__main__':x_train, y_train, x_test, y_test = get_dataset_to_train()model = get_model()model = train(model, x_train, y_train, x_test, y_test)evaluate_model(x_test, y_test)predict_my_module()

完整代码

import numpy as np
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Dense, Dropout
import keras as keras
from keras.callbacks import EarlyStopping, ModelCheckpoint
import tensorflow_hub as hubembedding_url = "https://tfhub.dev/google/nnlm-en-dim50/2"index_dic = {"0":"negative", "1": "positive"}def get_dataset_to_train():train_test = np.load('dataset/train_test.npz', allow_pickle=True)x_train =  train_test['X_train']y_train = train_test['y_train']x_test =  train_test['X_test']y_test = train_test['y_test']return x_train, y_train, x_test, y_testdef get_model():hub_layer = hub.KerasLayer(embedding_url, input_shape=[], dtype=tf.string, trainable=True)# Build the modelmodel = Sequential([hub_layer,Dense(16, activation='relu'),Dropout(0.5),Dense(2, activation='softmax')])print(model.summary())model.compile(optimizer=keras.optimizers.Adam(),loss=keras.losses.SparseCategoricalCrossentropy(),metrics=[keras.metrics.SparseCategoricalAccuracy()])return modeldef train(model , train_data, train_labels, test_data, test_labels):# train_data, train_labels, test_data, test_labels = get_dataset_to_train()train_data = [tf.compat.as_str(tf.compat.as_bytes(str(x))) for x in train_data]test_data = [tf.compat.as_str(tf.compat.as_bytes(str(x))) for x in test_data]train_data = np.asarray(train_data)  # Convert to numpy arraytest_data = np.asarray(test_data)  # Convert to numpy arrayprint(train_data.shape, test_data.shape)early_stop = EarlyStopping(monitor='val_sparse_categorical_accuracy', patience=4, mode='max', verbose=1)# 定义ModelCheckpoint回调函数# checkpoint = ModelCheckpoint( './models/model_new1.h5', monitor='val_sparse_categorical_accuracy', save_best_only=True,#                              mode='max', verbose=1)checkpoint_pb = ModelCheckpoint(filepath="./models_pb/",  monitor='val_sparse_categorical_accuracy', save_weights_only=False, save_best_only=True)history = model.fit(train_data[:2000], train_labels[:2000], epochs=45, batch_size=45, validation_data=(test_data, test_labels), shuffle=True,verbose=1, callbacks=[early_stop, checkpoint_pb])print("history", history)return modeldef evaluate_model(test_data, test_labels):model = load_trained_model()# Evaluate the modelresults = model.evaluate(test_data, test_labels, verbose=2)print("Test accuracy:", results[1])def predict(real_data):model  = load_trained_model()probabilities = model.predict([real_data]);print("probabilities :",probabilities)result =  get_label(probabilities)return resultdef get_label(probabilities):index = np.argmax(probabilities[0])print("index :" + str(index))result_str =  index_dic.get(str(index))# result_str = list(index_dic.keys())[list(index_dic.values()).index(index)]return result_strdef load_trained_model():# model = get_model()# model.load_weights('./models/model_new1.h5')model = tf.keras.models.load_model('models_pb')return modeldef predict_my_module():# review = "I don't like it"# review = "this is bad movie "# review = "This is good movie"review = " this is terrible movie"# review = "This isn‘t great movie"# review = "i think this is bad movie"# review = "I'm not very disappoint for this movie"# review = "I'm not very disappoint for this movie"# review = "I am very happy for this movie"#neg:0 postive:1s = predict(review)print(s)if __name__ == '__main__':x_train, y_train, x_test, y_test = get_dataset_to_train()model = get_model()model = train(model, x_train, y_train, x_test, y_test)evaluate_model(x_test, y_test)predict_my_module()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/151968.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python如何实现原型设计模式?什么是原型设计模式?Python 原型设计模式示例代码

什么是原型&#xff08;ProtoType&#xff09;设计模式&#xff1f; 原型模式&#xff08;Prototype Pattern&#xff09;是一种创建型设计模式&#xff0c;旨在通过复制现有对象来创建新对象&#xff0c;而无需通过标准的构造方式。它允许我们基于现有对象创建新对象&#xf…

PHPmail 发送邮件错误 550 的原因是什么?

电子邮件错误消息链接到简单邮件传输协议 (SMTP)&#xff0c;这是一组发送和接收电子邮件的标准化规则。因此&#xff0c;它也称为 SMTP 550 错误代码。在某些情况下&#xff0c;电子邮件错误 550 是由收件人一方的问题引起的。 以下是电子邮件错误 550 的一些可能原因&#x…

MFC/QT 一些快要遗忘的细节:

1&#xff1a;企业应用中&#xff0c;MFC平台除了用常见的对话框模式还有一种常用的就是单文档模式&#xff0c; 维护别人的代码&#xff0c;不容易区分,其实找与程序同名的cpp就知道了&#xff0c;比如项目名称为 DoCMFCDemo&#xff0c;那么就看BOOL CDocMFCDemoApp::InitI…

代码随想录算法训练营第23期day59|503.下一个更大元素II、42. 接雨水

一、503.下一个更大元素II 力扣题目链接 可以不扩充nums&#xff0c;在遍历的过程中模拟走两边nums class Solution { public:vector<int> nextGreaterElements(vector<int>& nums) {vector<int> result(nums.size(), -1);if (nums.size() 0) return…

外贸自建站什么意思?自建独立网站的好处?

外贸自建站的含义是什么&#xff1f;如何区分自建站和独立站&#xff1f; 随着全球贸易的不断发展&#xff0c;越来越多的企业开始关注外贸自建站。那么&#xff0c;“外贸自建站”到底是什么意思呢&#xff1f;海洋建站将为您详细解析这个问题&#xff0c;带您深入了解这一新…

Spring cloud负载均衡@LoadBalanced LoadBalancerClient

LoadBalance vs Ribbon 由于Spring cloud2020之后移除了Ribbon&#xff0c;直接使用Spring Cloud LoadBalancer作为客户端负载均衡组件&#xff0c;我们讨论Spring负载均衡以Spring Cloud2020之后版本为主&#xff0c;学习Spring Cloud LoadBalance&#xff0c;暂不讨论Ribbon…

Win10 开始菜单、微软app和设置都打不开(未解决)

环境&#xff1a; Win10专业版 问题描述&#xff1a; Win10 开始菜单、微软app和设置都打不开,桌面个性话打开就报错&#xff0c;打开个性化该文件没有与之关联的程序来执行该操作 解决方案&#xff1a; 一般造成原因是MS-Settings文件系统错误 1.先重启电脑&#xff08;重…

【开源】基于Vue.js的高校宿舍调配管理系统

项目编号&#xff1a; S 051 &#xff0c;文末获取源码。 \color{red}{项目编号&#xff1a;S051&#xff0c;文末获取源码。} 项目编号&#xff1a;S051&#xff0c;文末获取源码。 目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能需求2.1 学生端2.2 宿管2.3 老师端 三、系统…

Pycharm中添加Python库指南

一、介绍 Pycharm是一款为Python开发者提供的集成开发环境&#xff08;IDE&#xff09;&#xff0c;支持执行、调试Python代码&#xff0c;并提供了许多有用的工具和功能&#xff0c;其中之一就是在Pycharm中添加Python库。 添加Python库有许多好处&#xff0c;比如能够增加开…

【云栖 2023】林伟:大数据 AI 一体化的解读

云布道师 本文根据 2023 云栖大会演讲实录整理而成&#xff0c;演讲信息如下&#xff1a; 演讲人&#xff1a;林伟 | 阿里云研究员&#xff0c;阿里云计算平台事业部首席架构师&#xff0c;阿里云人工智能平台 PAI 和大数据开发治理平台 DataWorks 负责人 演讲主题&#xff1a…

推荐10 本软件架构技术的好书【赠书活动|第11期《高并发架构实战》】

相信大家都对未来的职业发展有着憧憬和规划&#xff0c;要做架构师、要做技术总监、要做CTO。对于如何实现自己的职业规划也都信心满满&#xff0c;努力工作、好好学习、不断提升自己。 文章目录 《高并发架构实战&#xff1a;从需求分析到系统设计》《架构师的自我修炼&#x…

手写消息队列(基于RabbitMQ)

一、什么是消息队列&#xff1f; 提到消息队列是否唤醒了你脑海深处的记忆&#xff1f;回看前面的这篇文章&#xff1a;《Java 多线程系列Ⅳ&#xff08;单例模式阻塞式队列定时器线程池&#xff09;》&#xff0c;其中我们在介绍阻塞队列时说过&#xff0c;阻塞队列最大的用途…

在VSCode创建vue项目,出现“因为在此系统上禁止运行脚本”问题

问题&#xff1a;vue : 无法加载文件 C:\Users\***\***\Roaming\npm\vue.ps1&#xff0c;因为在此系统上禁止运行脚本。有关详细信息&#xff0c;请参阅 ht tps:/go.microsoft.com/fwlink/?LinkID135170 中的 about_Execution_Policies。 所在位置 行:1 字符: 1 解决&#xff…

Elasticsearch文档操作

一、Elasticsearch的CURD 1、CURD之Create PUT lqz/doc/1 {"name":"顾老二","age":30,"from": "gu","desc": "皮肤黑、武器长、性格直","tags": ["黑", "长", "直…

目标检测标注工具AutoDistill

引言 在快速发展的机器学习领域&#xff0c;有一个方面一直保持不变&#xff1a;繁琐和耗时的数据标注任务。无论是用于图像分类、目标检测还是语义分割&#xff0c;长期以来人工标记的数据集一直是监督学习的基础。 然而&#xff0c;由于一个创新性的工具 AutoDistill&#x…

【Python 千题 —— 基础篇】输出可以被5整除的数

题目描述 题目描述 输出40以内可以被5整除的数&#xff0c;每一个数字间隔一个空格。 输入描述 无输入。 输出描述 输出40以内可以被5整除的数。 示例 示例 ① 输出&#xff1a; 0 5 10 15 20 25 30 35 40 代码讲解 下面是本题的代码&#xff1a; # 描述: 输出40以…

Linux shell编程学习笔记26:stty(set tty)

之前我们探讨了Linux中的tty&#xff0c;tty命令的主要功能是显示当前使用的终端名称。 如果我们想进一步对tty进行设置&#xff0c;就要用到stty。 stty的功能&#xff1a;显示和修改终端特性&#xff08;Print or change terminal characteristics&#xff09;。 1 stty -…

编译安装适用于树梅派4B的android系统

一、下载android源码 aosp 一般来说需要通过storage.googleapis.com 下载&#xff0c;但是由于网络限制的原因&#xff0c;采用通过清华源镜像来下载。 1.打开 清华源AOSP镜像 (可以参考里面步骤下载) 2.下载repo 工具 mkdir ~/bin PATH~/bin:$PATH curl https://storage…

轻量封装WebGPU渲染系统示例<35>- HDR环境数据应用到PBR渲染材质

当前示例源码github地址: https://github.com/vilyLei/voxwebgpu/blob/feature/rendering/src/voxgpu/sample/BasePbrMaterialTest.ts 当前示例运行效果: 微调参数之后的效果: 此示例基于此渲染系统实现&#xff0c;当前示例TypeScript源码如下: export class BasePbrMateri…

Sentinel 熔断规则 (DegradeRule)

Sentinel 是面向分布式、多语言异构化服务架构的流量治理组件&#xff0c;主要以流量为切入点&#xff0c;从流量路由、流量控制、流量整形、熔断降级、系统自适应过载保护、热点流量防护等多个维度来帮助开发者保障微服务的稳定性。 SpringbootDubboNacos 集成 Sentinel&…