基于 EfficientNetV2 实现判别MNIST 手写模型分类

pytorch深度学习项目实战100例 的学习记录

我的环境:

白票大王: google colab

用其他的话,其实实现也行,但是让小白来重环境来开始安装的话,浪费时间

论文速读

EfficientNetV2是由 Google Research,Brain Team发布在2021 ICML的一篇论文,它结合使用NAS和缩放,优化训练速度和参数效率。并且模型中使用新操作(如 Fused-MBConv)在搜索空间中进行搜索。EfficientNetV2 模型比EfficientNetV1的训练速度快得多,同时体积小 6.8 倍。

在这里插入图片描述
理解和提高 EfficientNetV1 的训练效率

  1. 使用非常大的图像尺寸进行训练很慢( Training with very large image sizes is slow)
    在这里插入图片描述

  2. Depth-wise卷积在模型的早期层执行缓慢但在后期层是有效的(Depthwise convolutions are slow in early layers but effective in later stages)

在这里插入图片描述
MBConv 和 Fused-MBConv 的结构

Fused-MBConv 逐渐将 EfficientNet-B4 中的原始 MBConv 替换为 Fused-MBConv。

但是不是所有替换都有效果
在这里插入图片描述

结构图
在这里插入图片描述
EfficientNetV2 与 EfficientNetV1有几个主要区别:

  • EfficientNetV2 在早期层中广泛使用了 MBConv 和新添加的 fused-MBConv。
  • EfficientNetV2 更喜欢 MBConv 的较小扩展比,因为较小的扩展比往往具有较少的内存访问开销。
  • EfficientNetV2 更喜欢较小的核大小( 3×3),但它增加了更多层来补偿较小内核大小导致的感受野减少。
  • EfficientNetV2 完全移除了原始 EfficientNet 中的最后一个 stride-1 阶段,这可能是由于其较大的参数大小和内存访问开销。

训练
(1) 我们将最大推理图像大小限制为 480,因为非常大的图像通常会导致昂贵的内存和训练速度开销; (2) 作为启发式,我们还逐渐将更多的层添加到后期(例如表 4 中的第 5 阶段和第 6 阶段),以增加网络容量而不会增加太多运行时开销。
在这里插入图片描述
当图像尺寸较小增广较弱时模型的表现最好; 但是对于更大的图像,它在更强的增广的情况下表现更好。从小图像尺寸和弱正则化(epoch = 1)开始,然后随着更大的图像尺寸和更强的正则化逐渐增加学习难度:更大的 Dropout 率、RandAugment 幅度和混合比

在这里插入图片描述

因为阿光的项目中用了mindspore,而colab上面运行不了TT

from collections import OrderedDict
from functools import partialimport mindspore.nn as nn
import mindspore.numpy as mnp
import mindspore.ops as opsimport argparse
import osimport mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.vision.c_transforms as CV
from mindspore import Model
from mindspore import context
from mindspore import dtype as mstype
from mindspore.dataset.vision import Inter
from mindspore.nn import Accuracy
from mindspore.train.callback import LossMonitor
from mindspore.train.callback import ModelCheckpoint, CheckpointConfigdef drop_path(x, drop_prob: float = 0., training: bool = False):if drop_prob == 0. or not training:return xkeep_prob = 1 - drop_probshape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets#     random_tensor = keep_prob + Tensor(torch.rand(shape).numpy())stdnormal = ops.StandardNormal(seed=1)random_tensor = keep_prob + stdnormal(shape)#     random_tensor.floor_()  # binarizerandom_tensor = mnp.floor(random_tensor)output = mnp.divide(x, keep_prob) * random_tensor#     output = x.div(keep_prob) * random_tensorreturn outputclass DropPath(nn.Cell):def __init__(self, drop_prob=None):super(DropPath, self).__init__()self.drop_prob = drop_probdef construct(self, x):return drop_path(x, self.drop_prob, self.training)class ConvBNAct(nn.Cell):def __init__(self,in_planes: int,out_planes: int,kernel_size: int = 3,stride: int = 1,groups: int = 1,norm_layer=None,activation_layer=None):super(ConvBNAct, self).__init__()#         self.activation=activation_layerpadding = (kernel_size - 1) // 2if norm_layer is None:norm_layer = nn.BatchNorm2dif activation_layer is None:activation_layer = nn.ReLU  # alias Swish  (torch>=1.7)self.conv = nn.Conv2d(in_channels=in_planes,out_channels=out_planes,kernel_size=kernel_size,stride=stride,# padding=padding,group=groups,has_bias=False)self.bn = norm_layer(out_planes)#         if self.activation is not no':self.act = activation_layer()def construct(self, x):result = self.conv(x)result = self.bn(result)#         if self.activation is not 'no':result = self.act(result)return resultclass SqueezeExcite(nn.Cell):def __init__(self,input_c: int,  # block input channelexpand_c: int,  # block expand channelse_ratio: float = 0.25):super(SqueezeExcite, self).__init__()squeeze_c = int(input_c * se_ratio)self.conv_reduce = nn.Conv2d(expand_c, squeeze_c, 1, pad_mode='valid')self.act1 = nn.ReLU()  # alias Swishself.conv_expand = nn.Conv2d(squeeze_c, expand_c, 1, pad_mode='valid')self.act2 = nn.Sigmoid()def construct(self, x):scale = x.mean((2, 3), keep_dims=True)scale = self.conv_reduce(scale)scale = self.act1(scale)scale = self.conv_expand(scale)scale = self.act2(scale)return scale * xclass MBConv(nn.Cell):def __init__(self,kernel_size: int,input_c: int,out_c: int,expand_ratio: int,stride: int,se_ratio: float,drop_rate: float,norm_layer):super(MBConv, self).__init__()if stride not in [1, 2]:raise ValueError("illegal stride value.")self.has_shortcut = (stride == 1 and input_c == out_c)activation_layer = nn.ReLU  # alias Swishexpanded_c = input_c * expand_ratio# 在EfficientNetV2中,MBConv中不存在expansion=1的情况所以conv_pw肯定存在assert expand_ratio != 1# Point-wise expansionself.expand_conv = ConvBNAct(input_c,expanded_c,kernel_size=1,norm_layer=norm_layer,activation_layer=activation_layer)# Depth-wise convolutionself.dwconv = ConvBNAct(expanded_c,expanded_c,kernel_size=kernel_size,stride=stride,groups=expanded_c,norm_layer=norm_layer,activation_layer=activation_layer)#         self.se = SqueezeExcite(input_c, expanded_c, se_ratio) if se_ratio > 0 else ops.Identityself.se = SqueezeExcite(input_c, expanded_c, se_ratio) if se_ratio > 0 else None# Point-wise linear projectionself.project_conv = ConvBNAct(expanded_c,out_planes=out_c,kernel_size=1,norm_layer=norm_layer,#                                       activation_layer=ops.Identity)  # 注意这里没有激活函数,所有传入Identityself.out_channels = out_c# 只有在使用shortcut连接时才使用dropout层self.drop_rate = drop_rateif self.has_shortcut and drop_rate > 0:self.dropout = DropPath(drop_rate)def construct(self, x):result = self.expand_conv(x)result = self.dwconv(result)#         if self.se != 'no':if self.se is not None:result = self.se(result)result = self.project_conv(result)if self.has_shortcut:if self.drop_rate > 0:result = self.dropout(result)result += xreturn resultclass FusedMBConv(nn.Cell):def __init__(self,kernel_size: int,input_c: int,out_c: int,expand_ratio: int,stride: int,se_ratio: float,drop_rate: float,norm_layer):super(FusedMBConv, self).__init__()assert stride in [1, 2]assert se_ratio == 0self.has_shortcut = stride == 1 and input_c == out_cself.drop_rate = drop_rateself.has_expansion = expand_ratio != 1activation_layer = nn.ReLU  # alias Swishexpanded_c = input_c * expand_ratio# 只有当expand ratio不等于1时才有expand convif self.has_expansion:# Expansion convolutionself.expand_conv = ConvBNAct(input_c,expanded_c,kernel_size=kernel_size,stride=stride,norm_layer=norm_layer,activation_layer=activation_layer)self.project_conv = ConvBNAct(expanded_c,out_c,kernel_size=1,norm_layer=norm_layer,#                                           activation_layer=ops.Identity)  # 注意没有激活函数else:# 当只有project_conv时的情况self.project_conv = ConvBNAct(input_c,out_c,kernel_size=kernel_size,stride=stride,norm_layer=norm_layer,activation_layer=activation_layer)  # 注意有激活函数self.out_channels = out_c# 只有在使用shortcut连接时才使用dropout层self.drop_rate = drop_rateif self.has_shortcut and drop_rate > 0:self.dropout = DropPath(drop_rate)def construct(self, x):if self.has_expansion:result = self.expand_conv(x)result = self.project_conv(result)else:result = self.project_conv(x)if self.has_shortcut:if self.drop_rate > 0:result = self.dropout(result)result += xreturn resultclass EfficientNetV2(nn.Cell):def __init__(self,model_cnf: list,num_classes: int = 10,num_features: int = 128,dropout_rate: float = 0.2,drop_connect_rate: float = 0.2):super(EfficientNetV2, self).__init__()for cnf in model_cnf:assert len(cnf) == 8norm_layer = partial(nn.BatchNorm2d, eps=1e-3, momentum=0.1)stem_filter_num = model_cnf[0][4]self.stem = ConvBNAct(1,stem_filter_num,kernel_size=3,stride=2,norm_layer=norm_layer)  # 激活函数默认是SiLUtotal_blocks = sum([i[0] for i in model_cnf])block_id = 0blocks = []for cnf in model_cnf:repeats = cnf[0]op = FusedMBConv if cnf[-2] == 0 else MBConvfor i in range(repeats):blocks.append(op(kernel_size=cnf[1],input_c=cnf[4] if i == 0 else cnf[5],out_c=cnf[5],expand_ratio=cnf[3],stride=cnf[2] if i == 0 else 1,se_ratio=cnf[-1],drop_rate=drop_connect_rate * block_id / total_blocks,norm_layer=norm_layer))block_id += 1self.blocks = nn.SequentialCell(*blocks)head_input_c = model_cnf[-1][-3]head = OrderedDict()head.update({"project_conv": ConvBNAct(head_input_c,num_features,kernel_size=1,norm_layer=norm_layer)})  # 激活函数默认是SiLU#         self.adaptive=ops.AdaptiveAvgPool2D((None,1))#         head.update({"avgpool": ops.AdaptiveAvgPool2D(None,1)})head.update({"flatten": nn.Flatten()})if dropout_rate > 0:head.update({"dropout": nn.Dropout(keep_prob=dropout_rate)})head.update({"classifier": nn.Dense(num_features, num_classes)})self.head = nn.SequentialCell(head)self.avgpool = nn.AvgPool2d(kernel_size=(10, 12), pad_mode='valid')def construct(self, x):x = self.stem(x)x = self.blocks(x)x = self.avgpool(x)#         x = self.adaptive(x)x = self.head(x)return xdef efficientnetv2_s(num_classes: int = 10):model_config = [[2, 3, 1, 1, 24, 24, 0, 0],[4, 3, 2, 4, 24, 48, 0, 0],[4, 3, 2, 4, 48, 64, 0, 0],[6, 3, 2, 4, 64, 128, 1, 0.25],[9, 3, 1, 6, 128, 160, 1, 0.25],[15, 3, 2, 6, 160, 256, 1, 0.25]]model = EfficientNetV2(model_cnf=model_config,num_classes=num_classes,dropout_rate=0.2)return modeldef create_dataset(data_path, batch_size=32, repeat_size=1,num_parallel_workers=1):# 定义数据集mnist_ds = ds.MnistDataset(data_path)resize_height, resize_width = 300, 384# 定义所需要操作的map映射resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)hwc2chw_op = CV.HWC2CHW()type_cast_op = C.TypeCast(mstype.int32)type_cast_op_image = C.TypeCast(mstype.float32)# 使用map映射函数,将数据操作应用到数据集mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)mnist_ds = mnist_ds.map(operations=type_cast_op_image, input_columns="image",num_parallel_workers=num_parallel_workers)mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers)# 进行shuffle、batch操作buffer_size = 1000mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)return mnist_dsdef train_net(args, model, batch_size, epoch_size, data_path, repeat_size, ckpoint_cb, sink_mode):ds_train = create_dataset(os.path.join(data_path, 'train'), batch_size, repeat_size)model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(10)], dataset_sink_mode=sink_mode)def test_net(network, model, data_path):ds_eval = create_dataset(os.path.join(data_path, 'test'))acc = model.eval(ds_eval, dataset_sink_mode=False)print('{}'.format(acc))net = efficientnetv2_s(num_classes=10)parser = argparse.ArgumentParser(description='MindSpore Example')
parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU'])args = parser.parse_known_args()[0]
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)# 定义损失函数
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')# 定义优化器
net_opt = nn.Momentum(net.trainable_params(), learning_rate=0.01, momentum=0.9)# 设置模型保存参数
config_ck = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=10)
# 应用模型保存参数
ckpoint = ModelCheckpoint(prefix='checkpoint_efficientvs', config=config_ck)train_epoch = 10
# mnist_path = './datasets/cifar-10-batches-bin'
mnist_path = './datasets/MNIST_Data'
dataset_size = 1
batch_size = 16
model = Model(net, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
train_net(args, model, batch_size, train_epoch, mnist_path, dataset_size, ckpoint, False)
test_net(net, model, mnist_path)

自己重新写了pytorch的版本,来重新训练

网路结构

import copy
from functools import partial
from collections import OrderedDictimport torch
from torch import nndef get_efficientnet_v2_structure(model_name):if 'efficientnet_v2_s' in model_name:return [# e k  s  in  out xN  se   fused(1, 3, 1, 24, 24, 2, False, True),(4, 3, 2, 24, 48, 4, False, True),(4, 3, 2, 48, 64, 4, False, True),(4, 3, 2, 64, 128, 6, True, False),(6, 3, 1, 128, 160, 9, True, False),(6, 3, 2, 160, 256, 15, True, False),]elif 'efficientnet_v2_m' in model_name:return [# e k  s  in  out xN  se   fused(1, 3, 1, 24, 24, 3, False, True),(4, 3, 2, 24, 48, 5, False, True),(4, 3, 2, 48, 80, 5, False, True),(4, 3, 2, 80, 160, 7, True, False),(6, 3, 1, 160, 176, 14, True, False),(6, 3, 2, 176, 304, 18, True, False),(6, 3, 1, 304, 512, 5, True, False),]elif 'efficientnet_v2_l' in model_name:return [# e k  s  in  out xN  se   fused(1, 3, 1, 32, 32, 4, False, True),(4, 3, 2, 32, 64, 7, False, True),(4, 3, 2, 64, 96, 7, False, True),(4, 3, 2, 96, 192, 10, True, False),(6, 3, 1, 192, 224, 19, True, False),(6, 3, 2, 224, 384, 25, True, False),(6, 3, 1, 384, 640, 7, True, False),]elif 'efficientnet_v2_xl' in model_name:return [# e k  s  in  out xN  se   fused(1, 3, 1, 32, 32, 4, False, True),(4, 3, 2, 32, 64, 8, False, True),(4, 3, 2, 64, 96, 8, False, True),(4, 3, 2, 96, 192, 16, True, False),(6, 3, 1, 192, 256, 24, True, False),(6, 3, 2, 256, 512, 32, True, False),(6, 3, 1, 512, 640, 8, True, False),]class ConvBNAct(nn.Sequential):"""Convolution-Normalization-Activation Module"""def __init__(self, in_channel, out_channel, kernel_size, stride, groups, norm_layer, act, conv_layer=nn.Conv2d):super(ConvBNAct, self).__init__(conv_layer(in_channel, out_channel, kernel_size, stride=stride, padding=(kernel_size-1)//2, groups=groups, bias=False),norm_layer(out_channel),act())class SEUnit(nn.Module):"""Squeeze-Excitation Unitpaper: https://openaccess.thecvf.com/content_cvpr_2018/html/Hu_Squeeze-and-Excitation_Networks_CVPR_2018_paper"""def __init__(self, in_channel, reduction_ratio=4, act1=partial(nn.SiLU, inplace=True), act2=nn.Sigmoid):super(SEUnit, self).__init__()hidden_dim = in_channel // reduction_ratioself.avg_pool = nn.AdaptiveAvgPool2d((1, 1))self.fc1 = nn.Conv2d(in_channel, hidden_dim, (1, 1), bias=True)self.fc2 = nn.Conv2d(hidden_dim, in_channel, (1, 1), bias=True)self.act1 = act1()self.act2 = act2()def forward(self, x):return x * self.act2(self.fc2(self.act1(self.fc1(self.avg_pool(x)))))class StochasticDepth(nn.Module):"""StochasticDepthpaper: https://link.springer.com/chapter/10.1007/978-3-319-46493-0_39:arg- prob: Probability of dying- mode: "row" or "all". "row" means that each row survives with different probability"""def __init__(self, prob, mode):super(StochasticDepth, self).__init__()self.prob = probself.survival = 1.0 - probself.mode = modedef forward(self, x):if self.prob == 0.0 or not self.training:return xelse:shape = [x.size(0)] + [1] * (x.ndim - 1) if self.mode == 'row' else [1]return x * torch.empty(shape).bernoulli_(self.survival).div_(self.survival).to(x.device)class MBConvConfig:"""EfficientNet Building block configuration"""def __init__(self, expand_ratio: float, kernel: int, stride: int, in_ch: int, out_ch: int, layers: int,use_se: bool, fused: bool, act=nn.SiLU, norm_layer=nn.BatchNorm2d):self.expand_ratio = expand_ratioself.kernel = kernelself.stride = strideself.in_ch = in_chself.out_ch = out_chself.num_layers = layersself.act = actself.norm_layer = norm_layerself.use_se = use_seself.fused = fused@staticmethoddef adjust_channels(channel, factor, divisible=8):new_channel = channel * factordivisible_channel = max(divisible, (int(new_channel + divisible / 2) // divisible) * divisible)divisible_channel += divisible if divisible_channel < 0.9 * new_channel else 0return divisible_channelclass MBConv(nn.Module):"""EfficientNet main building blocks:arg- c: MBConvConfig instance- sd_prob: stochastic path probability"""def __init__(self, c, sd_prob=0.0):super(MBConv, self).__init__()inter_channel = c.adjust_channels(c.in_ch, c.expand_ratio)block = []if c.expand_ratio == 1:block.append(('fused', ConvBNAct(c.in_ch, inter_channel, c.kernel, c.stride, 1, c.norm_layer, c.act)))elif c.fused:block.append(('fused', ConvBNAct(c.in_ch, inter_channel, c.kernel, c.stride, 1, c.norm_layer, c.act)))block.append(('fused_point_wise', ConvBNAct(inter_channel, c.out_ch, 1, 1, 1, c.norm_layer, nn.Identity)))else:block.append(('linear_bottleneck', ConvBNAct(c.in_ch, inter_channel, 1, 1, 1, c.norm_layer, c.act)))block.append(('depth_wise', ConvBNAct(inter_channel, inter_channel, c.kernel, c.stride, inter_channel, c.norm_layer, c.act)))block.append(('se', SEUnit(inter_channel, 4 * c.expand_ratio)))block.append(('point_wise', ConvBNAct(inter_channel, c.out_ch, 1, 1, 1, c.norm_layer, nn.Identity)))self.block = nn.Sequential(OrderedDict(block))self.use_skip_connection = c.stride == 1 and c.in_ch == c.out_chself.stochastic_path = StochasticDepth(sd_prob, "row")def forward(self, x):out = self.block(x)if self.use_skip_connection:out = x + self.stochastic_path(out)return outclass EfficientNetV2(nn.Module):"""Pytorch Implementation of EfficientNetV2paper: https://arxiv.org/abs/2104.00298- reference 1 (pytorch): https://github.com/d-li14/efficientnetv2.pytorch/blob/main/effnetv2.py- reference 2 (official): https://github.com/google/automl/blob/master/efficientnetv2/effnetv2_configs.py:arg- layer_infos: list of MBConvConfig- out_channels: bottleneck channel- nlcass: number of class- dropout: dropout probability before classifier layer- stochastic depth: stochastic depth probability"""def __init__(self, layer_infos, out_channels=1280, nclass=0, dropout=0.2, stochastic_depth=0.0,block=MBConv, act_layer=nn.SiLU, norm_layer=nn.BatchNorm2d):super(EfficientNetV2, self).__init__()self.layer_infos = layer_infosself.norm_layer = norm_layerself.act = act_layerself.in_channel = layer_infos[0].in_chself.final_stage_channel = layer_infos[-1].out_chself.out_channels = out_channelsself.cur_block = 0self.num_block = sum(stage.num_layers for stage in layer_infos)self.stochastic_depth = stochastic_depthself.stem = ConvBNAct(3, self.in_channel, 3, 2, 1, self.norm_layer, self.act)self.blocks = nn.Sequential(*self.make_stages(layer_infos, block))self.head = nn.Sequential(OrderedDict([('bottleneck', ConvBNAct(self.final_stage_channel, out_channels, 1, 1, 1, self.norm_layer, self.act)),('avgpool', nn.AdaptiveAvgPool2d((1, 1))),('flatten', nn.Flatten()),('dropout', nn.Dropout(p=dropout, inplace=True)),('classifier', nn.Linear(out_channels, nclass) if nclass else nn.Identity())]))def make_stages(self, layer_infos, block):return [layer for layer_info in layer_infos for layer in self.make_layers(copy.copy(layer_info), block)]def make_layers(self, layer_info, block):layers = []for i in range(layer_info.num_layers):layers.append(block(layer_info, sd_prob=self.get_sd_prob()))layer_info.in_ch = layer_info.out_chlayer_info.stride = 1return layersdef get_sd_prob(self):sd_prob = self.stochastic_depth * (self.cur_block / self.num_block)self.cur_block += 1return sd_probdef forward(self, x):return self.head(self.blocks(self.stem(x)))def change_dropout_rate(self, p):self.head[-2] = nn.Dropout(p=p, inplace=True)def efficientnet_v2_init(model):for m in model.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out')if m.bias is not None:nn.init.zeros_(m.bias)elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):nn.init.ones_(m.weight)nn.init.zeros_(m.bias)elif isinstance(m, nn.Linear):nn.init.normal_(m.weight, mean=0.0, std=0.01)nn.init.zeros_(m.bias)def get_efficientnet_v2(model_name, pretrained, nclass=0, dropout=0.1, stochastic_depth=0.2, **kwargs):residual_config = [MBConvConfig(*layer_config) for layer_config in get_efficientnet_v2_structure(model_name)]model = EfficientNetV2(residual_config, 1280, nclass, dropout=dropout, stochastic_depth=stochastic_depth, block=MBConv, act_layer=nn.SiLU)efficientnet_v2_init(model)return model

调用EfficientNet_V2_S的结构

# 获取EfficientNet V2 S的结构配置
model_name = 'efficientnet_v2_s'
pretrained = False  # Assuming custom implementation without pretrained weights
nclass = 10  # For MNIST
dropout = 0.1
stochastic_depth = 0.2model = get_efficientnet_v2(model_name, pretrained, nclass, dropout, stochastic_depth)

调用torchvision 的MNIST数据集

import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoadertransform = transforms.Compose([transforms.Resize((224, 224)),  # Resize to match EfficientNet's expected inputtransforms.Grayscale(num_output_channels=3),  # Convert to 3-channel RGBtransforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize for 3 channels
])train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

训练和优化器

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()
epochs =100
# Training loop
model.train()
for epoch in range(epochs):for images, labels in train_loader:images, labels = images.to(device), labels.to(device)optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()print(f"Epoch {epoch+1}, Loss: {loss.item()}")

测试

def test(model, device, test_loader, criterion):model.eval()  # Set the model to evaluation modetest_loss = 0correct = 0with torch.no_grad():  # No need to track gradients for validationfor images, labels in test_loader:images, labels = images.to(device), labels.to(device)outputs = model(images)test_loss += criterion(outputs, labels).item()  # Sum up batch losspred = outputs.argmax(dim=1, keepdim=True)  # Get the index

代码审批过后上传,记得点赞和收藏

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/725677.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

华为配置智能升级功能升级设备示例

配置智能升级功能升级设备示例 组网图形 图1 配置智能升级功能组网图 背景信息组网需求配置思路前提条件操作步骤操作结果 背景信息 为了方便用户及时了解设备主流运行版本&#xff0c;快速完成升级修复&#xff0c;华为设备支持自动下载、自助升级功能。用户在设备Web网管…

【HTML】HTML基础7.2(有序列表)

目录 标签 效果 注意 标签 <ol> <li>列表内容</li> <li>列表内容</li> <li>列表内容</li> <li>列表内容</li> 。。。。。。 </ol> 效果 代码 <ol><li>银河护卫队 10000000000</li><l…

C++ LRU缓存

题目&#xff1a; //构建双向链表的节点结构&#xff08;要有两个构造函数&#xff09; struct Node{int key, val;Node* pre;Node* next;Node():key(0), val(0), pre(nullptr), next(nullptr) {}Node(int _key, int _val): key(_key), val(_val), pre(nullptr), next(nullpt…

windows无界鼠标,多机共享一套键鼠

原因 当前使用一台笔记本和一个台式机。用起来很麻烦。想要找到共享键鼠的方案。找到了无界鼠标这个软件。 安装 在两台电脑上都安装powertoy应用。 https://github.com/microsoft/PowerToys csdn下载 安装完成后找到无界鼠标打开 配置 多台电脑配置相同的key,刷新识别设…

Unity3d调用C++ dll中的函数

一、生成dll 1.新建dll工程 2. 不用管dllmain.cpp&#xff0c;添加自定义Helper.h和Helper.cpp 3.添加要在外部调用的方法 //头文件 #define DLLEXPORT extern "C" __declspec(dllexport) DLLEXPORT int _stdcall Addition(int x, int y); DLLEXPORT int _stdcal…

LeetCode每日一题只 快乐数

目录 题目介绍&#xff1a; 算法原理&#xff1a; 鸽巢原理&#xff1a; 如何找到环里元素&#xff1a; 代码实现&#xff1a; 题目介绍&#xff1a; 题目链接&#xff1a;. - 力扣&#xff08;LeetCode&#xff09; 算法原理&#xff1a; 我先简单举两个例子&#xff…

python界面开发 - OptionMenu菜单

文章目录 1. python图形界面开发1.1. Python图形界面开发——Tkinter1.2. Python图形界面开发——PyQt1.3. Python图形界面开发——wxPython1.4. Python图形界面开发—— PyGTK&#xff1a;基于GTK1.5. Python图形界面开发—— Kivy1.6. Python图形界面开发——可视化工具1.7. …

ABAQUS软件报价费用 abaqus正版购买价格多少钱?

ABAQUS软件可以完成哪些模拟&#xff1f; ABAQUS软件是一套功能强大的工程模拟的有限元软件&#xff0c;其解决问题的范围从相对简单的线性分析到许多复杂的非线性问题。ABAQUS软件中包含了一套丰富的单元库&#xff0c;可模拟任意几何形状&#xff1b;还包含了各种类型的材料…

【学习笔记】计算机视觉深度学习网络模型

这是本人学习计算机视觉CV领域深度学习模型的学习的一点点学习笔记&#xff0c;很多片子没有完成&#xff0c;可以作为学习的参考~

灵神笔记(1)----动态规划篇

文章目录 介绍动态规划入门&#xff1a;从记忆化搜索到递推打家劫舍递归记忆化递归递推滚动变量 背包0-1 背包递归写法记忆化递归 目标和记忆化搜索递推两个数组一个数组 完全背包记忆化递归搜索 零钱兑换记忆化递归递推 背包问题变形[至多|恰好|至少] 最长公共子序列记忆化搜索…

微信小程序开发系列(二十)·wxml语法·setData()修改对象类型数据、ES6 提供的展开运算符、delete和rest的用法

目录 1. 新增单个、多个属性 1.1 新增单个属性 1.2 新增多个属性 2. 修改单个、多个属性 2.1 修改单个属性 2.2 修改多个属性 3. 优化 3.1 ES6 提供的展开运算符 3.2 Object.assign()将多个对象合并为一个对象 4. 删除单个、多个属性 4.1 删除单个属性 …

搬运机器人助力制造业转型升级

随着传统制造业的转型升级&#xff0c;智能化和多样化成为行业发展的新趋势。在这个过程中&#xff0c;富唯智能搬运机器人作为智能工厂的重要搬运机器人&#xff0c;展现出了卓越的性能和适应性。 它不仅能够应对各种材料、形状和重量的搬运装卸需求&#xff0c;还能与智能物理…

跨链桥的类型总结/相关的名词解释

首先&#xff0c;这是一个会持续更新的文章&#xff0c;我会不断把自己了解到的跨链桥名词解释更新在这里。 跨链桥类型 基于传输方式分类&#xff1a; Lock and Mint&#xff1a;在一条链上锁定资产&#xff0c;在另一条链上铸造等价资产liqidity pool&#xff1a;在不同链…

数字化转型导师坚鹏:金融科技咨询方法论

金融科技咨询方法论 ——方法、做法、演法、心法 课程背景&#xff1a; 数字化转型背景下&#xff0c;很多机构存在以下问题&#xff1a; 不知道先进的金融科技咨询方法论&#xff1f; 不知道如何运作金融科技咨询项目&#xff1f; 不知道如何汇报咨询项目关键成果&…

第五十一回 李逵打死殷天赐 柴进失陷高唐州-AI发展历程和常用框架

朱仝说只要杀了李逵就上梁山&#xff0c;柴进就劝李逵先在庄上住一段时间&#xff0c;先让朱仝、雷横和吴用回了梁山。 李逵在柴进庄上住了一个月&#xff0c;碰到柴进的叔叔柴皇城病重来信叫他去一趟&#xff0c;于是李逵就随着柴进去了高唐州。 柴皇城被殷天锡气死&#xf…

世界级通讯社发稿-法新社海外发稿渠道-大舍传媒

世界级通讯社发稿-法新社海外发稿渠道-大舍传媒 美联社&#xff1a;全球最大的通讯社之一 美联社&#xff08;Associated Press&#xff09;是全球最大的通讯社之一&#xff0c;成立于1846年&#xff0c;总部位于美国纽约。该社拥有一支庞大的全球新闻团队&#xff0c;涵盖了…

mysql8安装配置(最新版)

目录 一、下载mysql8 二、安装mysql8 三、配置mysql 一、下载mysql8 下载链接&#xff1a;https://pan.quark.cn/s/58d9072e51c4 二、安装mysql8 双击msi文件 选择custom 根据所需选择组件 修改安装路径 选中execute&#xff0c;安装&#xff0c;弹出提示安装VS的提示框之后…

MYSQL5.7报1205 - Lock wait timeout exceeded; try restarting transaction

简介 今天使用navicate操作添加时&#xff0c;mysql报错误&#xff0c;错误如下 原因 这个问题的原因是在mysql中产生了事务A&#xff0c;执行了修改的语句&#xff0c;比如&#xff1a; update t1 set aget18 where id1;此时事务并未进行提交&#xff0c;事务B开始运行&am…

自学高效备考2024年AMC10:2000-2023年1250道AMC10真题解析

我们今天继续来随机看5道AMC10真题&#xff0c;以及详细解析&#xff0c;这些题目来自1250道完整的官方历年AMC10真题库。通过系统研究和吃透AMC10的历年真题&#xff0c;参加AMC10的竞赛就能拿到好名次。 即使不参加AMC10竞赛&#xff0c;初中和高中数学一定会学得比较轻松、…

PTA L2-001 紧急救援

作为一个城市的应急救援队伍的负责人&#xff0c;你有一张特殊的全国地图。在地图上显示有多个分散的城市和一些连接城市的快速道路。每个城市的救援队数量和每一条连接两个城市的快速道路长度都标在地图上。当其他城市有紧急求助电话给你的时候&#xff0c;你的任务是带领你的…