Pytorch损失函数losses简介

一般来说,监督学习的目标函数由损失函数和正则化项组成。(Objective = Loss + Regularization)

Pytorch中的损失函数一般在训练模型时候指定。

注意Pytorch中内置的损失函数的参数和tensorflow不同,是y_pred在前,y_true在后,而Tensorflow是y_true在前,y_pred在后。

对于回归模型,通常使用的内置损失函数是均方损失函数nn.MSELoss 。

对于二分类模型,通常使用的是二元交叉熵损失函数nn.BCELoss (输入已经是sigmoid激活函数之后的结果) 或者 nn.BCEWithLogitsLoss (输入尚未经过nn.Sigmoid激活函数) 。

对于多分类模型,一般推荐使用交叉熵损失函数 nn.CrossEntropyLoss。 (y_true需要是一维的,是类别编码。y_pred未经过nn.Softmax激活。)

此外,如果多分类的y_pred经过了nn.LogSoftmax激活,可以使用nn.NLLLoss损失函数(The negative log likelihood loss)。 这种方法和直接使用nn.CrossEntropyLoss等价。

如果有需要,也可以自定义损失函数,自定义损失函数需要接收两个张量y_pred,y_true作为输入参数,并输出一个标量作为损失函数值。

Pytorch中的正则化项一般通过自定义的方式和损失函数一起添加作为目标函数。

一,内置损失函数

内置的损失函数一般有类的实现和函数的实现两种形式。

如:nn.BCE 和 F.binary_cross_entropy 都是二元交叉熵损失函数,前者是类的实现形式,后者是函数的实现形式。

实际上类的实现形式通常是调用函数的实现形式并用nn.Module封装后得到的。

一般我们常用的是类的实现形式。它们封装在torch.nn模块下,并且类名以Loss结尾。

常用的一些内置损失函数说明如下。

nn.MSELoss(均方误差损失,也叫做L2损失,用于回归)

nn.L1Loss (L1损失,也叫做绝对值误差损失,用于回归)

nn.SmoothL1Loss (平滑L1损失,当输入在-1到1之间时,平滑为L2损失,用于回归)

nn.BCELoss (二元交叉熵,用于二分类,输入已经过nn.Sigmoid激活,对不平衡数据集可以用weigths参数调整类别权重)

nn.BCEWithLogitsLoss (二元交叉熵,用于二分类,输入未经过nn.Sigmoid激活)

nn.CrossEntropyLoss (交叉熵,用于多分类,要求label为稀疏编码,输入未经过nn.Softmax激活,对不平衡数据集可以用weigths参数调整类别权重)

nn.NLLLoss (负对数似然损失,用于多分类,要求label为稀疏编码,输入经过nn.LogSoftmax激活)

nn.CosineSimilarity(余弦相似度,可用于多分类)

nn.AdaptiveLogSoftmaxWithLoss (一种适合非常多类别且类别分布很不均衡的损失函数,会自适应地将多个小类别合成一个cluster)

更多损失函数的介绍参考如下知乎文章:

《PyTorch的十八个损失函数》

二,自定义L1和L2正则化项

通常认为L1 正则化可以产生稀疏权值矩阵,即产生一个稀疏模型,可以用于特征选择。

而L2 正则化可以防止模型过拟合(overfitting)。一定程度上,L1也可以防止过拟合。

# L2正则化
def L2Loss(model,alpha):l2_loss = torch.tensor(0.0, requires_grad=True)for name, param in model.named_parameters():if 'bias' not in name: #一般不对偏置项使用正则l2_loss = l2_loss + (0.5 * alpha * torch.sum(torch.pow(param, 2)))return l2_loss# L1正则化
def L1Loss(model,beta):l1_loss = torch.tensor(0.0, requires_grad=True)for name, param in model.named_parameters():if 'bias' not in name:l1_loss = l1_loss +  beta * torch.sum(torch.abs(param))return l1_loss# 将L2正则和L1正则添加到FocalLoss损失,一起作为目标函数
def focal_loss_with_regularization(y_pred,y_true):focal = FocalLoss()(y_pred,y_true) l2_loss = L2Loss(model,0.001) #注意设置正则化项系数l1_loss = L1Loss(model,0.001)total_loss = focal + l2_loss + l1_lossreturn total_lossmodel.compile(loss_func =focal_loss_with_regularization,optimizer= torch.optim.Adam(model.parameters(),lr = 0.01),metrics_dict={"accuracy":accuracy})

只写了部分,具体的参考《20天吃透Pytorch》

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/389348.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

读取Mc1000的 唯一 ID 机器号

先引用Symbol.ResourceCoordination 然后引用命名空间 using System;using System.Security.Cryptography;using System.IO; 以下为类程序 /// <summary> /// 获取设备id /// </summary> /// <returns></returns> public static string GetDevi…

样本均值的抽样分布_抽样分布样本均值

样本均值的抽样分布One of the most important concepts discussed in the context of inferential data analysis is the idea of sampling distributions. Understanding sampling distributions helps us better comprehend and interpret results from our descriptive as …

玩转ceph性能测试---对象存储(一)

笔者最近在工作中需要测试ceph的rgw&#xff0c;于是边测试边学习。首先工具采用的intel的一个开源工具cosbench&#xff0c;这也是业界主流的对象存储测试工具。 1、cosbench的安装&#xff0c;启动下载最新的cosbench包wget https://github.com/intel-cloud/cosbench/release…

[BZOJ 4300]绝世好题

Description 题库链接 给定一个长度为 \(n\) 的数列 \(a_i\) &#xff0c;求 \(a_i\) 的子序列 \(b_i\) 的最长长度&#xff0c;满足 \(b_i\wedge b_{i-1}\neq 0\) &#xff08; \(\wedge\) 表示按位与&#xff09; \(1\leq n\leq 100000\) Solution 令 \(f_i\) 为二进制第 \(i…

因果关系和相关关系 大数据_数据科学中的相关性与因果关系

因果关系和相关关系 大数据Let’s jump into it right away.让我们马上进入。 相关性 (Correlation) Correlation means relationship and association to another variable. For example, a movement in one variable associates with the movement in another variable. For…

Pytorch构建模型的3种方法

这个地方一直是我思考的地方&#xff01;因为学的代码太多了&#xff0c;构建的模型各有不同&#xff0c;这里记录一下&#xff01; 可以使用以下3种方式构建模型&#xff1a; 1&#xff0c;继承nn.Module基类构建自定义模型。 2&#xff0c;使用nn.Sequential按层顺序构建模…

vue取数据第一个数据_我作为数据科学家的第一个月

vue取数据第一个数据A lot.很多。 I landed my first job as a Data Scientist at the beginning of August, and like any new job, there’s a lot of information to take in at once.我于8月初找到了数据科学家的第一份工作&#xff0c;并且像任何新工作一样&#xff0c;一…

Flask-SocketIO 简单使用指南

Flask-SocketIO 使 Flask 应用程序能够访问客户端和服务器之间的低延迟双向通信。客户端应用程序可以使用 Javascript&#xff0c;C &#xff0c;Java 和 Swift 中的任何 SocketIO 官方客户端库或任何兼容的客户端来建立与服务器的永久连接。 安装 直接使用 pip 来安装&#xf…

STL-开篇

基本概念 STL&#xff1a; Standard Template Library&#xff0c;标准模板库 定义&#xff1a; c引入的一个标准类库 特点&#xff1a;1&#xff09;数据结构和算法的 c实现&#xff08; 采用模板类和模板函数&#xff09;2&#xff09;数据的存储和算法的分离3&#xff09;高…

Symbol Mc1000 声音的设置以及播放

首先引用Symbol.Audio 加一命名空间using Symbol.Audio; /声音设备的设置 //Select Device from device list Symbol.Audio.Device MyDevice (Symbol.Audio.Device)Symbol.StandardForms.SelectDevice.Select( Symbol.Audio.Controller.Title, Symbol.Audio.Devic…

/bin/bash^M: 坏的解释器: 没有那个文件或目录

在win下编辑的时候&#xff0c;换行结尾是\n\r &#xff0c; 而在linux下 是\n&#xff0c;所以会多出来一个\r&#xff0c;这样会出现错误 此时执行 sed -i s/\r$// file.sh 将file.sh中的\r都替换为空白&#xff0c;问题解决转载于:https://www.cnblogs.com/zzdbullet/p/9890…

rcp rapido_为什么气流非常适合Rapido

rcp rapidoBack in 2019, when we were building our data platform, we started building the data platform with Hadoop 2.8 and Apache Hive, managing our own HDFS. The need for managing workflows whether it’s data pipelines, i.e. ETL’s, machine learning predi…

pandas处理丢失数据与数据导入导出

3.4pandas处理丢失数据 头文件&#xff1a; import numpy as np import pandas as pd丢弃数据部分&#xff1a; dates pd.date_range(20130101,periods6) df pd.DataFrame(np.random.randn(6,4),indexdates,columns[A,B,C,D]) df.iloc[0,1] np.nan df.iloc[1,2] np.nanp…

Mysql5.7开启远程

2019独角兽企业重金招聘Python工程师标准>>> 1.注掉bind-address #bind-address 127.0.0.1 2.开启远程访问权限 grant all privileges on *.* to root"xxx.xxx.xxx.xxx" identified by "密码"; 或 grant all privileges on *.* to root"%…

分类结果可视化python_可视化分类结果的另一种方法

分类结果可视化pythonI love good data visualizations. Back in the days when I did my PhD in particle physics, I was stunned by the histograms my colleagues built and how much information was accumulated in one single plot.我喜欢出色的数据可视化。 早在我获得…

算法组合 优化算法_算法交易简化了风险价值和投资组合优化

算法组合 优化算法Photo by Markus Spiske (left) and Jamie Street (right) on UnsplashMarkus Spiske (左)和Jamie Street(右)在Unsplash上的照片 In the last post, we saw how actual algorithms are developed and tested. In this post, we will figure out the level of…

Symbol Mc1000 快捷键 的 设置 事件 开发

switch (e.KeyCode) { ///数据 case Keys.F1://清除数据 if(File.Exists("Storage Card/CG.sdf")) { Mc.gConn.Close(); Mc.gConn.Dispose(); File.Delete("Storage Card/CG.sdf"); } MessageBox.S…

pandas合并concatmerge和plot画图

3.6&#xff0c;3.7pandas合并concat&merge 头文件&#xff1a; import pandas as pd import numpy as npconcat基础合并用法 df1 pd.DataFrame(np.ones((3,4))*0,columns [a,b,c,d]) df2 pd.DataFrame(np.ones((3,4))*1,columns [a,b,c,d]) df3 pd.DataFrame(np.ones…

Android跳转WIFI界面的四种方式

第一种 Intent intent new Intent(); intent.setAction("android.net.wifi.PICK_WIFI_NETWORK"); startActivity(intent); 第二种 startActivity(new Intent(android.provider.Settings.ACTION_WIFI_SETTINGS)); 第三种 Intent i new Intent(); if(android.os.Buil…

PS抠发丝技巧 「选择并遮住…」

PS抠发丝技巧 「选择并遮住…」 现在的海报设计&#xff0c;大多数都有模特MM&#xff0c;然而MM的头发实用太多了&#xff0c;有的还飘起来…… 对于设计师(特别是淘宝美工)没有一个强大、快速、实用的抠发丝技巧真的混不去哦。而PS CC 2017版本开始&#xff0c;就有了一个强大…