2-Embedding例子:简单NN网络、迁移学习例子(glove语料预训练)

一、简单例子:构造简单NN网络生成Embedding

1、pytorch例子

2、tensorflow例子

# 1导入模块
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Embedding
import numpy as np# 2构建语料库
corpus=[["The", "weather", "will", "be", "nice", "tomorrow"],["How", "are", "you", "doing", "today"],["Hello", "world", "!"]
]# 3生成字典
#获取语料不同单词,并过滤掉一些字符如"!"
word_set=set([i for item in corpus for i in item if i!='!']) 
word_dicts={}#索引从1开始,0用来填充
j=1
for i in word_set:word_dicts[i]=j j=j+1# 4用索引表示语料
raw_inputs=[]
for i in range(len(corpus)):raw_inputs.append([word_dicts[j]  for j in corpus[i] if j!="!"])padded_inputs = tf.keras.preprocessing.sequence.pad_sequences(raw_inputs,padding='post')print(padded_inputs)# 5构建网络
model = Sequential()
model.add(Embedding(20, 4, input_length=6,mask_zero=True))
model.compile('rmsprop', 'mse')
output_array = model.predict(padded_inputs)
output_array.shape
# 6 查看结果
output_array[1]

输出结果:

二、迁移学习: 使用预训练模型生成Embedding

1、什么是迁移学习?不同任务场景下,如何使用预训练模型?

迁移学习是在一个任务上学习到的模型(结构、权重)作为初始点,应用到另一个新的任务上。

那该如何使用预训练模型呢?

场景1: 数据集小,数据相似度高

去掉输出层,然后将剩下的整个网络当作一个固定的特征提取机,应用到新的数据集中。
过程如图3-11所示,调整分类器中的几个参数,其他模块保持“冻结”即可。
这种微调方法,有时又称为特征抽取,因为预训练模型可以作为目标数据的特征提取器。

场景2: 数据集大, 数据相似度高

因为目标数据与预训练模型的训练数据之间高度相似,故采用预训练模型会非常有效。
另外,训练系统有一个较大的数据集,采用冻结预处理模型中少量较低层,修改分类器,然后在新数据集的基础上重新开始训练是一种较好的方式,具体处理过程如图3-12所示。

场景3:  数据集小,数据相似度不高

在这种情况下,可以冻结预训练模型中较少的网络高层,然后重新训练后面的网络,修改分类器。因为数据的相似度不高,重新训练的过程就变得非常关键。而新数据集大小的不足,则是通过冻结预训练模型中一些较低的网络层进行弥补,具体处理过程如图3-13所示。

场景4: 数据集大, 数据相似度不高

在这种情况下,因为有一个很大的数据集,所以神经网络的训练过程将会比较有效率。然而,因为目标数据与预训练模型的训练数据之间存在很大差异,采用预训练模型不是一种高效的方式。因此最好的方法还是将预处理模型中的权重全都初始化后再到新数据集的基础上重新开始训练,具体处理过程如图3-14所示。

2、使用Glove预训练数据集迁移学习例子

import osimdb_dir = './aclImdb' # 电影评论数据集
train_dir = os.path.join(imdb_dir, 'train')labels = []
texts = []for label_type in ['neg', 'pos']:dir_name = os.path.join(train_dir, label_type)for fname in os.listdir(dir_name):if fname[-4:] == '.txt':f = open(os.path.join(dir_name, fname))texts.append(f.read())f.close()if label_type == 'neg':labels.append(0)else:labels.append(1)from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
import numpy as npmaxlen = 100  # 只保留前100单词的评论
training_samples = 200  # 在200个样本上训练
validation_samples = 10000  # W对10000个样品进行验证
max_words = 10000  # 只考虑数据集中最常见的10000 个单词tokenizer = Tokenizer(num_words=max_words)
tokenizer.fit_on_texts(texts)
sequences = tokenizer.texts_to_sequences(texts)word_index = tokenizer.word_index
print('Found %s unique tokens.' % len(word_index))data = pad_sequences(sequences, maxlen=maxlen)labels = np.asarray(labels)
print('Shape of data tensor:', data.shape)
print('Shape of label tensor:', labels.shape)# 将数据划分为训练集和验证集
# 首先打乱数据, 因一开始数据集是排序好的
# 负面评论在前, 正面评论在后
indices = np.arange(data.shape[0])
np.random.shuffle(indices)
data = data[indices]
labels = labels[indices]x_train = data[:training_samples]
y_train = labels[:training_samples]
x_val = data[training_samples: training_samples + validation_samples]
y_val = labels[training_samples: training_samples + validation_samples]glove_dir = './glove.6B/'embeddings_index = {}
f = open(os.path.join(glove_dir, 'glove.6B.100d.txt'))
for line in f:values = line.split()word = values[0]coefs = np.asarray(values[1:], dtype='float32')embeddings_index[word] = coefs
f.close()print('Found %s word vectors.' % len(embeddings_index))for key,value in embeddings_index.items():print(key,value)breakembedding_dim = 100embedding_matrix = np.zeros((max_words, embedding_dim))
for word, i in word_index.items():embedding_vector = embeddings_index.get(word)if i < max_words:if embedding_vector is not None:# 在嵌入索引(embedding index)找不到的词,其嵌入向量都设为0embedding_matrix[i] = embedding_vectorfrom tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Embedding, Flatten, Densemodel = Sequential()
model.add(Embedding(max_words, embedding_dim, input_length=maxlen))
model.add(Flatten())
model.add(Dense(32, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
model.summary()model.layers[0].set_weights([embedding_matrix])
model.layers[0].trainable = Falsemodel.compile(optimizer='rmsprop',loss='binary_crossentropy',metrics=['acc'])
history = model.fit(x_train, y_train,epochs=10,batch_size=32, validation_data=(x_val, y_val))
model.save_weights('pre_trained_glove_model.h5')
import matplotlib.pyplot as plt
%matplotlib inlineacc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']epochs = range(1, len(acc) + 1)plt.plot(epochs, acc, 'bo', label='Training acc')
plt.plot(epochs, val_acc, 'b', label='Validation acc')
plt.title('Training and validation accuracy')
plt.legend()

输出结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/1701.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux命令接着学习

which命令&#xff0c;找到各种命令程序所处在的位置 语法&#xff1a;which查找的命令 那么对于我们想查找其他类型文件所在的位置&#xff0c;我们可以用到find命令 find命令 选项为-name&#xff0c;表示按照文件名进行查找 find命令中通配符 find命令和前面rm命令一样&…

MT3023 歌词中找单词

1.暴力 10/12 #include <bits/stdc.h> using namespace std; int n; string a[10005]; int main() {cin >> n;for (int i 0; i < n; i)cin >> a[i];string ll;cin >> ll;for (int i 0; i < n; i){string u a[i];int num 0;int j 0;for (in…

解线性方程组——追赶法解三对角方程组 | 北太天元

一、问题描述 对于线性方程组 A x b , A ( b 1 c 1 a 2 b 2 c 2 ⋱ ⋱ ⋱ ⋱ ⋱ ⋱ a n − 1 b n − 1 c n − 1 a n b n ) , b ( f 1 f 2 ⋮ f n ) Axb,\quad A\begin{pmatrix}b_1&c_1&&&&\\a_2&b_2&c_2&&&\\&\ddots&\d…

CentOS 7安装、卸载MySQL数据库(一)

说明&#xff1a;本文介绍如何在CentOS 7操作系统下使用yum方式安装MySQL数据库&#xff0c;及卸载&#xff1b; 安装 Step1&#xff1a;卸载mariadb 敲下面的命令&#xff0c;查看系统mariadb软件包 rpm -qa|grep mariadb跳出mariadb软件包信息后&#xff0c;敲下面的命令…

mysql基础14——视图

视图 视图是一种虚拟表 可以把一段查询语句作为视图存储在数据库中 需要的时候把视图看作一个表&#xff0c;对里面的数据进行查询 视图并没有真正存储数据 避免了数据存储过程中可能产生的冗余 提高了存储的效率 子查询 嵌套在另一个查询中的查询 派生表 如果在查询中…

六、项目发布 -- 4. 电子书详情页API开发、电子书列表API开发

电子书详情页API的编写 同理如下app.get中路由、回调&#xff1b;回调中要连接数据库、接收前端传过来的值、到数据库中做查询&#xff0c;然后回调&#xff08;如果回调失败返回什么JSON&#xff0c;如果回调成功返回什么JSON&#xff09;&#xff1b;最后千万别忘记了关闭数…

怎样快速打造二级分销小程序

乔拓云是一个专门开发小程序模板的平台&#xff0c;致力于帮助商家快速上线自己的小程序。通过套用乔拓云提供的精美模板&#xff0c;商家无需具备专业的技术背景&#xff0c;也能轻松打造出功能齐全、美观大方的小程序。 在乔拓云的官网&#xff0c;商家可以免费注册账号并登录…

全科都收!1区毕业水刊,影响因子狂涨至9.8,无预警记录!国人评价高!

本期&#xff0c;小编给大家解析的是一本创刊于2014年&#xff0c;且于同年被WOS数据库收录的毕业“水刊”——SCIENTIFIC DATA。 截图来源&#xff1a;期刊官网 SCIENTIFIC DATA&#xff08;ISSN&#xff1a;2052-4463&#xff09;是一本致力于数据的开放获取期刊&#xff0c…

可视化大屏在政务领域应用非常普遍,带你看看

可视化大屏在政务领域的应用非常普遍&#xff0c;政务领域需要处理大量的数据和信息&#xff0c;通过可视化大屏可以将这些数据以直观、易懂的方式展示出来&#xff0c;帮助政府决策者和工作人员更好地了解和分析数据&#xff0c;从而做出更准确、科学的决策。 在政务领域&…

xhEditor实现WORD粘贴图片自动上传

1.下载示例&#xff1a; 从官网下载 http://www.ncmem.com/webapp/wordpaster/versions.aspx 从gitee中下载 https://gitee.com/xproer/wordpaster-php-xheditor1x 2.将插件目录复制到项目中 3.引入插件文件 定义插件图标 初始化插件&#xff0c;在工具栏中添加插件按钮 效果…

B端界面:除了蓝色外,四条搞定清新明快的界面设计。

一、什么是清新明快风格 清新明快的设计风格是指在B端系统中使用明亮、清淡的色彩、简洁的布局和自然元素&#xff0c;以及轻快的动效&#xff0c;营造出轻松、愉悦的界面氛围。 二、哪些行业适用 这种设计风格适用于多个行业&#xff0c;特别是那些与创意、娱乐、健康、旅游…

安卓原生项目工程结构说明

.gradle 和 .idea (自动生成) .gradle 是gradle下载好的缓存&#xff0c;如果有配置好的 下载好的缓存 直接会拿来用 没有会下载 生成 .idea 是编辑器的配置 app 代码主逻辑 目录 项目中的代码 资源都会在里面 工作的时候的核心目录 gradle 下载安卓的构建器gradle相关的配置信…

V23092-A1024-A301 工业继电器 24V 6A 一组转换

V23092-A1024-A301是一款通用继电器。参数为24V 6A 该继电器适用于控制各种电气负载&#xff0c;如电机、加热器或其他高电流设备。广泛应用于各种工业控制和自动化系统中&#xff0c;它的封装尺寸和引脚排列符合标准的工业规范&#xff0c;便于安装和使用。 产品种类: 通用…

C语言——贪吃蛇游戏的实现

一. 贪吃蛇的介绍 我们都有玩过一个小游戏——贪吃蛇&#xff0c;贪吃蛇也是一个经典游戏。如上图所示&#xff0c;游戏玩法就是操控一个蛇&#xff0c;让它吃掉食物&#xff0c;每吃掉一个食物就会增加自己身体一格长度&#xff0c;并且保证自己不能撞到墙和自己本身&#xff…

Ubuntu系统安装Anaconda

1. 下载Anconda安装包 1.1 wget命令下载 当然还可以去清华大学开源软件镜像站&#xff1a;Index of /anaconda/archive/ | 清华大学开源软件镜像站 | Tsinghua Open Source Mirror&#xff0c;下载各种版本的Anaconda。 wget下载命令如下&#xff1a; 我这里下载的是2024.02…

二百三十三、Flume——Flume采集JSON文件到Kafka,再用Flume采集Kafka数据到HDFS中

一、目的 由于使用了新的Kafka协议&#xff0c;因为根据新的协议推送模拟数据到Kafka中&#xff0c;再Flume采集Kafka数据到HDFS中 二、技术选型 &#xff08;一&#xff09;Kettle工具 准备使用Kettle的JSON input控件和Kafka producer控件&#xff0c;但是搞了1天没搞定&…

OSPF的LSA详解

一、什么是LSA&#xff1f;LSA作用&#xff1f; 在OSPF协议中&#xff0c;LSA全称链路状态通告&#xff0c;主要由LSA头部信息&#xff08;LSA摘要&#xff09;和链路状态组成。部分LSA只有LSA头部信息&#xff0c;无链路状态信息。使用LSA来传递路由信息和拓扑信息&#xff0c…

【STM32F4】STM32CUMX相关环境配置

一、环境配置 我们需要以下两个软件 &#xff08;一&#xff09;keil5 最正统&#xff0c;最经典的嵌入式MCU开发环境。 该环境的配置可以看看之前的文章 所需文件如下&#xff1a; 当时配置的是STC8H的环境&#xff0c;现在基于此&#xff0c;重新给STM32配置环境。能让STC…

运营商三要素核验接口-手机实名验证API

运营商三要素核验接口是一种API&#xff08;Application Programming Interface&#xff0c;应用程序编程接口&#xff09;&#xff0c;主要用于通过互联网技术对接通信运营商的实名制数据库&#xff0c;以验证用户提供的手机号码、身份证号码、姓名这三项关键信息&#xff08;…

Python | Leetcode Python题解之第37题解数独

题目&#xff1a; 题解&#xff1a; class Solution:def solveSudoku(self, board: List[List[str]]) -> None:def dfs(pos: int):nonlocal validif pos len(spaces):valid Truereturni, j spaces[pos]for digit in range(9):if line[i][digit] column[j][digit] bloc…