Simple-STNDT使用Transformer进行Spike信号的表征学习(二)模型结构

文章目录

    • 1. 位置编码
    • 1.2 EncoderLayer
    • 1.3 Encoder
    • 1.4 STNDT

1. 位置编码

model.py

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer, MultiheadAttention
import math
UNMASKED_LABEL = -100class PositionalEncoding(nn.Module):def __init__(self, trial_length, d_model, dropout):super().__init__()self.dropout = nn.Dropout(dropout)pe = torch.zeros(trial_length, d_model)position = torch.arange(0, trial_length, dtype=torch.float).unsqueeze(1)div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))pe[:, 0::2] = torch.sin(position * div_term)if d_model % 2 == 0:pe[:, 1::2] = torch.cos(position * div_term)else:pe[:, 1::2] = torch.cos(position * div_term[:-1])pe = pe.unsqueeze(0).transpose(0, 1)self.register_buffer('pe', pe)def forward(self, x):x = x + self.pe[:x.size(0), :]return self.dropout(x)

1.2 EncoderLayer

model.py
核心编码层,加入了将空间注意力编码

class STNTransformerEncoderLayer(TransformerEncoderLayer):def __init__(self, d_model, d_model_s, num_heads=2,  dim_feedforward=128, dropout=0.1, activation='relu'):super().__init__(d_model,nhead=num_heads,dim_feedforward=dim_feedforward,dropout=dropout,activation=activation)self.num_heads = num_headsself.num_input = d_modelself.d_model_s = d_model_s      # d_model_s: 时间步数(例如 160), 用于空间自注意力self.spatial_self_attn = MultiheadAttention(embed_dim=d_model_s, num_heads=num_heads)self.spatial_norm1 = nn.LayerNorm(d_model_s)self.ts_norm1 = nn.LayerNorm(d_model)self.ts_norm2 = nn.LayerNorm(d_model)self.ts_linear1 = nn.Linear(d_model, dim_feedforward)self.ts_linear2 = nn.Linear(dim_feedforward, d_model)self.ts_dropout1 = nn.Dropout(dropout)self.ts_dropout2 = nn.Dropout(dropout)self.ts_dropout3 = nn.Dropout(dropout)def attend(self, src, context_mask=None, **kwargs):attn_res = self.self_attn(src, src, src, attn_mask=context_mask, **kwargs)return (*attn_res, torch.tensor(0, device=src.device, dtype=torch.float))def spatial_attend(self, src, context_mask=None, **kwargs):r"""Attends over spatial dimensionArgs:src: spatiotemporal neural population inputcontext_mask: spatial context maskReturns:spatiotemporal neural population activity transformed by spatial attention"""attn_res = self.spatial_self_attn(src, src, src, attn_mask=context_mask, **kwargs)return (*attn_res, torch.tensor(0, device=src.device, dtype=torch.float))def forward(self, src, spatial_src, src_mask=None, spatial_src_mask=None, src_key_padding_mask=None):# temporalresidual = srcsrc = self.norm1(src)t_out, t_weights, _ = self.attend(src, context_mask=src_mask, key_padding_mask=src_key_padding_mask)src = residual + self.dropout1(t_out)residual = srcsrc = self.norm2(src)src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))src = residual + self.dropout2(src2)# spatialspatial_src = self.spatial_norm1(spatial_src)spatial_out, spatial_weights, _ = self.spatial_attend(spatial_src,context_mask=spatial_src_mask, key_padding_mask=None)# spatio-temporal feature mixturets_residual = srcsrc = self.ts_norm1(src)ts_out = torch.bmm(spatial_weights, src.permute(1, 2, 0)).permute(2, 0, 1)ts_out = ts_residual + self.ts_dropout1(ts_out)ts_residual = ts_outts_out = self.ts_norm2(ts_out)ts_out = self.ts_linear2(self.ts_dropout2(self.activation(self.ts_linear1(ts_out))))ts_out = ts_residual + self.ts_dropout3(ts_out)return ts_out

1.3 Encoder

model.py

class STNTransformerEncoder(TransformerEncoder):def __init__(self, encoder_layer, num_layers, norm=None):super().__init__(encoder_layer, num_layers, norm)def forward(self, src, spatial_src, mask=None, spatial_mask=None):for i, mod in enumerate(self.layers):if i == 0:src = mod(src, spatial_src, src_mask=mask, spatial_src_mask=spatial_mask)else:src = mod(src, src.permute(2, 1, 0), src_mask=mask, spatial_src_mask=spatial_mask)if self.norm is not None:src = self.norm(src)return src

1.4 STNDT

model.py

class SpatioTemporalNDT(nn.Module):def __init__(self, trial_length, num_neurons, temperature=0.1, c_lambda=0.3, dropout=0.2, pos_drop=0.1, enc_layers=1, log_rates=True,enc_heads=2,  enc_dff=128, enc_drop=0.1) -> None:super().__init__()self.src_mask = Noneself.num_input = num_neuronsself.num_spatial_input = trial_lengthself.embedder = nn.Identity()self.spatial_embedder = nn.Identity()self.scale = math.sqrt(num_neurons)self.spatial_scale = math.sqrt(trial_length)self.src_pos_encoder = PositionalEncoding(trial_length, num_neurons, pos_drop)self.spatial_pos_encoder = PositionalEncoding(num_neurons, trial_length, pos_drop)self.projector = nn.Identity()self.spatial_projector = nn.Identity()self.n_views = 2self.temperature = temperatureself.contrast_lambda = c_lambdaself.cel = nn.CrossEntropyLoss(reduction='none')self.mse = nn.MSELoss(reduction='mean')encoder_layer =STNTransformerEncoderLayer(d_model=self.num_input,d_model_s=self.num_spatial_input, num_heads=enc_heads,dim_feedforward=enc_dff,dropout=enc_drop)self.transformer_encoder = STNTransformerEncoder(encoder_layer, enc_layers, nn.LayerNorm(self.num_input))self.rate_dropout = nn.Dropout(dropout)self.src_decoder = nn.Linear(num_neurons, self.num_input)self.classifier = nn.PoissonNLLLoss(reduction='none', log_input=log_rates)def _get_mask(self, src, do_convert=True):if self.src_mask is not None:return self.src_masksize = src.size(0)context_forward = 13context_backward = 79mask = (torch.triu(torch.ones(size, size), diagonal=-context_forward) == 1).transpose(0, 1)back_mask = (torch.triu(torch.ones(size, size), diagonal=-context_backward) == 1)mask = mask & back_maskmask = mask.float()mask = binary_mask_to_attn_mask(mask)self.src_mask = maskreturn self.src_maskdef forward(self, src: torch.Tensor, mask_labels: torch.Tensor):src = src.float()spatial_src = src.permute(2,0,1)spatial_src = self.spatial_embedder(spatial_src) * self.spatial_scalespatial_src = self.spatial_pos_encoder(spatial_src)src = src.permute(1,0,2)src = self.embedder(src) * self.scalesrc = self.src_pos_encoder(src)src_mask = self._get_mask(src)spatial_src_mask = Noneencoder_output = self.transformer_encoder(src, spatial_src, src_mask, spatial_src_mask)encoder_output = self.rate_dropout(encoder_output)decoder_output = self.src_decoder(encoder_output)decoder_rates = decoder_output.permute(1, 0, 2)decoder_loss = self.classifier(decoder_rates, mask_labels)masked_decoder_loss = decoder_loss[mask_labels != UNMASKED_LABEL]masked_decoder_loss = masked_decoder_loss.mean()return masked_decoder_loss, decoder_ratesdef binary_mask_to_attn_mask(x):return x.float().masked_fill(x == 0, float('-inf')).masked_fill(x == 1, float(0.0))

下一篇: https://blog.csdn.net/weixin_46866349/article/details/139906391

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/32904.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

RMDA通信1:通信过程和优势,以太网socket为何用户空间拷贝到内核空间

视频分享: 1.1 RDMA基本原理和优势,以太网socket通信为什么要用户空间拷贝到内核空间_哔哩哔哩_bilibili 一、以太网socket通信 1.1 以太网socket通信过程 1、发送端发起一次通信操作,数据由用户空间拷贝到内核空间。拷贝由CPU完成&#x…

Java基础--AOP--1.概述

一、AOP简介 AOP(Aspect Oriented )即为面向切面编程,也可称为面向方法编程,是方法增强的一种途径,通常可用于记录操作日志、权限空值、事务管理等等;Spring框架中的事务底层就是AOP。 二、AOP的组成 1、连接点&…

【YOLO 系列】基于YOLO V8的车载摄像头交通信号灯检测识别系统【python源码+Pyqt5界面+数据集+训练代码】

前言 随着智能交通系统的发展,交通信号灯的准确识别对于提高道路安全和交通效率具有至关重要的作用。传统的交通信号灯识别方法依赖于固定的传感器和摄像头,存在安装成本高、维护困难等问题。为了解决这些问题,我们启动了这个项目&#xff0…

中文邮件模板之向论文的作者咨询论文相关问题

目录 1. 内容 2. 邮件昵称 3. 格式很重要! 1. 内容 尊敬的: 您好,很抱歉在您百忙之中打扰您。 我是大学的一名硕士生,最近在做项目【】,您发表的论文【】给了我很大的启发。 论文中没有给出具体参数,如和…

React+TS 从零开始教程(2):简中简 HelloWolrd

源码链接:https://pan.quark.cn/s/c6fbc31dcb02 这一节,我们来见识ReactTS的威力,开始上手开发第一个组件,什么组件呢? 当然是简中简的 HelloWolrd组件啦。 在src下创建一个components,然后新建Hello.tsx …

CVPR2023论文速览Transformer

Paper1 TrojViT: Trojan Insertion in Vision Transformers 摘要原文: Vision Transformers (ViTs) have demonstrated the state-of-the-art performance in various vision-related tasks. The success of ViTs motivates adversaries to perform backdoor attacks on ViTs.…

C++系统相关操作3 - 获取操作系统的平台类型

1. 关键词2. sysutil.h3. sysutil.cpp4. 测试代码5. 运行结果6. 源码地址 1. 关键词 C 系统调用 操作系统平台类型 跨平台 2. sysutil.h #pragma once#include <cstdint> #include <string>namespace cutl {/*** brief Operating system platform type.**/enum…

详解 ClickHouse 的语法优化规则

ClickHouse 的 SQL 优化规则是基于 RBO(Rule Based Optimization) 一、count 优化 --1. count()、count(1) 和 count(*)&#xff0c;且没有 where 条件&#xff0c;则会直接使用 system.tables 的 total_rows EXPLAIN SELECT count()FROM datasets.hits_v1;--2. count(column)…

一款有趣的Python库绘制风向图,小白容易上手

利用 Python 绘制风向图 绘制风向图通常使用 matplotlib 库的 Barbs 类来实现.这个类用于绘制风向和风速的矢量场,可以实现不同的风向图风格. 安装 ## 命令安装 matplotlib 库&#xff1a;pip install matplotlib用法 下面是一个简单的示例代码,绘制风向图&#xff1a; 使…

代码随想录算法训练营Day46|动态规划:121.买卖股票的最佳时机I、122.买卖股票的最佳时机II、123.买卖股票的最佳时机III

买卖股票的最佳时机I 121. 买卖股票的最佳时机 - 力扣&#xff08;LeetCode&#xff09; 之前用贪心算法做过相同的题&#xff0c;这次考虑使用动态规划来完成。 dp[i]表示前i天的最大利润 我们已知每一天的价格price[i]&#xff0c;则dp[i]为每一天的价格price[i]减去当初…

论文学习_恶意代码家族检测关键技术研究

0. 摘要 研究背景:近年来,恶意代码的数量和规模在以指数级别增长,威胁和影响力与日俱增,造成的经济损失和社会损失也越来越高。因此,如何快速地识别出恶意代码的变种信息,掌握其家族等属性,能够有效辅助网络安全人员掌握其功能性和危害性,具有重要的研究价值。 研究内…

虚拟现实环境下的远程教育和智能评估系统(十三)

管理/教师端前端工作汇总education-admin&#xff1a; 首先是登录注册页面的展示 管理员 首页 管理员登录后的首页如下图所示 管理员拥有所有的权限 课程管理 1、可以查看、修改、增添、删除课程列表内容 2、可以对课程资源进行操作 3、可以对课程的类别信息进行管理&…

java的输出流File OutputStream

一、字节输出流FileOutput Stream 1、定义 使用OutputStream类的FileOutput Stream子类向文本文件写入的数据。 2.常用构造方法 3.创建文件输出流对象的常用方式 二、输出流FileOutputStream类的应用示例 1.示例 2、实现步骤 今天的总结就到此结束啦&#xff0c;拜拜&#x…

【Web APIs】DOM 文档对象模型 ⑤ ( 获取特殊元素 | 获取 html 元素 | 获取 body 元素 )

文章目录 一、获取特殊元素1、获取 html 元素2、获取 body 元素3、完整代码示例 本博客相关参考文档 : WebAPIs 参考文档 : https://developer.mozilla.org/zh-CN/docs/Web/APIgetElementById 函数参考文档 : https://developer.mozilla.org/zh-CN/docs/Web/API/Document/getE…

I2C总线8位IO扩展器PCF8574

PCF8574用于I2C总线的远程8位I/O扩展器 PCF8574国产有多个厂家有替代产品&#xff0c;图示为其中一款HT8574 1 产品特点 低待机电流消耗&#xff1a;10 uA&#xff08;最大值&#xff09; I2C 转并行端口扩展器 漏极开路中断输出 与大多数微控制器兼容 具有大电流驱动能力的闭…

嵌入式系统中的加解密签名

笔者来了解一下嵌入式系统中的加解密 1、背景与名词解释 笔者最近在做安全升级相关的模块&#xff0c;碰到了一些相关的概念和一些应用场景&#xff0c;特来学习记录一下。 1.1 名词解释 对称加密&#xff1a;对称加密是一种加密方法&#xff0c;使用相同的密钥&#xff08;…

IDEA各种实体类运行爆红,不运行就没事

1.问题描述 如图所示&#xff0c;后端项目的import的各种entity爆红&#xff0c;点击也有导入包的提示&#xff0c;且这种报红几乎遍布了整个工程项目 2.我的解决方案 清空缓存&#xff0c;然后把target文件删掉&#xff0c;重新跑 3.小结 idea项目有时候就是一个核弹&…

kubernetes排错(六)-Pod 状态一直 Terminating

查看 Pod 事件: $ kubectl describe pod apigateway-6dc48bf8b6-clcwk -n cn-staging 报错有以下几种情况&#xff0c;不同情况处理方式不同&#xff1a; 1&#xff09;Need to kill Pod Normal Killing 39s (x735 over 15h) kubelet, 10.179.80.31 Killing container …

软件设计师笔记-系统开发和运行知识(一)

软件工程 软件工程是一门研究用工程化方法构建和维护有效、实用和高质量软件的学科。它涉及计算机科学、数学、管理科学等多领域的原理和技术。其核心目标是应用这些原理和技术来提高软件的生产效率、质量,并降低其成本。 关键组成部分: 计算机科学:提供了软件开发所需的基…

微信小程序api和注册

微信小程序API学习总结 引言 随着移动互联网的快速发展&#xff0c;微信小程序已经成为开发者们关注的热点之一。微信小程序以其轻便、快捷、易于开发的特点吸引了大量的开发者。本文将对微信小程序的学习过程进行总结&#xff0c;希望能够帮助读者更好地掌握微信小程序的开发技…