1: 利用spikingjelly 实现MNIST 数据集分类
设置仿真时间T=10
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from spikingjelly.activation_based import neuron, encoding, functional, surrogate, layer
import argparseimport pandas as pd
dir = 'E:/MNIST_CSV/'# 读取 CSV 文件
mnist_data_train = pd.read_csv(dir+'mnist_train.csv') # 或者 'mnist_train.csv'
mnist_data_test = pd.read_csv(dir+'mnist_test.csv') # 或者 'mnist_test.csv'# SNN 网络
class SNN(nn.Module):def __init__(self, tau):super().__init__()self.layer = nn.Sequential(layer.Flatten(),layer.Linear(28 * 28, 10, bias=False),neuron.LIFNode(tau=tau, surrogate_function=surrogate.ATan()),)def forward(self, x: torch.Tensor):return self.layer(x)def main():''':return: None* :ref:`API in English <lif_fc_mnist.main-en>`.. _lif_fc_mnist.main-cn:使用全连接-LIF的网络结构,进行MNIST识别。\n这个函数会初始化网络进行训练,并显示训练过程中在测试集的正确率。* :ref:`中文API <lif_fc_mnist.main-cn>`.. _lif_fc_mnist.main-en:The network with FC-LIF structure for classifying MNIST.\nThis function initials the network, starts trainingand shows accuracy on test dataset.'''parser = argparse.ArgumentParser(description='LIF MNIST Training')net = SNN(tau=2.0)# net.to(args.device)# # 初始化数据加载器train_data = pd.read_csv(dir + 'mnist_train.csv')y_train = train_data.iloc[:, 0] # 第一列是标签X_train = train_data.iloc[:, 1:] # 其余列是像素值# 转换为张量X_train = torch.tensor(X_train.values, dtype=torch.float32)# print("X_train:",X_train.shape)y_train = torch.tensor(y_train.values, dtype=torch.long)# 创建 DataLoadertrain_dataset = TensorDataset(X_train, y_train)train_data_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)# 选择一个样本start_epoch = 0optimizer = torch.optim.SGD(net.parameters(), lr=0.001)encoder = encoding.PoissonEncoder()# 训练for epoch in range(start_epoch, 10):start_time = time.time()train_loss = 0train_acc = 0train_samples = 0net.train()for img, label in train_data_loader:img = img.reshape(-1,28,28)optimizer.zero_grad()img = imglabel = labellabel_onehot = F.one_hot(label, 10).float()out_fr = 0.for t in range(10):encoded_img = encoder(img)out_fr += net(encoded_img)out_fr = out_fr / 10loss = F.mse_loss(out_fr, label_onehot)print(loss)loss.backward(retain_graph=True)optimizer.step()train_samples += label.numel()train_loss += loss.item() * label.numel()# 正确率的计算方法如下。认为输出层中脉冲发放频率最大的神经元的下标i是分类结果train_acc += (out_fr.argmax(1) == label).float().sum().item()# 优化一次参数后,需要重置网络的状态,因为SNN的神经元是有“记忆”的functional.reset_net(net)if __name__ == '__main__':main()
2:附件
spikingjelly 实现MNIST 数据集分类coda完整版
https://spikingjelly.readthedocs.io/zh-cn/0.0.0.0.14/activation_based/lif_fc_mnist.html