Simple-STNDT使用Transformer进行Spike信号的表征学习(一)数据处理篇

文章目录

    • 1.数据处理部分
      • 1.1 下载数据集
      • 1.2 数据集预处理
      • 1.3 划分train-val并创建Dataset对象
      • 1.4 掩码mask操作

数据、评估标准见NLB2021
https://neurallatents.github.io/

以下代码依据
https://github.com/trungle93/STNDT

原代码使用了 Ray+Config文件进行了参数搜索,库依赖较多,数据流过程不明显,代码冗杂,这里进行了抽丝剥茧,将其中最核心的部分提取出来。

1.数据处理部分

1.1 下载数据集

需要依赖 pip install dandi
downald.py

root = "D:/NeuralLatent/"
def downald_data():from dandi.download import downloaddownload("https://dandiarchive.org/dandiset/000128", root)download("https://dandiarchive.org/dandiset/000138", root)download("https://dandiarchive.org/dandiset/000139", root)download("https://dandiarchive.org/dandiset/000140", root)download("https://dandiarchive.org/dandiset/000129", root)download("https://dandiarchive.org/dandiset/000127", root)download("https://dandiarchive.org/dandiset/000130", root)

1.2 数据集预处理

需要依赖官方工具包pip install nlb_tools
主要是加载锋值序列数据,将其采样为5ms的时间槽
preprocess.py

## 以下为参数示例
# data_path = root + "/000129/sub-Indy/"
# dataset_name = "mc_rtt"
## 注意 "./data" 必须提前创建好from nlb_tools.make_tensors import make_train_input_tensors, make_eval_input_tensors, combine_h5def preprocess(data_path, dataset_name=None):dataset = NWBDataset(datapath)bin_width = 5dataset.resample(bin_width)make_train_input_tensors(dataset, dataset_name=dataset_name, trial_split="train", include_behavior=True, include_forward_pred=True, save_file=True,save_path=f"./data/{dataset_name}_train.h5")make_eval_input_tensors(dataset, dataset_name=dataset_name, trial_split="val", save_file=True, save_path=f"./data/{dataset_name}_val.h5")combine_h5([f"./data/{dataset_name}_train.h5", f"./data/{dataset_name}_val.h5"], save_path=f"./data/{dataset_name}_full.h5")## './data/mc_rtt_full.h5' 将成为后续的主要分析数据

1.3 划分train-val并创建Dataset对象

读取'./data/mc_rtt_full.h5'中的数据并创建dataset
dataset.py

import h5py
import numpy as np
import torch
from torch.utils import data
# data_path = "./data/mc_rtt_full.h5"class SpikesDataset(data.Dataset):def __init__(self, spikes, heldout_spikes, forward_spikes) -> None:self.spikes = spikesself.heldout_spikes = heldout_spikesself.forward_spikes = forward_spikesdef __len__(self):return self.spikes.size(0)def __getitem__(self, index):r"""Return spikes and rates, shaped T x N (num_neurons)"""return self.spikes[index], self.heldout_spikes[index], self.forward_spikes[index]def make_datasets(data_path):with h5py.File(data_path, 'r') as h5file:h5dict = {key: h5file[key][()] for key in h5file.keys()}if 'eval_spikes_heldin' in h5dict: # NLB dataget_key = lambda key: h5dict[key].astype(np.float32)train_data = get_key('train_spikes_heldin')train_data_fp = get_key('train_spikes_heldin_forward')train_data_heldout_fp = get_key('train_spikes_heldout_forward')train_data_all_fp = np.concatenate([train_data_fp, train_data_heldout_fp], -1)valid_data = get_key('eval_spikes_heldin')train_data_heldout = get_key('train_spikes_heldout')if 'eval_spikes_heldout' in h5dict:valid_data_heldout = get_key('eval_spikes_heldout')else:valid_data_heldout = np.zeros((valid_data.shape[0], valid_data.shape[1], train_data_heldout.shape[2]), dtype=np.float32)if 'eval_spikes_heldin_forward' in h5dict:valid_data_fp = get_key('eval_spikes_heldin_forward')valid_data_heldout_fp = get_key('eval_spikes_heldout_forward')valid_data_all_fp = np.concatenate([valid_data_fp, valid_data_heldout_fp], -1)else:valid_data_all_fp = np.zeros((valid_data.shape[0], train_data_fp.shape[1], valid_data.shape[2] + valid_data_heldout.shape[2]), dtype=np.float32)train_dataset = SpikesDataset(torch.tensor(train_data).long(),            # [810, 120, 98]torch.tensor(train_data_heldout).long(),    # [810, 120, 32]torch.tensor(train_data_all_fp).long(),     # [810, 40, 130])val_dataset = SpikesDataset(torch.tensor(valid_data).long(),            # [810, 120, 98]torch.tensor(valid_data_heldout).long(),    # [810, 120, 32]torch.tensor(valid_data_all_fp).long(),     # [810, 40, 130])return train_dataset, val_dataset

1.4 掩码mask操作

dataset.py

# Some infeasibly high spike count
UNMASKED_LABEL = -100def mask_batch(batch, heldout_spikes, forward_spikes):batch = batch.clone() # make sure we don't corrupt the input data (which is stored in memory)mask_ratio = 0.31254mask_random_ratio = 0.876mask_token_ratio = 0.527labels = batch.clone()mask_probs = torch.full(labels.shape, mask_ratio)# If we want any tokens to not get masked, do it here (but we don't currently have any)mask = torch.bernoulli(mask_probs)mask = mask.bool()labels[~mask] = UNMASKED_LABEL  # No ground truth for unmasked - use this to mask loss# We use random assignment so the model learns embeddings for non-mask tokens, and must rely on context# Most times, we replace tokens with MASK tokenindices_replaced = torch.bernoulli(torch.full(labels.shape, mask_token_ratio)).bool() & maskbatch[indices_replaced] = 0# Random % of the time, we replace masked input tokens with random value (the rest are left intact)indices_random = torch.bernoulli(torch.full(labels.shape, mask_random_ratio)).bool() & mask & ~indices_replacedrandom_spikes = torch.randint(batch.max(), labels.shape, dtype=torch.long)batch[indices_random] = random_spikes[indices_random]# heldout spikes are all maskedbatch = torch.cat([batch, torch.zeros_like(heldout_spikes)], -1)labels = torch.cat([labels, heldout_spikes.to(batch.device)], -1)batch = torch.cat([batch, torch.zeros_like(forward_spikes)], 1)labels = torch.cat([labels, forward_spikes.to(batch.device)], 1)# Leave the other 10% alonereturn batch, labels

下一篇: https://blog.csdn.net/weixin_46866349/article/details/139906187

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/34344.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【华为OD机试 2023】快递投放问题(C++ Java JavaScript Python)

题目 题目描述 有N个快递站点用字符串标识,某些站点之间有道路连接。 每个站点有一些包裹要运输,每个站点间的包裹不重复,路上有检查站会导致部分货物无法通行,计算哪些货物无法正常投递? 输入描述 第一行输入M N,M个包裹N个道路信息.后面M行分别输入包裹名、包裹起点、包…

期末成绩怎么快速发给家长

Hey各位老师们,今天来聊一个超级实用的话题:如何快速高效的向家长们传达学生的期末成绩。你可能会想,这不是很简单吗?直接班级群发个消息不就得了?但别忘了,保护学生隐私和自尊心也是很重要的哦&#xff01…

GB28181视频汇聚平台EasyCVR接入Ehome设备视频播放出现异常是什么原因?

多协议接入视频汇聚平台EasyCVR视频监控系统采用了开放式的架构,系统可兼容多协议接入,包括市场标准协议:国标GB/T 28181协议、GA/T 1400协议、JT808、RTMP、RTSP/Onvif协议;以及主流厂家私有协议及SDK,如:…

视频融合共享平台LntonCVS视频监控平台在农场果园等场景的使用方案

我国大江南北遍布着各类果园。传统的安全防范方式主要是建立围墙,但这种方式难以彻底阻挡不法分子的入侵和破坏。因此,需要一套先进、科学、实用且稳定的安全防范报警系统,以及时发现并处理潜在问题。 需求分析 由于果园地处偏远且缺乏有效防…

【已解决】后端接口返回的是文件流(数据流),前端代码如何实现下载文件流--封装代码

后端接口返回的是文件流(数据流),前端代码如何实现下载文件流--封装代码 实例代码环境 前端框架:vue3.0 请求插件:axios 1.6.2 export async function downFile(url, params, config) {downloadLoadingInstance ElLoa…

redis以后台的方式启动

文章目录 1、查看redis安装的目录2、Redis以后台的方式启动3、通过客户端连接redis4、连接后,测试与redis的连通性 1、查看redis安装的目录 [rootlocalhost ~]# cd /usr/local/redis/ [rootlocalhost redis]# ll 总用量 112 drwxr-xr-x. 2 root root 150 12月 6…

【从零开始认识AI】梯度下降法

目录 1. 原理介绍 2. 代码实现 1. 原理介绍 梯度下降法(Gradient Descent)是一种用于优化函数的迭代算法,广泛应用于机器学习和深度学习中,用来最小化一个目标函数。该目标函数通常代表模型误差或损失。 基本思想是从一个初始…

Docker镜像仓库:存储与分发Docker镜像的中央仓库

探索Docker镜像仓库:存储与分发Docker镜像的中央仓库 如果你是Docker的新手,或者已经在使用Docker但还不太了解Docker镜像仓库,那么这篇博客将是你的最佳指南。我们将从基础概念开始,逐步深入,帮助你全面掌握Docker注…

JS中的三种事件模型

JavaScript 中的事件模型主要有三种: 传统事件模型(DOM Level 0)标准事件模型(DOM Level 2)IE 事件模型(非标准,仅限于旧版本的 Internet Explorer) 下面分别介绍这三种事件模型&…

【JavaScript 小工具】——获取 URL 中的参数

要从 location.href 中获取指定参数,你可以使用 JavaScript 来解析 URL 并提取参数值。以下是一种常见的方法: // 获取当前页面的 URL var urlString window.location.href;// 解析 URL,获取参数部分 var url new URL(urlString);// 获取参…

C#写一个WebService服务器

首先在NuGet中下载Fleck动态库 创建一个WebSocketHelper类 public class WebSocketHelper {//客户端url以及其对应的Socket对象字典IDictionary<string, IWebSocketConnection> dic_Sockets new Dictionary<string, IWebSocketConnection>();//创建一个 websock…

软件测试计划审核表、试运行审核、试运行申请表、开工申请表

1、系统测试计划审核表 2、系统试运行审核表 3、系统试运行申请表 4、开工申请表 5、开工令 6、项目经理授权书 软件全套资料获取&#xff1a;本文末个人名片直接获取或者进主页。 系统测试计划审核表 系统试运行审核表 系统试运行申请表 开工申请表 开工令 项目经理授权书

青否数字人实时直播带货手机版发布!

青否数字人6大核心 AIGC 技术&#xff0c;让新手小白也能轻松搞定数字人在全平台的稳定直播&#xff0c;并有效规避违规风险&#xff0c;赋能商家开播即赚钱&#xff01; AI主播 只需要录制主播1分钟的绿幕视频&#xff0c;1秒钟就能克隆出一个数字人主播形象。S级真人深度学习…

快速鲁棒的 ICP (Fast and Robust Iterative Closest Point)

迭代最近点&#xff08;Iterative Closet Point&#xff0c;ICP&#xff09;算法及其变体是两个点集之间刚性配准的基本技术&#xff0c;在机器人技术和三维重建等领域有着广泛的应用。ICP的主要缺点是&#xff1a;收敛速度慢&#xff0c;以及对异常值、缺失数据和部分重叠的敏…

未来已来!GPT-5震撼登场,工作与生活面临新变革!

随着科技界领袖对AI系统发展之快的惊叹&#xff0c;新一代大语言模型GPT-5即将登场&#xff0c;引发了我们对工作和日常生活的新一轮思考。微软CTO Kevin Scott和阿里巴巴董事长蔡崇信等人的言论为我们描绘了一幅生动的未来图景&#xff0c;即AI将在我们的生活中扮演越来越重要…

力扣刷题总结 -- 数组28

82. 无法吃午餐的学生数量&#xff08;简单&#xff09; 题目要求&#xff1a; 学校的自助午餐提供圆形和方形的三明治&#xff0c;分别用数字 0 和 1 表示。所有学生站在一个队列里&#xff0c;每个学生要么喜欢圆形的要么喜欢方形的。 餐厅里三明治的数量与学生的数量相同。…

el-form-item的label设置两端对齐

<style scoped> ::v-deep .el-form-item__label {display: inline;text-align-last: justify; } </style>效果如图所示

数据分析必备:一步步教你如何用matplotlib做数据可视化(11)

1、Matplotlib 三维绘图 尽管Matplotlib最初设计时只考虑了二维绘图&#xff0c;但是在后来的版本中&#xff0c;Matplotlib的二维显示器上构建了一些三维绘图实用程序&#xff0c;以提供一组三维数据可视化工具。通过导入Matplotlib包中包含的mplot3d工具包&#xff0c;可以启…

双 μC 的 PWM 频率和分辨率

该方法是过滤 PWM 信号的 HF 分量&#xff0c;只留下与占空比成正比的 LF 或 DC 分量。然而&#xff0c;低通滤波器并不能完全滤除PWM频率&#xff0c;因此LF/DC信号一般会有一些纹波。 有两种方法可以降低 PWM DAC 的纹波。可以降低低通滤波器的截止频率&#xff0c;或者提高…