1 问题
不同的batch_size对训练集和验证集的精度和损失的影响有多大?
2 方法
通过设置不同batch_size算出不同batch_size对应的训练集精度、训练集损失和验证集的精度和损失,通过数据可视化将精度和损失展示出来,比较出不同batch_size对他们的影响
基础参数配置:
训练周期:50
学习率:0.2
优化器:SGD
batch_size:32 64 128 256
步骤:
设置不同的batch_szie
for batch_size in [32,64,128,256]:得到不同batch_size对应的训练集精度、训练集损失和验证集的精度和损失,
result1= []
result2= []
result3= []
result4= []
for batch_size in [32,64,128,256]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
train_all_ds = torchvision.datasets.MNIST(root="data", download=True, train=True, transform=transform, )
# 将训练集划分为训练集+验证集
train_ds, val_ds = torch.utils.data.random_split(train_all_ds, [50000,10000])
test_ds = torchvision.datasets.MNIST(root="data", download=True, train=False, transform=transform, )
train_loader = DataLoader(dataset=train_ds,batch_size=batch_size, shuffle=True,)
val_loader = DataLoader(dataset=val_ds,batch_size=batch_size,)
test_loader = DataLoader(dataset=test_ds,batch_size=batch_size,)
# (5) 网络的输入、输出以及测试网络的性能(不经过任何训练的网络)
net = MyNet().to(device)
optimizer = torch.optim.SGD(net.parameters(), lr=0.2)
loss_fn = torch.nn.CrossEntropyLoss()
# (6)训练周期
begin_time = time()
train_accuracy_list = []
train_loss_list = []
val_accuracy_list = []
val_loss_list = []
epoch = 50
for t in range(epoch):
print(f"Epoch {t + 1}")
train_accuracy, train_loss = train(train_loader, net, loss_fn, optimizer)
train_accuracy_list.append(train_accuracy)
train_loss_list.append(train_loss)
print(f'Train Acc:{train_accuracy}, Train Val:{train_loss}')
val_accuracy, val_loss = val(val_loader, net, loss_fn)
val_accuracy_list.append(val_accuracy)
val_loss_list.append(val_loss)
# print(train_accuracy_list)
# print(train_loss_list)
# print(val_accuracy_list)
# print(val_loss_list)
result1.append(train_accuracy_list)
result2.append(train_loss_list)
result3.append(val_accuracy_list)
result4.append(val_loss_list)
train_accuracy_list_1=result1[0]
train_accuracy_list_2=result1[1]
train_accuracy_list_3=result1[2]
train_accuracy_list_4=result1[3]
# print(result1)
# print(train_accuracy_list_1)
train_loss_list_1=result2[0]
train_loss_list_2=result2[1]
train_loss_list_3=result2[2]
train_loss_list_4=result2[3]
val_accuracy_list_1=result3[0]
val_accuracy_list_2=result3[1]
val_accuracy_list_3=result3[2]
val_accuracy_list_4=result3[3]
val_loss_list_1=result4[0]
val_loss_list_2=result4[1]
val_loss_list_3=result4[2]
val_loss_list_4=result4[3]
3.数据可视化将精度和损失展示出来
def picture(data1,data2,data3,data4,data5,data6,data7,data8,data9,data10,data11,data12,data13,data14,data15,data16):
ax = plt.subplot(1, 2, 1)
ax.plot(range(len(data1)), data1, ls='-', color='red',label='batch_size=32')
ax.plot(range(len(data2)), data2, ls='-', color='blue',label='batch_size=64')
ax.plot(range(len(data3)), data3, ls='-', color='green',label='batch_size=128')
ax.plot(range(len(data4)), data4, ls='-', color='black',label='batch_size=256')
ax.plot(range(len(data5)), data5, ls='-', color='red')
ax.plot(range(len(data6)), data6, ls='-', color='blue')
ax.plot(range(len(data7)), data7, ls='-', color='green')
ax.plot(range(len(data8)), data8, ls='-', color='black')
ax.legend(['batch_size=32', 'batch_size=64','batch_size=128','batch_size=256'])
plt.rcParams['font.sans-serif']=['SimHei']
ax.set_title('上面是train_accuracy_list。下面是train_loss_list', fontsize=16)
ax.set_xlabel('Epcho')
ax1 = plt.subplot(1, 2, 2)
ax1.plot(range(len(data9)), data9, ls='-', color='red')
ax1.plot(range(len(data10)), data10, ls='-', color='blue')
ax1.plot(range(len(data11)), data11, ls='-', color='green')
ax1.plot(range(len(data12)), data12, ls='-', color='black')
ax1.plot(range(len(data13)), data13, ls='-', color='red')
ax1.plot(range(len(data14)), data14, ls='-', color='blue')
ax1.plot(range(len(data15)), data15, ls='-', color='green')
ax1.plot(range(len(data16)), data16, ls='-', color='black')
ax1.legend(['batch_size=32', 'batch_size=64','batch_size=128','batch_size=256'])
plt.rcParams['font.sans-serif']=['SimHei']
ax1.set_title('上面是(val_accuracy_list,c下面是val_loss_list', fontsize=16)
ax1.set_xlabel('Epcho')
plt.show()
可视化结果:
3 结语
针对该问题通过循环设置不同的batch_size设置得到对应的训练集精度、训练集损失和验证集的精度和损失,然后将它们存到对应的列表里面,然后通过索引将它们拿出来,最后通过数据可视化将它们展示出来,比较结果。