Embedding例子:简单NN网络、迁移学习例子

一、简单例子:构造简单NN网络生成Embedding

1、pytorch例子

2、tensorflow例子

# 1导入模块
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Embedding
import numpy as np# 2构建语料库
corpus=[["The", "weather", "will", "be", "nice", "tomorrow"],["How", "are", "you", "doing", "today"],["Hello", "world", "!"]
]# 3生成字典
#获取语料不同单词,并过滤掉一些字符如"!"
word_set=set([i for item in corpus for i in item if i!='!']) 
word_dicts={}#索引从1开始,0用来填充
j=1
for i in word_set:word_dicts[i]=j j=j+1# 4用索引表示语料
raw_inputs=[]
for i in range(len(corpus)):raw_inputs.append([word_dicts[j]  for j in corpus[i] if j!="!"])padded_inputs = tf.keras.preprocessing.sequence.pad_sequences(raw_inputs,padding='post')print(padded_inputs)# 5构建网络
model = Sequential()
model.add(Embedding(20, 4, input_length=6,mask_zero=True))
model.compile('rmsprop', 'mse')
output_array = model.predict(padded_inputs)
output_array.shape
# 6 查看结果
output_array[1]

输出结果:

二、迁移学习: 使用预训练模型生成Embedding

1、使用Glove预训练数据集迁移学习

import osimdb_dir = './aclImdb' # 电影评论数据集
train_dir = os.path.join(imdb_dir, 'train')labels = []
texts = []for label_type in ['neg', 'pos']:dir_name = os.path.join(train_dir, label_type)for fname in os.listdir(dir_name):if fname[-4:] == '.txt':f = open(os.path.join(dir_name, fname))texts.append(f.read())f.close()if label_type == 'neg':labels.append(0)else:labels.append(1)from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
import numpy as npmaxlen = 100  # 只保留前100单词的评论
training_samples = 200  # 在200个样本上训练
validation_samples = 10000  # W对10000个样品进行验证
max_words = 10000  # 只考虑数据集中最常见的10000 个单词tokenizer = Tokenizer(num_words=max_words)
tokenizer.fit_on_texts(texts)
sequences = tokenizer.texts_to_sequences(texts)word_index = tokenizer.word_index
print('Found %s unique tokens.' % len(word_index))data = pad_sequences(sequences, maxlen=maxlen)labels = np.asarray(labels)
print('Shape of data tensor:', data.shape)
print('Shape of label tensor:', labels.shape)# 将数据划分为训练集和验证集
# 首先打乱数据, 因一开始数据集是排序好的
# 负面评论在前, 正面评论在后
indices = np.arange(data.shape[0])
np.random.shuffle(indices)
data = data[indices]
labels = labels[indices]x_train = data[:training_samples]
y_train = labels[:training_samples]
x_val = data[training_samples: training_samples + validation_samples]
y_val = labels[training_samples: training_samples + validation_samples]glove_dir = './glove.6B/'embeddings_index = {}
f = open(os.path.join(glove_dir, 'glove.6B.100d.txt'))
for line in f:values = line.split()word = values[0]coefs = np.asarray(values[1:], dtype='float32')embeddings_index[word] = coefs
f.close()print('Found %s word vectors.' % len(embeddings_index))for key,value in embeddings_index.items():print(key,value)breakembedding_dim = 100embedding_matrix = np.zeros((max_words, embedding_dim))
for word, i in word_index.items():embedding_vector = embeddings_index.get(word)if i < max_words:if embedding_vector is not None:# 在嵌入索引(embedding index)找不到的词,其嵌入向量都设为0embedding_matrix[i] = embedding_vectorfrom tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Embedding, Flatten, Densemodel = Sequential()
model.add(Embedding(max_words, embedding_dim, input_length=maxlen))
model.add(Flatten())
model.add(Dense(32, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
model.summary()model.layers[0].set_weights([embedding_matrix])
model.layers[0].trainable = Falsemodel.compile(optimizer='rmsprop',loss='binary_crossentropy',metrics=['acc'])
history = model.fit(x_train, y_train,epochs=10,batch_size=32, validation_data=(x_val, y_val))
model.save_weights('pre_trained_glove_model.h5')
import matplotlib.pyplot as plt
%matplotlib inlineacc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']epochs = range(1, len(acc) + 1)plt.plot(epochs, acc, 'bo', label='Training acc')
plt.plot(epochs, val_acc, 'b', label='Validation acc')
plt.title('Training and validation accuracy')
plt.legend()

输出结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/431.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

每日两题 / 22. 括号生成 54. 螺旋矩阵(LeetCode热题100)

22. 括号生成 - 力扣&#xff08;LeetCode&#xff09; dfs生成合法的括号序列即可 class Solution { public:vector<string> ans;void dfs(int l, int r, int n, string& s){if (s.size() n * 2){ans.push_back(s);return;}if (l){s "(";dfs(l - 1, …

Burpsuite插件 BurpAPIFinder专为未授权/敏感信息/越权而生

BurpAPIFinder 攻防演练过程中&#xff0c;我们通常会用浏览器访问一些资产&#xff0c;但很多未授权/敏感信息/越权隐匿在已访问接口过html、JS文件等&#xff0c;通过该Burp插件我们可以&#xff1a; 1、发现通过某接口可以进行未授权/越权获取到所有的账号密码、私钥、凭证 …

记录方式重新打开人生

文章目录 引言节省生命感知细节的能力正视痛苦总结 引言 你是否遇到以下问题 时间过得很快&#xff0c;不知道过去在忙什么事情很多很杂&#xff0c;感觉一直都很忙但是好像也没啥收获生活、工作中不顺心的事情很多&#xff0c;心里比较烦躁压抑 那么可以尝试参考《奇特的一…

python环境引用《解读》----- 环境隔离

首先我先讲一下Anaconda&#xff0c;因为我用的是Anaconda进行包管理。方便后面好理解一点。 大家在python中引用环境的时候都会经历下面这一步&#xff1a; 那么好多人就会出现以下问题&#xff08;我就是遇到了这个问题&#xff09;&#xff1a; 我明明下载了包&#xff0c…

吴恩达深度学习笔记:深度学习的 实践层面 (Practical aspects of Deep Learning)1.1-1.3

目录 第一门课&#xff1a;第二门课 改善深层神经网络&#xff1a;超参数调试、正 则 化 以 及 优 化 (Improving Deep Neural Networks:Hyperparameter tuning, Regularization and Optimization)第一周&#xff1a;深度学习的 实践层面 (Practical aspects of Deep Learning)…

阶段性学习汇报 4月19日

一、毕业设计和毕业论文 毕业设计后端功能基本实现&#xff0c;但是还有些具体的细节需要优化&#xff0c;例如这些图片的显示问题&#xff0c;前端只有个前端页面以及部分交互逻辑&#xff0c;还需进一步完善。我想在疾病预测这里加一个创新点&#xff0c;基于推荐算法。小程序…

测绘管理与法律法规 | 中华人民共和国测绘法 | 学习笔记

《中华人民共和国测绘法》笔记&#xff1a; 第一章 总则 第一条&#xff1a;立法目的&#xff0c;即加强测绘管理&#xff0c;促进测绘事业发展&#xff0c;保障测绘事业为经济建设、国防建设、社会发展和生态保护服务&#xff0c;维护国家地理信息安全。 第二条&#xff1a;…

网络爬虫软件学习

1 什么是爬虫软件 爬虫软件&#xff0c;也称为网络爬虫或网络蜘蛛&#xff0c;是一种自动抓取万维网信息的程序或脚本。它基于一定的规则&#xff0c;自动地访问网页并抓取需要的信息。爬虫软件可以应用于大规模数据采集和分析&#xff0c;广泛应用于舆情监测、品牌竞争分析、…

ollama大语言模型

查看已经安装的大语言模型 ollama list运行大语言模型 ollama run llama2:latest

Qt实现Mysql数据库的连接,查询,修改,删除,增加功能

Qt实现Mysql数据库的连接&#xff0c;查询&#xff0c;修改&#xff0c;删除&#xff0c;增加功能 安装Mysql数据库&#xff0c;QtCreator Mysql选择Mysql Server 8.1版本安装。 Mysql Server 8.1安装过程 1.首先添加网络服务权限&#xff1a; WinR键输入compmgmt.msc进入…

Linux【实战】—— LAMP环境搭建 部署网站

目录 一、介绍 1.1什么是LAMP&#xff1f; 1.2LAMP的作用 二、部署静态网站 2.1 虚拟主机&#xff1a;一台服务器上部署多个网站 2.1.1 安装Apache服务 2.1.2 防火墙配置 2.1.3 准备网站目录 2.1.4 创建网站的配置文件 2.1.5 检查配置文件是否正确 2.1.6 Linux客户端…

web自动化系列-selenium的3种等待方式(十一)

在ui自动化测试中&#xff0c;几乎出现问题最多的情况就是定位不到元素 &#xff0c;当你的自动化在运行过程中 &#xff0c;突然发现报错走不下去了 。很大概率就是因为找不到元素 &#xff0c;而找不到元素的一个主要原因就是页面加载慢 &#xff0c;代码运行速度快导致 。 …

深入理解MySQL中的UPDATE JOIN语句

在MySQL数据库中&#xff0c;UPDATE语句用于修改表中现有的记录。有时&#xff0c;我们需要根据另一个相关联表中的条件来更新表中的数据。这时就需要使用UPDATE JOIN语句。最近我们遇到了这样的需求&#xff1a;我们有一张历史记录表&#xff0c;其中一个字段记录了用,连接的多…

【转】关于vsCode创建后,不显示NPM脚本解决

刚刚使用vue ui新建了个vue项目&#xff0c;打开vs-code发现&#xff0c;无论怎么设置都找不到NPM脚本显示&#xff0c;苦恼了很久&#xff0c;突然发现&#xff01;打开了package-lock.json&#xff0c;然后立马把vs-code关闭&#xff0c;重新打开&#xff0c;就显示了npm脚本…

DePT: Decoupled Prompt Tuning 论文阅读

DePT: Decoupled Prompt Tuning 了论文阅读 Abstract1. Introduction2. Methodology2.1. Preliminaries2.2. A Closer Look at the BNT Problem2.3. Decoupled Prompt Tuning 3. Experiments5. Conclusions 文章信息&#xff1a; 原文链接&#xff1a;https://arxiv.org/abs/…

【行为型模式】模板方法模式

一、模板方法模式概述 模板方法模式定义&#xff1a;在一个方法中定义一个算法的骨架,而将一些步骤延迟到子类中。模板方法使得子类可以在不改变算法结构的情况下,重新定义算法中的某些步骤。(类对象型模式) 模板方法中的基本方法是实现算法的各个步骤&#xff0c;是模板方法的…

rocketmq-dashboard打包测试报错

rocketmq-dashboard运行的时候没问题&#xff0c;但是打包执行测试的时候就是报错 这时候跳过测试就可以成功 报错为 There are test failures. Please refer to D:\CodeEn\rocketmq-dashboard\target\surefire-reports for the individual test results. 你只需要跳过测试就…

vue框架中的路由

vue框架中的路由 一.VueRouter的使用&#xff08;52&#xff09;二.路由模块封装三.声明式导航 - 导航链接1.router-link-active类名2.router-link-exact-active类名3.声明式导航-自定义类名 四.查询参数传参五.动态路由传参方式查询参数传参 VS 动态路由传参 六.动态路由参数的…

javaWeb项目-毕业生就业信息管理系统功能介绍

项目关键技术 开发工具&#xff1a;IDEA 、Eclipse 编程语言: Java 数据库: MySQL5.7 框架&#xff1a;ssm、Springboot 前端&#xff1a;Vue、ElementUI 关键技术&#xff1a;springboot、SSM、vue、MYSQL、MAVEN 数据库工具&#xff1a;Navicat、SQLyog 1、JSP技术 JSP(Jav…

【Canvas技法】四条C形色带填满一个圆/环形

【关键点】 通过三角函数计算控制点的位置。 【成果图】 【代码】 <!DOCTYPE html> <html lang"utf-8"> <meta http-equiv"Content-Type" content"text/html; charsetutf-8"/> <head><title>四条C形色带填满一个…