【深度学习实验】前馈神经网络(八):模型评价(自定义支持分批进行评价的Accuracy类)

目录

一、实验介绍

 二、实验环境

1. 配置虚拟环境

2. 库版本介绍

三、实验内容

 0. 导入必要的工具包

1. __init__(构造函数)

2. update函数(更新评价指标)

5. accumulate(计算准确率)

4. reset(重置评价指标)

5. 构造数据进行测试

6. 代码整合


一、实验介绍

       本文将实现一个辅助功能——计算预测的准确率。Accuracy支持对每一个回合中每批数据进行评价,并将结果累积,最终获得整批数据的评价结果。

  • 在训练或验证过程中迭代地调用update方法来更新评价指标;
  • 使用accumulate方法获取累计的准确率;
  • 通过reset方法重置评价指标,以便进行下一轮的计算。

 二、实验环境

    本系列实验使用了PyTorch深度学习框架,相关操作如下:

1. 配置虚拟环境

conda create -n DL python=3.7 
conda activate DL
pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
conda install matplotlib
 conda install scikit-learn

2. 库版本介绍

软件包本实验版本目前最新版
matplotlib3.5.33.8.0
numpy1.21.61.26.0
python3.7.16
scikit-learn0.22.11.3.0
torch1.8.1+cu1022.0.1
torchaudio0.8.12.0.2
torchvision0.9.1+cu1020.15.2

三、实验内容

ChatGPT:

        前馈神经网络(Feedforward Neural Network)是一种常见的人工神经网络模型,也被称为多层感知器(Multilayer Perceptron,MLP)。它是一种基于前向传播的模型,主要用于解决分类和回归问题。

        前馈神经网络由多个层组成,包括输入层、隐藏层和输出层。它的名称"前馈"源于信号在网络中只能向前流动,即从输入层经过隐藏层最终到达输出层,没有反馈连接。

以下是前馈神经网络的一般工作原理:

  1. 输入层:接收原始数据或特征向量作为网络的输入,每个输入被表示为网络的一个神经元。每个神经元将输入加权并通过激活函数进行转换,产生一个输出信号。

  2. 隐藏层:前馈神经网络可以包含一个或多个隐藏层,每个隐藏层由多个神经元组成。隐藏层的神经元接收来自上一层的输入,并将加权和经过激活函数转换后的信号传递给下一层。

  3. 输出层:最后一个隐藏层的输出被传递到输出层,输出层通常由一个或多个神经元组成。输出层的神经元根据要解决的问题类型(分类或回归)使用适当的激活函数(如Sigmoid、Softmax等)将最终结果输出。

  4. 前向传播:信号从输入层通过隐藏层传递到输出层的过程称为前向传播。在前向传播过程中,每个神经元将前一层的输出乘以相应的权重,并将结果传递给下一层。这样的计算通过网络中的每一层逐层进行,直到产生最终的输出。

  5. 损失函数和训练:前馈神经网络的训练过程通常涉及定义一个损失函数,用于衡量模型预测输出与真实标签之间的差异。常见的损失函数包括均方误差(Mean Squared Error)和交叉熵(Cross-Entropy)。通过使用反向传播算法(Backpropagation)和优化算法(如梯度下降),网络根据损失函数的梯度进行参数调整,以最小化损失函数的值。

        前馈神经网络的优点包括能够处理复杂的非线性关系,适用于各种问题类型,并且能够通过训练来自动学习特征表示。然而,它也存在一些挑战,如容易过拟合、对大规模数据和高维数据的处理较困难等。为了应对这些挑战,一些改进的网络结构和训练技术被提出,如卷积神经网络(Convolutional Neural Networks)和循环神经网络(Recurrent Neural Networks)等。

本系列为实验内容,对理论知识不进行详细阐释

(咳咳,其实是没时间整理,待有缘之时,回来填坑)

977468b5ae9843c6a88005e792817cb1.png​​

 0. 导入必要的工具包

import torch
from sklearn.datasets import load_iris
from torch.utils.data import Dataset, DataLoader
  • DatasetDataLoader类用于处理数据集和数据加载

这段代码定义了一个名为Accuracy的类,用于支持分批进行模型评价,特别是在分类任务中计算准确率。

1. __init__(构造函数)

class Accuracy:def __init__(self, is_logist=True):self.num_correct = 0self.num_count = 0self.is_logist = is_logist
  • 构造函数在创建Accuracy对象时被调用。它接受一个可选的参数is_logist,默认为True,用于指示是否为logist形式的预测值。
  • self.num_correct用于记录正确预测的样本个数。
  • self.num_count用于记录总样本个数。
  • self.is_logist指示是否为logist形式的预测值。

2. update函数(更新评价指标)

def update(self, outputs, labels):if outputs.shape[1] == 1:outputs = outputs.squeeze(-1)if self.is_logist:preds = (outputs >= 0).long()else:preds = (outputs >= 0.5).long()else:preds = torch.argmax(outputs, dim=1).long()labels = labels.squeeze(-1)batch_correct = (preds==labels).float().sum()batch_count = len(labels)self.num_correct += batch_correctself.num_count += batch_count
  • update方法用于更新评价指标。它接受两个参数outputslabels,分别表示模型的预测输出和真实标签。
  • 根据outputs的形状判断任务类型。
    •  如果outputs是二维张量且第二维大小为1,那么表示是二分类任务。
      •   如果is_logist=True,则将outputs通过阈值(0)转换为预测值preds,并将其转换为整数类型。
      •   如果is_logist=False,则将outputs通过阈值(0.5)转换为预测值preds,并将其转换为整数类型。
    •  如果outputs是二维张量且第二维大小大于1,表示是多分类任务。此时,将outputs中概率最大的类别作为预测值preds
  • labels去除多余的维度,并计算本批数据中预测正确的样本个数batch_correct
  • 获取本批数据的样本个数batch_count
  • 更新num_correctnum_count,累积计算正确样本个数和总样本个数。

5. accumulate(计算准确率)

def accumulate(self):if self.num_count == 0:return 0return self.num_correct / self.num_count
  • accumulate方法用于计算准确率。
    •  如果num_count为0,表示没有进行过更新,返回0。
    • 否则,返回正确样本个数除以总样本个数的比例,即准确率

4. reset(重置评价指标)

def reset(self):self.num_correct = 0self.num_count = 0
  • reset方法用于重置评价指标,将num_correctnum_count重置为0,以便进行下一轮评价

5. 构造数据进行测试

y = torch.tensor([0, 2])
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
acc = Accuracy()
acc.update(y_hat, y)
acc.num_correct

6. 代码整合

import torch# 支持分批进行模型评价的 Accuracy 类
class Accuracy:def __init__(self, is_logist=True):# 正确样本个数self.num_correct = 0# 样本总数self.num_count = 0self.is_logist = is_logistdef update(self, outputs, labels):# 判断是否为二分类任务if outputs.shape[1] == 1:outputs = outputs.squeeze(-1)# 判断是否是logit形式的预测值if self.is_logist:preds = (outputs >= 0).long()else:preds = (outputs >= 0.5).long()else:# 多分类任务时,计算最大元素索引作为类别preds = torch.argmax(outputs, dim=1).long()# 获取本批数据中预测正确的样本个数labels = labels.squeeze(-1)batch_correct = (preds == labels).float().sum()batch_count = len(labels)# 更新self.num_correct += batch_correctself.num_count += batch_countdef accumulate(self):# 使用累计的数据,计算总的评价指标if self.num_count == 0:return 0return self.num_correct / self.num_countdef reset(self):self.num_correct = 0self.num_count = 0y = torch.tensor([0, 2])
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
acc = Accuracy()
acc.update(y_hat, y)
acc.num_correct

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/87730.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

进行 XSS 攻击 和 如何防御

跨站脚本攻击(XSS 攻击)是 Web 开发中最危险的攻击之一。以下是它们的工作原理以及防御方法。 XSS 攻击 跨站脚本攻击就是在另一个用户的计算机上运行带有恶意的 JS 代码。假如我们的程序没有对这些恶意的脚本进行防御的话,他们就会由我们的…

【Java】泛型 之 super通配符

我们前面已经讲到了泛型的继承关系&#xff1a;Pair<Integer>不是Pair<Number>的子类。 考察下面的set方法&#xff1a; void set(Pair<Integer> p, Integer first, Integer last) {p.setFirst(first);p.setLast(last); }传入Pair<Integer>是允许的&…

李宏毅hw-10 ——adversarial attack

一、查漏补缺&#xff1a; 1.关于glob.glob的用法&#xff0c;返回一个文件路径的 列表&#xff1a; 当然&#xff0c;再套用1个sort&#xff0c;就是将所有的文件路径按照字母进行排序了 2.relpath relative_path返回相对于基准路径的相对路径的函数 二、代码剖析&#xff…

Mybatis连接DB2数据库时,FETCH FIRST {n} ROWS ONLY不能参数化解决

Mybatis连接DB2数据为时 ......WHERE ROW_NUM_HAHA > #{start,jdbcTypeNUMERIC} FETCH FIRST #{pageSize,jbdcTypeNUMERIC} ROWS ONLY...... 如果像上面这样写是不行的。查过资料后&#xff0c;才发现FETCH FIRST后面的值是不能参数化的&#xff0c;只能写死。而Mybatis中…

STM32单片机入门学习(四)-蜂鸣器

蜂鸣器接线 低平蜂鸣器&#xff0c;低电平发声&#xff0c;高电平不发声&#xff0c; 三个排针&#xff0c;VCC接3.3v&#xff0c;GND接地&#xff0c;I/O接A0口&#xff0c;如图&#xff1a; 蜂鸣器代码&#xff1a;响一秒停半秒 #include "stm32f10x.h" #includ…

MySQL 排序规则

文章目录 1.简介2.支持的排序规则3.设置排序规则4.中文排序规则参考文献 1.简介 字符集是一组符号和编码。排序规则是一组用于比较字符集中的字符的规则。 每个 MySQL 字符集可以支持一个或者多个排序规则&#xff0c;用于定义每个字符的比较规则&#xff0c;包括是否区分大小…

软考高级系统架构设计师系列论文真题八:论企业集成平台的技术与应用

软考高级系统架构设计师系列论文真题八:论企业集成平台的技术与应用 一、论企业集成平台的技术与应用二、找准核心论点三、理论素材准备四、精品范文赏析1.摘要2.正文3.总结软考高级系统架构设计师系列论文之:百篇软考高级架构设计师论文范文软考高级系统架构设计师系列之:论…

LRU、LFU 内存淘汰算法的设计与实现

1、背景介绍 LRU、LFU都是内存管理淘汰算法&#xff0c;内存管理是计算机技术中重要的一环&#xff0c;也是多数操作系统中必备的模块。应用场景&#xff1a;假设 给定你一定内存空间&#xff0c;需要你维护一些缓存数据&#xff0c;LRU、LFU就是在内存已经满了的情况下&#…

以容器方式运行 windows 图形化界面系统,附docker详细配置步骤和yaml完整执行文件

以容器方式运行 windows 图形化界面系统,附docker详细配置步骤和yaml完整执行文件。 常规普通的docker中运行windows系统,只能运行无界面化的系统,例如: 要在Docker中运行Windows应用程序,需要使用Windows容器。以下是一些步骤: 确认您的操作系统支持Docker桌面应用程…

2023-2024年最新大数据学习路线

文章目录 2023-2024年最新大数据学习路线大数据开发入门*01*阶段案例实战 大数据核心基础*02*阶段案例实战 千亿级数仓技术*03*阶段项目实战 PB级内存计算04阶段项目实战 亚秒级实时计算*05*阶段项目实战 大厂面试*06* 2023-2024年最新大数据学习路线 新路线图在Spark一章不再…

Android跨进程通信:Binder机制原理

目录 1. Binder到底是什么&#xff1f; 2. 知识储备 2.1 进程空间划分 2.2 进程隔离 & 跨进程通信&#xff08; IPC &#xff09; 2.3 内存映射 2.3.1 作用 2.3.2 实现过程 2.3.3 特点 2.3.4 应用场景 2.3.5 实例讲解 ① 文件读 / 写操作 ② 跨进程通信 3. Bi…

【学习笔记】Prufer序列

Prufer序列 起源于对 C a y l e y Cayley Cayley定理的证明&#xff0c;但是其功能远不止于此 现在考虑将一棵n个节点的树与一个长度为n-2的prufer序列构造对应关系 T r e e − > P r u f e r : Tree->Prufer: Tree−>Prufer: ①从树上选择编号最小的叶子节点&#x…

Cron表达式_用于定时调度任务

一、Cron表达式简介 Cron表达式是一个用于设置计划任务的字符串&#xff0c;该字符串以5或6个空格分隔&#xff0c;分为6或7个域&#xff0c;每一个域代表任务在相应时间、日期或时间间隔执行的规则【Cron表达式最初是在类Unix操作中系统中使用的&#xff0c;但现在已经广泛应用…

人机融合需要在事实与价值之间构建新型的拓扑关系

人机融合&#xff0c;这是指将人类智慧&#xff08;含艺术&#xff09;与计算机科技相结合&#xff0c;共同解决复杂问题的一种新方式。在人机融合中&#xff0c;我们需要构建事实与价值之间的新型拓扑关系&#xff0c;以实现更有效的知识管理和决策支持。 事实是指客观存在的、…

Python爬虫爬取豆瓣电影短评(爬虫入门,Scrapy框架,Xpath解析网站,jieba分词)

声明&#xff1a;以下内容仅供学习参考&#xff0c;禁止用于任何商业用途 很久之前就想学爬虫了&#xff0c;但是一直没机会&#xff0c;这次终于有机会了 主要参考了《疯狂python讲义》的最后一章 首先安装Scrapy&#xff1a; pip install scrapy 然后创建爬虫项目&#…

EdgeMoE: Fast On-Device Inference of MoE-based Large Language Models

本文是LLM系列文章&#xff0c;针对《EdgeMoE: Fast On-Device Inference of MoE-based Large Language Models》的翻译。 EdgeMoE&#xff1a;基于MoE的大型语言模型的快速设备推理 摘要1 引言2 实验与分析3 EDGEMOE设计4 评估5 相关工作6 结论 摘要 GPT和LLaMa等大型语言模…

力扣26:删除有序数组中的重复项

26. 删除有序数组中的重复项 - 力扣&#xff08;LeetCode&#xff09; 题目&#xff1a; 给你一个 非严格递增排列 的数组 nums &#xff0c;请你 原地 删除重复出现的元素&#xff0c;使每个元素 只出现一次 &#xff0c;返回删除后数组的新长度。元素的 相对顺序 应该保持 …

关于业务库从MySQL迁移到DM8的操作指南

升级前准备 jdbc:dm://10.252.10.15:5237 username: datashare password: datashare123 把当前MySQL数据库下的数据库表以及数据迁移到DM8。通过达梦8自带的工具可以实现迁移&#xff08;仅支持Win&#xff09; DM8管理工具下载&#xff1a;https://www.dameng.com/DM8.html…

Leetcode 95. 不同的二叉搜索树 II

文章目录 题目代码&#xff08;9.21 首刷看解析&#xff09; 题目 Leetcode 95. 不同的二叉搜索树 II 代码&#xff08;9.21 首刷看解析&#xff09; class Solution { public:vector<TreeNode*> generateTrees(int n) {return build(1,n);}vector<TreeNode*> bu…

将本地前端工程中的npm依赖上传到Nexus

【问题背景】 用Nexus搭建了内网的依赖仓库&#xff0c;需要将前端工程中node_modules中的依赖上传到Nexus上&#xff0c;但是node_modules中的依赖已经是解压后的状态&#xff0c;如果直接机械地将其简单地打包上传到Nexus&#xff0c;那么无法通过npm install下载使用。故有…