李宏毅 2022机器学习 HW3 boss baseline 上分记录

作业数据是所有数据都有标签的版本。

李宏毅 2022机器学习 HW3 boss baseline 上分记录

    • 1. 训练数据增强, private 0.76056
    • 2. cross validation&ensemble, private 0.81647
    • 3. test dataset augmentation, private 0.82458
    • 4. resnet, private 0.86555
    • 5. Image Normalization, private 0.87494
    • 6. 减小batch_size, private 0.895

1. 训练数据增强, private 0.76056

结论:训练数据增强、更长时间的训练、dropout都证明很有效果,实验效果提升至接近strong baseline

增强1:crop + geometry
增强2:crop + geometry + gray
另外epochs数目增加到100,patience增加到10个epochs,FC层增加 dropout(0.3)

增强代码如下

#训练数据增强代码train_tfm = transforms.Compose([# Resize the image into a fixed shape (height = width = 128)# transforms.Resize((128, 128)),transforms.RandomResizedCrop(size=(128, 128), scale=(0.8, 1)),# 几何变换transforms.RandomHorizontalFlip(p=0.5),transforms.RandomVerticalFlip(p=0.5),transforms.RandomRotation(degrees=180),transforms.RandomAffine(degrees=30),#像素变换transforms.RandomGrayscale(p=0.2), # You may add some transforms here.# ToTensor() should be the last one of the transforms.transforms.ToTensor(),
])

具体实验结果如下:
在这里插入图片描述

2. cross validation&ensemble, private 0.81647

使用5-fold cross validation,划分的时候使用分层抽样,
2.1)epochs=100, patience=10
训练时发现通常在60 epochs左右就early stop了,最终public score不如之前,但private score有提升,说明cross validation在过拟合上还是有效果的。
在这里插入图片描述
2.2)epochs=100, patience=16,再看看效果
patience增大后,效果有了一个非常明显的提升,超过strong baseline。具体看实验过程,会发现之前patience=10的时候,基本60epochs就停了,而现在patience=100的时候,early stop没有起作用,都是训练满100个epochs。猜测应该是使用5-fold的cross validation时,对比默认的train/valid,一方面训练数据更多,另一方面valid数据变少波动性更大,所以应该给更多的时间训练。
在这里插入图片描述

3. test dataset augmentation, private 0.82458

结论:此方式有效,分数进一步提升
在这里插入图片描述
测试数据的具体增强方式如下:
在步骤2的基础上,对test数据集使用了train数据集的数据增强方式,生成5张图片预测,对预测结果值平均,然后再用这个结果与原预测结果平均。以下为作业PPT相关部分。
在这里插入图片描述

4. resnet, private 0.86555

使用torchvision自带的resnet模型(按照作业要求,pretrained=False),尝试了resnet18和resnet50,效果进一步有了明显提升。public榜上超过bossline,但是从private榜上,可以看出存在一定过拟合。 另外resnet50的效果并没有比resnet18好,可能是小数据集的原因。这里均使用epochs=200,patience=16, lr=0.0003, weight_decay=1e-5。
在这里插入图片描述
在这里插入图片描述

两个注意点:
1,图片size设成224x224(论文中使用的图片尺寸),对比了128和224,两者差别很大。
2,resnet中的全连接层需要从原来的1000改成此次任务预测的类别数目11,代码如下:

def model_resnet():resnet = resnet18(pretrained=False)resnet.fc = nn.Sequential(nn.Linear(resnet.fc.in_features, 512),nn.ReLU(),nn.Dropout(0.3),nn.Linear(512, 11))return resnet

5. Image Normalization, private 0.87494

尝试了image normalization,略微有一些提升,尤其是做了test augmentation的private榜,达到了目前最高分。
在这里插入图片描述
image normalization的代码如下,注意train和test需要加上同样的norm。

height, width = 224, 224# Normally, We don't need augmentations in testing and validation.
# All we need here is to resize the PIL image and transform it into Tensor.
test_tfm = transforms.Compose([# transforms.Resize((128, 128)),transforms.Resize((height, width)),transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])# However, it is also possible to use augmentation in the testing phase.
# You may use train_tfm to produce a variety of images and then test using ensemble methods
train_tfm = transforms.Compose([# Resize the image into a fixed shape (height = width = 128)# transforms.Resize((128, 128)),# transforms.RandomResizedCrop(size=(128, 128), scale=(0.8, 1)),transforms.RandomResizedCrop(size=(height, width), scale=(0.8, 1)),# 几何变换transforms.RandomHorizontalFlip(p=0.5),transforms.RandomVerticalFlip(p=0.5),transforms.RandomRotation(degrees=180),transforms.RandomAffine(degrees=30),#像素变换transforms.RandomGrayscale(p=0.2), # You may add some transforms here.# ToTensor() should be the last one of the transforms.transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])

6. 减小batch_size, private 0.895

将batch_size从64减小到16,模型效果进一步提升,加上image normalization后,private和public双双达到目前最高分。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/97880.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

1024 画跳动的爱心#程序代码 #编程语言 #计算机

废话不多说 直接开干! 用到库 random time tkinter 快速镜像 pip install -i https://pypi.tuna.tsinghua.edu.cn/simple tkinter 上代码 import random import time from math import sin, cos, pi, log from tkinter import *CANVAS_WIDTH 640 # 画布的宽 CANVAS_HEIGH…

vue启动项目,npm run dev出现error:0308010C:digital envelope routines::unsupported

运行vue项目,npm run dev的时候出现不支持错误error:0308010C:digital envelope routines::unsupported。 在网上找了很多,大部分都是因为版本问题,修改环境之类的,原因是对的但是大多还是没能解决。经过摸索终于解决了。 方法如…

LLMs与外部应用程序交互 Interacting with external applications

在上一节中,您看到了LLM如何与外部数据集进行交互。现在让我们看看它们如何与外部应用程序进行交互。为了激发需要这种LLM增强的问题和用例的类型,您将重新审视之前在课程中看到的客户服务机器人示例。在这次浏览中,您将查看需要的集成&#…

传输层协议——TCP、UDP

目录 1、UDP 协议(用户数据报协议) 协议特点 报文首部格式 2、TCP 协议(传输控制协议) 协议特点 报文首部格式 TCP连接建立时的三次握手 TCP拆除连接的四次挥手 TCP的流量控制 TCP的拥塞控制 3、传输层端口号 三类端口…

自动驾驶学习笔记(二)——Apollo入门

#Apollo开发者# 学习课程的传送门如下,当您也准备学习自动驾驶时,可以和我一同前往: 《自动驾驶新人之旅》免费课程—> 传送门 《2023星火培训【感知专项营】》免费课程—>传送门 文章目录 前言 Ubuntu Linux文件系统 Linux指令…

Apache Tomcat安装、运行

介绍 Apache Tomcat是下面多个规范的一个开源实现:Jakarta Servlet、Jakarta Server Pages、Jakarta Expression Language、Jakarta WebSocket、Jakarta Annotations 和 Jakarta Authentication。这些规范是 Jakarta EE 平台的一部分。 Jakarta EE 平台是Java EE平…

Springboot项目log4j与logback的Jar包冲突问题

异常信息关键词: SLF4J: Class path contains multiple SLF4J bindings. ERROR in ch.qos.logback.core.joran.spi.Interpreter24:14 - no applicable action for [properties], current ElementPath is [[configuration][properties]] 详细异常信息&#xff1a…

C/C++ 进程间通信system V IPC对象超详细讲解(系统性学习day9)

目录 前言 一、system V IPC对象图解 1.流程图解: ​编辑 2.查看linux内核中的ipc对象: 二、消息队列 1.消息队列的原理 2.消息队列相关的API 2.1 获取或创建消息队列(msgget) 实例代码如下: 2.2 发送消息到消…

c++视觉图像线性混合

图像线性混合 使用 cv::addWeighted() 函数对两幅图像进行线性混合。alpha 和 beta 是两幅图像的权重,它们之和应该等于1。gamma 是一个可选的增益,这里设置为0。 你可以通过调整 alpha 的值来改变混合比例。如果 alpha0.5,则两幅图像等权重…

最短路径专题8 交通枢纽 (Floyd求最短路 )

题目: 样例: 输入 4 5 2 0 1 1 0 2 5 0 3 3 1 2 2 2 3 4 0 2 输出 0 7 思路: 由题意,绘制了该城市的地图之后,由给出的 k 个编号作为起点,求该点到各个点之间的最短距离之和最小的点是哪个,并…

C语言学生成绩录入系统

一、系统概述 该系统是一个由链表创建主菜单的框架,旨在快速创建学生成绩录入系统的主菜单结构。其主要任务包括: 实现链表的创建、插入和遍历功能,用于存储和展示学生成绩录入系统各个模块的菜单项。 2. 提供用户友好的主菜单界面&#xf…

Redis的五种常用数据类型

1.字符串 String的数据结构是简单的Key-Value模型,Value可以是字符串,也可以是数字。 String是Redis最基本的类型,是二进制安全的,意味着Redis的string可以包含任何数据,比如jpg图片。 一个redis中字符串value最大是…

AT9110H-单通道低压 H桥电机驱动芯片

AT9110H能够驱动一个直流有刷电机或其它诸如螺线管的器件。输出驱动模块由PMOSNMOS功率管构成的H桥组成,以驱动电机绕组。AT9110H能够提供高达12V1A的驱动输出。 AT9110H是SOP8封装,且是无铅产品,符合环保标准。 AT9110H具有一个PWM (IN1/IN2…

SpringBoot-黑马程序员-学习笔记(一)

8.pom文件中的parent 我们使用普通maven项目导入依赖时,通常需要在导入依赖的时候指定版本号,而springboot项目不需要指定版本号,会根据当前springboot的版本来下载对应的最稳定的依赖版本。 点开pom文件会看到这个: 继承了一个…

arm-三盏灯流水

.text .global _start _start: 1.设置GPIOE寄存器的时钟使能 RCC_MP_AHB4ENSETR[4]->1 0x50000a28 LDR R0,0x50000A28 LDR R1,[R0] ORR R1,R1,#(0x3<<4) 第四位第五位都设置为1 STR R1,[R0] 写回2.设置PE10管脚为输出模式 GPIOE_MODER[21:20]->01 0x5000…

扭线机控制

扭线机属于线缆加工设备&#xff0c;线缆加工设备种类非常多。有用于网线绞合的单绞&#xff0c;双绞机等&#xff0c;有关单绞机相关算法介绍&#xff0c;大家可以查看专栏相关文章&#xff0c;有详细介绍&#xff0c;常用链接如下&#xff1a; 线缆行业单绞机控制算法&#…

【软考】9.1 顺序表/链表/栈和队列

《线性结构》 顺序存储和链表存储 每个元素最多只有一个出度和一个入度&#xff0c;表现为一条线状链表存储结构&#xff1a;每个节点有两个域&#xff0c;即数据&#xff0c;指针域&#xff08;指向下一个逻辑上相邻的节点&#xff09; 时间复杂度&#xff1a;与其数量级成正…

【Docker】简易版harbor部署

文章目录 依赖于docker-compose下载添加执行权限测试 安装harbor下载解压修改配置文件部署配置开机自启动登录验证 使用harbor登录打标签上传下载 常见问题 依赖于docker-compose 下载 curl -L “https://github.com/docker/compose/releases/download/2.22.0/docker-compose-…

基于SVM+TensorFlow+Django的酒店评论打分智能推荐系统——机器学习算法应用(含python工程源码)+数据集+模型(三)

目录 前言总体设计系统整体结构图系统流程图 运行环境模块实现1. 数据预处理2. 模型训练及保存3. 模型应用 系统测试1. 训练准确率2. 测试效果3. 模型应用 相关其它博客工程源代码下载其它资料下载 前言 本项目以支持向量机&#xff08;SVM&#xff09;技术为核心&#xff0c;…

C语言 - 数组

目录 1. 一维数组的创建和初始化 1.1 数组的创建 1.2 数组的初始化 1.3 一维数组的使用 1.4 一维数组在内存中的存储 2. 二维数组的创建和初始化 2.1 二维数组的创建 2.2 二维数组的初始化 2.3 二维数组的使用 2.4 二维数组在内存中的存储 3. 数组越界 4. 数组作为函数参数 4.1…