用C语言构建一个手写数字识别神经网络

59bc0d11101b414b8ec3a3ce89e498b7.jpg

(原理和程序基本框架请参见前一篇 "用C语言构建了一个简单的神经网路")

1.准备训练和测试数据集
从http://yann.lecun.com/exdb/mnist/下载手写数字训练数据集, 包括图像数据train-images-idx3-ubyte.gz 和标签数据 train-labels-idx1-ubyte.gz.
分别将他们解压后放在本地文件夹中,解压后文件名为train-images-idx3-ubyte和train-labels-idx1-ubyte. 训练数据集一共包含了6万个手写数字灰度图和对应的标签.
为图方便,我们直接从训练数据集中提取5000个作为测试数据.当然,实际训练数据中并不包含这些测试数据.


2.设计神经网络
采用简单的三层全连接神经网络,包括输入层(wi),中间层(wm)和输出层(wo).这里暂时不使用卷积层,下次替换后进行比较.
输入层: 一共20个神经元,每一张手写数字的图片大小为28x28,将全部展平后的784个灰度数据归一化,即除以255.0, 使其数值位于[0 1]区间,这样可以防止数据在层层计算和传递后变得过分大.将这784个[0 1]之间的数据与20个神经元进行全连接.神经元激活函数用func_ReLU.
中间层: 一共20个神经元,与输入层的20个神经元输出进行全连接.神经元激活函数用func_ReLU.
输出层: 一共10个神经元,分别对应0~9数字的可能性,与中间层的20个神经元输出进行全连接.层的激活函数用func_softmax.
特别地,神经元的激活函数在new_nvcell()中设定,层的激活函数直接赋给nerve_layer->transfunc.
损失函数: 采用期望和预测值的交叉熵损失函数func_lossCrossEntropy. 损失函数在nvnet_feed_forward()中以参数形式输入.

3.训练神经网络
由于整个程序是以nvcell神经元结构为基础进行构建的,其不同于矩阵/张量形式的批量数据描述,因此这个神经网络只能以神经元为单位,逐个逐层地进行前向和反向传导.
相应地,这里采用SGD(Stochastic Gradient Descent)梯度下降更新法,即对每一个样本先进行前向和反向传导计算,接着根据计算得到的梯度值马上更新所有参数.与此不同,mini-batch GD采用小批量样本进行前向和反向传导计算,然后根据累积的梯度数值做1次参数更新.显然,采用SGD方法参数更新更加频繁,计算时间相应也变长了.不过,据网文分析,采用SGD也更容易趋近全局最优解,尽管逼近的途径会比较曲折.本文程序中的分批计算是为了方便监控计算过程和打印中间值.(当然,要实现mini-batch GD也是可以的,先完成一批量样本的前后传导计算,期间将各参数的梯度累计起来,  最后取其平均值更新一次参数.)
这里使用平均损失值mean_err<=0.0025来作为训练的终止条件,为防止无法收敛到此数值,同时设置最大的epoch计数.
训练的样本数量由TRAIN_IMGTOTAL来设定, 训练时,先读取一个样本数据和一个标签,分别存入到data_input[28*28]和data_target[10], 为了配合应用softmax函数,这里data_target[]是one-hot编码格式.读入样本数据后先进行前向传导计算nvnet_feed_forward(),接着进行反向传导计算nvnet_feed_backward(), 最后更新参数nvnet_update_params(), 这样就完成了一个样本的训练.如此循环计算,完成一次所有样本的训练(epoch)后计算mean_err, 看是否达到预设目标.

4.测试训练后的神经网络
训练完成后,对模型进行简单评估.方法就是用训练后的模型来预测(predict)或推理(infer)前面的测试数据集中的图像数据,将结果与对应的标签值做对比.
同样,将一个测试样本加载到data_input[], 跑一次nvnet_feed_forward(),直接读取输出层的wo_layer->douts[k] (k=0~9),如果其值大于0.5,就认为模型预测图像上的数字是k.

5.小结
取5万条训练样本进行训练,训练后再进行测试,其准确率可接近94%.
与卷积神经网络相比较,为达到相同的结果,全连接的神经网络的所需要的训练时间会更长.

6.实验和改进
可以先将28*28的图片下采样到14*14后再连接到输入层.这样可以提高速度.

可以试着调整输入层和中间层的神经元数目.

也可以试着调整单个神经元的输入连接方式...

源代码:
https://github.com/midaszhou/nnc
下载后编译:
make TEST_NAME=test_nnc2

ca7ccbb483734fa191163ce2b05ce968.png

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/15403.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

芯片制造详解.光刻技术与基本流程.学习笔记(四)

本篇文章是看了以下视频后的笔记提炼&#xff0c;欢迎各位观看原视频&#xff0c;这里附上地址 芯片制造详解04&#xff1a;光刻技术与基本流程&#xff5c;国产之路不容易 芯片制造详解.光刻技术与基本流程.学习笔记 四 一、引子二、光刻(1).光掩膜(2).光刻机(3).光刻胶(4).挖…

宝塔设置云服务器mysql端口转发,实现本地电脑访问云mysql

环境&#xff1a;centos系统使用宝塔面板 实现功能&#xff1a;宝塔设置云服务器mysql端口转发&#xff0c;实现本地电脑访问mysql 1.安装mysql、PHP-7.4.33、phpMyAdmin 5.0 软件商店》搜索 mysql安装即可 软件商店》搜索 PHP安装7.4.33即可&#xff08;只需要勾选快速安装&…

按键消抖(有/无状态机)

一&#xff0c;理论概念 按键抖动 按键抖动&#xff1a;按键抖动通常的按键所用开关为机械弹性开关&#xff0c;当机械触点断开、闭合时&#xff0c;由于机械触点的弹性作用&#xff0c;一个按键开关在闭合时不会马上稳定地接通&#xff0c;在断开时也不会一下子断开。因而在闭…

数据结构: 线性表(顺序表实现)

文章目录 1. 线性表的定义2. 线性表的顺序表示:顺序表2.1 概念及结构2.2 接口实现2.2.1 顺序表初始化 (SeqListInit)2.2.2 顺序表尾插 (SeqListPushBack)2.2.3 顺序表打印 (SeqListPrint)2.2.6 顺序表销毁 (SeqListDestroy)2.2.5 顺序表尾删 (SeqListPopBack)2.2.6 顺序表头插 …

Windows 错误代码 0xc00000e的应对方案

Windows 上的错误代码 0xc00000e 是什么? 我们在日常的电脑使用过程中,错误代码 0xc00000e 是一个 BSOD 错误,有时会在启动过程中出现黑屏。错误代码通常伴随着一条消息,说明 – “您的电脑需要维修。 所需设备未连接或无法访问。” 该错误消息表明 Windows 操作系统在启动…

Matlab实现光伏仿真(附上30个完整仿真源码)

光伏发电电池模型是描述光伏电池在不同条件下产生电能的数学模型。该模型可以用于预测光伏电池的输出功率&#xff0c;并为优化光伏电池系统设计和控制提供基础。本文将介绍如何使用Matlab实现光伏发电电池模型。 文章目录 1、光伏发电电池模型2、使用Matlab实现光伏发电电池模…

安全学习DAY08_算法加密

算法加密 漏洞分析、漏洞勘测、漏洞探针、挖漏洞时要用到的技术知识 存储密码加密-应用对象传输加密编码-发送回显数据传输格式-统一格式代码特性混淆-开发语言 传输数据 – 加密型&编码型 安全测试时&#xff0c;通常会进行数据的修改增加提交测试 数据在传输的时候进行…

【Linux】关于Bad magic number in super-block 当尝试打开/dev/sda1 时找不到有效的文件系统超级块

每个区段与 superblock 的信息都可以使用 dumpe2fs 这个指令来查询的&#xff01; 不过可惜的是&#xff0c;我们的 CentOS 7 现在是以 xfs 为默认文件系统&#xff0c; 所以目前你的系统应该无法使用 dumpe2fs 去查询任何文件系统的。 因为目前两个版本系统的根目录使用的文…

IT职场笔记

MySQL笔记之一致性视图与MVCC实现 一致性读视图是InnoDB在实现MVCC用到的虚拟结构&#xff0c;用于读提交&#xff08;RC&#xff09;和可重复度&#xff08;RR&#xff09;隔离级别的实现。 一致性视图没有物理结构&#xff0c;主要是在事务执行期间用来定义该事物可以看到什…

gitlab迁移

下载地址 安装 https://packages.gitlab.com/gitlab/gitlab-ce/packages/el/7/gitlab-ce-12.6.2-ce.0.el7.x86_64.rpm/download.rpm yum -y install gitlab-ce-12.6.2-ce.0.el7.x86_64.rpm 官方文档 https://docs.gitlab.cn/jh/raketasks/backup_restore.html#restore-gitlab …

护网行动:ADSelfService Plus引领企业网络安全新纪元

随着信息技术的飞速发展&#xff0c;企业网络的重要性变得愈发显著。然而&#xff0c;随之而来的网络安全威胁也日益增多&#xff0c;网络黑客和恶意软件不断涌现&#xff0c;给企业的数据和机密信息带来巨大风险。在这个信息安全威胁层出不穷的时代&#xff0c;企业急需一款强…

在CSDN学Golang云原生(服务网格istio)

一&#xff0c;在Kubernetes上部署istio 在Kubernetes上部署istio&#xff0c;可以按照以下步骤进行&#xff1a; 安装Istio 使用以下命令从Istio官网下载最新版本的Istio&#xff1a; curl -L https://istio.io/downloadIstio | ISTIO_VERSION<VERSION> sh - 其中&…

【HDFS】LocatedBlocks、LocatedBlock、LocatedStripedBlock、ExtendedBlock类分析

本文主要介绍如下内容: 1、 介绍标题中类的功能及相关字段 2、 与字段初始化相关的一些细节 一、ExtendedBlock类 在Block Pools之间唯一标识一个块。 直白点就是一个Block再加一个块池id。 块池的概念是HDFS联邦集群之后产生的,因为一台DataNode的主机可以作为多个HDFS集群…

Ubuntu的安装与部分配置

该教程使用的虚拟机是virtuabox&#xff0c;镜像源的版本是ubuntu20.04.5桌面版 可通过下面的链接在Ubuntu官网下载&#xff1a;Alternative downloads | Ubuntu 也可直接通过下面的链接进入百度网盘下载【有Ubuntu20.04.5与hadoop3.3.2以及jdk1.8.0_162&#xff0c;该篇需要使…

npm install pnpm -g报错解决!

目录 报错信息&#xff1a;&#xff08;反正就是各种err&#xff09; 报错分析&#xff1a; 错误处理&#xff1a; 其它pnpm报错传送门&#xff1a; 报错信息&#xff1a;&#xff08;反正就是各种err&#xff09; npm ERR! code EPERM npm ERR! syscall mkdir npm ERR! pa…

idea 关于高亮显示与选中字符串相同的内容

dea 关于高亮显示与选中字符串相同的内容&#xff0c;本文作为个人备忘的同时也希望可以作为大家的参考。 依次修改File-settings-Editor-Color Scheme-General菜单下的Code-Identifier under caret和Identifier under caret(write)的Backgroud色值&#xff0c;可以参考下图。…

【案例】--GPT衍生应用案例

目录 一、前言二、GPT实现智能问答架构2.1、基本的GPT实现智能问答架构2.2、可应用的GPT实现智能问答架构1、语义转换2、相似度关键字矩阵3、ES中搜索相似度关键字矩阵三、后续一、前言 GPT,全称Generative Pre-trained Transformer ,中文名可译作生成式预训练Transformer。…

算法leetcode|64. 最小路径和(rust重拳出击)

文章目录 64. 最小路径和&#xff1a;样例 1&#xff1a;样例 2&#xff1a;提示&#xff1a; 分析&#xff1a;题解&#xff1a;rust&#xff1a;go&#xff1a;c&#xff1a;python&#xff1a;java&#xff1a; 64. 最小路径和&#xff1a; 给定一个包含非负整数的 m x n 网…

【linux--->传输层协议】

文章目录 [TOC](文章目录) 一、端口号1.端口号划分范围2.常用知名端口号 二、网络命令1.netstat 命令2.pidof 命令 三、UDP协议1.格式2.协议的分离和合并3.特点4.缓冲区 四、TCP协议1.格式2.4位的数据偏移3.确认应答机制4.序号与确认序号5.16位窗口6.标志位7.超时重传8.三次握手…

大脑睡眠是否因智力的不同而不同?

摘要 目的&#xff1a;比较不同智力水平儿童的睡眠脑电图。 方法&#xff1a;根据韦氏儿童智力量表(WISC)评分进行分组(17名智商正常[NIQ组]&#xff0c;24名高智商[HIQ组])。采用方差分析和线性回归模型(根据年龄和性别进行校正)比较组间频谱功率及其与WISC评分的关系。 结…