nn.LSTM个人记录

简介

 

nn.LSTM参数

torch.nn.lstm(input_size,   "输入的嵌入向量维度,例如每个单词用50维向量表示,input_size就是50"hidden_size,  "隐藏层节点数量,也是输出的嵌入向量维度"num_layers,   "lstm 隐层的层数,默认为1"bias,         "隐层是否带 bias,默认为 true"batch_first,  "True 或者 False,如果是 True,则 input 为(batchsize, len, input_size),默认值为:False(len, batchsize, input_size)"dropout,      "除最后一层,每一层的输出都进行dropout,默认值0"bidirectional "如果设置为 True, 则表示双向 LSTM,默认为 False")

维度

batch_first=True,输入维度(batchsize,len,input_size)

batch_first=False,输入维度(len,batchsize, input_size)

batch_first=False,输出维度(len,batchsize,hidden_size)

举例嵌入向量维度为1

 假如输入x为(batchsize,len)的序列,即嵌入向量维度为1,进行一个回归预测。

如果将嵌入向量维度维度设为1就不太合理,因为如果len非常长例如几w,那么经过几w的时间步得到的得到的h维度为(batchsize,1),序列太长丢失很多信息,再输入全连接层预测效果不好。并且lstm实际上将嵌入向量维度从input_size规约到hidden_size。

所以在这里我们将len作为input_size,嵌入向量维度1作为len(即对调了一下)

添加一个维度:

x = x.unsqueeze(0)

x维度变为(1,batchsize,len),相当于设置数据的长度为1,嵌入向量维度为len,通过nn.LSTM输入到网络中。

#lstm为定义的网络
#h[-1]为最后输入到全连接层的嵌入矩阵 但是由于此问题中len为1,所以x等于h[-1]
x, (h, c) = lstm(x)

x维度变为(1,batchsize,hidden_size)

h为每层lstm最后一个时间步的输出一般可以输入到后续的全连接层),维度为(num_layers,batchsize,hidden_size)

c为最后一个时间步 LSTM cell 的状态(记忆单元,一般用不到),维度为(num_layers,batchsize,hidden_size)

移除张量中所有尺寸为 1 的维度,即将第一个维度移除掉:

lstm_out = x.squeeze(0)

x维度变为(batchsize,hidden_size) ,输入到全连接层(线性层,维度(hidden_size,num_class))中,最终输出维度(batchsize,num_class)

参考:

Pytorch — LSTM (nn.LSTM & nn.LSTMCell)-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/241434.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python---静态Web服务器-多任务版

1. 静态Web服务器的问题 目前的Web服务器,不能支持多用户同时访问,只能一个一个的处理客户端的请求,那么如何开发多任务版的web服务器同时处理 多个客户端的请求? 可以使用多线程,比进程更加节省内存资源。 多任务版web服务器…

计算机网络——网络层(四)

前言: 前面我们已经对物理层和数据链路层有了一个简单的认识与了解,现在我们需要对数据链路层再往上的一个层,网络层进行一个简单的学习与认识,网络层有着极其重要的作用,让我们对网络层进行一个简单的认识与学习吧 目…

Ubuntu:VS Code上C++的环境配置

使用 VSCode 开发 C/C 程序 , 涉及到 工作区的.vscode文件夹下的3个配置文件(均可以手动创建) : ① tasks.json : 编译器构建 配置文件 ; ② launch.json : 调试器设置 配置文件 ; ③ c_cpp_properties.json : 编译器路径和智能代码提示 配置文件 ; …

神经网络:机器学习基础

【一】什么是模型的偏差和方差? 误差(Error) 偏差(Bias) 方差(Variance) 噪声(Noise),一般地,我们把机器学习模型的预测输出与样本的真实label…

详解FreeRTOS:专栏总述

目录 1、理论篇 2、基础篇 3、进阶篇 4、高级篇 5、拓展篇 本专栏基于FreeRTOS底层源码介绍了嵌入式实时操作系统的概念,FreeRTOS任务创建、任务调度、任务同步与消息传递,软件定时器、事件通知等知识。 主要分为5方面内容:理论篇、基础…

Python中json模块的使用与pyecharts绘图的基本介绍

文章目录 json模块json与Python数据的相互转化 pyecharts模块pyecharts基本操作基础折线图配置选项全局配置选项 json模块的数据处理折线图示例示例代码 json模块 json实际上是一种数据存储格式,是一种轻量级的数据交互格式,可以把他理解成一个特定格式…

spring核心组件详细分析图

文章目录 spring核心组件 spring核心组件 spring七大组成部分 1、核心容器 这是Spring框架最基础的部分,它提供了依赖注入(DependencyInjection)特征来实现容器对Bean的管理。这里最基本的概念是BeanFactory,它是任何Spring应用…

python dash 写一个登陆页 4

界面 代码: 这里引入了dash_bootstrap_components 进行界面美化 ,要记一些className,也不是原来说的不用写CSS了。 from dash import Dash, html, dcc, callback, Output, Input, State import dash_bootstrap_components as dbcapp Dash(…

深入理解.net运行时方法表

在.net运行时,每一个类型在创建第一个实例,或者静态成员被第一次访问,或者被反射创建时,就会创建一个与该类型关联的方法表: 基本结构大概如下: -------------------------- | Method Table …

持续集成交付CICD:Jira 远程触发 Jenkins 实现更新 GitLab 分支

目录 一、实验 1.环境 2.GitLab 查看项目 3.Jira新建模块 4. Jira 通过Webhook 触发Jenkins流水线 3.Jira 远程触发 Jenkins 实现更新 GitLab 分支 二、问题 1.Jira 配置网络钩子失败 2. Jira 远程触发Jenkins 报错 一、实验 1.环境 (1)主机 …

HarmonyOS构建第一个JS应用(FA模型)

构建第一个JS应用(FA模型) 创建JS工程 若首次打开DevEco Studio,请点击Create Project创建工程。如果已经打开了一个工程,请在菜单栏选择File > New > Create Project来创建一个新工程。 选择Application应用开发&#xf…

Docker知识总结

Docker 学习目标: 掌握Docker基础知识,能够理解Docker镜像与容器的概念 完成Docker安装与启动 掌握Docker镜像与容器相关命令 掌握Tomcat Nginx 等软件的常用应用的安装 掌握docker迁移与备份相关命令 能够运用Dockerfile编写创建容器的脚本 能够…

go语言学习计划。

第1周:Go语言概述与环境搭建 内容:了解Go语言的历史、特点和应用场景。安装Go环境,配置工作区。实践:编写第一个Go程序,了解Go的编译运行流程。 第2周:基本语法与数据类型 内容:学习基本数据…

排序算法——基数排序

将需要排序的各个数当做元素,集合组成数组,对数组中的元素进行排序,再开辟一个临时数组的空间将数组中已有的元素数值当做临时数组的下标储存在临时数组中,然后用区别初始化值的方法区别出临时数组中待排数组的元素,以…

全方位掌握卷积神经网络:理解原理 优化实践应用

计算机视觉CV的发展 检测任务 分类与检索 超分辨率重构 医学任务 无人驾驶 整体网络架构 卷积层和激活函数(ReLU)的组合是网络的核心组成部分 激活函数(ReLU) 引入非线性,增强网络的表达能力。 卷积层 负责特征提取 池化层…

[MySQL] 二进制文件

文章目录 日志文件binlog是什么简介产生方式文件格式statementrowmixed 怎么办设置文件存储格式缓冲区大小的调整方式缓冲区写入文件的时机 sync_binlog 日志文件 官网 https://dev.mysql.com/doc/refman/8.0/en/server-logs.html中文版 https://mysql.net.cn/doc/refman/8.0/e…

OpenCV | 霍夫变换:以车道线检测为例

霍夫变换 霍夫变换只能灰度图,彩色图会报错 lines cv2.HoughLinesP(edge_img,1,np.pi/180,15,minLineLength40,maxLineGap20) 参数1:要检测的图片矩阵参数2:距离r的精度,值越大,考虑越多的线参数3:距离…

快速安装方式安装开源OpenSIPS和CP控制界面

OpenSIPS是目前世界上主流的两个SIP软交换引擎(其中另外一个是kamailio)或者SIP信令服务器(个人认为是比较正确的称谓)。关于Opensips的基础和一些参数配置和安装方式笔者在很久以前的历史文档中有非常多的介绍。最近,很多用户使用OpenSIPS软…

《PySpark大数据分析实战》-18.什么是数据分析

📋 博主简介 💖 作者简介:大家好,我是wux_labs。😜 热衷于各种主流技术,热爱数据科学、机器学习、云计算、人工智能。 通过了TiDB数据库专员(PCTA)、TiDB数据库专家(PCTP…

EPROM 作为存储器的 8 位单片机

一、基本概述 TX-P01I83 是以 EPROM 作为存储器的 8 位单片机,专为多 IO 产品的应用而设计,例如遥控器、风扇/灯光控制或是 玩具周边等等。采用 CMOS 制程并同时提供客户低成本、高性能等显着优势。TX-P01I83 核心建立在 RISC 精简指 令集架构可以很容易…