利用(Transfer Learning)迁移学习在IMDB数据上训练一个文本分类模型

1. 背景

有些场景下,开始的时候数据量很小,如果我们用一个几千条数据训练一个全新的深度机器学习的文本分类模型,效果不会很好。这个时候你有两种选择,1.用传统的机器学习训练,2.利用迁移学习在一个预训练的模型上训练。本博客教你怎么用tensorflow Hub和keras 在少量的数据上训练一个文本分类模型。

2. 实践

2.1. 下载IMDB 数据集,参考下面博客。

Imdb影评的数据集介绍与下载_imdb影评数据集-CSDN博客

2.2.  预处理数据

替换掉imdb目录 (imdb_raw_data_dir). 创建dataset目录。

import numpy as np
import os as osimport re
from sklearn.model_selection import train_test_splitvocab_size = 30000
maxlen = 200
imdb_raw_data_dir = "/Users/harry/Documents/apps/ml/aclImdb"
save_dir = "dataset"def get_data(datapath =r'D:\train_data\aclImdb\aclImdb\train' ):pos_files = os.listdir(datapath + '/pos')neg_files = os.listdir(datapath + '/neg')print(len(pos_files))print(len(neg_files))pos_all = []neg_all = []for pf, nf in zip(pos_files, neg_files):with open(datapath + '/pos' + '/' + pf, encoding='utf-8') as f:s = f.read()s = process(s)pos_all.append(s)with open(datapath + '/neg' + '/' + nf, encoding='utf-8') as f:s = f.read()s = process(s)neg_all.append(s)print(len(pos_all))# print(pos_all[0])print(len(neg_all))X_orig= np.array(pos_all + neg_all)# print(X_orig)Y_orig = np.array([1 for _ in range(len(pos_all))] + [0 for _ in range(len(neg_all))])print("X_orig:", X_orig.shape)print("Y_orig:", Y_orig.shape)return X_orig, Y_origdef generate_dataset():X_orig, Y_orig = get_data(imdb_raw_data_dir + r'/train')X_orig_test, Y_orig_test = get_data(imdb_raw_data_dir + r'/test')X_orig = np.concatenate([X_orig, X_orig_test])Y_orig = np.concatenate([Y_orig, Y_orig_test])X = X_origY = Y_orignp.random.seed = 1random_indexs = np.random.permutation(len(X))X = X[random_indexs]Y = Y[random_indexs]X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.3)print("X_train:", X_train.shape)print("y_train:", y_train.shape)print("X_test:", X_test.shape)print("y_test:", y_test.shape)np.savez(save_dir + '/train_test', X_train=X_train, y_train=y_train, X_test= X_test, y_test=y_test )def rm_tags(text):re_tag = re.compile(r'<[^>]+>')return re_tag.sub(' ', text)def clean_str(string):string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string)string = re.sub(r"\'s", " \'s", string)  # it's -> it 'sstring = re.sub(r"\'ve", " \'ve", string) # I've -> I 'vestring = re.sub(r"n\'t", " n\'t", string) # doesn't -> does n'tstring = re.sub(r"\'re", " \'re", string) # you're -> you arestring = re.sub(r"\'d", " \'d", string)  # you'd -> you 'dstring = re.sub(r"\'ll", " \'ll", string) # you'll -> you 'llstring = re.sub(r"\'m", " \'m", string) # I'm -> I 'mstring = re.sub(r",", " , ", string)string = re.sub(r"!", " ! ", string)string = re.sub(r"\(", " \( ", string)string = re.sub(r"\)", " \) ", string)string = re.sub(r"\?", " \? ", string)string = re.sub(r"\s{2,}", " ", string)return string.strip().lower()def process(text):text = clean_str(text)text = rm_tags(text)#text = text.lower()return  textif __name__ == '__main__':generate_dataset()

执行完后,产生train_test.npz 文件

2.3.  训练模型

1. 取数据集

def get_dataset_to_train():train_test = np.load('dataset/train_test.npz', allow_pickle=True)x_train =  train_test['X_train']y_train = train_test['y_train']x_test =  train_test['X_test']y_test = train_test['y_test']return x_train, y_train, x_test, y_test

2. 创建模型

基于nnlm-en-dim50/2 预训练的文本嵌入向量,在模型外面加了两层全连接。

def get_model():hub_layer = hub.KerasLayer(embedding_url, input_shape=[], dtype=tf.string, trainable=True)# Build the modelmodel = Sequential([hub_layer,Dense(16, activation='relu'),Dropout(0.5),Dense(2, activation='softmax')])print(model.summary())model.compile(optimizer=keras.optimizers.Adam(),loss=keras.losses.SparseCategoricalCrossentropy(),metrics=[keras.metrics.SparseCategoricalAccuracy()])return model

还可以使用来自 TFHub 的许多其他预训练文本嵌入向量:

  • google/nnlm-en-dim128/2 - 基于与 google/nnlm-en-dim50/2 相同的数据并使用相同的 NNLM 架构进行训练,但具有更大的嵌入向量维度。更大维度的嵌入向量可以改进您的任务,但可能需要更长的时间来训练您的模型。
  • google/nnlm-en-dim128-with-normalization/2 - 与 google/nnlm-en-dim128/2 相同,但具有额外的文本归一化,例如移除标点符号。如果您的任务中的文本包含附加字符或标点符号,这会有所帮助。
  • google/universal-sentence-encoder/4 - 一个可产生 512 维嵌入向量的更大模型,使用深度平均网络 (DAN) 编码器训练。

还有很多!在 TFHub 上查找更多文本嵌入向量模型。

3. 评估你的模型

def evaluate_model(test_data, test_labels):model = load_trained_model()# Evaluate the modelresults = model.evaluate(test_data, test_labels, verbose=2)print("Test accuracy:", results[1])def load_trained_model():# model = get_model()# model.load_weights('./models/model_new1.h5')model = tf.keras.models.load_model('models_pb')return model

4. 测试几个例子

def predict(real_data):model  = load_trained_model()probabilities = model.predict([real_data]);print("probabilities :",probabilities)result =  get_label(probabilities)return resultdef get_label(probabilities):index = np.argmax(probabilities[0])print("index :" + str(index))result_str =  index_dic.get(str(index))# result_str = list(index_dic.keys())[list(index_dic.values()).index(index)]return result_strdef predict_my_module():# review = "I don't like it"# review = "this is bad movie "# review = "This is good movie"review = " this is terrible movie"# review = "This isn‘t great movie"# review = "i think this is bad movie"# review = "I'm not very disappoint for this movie"# review = "I'm not very disappoint for this movie"# review = "I am very happy for this movie"#neg:0 postive:1s = predict(review)print(s)if __name__ == '__main__':x_train, y_train, x_test, y_test = get_dataset_to_train()model = get_model()model = train(model, x_train, y_train, x_test, y_test)evaluate_model(x_test, y_test)predict_my_module()

完整代码

import numpy as np
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Dense, Dropout
import keras as keras
from keras.callbacks import EarlyStopping, ModelCheckpoint
import tensorflow_hub as hubembedding_url = "https://tfhub.dev/google/nnlm-en-dim50/2"index_dic = {"0":"negative", "1": "positive"}def get_dataset_to_train():train_test = np.load('dataset/train_test.npz', allow_pickle=True)x_train =  train_test['X_train']y_train = train_test['y_train']x_test =  train_test['X_test']y_test = train_test['y_test']return x_train, y_train, x_test, y_testdef get_model():hub_layer = hub.KerasLayer(embedding_url, input_shape=[], dtype=tf.string, trainable=True)# Build the modelmodel = Sequential([hub_layer,Dense(16, activation='relu'),Dropout(0.5),Dense(2, activation='softmax')])print(model.summary())model.compile(optimizer=keras.optimizers.Adam(),loss=keras.losses.SparseCategoricalCrossentropy(),metrics=[keras.metrics.SparseCategoricalAccuracy()])return modeldef train(model , train_data, train_labels, test_data, test_labels):# train_data, train_labels, test_data, test_labels = get_dataset_to_train()train_data = [tf.compat.as_str(tf.compat.as_bytes(str(x))) for x in train_data]test_data = [tf.compat.as_str(tf.compat.as_bytes(str(x))) for x in test_data]train_data = np.asarray(train_data)  # Convert to numpy arraytest_data = np.asarray(test_data)  # Convert to numpy arrayprint(train_data.shape, test_data.shape)early_stop = EarlyStopping(monitor='val_sparse_categorical_accuracy', patience=4, mode='max', verbose=1)# 定义ModelCheckpoint回调函数# checkpoint = ModelCheckpoint( './models/model_new1.h5', monitor='val_sparse_categorical_accuracy', save_best_only=True,#                              mode='max', verbose=1)checkpoint_pb = ModelCheckpoint(filepath="./models_pb/",  monitor='val_sparse_categorical_accuracy', save_weights_only=False, save_best_only=True)history = model.fit(train_data[:2000], train_labels[:2000], epochs=45, batch_size=45, validation_data=(test_data, test_labels), shuffle=True,verbose=1, callbacks=[early_stop, checkpoint_pb])print("history", history)return modeldef evaluate_model(test_data, test_labels):model = load_trained_model()# Evaluate the modelresults = model.evaluate(test_data, test_labels, verbose=2)print("Test accuracy:", results[1])def predict(real_data):model  = load_trained_model()probabilities = model.predict([real_data]);print("probabilities :",probabilities)result =  get_label(probabilities)return resultdef get_label(probabilities):index = np.argmax(probabilities[0])print("index :" + str(index))result_str =  index_dic.get(str(index))# result_str = list(index_dic.keys())[list(index_dic.values()).index(index)]return result_strdef load_trained_model():# model = get_model()# model.load_weights('./models/model_new1.h5')model = tf.keras.models.load_model('models_pb')return modeldef predict_my_module():# review = "I don't like it"# review = "this is bad movie "# review = "This is good movie"review = " this is terrible movie"# review = "This isn‘t great movie"# review = "i think this is bad movie"# review = "I'm not very disappoint for this movie"# review = "I'm not very disappoint for this movie"# review = "I am very happy for this movie"#neg:0 postive:1s = predict(review)print(s)if __name__ == '__main__':x_train, y_train, x_test, y_test = get_dataset_to_train()model = get_model()model = train(model, x_train, y_train, x_test, y_test)evaluate_model(x_test, y_test)predict_my_module()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/145629.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

JAVA JPA 使用实体类注解 @CreatedDate @LastModifiedDate自动生成创建和修改时间

JPA 使用实体类注解 CreatedDate LastModifiedDate自动生成创建和修改时间 说明&#xff1a;jpa实体添加数据库自动生成创建和修改时间 1.ApplicationBootstrap增加以下注解 EnableJpaAuditing2.实体类增加注解以下注解 Table(name"user") JsonIgnoreProperties(v…

【Linux专题】firewalld 过滤出接口流量

【赠送】IT技术视频教程&#xff0c;白拿不谢&#xff01;思科、华为、红帽、数据库、云计算等等_厦门微思网络的博客-CSDN博客文章浏览阅读428次。风和日丽&#xff0c;小微给你送福利~如果你是小微的老粉&#xff0c;这里有一份粉丝福利待领取...如果你是新粉关注到了小微&am…

智慧工地解决方案,实现安全预警、机械智能监控、作业指导、绿色施工、劳务管理、工程进度监控、施工质量检查

智慧工地云平台全套源码 智慧工地平台采用先进的云计算、物联网和大数据技术&#xff0c;可以实现智慧工地方案的落地。能够实现实时掌控工地活动及各项进度&#xff0c;有效预防违章施工。能够为工地提供多项服务&#xff0c;如安全预警、机械智能监控、作业指导、绿色施工、劳…

JVM bash:jmap:未找到命令 解决

如果我们在使用JVM的jmap命令时遇到了"bash: jmap: 未找到命令"的错误&#xff0c;这可能是因为jmap命令没有在系统的可执行路径中。 要解决这个问题&#xff0c;可以尝试以下几种方法&#xff1a; 1. 检查Java安装&#xff1a;确保您已正确安装了Java Development …

Spring Boot EasyPOI 使用指定模板导出Excel

相信大家都遇到过&#xff0c;用户提出要把界面上的数据导成一个Excel&#xff0c;还得是用户指定的Excel格式&#xff0c;用原生的POI&#xff0c;需要自己去实现&#xff0c;相信是比较麻烦的&#xff0c;所以我们可以使用开源的EasyPOI. 先上个图&#xff0c;看看是不是大家…

ES Kibana windows 安装

ES & Kibana windows 安装 声明&#xff1a; 本文没有实际操作过&#xff0c;只记录。具体操作请参考 ES & Kibana 安装 该文章 JDK1.8&#xff0c;最低要求&#xff01;ElasticSearch客户端&#xff0c;界面工具&#xff01; Java开发&#xff0c;ElasticSearch的版…

Newman

近期在复习Postman的基础知识&#xff0c;在小破站上跟着百里老师系统复习了一遍&#xff0c;也做了一些笔记&#xff0c;希望可以给大家一点点启发。 一&#xff09;如何安装Newman 1、下载并安装NodeJs 在官网下载NodeJs&#xff1a; Download | Node.js&#xff08;官网的…

2023.11.14 关于 Spring Boot 创建和使用

目录 Spring Boot Spring Boot 项目的创建 网页版创建 Spring Boot 项目 Spring Boot 目录说明 项目运行 Spring Boot Spring Boot 是基于 Spring 设计的一个全新的框架&#xff0c;其目的是用来简化 Spring 的应用、初始搭建、开发的整个过程Spring Boot 就是一个整合了…

初认识vue,v-for,v-if,v-bind,v-model,v-html等指令

vue 一.vue3介绍 1.为什么data是函数而不是对象? 因为vue是组件开发,组件会多次复用,data如果是对象,多次复用是共享,必须函数返回一个新的对象 1. 官网初识 Vue (发音为 /vjuː/&#xff0c;类似 view) 是一款用于构建用户界面的 JavaScript 框架。它基于标准 HTML、CSS …

UTC秒数与年月日时分秒的转换

闰年是指公历中每四年中有一个多出来的闰日&#xff08;2月29日&#xff09;的年份。判断一个年份是否为闰年&#xff0c;可以按照以下规则进行判断&#xff1a; 如果一个年份能够被4整除&#xff0c;但不能被100整除&#xff0c;则它是闰年。例如&#xff0c;2004年是闰年&am…

大数据爬虫分析基于Python+Django旅游大数据分析系统

欢迎大家点赞、收藏、关注、评论啦 &#xff0c;由于篇幅有限&#xff0c;只展示了部分核心代码。 文章目录 一项目简介 二、功能三、系统四. 总结 一项目简介 基于Python和Django的旅游大数据分析系统是一种使用Python编程语言和Django框架开发的系统&#xff0c;用于处理和分…

DSP生成hex方法

以下使用两种方法生成的HEX文件&#xff0c;亲测可用 &#xff08;1&#xff09;万能法 不管.out文件是哪个版本CCS编译器生成的&#xff0c;只要用HEX2000.exe软件&#xff0c;翻译都可以使用。方法&#xff1a; hex2000 -romwidth 16 -memwidth 16 -i -o 20170817chuankou…

12V升压18V 1A 内置MOS 升压芯片5-35V输入内置MOS升压IC

12V升压18V 1A 内置MOS 升压芯片5-35V输入内置MOS升压IC

vnodeToString函数把vnode转为string(innerhtml)

函数 function vnodeToString(vnode) {// 如果是文本节点&#xff0c;直接返回文本内容if ([string, boolean, undefined, null, number].includes(typeof vnode)) {return vnode;}// 转换节点的属性为字符串形式const attrs Object.keys(vnode.attrs || {}).map((key) > …

数据治理入门

处理模式 模式名称常见场景常见框架批处理夜间几个小时&#xff0c;无人值守hive spark datax流处理7*24H一直运行&#xff0c;无人值守maxwell, flink, flume, kafka即席处理人机交互接口访问 web页面 数据治理的意义 数据质量低&#xff1a;数据错误&#xff0c;不准确或不…

构建坚固防线:提升网站整体安防水平的有效途径

在当今数字化时代&#xff0c;网站安全问题日益突出&#xff0c;如何提升网站整体安防水平成为网站运营者不可忽视的挑战。本文将从高防服务器与CDN的局限性出发&#xff0c;探讨提升网站安防的有效途径&#xff0c;并最终总结其必要性。 高防服务器与CDN的局限性 高防服务器的…

别再吐槽大学教材了,来看看这些网友强推的数学神作!

前言 关于大学数学教材的吐槽似乎从来没停止过。有人慨叹&#xff1a;数学教材晦涩难懂。错&#xff01;难懂&#xff0c;起码还可以读懂。数学教材你根本读不懂&#xff1b;也有人说&#xff1a;数学教材简直就是天书。 数学教材有好有坏&#xff0c;这话不假&#xff0c;但更…

二次元商业计划书PPT模版

二次元商业计划书PPT模版 共&#xff1a;9页 PPT模版&#xff1a; 百度网盘 请输入提取码&#xff1a;ax48

pytorch 模型复现

一般来说&#xff0c;设置相同的随机种子&#xff0c;在相同参数条件下&#xff0c;能使pytorch模型复现出相同的结果。随机种子的设置代码如下&#xff1a; def get_random_seed(seed):random.seed(seed)os.environ[PYTHONHASHSEED] str(seed)np.random.seed(seed)torch.man…

FPGA基础以太网

以太网数据通信 物理层&#xff1a;网线网卡&#xff08;PHY芯片&#xff09; 数据链路层&#xff1a;Mac层(数据有效传输&#xff09; 如图所示&#xff1a;FPGA中的Mac层中的MII接口负责控制PHY芯片&#xff0c;PHY芯片通过网线与PC端进行以太网数据传输。 FPGA中&#xff…