不同的batch_size对精度和损失的影响研究

1 问题

不同的batch_size对训练集和验证集的精度和损失的影响有多大?

2 方法

通过设置不同batch_size算出不同batch_size对应的训练集精度、训练集损失和验证集的精度和损失,通过数据可视化将精度和损失展示出来,比较出不同batch_size对他们的影响

基础参数配置:

  1. 训练周期:50

  2. 学习率:0.2

  3. 优化器:SGD

  4. batch_size:32 64 128 256

步骤:

  1. 设置不同的batch_szie
    for batch_size in [32,64,128,256]:

  2. 得到不同batch_size对应的训练集精度、训练集损失和验证集的精度和损失,

result1= []

    result2= []

    result3= []

    result4= []

    for batch_size in [32,64,128,256]:

        device = 'cuda' if torch.cuda.is_available() else 'cpu'

        train_all_ds = torchvision.datasets.MNIST(root="data", download=True, train=True, transform=transform, )

        # 将训练集划分为训练集+验证集

        train_ds, val_ds = torch.utils.data.random_split(train_all_ds, [50000,10000])

        test_ds = torchvision.datasets.MNIST(root="data", download=True, train=False, transform=transform, )

        train_loader = DataLoader(dataset=train_ds,batch_size=batch_size, shuffle=True,)

        val_loader = DataLoader(dataset=val_ds,batch_size=batch_size,)

        test_loader = DataLoader(dataset=test_ds,batch_size=batch_size,)

        # (5) 网络的输入、输出以及测试网络的性能(不经过任何训练的网络)

        net = MyNet().to(device)

        optimizer = torch.optim.SGD(net.parameters(), lr=0.2)

        loss_fn = torch.nn.CrossEntropyLoss()

        # (6)训练周期

        begin_time = time()

        train_accuracy_list = []

        train_loss_list = []

        val_accuracy_list = []

        val_loss_list = []

        epoch = 50

        for t in range(epoch):

            print(f"Epoch {t + 1}")

            train_accuracy, train_loss = train(train_loader, net, loss_fn, optimizer)

            train_accuracy_list.append(train_accuracy)

            train_loss_list.append(train_loss)

            print(f'Train Acc:{train_accuracy}, Train Val:{train_loss}')

            val_accuracy, val_loss = val(val_loader, net, loss_fn)

            val_accuracy_list.append(val_accuracy)

            val_loss_list.append(val_loss)

        # print(train_accuracy_list)

        # print(train_loss_list)

        # print(val_accuracy_list)

        # print(val_loss_list)

        result1.append(train_accuracy_list)

        result2.append(train_loss_list)

        result3.append(val_accuracy_list)

        result4.append(val_loss_list)

    train_accuracy_list_1=result1[0]

    train_accuracy_list_2=result1[1]

    train_accuracy_list_3=result1[2]

    train_accuracy_list_4=result1[3]

# print(result1)

# print(train_accuracy_list_1)

    train_loss_list_1=result2[0]

    train_loss_list_2=result2[1]

    train_loss_list_3=result2[2]

    train_loss_list_4=result2[3]

    val_accuracy_list_1=result3[0]

    val_accuracy_list_2=result3[1]

    val_accuracy_list_3=result3[2]

    val_accuracy_list_4=result3[3]

    val_loss_list_1=result4[0]

    val_loss_list_2=result4[1]

    val_loss_list_3=result4[2]

    val_loss_list_4=result4[3]

3.数据可视化将精度和损失展示出来

def picture(data1,data2,data3,data4,data5,data6,data7,data8,data9,data10,data11,data12,data13,data14,data15,data16):

    ax = plt.subplot(1, 2, 1)

    ax.plot(range(len(data1)), data1, ls='-', color='red',label='batch_size=32')

    ax.plot(range(len(data2)), data2, ls='-', color='blue',label='batch_size=64')

    ax.plot(range(len(data3)), data3, ls='-', color='green',label='batch_size=128')

    ax.plot(range(len(data4)), data4, ls='-', color='black',label='batch_size=256')

    ax.plot(range(len(data5)), data5, ls='-', color='red')

    ax.plot(range(len(data6)), data6, ls='-', color='blue')

    ax.plot(range(len(data7)), data7, ls='-', color='green')

    ax.plot(range(len(data8)), data8, ls='-', color='black')

    ax.legend(['batch_size=32', 'batch_size=64','batch_size=128','batch_size=256'])

    plt.rcParams['font.sans-serif']=['SimHei']

    ax.set_title('上面是train_accuracy_list。下面是train_loss_list', fontsize=16)

    ax.set_xlabel('Epcho')

    ax1 = plt.subplot(1, 2, 2)

    ax1.plot(range(len(data9)), data9, ls='-', color='red')

    ax1.plot(range(len(data10)), data10, ls='-', color='blue')

    ax1.plot(range(len(data11)), data11, ls='-', color='green')

    ax1.plot(range(len(data12)), data12, ls='-', color='black')

    ax1.plot(range(len(data13)), data13, ls='-', color='red')

    ax1.plot(range(len(data14)), data14, ls='-', color='blue')

    ax1.plot(range(len(data15)), data15, ls='-', color='green')

    ax1.plot(range(len(data16)), data16, ls='-', color='black')

    ax1.legend(['batch_size=32', 'batch_size=64','batch_size=128','batch_size=256'])

    plt.rcParams['font.sans-serif']=['SimHei']

    ax1.set_title('上面是(val_accuracy_list,c下面是val_loss_list', fontsize=16)

    ax1.set_xlabel('Epcho')

    plt.show()

可视化结果:

ae63ded059cb1d7bd0b189d8005ef024.png

3 结语

针对该问题通过循环设置不同的batch_size设置得到对应的训练集精度、训练集损失和验证集的精度和损失,然后将它们存到对应的列表里面,然后通过索引将它们拿出来,最后通过数据可视化将它们展示出来,比较结果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/782757.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

CTK插件框架学习-插件注册调用(03)

CTK插件框架学习-新建插件(02)https://mp.csdn.net/mp_blog/creation/editor/136923735 一、CTK插件组成 接口类:对外暴露的接口,供其他插件调用实现类:实现接口内的方法激活类:负责将插件注册到CTK框架中 二、接口、插件、服务…

文生视频大模型Sora的复现经验

大家好,我是herosunly。985院校硕士毕业,现担任算法研究员一职,热衷于机器学习算法研究与应用。曾获得阿里云天池比赛第一名,CCF比赛第二名,科大讯飞比赛第三名。拥有多项发明专利。对机器学习和深度学习拥有自己独到的…

BFS专题

1、BFS解决FloodFill算法 1、1图像渲染 733. 图像渲染 - 力扣(LeetCode) class Solution {typedef pair<int,int> PII;int dx[4] = {0,0,1,-1};int dy[4] = {1,-1,0,0}; public:vector<vector<int>> floodFill(vector<vector<int>>& i…

RIP环境下的MGRE 综合实验

实验题目及要求&#xff1a; 1.R5为ISP&#xff0c;只能进行IP地址配置&#xff0c;其所有地址均配为公有IP地址 2.R1和R5间使用PPP的PAP认证&#xff0c;R5为主认证方; R2于R5之间使用PPP的chap认证&#xff0c;R5为主认证方&#xff1b; R3于R5之间使用HDLC封装。 3.R1/…

【C++】为什么能实现函数重载

从C语言一路学到C的途中&#xff0c;C语言C语言相比&#xff0c;多了个函数重载&#xff0c;那么函数重载是如何实现的呢&#xff0c;为什么C语言无法支持&#xff0c;在本篇博客中&#xff0c;将会讲解C为何能实现函数重载。 一.编译过程 C能实现函数重载&#xff0c;而C语言不…

QT 二维坐标系显示坐标点及点与点的连线-通过定时器自动添加随机数据点

QT 二维坐标系显示坐标点及点与点的连线-通过定时器自动添加随机数据点 功能介绍头文件C文件运行过程 功能介绍 上面的代码实现了一个简单的 Qt 应用程序&#xff0c;其功能包括&#xff1a; 创建一个 MainWindow 类&#xff0c;继承自 QMainWindow&#xff0c;作为应用程序的…

2024软件设计师备考讲义——UML(统一建模语言)

UML的概念 用例图的概念 包含 <<include>>扩展<<exted>>泛化 用例图&#xff08;也可称用例建模&#xff09;描述的是外部执行者&#xff08;Actor&#xff09;所理解的系统功能。用例图用于需求分析阶段&#xff0c;它的建立是系统开发者和用户反复…

Pyppeteer中Chromium安装步骤

1、下载压缩文件 在官网下载chrome-win.zip文件 2、终端下载pyppeteer 首先在Pycharm终端运行pip install pyppeteer 3、查找文件默认路径 在运行以下代码&#xff0c;找到可执行文件默认路径 import pyppeteer.chromium_downloader print(默认版本是&#xff1a;{}.forma…

牛角工具箱源码 轻松打造个性化在线工具箱

&#x1f389; Whats this&#xff1f; 这是一款在线工具箱程序&#xff0c;您可以通过安装扩展增强她的功能 通过插件模板的功能&#xff0c;您也可以把她当做网页导航来使用~ 觉得该项目不错的可以给个Star~ &#x1f63a; 演示地址 https://tool.aoaostar.com &#x1f…

TCP网络协议栈和Posix网络部分API总结

文章目录 Posix网络部分API综述TCP协议栈通信过程TCP三次握手和四次挥手&#xff08;看下图&#xff09;三次握手常见问题&#xff1f;为什么是三次握手而不是两次&#xff1f;三次握手和哪些函数有关&#xff1f;TCP的生命周期是从什么时候开始的&#xff1f; 四次挥手通信状态…

HarmonyOS实战开发-如何实现一个自定义抽奖圆形转盘

介绍 本篇Codelab是基于画布组件、显式动画&#xff0c;实现的一个自定义抽奖圆形转盘。包含如下功能&#xff1a; 通过画布组件Canvas&#xff0c;画出抽奖圆形转盘。通过显式动画启动抽奖功能。通过自定义弹窗弹出抽中的奖品。 相关概念 Stack组件&#xff1a;堆叠容器&am…

从0开始搭建基于VUE的前端项目(一) 项目创建和配置

准备与版本 安装nodejs(v20.11.1)安装vue脚手架(@vue/cli 5.0.8) ,参考(https://cli.vuejs.org/zh/)vue版本(2.7.16),vue2的最后一个版本vue.config.js的配置详解(https://cli.vuejs.org/zh/config/)element-ui(2.15.14)(https://element.eleme.io/)vuex(3.6.2) (https://…

K8S命令行可视化实验

以下为K8s命令行可视化工具的实验内容&#xff0c;相比于直接使用命令行&#xff0c;可视化工具可能更直观、更易于操作。 Lens Lens是用于监控和调试的K8S IDE。可以在Windows、Linux以及Mac桌面上完美运行。在 Kubernetes 上&#xff1a; 托管地址&#xff1a;github/lensa…

机器人运动控制

一、基础 1.1 矢量速度和旋转速度 矢量速度用来控制运动方向&#xff0c;任何一个方向都可以看成x、y、z三轴方向的合。单位规定是m/s。 旋转速度用来控制旋转方向&#xff0c;可以看成x、y、z三轴方向旋转的合。单位规定是pi/s。 速度消息包&#xff0c;可以在ROS Index上搜…

助力福建新型职业农民培育 北方天途推进无人机植保应用培训

为加强新型职业农民的职业培育&#xff0c;扩展新型农民的知识范围和专业技术水平&#xff0c;推进农业供给侧结构性改革。日前&#xff0c;在农业部门的大力支持下&#xff0c;北方天途航空和宁德天禾科技服务携手为福建省农民朋友开展了植保无人机驾驶员的应用培训。福建省农…

网页布局案例 浮动

这里主要讲浮动 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>Title</title><style>*{padding: 0;margin: 0;}.header{height: 40px;background-color: #333;}.nav{width: 1226px;heig…

深入理解数据结构(2):顺序表和链表详解

文章主题&#xff1a;顺序表和链表详解&#x1f331;所属专栏&#xff1a;深入理解数据结构&#x1f4d8;作者简介&#xff1a;更新有关深入理解数据结构知识的博主一枚&#xff0c;记录分享自己对数据结构的深入解读。&#x1f604;个人主页&#xff1a;[₽]的个人主页&#x…

机器学习——降维算法-奇异值分解(SVD)

机器学习——降维算法-奇异值分解&#xff08;SVD&#xff09; 在机器学习中&#xff0c;降维是一种常见的数据预处理技术&#xff0c;用于减少数据集中特征的数量&#xff0c;同时保留数据集的主要信息。奇异值分解&#xff08;Singular Value Decomposition&#xff0c;简称…

csp资料

头文件 #include <bits/stdc.h> using namespace std isdigit(c); isalpha(c); switch(type){case value : 操作 } continue;//结束本轮循环 break;//结束所在的整个循环tips: //除法变乘法来算 //减法变加法 num1e42;//"1e4"表示10的4次方//用于移除容器中相…

某国投集团知识竞赛活动方案

一、抽签分组办法 1.抽签&#xff1a;参赛队伍赛前进行抽签分组。 2.分组&#xff1a;全部报名参赛队伍按照抽签顺序分为4组&#xff0c;每组7支队伍进行预赛&#xff0c;9月16日上午1、2组进行初赛&#xff0c;9月16日下午3、4组进行初赛。每组决出的前三名进入决赛。 二、初…