Simple-STNDT使用Transformer进行Spike信号的表征学习(三)训练与评估

文章目录

    • 1. 评估指标
    • 2. 训练准备
    • 3. debug测试
    • 4. train-val函数

1. 评估指标

import numpy as np
from scipy.special import gammaln
import torchdef neg_log_likelihood(rates, spikes, zero_warning=True):"""Calculates Poisson negative log likelihood given rates and spikes.formula: -log(e^(-r) / n! * r^n)= r - n*log(r) + log(n!)Parameters----------rates : np.ndarraynumpy array containing rate predictionsspikes : np.ndarraynumpy array containing true spike countszero_warning : bool, optionalWhether to print out warning about 0 rate predictions or notReturns-------floatTotal negative log-likelihood of the data"""assert spikes.shape == rates.shape, \f"neg_log_likelihood: Rates and spikes should be of the same shape. spikes: {spikes.shape}, rates: {rates.shape}"if np.any(np.isnan(spikes)):mask = np.isnan(spikes)rates = rates[~mask]spikes = spikes[~mask]assert not np.any(np.isnan(rates)), \"neg_log_likelihood: NaN rate predictions found"assert np.all(rates >= 0), \"neg_log_likelihood: Negative rate predictions found"if (np.any(rates == 0)):rates[rates == 0] = 1e-9result = rates - spikes * np.log(rates) + gammaln(spikes + 1.0)return np.sum(result)def bits_per_spike(rates, spikes):"""Computes bits per spike of rate predictions given spikes.Bits per spike is equal to the difference between the log-likelihoods (in base 2)of the rate predictions and the null model (i.e. predicting mean firing rate of each neuron)divided by the total number of spikes.Parameters----------rates : np.ndarray3d numpy array containing rate predictionsspikes : np.ndarray3d numpy array containing true spike countsReturns-------floatBits per spike of rate predictions"""nll_model = neg_log_likelihood(rates, spikes)nll_null = neg_log_likelihood(np.tile(np.nanmean(spikes, axis=(0,1), keepdims=True), (spikes.shape[0], spikes.shape[1], 1)), spikes, zero_warning=False)return (nll_null - nll_model) / np.nansum(spikes) / np.log(2)

2. 训练准备

from torch.utils.data import DataLoader
from dataset import make_datasets, mask_batch
from model import SpatioTemporalNDT
from metric import bits_per_spike
import torch
from torch.optim import AdamW
from torch import nnbatch_size = 16
lr = 1e-3
train_dataset, val_dataset = make_datasets()
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size*2, shuffle=False)
trial_length = 160
neuron_num = 130
model = SpatioTemporalNDT(trial_length, neuron_num)
num_epochs = 50
optim = AdamW(model.parameters(), lr=lr)
log_interval = 20

3. debug测试

def param_num(model):return sum(param.numel() for param in model.parameters() if param.requires_grad)def debug_test():spikes, heldout_spikes, forward_spikes = next(iter(train_dataloader))print(spikes.shape)             # [16, 120, 98]print(heldout_spikes.shape)     # [16, 120, 32]print(forward_spikes.shape)     # [16, 40, 130]masked_spikes, labels = mask_batch(spikes, heldout_spikes, forward_spikes)print(masked_spikes.shape)      # [16, 160, 130]print(labels.shape)             # [16, 160, 130]print(param_num(model))         # 256886loss, decoder_rates = model.forward(masked_spikes, labels)print(loss)                     # tensor(1.2356, grad_fn=<MeanBackward0>)print(decoder_rates.shape)      # torch.Size([16, 160, 130])val_loss, val_score = valid(val_dataloader, model)print(val_loss)print(val_score)

4. train-val函数

def train(model, dataloader, val_dataloader, num_epochs, optim):for epoch in range(num_epochs):print(f"--------- Epoch{epoch:2d} ----------")train_loss = []for i, (spikes, heldout_spikes, forward_spikes) in enumerate(dataloader):masked_spikes, labels = mask_batch(spikes, heldout_spikes, forward_spikes)loss, decoder_rates = model(masked_spikes, labels)optim.zero_grad()loss.backward()nn.utils.clip_grad_norm_(model.parameters(), 200.0)optim.step()with torch.no_grad():train_loss.append(loss.item())if i % log_interval == 0:print(f"Train loss: {sum(train_loss)/len(train_loss)}")val_loss, val_score = valid(val_dataloader, model)print(f"val loss: {float(val_loss)}")print(f"val score: {float(val_score)}")print()def valid(val_dataloader, model):model.eval()pred_rates = []heldout_spikes_full = []loss_list = []with torch.no_grad():for spikes, heldout_spikes, forward_spikes in val_dataloader:no_mask_labels = spikes.clone()no_mask_labels = torch.cat([no_mask_labels, torch.zeros_like(heldout_spikes)], -1)no_mask_labels = torch.cat([no_mask_labels, torch.zeros_like(forward_spikes)], 1)no_mask_labels[:, :, -heldout_spikes.size(-1):] = -100 # unmasked_labelno_mask_labels[:, -forward_spikes.size(1):,:] = -100 # unmasked_labelspikes = torch.cat([spikes, torch.zeros_like(heldout_spikes)], -1)spikes = torch.cat([spikes, torch.zeros_like(forward_spikes)], 1)loss, batch_rates = model(spikes, no_mask_labels)pred_rates.append(batch_rates)heldout_spikes_full.append(heldout_spikes)loss_list.append(loss)heldout_spikes = torch.cat(heldout_spikes_full, dim=0)pred_rates = torch.cat(pred_rates, dim=0)eval_rates_heldout = torch.exp(pred_rates.clone()[:, :heldout_spikes.size(1), -heldout_spikes.size(-1):]).numpy()[()].astype('float')eval_spikes_heldout = heldout_spikes.clone().numpy()[()].astype('float')# print(eval_rates_heldout.shape)     # (270, 120, 32)# print(eval_spikes_heldout.shape)    # (270, 120, 32)return sum(loss_list), float(bits_per_spike(eval_rates_heldout, eval_spikes_heldout))

最后,开始训练:

(stndt) D:\STNDT>python main.py
--------- Epoch 0 ----------
Train loss: 1.2486777305603027
Train loss: 0.5138219218878519
Train loss: 0.32351083744589876
val loss: 0.8636534214019775
val score: -0.39136893422272767--------- Epoch 1 ----------
Train loss: 0.09501783549785614
Train loss: 0.09383604036910194
Train loss: 0.09296295773692248
val loss: 0.8206770420074463
val score: -0.09666108663240561--------- Epoch 2 ----------
Train loss: 0.09622671455144882
Train loss: 0.09049306774423235
Train loss: 0.08994600358532696
val loss: 0.812911331653595
val score: -0.04202061410637105--------- Epoch 3 ----------
Train loss: 0.09225568175315857
Train loss: 0.09019481816462108
Train loss: 0.08970968806888999
val loss: 0.8099062442779541
val score: -0.019777008609723395--------- Epoch 4 ----------
Train loss: 0.08371596038341522
Train loss: 0.08918796905449458
Train loss: 0.0894875490083927
val loss: 0.8083348274230957
val score: -0.008896993842432857--------- Epoch 5 ----------
Train loss: 0.09019782394170761
Train loss: 0.08884035441137496
Train loss: 0.08963883395602064
val loss: 0.8072853088378906
val score: -0.0026569800293788507--------- Epoch 6 ----------
Train loss: 0.09667835384607315
Train loss: 0.09060979953833989
Train loss: 0.08956735653848183
val loss: 0.8064565658569336
val score: 0.0003163842262874261--------- Epoch 7 ----------
Train loss: 0.08744495362043381
Train loss: 0.08888665941499528
Train loss: 0.08930287855427439
val loss: 0.8058080077171326
val score: 0.005321093845270125--------- Epoch 8 ----------
Train loss: 0.10221674293279648
Train loss: 0.09078312771660942
Train loss: 0.08951869806865366
val loss: 0.8044026494026184
val score: 0.007113516568588765--------- Epoch 9 ----------
Train loss: 0.09160886704921722
Train loss: 0.08984803798652831
Train loss: 0.0897282888976539
val loss: 0.803226113319397
val score: 0.01217366049067505--------- Epoch10 ----------
Train loss: 0.09165512025356293
Train loss: 0.08854220310846965
Train loss: 0.08920388268988307
val loss: 0.8014105558395386
val score: 0.015657932109121083--------- Epoch11 ----------
Train loss: 0.07934647053480148
Train loss: 0.08873837547642845
Train loss: 0.08900632345821799
val loss: 0.7992606163024902
val score: 0.017361369978752348--------- Epoch12 ----------
Train loss: 0.08641393482685089
Train loss: 0.0893486404702777
Train loss: 0.08927923113834567
val loss: 0.7964036464691162
val score: 0.026846927269458674--------- Epoch13 ----------
Train loss: 0.08859497308731079
Train loss: 0.08794442635206949
Train loss: 0.08938420000599652
val loss: 0.7929846048355103
val score: 0.033583528051411037--------- Epoch14 ----------
Train loss: 0.08901184052228928
Train loss: 0.08875668652000882
Train loss: 0.08939630665430208
val loss: 0.7878748178482056
val score: 0.04465469491549107--------- Epoch15 ----------
Train loss: 0.09487541764974594
Train loss: 0.08885077848320916
Train loss: 0.08909488651083737
val loss: 0.7851467728614807
val score: 0.046395409621300066--------- Epoch16 ----------
Train loss: 0.0839885026216507
Train loss: 0.08959413000515529
Train loss: 0.08932711874566428
val loss: 0.7806612253189087
val score: 0.05012596379845563--------- Epoch17 ----------
Train loss: 0.09544813632965088
Train loss: 0.08826960552306402
Train loss: 0.0890249778948179
val loss: 0.7787002325057983
val score: 0.05084565441331739--------- Epoch18 ----------
Train loss: 0.09305278211832047
Train loss: 0.08740198683171045
Train loss: 0.08877205539767336
val loss: 0.7735776305198669
val score: 0.06808317309022775--------- Epoch19 ----------
Train loss: 0.08946727961301804
Train loss: 0.0880857486100424
Train loss: 0.08832225821367125
val loss: 0.7722467184066772
val score: 0.0741929715804975--------- Epoch20 ----------
Train loss: 0.09155283123254776
Train loss: 0.08762263329256148
Train loss: 0.08867140041618812
val loss: 0.774036705493927
val score: 0.06465988606612133--------- Epoch21 ----------
Train loss: 0.08425123244524002
Train loss: 0.08848933414334342
Train loss: 0.08806171540806933
val loss: 0.7706096768379211
val score: 0.06233272968330965--------- Epoch22 ----------
Train loss: 0.08672144263982773
Train loss: 0.08736556342669896
Train loss: 0.08800865782470238
val loss: 0.7690156698226929
val score: 0.07570956489538153--------- Epoch23 ----------
Train loss: 0.09086063504219055
Train loss: 0.0895571896717662
Train loss: 0.08793148053128545
val loss: 0.7725724577903748
val score: 0.045295719065139656--------- Epoch24 ----------
Train loss: 0.08895140141248703
Train loss: 0.08862598595165071
Train loss: 0.08853605389595032
val loss: 0.7674567103385925
val score: 0.07400126493414798--------- Epoch25 ----------
Train loss: 0.08059882372617722
Train loss: 0.08788907066697166
Train loss: 0.08830737322568893
val loss: 0.7654385566711426
val score: 0.0783971076192251--------- Epoch26 ----------
Train loss: 0.0904078260064125
Train loss: 0.08821353883970351
Train loss: 0.08813101125926506
val loss: 0.7648967504501343
val score: 0.06579874206738114--------- Epoch27 ----------
Train loss: 0.0888797715306282
Train loss: 0.08781595457167853
Train loss: 0.08853465282335514
val loss: 0.765023946762085
val score: 0.06403537205845905--------- Epoch28 ----------
Train loss: 0.0925334170460701
Train loss: 0.08814156835987455
Train loss: 0.08763645026015073
val loss: 0.7604566216468811
val score: 0.08386773786224676--------- Epoch29 ----------
Train loss: 0.09102518111467361
Train loss: 0.08881006035066787
Train loss: 0.08800200536483671
val loss: 0.7639309167861938
val score: 0.05987701272594979--------- Epoch30 ----------
Train loss: 0.08757702261209488
Train loss: 0.08790529945066997
Train loss: 0.08796896276677527
val loss: 0.7679344415664673
val score: 0.04645880716520806--------- Epoch31 ----------
Train loss: 0.09563669562339783
Train loss: 0.08776313385793141
Train loss: 0.08768014010132813
val loss: 0.7532508969306946
val score: 0.09419951931221196--------- Epoch32 ----------
Train loss: 0.08262639492750168
Train loss: 0.08920836945374806
Train loss: 0.08818964242208295
val loss: 0.7534663081169128
val score: 0.07980706821661744--------- Epoch33 ----------
Train loss: 0.09010934829711914
Train loss: 0.08798151392312277
Train loss: 0.08814984251086305
val loss: 0.7573298215866089
val score: 0.0587445179781999--------- Epoch34 ----------
Train loss: 0.09029105305671692
Train loss: 0.08793160106454577
Train loss: 0.087826013383342
val loss: 0.7541366219520569
val score: 0.04576204364697583--------- Epoch35 ----------
Train loss: 0.09183177351951599
Train loss: 0.08813220936627615
Train loss: 0.08824214902592868
val loss: 0.7545167803764343
val score: 0.043795136749962035--------- Epoch36 ----------
Train loss: 0.08738738298416138
Train loss: 0.08769806651842027
Train loss: 0.08801802520344897
val loss: 0.7475957870483398
val score: 0.07046052509968409--------- Epoch37 ----------
Train loss: 0.08695636689662933
Train loss: 0.08928513243084862
Train loss: 0.08794533206922252
val loss: 0.7405006885528564
val score: 0.08250606459379788--------- Epoch38 ----------
Train loss: 0.08741921186447144
Train loss: 0.08701477554582414
Train loss: 0.08772314776007722
val loss: 0.7421612739562988
val score: 0.07261544623998699--------- Epoch39 ----------
Train loss: 0.08897516131401062
Train loss: 0.08884722207273756
Train loss: 0.08827457195375024
val loss: 0.7383261919021606
val score: 0.05041364027920663--------- Epoch40 ----------
Train loss: 0.08877569437026978
Train loss: 0.08783218938679922
Train loss: 0.08838088319795888
val loss: 0.7311040759086609
val score: 0.05160266134263263--------- Epoch41 ----------
Train loss: 0.0751330778002739
Train loss: 0.0872439131850288
Train loss: 0.08815818952351082
val loss: 0.723595917224884
val score: 0.08080731948303856--------- Epoch42 ----------
Train loss: 0.09519665688276291
Train loss: 0.0866984451810519
Train loss: 0.08742059876279133
val loss: 0.7205336689949036
val score: 0.08327377202054256--------- Epoch43 ----------
Train loss: 0.08966871351003647
Train loss: 0.08703825693754923
Train loss: 0.08704596176380064
val loss: 0.7158994078636169
val score: 0.05753987849499046--------- Epoch44 ----------
Train loss: 0.08914705365896225
Train loss: 0.08722686128956932
Train loss: 0.08729714445951509
val loss: 0.7021420001983643
val score: 0.08133226152944593--------- Epoch45 ----------
Train loss: 0.08485537022352219
Train loss: 0.08770599854843956
Train loss: 0.08782925693000235
val loss: 0.705651044845581
val score: 0.07325790592903407--------- Epoch46 ----------
Train loss: 0.08972616493701935
Train loss: 0.088348921920572
Train loss: 0.08801035510330665
val loss: 0.6982176303863525
val score: 0.06009563284716213--------- Epoch47 ----------
Train loss: 0.08506552129983902
Train loss: 0.08846274834303629
Train loss: 0.08772453265946085
val loss: 0.684754490852356
val score: 0.10142577749520322--------- Epoch48 ----------
Train loss: 0.08494629710912704
Train loss: 0.08716638279812676
Train loss: 0.08738453831614518
val loss: 0.6825719475746155
val score: 0.087609587353269--------- Epoch49 ----------
Train loss: 0.08093467354774475
Train loss: 0.08778195899157297
Train loss: 0.08736045422350489
val loss: 0.6823106408119202
val score: 0.06519610685639747

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/32833.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java面试题:通过实例说明工厂模式和抽象工厂模式的用法,以及它们在解耦中的作用

工厂模式和抽象工厂模式是创建型设计模式中的两种&#xff0c;主要用于对象的创建&#xff0c;并且通过将对象的创建过程封装起来&#xff0c;来实现代码的解耦和灵活性。下面通过具体实例来说明这两种模式的用法及其在解耦中的作用。 工厂模式&#xff08;Factory Method Pat…

STM32 - LED灯 蜂鸣器

&#x1f6a9; WRITE IN FRONT &#x1f6a9; &#x1f50e; 介绍&#xff1a;"謓泽"正在路上朝着"攻城狮"方向"前进四" &#x1f50e;&#x1f3c5; 荣誉&#xff1a;2021|2022年度博客之星物联网与嵌入式开发TOP5|TOP4、2021|2222年获评…

Pytest框架中pytest.mark功能

文章目录 mark功能 1. 使用pytest.mark.skip 2. 使用pytest.mark.skipif 3. 使用 pytest.mark.xfail 4使用pytest.mark.parametrize 5 使用pytest.mark.自定义标记 6 使用pytest.mark.usefixtures pytest 的mark功能在pytest官方文档是这样解释的&#xff1a; https://…

Rust:使用 Warp 框架编写基于 HTTPS 的 RESTful API

在 Rust 中使用 Warp 框架编写基于 HTTPS 的 RESTful API&#xff0c;你需要首先设置好 TLS/SSL 证书以启用 HTTPS。以下是一个基本的步骤指南&#xff1a; 步骤 1: 安装 Rust 和 Cargo 确保你已经安装了 Rust 和 Cargo。你可以从 Rust 官网 下载并安装 Rust。 步骤 2: 创建…

stm32学习笔记---GPIO输出(代码部分)LED闪烁/流水灯/蜂鸣器

目录 面包板的使用方法 第一个演示代码&#xff1a;LED闪烁 最后一次快速新建工程演示 点击新建工程 选择芯片 在工程文件夹中创建Start、Library、User Start文件夹的必备文件复制操作 Library文件夹的必备文件复制操作 User文件夹的必备文件复制操作 在keil中创建S…

关于数据登记的六点观察|数据与治理思享会(第1期)圆满举行

本文内容转载自 数据与治理专委会。 鼹鼠哥有幸在上周参与了数据大讲堂的首次线下活动&#xff0c;也做了个简短笔记 [最新]清华数据大讲堂线下思享会 因为上次是个人笔记&#xff0c;有些内容不方便些。既然今天官方公众号发出来了&#xff0c;就在这里把官方的内容也给大家转…

Repair LED lights

Repair LED lights 修理LED灯&#xff0c;现在基本用灯带&#xff0c;就是小型LED灯串联一起的 1&#xff09;拆旧灯条&#xff0c;这个旧的是用螺丝拧的产品 电闸关掉。 2&#xff09;五金店买一个&#xff0c;这种是磁铁吸附的产品 现在好多都是铝线啊。。。 小部件&#x…

【大数据离线项目四:什么是海豚调度?怎么使用可以将海豚调度应用到我们的大数据项目开发中?】

前言&#xff1a; &#x1f49e;&#x1f49e;大家好&#xff0c;我是书生♡&#xff0c;今天主要和大家分享一下什么是海豚调度&#xff1f;怎么使用可以将海豚调度应用到我们的项目开发中&#xff1f;希望对大家有所帮助。 &#x1f49e;&#x1f49e;代码是你的画笔&#xf…

数组 (java)

文章目录 一维数组静态初始化动态初始化 二维数组静态初始化动态初始化 数组参数传递可变参数关于 main 方法的形参 argsArray 工具类sort 中的 comparable 和 comparatorcomparator 比较器排序comparable 自然排序 一维数组 线性结构 静态初始化 第一种&#xff1a;int[] a…

IDEA插件推荐-CodeGeex

功能&#xff1a;这个插件可以实现快速翻译代码&#xff0c;json文件格式转换&#xff0c;代码语言类型转换。 安装方式&#xff1a;File->Settings->Plugins->MarketPlace->搜索“CodeGeex”即可 &#xff08;CodeGeex功能展示&#xff09; &#xff08;CodeGeex…

模拟算法讲解

模拟算法是一种基于实际情况模拟的算法&#xff0c;通过模拟现实世界中的系统或过程&#xff0c;来研究它们的性质和行为。模拟算法可以用于解决各种问题&#xff0c;包括物理模拟、经济模拟、社会模拟等。 模拟算法的基本步骤包括&#xff1a; 定义问题&#xff1a;明确需要模…

C++面向对象三大特性--多态

C面向对象三大特性–多态 文章目录 C面向对象三大特性--多态1.虚函数&#xff08;Virtual Function&#xff09;2.纯虚函数&#xff08;Pure Virtual Function&#xff09;和抽象类&#xff08;Abstract Class&#xff09;3.重写&#xff08;Override&#xff09;4.动态绑定&am…

【STM32c8t6】AHT20温湿度采集

【STM32c8t6】AHT20温湿度采集 一、探究目的二、探究原理2.1 I2C2.1. 硬件I2C2.1. 软件I2C 2.2 AHT20数据手册 三、实验过程3.1 CubeMX配置3.2 实物接线图3.3 完整代码3.4 效果展示 四、探究总结 一、探究目的 学习I2C总线通信协议&#xff0c;使用STM32F103完成基于I2C协议的A…

android串口助手apk下载 源码 演示 支持android 4-14及以上

android串口助手apk下载 1、自动获取串口列表 2、打开串口就开始接收 3、收发 字符或16进制 4、默认发送at\r\n 5、android串口助手apk 支持android 4-14 &#xff08;Google seral port 太老&#xff09; 源码找我 需要 用adb root 再setenforce 0进入SELinux 模式 才有权限…

【数据结构导论】自考笔试题:伪代码练习题汇总 1

目录 一、开源项目推荐 二、线性表的基本运算在单链表上的实现 &#xff08;1&#xff09;初始化 &#xff08;2&#xff09;插入 p 指向的新结点的操作 &#xff08;3&#xff09;删除 *p 节点 三、循环链表 &#xff08;1&#xff09;在单链表中 &#xff08;2&…

如何打包数据库文件

使用 mysqldump 命令&#xff1a; mysqldump -u username -p database_name > output_file.sql username 是数据库的用户名。database_name 是要导出的数据库名称。output_file.sql 是导出的 SQL 文件名&#xff0c;可以自定义。 示例&#xff1a; mysqldump -u root -p…

urfread刷算法题day1|LeetCode2748.美丽下标的数目

题目 题目链接 LeetCode2748.美丽下标对的数目 题目描述 给你一个下标从 0 开始的整数数组 nums 。 如果下标对 i、j 满足 0 ≤ i < j < nums.length &#xff0c; 如果 nums[i] 的 第一个数字 和 nums[j] 的 最后一个数字 互质 &#xff0c; 则认为 nums[i] 和 nums…

基于PyTorch设计的全景图合成系统【文末完整工程源码下载】

前言 本项目实现基于PyTorch将多张图片合成为一张全景图。&#xff08;图像存储路径为/images/1&#xff09;。 作者&#xff1a;阿齐Archie&#xff08;联系我微信公众号&#xff1a;阿齐Archie&#xff09; 使用的图片为&#xff1a; 合成后为&#xff1a; 这个全景图项目主…

支持WebDav的网盘infiniCloud(静读天下,Zotero 等挂载)

前言 WebDav是一种基于HTTP的协议&#xff0c;允许用户在Web上直接编辑和管理文件&#xff0c;如复制、移动、删除等。 尽管有一些网盘支持WebDav&#xff0c;但其中大部分都有较多的使用限制。这些限制可能包括&#xff1a;上传文件的大小限制、存储空间的限制、下载速度的限…

Golang日常开发第三方库收集

GUI fltk_go 获取系统系统相关信息 基于go-fltk的分叉版本&#xff0c;项目地址:https://github.com/george012/fltk_go 传送门 Web—>web方向 gin 经典Web框架项目地址:https://github.com/gin-gonic/gin 传送门 跨平台工具 gopsutil 获取系统系统相关信息 基于psu…