Pytorch-day08-模型进阶训练技巧

PyTorch 模型进阶训练技巧

  • 自定义损失函数 如 cross_entropy + L2正则化
  • 动态调整学习率 如每十次 *0.1

典型案例:loss上下震荡
在这里插入图片描述

1、自定义损失函数

  • 1、PyTorch已经提供了很多常用的损失函数,但是有些非通用的损失函数并未提供,比如:DiceLoss、HuberLoss…等

  • 2、模型如果出现loss震荡,在经过调整数据集或超参后,现象依然存在,非通用损失函数或自定义损失函数针对特定模型会有更好的效果
    比如:DiceLoss是医学影像分割常用的损失函数,定义如下:
    在这里插入图片描述

  • Dice系数, 是一种集合相似度度量函数,通常用于计算两个样本的相似度(值范围为 [0, 1]):

  • ∣X∩Y∣表示X和Y之间的交集,∣ X ∣ 和∣ Y ∣ 分别表示X和Y的元素个数,其中,分子中的系数 2,是因为分母存在重复计算 X 和 Y 之间的共同元素的原因.

import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.optim.lr_scheduler import LambdaLR
from torch.optim.lr_scheduler import StepLR
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter
import time
import numpy as np
#DiceLoss 实现 Vnet 医学影像分割模型的损失函数
class DiceLoss(nn.Module):def __init__(self, weight=None, size_average=True):super(DiceLoss, self).__init__()def forward(self, inputs, targets, smooth=1):inputs = F.sigmoid(inputs)       inputs = inputs.view(-1)targets = targets.view(-1)intersection = (inputs * targets).sum()                  dice_loss = 1 - (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)return dice_loss
#自定义实现多分类损失函数 处理多分类
# cross_entropy + L2正则化
class MyLoss(torch.nn.Module):def __init__(self, weight_decay=0.01):super(MyLoss, self).__init__()self.weight_decay = weight_decaydef forward(self, inputs, targets):ce_loss = F.cross_entropy(inputs, targets)l2_loss = torch.tensor(0., requires_grad=True).to(inputs.device)for name, param in self.named_parameters():if 'weight' in name:l2_loss += torch.norm(param)loss = ce_loss + self.weight_decay * l2_lossreturn loss

注:

  • 在自定义损失函数时,涉及到数学运算时,我们最好全程使用PyTorch提供的张量计算接口
  • 利用Pytorch张量自带的求导机制
#超参数定义
# 批次的大小
batch_size = 16 #可选32、64、128
# 优化器的学习率
lr = 1e-4
#运行epoch
max_epochs = 2
# 方案二:使用“device”,后续对要使用GPU的变量用.to(device)即可
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") # 指明调用的GPU为1号
# 数据读取
#cifar10数据集为例给出构建Dataset类的方式
from torchvision import datasets#“data_transform”可以对图像进行一定的变换,如翻转、裁剪、归一化等操作,可自己定义
data_transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])train_cifar_dataset = datasets.CIFAR10('cifar10',train=True, download=False,transform=data_transform)
test_cifar_dataset = datasets.CIFAR10('cifar10',train=False, download=False,transform=data_transform)#构建好Dataset后,就可以使用DataLoader来按批次读入数据了
train_loader = torch.utils.data.DataLoader(train_cifar_dataset, batch_size=batch_size, num_workers=4, shuffle=True, drop_last=True)test_loader = torch.utils.data.DataLoader(test_cifar_dataset, batch_size=batch_size, num_workers=4, shuffle=False)
# restnet50 pretrained
Resnet50 = torchvision.models.resnet50(pretrained=True)
Resnet50.fc.out_features=10
print(Resnet50)
ResNet((conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)))(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))(fc): Linear(in_features=2048, out_features=10, bias=True)
)
#训练&验证# 定义损失函数和优化器
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# 损失函数:自定义损失函数
criterion = MyLoss()
# 优化器
optimizer = torch.optim.Adam(Resnet50.parameters(), lr=lr)
epoch = max_epochs
Resnet50 = Resnet50.to(device)
total_step = len(train_loader)
train_all_loss = []
test_all_loss = []for i in range(epoch):Resnet50.train()train_total_loss = 0train_total_num = 0train_total_correct = 0for iter, (images,labels) in enumerate(train_loader):images = images.to(device)labels = labels.to(device)outputs = Resnet50(images)loss = criterion(outputs,labels)train_total_correct += (outputs.argmax(1) == labels).sum().item()#backwordoptimizer.zero_grad()loss.backward()optimizer.step()train_total_num += labels.shape[0]train_total_loss += loss.item()print("Epoch [{}/{}], Iter [{}/{}], train_loss:{:4f}".format(i+1,epoch,iter+1,total_step,loss.item()/labels.shape[0]))Resnet50.eval()test_total_loss = 0test_total_correct = 0test_total_num = 0for iter,(images,labels) in enumerate(test_loader):images = images.to(device)labels = labels.to(device)outputs = Resnet50(images)loss = criterion(outputs,labels)test_total_correct += (outputs.argmax(1) == labels).sum().item()test_total_loss += loss.item()test_total_num += labels.shape[0]print("Epoch [{}/{}], train_loss:{:.4f}, train_acc:{:.4f}%, test_loss:{:.4f}, test_acc:{:.4f}%".format(i+1, epoch, train_total_loss / train_total_num, train_total_correct / train_total_num * 100, test_total_loss / test_total_num, test_total_correct / test_total_num * 100))train_all_loss.append(np.round(train_total_loss / train_total_num,4))test_all_loss.append(np.round(test_total_loss / test_total_num,4))
Epoch [1/10], Iter [3124/3125], train_loss:0.054007
Epoch [1/10], Iter [3125/3125], train_loss:0.042914---------------------------------------------------------------------------

2、动态调整学习率

2.1 torch.optim.lr_scheduler

学习率选择的问题:

  • 1、学习率设置过小,会极大降低收敛速度,增加训练时间
  • 2、学习率设置太大,可能导致参数在最优解两侧来回振荡

以上问题都是学习率设置不满足模型训练的需求,解决方案:

  • PyTorch中提供了scheduler

官方API提供的torch.optim.lr_scheduler动态学习率:

  • lr_scheduler.LambdaLR

  • lr_scheduler.MultiplicativeLR

  • lr_scheduler.StepLR

  • lr_scheduler.MultiStepLR

  • lr_scheduler.ExponentialLR

  • lr_scheduler.CosineAnnealingLR

  • lr_scheduler.ReduceLROnPlateau

  • lr_scheduler.CyclicLR

  • lr_scheduler.OneCycleLR

  • lr_scheduler.CosineAnnealingWarmRestarts

2.2、torch.optim.lr_scheduler.LambdaLR

torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=- 1, verbose=False)

# LambdaLR 实现
lr_lambda = f(epoch)
new_lr = lr_lambda * init_lr

思想:初始学习率乘以系数,由于每一次乘系数都是乘初始学习率,因此系数往往是epoch的函数。

#伪代码:Assuming optimizer has two groups.lambda1 = lambda epoch: 1 / (epoch+1)scheduler = LambdaLR(optimizer, lr_lambda=lambda1)for epoch in range(100):train(...)validate(...)scheduler.step()

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-f4P8ROuA-1692613806234)(attachment:image-2.png)]

MultiplicativeLR

torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda, last_epoch=- 1, verbose=False)

与LambdaLR不同,该方法用前一次的学习率乘以lr_lambda,因此通常lr_lambda函数不需要与epoch有关。

new_lr = lr_lambda * old_lr

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-g2URgkPf-1692613806234)(attachment:image.png)]

2.2、自定义scheduler

官方给的动态学习率调整的API如果均不能满足我们的诉求,应该怎么办?

我们可以通过自定义函数adjust_learning_rate来改变param_group中lr的值
  • 1、官方的API均不能满足诉求
  • 2、我们根据adjust_learning_rate实现学习率调整方法
# 训练中调用学习率方法
optimizer = torch.optim.SGD(model.parameters(),lr = args.lr,momentum = 0.9)
for epoch in range(10):train(...)validate(...)adjust_learning_rate(optimizer,epoch)
#函数:分段,每隔几(10)段个epoch,第一个epoch为序号0不计,使学习率变乘以0.1的epoch次方数
def adjust_learning_rate(optim, epoch, size=10, gamma=0.1):if (epoch + 1) % size == 0:pow = (epoch + 1) // sizelr = learning_rate * np.power(gamma, pow)for param_group in optim.param_groups:param_group['lr'] = lr

代码实例

  • lr_scheduler.LambdaLR
  • adjust_learning_rate

#训练&验证
writer = SummaryWriter("../train_skills")
# 定义损失函数和优化器
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# 损失函数
criterion = nn.CrossEntropyLoss()
# 优化器
optimizer = torch.optim.Adam(Resnet50.parameters(), lr=lr)# 自定义 scheduler 
scheduler_my = LambdaLR(optimizer, lr_lambda=lambda epoch: 1/(epoch+1),verbose = True)
print("初始化的学习率:", optimizer.defaults['lr'])epoch = max_epochs
Resnet50 = Resnet50.to(device)
total_step = len(train_loader)
train_all_loss = []
test_all_loss = []for i in range(epoch):Resnet50.train()train_total_loss = 0train_total_num = 0train_total_correct = 0for iter, (images,labels) in enumerate(train_loader):images = images.to(device)labels = labels.to(device)outputs = Resnet50(images)loss = criterion(outputs,labels)train_total_correct += (outputs.argmax(1) == labels).sum().item()#backwordoptimizer.zero_grad()loss.backward()optimizer.step()train_total_num += labels.shape[0]train_total_loss += loss.item()print("Epoch [{}/{}], Iter [{}/{}], train_loss:{:4f}".format(i+1,epoch,iter+1,total_step,loss.item()/labels.shape[0]))writer.add_scalar("lr", optim.param_groups[0]['lr'], i)print("第%d个epoch的学习率:%f" % (epoch, optimizer.param_groups[0]['lr']))scheduler_my.step() #scheduler#自定义调整lr
#     adjust_learning_rate(optimizer, i)Resnet50.eval()test_total_loss = 0test_total_correct = 0test_total_num = 0for iter,(images,labels) in enumerate(test_loader):images = images.to(device)labels = labels.to(device)outputs = Resnet50(images)loss = criterion(outputs,labels)test_total_correct += (outputs.argmax(1) == labels).sum().item()test_total_loss += loss.item()test_total_num += labels.shape[0]print("Epoch [{}/{}], train_loss:{:.4f}, train_acc:{:.4f}%, test_loss:{:.4f}, test_acc:{:.4f}%".format(i+1, epoch, train_total_loss / train_total_num, train_total_correct / train_total_num * 100, test_total_loss / test_total_num, test_total_correct / test_total_num * 100))train_all_loss.append(np.round(train_total_loss / train_total_num,4))test_all_loss.append(np.round(test_total_loss / test_total_num,4))
writer.close()
Adjusting learning rate of group 0 to 1.0000e-04.
初始化的学习率: 0.0001
Epoch [1/2], Iter [1/3125], train_loss:0.777986

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/54784.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

如何优化因为高亮造成的大文本(大字段)检索缓慢问题

首先还是说一下背景,工作中用到了 elasticsearch 的检索以及高亮展示,但是索引中的content字段是读取的大文本内容,所以后果就是索引的单个字段很大,造成单独检索请求的时候速度还可以,但是加入高亮之后检索请求的耗时…

开始MySQL之路——MySQL约束概述详解

MySQL约束 create table [if not exists] 表名(字段名1 类型[(宽度)] [约束条件] [comment 字段说明],字段名2 类型[(宽度)] [约束条件] [comment 字段说明],字段名3 类型[(宽度)] [约束条件] [comment 字段说明] )[表的一些设置]; 概念 约束英文:constraint 约束实…

GeoHash之存储篇

前言: 在上一篇文章GeoHash——滴滴打车如何找出方圆一千米内的乘客主要介绍了GeoHash的应用是如何的,本篇文章我想要带大家探索一下使用什么样的数据结构去存储这些Base32编码的经纬度能够节省内存并且提高查询的效率。 前缀树、跳表介绍: …

数据结构队列的实现

本章介绍数据结构队列的内容,我们会从队列的定义以及使用和OJ题来了解队列,话不多说,我们来实现吧 队列 1。队列的概念及结构 队列:只允许在一端进行插入数据操作,在另一端进行删除数据操作的特殊线性表,…

centos7搭建apache作为文件站后,其他人无法访问解决办法

在公司内网的一个虚拟机上搭建了httpsd服务,准备作为内部小伙伴们的文件站,但是搭建好之后发现别的小伙伴是无法访问我机器的。 于是寻找一下原因,排查步骤如下: 1.netstat -lnp 和 ps aux 先看下端口和 服务情况 发现均正常 2.…

设计模式-工厂设计模式

核心思想 在简单工厂模式的基础上进一步的抽象化具备更多的可扩展和复用性,增强代码的可读性使添加产品不需要修改原来的代码,满足开闭原则 优缺点 优点 符合单一职责,每个工厂只负责生产对应的产品符合开闭原则,添加产品只需添…

探讨uniapp的路由与页面生命周期问题

1 首先我们引入页面路由 2 页面生命周期函数 onLoad() {console.log(页面加载)},onShow() {console.log(页面显示)},onReady(){console.log(页面初次显示)},onHide() {console.log(页面隐藏)},onUnload() {console.log(页面卸载)},onBackPress(){console.log(页面返回)}3 页面…

代码随想录算法训练营之JAVA|第三十九天|474. 一和零

今天是第39天刷leetcode,立个flag,打卡60天。 算法挑战链接 474. 一和零https://leetcode.cn/problems/ones-and-zeroes/ 第一想法 题目理解:找到符合条件的子集,这又是一个组合的问题。 看到这个题目的时候,我好像…

JAVA学习-愚见

JAVA学习-愚见 分享一下Java的学习路线,仅供参考【本人亲测,真实有效】 1、尽可能推荐较新的课程 2、大部分视频在B站上直接搜关键词就行【自学,B大的学生】 文章目录 JAVA学习-愚见前期准备Java基础课程练手项目 数据库JavaWeb前端基础 Vue…

学习设计模式之观察者模式,但是宝可梦

前言 作者在准备秋招中,学习设计模式,做点小笔记,用宝可梦为场景举例,有错误欢迎指出。 观察者模式 观察者模式定义了一种一对多的依赖关系,一个对象的状态改变,其他所有依赖者都会接收相应的通知。 所…

匈牙利算法 in 二分图匹配

https://www.luogu.com.cn/problem/P3386 重新看这个算法,才发现自己没有理解。 左边的点轮流匹配,看是否能匹配成功。对右边的点进行记录是否尝试过 然后有空就进,别人能退的就进 遍历左部点: 尝试匹配过程:

[C++] STL_vector 迭代器失效问题

文章目录 1、前言2、情况一:底层空间改变的操作3、情况二:指定位置元素的删除操作4、g编译器对迭代器失效检测4.1 扩容4.2 erase删除任意位置(非尾删)4.3 erase尾删 5、总结 1、前言 **迭代器的主要作用就是让算法能够不用关心底…

DataWhale 机器学习夏令营第三期——任务二:可视化分析

DataWhale 机器学习夏令营第三期 学习记录二 (2023.08.23)——可视化分析1.赛题理解2. 数据可视化分析2.1 用户维度特征分布分析2.2 时间特征分布分析 DataWhale 机器学习夏令营第三期 ——用户新增预测挑战赛 学习记录二 (2023.08.23)——可视化分析 2023.08.17 已跑通baseli…

Android沉浸式实现(记录)

沉浸式先看效果 直接上代码 Android manifest文件 android:theme"style/Theme.AppCompat.NoActionBar"布局文件 <?xml version"1.0" encoding"utf-8"?> <androidx.constraintlayout.widget.ConstraintLayout xmlns:android"ht…

mit s0681 lab2 Trace系统调用实现

实验一 实现一个用户级别的程序&#xff0c;功能为&#xff0c;指定系统调用后&#xff0c;跟踪程序的系统调用情况 分析实验 实验目标为实现一个程序去跟踪指定程序的系统调用。因此目标有两个 实现一个程序跟踪目标程序的系统调用 实现1&#xff0c;就需要在用户这边实…

4.18 TCP 和 UDP 可以使用同一个端口吗?

目录 TCP 和 UDP 可以同时绑定相同的端口吗&#xff1f; 多个 TCP 服务进程可以绑定同一个端口吗&#xff1f; 重启 TCP 服务进程时&#xff0c;为什么会有“Address in use”的报错信息&#xff1f; 重启 TCP 服务进程时&#xff0c;如何避免“Address in use”的报错信息…

HarmonyOS/OpenHarmony应用开发-ArkTS语言渲染控制LazyForEach数据懒加载

LazyForEach从提供的数据源中按需迭代数据&#xff0c;并在每次迭代过程中创建相应的组件。当LazyForEach在滚动容器中使用了&#xff0c;框架会根据滚动容器可视区域按需创建组件&#xff0c;当组件划出可视区域外时&#xff0c;框架会进行组件销毁回收以降低内存占用。一、接…

智驾算力芯片市场仍处于「波动」周期,Momenta曝光自研NPU

用「冷热不均」来形容当下的汽车芯片赛道&#xff0c;再合适不过了。 本周&#xff0c;英伟达公布的第二财季&#xff08;5-7月&#xff09;营收达到创纪录的135亿美元&#xff0c;大幅超出了此前市场普遍预期的略高于110亿美元&#xff0c;同比增速更是达到了101%。 其中&…

接口测试总结分享(http与rpc)

接口测试是测试系统组件间接口的一种测试。接口测试主要用于检测外部系统与系统之间以及内部各个子系统之间的交互点。测试的重点是要检查数据的交换&#xff0c;传递和控制管理过程&#xff0c;以及系统间的相互逻辑依赖关系等。 一、了解一下HTTP与RPC 1. HTTP&#xff08;H…

【python】输出高亮信息的内容

背景 日志是定位问题和数据分析的关键手段之一&#xff0c;尤其是在调试阶段&#xff0c;高效的、具有辨识度的日志可以非常快速准确的进行问题定位。shell中的echo命令自带文本格式化输出的功能&#xff0c;我们先来回顾下基本的语法&#xff0c;然后套用到python中即可。 s…