简易机器学习笔记(五)更换损失函数:交叉熵

前言

我们之前用的是均方差作为我们神经网络的损失函数评估值,但是我们对于结果,比如给定你一张应该是0的照片,它识别成了6,这个时候这个均方差表达了什么特别的含义吗?显然你识别成6并不代表它比识别成1的情况误差更大。

所以说我们需要一种全新的方式,基于概率的方案来对结果进行规范。也就是我们说的交叉熵损失函数。

至于什么是交叉熵损失函数,由于本文不涉及实际的数学论证,感兴趣可以看这个简单的小视频:
你真的理解交叉熵损失函数了吗?

修改之处

我这里由于是笔记,就不过多对理论论证了,这里只说在实际开发中需要修改哪些地方:

  1. 在数据处理部分,需要修改标签变量Label的格式,代码如下所示。

从:label = np.reshape(labels[i], [1]).astype(‘float32’)
到:label = np.reshape(labels[i], [1]).astype(‘int64’)
(注意,一般情况下paddle的库对于label的认知就是默认是int64类型)

  1. 在网络定义部分,需要修改输出层结构,代码如下所示。
    从:self.fc = Linear(in_features=980, out_features=1)
    到:self.fc = Linear(in_features=980, out_features=10)

先说明为什么是这样,因为输出的数字理论上只有0 - 9 十个数字,这里输出层实际上要对所有的可能进行规范,有几个输出就要给定多少可能。

  1. 修改计算损失的函数,从均方误差(常用于回归问题)到交叉熵误差(常用于分类问题),代码如下所示。

从:loss = paddle.nn.functional.square_error_cost(predict, label)
到:loss = paddle.nn.functional.cross_entropy(predict, label)

实际代码

#修改损失函数的初体验# 损失函数是模型优化的目标,用于在众多的参数取值中,识别最理想的取值。
# 损失函数的计算在训练过程的代码中,每一轮模型训练的过程都相同,分如下三步:#1. 先根据输入数据正向计算预测输出。
#2. 再根据预测值和真实值计算损失。
#3. 最后根据损失反向传播梯度并更新参数。# 在之前的方案中,我们复用了房价预测模型的损失函数-均方误差。从预测效果来看,虽然损失不断下降,模型的预测值逐渐逼近真实值,但模型的最终效果不够理想。究其根本,不同的深度学习任务需要有各自适宜的损失函数。
# 我们以房价预测和手写数字识别两个任务为例,详细剖析其中的缘由如下:# 1. 房价预测是回归任务,而手写数字识别是分类任务,使用均方误差作为分类任务的损失函数存在逻辑和效果上的缺欠。
# 2. 房价可以是大于0的任何浮点数,而手写数字识别的输出只可能是0~9之间的10个整数,相当于一种标签。
# 3. 在房价预测的案例中,由于房价本身是一个连续的实数值,因此以模型输出的数值和真实房价差距作为损失函数(Loss)是符合道理的。
# 但对于分类问题,真实结果是分类标签,而模型输出是实数值,导致以两者相减作为损失不具备物理含义。# 如果模型能输出10个标签的概率,对应真实标签的概率输出尽可能接近100%,而其他标签的概率输出尽可能接近0%,且所有输出概率之和为1。
# 这是一种更合理的假设!与此对应,真实的标签值可以转变成一个10维度的one-hot向量
# 在对应数字的位置上为1,其余位置为0,比如标签“6”可以转变成[0,0,0,0,0,0,1,0,0,0]。#数据处理部分之前的代码,保持不变# 在手写数字识别任务中,仅改动三行代码,就可以将在现有模型的损失函数替换成交叉熵(Cross_entropy)。# 在读取数据部分,将标签的类型设置成int,体现它是一个标签而不是实数值(飞桨框架默认将标签处理成int64)。
# 在网络定义部分,将输出层改成“输出十个标签的概率”的模式。
# 在训练过程部分,将损失函数从均方误差换成交叉熵。# 在数据处理部分,需要修改标签变量Label的格式,代码如下所示。# 从:label = np.reshape(labels[i], [1]).astype(‘float32’)
# 到:label = np.reshape(labels[i], [1]).astype(‘int64’)import os
import random
import paddle
import numpy as np
import matplotlib.pyplot as plt
from PIL import Imageimport gzip
import json# 创建一个类MnistDataset,继承paddle.io.Dataset 这个类
# MnistDataset的作用和上面load_data()函数的作用相同,均是构建一个迭代器
class MnistDataset(paddle.io.Dataset):def __init__(self, mode):datafile = './work/mnist.json.gz'data = json.load(gzip.open(datafile))# 读取到的数据区分训练集,验证集,测试集train_set, val_set, eval_set = data# 数据集相关参数,图片高度IMG_ROWS, 图片宽度IMG_COLSself.IMG_ROWS = 28self.IMG_COLS = 28if mode=='train':# 获得训练数据集imgs, labels = train_set[0], train_set[1]elif mode=='valid':# 获得验证数据集imgs, labels = val_set[0], val_set[1]elif mode=='eval':# 获得测试数据集imgs, labels = eval_set[0], eval_set[1]else:raise Exception("mode can only be one of ['train', 'valid', 'eval']")# 校验数据imgs_length = len(imgs)assert len(imgs) == len(labels), \"length of train_imgs({}) should be the same as train_labels({})".format(len(imgs), len(labels))self.imgs = imgsself.labels = labelsdef __getitem__(self, idx):img = np.reshape(self.imgs[idx], [1, self.IMG_ROWS, self.IMG_COLS]).astype('float32')label = np.reshape(self.labels[idx], [1]).astype('int64')return img, labeldef __len__(self):return len(self.imgs)
# 声明数据加载函数,使用训练模式,MnistDataset构建的迭代器每次迭代只返回batch=1的数据
train_dataset = MnistDataset(mode='train')
# 使用paddle.io.DataLoader 定义DataLoader对象用于加载Python生成器产生的数据,
# DataLoader 返回的是一个批次数据迭代器,并且是异步的;
train_loader = paddle.io.DataLoader(train_dataset, batch_size=100, shuffle=True, drop_last=True)
val_dataset = MnistDataset(mode='valid')
val_loader = paddle.io.DataLoader(val_dataset, batch_size=128,drop_last=True)# 在网络定义部分,需要修改输出层结构,代码如下所示。# 从:self.fc = Linear(in_features=980, out_features=1)
# 到:self.fc = Linear(in_features=980, out_features=10)# 定义 SimpleNet 网络结构
import paddle
from paddle.nn import Conv2D, MaxPool2D, Linear
import paddle.nn.functional as F
# 多层卷积神经网络实现
class MNIST(paddle.nn.Layer):def __init__(self):super(MNIST, self).__init__()# 定义卷积层,输出特征通道out_channels设置为20,卷积核的大小kernel_size为5,卷积步长stride=1,padding=2self.conv1 = Conv2D(in_channels=1, out_channels=20, kernel_size=5, stride=1, padding=2)# 定义池化层,池化核的大小kernel_size为2,池化步长为2self.max_pool1 = MaxPool2D(kernel_size=2, stride=2)# 定义卷积层,输出特征通道out_channels设置为20,卷积核的大小kernel_size为5,卷积步长stride=1,padding=2self.conv2 = Conv2D(in_channels=20, out_channels=20, kernel_size=5, stride=1, padding=2)# 定义池化层,池化核的大小kernel_size为2,池化步长为2self.max_pool2 = MaxPool2D(kernel_size=2, stride=2)# 定义一层全连接层,输出维度是10self.fc = Linear(in_features=980, out_features=10)# 定义网络前向计算过程,卷积后紧接着使用池化层,最后使用全连接层计算最终输出# 卷积层激活函数使用Reludef forward(self, inputs):x = self.conv1(inputs)x = F.relu(x)x = self.max_pool1(x)x = self.conv2(x)x = F.relu(x)x = self.max_pool2(x)x = paddle.reshape(x, [x.shape[0], 980])x = self.fc(x)return x#  修改计算损失的函数,从均方误差(常用于回归问题)到交叉熵误差(常用于分类问题),代码如下所示。
# 从:loss = paddle.nn.functional.square_error_cost(predict, label)
# 到:loss = paddle.nn.functional.cross_entropy(predict, label)def evaluation(model, datasets):model.eval()acc_set = list()for batch_id, data in enumerate(datasets()):images, labels = dataimages = paddle.to_tensor(images)labels = paddle.to_tensor(labels)pred = model(images)   # 获取预测值acc = paddle.metric.accuracy(input=pred, label=labels)acc_set.extend(acc.numpy())# #计算多个batch的准确率acc_val_mean = np.array(acc_set).mean()return acc_val_mean#仅修改计算损失的函数,从均方误差(常用于回归问题)到交叉熵误差(常用于分类问题)
def train(model):model.train()#调用加载数据的函数# train_loader = load_data('train')# val_loader = load_data('valid')opt = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())EPOCH_NUM = 10for epoch_id in range(EPOCH_NUM):for batch_id, data in enumerate(train_loader()):#准备数据images, labels = dataimages = paddle.to_tensor(images)labels = paddle.to_tensor(labels)#前向计算的过程predicts = model(images)#计算损失,使用交叉熵损失函数,取一个批次样本损失的平均值loss = F.cross_entropy(predicts, labels)avg_loss = paddle.mean(loss)#每训练了200批次的数据,打印下当前Loss的情况if batch_id % 200 == 0:print("epoch: {}, batch: {}, loss is: {}".format(epoch_id, batch_id, avg_loss.numpy()))#后向传播,更新参数的过程avg_loss.backward()# 最小化loss,更新参数opt.step()# 清除梯度opt.clear_grad()# acc_train_mean = evaluation(model, train_loader)# acc_val_mean = evaluation(model, val_loader)# print('train_acc: {}, val acc: {}'.format(acc_train_mean, acc_val_mean))   #保存模型参数paddle.save(model.state_dict(), 'mnist.pdparams')# #创建模型    
# model = MNIST()
# #启动训练过程
# train(model)# 读取一张本地的样例图片,转变成模型输入的格式
def load_image(img_path):# 从img_path中读取图像,并转为灰度图im = Image.open(img_path).convert('L')im = im.resize((28, 28), Image.LANCZOS)im = np.array(im).reshape(1, 1, 28, 28).astype(np.float32)# 图像归一化im = 1.0 - im / 255.return im# 定义预测过程
model = MNIST()
params_file_path = 'mnist.pdparams'
img_path = 'work/example_6.jpg'
# 加载模型参数
param_dict = paddle.load(params_file_path)
model.load_dict(param_dict)
# 灌入数据
model.eval()
tensor_img = load_image(img_path)
#模型反馈10个分类标签的对应概率
results = model(paddle.to_tensor(tensor_img))
#取概率最大的标签作为预测输出
lab = np.argsort(results.numpy())
print("本次预测的数字是: ", lab[0][-1])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/609093.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

各类Java对象

概念的混淆: 新一代的开发者,学习某一概念的不同描述或是案例后,人脑会去抽象、提取其特征。这一过程可能造成语义扩散、概念扭曲。这是一个盲人摸象的过程。 写到这里时,我在想:“盲人摸象”与“抽象”的“象”是不是一个意思呢&…

灵魂三连问:是5G卡吗?支持5G吗?是5G套餐吗

关于5G的问题,小伙伴们的疑问是不是很多,它和4G到底有什么区别呢?什么是5G卡?什么是5G套餐?支持5G吗?什么是5G基站?我想大家现在一定是晕的,下面小编来给大家解惑! 1&…

【elfboard linux 开发板】9. 虚拟机扩容和内核编译

1. 虚拟机扩容 需要将虚拟机的快照全都删除,并且将运行的系统关机点击扩展,改为需要的磁盘大小安装gparted工具 sudo apt-get install gparted 如果报错,则按照出错内容修改,一般是出现下载错误,可以使用下列命令&…

PHP 平滑重启 kill -SIGUSR2 <PID>

在 PHP 中,平滑重启通常涉及向 PHP 进程发送特定的信号。以下是使用信号进行平滑重启的一般步骤: 1. 查找 PHP 进程的主进程 ID (PID): 首先,您需要找到正在运行的 PHP 进程的主进程 ID (PID)。您可以使用 ps 命令来查找 ps a…

Ubuntu22.04安装VTK8.2

1. 安装ccmake 和 VTK 的依赖项: sudo apt-get install cmake-curses-gui sudo apt-get install freeglut3-dev2.下载VTK-8.2.0库 VTK官方网址 自己选择合适的版本进行下载,解压到VTK文件夹下,再新建文件下名为build 3. 配置VTK 进入buil…

【Python百宝箱】模拟未见之境:精准工具畅游分子动力学风景

分子演绎:模拟工具的综合探索 前言 在当今科学研究中,分子动力学模拟成为解析原子和分子行为的关键工具之一。本文将深入探讨几种领先的分子动力学模拟工具,包括MDTraj、ASE(原子模拟环境)、OpenMM和CHARMM。这些工具…

控制器转盘错误

目录 起因 经过(调试) 点焊机数据修改 结果 今天来记录设备的维修记录,下面来根据起因,经过,结果来说明情况!!!,希望对读者的你有帮助!!!

强化学习Double DQN方法玩雅达利Breakout游戏完整实现代码与评估pytorch

1. 实验环境 1.1 硬件配置 处理器:2*AMD EPYC 7773X 64-Core内存:1.5TB显卡:8*NVIDIA GeForce RTX 3090 24GB 1.2 工具环境 Python:3.10.12Anaconda:23.7.4系统:Ubuntu 22.04.3 LTS (GNU/Linux 5.15.0-…

软件测试|解决‘pip‘ 不是内部或外部命令,也不是可运行的程序或批处理文件

前言 很多Python初学者在使用Python时,会遇到环境的问题,比如无法使用pip命令安装第三方库的问题,如下图: 当出现错误信息 "pip 不是内部或外部命令,也不是可运行的程序或批处理文件" 时,这通常…

echarts柱状图加单位,底部文本溢出展示

刚开始设置了半天都不展示单位,后来发现是被挡住了,需要调高top值 // 基于准备好的dom,初始化echarts实例var myChart echarts.init(document.getElementById("echartD"));rankOption {// backgroundColor: #00265f,tooltip: {…

树定义及遍历

1、定义树 可以参考链表,链表遍历不方便,如果单链表有多个next指针,则就形成了树。 Java: public class TreeNode {int val;TreeNode left, right;TreeNode(int val) { this.val val; this.left null;this.right null;} } Python&#…

WIN32 桌面应用编程综合实验一学习记录

文章目录 引用传递和指针传递的区别和联系如何创建一个空的WINDOWS桌面项目C编程中函数声明、定义和链接的基本概念 引用传递和指针传递的区别和联系 case ID_SETTING_FONT:GetDrawFont(hWnd, gs_logFont, &gs_TextColor); break;logFont 和 pColor 的用法体现了 C 中两种…

stm32的规则采样与注入采样的理解

规则与注入转换 在STM32中,规则采样(Regular Conversion)和注入采样(Injected Conversion)是用于模数转换的两种不同模式。 规则采样(Regular Conversion):规则采样是STM32中最常用…

面试算法105:最大的岛屿

题目 海洋岛屿地图可以用由0、1组成的二维数组表示,水平或竖直方向相连的一组1表示一个岛屿,请计算最大的岛屿的面积(即岛屿中1的数目)。例如,在下图中有4个岛屿,其中最大的岛屿的面积为5。 分析 将岛屿…

力扣-34. 在排序数组中查找元素的第一个和最后一个位置

文章目录 力扣题目代码 力扣题目 给你一个按照非递减顺序排列的整数数组 nums,和一个目标值 target。请你找出给定目标值在数组中的开始位置和结束位置。 如果数组中不存在目标值 target,返回 [-1, -1]。 你必须设计并实现时间复杂度为 O(log n) 的算…

山东名岳轩印刷包装携专业包装袋盛装亮相2024济南生物发酵展

山东名岳轩印刷包装有限公司盛装亮相2024第12届国际生物发酵展,3月5-7日山东国际会展中心与您相约! 展位号:1号馆F17 山东名岳轩印刷包装有限公司是一家拥有南北两个生产厂区,设计、制版、印刷,营销策划为一体的专业…

python运行报错_ModuleNotFoundError: No module named ‘xxx‘,调用自己定义的文件报错。

问题描述:cifar10.py文件调用non_stationary.py文件的方法 目录结构: project_directory/ └── continuum/├── dataset_scripts/│ └── cifar10.py├── __init__.py├── continuum.py└── non_stationary.py# cifar10.pyfrom continuum…

JavaSec基础 反射修改Final修饰的属性及绕过高版本反射限制

反射重拾 半年没碰java了 先写点基础回忆一下 反射弹计算器 public class Test {public static void main(String[] args) throws Exception {Class<?> clazz Class.forName("java.lang.Runtime");clazz.getDeclaredMethod("exec", String.cla…

springBoot-自动配置原理

以下笔记内容&#xff0c; 整理自B站黑马springBoot视频&#xff0c;抖音Holis 1、自动配置原理 1.收集Spring开发者的编程习惯&#xff0c;整理开发过程使用的常用技术列表一>(技术集A) 2.收集常用技术(技术集A)的使用参数&#xff0c;整理开发过程中每个技术的常用设置列表…

灵活轻巧的java接口自动化测试实战

前言 无论是自动化测试还是自动化部署&#xff0c;撸码肯定少不了&#xff0c;所以下面的基于java语言的接口自动化测试&#xff0c;要想在业务上实现接口自动化&#xff0c;前提是要有一定的java基础。 如果没有java基础&#xff0c;也没关系。这里小编也为大家提供了一套jav…