【深度学习实验】前馈神经网络(八):模型评价(自定义支持分批进行评价的Accuracy类)

目录

一、实验介绍

 二、实验环境

1. 配置虚拟环境

2. 库版本介绍

三、实验内容

 0. 导入必要的工具包

1. __init__(构造函数)

2. update函数(更新评价指标)

5. accumulate(计算准确率)

4. reset(重置评价指标)

5. 构造数据进行测试

6. 代码整合


一、实验介绍

       本文将实现一个辅助功能——计算预测的准确率。Accuracy支持对每一个回合中每批数据进行评价,并将结果累积,最终获得整批数据的评价结果。

  • 在训练或验证过程中迭代地调用update方法来更新评价指标;
  • 使用accumulate方法获取累计的准确率;
  • 通过reset方法重置评价指标,以便进行下一轮的计算。

 二、实验环境

    本系列实验使用了PyTorch深度学习框架,相关操作如下:

1. 配置虚拟环境

conda create -n DL python=3.7 
conda activate DL
pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
conda install matplotlib
 conda install scikit-learn

2. 库版本介绍

软件包本实验版本目前最新版
matplotlib3.5.33.8.0
numpy1.21.61.26.0
python3.7.16
scikit-learn0.22.11.3.0
torch1.8.1+cu1022.0.1
torchaudio0.8.12.0.2
torchvision0.9.1+cu1020.15.2

三、实验内容

ChatGPT:

        前馈神经网络(Feedforward Neural Network)是一种常见的人工神经网络模型,也被称为多层感知器(Multilayer Perceptron,MLP)。它是一种基于前向传播的模型,主要用于解决分类和回归问题。

        前馈神经网络由多个层组成,包括输入层、隐藏层和输出层。它的名称"前馈"源于信号在网络中只能向前流动,即从输入层经过隐藏层最终到达输出层,没有反馈连接。

以下是前馈神经网络的一般工作原理:

  1. 输入层:接收原始数据或特征向量作为网络的输入,每个输入被表示为网络的一个神经元。每个神经元将输入加权并通过激活函数进行转换,产生一个输出信号。

  2. 隐藏层:前馈神经网络可以包含一个或多个隐藏层,每个隐藏层由多个神经元组成。隐藏层的神经元接收来自上一层的输入,并将加权和经过激活函数转换后的信号传递给下一层。

  3. 输出层:最后一个隐藏层的输出被传递到输出层,输出层通常由一个或多个神经元组成。输出层的神经元根据要解决的问题类型(分类或回归)使用适当的激活函数(如Sigmoid、Softmax等)将最终结果输出。

  4. 前向传播:信号从输入层通过隐藏层传递到输出层的过程称为前向传播。在前向传播过程中,每个神经元将前一层的输出乘以相应的权重,并将结果传递给下一层。这样的计算通过网络中的每一层逐层进行,直到产生最终的输出。

  5. 损失函数和训练:前馈神经网络的训练过程通常涉及定义一个损失函数,用于衡量模型预测输出与真实标签之间的差异。常见的损失函数包括均方误差(Mean Squared Error)和交叉熵(Cross-Entropy)。通过使用反向传播算法(Backpropagation)和优化算法(如梯度下降),网络根据损失函数的梯度进行参数调整,以最小化损失函数的值。

        前馈神经网络的优点包括能够处理复杂的非线性关系,适用于各种问题类型,并且能够通过训练来自动学习特征表示。然而,它也存在一些挑战,如容易过拟合、对大规模数据和高维数据的处理较困难等。为了应对这些挑战,一些改进的网络结构和训练技术被提出,如卷积神经网络(Convolutional Neural Networks)和循环神经网络(Recurrent Neural Networks)等。

本系列为实验内容,对理论知识不进行详细阐释

(咳咳,其实是没时间整理,待有缘之时,回来填坑)

977468b5ae9843c6a88005e792817cb1.png​​

 0. 导入必要的工具包

import torch
from sklearn.datasets import load_iris
from torch.utils.data import Dataset, DataLoader
  • DatasetDataLoader类用于处理数据集和数据加载

这段代码定义了一个名为Accuracy的类,用于支持分批进行模型评价,特别是在分类任务中计算准确率。

1. __init__(构造函数)

class Accuracy:def __init__(self, is_logist=True):self.num_correct = 0self.num_count = 0self.is_logist = is_logist
  • 构造函数在创建Accuracy对象时被调用。它接受一个可选的参数is_logist,默认为True,用于指示是否为logist形式的预测值。
  • self.num_correct用于记录正确预测的样本个数。
  • self.num_count用于记录总样本个数。
  • self.is_logist指示是否为logist形式的预测值。

2. update函数(更新评价指标)

def update(self, outputs, labels):if outputs.shape[1] == 1:outputs = outputs.squeeze(-1)if self.is_logist:preds = (outputs >= 0).long()else:preds = (outputs >= 0.5).long()else:preds = torch.argmax(outputs, dim=1).long()labels = labels.squeeze(-1)batch_correct = (preds==labels).float().sum()batch_count = len(labels)self.num_correct += batch_correctself.num_count += batch_count
  • update方法用于更新评价指标。它接受两个参数outputslabels,分别表示模型的预测输出和真实标签。
  • 根据outputs的形状判断任务类型。
    •  如果outputs是二维张量且第二维大小为1,那么表示是二分类任务。
      •   如果is_logist=True,则将outputs通过阈值(0)转换为预测值preds,并将其转换为整数类型。
      •   如果is_logist=False,则将outputs通过阈值(0.5)转换为预测值preds,并将其转换为整数类型。
    •  如果outputs是二维张量且第二维大小大于1,表示是多分类任务。此时,将outputs中概率最大的类别作为预测值preds
  • labels去除多余的维度,并计算本批数据中预测正确的样本个数batch_correct
  • 获取本批数据的样本个数batch_count
  • 更新num_correctnum_count,累积计算正确样本个数和总样本个数。

5. accumulate(计算准确率)

def accumulate(self):if self.num_count == 0:return 0return self.num_correct / self.num_count
  • accumulate方法用于计算准确率。
    •  如果num_count为0,表示没有进行过更新,返回0。
    • 否则,返回正确样本个数除以总样本个数的比例,即准确率

4. reset(重置评价指标)

def reset(self):self.num_correct = 0self.num_count = 0
  • reset方法用于重置评价指标,将num_correctnum_count重置为0,以便进行下一轮评价

5. 构造数据进行测试

y = torch.tensor([0, 2])
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
acc = Accuracy()
acc.update(y_hat, y)
acc.num_correct

6. 代码整合

import torch# 支持分批进行模型评价的 Accuracy 类
class Accuracy:def __init__(self, is_logist=True):# 正确样本个数self.num_correct = 0# 样本总数self.num_count = 0self.is_logist = is_logistdef update(self, outputs, labels):# 判断是否为二分类任务if outputs.shape[1] == 1:outputs = outputs.squeeze(-1)# 判断是否是logit形式的预测值if self.is_logist:preds = (outputs >= 0).long()else:preds = (outputs >= 0.5).long()else:# 多分类任务时,计算最大元素索引作为类别preds = torch.argmax(outputs, dim=1).long()# 获取本批数据中预测正确的样本个数labels = labels.squeeze(-1)batch_correct = (preds == labels).float().sum()batch_count = len(labels)# 更新self.num_correct += batch_correctself.num_count += batch_countdef accumulate(self):# 使用累计的数据,计算总的评价指标if self.num_count == 0:return 0return self.num_correct / self.num_countdef reset(self):self.num_correct = 0self.num_count = 0y = torch.tensor([0, 2])
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
acc = Accuracy()
acc.update(y_hat, y)
acc.num_correct

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/87730.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

进行 XSS 攻击 和 如何防御

跨站脚本攻击(XSS 攻击)是 Web 开发中最危险的攻击之一。以下是它们的工作原理以及防御方法。 XSS 攻击 跨站脚本攻击就是在另一个用户的计算机上运行带有恶意的 JS 代码。假如我们的程序没有对这些恶意的脚本进行防御的话,他们就会由我们的…

李宏毅hw-10 ——adversarial attack

一、查漏补缺: 1.关于glob.glob的用法,返回一个文件路径的 列表: 当然,再套用1个sort,就是将所有的文件路径按照字母进行排序了 2.relpath relative_path返回相对于基准路径的相对路径的函数 二、代码剖析&#xff…

STM32单片机入门学习(四)-蜂鸣器

蜂鸣器接线 低平蜂鸣器,低电平发声,高电平不发声, 三个排针,VCC接3.3v,GND接地,I/O接A0口,如图: 蜂鸣器代码:响一秒停半秒 #include "stm32f10x.h" #includ…

LRU、LFU 内存淘汰算法的设计与实现

1、背景介绍 LRU、LFU都是内存管理淘汰算法,内存管理是计算机技术中重要的一环,也是多数操作系统中必备的模块。应用场景:假设 给定你一定内存空间,需要你维护一些缓存数据,LRU、LFU就是在内存已经满了的情况下&#…

2023-2024年最新大数据学习路线

文章目录 2023-2024年最新大数据学习路线大数据开发入门*01*阶段案例实战 大数据核心基础*02*阶段案例实战 千亿级数仓技术*03*阶段项目实战 PB级内存计算04阶段项目实战 亚秒级实时计算*05*阶段项目实战 大厂面试*06* 2023-2024年最新大数据学习路线 新路线图在Spark一章不再…

Android跨进程通信:Binder机制原理

目录 1. Binder到底是什么? 2. 知识储备 2.1 进程空间划分 2.2 进程隔离 & 跨进程通信( IPC ) 2.3 内存映射 2.3.1 作用 2.3.2 实现过程 2.3.3 特点 2.3.4 应用场景 2.3.5 实例讲解 ① 文件读 / 写操作 ② 跨进程通信 3. Bi…

Cron表达式_用于定时调度任务

一、Cron表达式简介 Cron表达式是一个用于设置计划任务的字符串,该字符串以5或6个空格分隔,分为6或7个域,每一个域代表任务在相应时间、日期或时间间隔执行的规则【Cron表达式最初是在类Unix操作中系统中使用的,但现在已经广泛应用…

人机融合需要在事实与价值之间构建新型的拓扑关系

人机融合,这是指将人类智慧(含艺术)与计算机科技相结合,共同解决复杂问题的一种新方式。在人机融合中,我们需要构建事实与价值之间的新型拓扑关系,以实现更有效的知识管理和决策支持。 事实是指客观存在的、…

Python爬虫爬取豆瓣电影短评(爬虫入门,Scrapy框架,Xpath解析网站,jieba分词)

声明:以下内容仅供学习参考,禁止用于任何商业用途 很久之前就想学爬虫了,但是一直没机会,这次终于有机会了 主要参考了《疯狂python讲义》的最后一章 首先安装Scrapy: pip install scrapy 然后创建爬虫项目&#…

力扣26:删除有序数组中的重复项

26. 删除有序数组中的重复项 - 力扣(LeetCode) 题目: 给你一个 非严格递增排列 的数组 nums ,请你 原地 删除重复出现的元素,使每个元素 只出现一次 ,返回删除后数组的新长度。元素的 相对顺序 应该保持 …

Leetcode 95. 不同的二叉搜索树 II

文章目录 题目代码&#xff08;9.21 首刷看解析&#xff09; 题目 Leetcode 95. 不同的二叉搜索树 II 代码&#xff08;9.21 首刷看解析&#xff09; class Solution { public:vector<TreeNode*> generateTrees(int n) {return build(1,n);}vector<TreeNode*> bu…

将本地前端工程中的npm依赖上传到Nexus

【问题背景】 用Nexus搭建了内网的依赖仓库&#xff0c;需要将前端工程中node_modules中的依赖上传到Nexus上&#xff0c;但是node_modules中的依赖已经是解压后的状态&#xff0c;如果直接机械地将其简单地打包上传到Nexus&#xff0c;那么无法通过npm install下载使用。故有…

Jenkins Job的Migrate之旅

场景 使用Jenkins 做为应用的定时任务处理&#xff0c; 在上面建立的800个左右的Job, 这个环境运行了很多年&#xff0c; 当初安装的最新版本是Jenkins 1.642.3&#xff0c; 现在因为OS需要升级等原因&#xff0c; 驻在上面的Jenkins 服务器也需要一并升级&#xff0c;在新的服…

Mock.js之Element-ui搭建首页导航与左侧菜单

&#x1f3ac; 艳艳耶✌️&#xff1a;个人主页 &#x1f525; 个人专栏 &#xff1a;《Spring与Mybatis集成整合》《springMvc使用》 ⛺️ 生活的理想&#xff0c;为了不断更新自己 ! 1、Mock.js的使用 1.1.什么是Mock.js Mock.js是一个模拟数据的生成器&#xff0c;用来帮助前…

浅谈C++|文件篇

C中的文件操作是通过使用文件流来实现的。文件流提供了对文件的输入和输出功能。下面是C文件操作的基本步骤&#xff1a; 1. 包含头文件&#xff1a;首先&#xff0c;包含 <fstream> 头文件&#xff0c;它包含了进行文件操作所需的类和函数。 2 . 进行文件读写操作&#…

9领域事件

本系列包含以下文章&#xff1a; DDD入门DDD概念大白话战略设计代码工程结构请求处理流程聚合根与资源库实体与值对象应用服务与领域服务领域事件&#xff08;本文&#xff09;CQRS 案例项目介绍 # 既然DDD是“领域”驱动&#xff0c;那么我们便不能抛开业务而只讲技术&…

Windows专业版的Docker下载、安装与启用Kubenetes、访问Kubernetes Dashboard

到Docker 官网https://www.docker.com/ 下载windows操作系统对应的docker软件安装 Docker Desktop Installer-Win.exe 2023-09版本是4.23 下载后双击安装 重启windows后&#xff0c;继续安装 接受服务继续安装 解决碰到的Docker Engine stopped 打开 控制面板》程序》启用或关…

软件测试/测试开发丨利用人工智能ChatGPT自动生成PPT

点此获取更多相关资料 简介 PPT 已经渗透到我们的日常工作中&#xff0c;无论是工作汇报、商务报告、学术演讲、培训材料都常常要求编写一个正式的 PPT&#xff0c;协助完成一次汇报或一次演讲。PPT相比于传统文本的就是有布局、图片、动画效果等&#xff0c;可以给到观众更好…

第一百五十四回 如何实现滑动菜单

文章目录 概念介绍实现方法示例代码体验分享 我们在上一章回中介绍了滑动窗口相关的内容相关的内容&#xff0c;本章回中将介绍如何实现 滑动菜单.闲话休提&#xff0c;让我们一起Talk Flutter吧。 概念介绍 我们在本章回中介绍的滑动菜单表示屏幕上向左或者向右滑动滑动时弹…

自注意力机制

回顾以下注意力机制&#xff1a; 自注意力机制 Self-Attention的关键点 在于 K ≈ \approx ≈V ≈ \approx ≈Q 来源于同一个X&#xff0c;三者是同源的&#xff0c;通过 W Q W_Q WQ​, W K W_K WK​, W V W_V WV​做了一层线性变换。 接下来步骤和注意力机制一模一样。 …