勾八头歌之RNN

一、RNN快速入门

1.学习单步的RNN:RNNCell

# -*- coding: utf-8 -*-
import tensorflow as tf# 参数 a 是 BasicRNNCell所含的神经元数, 参数 b 是 batch_size, 参数 c 是单个 input 的维数,shape = [ b , c ]
def creatRNNCell(a,b,c):# 请在此添加代码 完成本关任务# ********** Begin *********#x1=tf.placeholder(tf.float32,[b,c])cell=tf.nn.rnn_cell.BasicRNNCell(num_units=a)h0=cell.zero_state(batch_size=b,dtype=tf.float32)output,h1=cell.__call__(x1,h0)print(cell.state_size)print(h1)# ********** End **********#

2.探幽入微LSTM

# -*- coding: utf-8 -*-
import tensorflow as tf# 参数 a 是 BasicLSTMCell所含的神经元数, 参数 b 是 batch_size, 参数 c 是单个 input 的维数,shape = [ b , c ]
def creatLSTMCell(a,b,c):# 请在此添加代码 完成本关任务# ********** Begin *********#x1=tf.placeholder(tf.float32,[b,c])cell=tf.nn.rnn_cell.BasicLSTMCell(num_units=a)h0=cell.zero_state(batch_size=b,dtype=tf.float32)output,h1=cell.__call__(x1,h0)print(h1.h)print(h1.c)# ********** End **********#

3.进阶RNN:学习一次执行多步以及堆叠RNN

# -*- coding: utf-8 -*-
import tensorflow as tf
import numpy as np# 参数 a 是RNN的层数, 参数 b 是每个BasicRNNCell包含的神经元数即state_size
# 参数 c 是输入序列的批量大小即batch_size,参数 d 是时间序列的步长即time_steps,参数 e 是单个输入input的维数即input_size
def MultiRNNCell_dynamic_call(a,b,c,d,e):# 用tf.nn.rnn_cell MultiRNNCell创建a层RNN,并调用tf.nn.dynamic_rnn# 请在此添加代码 完成本关任务# ********** Begin *********#cell = tf.nn.rnn_cell.MultiRNNCell([tf.nn.rnn_cell.BasicRNNCell(num_units=b) for _ in range(a)]) # a层RNNinputs = tf.placeholder(np.float32, shape=(c, d, e)) # a 是 batch_size,d 是time_steps, e 是input_sizeh0=cell.zero_state(batch_size=c,dtype=tf.float32)output, h1 = tf.nn.dynamic_rnn(cell, inputs, initial_state=h0)print(output)# ********** End **********#

二、RNN循环神经网络

1.Attention注意力机制(A  ABC  B  C  A)

2.Seq2Seq

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variabledtype = torch.FloatTensor
char_list = [c for c in 'SEPabcdefghijklmnopqrstuvwxyz']
char_dic = {n: i for i, n in enumerate(char_list)}
seq_data = [['man', 'women'], ['black', 'white'], ['king', 'queen'], ['girl', 'boy'], ['up', 'down'], ['high', 'low']]
seq_len = 8
n_hidden = 128
n_class = len(char_list)
batch_size = len(seq_data)##########Begin##########
#对数据进行编码部分
##########End##########
def make_batch(seq_data):batch_size = len(seq_data)input_batch, output_batch, target_batch = [], [], []for seq in seq_data:for i in range(2):seq[i] += 'P' * (seq_len - len(seq[i]))input = [char_dic[n] for n in seq[0]]output = [char_dic[n] for n in ('S' + seq[1])]target = [char_dic[n] for n in (seq[1] + 'E')]input_batch.append(np.eye(n_class)[input])output_batch.append(np.eye(n_class)[output])target_batch.append(target)return Variable(torch.Tensor(input_batch)), Variable(torch.Tensor(output_batch)), Variable(torch.LongTensor(target_batch))##########Begin##########
#模型类定义
input_batch, output_batch, target_batch = make_batch(seq_data)
class Seq2Seq(nn.Module):def __init__(self):super(Seq2Seq, self).__init__()self.encoder = nn.RNN(input_size=n_class, hidden_size=n_hidden)self.decoder = nn.RNN(input_size=n_class, hidden_size=n_hidden)self.fc = nn.Linear(n_hidden, n_class)def forward(self, enc_input, enc_hidden, dec_input):enc_input = enc_input.transpose(0, 1)dec_input = dec_input.transpose(0, 1)_, h_states = self.encoder(enc_input, enc_hidden)outputs, _ = self.decoder(dec_input, h_states)outputs = self.fc(outputs)return outputs
##########End##########model = Seq2Seq()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)##########Begin##########
#模型训练过程
for epoch in range(5001):hidden = Variable(torch.zeros(1, batch_size, n_hidden))optimizer.zero_grad()outputs = model(input_batch, hidden, output_batch)outputs = outputs.transpose(0, 1)loss = 0for i in range(batch_size):loss += criterion(outputs[i], target_batch[i])loss.backward()optimizer.step()
##########End####################Begin##########
#模型验证过程函数
def translated(word):input_batch, output_batch, _ = make_batch([[word, 'P' * len(word)]])hidden = Variable(torch.zeros(1, 1, n_hidden))outputs = model(input_batch, hidden, output_batch)predict = outputs.data.max(2, keepdim=True)[1]decode = [char_list[i] for i in predict]end = decode.index('P')translated = ''.join(decode[:end])print(translated)
##########End##########translated('highh')
translated('kingh')

三、RNN和LSTM

1.循环神经网络简介

import torchdef rnn(input,state,params):"""循环神经网络的前向传播:param input: 输入,形状为 [ batch_size,num_inputs ]:param state: 上一时刻循环神经网络的状态,形状为 [ batch_size,num_hiddens ]:param params: 循环神经网络的所使用的权重以及偏置:return: 输出结果和此时刻网络的状态"""W_xh,W_hh,b_h,W_hq,b_q = params"""W_xh : 输入层到隐藏层的权重W_hh : 上一时刻状态隐藏层到当前时刻的权重b_h : 隐藏层偏置W_hq : 隐藏层到输出层的权重b_q : 输出层偏置"""H = state# 输入层到隐藏层H = torch.matmul(input, W_xh) + torch.matmul(H, W_hh) + b_hH = torch.tanh(H)# 隐藏层到输出层Y = torch.matmul(H, W_hq) + b_qreturn Y,Hdef init_rnn_state(num_inputs,num_hiddens):"""循环神经网络的初始状态的初始化:param num_inputs: 输入层中神经元的个数:param num_hiddens: 隐藏层中神经元的个数:return: 循环神经网络初始状态"""init_state = torch.zeros((num_inputs,num_hiddens),dtype=torch.float32)return init_state

2.长短时记忆网络

import torchdef lstm(X,state,params):"""LSTM:param X: 输入:param state: 上一时刻的单元状态和输出:param params: LSTM 中所有的权值矩阵以及偏置:return: 当前时刻的单元状态和输出"""W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q = params"""W_xi,W_hi,b_i : 输入门中计算i的权值矩阵和偏置W_xf,W_hf,b_f : 遗忘门的权值矩阵和偏置W_xo,W_ho,b_o : 输出门的权值矩阵和偏置W_xc,W_hc,b_c : 输入门中计算c_tilde的权值矩阵和偏置W_hq,b_q : 输出层的权值矩阵和偏置"""#上一时刻的输出 H 和 单元状态 C。(H,C) = state# 遗忘门F = torch.matmul(X, W_xf) + torch.matmul(H, W_hf) + b_fF = torch.sigmoid(F)# 输入门I = torch.sigmoid(torch.matmul(X,W_xi)+torch.matmul(H,W_hi) + b_i)C_tilde = torch.tanh(torch.matmul(X, W_xc) + torch.matmul(H, W_hc) + b_c)C = F * C + I * C_tilde# 输出门O = torch.sigmoid(torch.matmul(X,W_xo)+torch.matmul(H,W_ho) + b_o)H = O * C.tanh()# 输出层Y = torch.matmul(H,W_hq) + b_qreturn Y,(H,C)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/2617.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

sprinboot+人大金仓配置

1. .yml 配置 spring:datasource:type: com.alibaba.druid.pool.DruidDataSource#driverClassName: dm.jdbc.driver.DmDriver## todo 人大金仓driverClassName: com.kingbase8.Driverdruid:## todo 人大金仓master:url: jdbc:kingbase8://111.111.111.111:54321/dbname?cu…

粘合聚酰亚胺PI塑料材料使用UV胶的优势有哪些? (三十四)

聚酰亚胺PI难于粘接,在PI粘接方法中使用UV胶粘剂粘接PI的优势有哪些? 聚酰亚胺(PI)是一种具有耐高低温性能、高绝缘性、耐化性、低热膨胀系数的材料,广泛用于FPC基材和各种耐高温电机电器的绝缘材料。然而,…

MySQL常见的约束

什么是约束? 限制,限制我们表中的数据,保证添加到数据表中的数据准确和可靠性!凡是不符合约束的数据,插入时就会失败,插入不进去的! 比如:学生信息表中,学号就会约束不…

Java | Leetcode Java题解之第45题跳跃游戏II

题目&#xff1a; 题解&#xff1a; class Solution {public int jump(int[] nums) {int length nums.length;int end 0;int maxPosition 0; int steps 0;for (int i 0; i < length - 1; i) {maxPosition Math.max(maxPosition, i nums[i]); if (i end) {end maxP…

POP —— 简介

目录 Emitting Applying forces Reacting to surfaces Limiting particle speed Following a leader or leaders Swirling particles around vortex filaments Visualizing Forces Collisions Instancing and Rendering Sprite Particles Streams Writing particle…

编程基础“四大件”

基础四大件包括&#xff1a;数据结构和算法,计算机网络,操作系统,设计模式 这跟学什么编程语言,后续从事什么编程方向均无关&#xff0c;只要做编程开发&#xff0c;这四个计算机基础就无法避开。可以这么说&#xff0c;这基础四大件真的比编程语言重要&#xff01;&#xff0…

色温的介绍

文章目录 色温的概念照明领域显示技术领域 色温的概念 色温是描述光源色彩特性的一个重要参数&#xff0c;通常用来表征光的暖冷程度。它以开尔文&#xff08;Kelvin&#xff0c;K&#xff09;为单位来表示&#xff0c;通常简写为K。色温越高&#xff0c;光线看起来就越接近于…

如何用PHP语言实现远程语音播报

如何用PHP语言实现远程语音播报呢&#xff1f; 本文描述了使用PHP语言调用HTTP接口&#xff0c;实现语音播报。通过发送文本信息&#xff0c;来实现远程语音播报、语音提醒、语音警报等。 可选用产品&#xff1a;可根据实际场景需求&#xff0c;选择对应的规格 序号设备名称1…

比特币之路:技术突破、创新思维与领军人物

比特币的兴起是一段充满技术突破、创新思维和领军人物的传奇之路。在这篇文章中&#xff0c;我们将探讨比特币发展的历程&#xff0c;以及那些在这一过程中发挥重要作用的关键人物。 技术突破与前奏 比特币的诞生并非凭空而来&#xff0c;而是建立在先前的技术储备之上。在密码…

机器学习中常见的数据分析,处理方式(以泰坦尼克号为例)

数据分析 读取数据查看数据各个参数信息查看有无空值如何填充空值一些特殊字段如何处理读取数据查看数据中的参数信息实操具体问题具体分析年龄问题 重新划分数据集如何删除含有空白值的行根据条件删除一些行查看特征和标签的相关性 读取数据 查看数据各个参数信息 查看有无空…

TCP三次握手详解

目录 什么是TCP TCP头格式组成 三次握手 第一次握手 第二次握手 第三次握手 三次握手的好处 为什么需要三次握手&#xff1f; 什么是TCP 传输控制协议(TCP)是Internet一个重要的传输层协议。TCP提供面向连接、可靠、有序、字节流传输服务。 面向连接&#xff1a; 应用…

百度糯米携手中山大学举办“开学流水宴”

热游圈消息&#xff1a; 百度糯米携手中山大学&#xff0c;于9月13日在“百团大战”游园会上举办了一场别开生面的“开学流水宴”&#xff0c;吸引了众多新生和百度糯米用户参与。这场长达20米的流水宴不仅为新生们带来了美味佳肴&#xff0c;更为他们提供了结交新朋友、增进同…

编写你的第一个java 程序

1.安装 jdk 网址&#xff1a; Java Downloads | Oracle 一般我们安装jdk 17 就行了 自己练习 自己学习 真正的开发中我们使用jdk 8 这个是最适合开发java 应用程序的 当然你也可以选择你的 系统 来安装这个java 在文件资源管理器打开JDK的安装目录的bin目录&#xff0c;会发…

pycharm远程连接server

1.工具–部署–配置 2.部署完成后&#xff0c;将现有的项目的解释器设置为ssh 解释器。实现在远端开发 解释器可以使用/usr/bin/python3

ROC和AUC

什么是ROC和AUC ROC曲线&#xff08;Receiver Operating Characteristic curve&#xff09;和AUC&#xff08;Area Under the Curve&#xff09;是用于评估二分类模型性能的重要工具。 ROC曲线以真正例率&#xff08;True Positive Rate&#xff0c;也称为召回率或灵敏度&…

Scala的函数至简原则

对于scala语言来说&#xff0c;函数的至简原则是它的一大特色。下面让我们一起来看看分别有什么吧&#xff01; 函数至简原则&#xff1a;能省则省&#xff01; 初始函数 def test(name:String):String{return name }1、return可以省略&#xff0c;Scala会使用函数体的最后一…

【Ubuntu20.04+Noetic】UR5e+Gazebo+Moveit

环境准备 创建工作空间 mkdir -p ur5e_ws/src cd ur5e_ws/srcUR机械臂软件包 UR官方没更新最新的noetic的分支,因此安装melodic,并需要改动相关文件。 安装UR的模型配置包,包里面有UR模型文件,moveit配置等: cd ~/ur5e_ws/src git clone -b melodic-devel https://git…

探索未来的区块链DApp应用,畅享数字世界的无限可能

随着区块链技术的飞速发展&#xff0c;分布式应用&#xff08;DApp&#xff09;正成为数字经济中的一股强劲力量。DApp以其去中心化、透明公正的特点&#xff0c;为用户带来了全新的数字体验&#xff0c;开创了数字经济的新潮流。作为一家专业的区块链DApp应用开发公司&#xf…

3月黄油奶酪行业数据分析:安佳和妙可蓝多领军市场

近些年来&#xff0c;随着新消费主义盛行&#xff0c;老少皆宜的黄油和奶酪逐渐成为都市年轻人的烘培“新宠”。 今年3月份&#xff0c;黄油奶酪表现的中规中矩&#xff0c;处在稳定发展阶段。根据鲸参谋数据显示&#xff0c;3月份&#xff0c;在线上综合电商平台&#xff08;…

凌恩病原微生物检测系统上线啦,助力环境病原微生物检测

病原微生物是指能够引起人类或动物疾病的微生物&#xff0c;包括病毒、细菌、真菌、衣原体和支原体等。病原微生物可以通过空气、体液等介质传播&#xff0c;危害人体健康&#xff0c;造成财产损失。因此&#xff0c;快速、准确地检测病原微生物对于疫情防控和保障人民生命健康…