基于pytorch_lightning测试resnet18不同激活方式在CIFAR10数据集上的精度

基于pytorch_lightning测试resnet18不同激活方式在CIFAR10数据集上的精度

  • 一.曲线
    • 1.train_acc
    • 2.val_acc
    • 3.train_loss
    • 4.lr
  • 二.代码

本文介绍了如何基于pytorch_lightning测试resnet18不同激活方式在CIFAR10数据集上的精度
特别说明:
1.NoActive:没有任何激活函数
2.SparseActivation:只保留topk的激活,其余清零,topk通过训练得到[初衷是想让激活变得稀疏]
3.SelectiveActive:通过训练得到使用的激活函数
可参考的代码片段
1.pytorch_lightning 如何使用
2.pytorch如何替换激活函数
3.如何对自定义权值做衰减

一.曲线

1.train_acc

在这里插入图片描述

2.val_acc

在这里插入图片描述

3.train_loss

在这里插入图片描述

4.lr

在这里插入图片描述

二.代码

from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
import pytorch_lightning as pl
from torch.utils.data import DataLoader
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
import os
import numpy as np
from pytorch_lightning.loggers import TensorBoardLogger#torch.set_float32_matmul_precision('medium')class ResidualBlock(nn.Module):def __init__(self, inchannel, outchannel, stride=1):super(ResidualBlock, self).__init__()self.left = nn.Sequential(nn.Conv2d(inchannel, outchannel, kernel_size=3, stride=stride, padding=1, bias=False),nn.BatchNorm2d(outchannel),nn.ReLU(inplace=True),nn.Conv2d(outchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=False),nn.BatchNorm2d(outchannel))self.shortcut = nn.Sequential()if stride != 1 or inchannel != outchannel:self.shortcut = nn.Sequential(nn.Conv2d(inchannel, outchannel, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(outchannel))self.act=nn.ReLU()def forward(self, x):out = self.left(x)out += self.shortcut(x)out = self.act(out)return outclass ResNet(nn.Module):def __init__(self, ResidualBlock, num_classes=10):super(ResNet, self).__init__()self.inchannel = 64self.conv1 = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),nn.BatchNorm2d(64),nn.ReLU(),)self.layer1 = self.make_layer(ResidualBlock, 64,  2, stride=1)self.layer2 = self.make_layer(ResidualBlock, 128, 2, stride=2)self.layer3 = self.make_layer(ResidualBlock, 256, 2, stride=2)self.layer4 = self.make_layer(ResidualBlock, 512, 2, stride=2)self.fc = nn.Linear(512, num_classes)self.dropout=nn.Dropout(0.5)def make_layer(self, block, channels, num_blocks, stride):strides = [stride] + [1] * (num_blocks - 1)layers = []for stride in strides:layers.append(block(self.inchannel, channels, stride))self.inchannel = channelsreturn nn.Sequential(*layers)def forward(self, x):out = self.conv1(x)out = self.layer1(out)out = self.layer2(out)out = self.layer3(out)out = self.layer4(out)out = F.avg_pool2d(out, 4)out = out.view(out.size(0), -1)out = self.dropout(out)out = self.fc(out)return outclass SparseActivation(nn.Module):act_array=[x.cuda() for x in [nn.ReLU(),nn.ReLU6(),nn.Sigmoid(),nn.Hardsigmoid(),nn.GELU(),nn.SiLU(),nn.Mish(),nn.LeakyReLU(),nn.Hardswish(),nn.PReLU(),nn.SELU(),nn.Softplus(),nn.Softsign()]]def __init__(self,args):super(SparseActivation, self).__init__()self.input_weights = nn.Parameter(torch.randn(1)).cuda()self.act=SparseActivation.act_arrayself.act_weights = nn.Parameter(torch.randn(len(self.act))).cuda()self.args=argsdef forward(self, x):        index=self.args.actif index>=0:index=index-1if index==-1:prob=F.softmax(self.act_weights,dim=0)_, index = torch.topk(prob, 1, dim=0)x=self.act[index](x)if self.args.sparse==0:return xinput=x.flatten(1)input_weights = torch.sigmoid(self.input_weights)        topk = input.size(1)*input_weightstopk=topk.int()topk_vals, topk_indices = torch.topk(input, topk, dim=1)mask = torch.zeros_like(input).scatter(1, topk_indices, topk_vals)return mask.reshape(x.shape)class LitNet(pl.LightningModule):def __init__(self, args):super(LitNet, self).__init__()self.save_hyperparameters()self.args = argsself.resnet18 = ResNet(ResidualBlock)self.criterion = nn.CrossEntropyLoss()self.ws=[]self.replace_activation(self.resnet18,nn.ReLU, SparseActivation,self.ws)    def replace_activation(self,module, old_activation, new_activation,ws):for name, child in module.named_children():if isinstance(child, old_activation):op=new_activation(self.args)ws.append(op.input_weights)setattr(module, name,op)else:self.replace_activation(child, old_activation, new_activation,ws)        def forward(self, x):return self.resnet18(x)def on_train_epoch_start(self):self.train_total_loss=[]self.train_total_acc=[]def on_train_epoch_end(self):self.log('epoch_train_loss', np.mean(self.train_total_loss))self.log('epoch_train_acc', np.mean(self.train_total_acc)) self.log("lr",self.optimizer.state_dict()['param_groups'][0]['lr'])def training_step(self, batch, batch_idx):data, target = batchoutput = self(data)loss = self.criterion(output, target)l2_reg = torch.tensor(0.).cuda()l2_lambda=0.001for param in self.ws:l2_reg += torch.norm(param+4)                    loss += l2_lambda * l2_reg        self.log('iter_train_loss', loss)_, predicted = torch.max(output.data, 1)correct = (predicted == target).sum()acc = 100. * correct / target.size(0)      self.train_total_loss.append(loss.item())self.train_total_acc.append(acc.item())return loss       def on_validation_epoch_start(self):self.val_total_loss=[]self.val_total_acc=[]def on_validation_epoch_end(self):self.log('epoch_val_loss', np.mean(self.val_total_loss))self.log('epoch_val_acc', np.mean(self.val_total_acc))def validation_step(self, batch, batch_idx):data, target = batchoutput = self(data)_, predicted = torch.max(output.data, 1)correct = (predicted == target).sum()acc = 100. * correct / target.size(0)loss = self.criterion(output, target)        self.val_total_loss.append(loss.item())self.val_total_acc.append(acc.item())def test_step(self, batch, batch_idx):data, target = batchoutput = self(data)loss = self.criterion(output, target)self.log('test_loss', loss)return lossdef configure_optimizers(self):self.optimizer = optim.SGD(self.parameters(), lr=self.args.lr, momentum=0.9,weight_decay=5e-4)self.scheduler = optim.lr_scheduler.StepLR(self.optimizer,step_size=10,gamma = 0.8)            return [self.optimizer],[self.scheduler]class CIFAR10DataModule(pl.LightningDataModule):def __init__(self, batch_size):super().__init__()self.batch_size = batch_sizedef setup(self, stage=None):transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),])transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),])self.train = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)self.test = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)def train_dataloader(self):return DataLoader(self.train, batch_size=self.batch_size,shuffle=True,num_workers=2,persistent_workers=True)def val_dataloader(self):return DataLoader(self.test, batch_size=self.batch_size,shuffle=False,num_workers=2,persistent_workers=True)def test_dataloader(self):return DataLoader(self.test, batch_size=self.batch_size)def main():parser = argparse.ArgumentParser(description='PyTorch MNIST Example')parser.add_argument('--batch-size', type=int, default=128, metavar='N',help='input batch size for training (default: 64)')parser.add_argument('--epochs', type=int, default=100, metavar='N',help='number of epochs to train (default: 14)')parser.add_argument('--lr', type=float, default=0.01, metavar='LR',help='learning rate (default: 1.0)')parser.add_argument('--act', type=int, default=-1,help='learning rate (default: 1.0)')parser.add_argument('--sparse', type=int, default=0,help='learning rate (default: 1.0)')args = parser.parse_args()cifar10_data = CIFAR10DataModule(batch_size=args.batch_size)log_dir = "lightning_logs"args.sparse=0   #不开启稀疏args.act=0      #自适应激活model = LitNet(args)logger = TensorBoardLogger(save_dir=log_dir, name="SelectiveActive")    trainer = pl.Trainer(logger=logger,devices=1,max_epochs=args.epochs,val_check_interval=1.0,gradient_clip_val=0.9, gradient_clip_algorithm="value")trainer.fit(model, cifar10_data)    args.sparse=0     #不开启稀疏args.act=-1       #不用激活model = LitNet(args)    cifar10_data = CIFAR10DataModule(batch_size=args.batch_size)logger = TensorBoardLogger(save_dir=log_dir, name="NoActive")    trainer = pl.Trainer(logger=logger,devices=1,max_epochs=args.epochs,val_check_interval=1.0,gradient_clip_val=0.9, gradient_clip_algorithm="value")trainer.fit(model, cifar10_data)  args.sparse=1args.act=-1       #不用激活,开启稀疏model = LitNet(args)       logger = TensorBoardLogger(save_dir=log_dir, name="SparseActivation")    trainer = pl.Trainer(logger=logger,devices=1,max_epochs=args.epochs,val_check_interval=1.0,gradient_clip_val=0.9, gradient_clip_algorithm="value")trainer.fit(model, cifar10_data)  for idx,act_name in enumerate(SparseActivation.act_array):name=act_name.__class__.__name__print(name)args.act=idx+1args.sparse=0model = LitNet(args)     logger = TensorBoardLogger(save_dir=log_dir, name=name)    trainer = pl.Trainer(logger=logger,devices=1,max_epochs=args.epochs,val_check_interval=1.0,gradient_clip_val=0.9, gradient_clip_algorithm="value")trainer.fit(model, cifar10_data)if __name__ == '__main__':main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/850982.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

调研管理系统的设计

管理员账户功能包括:系统首页,个人中心,管理员管理,基础数据管理,教师类型管理,课程类型管理,公告类型管理 前台账户功能包括:系统首页,个人中心,论坛&#…

腾讯云和windows11安装frp,实现内网穿透

一、内网穿透目的 实现公网上,访问到windows上启动的web服务 二、内网穿透的环境准备 公网服务器、windows11的电脑、frp软件(需要准备两个软件,一个是安装到公网服务器上的,一个是安装到windows上的) frp下载地址下载版本 1.此版本(老版…

论文阅读:Indoor Scene Layout Estimation from a Single Image

项目地址:https://github.com/leVirve/lsun-room/tree/master 发表时间:2018 icpr 场景理解,在现实交互的众多方面中,因其在增强现实(AR)等应用中的相关性而得到广泛关注。场景理解可以分为几个子任务&…

C++ 内联函数 auto关键字

内联函数 用inline修饰的函数会成为内联函数,内联函数会在编译的阶段在调用函数的位置进行展开,不会涉及建立栈帧以提高效率,同时每一次的函数调用都会展开整个函数导致内存消耗的增加,是以空间换时间,所以内联函数比…

SpringSecurity入门(二)

8、获取用户认证信息 三种策略模式,调整通过修改VM options // 如果没有设置自定义的策略,就采用MODE_THREADLOCAL模式 public static final String MODE_THREADLOCAL "MODE_THREADLOCAL"; // 采用InheritableThreadLocal,它是Th…

最新下载:Navicat for MySQL 11软件安装视频教程

软件简介: Navicat for MySQL 是一款强大的 MySQL 数据库管理和开发工具,它为专业开发者提供了一套强大的足够尖端的工具,但对于新用户仍然易于学习。Navicat For Mysql中文网站:http://www.formysql.com/ Navicat for MySQL 基于…

NLP实战入门——文本分类任务(TextRNN,TextCNN,TextRNN_Att,TextRCNN,FastText,DPCNN,BERT,ERNIE)

本文参考自https://github.com/649453932/Chinese-Text-Classification-Pytorch?tabreadme-ov-file,https://github.com/leerumor/nlp_tutorial?tabreadme-ov-file,https://zhuanlan.zhihu.com/p/73176084,是为了进行NLP的一些典型模型的总…

如何远程桌面连接?

远程桌面连接是一种方便快捷的方式,可以帮助用户在不同地区的设备之间实现信息的远程通信。我们将介绍一种名为【天联】的组网产品,它可以帮助用户轻松实现远程桌面连接。 【天联】组网是一款异地组网内网穿透产品,由北京金万维科技有限公司…

绿联Nas docker 中 redis 老访问失败的排查

部署了一些服务,老隔3-5 天其他服务就联不上 redis 了,未确定具体原因,只记录观察到的现象 宿主机访问 只有 ipv6 绑定了,ipv4 绑定挂掉了 其他容器访问 也无法访问成功 当重启容器后: 一切又恢复正常。 可能的解…

MATLAB | 透明度渐变颜色条

hey 各位好久不见,今天提供一段有趣的小代码,之前刷到公众号闻道研学的一篇推送MATLAB绘图技巧 | 设置颜色条的透明度(附完整代码)(https://mp.weixin.qq.com/s/bVx8AVL9jGlatja51v4H0A),文章希…

机器学习周记(第四十二周:AT-LSTM)2024.6.3~2024.6.9

目录 摘要Abstract一、文献阅读1. 题目2. abstract3. 网络架构3.1 LSTM3.2 注意力机制概述3.3 AT-LSTM3.4 数据预处理 4. 文献解读4.1 Introduction4.2 创新点4.3 实验过程4.3.1 训练参数4.3.2 数据集4.3.3 实验设置4.3.4 实验结果 5. 基于pytorch的transformer 摘要 本周阅读…

免费,C++蓝桥杯等级考试真题--第11级(含答案解析和代码)

C蓝桥杯等级考试真题--第11级 答案:D 解析: A. a b; b a; 这种方式会导致a和b最终都等于b原来的值,因为a的原始值在被b覆盖前没有保存。 B. swap(a,b); 如果没有自定义swap函数或者没有包含相应的库,这个选项会编…

【C++题解】1389 - 数据分析

问题:1389 - 数据分析 类型:简单循环 题目描述: 该方法的操作方式为,如果要传递 2 个数字信息给友军,会直接传递给友军一个整数 n(n 是一个 10 位以内的整数),该整数的长度代表要传…

汇编语言LDS指令

在8086架构的实模式下,LDS指令(Load Pointer Using DS)用于从内存中加载一个32位的指针到指定寄存器和DS寄存器。我们来详细解释一下这条指令为什么会修改DS段寄存器。 LDS指令的功能 LDS指令格式如下: LDS destination, sourc…

程序猿大战Python——运算符

常见的运算符 目标:了解Python中常见的运算符有哪些? 运算符是用于执行程序代码的操作运算。常见的运算符有: (1)算术运算符:、-、*、/、//、% 、**; (2)赋值运算符&am…

macOS - 终端快捷键

本文转自 Mac 上“终端”中的键盘快捷键 https://support.apple.com/zh-cn/guide/terminal/trmlshtcts/mac 以下基于系统版本 macOS Sonoma 14 文章目录 Mac 上“终端”中的键盘快捷键1、使用“终端”窗口和标签页2、编辑命令行3、在“终端”窗口中选择和查找文本4、使用标记和…

【Uniapp】uniapp微信小程序定义图片地址全局变量

错误写法: main.js Vue.prototype.$imgUrl 图片地址这么写之后 就发现压根不起作用;获取到的是undefined 正确写法: 返回函数,后面可以拼上OSS图片完整路径 Vue.prototype.$imgUrl (url) > {return ("https://地址…

Android——热点开关(优化中)

SoftAP打开与关闭 目录 1.三个名词的解释以及关系 Tethering——网络共享,WiFi热点、蓝牙、USB SoftAp——热点(无线接入点),临时接入点 Hostapd——Hostapd是用于Linux系统的软件,,支持多种无线认证和加密协议,将任…

LabVIEW进行图像拼接的实现方法与优化

在工业检测和科研应用中,对于大尺寸物体的拍摄需要通过多次拍摄后进行图像拼接。LabVIEW 作为强大的图形化编程工具,能够实现图像拼接处理。本文将详细介绍LabVIEW进行图像拼接的实现方法、注意事项和提高效率的策略。 图像拼接的实现方法 1. 图像采集…

c++引用的本质(反汇编角度分析)

目录 一、引用基础理论 二、 引用的本质 三、从反汇编角度进行分析 1.变量赋值 2.引用和指针初始化 3.通过引用和指针赋值 4.eaxd的作用 一、引用基础理论 在c中我们都知道,引用(&)就是变量的一个别名,它允许我们为已存…