以下代码是以CIFAR10这个10分类的图片数据集训练过程的完整的代码。
训练部分
train.py主要包含以下几个部件:
- 准备训练、测试数据集
- 用DateLoader加载两个数据集,要设置好batchsize
- 创建网络模型(具体模型在model.py中)
- 设置损失函数
- 设置优化器,其中要包含优化的参数和学习率
- 初始化一些参数,如训练测试的次数、以及训练的轮数epoch
- 以训练轮数为循环进入训练
- 从训练数据中加载数据,将数据(模型的输出和目标(标签))送进损失函数中计算损失
- 梯度清零,并且反向传播损失函数,用优化器进行参数更新,并累计训练步数。
- 在保证不调优的情况下看正确率(with)
从测试集中拿数据,一样的讨论算损失,但是要算正确率- 用tensorboard可是话训练的结果
关于imgs, targets =data这句代码中的targets解释
-
imgs
(Images): 这个变量通常包含一批图像数据。在计算机视觉任务中,这些图像是模型的输入,可以是任何形式的视觉数据,比如照片、视频帧、医学影像等。在训练过程中,这些图像通过神经网络进行前向传播以生成预测结果。 -
targets
(Targets): 这个变量包含与imgs
中每个图像对应的标签或目标。标签的具体形式取决于执行的任务:- 在分类任务中,
targets
可能是类别标签,例如识别图像中的对象(猫、狗、汽车等)。 - 在对象检测任务中,
targets
可能包括对象的边界框(bounding boxes)和类别。 - 在语义分割任务中,
targets
可能是每个像素的类别标签。 - 在回归任务中,
targets
可能是一些连续值,如在面部关键点检测中的坐标点。
- 在分类任务中,
在训练过程中targets用于损失函数(交叉熵损失、均方误差等),这是模型学习并优化其参数的基础。损失函数衡量了模型预测和真实目标之间的差异,训练目标是最小化差异。
关于optimizer.step()的解释
在机器学习中,这玩意是个关键操作,就是用来根性模型参数的。
优化器和梯度下降,常用的优化算法(SGD、Adam、RMSprop等)来调整网络参数(如权重和偏差),以最小化损失函数。这个过程被称为梯度下降。
训练过程中的步骤:
- Forward Pass:输入数据进行前向计算,生成预测。
- 计算损失函数,比较网络的预测和真实计算损失
- 反向传播:通过反向传播损失,计算每个参数梯度 loss.backward()来完成。
- 更新参数optimizer.step()被调用来更新网络的参数。根据计算出的梯度和定义的优化算法,它会调整参数以减小损失。
注意:
optimizer.step()根据优化器预定义的规则和计算出的梯度来更新模型参数。在调用它之后,会执行
optimizer.zero_grad(),以便下一次迭代时从干净的状态开始。
import torch.nn
import torchvision
from torch.utils.tensorboard import SummaryWriterfrom model import *
from torch.utils.data import DataLoader
from torch import nn#准备数据集
train_data = torchvision.datasets.CIFAR10(root="../data", train=True, transform=torchvision.transforms.ToTensor(),download=True)
#测试数据集
test_data = torchvision.datasets.CIFAR10(root="../data", train=False, transform=torchvision.transforms.ToTensor(),download=True)
#length 长度
train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练数据集的长度:{}".format(train_data_size))
print("测试数据集的长度:{}".format(test_data_size))#利用DataLoader 来加载数据集
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)#创建网络模型
tudui = Tudui()#损失函数
loss_fn = nn.CrossEntropyLoss()#优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(tudui.parameters(), lr=learning_rate)#设置训练网络的一些参数
#记录训练的次数
total_train_step = 0
#测试的次数
total_test_step = 0
#训练的轮数
epoch = 10#添加 tensorboard
writer = SummaryWriter("../logs_train")for i in range(epoch):print("-------------第{}轮训练开始-------------".format(i+1))#训练步骤开始#并不需要把网络设置成训练状态才能进行训练tudui.train()for data in train_dataloader:imgs, targets =dataoutputs = tudui(imgs)loss = loss_fn(outputs, targets)#梯度清零#优化器优化模型optimizer.zero_grad()loss.backward()optimizer.step()total_train_step = total_train_step + 1#避免无用信息覆盖if total_train_step % 100 == 0:print("训练次数: {},loss: {}".format(total_train_step, loss.item()))writer.add_scalar("train_loss", loss.item(), total_train_step)#测试步骤#也不是需要把网络设置成eval状态才能进行网络的一个测试tudui.eval()total_test_loss = 0#看正确率total_accuracy = 0#在with里面的代码没有了梯度,保证不会进行调优with torch.no_grad():for data in test_dataloader:imgs, targets =dataoutputs = tudui(imgs)#一部分数据在网络模型上的损失loss = loss_fn(outputs, targets)total_test_loss = total_test_loss + lossaccuracy = (outputs.argmax(1) == targets).sum()total_accuracy = total_accuracy + accuracyprint("整体测试集上的Loss:{}".format(total_test_loss))print("整体测试的正确率:{}".format(total_accuracy/test_data_size))writer.add_scalar("train_loss", loss.item(), total_test_step)writer.add_scalar("test_accuracy", total_accuracy/test_data_size, total_test_step)#测试的步骤+1否则图画不出来total_train_step = total_test_step + 1torch.save(tudui, "tudui_{}.pth".format(i))print("模型已保存")
writer.close()
上面是一个训练过程,下面介绍一下训练准确率怎么得来的。
假设有一个2分类的模型
Model(2分类)
#下面是得分
Outputs = [[0.2,0.3],[0.1,0.4]]
#通过Argmax 变成
Preds = [1]
[1]
Inputs target=[0][1]
Preds==inputs target
#上面的这个式子返回的就是T or F
#加起来就是分类正确的个数了。
[false,true].sum()=1
这边注意一下output.argmax(x)的方向,x是0或是1,0的方向是竖着来的,1的方向是横着来的。
import torch outputs = torch.tensor([[0.1,0.2],[0.3,0.4]]) print(outputs.argmax(1)) preds = outputs.argmax(1) targets = torch.tensor([0,1]) print((preds == targets).sum())
-----------------------------------------------------未完待续1-------------------------------------------------------------
训练的一些细节:
如果有Dropout和BatchNorm等一些特殊层,需要